latch-asgi 0.3.0__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. {latch_asgi-0.3.0 → latch_asgi-1.0.0}/PKG-INFO +1 -2
  2. latch_asgi-1.0.0/latch_asgi/asgi_iface.py +72 -0
  3. {latch_asgi-0.3.0 → latch_asgi-1.0.0}/latch_asgi/auth.py +7 -15
  4. {latch_asgi-0.3.0 → latch_asgi-1.0.0}/latch_asgi/config.py +2 -0
  5. {latch_asgi-0.3.0 → latch_asgi-1.0.0}/latch_asgi/context/common.py +30 -19
  6. latch_asgi-1.0.0/latch_asgi/context/http.py +28 -0
  7. latch_asgi-1.0.0/latch_asgi/context/websocket.py +50 -0
  8. {latch_asgi-0.3.0 → latch_asgi-1.0.0}/latch_asgi/datadog_propagator.py +14 -12
  9. latch_asgi-1.0.0/latch_asgi/framework/common.py +40 -0
  10. {latch_asgi-0.3.0 → latch_asgi-1.0.0}/latch_asgi/framework/http.py +58 -56
  11. {latch_asgi-0.3.0 → latch_asgi-1.0.0}/latch_asgi/framework/websocket.py +98 -93
  12. latch_asgi-1.0.0/latch_asgi/server.py +495 -0
  13. {latch_asgi-0.3.0 → latch_asgi-1.0.0}/pyproject.toml +67 -11
  14. latch_asgi-0.3.0/latch_asgi/asgi_iface.py +0 -328
  15. latch_asgi-0.3.0/latch_asgi/context/http.py +0 -42
  16. latch_asgi-0.3.0/latch_asgi/context/websocket.py +0 -34
  17. latch_asgi-0.3.0/latch_asgi/framework/common.py +0 -7
  18. latch_asgi-0.3.0/latch_asgi/server.py +0 -338
  19. {latch_asgi-0.3.0 → latch_asgi-1.0.0}/COPYING +0 -0
  20. {latch_asgi-0.3.0 → latch_asgi-1.0.0}/README.md +0 -0
  21. {latch_asgi-0.3.0 → latch_asgi-1.0.0}/latch_asgi/__init__.py +0 -0
  22. {latch_asgi-0.3.0 → latch_asgi-1.0.0}/latch_asgi/context/__init__.py +0 -0
  23. {latch_asgi-0.3.0 → latch_asgi-1.0.0}/latch_asgi/framework/__init__.py +0 -0
  24. {latch_asgi-0.3.0 → latch_asgi-1.0.0}/latch_asgi/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: latch-asgi
3
- Version: 0.3.0
3
+ Version: 1.0.0
4
4
  Summary: ASGI python server
5
5
  Author-Email: Max Smolin <max@latch.bio>
6
6
  License: CC0-1.0
@@ -13,7 +13,6 @@ Requires-Dist: PyJWT[crypto]<3.0.0,>=2.6.0
13
13
  Requires-Dist: orjson<4.0.0,>=3.8.5
14
14
  Requires-Dist: opentelemetry-sdk<2.0.0,>=1.15.0
15
15
  Requires-Dist: opentelemetry-api<2.0.0,>=1.15.0
16
- Requires-Dist: opentelemetry-instrumentation-asgi<1.0,>=0.36b0
17
16
  Description-Content-Type: text/markdown
18
17
 
19
18
  # python-asgi
