asgi-tools 1.2.0__cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
asgi_tools/tests.py ADDED
@@ -0,0 +1,405 @@
1
+ """Testing tools."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import binascii
7
+ import io
8
+ import mimetypes
9
+ import os
10
+ import random
11
+ from collections import deque
12
+ from contextlib import asynccontextmanager, suppress
13
+ from functools import partial
14
+ from http.cookies import SimpleCookie
15
+ from json import loads
16
+ from pathlib import Path
17
+ from typing import (
18
+ TYPE_CHECKING,
19
+ Any,
20
+ AsyncGenerator,
21
+ Awaitable,
22
+ Callable,
23
+ Coroutine,
24
+ Deque,
25
+ cast,
26
+ )
27
+ from urllib.parse import urlencode
28
+
29
+ from yarl import URL
30
+
31
+ from ._compat import aio_cancel, aio_sleep, aio_spawn, aio_timeout, aio_wait
32
+ from .constants import BASE_ENCODING, DEFAULT_CHARSET
33
+ from .errors import ASGIConnectionClosedError, ASGIInvalidMessageError
34
+ from .response import Response, ResponseJSON, ResponseWebSocket, parse_websocket_msg
35
+ from .utils import CIMultiDict, parse_headers
36
+
37
+ if TYPE_CHECKING:
38
+ from multidict import MultiDict
39
+
40
+ from .types import TJSON, TASGIApp, TASGIMessage, TASGIReceive, TASGIScope, TASGISend
41
+
42
+
43
+ class TestResponse(Response):
44
+ """Response for test client."""
45
+
46
+ def __init__(self):
47
+ super().__init__(b"")
48
+ self.content = None
49
+
50
+ async def __call__(self, _: TASGIScope, receive: TASGIReceive, send: TASGISend): # noqa: ARG002
51
+ self._receive = receive
52
+ msg = await self._receive()
53
+ assert msg.get("type") == "http.response.start", "Invalid Response"
54
+ self.status_code = int(msg.get("status", 502))
55
+ self.headers = cast("MultiDict", parse_headers(msg.get("headers", [])))
56
+ self.content_type = self.headers.get("content-type")
57
+ for cookie in self.headers.getall("set-cookie", []):
58
+ self.cookies.load(cookie)
59
+
60
+ async def stream(self) -> AsyncGenerator[bytes, None]:
61
+ """Stream the response."""
62
+ more_body = True
63
+ while more_body:
64
+ msg = await self._receive()
65
+ if msg.get("type") == "http.response.body":
66
+ chunk = msg.get("body")
67
+ if chunk:
68
+ yield chunk
69
+ more_body = msg.get("more_body", False)
70
+
71
+ async def body(self) -> bytes:
72
+ """Load response body."""
73
+ if self.content is None:
74
+ body = b""
75
+ async for chunk in self.stream():
76
+ body += chunk
77
+ self.content = body
78
+
79
+ return self.content
80
+
81
+ async def text(self) -> str:
82
+ body = await self.body()
83
+ return body.decode(DEFAULT_CHARSET)
84
+
85
+ async def json(self) -> TJSON:
86
+ text = await self.text()
87
+ return loads(text)
88
+
89
+
90
+ class TestWebSocketResponse(ResponseWebSocket):
91
+ """Support websockets in tests."""
92
+
93
+ def connect(self) -> Coroutine[TASGIMessage, Any, Any]:
94
+ return self.send({"type": "websocket.connect"})
95
+
96
+ async def disconnect(self):
97
+ await self.send({"type": "websocket.disconnect", "code": 1005})
98
+ self.state = self.STATES.DISCONNECTED
99
+
100
+ def send(self, msg, msg_type="websocket.receive"):
101
+ """Send a message to a client."""
102
+ return super().send(msg, msg_type=msg_type)
103
+
104
+ async def receive(self, *, raw=False):
105
+ """Receive messages from a client."""
106
+ if self.partner_state == self.STATES.DISCONNECTED:
107
+ raise ASGIConnectionClosedError
108
+
109
+ msg = await self._receive()
110
+ if not msg["type"].startswith("websocket."):
111
+ raise ASGIInvalidMessageError(msg)
112
+
113
+ if msg["type"] == "websocket.accept":
114
+ self.partner_state = self.STATES.CONNECTED
115
+ return await self.receive(raw=raw)
116
+
117
+ if msg["type"] == "websocket.close":
118
+ self.partner_state = self.STATES.DISCONNECTED
119
+ raise ASGIConnectionClosedError
120
+
121
+ return msg if raw else parse_websocket_msg(msg, charset=DEFAULT_CHARSET)
122
+
123
+
124
+ class ASGITestClient:
125
+ """Built-in test client for ASGI applications.
126
+
127
+ Features:
128
+
129
+ * cookies
130
+ * multipart/form-data
131
+ * follow redirects
132
+ * request streams
133
+ * response streams
134
+ * websocket support
135
+ * lifespan management
136
+
137
+ """
138
+
139
+ def __init__(self, app: TASGIApp, base_url: str = "http://localhost"):
140
+ self.app = app
141
+ self.base_url = URL(base_url)
142
+ self.cookies: SimpleCookie = SimpleCookie()
143
+ self.headers: dict[str, str] = {}
144
+
145
+ def __getattr__(self, name: str) -> Callable[..., Awaitable]:
146
+ return partial(self.request, method=name.upper())
147
+
148
+ async def request(
149
+ self,
150
+ path: str,
151
+ method: str = "GET",
152
+ *,
153
+ query: str | dict = "",
154
+ headers: dict[str, str] | None = None,
155
+ cookies: dict[str, str] | None = None,
156
+ data: bytes | str | dict | AsyncGenerator[Any, bytes] = b"",
157
+ json: TJSON = None,
158
+ follow_redirect: bool = True,
159
+ timeout: float = 10.0,
160
+ ) -> TestResponse:
161
+ """Make a HTTP requests."""
162
+
163
+ headers = headers or dict(self.headers)
164
+
165
+ if isinstance(data, str):
166
+ data = Response.process_content(data)
167
+
168
+ elif isinstance(data, dict):
169
+ is_multipart = any(isinstance(value, io.IOBase) for value in data.values())
170
+ if is_multipart:
171
+ data, headers["Content-Type"] = encode_multipart(data)
172
+
173
+ else:
174
+ headers["Content-Type"] = "application/x-www-form-urlencoded"
175
+ data = urlencode(data).encode(DEFAULT_CHARSET)
176
+
177
+ elif json is not None:
178
+ headers["Content-Type"] = "application/json"
179
+ data = ResponseJSON.process_content(json)
180
+
181
+ pipe = Pipe()
182
+
183
+ if isinstance(data, bytes):
184
+ headers.setdefault("Content-Length", str(len(data)))
185
+
186
+ scope = self.build_scope(
187
+ path,
188
+ type="http",
189
+ query=query,
190
+ method=method,
191
+ headers=headers,
192
+ cookies=cookies,
193
+ )
194
+
195
+ async with aio_timeout(timeout):
196
+ await aio_wait(
197
+ pipe.stream(data),
198
+ self.app(scope, pipe.receive_from_app, pipe.send_to_client),
199
+ )
200
+
201
+ res = TestResponse()
202
+ await res(scope, pipe.receive_from_client, pipe.send_to_app)
203
+ for n, v in res.cookies.items():
204
+ self.cookies[n] = v
205
+
206
+ if follow_redirect and res.status_code in {301, 302, 303, 307, 308}:
207
+ return await self.get(res.headers["location"])
208
+
209
+ return res
210
+
211
+ # TODO: Timeouts for websockets
212
+ @asynccontextmanager
213
+ async def websocket(
214
+ self,
215
+ path: str,
216
+ query: str | dict | None = None,
217
+ headers: dict | None = None,
218
+ cookies: dict | None = None,
219
+ ):
220
+ """Connect to a websocket."""
221
+ pipe = Pipe()
222
+
223
+ ci_headers = CIMultiDict(headers or {})
224
+
225
+ scope = self.build_scope(
226
+ path,
227
+ headers=ci_headers,
228
+ query=query,
229
+ cookies=cookies,
230
+ type="websocket",
231
+ subprotocols=str(ci_headers.get("Sec-WebSocket-Protocol", "")).split(","),
232
+ )
233
+ ws = TestWebSocketResponse(scope, pipe.receive_from_client, pipe.send_to_app)
234
+ async with aio_spawn(
235
+ self.app,
236
+ scope,
237
+ pipe.receive_from_app,
238
+ pipe.send_to_client,
239
+ ):
240
+ await ws.connect()
241
+ yield ws
242
+ await ws.disconnect()
243
+
244
+ def lifespan(self, timeout: float = 3e-2):
245
+ """Manage `Lifespan <https://asgi.readthedocs.io/en/latest/specs/lifespan.html>`_
246
+ protocol."""
247
+ return manage_lifespan(self.app, timeout=timeout)
248
+
249
+ def build_scope(
250
+ self,
251
+ path: str,
252
+ headers: dict | CIMultiDict | None = None,
253
+ query: str | dict | None = None,
254
+ cookies: dict | None = None,
255
+ **scope,
256
+ ) -> TASGIScope:
257
+ """Prepare a request scope."""
258
+ headers = headers or {}
259
+ headers.setdefault("User-Agent", "ASGI-Tools-Test-Client")
260
+ headers.setdefault("Host", self.base_url.host)
261
+
262
+ if cookies:
263
+ for c, v in cookies.items():
264
+ self.cookies[c] = v
265
+
266
+ if len(self.cookies):
267
+ headers.setdefault("Cookie", self.cookies.output(header="", sep=";"))
268
+
269
+ url = URL(path)
270
+ if query:
271
+ url = url.with_query(query)
272
+
273
+ # Setup client
274
+ scope.setdefault("client", ("127.0.0.1", random.randint(1024, 65535))) # noqa: S311
275
+
276
+ return dict(
277
+ {
278
+ "asgi": {"version": "3.0"},
279
+ "http_version": "1.1",
280
+ "path": url.path,
281
+ "query_string": url.raw_query_string.encode(),
282
+ "raw_path": url.raw_path.encode(),
283
+ "root_path": "",
284
+ "scheme": (scope.get("type") == "http" and self.base_url.scheme) or "ws",
285
+ "headers": [
286
+ (key.lower().encode(BASE_ENCODING), str(val).encode(BASE_ENCODING))
287
+ for key, val in (headers or {}).items()
288
+ ],
289
+ "server": ("127.0.0.1", self.base_url.port),
290
+ },
291
+ **scope,
292
+ )
293
+
294
+
295
+ def encode_multipart(data: dict) -> tuple[bytes, str]:
296
+ body = io.BytesIO()
297
+ boundary = binascii.hexlify(os.urandom(16))
298
+ for name, data_value in data.items():
299
+ value = data_value
300
+ headers = f'Content-Disposition: form-data; name="{ name }"'
301
+ if hasattr(value, "read"):
302
+ filename = getattr(value, "name", None)
303
+ if filename:
304
+ headers = f'{ headers }; filename="{ Path(filename).name }"'
305
+ content_type = mimetypes.guess_type(filename)[0] or "application/octet-stream"
306
+ headers = f"{ headers }\r\nContent-Type: { content_type }"
307
+ value = value.read()
308
+
309
+ body.write(b"--%b\r\n" % boundary)
310
+ body.write(headers.encode("utf-8"))
311
+ body.write(b"\r\n\r\n")
312
+ if isinstance(value, str):
313
+ value = value.encode("utf-8")
314
+ body.write(value)
315
+ body.write(b"\r\n")
316
+
317
+ body.write(b"--%b--\r\n" % boundary)
318
+ return body.getvalue(), (b"multipart/form-data; boundary=%s" % boundary).decode()
319
+
320
+
321
+ class Pipe:
322
+ __slots__ = (
323
+ "app_is_closed",
324
+ "app_queue",
325
+ "client_is_closed",
326
+ "client_queue",
327
+ "delay",
328
+ )
329
+
330
+ def __init__(self, delay: float = 1e-3):
331
+ self.delay = delay
332
+ self.app_is_closed = False
333
+ self.client_is_closed = False
334
+ self.app_queue: Deque[TASGIMessage] = deque()
335
+ self.client_queue: Deque[TASGIMessage] = deque()
336
+
337
+ async def send_to_client(self, msg: TASGIMessage):
338
+ if self.client_is_closed:
339
+ raise ASGIInvalidMessageError(msg.get("type"))
340
+
341
+ if msg.get("type") == "websocket.close":
342
+ self.client_is_closed = True
343
+
344
+ elif msg.get("type") == "http.response.body":
345
+ self.client_is_closed = not msg.get("more_body", False)
346
+
347
+ self.client_queue.append(msg)
348
+
349
+ async def send_to_app(self, msg: TASGIMessage):
350
+ if self.app_is_closed:
351
+ raise ASGIInvalidMessageError(msg.get("type"))
352
+
353
+ if msg.get("type") == "http.disconnect":
354
+ self.app_is_closed = True
355
+
356
+ self.app_queue.append(msg)
357
+
358
+ async def receive_from_client(self):
359
+ while not self.client_queue:
360
+ await aio_sleep(self.delay)
361
+ return self.client_queue.popleft()
362
+
363
+ async def receive_from_app(self):
364
+ while not self.app_queue:
365
+ await aio_sleep(self.delay)
366
+ return self.app_queue.popleft()
367
+
368
+ async def stream(self, data: bytes | AsyncGenerator[Any, bytes]):
369
+ if isinstance(data, bytes):
370
+ return await self.send_to_app(
371
+ {"type": "http.request", "body": data, "more_body": False},
372
+ )
373
+
374
+ async for chunk in data:
375
+ await self.send_to_app({"type": "http.request", "body": chunk, "more_body": True})
376
+ await self.send_to_app({"type": "http.request", "body": b"", "more_body": False})
377
+ return None
378
+
379
+
380
+ @asynccontextmanager
381
+ async def manage_lifespan(app, timeout: float = 3e-2):
382
+ """Manage `Lifespan <https://asgi.readthedocs.io/en/latest/specs/lifespan.html>`_ protocol."""
383
+ pipe = Pipe()
384
+
385
+ scope = {"type": "lifespan"}
386
+
387
+ async def safe_spawn():
388
+ with suppress(BaseException):
389
+ await app(scope, pipe.receive_from_app, pipe.send_to_client)
390
+
391
+ async with aio_spawn(safe_spawn) as task:
392
+ await pipe.send_to_app({"type": "lifespan.startup"})
393
+
394
+ with suppress(TimeoutError, asyncio.TimeoutError): # python 310
395
+ async with aio_timeout(timeout):
396
+ msg = await pipe.receive_from_client()
397
+ if msg["type"] == "lifespan.startup.failed":
398
+ await aio_cancel(task)
399
+
400
+ yield
401
+
402
+ await pipe.send_to_app({"type": "lifespan.shutdown"})
403
+ with suppress(TimeoutError, asyncio.TimeoutError): # python 310
404
+ async with aio_timeout(timeout):
405
+ await pipe.receive_from_client()
asgi_tools/types.py ADDED
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import (
4
+ TYPE_CHECKING,
5
+ Any,
6
+ Awaitable,
7
+ Callable,
8
+ Coroutine,
9
+ Mapping,
10
+ MutableMapping,
11
+ TypeVar,
12
+ Union,
13
+ )
14
+
15
+ if TYPE_CHECKING:
16
+ from .request import Request
17
+
18
+ TASGIMessage = Mapping[str, Any]
19
+ TASGISend = Callable[[TASGIMessage], Awaitable[None]]
20
+ TASGIReceive = Callable[[], Awaitable[TASGIMessage]]
21
+ TASGIScope = MutableMapping[str, Any]
22
+ TASGIHeaders = list[tuple[bytes, bytes]]
23
+ TASGIApp = Callable[[TASGIScope, TASGIReceive, TASGISend], Awaitable[Any]]
24
+
25
+ TJSON = Union[None, bool, int, float, str, list["TJSON"], Mapping[str, "TJSON"]]
26
+ TExceptionHandler = Callable[["Request", BaseException], Coroutine[None, None, Any]]
27
+
28
+ TV = TypeVar("TV")
29
+ TVCallable = TypeVar("TVCallable", bound=Callable)
30
+ TVAsyncCallable = TypeVar("TVAsyncCallable", bound=Callable[..., Coroutine])
31
+ TVExceptionHandler = TypeVar("TVExceptionHandler", bound=TExceptionHandler)
asgi_tools/utils.py ADDED
@@ -0,0 +1,110 @@
1
+ """ASGI-Tools Utils."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from functools import wraps
7
+ from inspect import isasyncgenfunction, iscoroutinefunction
8
+ from typing import TYPE_CHECKING, Callable, Coroutine, overload
9
+ from urllib.parse import unquote_to_bytes
10
+
11
+ from multidict import CIMultiDict
12
+
13
+ from .constants import BASE_ENCODING
14
+
15
+ if TYPE_CHECKING:
16
+ from .types import TV, TASGIHeaders, TVAsyncCallable
17
+
18
+
19
+ def is_awaitable(fn: Callable) -> bool:
20
+ """Check than the given function is awaitable."""
21
+ return iscoroutinefunction(fn) or isasyncgenfunction(fn)
22
+
23
+
24
+ @overload
25
+ def to_awaitable(fn: TVAsyncCallable) -> TVAsyncCallable: ...
26
+
27
+
28
+ @overload
29
+ def to_awaitable(fn: Callable[..., TV]) -> Callable[..., Coroutine[None, None, TV]]: ...
30
+
31
+
32
+ def to_awaitable(fn: Callable):
33
+ """Convert the given function to a coroutine function if it isn't"""
34
+ if is_awaitable(fn):
35
+ return fn
36
+
37
+ @wraps(fn)
38
+ async def coro(*args, **kwargs):
39
+ return fn(*args, **kwargs)
40
+
41
+ return coro
42
+
43
+
44
+ def parse_headers(headers: TASGIHeaders) -> CIMultiDict:
45
+ """Decode the given headers list."""
46
+ return CIMultiDict(
47
+ [(n.decode(BASE_ENCODING), v.decode(BASE_ENCODING)) for n, v in headers],
48
+ )
49
+
50
+
51
+ OPTION_HEADER_PIECE_RE = re.compile(
52
+ r"""
53
+ \s*,?\s* # newlines were replaced with commas
54
+ (?P<key>
55
+ "[^"\\]*(?:\\.[^"\\]*)*" # quoted string
56
+ |
57
+ [^\s;,=*]+ # token
58
+ )
59
+ (?:\*(?P<count>\d+))? # *1, optional continuation index
60
+ \s*
61
+ (?: # optionally followed by =value
62
+ (?: # equals sign, possibly with encoding
63
+ \*\s*=\s* # * indicates extended notation
64
+ (?: # optional encoding
65
+ (?P<encoding>[^\s]+?)
66
+ '(?P<language>[^\s]*?)'
67
+ )?
68
+ |
69
+ =\s* # basic notation
70
+ )
71
+ (?P<value>
72
+ "[^"\\]*(?:\\.[^"\\]*)*" # quoted string
73
+ |
74
+ [^;,]+ # token
75
+ )?
76
+ )?
77
+ \s*;?
78
+ """,
79
+ flags=re.VERBOSE,
80
+ )
81
+
82
+
83
+ def parse_options_header(value: str) -> tuple[str, dict[str, str]]:
84
+ """Parse the given content disposition header."""
85
+
86
+ options: dict[str, str] = {}
87
+ if not value:
88
+ return "", options
89
+
90
+ if ";" not in value:
91
+ return value, options
92
+
93
+ ctype, rest = value.split(";", 1)
94
+ while rest:
95
+ match = OPTION_HEADER_PIECE_RE.match(rest)
96
+ if not match:
97
+ break
98
+
99
+ option, count, encoding, _, value = match.groups()
100
+ if value is not None:
101
+ if encoding is not None:
102
+ value = unquote_to_bytes(value).decode(encoding)
103
+
104
+ if count:
105
+ value = options.get(option, "") + value
106
+
107
+ options[option] = value.strip('" ').replace("\\\\", "\\").replace('\\"', '"')
108
+ rest = rest[match.end() :]
109
+
110
+ return ctype, options
asgi_tools/view.py ADDED
@@ -0,0 +1,69 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from typing import TYPE_CHECKING, Final
5
+
6
+ if TYPE_CHECKING:
7
+ from collections.abc import Awaitable
8
+
9
+ from http_router.types import TMethods
10
+
11
+ from .request import Request
12
+ from .router import Router
13
+
14
+ HTTP_METHODS: Final = {
15
+ "GET",
16
+ "HEAD",
17
+ "POST",
18
+ "PUT",
19
+ "DELETE",
20
+ "CONNECT",
21
+ "OPTIONS",
22
+ "TRACE",
23
+ "PATCH",
24
+ }
25
+
26
+
27
+ class HTTPView:
28
+ """Class-based view pattern for handling HTTP method dispatching.
29
+
30
+ .. code-block:: python
31
+
32
+ @app.route('/custom')
33
+ class CustomEndpoint(HTTPView):
34
+
35
+ async def get(self, request):
36
+ return 'Hello from GET'
37
+
38
+ async def post(self, request):
39
+ return 'Hello from POST'
40
+
41
+ # ...
42
+ async def test_my_endpoint(client):
43
+ response = await client.get('/custom')
44
+ assert await response.text() == 'Hello from GET'
45
+
46
+ response = await client.post('/custom')
47
+ assert await response.text() == 'Hello from POST'
48
+
49
+ response = await client.put('/custom')
50
+ assert response.status_code == 405
51
+
52
+ """
53
+
54
+ def __new__(cls, request: Request, **opts):
55
+ """Init the class and call it."""
56
+ self = super().__new__(cls)
57
+ return self(request, **opts)
58
+
59
+ @classmethod
60
+ def __route__(cls, router: Router, *paths: str, methods: TMethods | None = None, **params):
61
+ """Bind the class view to the given router."""
62
+ view_methods = dict(inspect.getmembers(cls, inspect.isfunction))
63
+ methods = methods or [m for m in HTTP_METHODS if m.lower() in view_methods]
64
+ return router.bind(cls, *paths, methods=methods, **params)
65
+
66
+ def __call__(self, request: Request, **opts) -> Awaitable:
67
+ """Dispatch the given request by HTTP method."""
68
+ method = getattr(self, request.method.lower())
69
+ return method(request, **opts)