latch-asgi 0.3.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- latch_asgi/asgi_iface.py +33 -289
- latch_asgi/auth.py +7 -15
- latch_asgi/config.py +2 -0
- latch_asgi/context/common.py +30 -19
- latch_asgi/context/http.py +9 -23
- latch_asgi/context/websocket.py +35 -19
- latch_asgi/datadog_propagator.py +14 -12
- latch_asgi/framework/common.py +33 -0
- latch_asgi/framework/http.py +58 -56
- latch_asgi/framework/websocket.py +98 -93
- latch_asgi/server.py +377 -220
- {latch_asgi-0.3.0.dist-info → latch_asgi-1.0.0.dist-info}/METADATA +1 -2
- latch_asgi-1.0.0.dist-info/RECORD +19 -0
- latch_asgi-0.3.0.dist-info/RECORD +0 -19
- {latch_asgi-0.3.0.dist-info → latch_asgi-1.0.0.dist-info}/WHEEL +0 -0
- {latch_asgi-0.3.0.dist-info → latch_asgi-1.0.0.dist-info}/licenses/COPYING +0 -0
latch_asgi/datadog_propagator.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Self
|
|
2
|
+
|
|
1
3
|
from opentelemetry import trace
|
|
2
4
|
from opentelemetry.context.context import Context
|
|
3
5
|
from opentelemetry.propagators import textmap
|
|
@@ -6,7 +8,7 @@ from opentelemetry.propagators import textmap
|
|
|
6
8
|
class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
|
|
7
9
|
# https://github.com/open-telemetry/opentelemetry-python-contrib/blob/934af7ea4f9b1e0294ced6a014d6eefdda156b2b/exporter/opentelemetry-exporter-datadog/src/opentelemetry/exporter/datadog/propagator.py
|
|
8
10
|
def extract(
|
|
9
|
-
self,
|
|
11
|
+
self: Self,
|
|
10
12
|
carrier: textmap.CarrierT,
|
|
11
13
|
context: Context | None = None,
|
|
12
14
|
getter: textmap.Getter[textmap.CarrierT] = textmap.default_getter,
|
|
@@ -28,12 +30,16 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
|
|
|
28
30
|
sampling_priority = getter.get(carrier, "x-datadog-sampling-priority")
|
|
29
31
|
|
|
30
32
|
trace_flags = trace.TraceFlags()
|
|
31
|
-
if
|
|
32
|
-
|
|
33
|
+
if (
|
|
34
|
+
sampling_priority is not None
|
|
35
|
+
and len(sampling_priority) > 0
|
|
36
|
+
and int(sampling_priority[0])
|
|
37
|
+
in (
|
|
33
38
|
1, # auto keep
|
|
34
39
|
2, # user keep
|
|
35
|
-
)
|
|
36
|
-
|
|
40
|
+
)
|
|
41
|
+
):
|
|
42
|
+
trace_flags = trace.TraceFlags(trace.TraceFlags.SAMPLED)
|
|
37
43
|
|
|
38
44
|
dd_origin = getter.get(carrier, "x-datadog-origin")
|
|
39
45
|
|
|
@@ -51,7 +57,7 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
|
|
|
51
57
|
return trace.set_span_in_context(trace.NonRecordingSpan(span_context), context)
|
|
52
58
|
|
|
53
59
|
def inject(
|
|
54
|
-
self,
|
|
60
|
+
self: Self,
|
|
55
61
|
carrier: textmap.CarrierT,
|
|
56
62
|
context: Context | None = None,
|
|
57
63
|
setter: textmap.Setter[textmap.CarrierT] = textmap.default_setter,
|
|
@@ -69,11 +75,7 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
|
|
|
69
75
|
"x-datadog-trace-id",
|
|
70
76
|
str(span_context.trace_id & 0xFFFF_FFFF_FFFF_FFFF),
|
|
71
77
|
)
|
|
72
|
-
setter.set(
|
|
73
|
-
carrier,
|
|
74
|
-
"x-datadog-parent-id",
|
|
75
|
-
str(span_context.span_id),
|
|
76
|
-
)
|
|
78
|
+
setter.set(carrier, "x-datadog-parent-id", str(span_context.span_id))
|
|
77
79
|
setter.set(
|
|
78
80
|
carrier,
|
|
79
81
|
"x-datadog-sampling-priority",
|
|
@@ -87,7 +89,7 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
|
|
|
87
89
|
)
|
|
88
90
|
|
|
89
91
|
@property
|
|
90
|
-
def fields(self) -> set[str]:
|
|
92
|
+
def fields(self: Self) -> set[str]:
|
|
91
93
|
return {
|
|
92
94
|
"x-datadog-trace-id",
|
|
93
95
|
"x-datadog-parent-id",
|
latch_asgi/framework/common.py
CHANGED
|
@@ -5,3 +5,36 @@ from opentelemetry.trace import get_tracer
|
|
|
5
5
|
Headers: TypeAlias = dict[str | bytes, str | bytes]
|
|
6
6
|
|
|
7
7
|
tracer = get_tracer(__name__)
|
|
8
|
+
|
|
9
|
+
otel_header_whitelist = {
|
|
10
|
+
"host",
|
|
11
|
+
"content-type",
|
|
12
|
+
"content-length",
|
|
13
|
+
"accept",
|
|
14
|
+
"accept-encoding",
|
|
15
|
+
"accept-language",
|
|
16
|
+
"user-agent",
|
|
17
|
+
"dnt",
|
|
18
|
+
"sec-fetch-dest",
|
|
19
|
+
"sec-fetch-mode",
|
|
20
|
+
"sec-fetch-site",
|
|
21
|
+
"sec-fetch-user",
|
|
22
|
+
"sec-gpc",
|
|
23
|
+
"te",
|
|
24
|
+
"upgrade-insecure-requests",
|
|
25
|
+
"device-memory",
|
|
26
|
+
"downlink",
|
|
27
|
+
"dpr",
|
|
28
|
+
"ect",
|
|
29
|
+
"rtt",
|
|
30
|
+
"sec-ch-prefers-color-scheme",
|
|
31
|
+
"sec-ch-prefers-reduced-motion",
|
|
32
|
+
"sec-ch-ua",
|
|
33
|
+
"sec-ch-ua-arch",
|
|
34
|
+
"sec-ch-ua-full-version",
|
|
35
|
+
"sec-ch-ua-mobile",
|
|
36
|
+
"sec-ch-ua-model",
|
|
37
|
+
"sec-ch-ua-platform",
|
|
38
|
+
"sec-ch-ua-platfrom-version",
|
|
39
|
+
"viewport-width",
|
|
40
|
+
}
|
latch_asgi/framework/http.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from http import HTTPStatus
|
|
2
|
-
from typing import Any, Literal, TypeAlias, TypeVar, cast
|
|
2
|
+
from typing import Any, Literal, Self, TypeAlias, TypeVar, cast
|
|
3
3
|
|
|
4
|
-
import opentelemetry.context as context
|
|
5
4
|
import orjson
|
|
6
5
|
from latch_data_validation.data_validation import DataValidationError, validate
|
|
7
6
|
from latch_o11y.o11y import trace_function, trace_function_with_span
|
|
7
|
+
from opentelemetry import context
|
|
8
8
|
from opentelemetry.trace.span import Span
|
|
9
9
|
|
|
10
10
|
from ..asgi_iface import (
|
|
@@ -17,17 +17,9 @@ from .common import Headers, tracer
|
|
|
17
17
|
|
|
18
18
|
T = TypeVar("T")
|
|
19
19
|
|
|
20
|
-
HTTPMethod: TypeAlias =
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
| Literal["POST"]
|
|
24
|
-
| Literal["PUT"]
|
|
25
|
-
| Literal["DELETE"]
|
|
26
|
-
| Literal["CONNECT"]
|
|
27
|
-
| Literal["OPTIONS"]
|
|
28
|
-
| Literal["TRACE"]
|
|
29
|
-
| Literal["PATCH"]
|
|
30
|
-
)
|
|
20
|
+
HTTPMethod: TypeAlias = Literal[
|
|
21
|
+
"GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"
|
|
22
|
+
]
|
|
31
23
|
|
|
32
24
|
# >>> O11y
|
|
33
25
|
|
|
@@ -42,34 +34,68 @@ def current_http_request_span() -> Span:
|
|
|
42
34
|
|
|
43
35
|
|
|
44
36
|
class HTTPErrorResponse(RuntimeError):
|
|
45
|
-
def __init__(
|
|
37
|
+
def __init__(
|
|
38
|
+
self: Self, status: HTTPStatus, data: Any, *, headers: Headers | None = None
|
|
39
|
+
) -> None:
|
|
40
|
+
if headers is None:
|
|
41
|
+
headers = {}
|
|
42
|
+
|
|
46
43
|
self.status = status
|
|
47
44
|
self.data = data
|
|
48
45
|
self.headers = headers
|
|
49
46
|
|
|
50
47
|
|
|
51
48
|
class HTTPInternalServerError(HTTPErrorResponse):
|
|
52
|
-
def __init__(self, data: Any, *, headers: Headers =
|
|
49
|
+
def __init__(self: Self, data: Any, *, headers: Headers | None = None) -> None:
|
|
53
50
|
super().__init__(HTTPStatus.INTERNAL_SERVER_ERROR, data, headers=headers)
|
|
54
51
|
|
|
55
52
|
|
|
56
53
|
class HTTPBadRequest(HTTPErrorResponse):
|
|
57
|
-
def __init__(self, data: Any, *, headers: Headers =
|
|
54
|
+
def __init__(self: Self, data: Any, *, headers: Headers | None = None) -> None:
|
|
58
55
|
super().__init__(HTTPStatus.BAD_REQUEST, data, headers=headers)
|
|
59
56
|
|
|
60
57
|
|
|
61
58
|
class HTTPForbidden(HTTPErrorResponse):
|
|
62
|
-
def __init__(self, data: Any, *, headers: Headers =
|
|
59
|
+
def __init__(self: Self, data: Any, *, headers: Headers | None = None) -> None:
|
|
63
60
|
super().__init__(HTTPStatus.FORBIDDEN, data, headers=headers)
|
|
64
61
|
|
|
65
62
|
|
|
66
|
-
class HTTPConnectionClosedError(RuntimeError):
|
|
63
|
+
class HTTPConnectionClosedError(RuntimeError):
|
|
64
|
+
...
|
|
67
65
|
|
|
68
66
|
|
|
69
|
-
# >>>
|
|
67
|
+
# >>>
|
|
68
|
+
# I/O
|
|
69
|
+
# >>>
|
|
70
70
|
|
|
71
71
|
# todo(maximsmol): add max body length limit by default
|
|
72
72
|
|
|
73
|
+
# >>> Receive
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@trace_function(tracer)
|
|
77
|
+
async def receive_data(receive: HTTPReceiveCallable) -> bytes:
|
|
78
|
+
res = b""
|
|
79
|
+
more_body = True
|
|
80
|
+
while more_body:
|
|
81
|
+
msg = await receive()
|
|
82
|
+
if msg["type"] == "http.disconnect":
|
|
83
|
+
raise HTTPConnectionClosedError
|
|
84
|
+
|
|
85
|
+
res += msg["body"]
|
|
86
|
+
more_body = msg["more_body"]
|
|
87
|
+
|
|
88
|
+
# todo(maximsmol): accumulate instead of overriding
|
|
89
|
+
# todo(maximsmol): probably use the content-length header if present?
|
|
90
|
+
current_http_request_span().set_attribute("http.request.body.size", len(res))
|
|
91
|
+
|
|
92
|
+
return res
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@trace_function(tracer)
|
|
96
|
+
async def receive_json(receive: HTTPReceiveCallable) -> Any:
|
|
97
|
+
return orjson.loads(await receive_data(receive))
|
|
98
|
+
|
|
73
99
|
|
|
74
100
|
async def receive_class_ext(
|
|
75
101
|
receive: HTTPReceiveCallable, cls: type[T]
|
|
@@ -87,47 +113,25 @@ async def receive_class(receive: HTTPReceiveCallable, cls: type[T]) -> T:
|
|
|
87
113
|
return (await receive_class_ext(receive, cls))[1]
|
|
88
114
|
|
|
89
115
|
|
|
90
|
-
|
|
91
|
-
async def receive_json(receive: HTTPReceiveCallable) -> Any:
|
|
92
|
-
return orjson.loads(await receive_data(receive))
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
async def receive_data(receive: HTTPReceiveCallable):
|
|
96
|
-
res = b""
|
|
97
|
-
more_body = True
|
|
98
|
-
while more_body:
|
|
99
|
-
with tracer.start_as_current_span("read chunk") as s:
|
|
100
|
-
msg = await receive()
|
|
101
|
-
if msg.type == "http.disconnect":
|
|
102
|
-
raise HTTPConnectionClosedError()
|
|
103
|
-
|
|
104
|
-
res += msg.body
|
|
105
|
-
more_body = msg.more_body
|
|
116
|
+
# >>> Send
|
|
106
117
|
|
|
107
|
-
s.set_attributes({"size": len(msg.body), "more_body": more_body})
|
|
108
118
|
|
|
109
|
-
|
|
110
|
-
# todo(maximsmol): probably use the content-length header if present?
|
|
111
|
-
current_http_request_span().set_attribute("http.request_content_length", len(res))
|
|
112
|
-
|
|
113
|
-
return res
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
@trace_function_with_span(tracer)
|
|
119
|
+
@trace_function(tracer)
|
|
117
120
|
async def send_http_data(
|
|
118
|
-
s: Span,
|
|
119
121
|
send: HTTPSendCallable,
|
|
120
122
|
status: HTTPStatus,
|
|
121
123
|
data: str | bytes,
|
|
122
124
|
/,
|
|
123
125
|
*,
|
|
124
|
-
content_type: str | bytes | None = "text/plain",
|
|
125
|
-
headers: Headers =
|
|
126
|
-
):
|
|
126
|
+
content_type: str | bytes | None = b"text/plain",
|
|
127
|
+
headers: Headers | None = None,
|
|
128
|
+
) -> None:
|
|
129
|
+
if headers is None:
|
|
130
|
+
headers = {}
|
|
131
|
+
|
|
127
132
|
if isinstance(data, str):
|
|
128
133
|
data = data.encode("utf-8")
|
|
129
134
|
|
|
130
|
-
s.set_attribute("size", len(data))
|
|
131
135
|
headers_to_send: list[tuple[bytes, bytes]] = [
|
|
132
136
|
(b"Content-Length", str(len(data)).encode("latin-1"))
|
|
133
137
|
]
|
|
@@ -152,8 +156,6 @@ async def send_http_data(
|
|
|
152
156
|
HTTPResponseBodyEvent(type="http.response.body", body=data, more_body=False)
|
|
153
157
|
)
|
|
154
158
|
|
|
155
|
-
current_http_request_span().set_attribute("http.response_content_length", len(data))
|
|
156
|
-
|
|
157
159
|
|
|
158
160
|
@trace_function(tracer)
|
|
159
161
|
async def send_json(
|
|
@@ -163,8 +165,8 @@ async def send_json(
|
|
|
163
165
|
/,
|
|
164
166
|
*,
|
|
165
167
|
content_type: str = "application/json",
|
|
166
|
-
headers: Headers =
|
|
167
|
-
):
|
|
168
|
+
headers: Headers | None = None,
|
|
169
|
+
) -> None:
|
|
168
170
|
return await send_http_data(
|
|
169
171
|
send, status, orjson.dumps(data), content_type=content_type, headers=headers
|
|
170
172
|
)
|
|
@@ -177,9 +179,9 @@ async def send_auto(
|
|
|
177
179
|
data: str | bytes | Any,
|
|
178
180
|
/,
|
|
179
181
|
*,
|
|
180
|
-
headers: Headers =
|
|
181
|
-
):
|
|
182
|
-
if isinstance(data, str
|
|
182
|
+
headers: Headers | None = None,
|
|
183
|
+
) -> None:
|
|
184
|
+
if isinstance(data, str | bytes):
|
|
183
185
|
return await send_http_data(send, status, data, headers=headers)
|
|
184
186
|
|
|
185
187
|
return await send_json(send, status, data, headers=headers)
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
from enum import Enum
|
|
2
|
-
from typing import Any, TypeVar, cast
|
|
2
|
+
from typing import Any, Self, TypeVar, cast
|
|
3
3
|
|
|
4
|
-
import
|
|
4
|
+
import orjson
|
|
5
5
|
from latch_data_validation.data_validation import DataValidationError, validate
|
|
6
6
|
from latch_o11y.o11y import trace_function, trace_function_with_span
|
|
7
|
+
from opentelemetry import context
|
|
7
8
|
from opentelemetry.trace.span import Span
|
|
8
|
-
import
|
|
9
|
+
from opentelemetry.util.types import AttributeValue
|
|
9
10
|
|
|
10
11
|
from ..asgi_iface import (
|
|
11
12
|
WebsocketAcceptEvent,
|
|
@@ -18,20 +19,18 @@ from .common import Headers, tracer
|
|
|
18
19
|
|
|
19
20
|
T = TypeVar("T")
|
|
20
21
|
|
|
21
|
-
|
|
22
|
+
websocket_session_span_key = context.create_key("websocket_message_span")
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
def
|
|
25
|
-
return cast(Span, context.get_value(
|
|
25
|
+
def current_websocket_session_span() -> Span:
|
|
26
|
+
return cast(Span, context.get_value(websocket_session_span_key))
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
# >>> Error classes
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
class WebsocketStatus(int, Enum):
|
|
32
|
-
"""
|
|
33
|
-
https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
|
|
34
|
-
"""
|
|
33
|
+
"""https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1"""
|
|
35
34
|
|
|
36
35
|
normal = 1000
|
|
37
36
|
"""
|
|
@@ -134,63 +133,108 @@ class WebsocketStatus(int, Enum):
|
|
|
134
133
|
|
|
135
134
|
|
|
136
135
|
class WebsocketErrorResponse(RuntimeError):
|
|
137
|
-
def __init__(self, status: WebsocketStatus, data: Any
|
|
138
|
-
super().__init__()
|
|
136
|
+
def __init__(self: Self, status: WebsocketStatus, data: Any) -> None:
|
|
139
137
|
self.status = status
|
|
140
138
|
self.data = data
|
|
141
|
-
self.headers = headers
|
|
142
139
|
|
|
143
140
|
|
|
144
141
|
class WebsocketBadMessage(WebsocketErrorResponse):
|
|
145
|
-
def __init__(self, data: Any
|
|
146
|
-
super().__init__(WebsocketStatus.policy_violation, data
|
|
142
|
+
def __init__(self: Self, data: Any) -> None:
|
|
143
|
+
super().__init__(WebsocketStatus.policy_violation, data)
|
|
147
144
|
|
|
148
145
|
|
|
149
146
|
class WebsocketInternalServerError(WebsocketErrorResponse):
|
|
150
|
-
def __init__(self, data: Any
|
|
151
|
-
super().__init__(WebsocketStatus.server_error, data
|
|
147
|
+
def __init__(self: Self, data: Any) -> None:
|
|
148
|
+
super().__init__(WebsocketStatus.server_error, data)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class WebsocketConnectionClosedError(RuntimeError):
|
|
152
|
+
def __init__(self: Self, code: WebsocketStatus) -> None:
|
|
153
|
+
self.code = code
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# >>>
|
|
157
|
+
# I/O
|
|
158
|
+
# >>>
|
|
159
|
+
|
|
160
|
+
# >>> Send Lifecycle
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@trace_function(tracer)
|
|
164
|
+
async def accept_connection(
|
|
165
|
+
send: WebsocketSendCallable,
|
|
166
|
+
/,
|
|
167
|
+
*,
|
|
168
|
+
subprotocol: str | None = None,
|
|
169
|
+
headers: Headers | None = None,
|
|
170
|
+
) -> None:
|
|
171
|
+
if headers is None:
|
|
172
|
+
headers = {}
|
|
173
|
+
|
|
174
|
+
headers_to_send: list[tuple[bytes, bytes]] = []
|
|
175
|
+
for k, v in headers.items():
|
|
176
|
+
if isinstance(k, str):
|
|
177
|
+
k = k.encode("latin-1")
|
|
178
|
+
if isinstance(v, str):
|
|
179
|
+
v = v.encode("latin-1")
|
|
180
|
+
headers_to_send.append((k, v))
|
|
152
181
|
|
|
182
|
+
await send(
|
|
183
|
+
WebsocketAcceptEvent(
|
|
184
|
+
type="websocket.accept", subprotocol=subprotocol, headers=headers_to_send
|
|
185
|
+
)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@trace_function_with_span(tracer)
|
|
190
|
+
async def close_connection(
|
|
191
|
+
s: Span, send: WebsocketSendCallable, /, *, status: WebsocketStatus, data: str
|
|
192
|
+
) -> None:
|
|
193
|
+
s.set_attributes({"status": status.name, "reason": data})
|
|
194
|
+
|
|
195
|
+
await send(
|
|
196
|
+
WebsocketCloseEvent(type="websocket.close", code=status.value, reason=data)
|
|
197
|
+
)
|
|
153
198
|
|
|
154
|
-
|
|
199
|
+
current_websocket_session_span().set_attribute("websocket.close_reason", data)
|
|
155
200
|
|
|
156
201
|
|
|
157
|
-
# >>>
|
|
202
|
+
# >>> Receive
|
|
158
203
|
|
|
204
|
+
# todo(maximsmol): add max message length limit by default
|
|
159
205
|
|
|
160
|
-
async def receive_websocket_data(receive: WebsocketReceiveCallable):
|
|
161
|
-
with tracer.start_as_current_span("read websocket message") as s:
|
|
162
|
-
msg = await receive()
|
|
163
206
|
|
|
164
|
-
|
|
207
|
+
@trace_function(tracer)
|
|
208
|
+
async def receive_data(receive: WebsocketReceiveCallable) -> bytes | str:
|
|
209
|
+
msg = await receive()
|
|
165
210
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
211
|
+
if msg["type"] == "websocket.connect":
|
|
212
|
+
# todo(ayush): allow upgrades here as well?
|
|
213
|
+
raise ValueError("ASGI protocol violation: duplicate websocket.connect event")
|
|
169
214
|
|
|
170
|
-
|
|
171
|
-
|
|
215
|
+
if msg["type"] == "websocket.disconnect":
|
|
216
|
+
raise WebsocketConnectionClosedError(WebsocketStatus(msg["code"]))
|
|
172
217
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
218
|
+
if msg["bytes"] is not None:
|
|
219
|
+
res = msg["bytes"]
|
|
220
|
+
elif msg["text"] is not None:
|
|
221
|
+
res = msg["text"]
|
|
222
|
+
else:
|
|
223
|
+
raise WebsocketBadMessage("empty message")
|
|
179
224
|
|
|
180
|
-
|
|
181
|
-
return res
|
|
225
|
+
return res
|
|
182
226
|
|
|
183
227
|
|
|
184
228
|
@trace_function(tracer)
|
|
185
|
-
async def
|
|
186
|
-
orjson.loads(await
|
|
229
|
+
async def receive_json(receive: WebsocketReceiveCallable) -> Any:
|
|
230
|
+
return orjson.loads(await receive_data(receive))
|
|
187
231
|
|
|
188
232
|
|
|
189
233
|
@trace_function(tracer)
|
|
190
|
-
async def
|
|
234
|
+
async def receive_class_ext(
|
|
191
235
|
receive: WebsocketReceiveCallable, cls: type[T]
|
|
192
236
|
) -> tuple[Any, T]:
|
|
193
|
-
data = await
|
|
237
|
+
data = await receive_json(receive)
|
|
194
238
|
|
|
195
239
|
try:
|
|
196
240
|
return data, validate(data, cls)
|
|
@@ -200,67 +244,28 @@ async def receive_websocket_class_ext(
|
|
|
200
244
|
|
|
201
245
|
@trace_function(tracer)
|
|
202
246
|
async def receive_websocket_class(receive: WebsocketReceiveCallable, cls: type[T]) -> T:
|
|
203
|
-
return (await
|
|
247
|
+
return (await receive_class_ext(receive, cls))[1]
|
|
204
248
|
|
|
205
249
|
|
|
206
|
-
|
|
207
|
-
async def send_websocket_data(
|
|
208
|
-
s: Span,
|
|
209
|
-
send: WebsocketSendCallable,
|
|
210
|
-
data: str | bytes,
|
|
211
|
-
):
|
|
212
|
-
if isinstance(data, str):
|
|
213
|
-
data = data.encode("utf-8")
|
|
214
|
-
|
|
215
|
-
s.set_attribute("body.size", len(data))
|
|
216
|
-
|
|
217
|
-
await send(WebsocketSendEvent(type="websocket.send", bytes=data, text=None))
|
|
218
|
-
|
|
219
|
-
current_websocket_request_span().set_attribute(
|
|
220
|
-
"websocket.sent_message_content_length", len(data)
|
|
221
|
-
)
|
|
250
|
+
# >>> Send
|
|
222
251
|
|
|
223
252
|
|
|
224
253
|
@trace_function(tracer)
|
|
225
|
-
async def
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
subprotocol: str | None = None,
|
|
231
|
-
headers: Headers = {},
|
|
232
|
-
):
|
|
233
|
-
msg = await receive()
|
|
234
|
-
if msg.type != "websocket.connect":
|
|
235
|
-
raise WebsocketBadMessage("cannot accept connection without connection request")
|
|
236
|
-
|
|
237
|
-
headers_to_send: list[tuple[bytes, bytes]] = []
|
|
238
|
-
|
|
239
|
-
for k, v in headers.items():
|
|
240
|
-
if isinstance(k, str):
|
|
241
|
-
k = k.encode("latin-1")
|
|
242
|
-
if isinstance(v, str):
|
|
243
|
-
v = v.encode("latin-1")
|
|
244
|
-
headers_to_send.append((k, v))
|
|
254
|
+
async def send_data(send: WebsocketSendCallable, data: str | bytes, /) -> None:
|
|
255
|
+
if isinstance(data, bytes):
|
|
256
|
+
await send(WebsocketSendEvent(type="websocket.send", bytes=data, text=None))
|
|
257
|
+
else:
|
|
258
|
+
await send(WebsocketSendEvent(type="websocket.send", bytes=None, text=data))
|
|
245
259
|
|
|
246
|
-
await send(
|
|
247
|
-
WebsocketAcceptEvent(
|
|
248
|
-
type="websocket.accept", subprotocol=subprotocol, headers=headers_to_send
|
|
249
|
-
)
|
|
250
|
-
)
|
|
251
260
|
|
|
261
|
+
@trace_function(tracer)
|
|
262
|
+
async def send_json(send: WebsocketSendCallable, data: Any, /) -> None:
|
|
263
|
+
return await send_data(send, orjson.dumps(data))
|
|
252
264
|
|
|
253
|
-
@trace_function_with_span(tracer)
|
|
254
|
-
async def close_websocket_connection(
|
|
255
|
-
s: Span,
|
|
256
|
-
send: WebsocketSendCallable,
|
|
257
|
-
/,
|
|
258
|
-
*,
|
|
259
|
-
status: WebsocketStatus,
|
|
260
|
-
data: str,
|
|
261
|
-
):
|
|
262
|
-
s.set_attribute("reason", data)
|
|
263
265
|
|
|
264
|
-
|
|
266
|
+
@trace_function(tracer)
|
|
267
|
+
async def send_websocket_auto(send: WebsocketSendCallable, data: Any, /) -> None:
|
|
268
|
+
if isinstance(data, str | bytes):
|
|
269
|
+
return await send_data(send, data)
|
|
265
270
|
|
|
266
|
-
|
|
271
|
+
return await send_json(send, data)
|