@@ -0,0 +1,72 @@
1
+ from collections.abc import Awaitable, Callable
2
+ from typing import TypeAlias
3
+
4
+ from hypercorn.typing import (
5
+ HTTPDisconnectEvent,
6
+ HTTPRequestEvent,
7
+ HTTPResponseBodyEvent,
8
+ HTTPResponseStartEvent,
9
+ HTTPServerPushEvent,
10
+ LifespanShutdownCompleteEvent,
11
+ LifespanShutdownEvent,
12
+ LifespanShutdownFailedEvent,
13
+ LifespanStartupCompleteEvent,
14
+ LifespanStartupEvent,
15
+ LifespanStartupFailedEvent,
16
+ WebsocketAcceptEvent,
17
+ WebsocketCloseEvent,
18
+ WebsocketConnectEvent,
19
+ WebsocketDisconnectEvent,
20
+ WebsocketReceiveEvent,
21
+ WebsocketResponseBodyEvent,
22
+ WebsocketResponseStartEvent,
23
+ WebsocketSendEvent,
24
+ )
25
+
26
+ # >>> Lifespan
27
+ LifespanReceiveEvent: TypeAlias = LifespanStartupEvent | LifespanShutdownEvent
28
+ LifespanReceiveCallable: TypeAlias = Callable[[], Awaitable[LifespanReceiveEvent]]
29
+
30
+ LifespanShutdownSendEvent: TypeAlias = (
31
+ LifespanShutdownCompleteEvent | LifespanShutdownFailedEvent
32
+ )
33
+ LifespanStartupSendEvent: TypeAlias = (
34
+ LifespanStartupCompleteEvent | LifespanStartupFailedEvent
35
+ )
36
+
37
+ LifespanSendEvent: TypeAlias = LifespanStartupSendEvent | LifespanShutdownSendEvent
38
+ LifespanSendCallable: TypeAlias = Callable[[LifespanSendEvent], Awaitable[None]]
39
+
40
+ # >>> HTTP
41
+
42
+ HTTPReceiveEvent: TypeAlias = HTTPRequestEvent | HTTPDisconnectEvent
43
+ HTTPReceiveCallable: TypeAlias = Callable[[], Awaitable[HTTPReceiveEvent]]
44
+
45
+ HTTPSendEvent: TypeAlias = (
46
+ HTTPResponseStartEvent
47
+ | HTTPResponseBodyEvent
48
+ | HTTPServerPushEvent
49
+ | HTTPDisconnectEvent
50
+ )
51
+ HTTPSendCallable: TypeAlias = Callable[[HTTPSendEvent], Awaitable[None]]
52
+
53
+ # >>> Websocket
54
+
55
+ WebsocketReceiveEventT: TypeAlias = (
56
+ WebsocketConnectEvent | WebsocketReceiveEvent | WebsocketDisconnectEvent
57
+ )
58
+ WebsocketReceiveCallable: TypeAlias = Callable[[], Awaitable[WebsocketReceiveEventT]]
59
+
60
+ WebsocketSendEventT: TypeAlias = (
61
+ WebsocketAcceptEvent
62
+ | WebsocketSendEvent
63
+ | WebsocketResponseBodyEvent
64
+ | WebsocketResponseStartEvent
65
+ | WebsocketCloseEvent
66
+ )
67
+ WebsocketSendCallable: TypeAlias = Callable[[WebsocketSendEventT], Awaitable[None]]
68
+
69
+ # >>> WWW
70
+
71
+ WWWReceiveCallable = HTTPReceiveCallable | WebsocketReceiveCallable
72
+ WWWSendCallable = HTTPSendCallable | WebsocketSendCallable
@@ -3,7 +3,7 @@
3
3
  import re
4
4
  from dataclasses import dataclass
5
5
  from http import HTTPStatus
6
- from typing import Literal
6
+ from typing import Literal, Self
7
7
 
8
8
  import jwt
9
9
  from jwt import PyJWKClient
@@ -37,14 +37,10 @@ class _HTTPUnauthorized(HTTPErrorResponse):
37
37
  """WARNING: HTTPForbidden is the correct error to use in virtually all cases"""
38
38
 
39
39
  def __init__(
40
- self,
40
+ self: Self,
41
41
  error_description: str,
42
- error: (
43
- Literal["invalid_request"]
44
- | Literal["invalid_token"]
45
- | Literal["insufficient_scope"]
46
- ),
47
- ):
42
+ error: (Literal["invalid_request", "invalid_token", "insufficient_scope"]),
43
+ ) -> None:
48
44
  escaped_description = error_description.replace('"', '\\"')
49
45
  super().__init__(
50
46
  HTTPStatus.UNAUTHORIZED,
@@ -63,7 +59,7 @@ class Authorization:
63
59
  execution_token: str | None = None
64
60
  sdk_token: str | None = None
65
61
 
66
- def unauthorized_if_none(self):
62
+ def unauthorized_if_none(self: Self) -> None:
67
63
  if self.oauth_sub is not None:
68
64
  return
69
65
  if self.execution_token is not None:
@@ -71,10 +67,7 @@ class Authorization:
71
67
  if self.sdk_token is not None:
72
68
  return
73
69
 
74
- raise _HTTPUnauthorized(
75
- "Authenticaton required",
76
- error="invalid_request",
77
- )
70
+ raise _HTTPUnauthorized("Authenticaton required", error="invalid_request")
78
71
 
79
72
 
80
73
  @trace_app_function
@@ -121,8 +114,7 @@ def get_signer_sub(auth_header: str) -> Authorization:
121
114
  jwt_key = jwk_client.get_signing_key_from_jwt(oauth_token).key
122
115
  except jwt.exceptions.InvalidTokenError as e:
123
116
  raise _HTTPUnauthorized(
124
- error_description="JWT decoding failed",
125
- error="invalid_token",
117
+ error_description="JWT decoding failed", error="invalid_token"
126
118
  ) from e
127
119
  except jwt.exceptions.PyJWKClientError:
128
120
  # fixme(maximsmol): gut this abomination
@@ -1,4 +1,5 @@
1
1
  from dataclasses import dataclass
2
+
2
3
  from latch_config.config import read_config
3
4
 
4
5
 
@@ -8,4 +9,5 @@ class AuthConfig:
8
9
  self_signed_jwk: str
9
10
  allow_spoofing: bool = False
10
11
 
12
+
11
13
  config = read_config(AuthConfig, "auth_")
@@ -1,61 +1,72 @@
1
1
  from dataclasses import dataclass, field
2
- from typing import Generic, TypeVar
3
-
4
- from latch_o11y.o11y import AttributesDict, app_tracer, trace_app_function
5
-
6
- from ..asgi_iface import WWWReceiveCallable, WWWScope, WWWSendCallable
2
+ from typing import Generic, Self, TypeVar
3
+
4
+ from hypercorn.typing import WWWScope
5
+ from latch_o11y.o11y import (
6
+ AttributesDict,
7
+ app_tracer,
8
+ dict_to_attrs,
9
+ trace_app_function,
10
+ )
11
+ from opentelemetry.util.types import AttributeValue
12
+
13
+ from ..asgi_iface import WWWReceiveCallable, WWWSendCallable
7
14
  from ..auth import Authorization, get_signer_sub
15
+ from ..framework.common import otel_header_whitelist
16
+ from ..framework.http import current_http_request_span
8
17
 
9
- # todo(ayush): this sucks
10
18
  Scope = TypeVar("Scope", bound=WWWScope)
11
- Send = TypeVar("Send", bound=WWWSendCallable)
12
- Receive = TypeVar("Receive", bound=WWWReceiveCallable)
19
+ SendCallable = TypeVar("SendCallable", bound=WWWSendCallable)
20
+ ReceiveCallable = TypeVar("ReceiveCallable", bound=WWWReceiveCallable)
13
21
 
14
22
 
15
23
  @dataclass
16
- class Context(Generic[Scope, Receive, Send]):
24
+ class Context(Generic[Scope, ReceiveCallable, SendCallable]):
17
25
  scope: Scope
18
- receive: Receive
19
- send: Send
26
+ receive: ReceiveCallable
27
+ send: SendCallable
20
28
 
21
29
  auth: Authorization = field(default_factory=Authorization, init=False)
22
30
 
23
31
  _header_cache: dict[bytes, bytes] = field(default_factory=dict, init=False)
24
32
  _db_response_idx: int = field(default=0, init=False)
25
33
 
26
- def __post_init__(self):
34
+ def __post_init__(self: Self) -> None:
27
35
  with app_tracer.start_as_current_span("find Authentication header"):
28
36
  auth_header = self.header_str("authorization")
29
37
 
30
38
  if auth_header is not None:
31
39
  self.auth = get_signer_sub(auth_header)
32
40
 
33
- def header(self, x: str | bytes):
41
+ if self.auth.oauth_sub is not None:
42
+ current_http_request_span().set_attribute("enduser.id", self.auth.oauth_sub)
43
+
44
+ def header(self: Self, x: str | bytes) -> bytes | None:
34
45
  if isinstance(x, str):
35
- x = x.encode("utf-8")
46
+ x = x.encode("latin-1")
36
47
 
37
48
  if x in self._header_cache:
38
49
  return self._header_cache[x]
39
50
 
40
- for k, v in self.scope.headers:
51
+ for k, v in self.scope["headers"]:
41
52
  self._header_cache[k] = v
42
53
  if k == x:
43
54
  return v
44
55
 
45
56
  return None
46
57
 
47
- def header_str(self, x: str | bytes):
58
+ def header_str(self: Self, x: str | bytes) -> str | None:
48
59
  res = self.header(x)
49
60
  if res is None:
50
61
  return None
51
62
 
52
63
  return res.decode("latin-1")
53
64
 
54
- def add_request_span_attrs(self, data: AttributesDict, prefix: str):
55
- raise NotImplementedError()
65
+ def add_request_span_attrs(self: Self, data: AttributesDict, prefix: str) -> None:
66
+ current_http_request_span().set_attributes(dict_to_attrs(data, prefix))
56
67
 
57
68
  @trace_app_function
58
- def add_db_response(self, data: AttributesDict):
69
+ def add_db_response(self: Self, data: AttributesDict) -> None:
59
70
  # todo(maximsmol): datadog has shit support for events
60
71
  # current_http_request_span().add_event(
61
72
  # f"database response {self._db_response_idx}", dict_to_attrs(data, "data")
@@ -0,0 +1,28 @@
1
+ from collections.abc import Awaitable, Callable
2
+ from dataclasses import dataclass
3
+ from typing import Any, Self, TypeAlias, TypeVar
4
+
5
+ from hypercorn.typing import HTTPScope
6
+ from latch_o11y.o11y import trace_app_function
7
+
8
+ from ..asgi_iface import HTTPReceiveCallable, HTTPSendCallable
9
+ from ..framework.http import HTTPMethod, receive_class_ext
10
+ from . import common
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ @dataclass
16
+ class Context(common.Context[HTTPScope, HTTPReceiveCallable, HTTPSendCallable]):
17
+ @trace_app_function
18
+ async def receive_request_payload(self: Self, cls: type[T]) -> T:
19
+ json, res = await receive_class_ext(self.receive, cls)
20
+
21
+ self.add_request_span_attrs(json, "http.request.body.data")
22
+
23
+ return res
24
+
25
+
26
+ HandlerResult: TypeAlias = Any | None
27
+ Handler: TypeAlias = Callable[[Context], Awaitable[HandlerResult]]
28
+ Route: TypeAlias = Handler | tuple[list[HTTPMethod], Handler]
@@ -0,0 +1,50 @@
1
+ from collections.abc import Awaitable, Callable
2
+ from dataclasses import dataclass
3
+ from typing import Any, Self, TypeAlias, TypeVar
4
+
5
+ from hypercorn.typing import WebsocketScope
6
+ from latch_o11y.o11y import AttributesDict, dict_to_attrs, trace_app_function
7
+ from opentelemetry.trace import get_current_span
8
+
9
+ from ..asgi_iface import WebsocketReceiveCallable, WebsocketSendCallable
10
+ from ..framework.common import Headers
11
+ from ..framework.websocket import (
12
+ accept_connection,
13
+ current_websocket_session_span,
14
+ receive_class_ext,
15
+ send_websocket_auto,
16
+ )
17
+ from . import common
18
+
19
+ T = TypeVar("T")
20
+
21
+
22
+ @dataclass
23
+ class Context(
24
+ common.Context[WebsocketScope, WebsocketReceiveCallable, WebsocketSendCallable]
25
+ ):
26
+ def add_session_span_attrs(self: Self, data: AttributesDict, prefix: str) -> None:
27
+ current_websocket_session_span().set_attributes(dict_to_attrs(data, prefix))
28
+
29
+ @trace_app_function
30
+ async def accept_connection(
31
+ self: Self, *, subprotocol: str | None = None, headers: Headers | None = None
32
+ ) -> None:
33
+ await accept_connection(self.send, subprotocol=subprotocol, headers=headers)
34
+
35
+ @trace_app_function
36
+ async def receive_message(self: Self, cls: type[T]) -> T:
37
+ json, res = await receive_class_ext(self.receive, cls)
38
+
39
+ get_current_span().set_attributes(dict_to_attrs(json, "payload"))
40
+
41
+ return res
42
+
43
+ @trace_app_function
44
+ async def send_message(self: Self, data: Any) -> None:
45
+ await send_websocket_auto(self.send, data)
46
+
47
+
48
+ HandlerResult = str
49
+ Handler: TypeAlias = Callable[[Context], Awaitable[HandlerResult]]
50
+ Route: TypeAlias = Handler
@@ -1,3 +1,5 @@
1
+ from typing import Self
2
+
1
3
  from opentelemetry import trace
2
4
  from opentelemetry.context.context import Context
3
5
  from opentelemetry.propagators import textmap
@@ -6,7 +8,7 @@ from opentelemetry.propagators import textmap
6
8
  class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
7
9
  # https://github.com/open-telemetry/opentelemetry-python-contrib/blob/934af7ea4f9b1e0294ced6a014d6eefdda156b2b/exporter/opentelemetry-exporter-datadog/src/opentelemetry/exporter/datadog/propagator.py
8
10
  def extract(
9
- self,
11
+ self: Self,
10
12
  carrier: textmap.CarrierT,
11
13
  context: Context | None = None,
12
14
  getter: textmap.Getter[textmap.CarrierT] = textmap.default_getter,
@@ -28,12 +30,16 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
28
30
  sampling_priority = getter.get(carrier, "x-datadog-sampling-priority")
29
31
 
30
32
  trace_flags = trace.TraceFlags()
31
- if sampling_priority is not None and len(sampling_priority) > 0:
32
- if int(sampling_priority[0]) in (
33
+ if (
34
+ sampling_priority is not None
35
+ and len(sampling_priority) > 0
36
+ and int(sampling_priority[0])
37
+ in (
33
38
  1, # auto keep
34
39
  2, # user keep
35
- ):
36
- trace_flags = trace.TraceFlags(trace.TraceFlags.SAMPLED)
40
+ )
41
+ ):
42
+ trace_flags = trace.TraceFlags(trace.TraceFlags.SAMPLED)
37
43
 
38
44
  dd_origin = getter.get(carrier, "x-datadog-origin")
39
45
 
@@ -51,7 +57,7 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
51
57
  return trace.set_span_in_context(trace.NonRecordingSpan(span_context), context)
52
58
 
53
59
  def inject(
54
- self,
60
+ self: Self,
55
61
  carrier: textmap.CarrierT,
56
62
  context: Context | None = None,
57
63
  setter: textmap.Setter[textmap.CarrierT] = textmap.default_setter,
@@ -69,11 +75,7 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
69
75
  "x-datadog-trace-id",
70
76
  str(span_context.trace_id & 0xFFFF_FFFF_FFFF_FFFF),
71
77
  )
72
- setter.set(
73
- carrier,
74
- "x-datadog-parent-id",
75
- str(span_context.span_id),
76
- )
78
+ setter.set(carrier, "x-datadog-parent-id", str(span_context.span_id))
77
79
  setter.set(
78
80
  carrier,
79
81
  "x-datadog-sampling-priority",
@@ -87,7 +89,7 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
87
89
  )
88
90
 
89
91
  @property
90
- def fields(self) -> set[str]:
92
+ def fields(self: Self) -> set[str]:
91
93
  return {
92
94
  "x-datadog-trace-id",
93
95
  "x-datadog-parent-id",
@@ -0,0 +1,40 @@
1
+ from typing import TypeAlias
2
+
3
+ from opentelemetry.trace import get_tracer
4
+
5
+ Headers: TypeAlias = dict[str | bytes, str | bytes]
6
+
7
+ tracer = get_tracer(__name__)
8
+
9
+ otel_header_whitelist = {
10
+ "host",
11
+ "content-type",
12
+ "content-length",
13
+ "accept",
14
+ "accept-encoding",
15
+ "accept-language",
16
+ "user-agent",
17
+ "dnt",
18
+ "sec-fetch-dest",
19
+ "sec-fetch-mode",
20
+ "sec-fetch-site",
21
+ "sec-fetch-user",
22
+ "sec-gpc",
23
+ "te",
24
+ "upgrade-insecure-requests",
25
+ "device-memory",
26
+ "downlink",
27
+ "dpr",
28
+ "ect",
29
+ "rtt",
30
+ "sec-ch-prefers-color-scheme",
31
+ "sec-ch-prefers-reduced-motion",
32
+ "sec-ch-ua",
33
+ "sec-ch-ua-arch",
34
+ "sec-ch-ua-full-version",
35
+ "sec-ch-ua-mobile",
36
+ "sec-ch-ua-model",
37
+ "sec-ch-ua-platform",
38
+ "sec-ch-ua-platfrom-version",
39
+ "viewport-width",
40
+ }
@@ -1,10 +1,10 @@
1
1
  from http import HTTPStatus
2
- from typing import Any, Literal, TypeAlias, TypeVar, cast
2
+ from typing import Any, Literal, Self, TypeAlias, TypeVar, cast
3
3
 
4
- import opentelemetry.context as context
5
4
  import orjson
6
5
  from latch_data_validation.data_validation import DataValidationError, validate
7
6
  from latch_o11y.o11y import trace_function, trace_function_with_span
7
+ from opentelemetry import context
8
8
  from opentelemetry.trace.span import Span
9
9
 
10
10
  from ..asgi_iface import (
@@ -17,17 +17,9 @@ from .common import Headers, tracer
17
17
 
18
18
  T = TypeVar("T")
19
19
 
20
- HTTPMethod: TypeAlias = (
21
- Literal["GET"]
22
- | Literal["HEAD"]
23
- | Literal["POST"]
24
- | Literal["PUT"]
25
- | Literal["DELETE"]
26
- | Literal["CONNECT"]
27
- | Literal["OPTIONS"]
28
- | Literal["TRACE"]
29
- | Literal["PATCH"]
30
- )
20
+ HTTPMethod: TypeAlias = Literal[
21
+ "GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"
22
+ ]
31
23
 
32
24
  # >>> O11y
33
25
 
@@ -42,34 +34,68 @@ def current_http_request_span() -> Span:
42
34
 
43
35
 
44
36
  class HTTPErrorResponse(RuntimeError):
45
- def __init__(self, status: HTTPStatus, data: Any, *, headers: Headers = {}):
37
+ def __init__(
38
+ self: Self, status: HTTPStatus, data: Any, *, headers: Headers | None = None
39
+ ) -> None:
40
+ if headers is None:
41
+ headers = {}
42
+
46
43
  self.status = status
47
44
  self.data = data
48
45
  self.headers = headers
49
46
 
50
47
 
51
48
  class HTTPInternalServerError(HTTPErrorResponse):
52
- def __init__(self, data: Any, *, headers: Headers = {}):
49
+ def __init__(self: Self, data: Any, *, headers: Headers | None = None) -> None:
53
50
  super().__init__(HTTPStatus.INTERNAL_SERVER_ERROR, data, headers=headers)
54
51
 
55
52
 
56
53
  class HTTPBadRequest(HTTPErrorResponse):
57
- def __init__(self, data: Any, *, headers: Headers = {}):
54
+ def __init__(self: Self, data: Any, *, headers: Headers | None = None) -> None:
58
55
  super().__init__(HTTPStatus.BAD_REQUEST, data, headers=headers)
59
56
 
60
57
 
61
58
  class HTTPForbidden(HTTPErrorResponse):
62
- def __init__(self, data: Any, *, headers: Headers = {}):
59
+ def __init__(self: Self, data: Any, *, headers: Headers | None = None) -> None:
63
60
  super().__init__(HTTPStatus.FORBIDDEN, data, headers=headers)
64
61
 
65
62
 
66
- class HTTPConnectionClosedError(RuntimeError): ...
63
+ class HTTPConnectionClosedError(RuntimeError):
64
+ ...
67
65
 
68
66
 
69
- # >>> I/O
67
+ # >>>
68
+ # I/O
69
+ # >>>
70
70
 
71
71
  # todo(maximsmol): add max body length limit by default
72
72
 
73
+ # >>> Receive
74
+
75
+
76
+ @trace_function(tracer)
77
+ async def receive_data(receive: HTTPReceiveCallable) -> bytes:
78
+ res = b""
79
+ more_body = True
80
+ while more_body:
81
+ msg = await receive()
82
+ if msg["type"] == "http.disconnect":
83
+ raise HTTPConnectionClosedError
84
+
85
+ res += msg["body"]
86
+ more_body = msg["more_body"]
87
+
88
+ # todo(maximsmol): accumulate instead of overriding
89
+ # todo(maximsmol): probably use the content-length header if present?
90
+ current_http_request_span().set_attribute("http.request.body.size", len(res))
91
+
92
+ return res
93
+
94
+
95
+ @trace_function(tracer)
96
+ async def receive_json(receive: HTTPReceiveCallable) -> Any:
97
+ return orjson.loads(await receive_data(receive))
98
+
73
99
 
74
100
  async def receive_class_ext(
75
101
  receive: HTTPReceiveCallable, cls: type[T]
@@ -87,47 +113,25 @@ async def receive_class(receive: HTTPReceiveCallable, cls: type[T]) -> T:
87
113
  return (await receive_class_ext(receive, cls))[1]
88
114
 
89
115
 
90
- @trace_function(tracer)
91
- async def receive_json(receive: HTTPReceiveCallable) -> Any:
92
- return orjson.loads(await receive_data(receive))
93
-
94
-
95
- async def receive_data(receive: HTTPReceiveCallable):
96
- res = b""
97
- more_body = True
98
- while more_body:
99
- with tracer.start_as_current_span("read chunk") as s:
100
- msg = await receive()
101
- if msg.type == "http.disconnect":
102
- raise HTTPConnectionClosedError()
103
-
104
- res += msg.body
105
- more_body = msg.more_body
116
+ # >>> Send
106
117
 
107
- s.set_attributes({"size": len(msg.body), "more_body": more_body})
108
118
 
109
- # todo(maximsmol): accumulate instead of overriding
110
- # todo(maximsmol): probably use the content-length header if present?
111
- current_http_request_span().set_attribute("http.request_content_length", len(res))
112
-
113
- return res
114
-
115
-
116
- @trace_function_with_span(tracer)
119
+ @trace_function(tracer)
117
120
  async def send_http_data(
118
- s: Span,
119
121
  send: HTTPSendCallable,
120
122
  status: HTTPStatus,
121
123
  data: str | bytes,
122
124
  /,
123
125
  *,
124
- content_type: str | bytes | None = "text/plain",
125
- headers: Headers = {},
126
- ):
126
+ content_type: str | bytes | None = b"text/plain",
127
+ headers: Headers | None = None,
128
+ ) -> None:
129
+ if headers is None:
130
+ headers = {}
131
+
127
132
  if isinstance(data, str):
128
133
  data = data.encode("utf-8")
129
134
 
130
- s.set_attribute("size", len(data))
131
135
  headers_to_send: list[tuple[bytes, bytes]] = [
132
136
  (b"Content-Length", str(len(data)).encode("latin-1"))
133
137
  ]
@@ -152,8 +156,6 @@ async def send_http_data(
152
156
  HTTPResponseBodyEvent(type="http.response.body", body=data, more_body=False)
153
157
  )
154
158
 
155
- current_http_request_span().set_attribute("http.response_content_length", len(data))
156
-
157
159
 
158
160
  @trace_function(tracer)
159
161
  async def send_json(
@@ -163,8 +165,8 @@ async def send_json(
163
165
  /,
164
166
  *,
165
167
  content_type: str = "application/json",
166
- headers: Headers = {},
167
- ):
168
+ headers: Headers | None = None,
169
+ ) -> None:
168
170
  return await send_http_data(
169
171
  send, status, orjson.dumps(data), content_type=content_type, headers=headers
170
172
  )
@@ -177,9 +179,9 @@ async def send_auto(
177
179
  data: str | bytes | Any,
178
180
  /,
179
181
  *,
180
- headers: Headers = {},
181
- ):
182
- if isinstance(data, str) or isinstance(data, bytes):
182
+ headers: Headers | None = None,
183
+ ) -> None:
184
+ if isinstance(data, str | bytes):
183
185
  return await send_http_data(send, status, data, headers=headers)
184
186
 
185
187
  return await send_json(send, status, data, headers=headers)