latch-asgi 0.2.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- latch_asgi/asgi_iface.py +33 -289
- latch_asgi/auth.py +7 -15
- latch_asgi/config.py +2 -0
- latch_asgi/context/common.py +30 -19
- latch_asgi/context/http.py +9 -23
- latch_asgi/context/websocket.py +35 -19
- latch_asgi/datadog_propagator.py +14 -12
- latch_asgi/framework/common.py +33 -0
- latch_asgi/framework/http.py +60 -66
- latch_asgi/framework/websocket.py +100 -128
- latch_asgi/server.py +377 -221
- latch_asgi-1.0.0.dist-info/METADATA +18 -0
- latch_asgi-1.0.0.dist-info/RECORD +19 -0
- {latch_asgi-0.2.0.dist-info → latch_asgi-1.0.0.dist-info}/WHEEL +1 -1
- latch_asgi-1.0.0.dist-info/licenses/COPYING +121 -0
- latch_asgi-0.2.0.dist-info/METADATA +0 -25
- latch_asgi-0.2.0.dist-info/RECORD +0 -18
latch_asgi/datadog_propagator.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Self
|
|
2
|
+
|
|
1
3
|
from opentelemetry import trace
|
|
2
4
|
from opentelemetry.context.context import Context
|
|
3
5
|
from opentelemetry.propagators import textmap
|
|
@@ -6,7 +8,7 @@ from opentelemetry.propagators import textmap
|
|
|
6
8
|
class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
|
|
7
9
|
# https://github.com/open-telemetry/opentelemetry-python-contrib/blob/934af7ea4f9b1e0294ced6a014d6eefdda156b2b/exporter/opentelemetry-exporter-datadog/src/opentelemetry/exporter/datadog/propagator.py
|
|
8
10
|
def extract(
|
|
9
|
-
self,
|
|
11
|
+
self: Self,
|
|
10
12
|
carrier: textmap.CarrierT,
|
|
11
13
|
context: Context | None = None,
|
|
12
14
|
getter: textmap.Getter[textmap.CarrierT] = textmap.default_getter,
|
|
@@ -28,12 +30,16 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
|
|
|
28
30
|
sampling_priority = getter.get(carrier, "x-datadog-sampling-priority")
|
|
29
31
|
|
|
30
32
|
trace_flags = trace.TraceFlags()
|
|
31
|
-
if
|
|
32
|
-
|
|
33
|
+
if (
|
|
34
|
+
sampling_priority is not None
|
|
35
|
+
and len(sampling_priority) > 0
|
|
36
|
+
and int(sampling_priority[0])
|
|
37
|
+
in (
|
|
33
38
|
1, # auto keep
|
|
34
39
|
2, # user keep
|
|
35
|
-
)
|
|
36
|
-
|
|
40
|
+
)
|
|
41
|
+
):
|
|
42
|
+
trace_flags = trace.TraceFlags(trace.TraceFlags.SAMPLED)
|
|
37
43
|
|
|
38
44
|
dd_origin = getter.get(carrier, "x-datadog-origin")
|
|
39
45
|
|
|
@@ -51,7 +57,7 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
|
|
|
51
57
|
return trace.set_span_in_context(trace.NonRecordingSpan(span_context), context)
|
|
52
58
|
|
|
53
59
|
def inject(
|
|
54
|
-
self,
|
|
60
|
+
self: Self,
|
|
55
61
|
carrier: textmap.CarrierT,
|
|
56
62
|
context: Context | None = None,
|
|
57
63
|
setter: textmap.Setter[textmap.CarrierT] = textmap.default_setter,
|
|
@@ -69,11 +75,7 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
|
|
|
69
75
|
"x-datadog-trace-id",
|
|
70
76
|
str(span_context.trace_id & 0xFFFF_FFFF_FFFF_FFFF),
|
|
71
77
|
)
|
|
72
|
-
setter.set(
|
|
73
|
-
carrier,
|
|
74
|
-
"x-datadog-parent-id",
|
|
75
|
-
str(span_context.span_id),
|
|
76
|
-
)
|
|
78
|
+
setter.set(carrier, "x-datadog-parent-id", str(span_context.span_id))
|
|
77
79
|
setter.set(
|
|
78
80
|
carrier,
|
|
79
81
|
"x-datadog-sampling-priority",
|
|
@@ -87,7 +89,7 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
|
|
|
87
89
|
)
|
|
88
90
|
|
|
89
91
|
@property
|
|
90
|
-
def fields(self) -> set[str]:
|
|
92
|
+
def fields(self: Self) -> set[str]:
|
|
91
93
|
return {
|
|
92
94
|
"x-datadog-trace-id",
|
|
93
95
|
"x-datadog-parent-id",
|
latch_asgi/framework/common.py
CHANGED
|
@@ -5,3 +5,36 @@ from opentelemetry.trace import get_tracer
|
|
|
5
5
|
Headers: TypeAlias = dict[str | bytes, str | bytes]
|
|
6
6
|
|
|
7
7
|
tracer = get_tracer(__name__)
|
|
8
|
+
|
|
9
|
+
otel_header_whitelist = {
|
|
10
|
+
"host",
|
|
11
|
+
"content-type",
|
|
12
|
+
"content-length",
|
|
13
|
+
"accept",
|
|
14
|
+
"accept-encoding",
|
|
15
|
+
"accept-language",
|
|
16
|
+
"user-agent",
|
|
17
|
+
"dnt",
|
|
18
|
+
"sec-fetch-dest",
|
|
19
|
+
"sec-fetch-mode",
|
|
20
|
+
"sec-fetch-site",
|
|
21
|
+
"sec-fetch-user",
|
|
22
|
+
"sec-gpc",
|
|
23
|
+
"te",
|
|
24
|
+
"upgrade-insecure-requests",
|
|
25
|
+
"device-memory",
|
|
26
|
+
"downlink",
|
|
27
|
+
"dpr",
|
|
28
|
+
"ect",
|
|
29
|
+
"rtt",
|
|
30
|
+
"sec-ch-prefers-color-scheme",
|
|
31
|
+
"sec-ch-prefers-reduced-motion",
|
|
32
|
+
"sec-ch-ua",
|
|
33
|
+
"sec-ch-ua-arch",
|
|
34
|
+
"sec-ch-ua-full-version",
|
|
35
|
+
"sec-ch-ua-mobile",
|
|
36
|
+
"sec-ch-ua-model",
|
|
37
|
+
"sec-ch-ua-platform",
|
|
38
|
+
"sec-ch-ua-platfrom-version",
|
|
39
|
+
"viewport-width",
|
|
40
|
+
}
|
latch_asgi/framework/http.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
from http import HTTPStatus
|
|
2
|
-
from typing import Any, Literal, TypeAlias, TypeVar, cast
|
|
2
|
+
from typing import Any, Literal, Self, TypeAlias, TypeVar, cast
|
|
3
3
|
|
|
4
|
-
import
|
|
5
|
-
import simdjson
|
|
4
|
+
import orjson
|
|
6
5
|
from latch_data_validation.data_validation import DataValidationError, validate
|
|
7
6
|
from latch_o11y.o11y import trace_function, trace_function_with_span
|
|
7
|
+
from opentelemetry import context
|
|
8
8
|
from opentelemetry.trace.span import Span
|
|
9
|
-
from orjson import dumps
|
|
10
9
|
|
|
11
10
|
from ..asgi_iface import (
|
|
12
11
|
HTTPReceiveCallable,
|
|
@@ -18,17 +17,9 @@ from .common import Headers, tracer
|
|
|
18
17
|
|
|
19
18
|
T = TypeVar("T")
|
|
20
19
|
|
|
21
|
-
HTTPMethod: TypeAlias =
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
| Literal["POST"]
|
|
25
|
-
| Literal["PUT"]
|
|
26
|
-
| Literal["DELETE"]
|
|
27
|
-
| Literal["CONNECT"]
|
|
28
|
-
| Literal["OPTIONS"]
|
|
29
|
-
| Literal["TRACE"]
|
|
30
|
-
| Literal["PATCH"]
|
|
31
|
-
)
|
|
20
|
+
HTTPMethod: TypeAlias = Literal[
|
|
21
|
+
"GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"
|
|
22
|
+
]
|
|
32
23
|
|
|
33
24
|
# >>> O11y
|
|
34
25
|
|
|
@@ -43,35 +34,68 @@ def current_http_request_span() -> Span:
|
|
|
43
34
|
|
|
44
35
|
|
|
45
36
|
class HTTPErrorResponse(RuntimeError):
|
|
46
|
-
def __init__(
|
|
47
|
-
|
|
37
|
+
def __init__(
|
|
38
|
+
self: Self, status: HTTPStatus, data: Any, *, headers: Headers | None = None
|
|
39
|
+
) -> None:
|
|
40
|
+
if headers is None:
|
|
41
|
+
headers = {}
|
|
42
|
+
|
|
48
43
|
self.status = status
|
|
49
44
|
self.data = data
|
|
50
45
|
self.headers = headers
|
|
51
46
|
|
|
52
47
|
|
|
53
48
|
class HTTPInternalServerError(HTTPErrorResponse):
|
|
54
|
-
def __init__(self, data: Any, *, headers: Headers =
|
|
49
|
+
def __init__(self: Self, data: Any, *, headers: Headers | None = None) -> None:
|
|
55
50
|
super().__init__(HTTPStatus.INTERNAL_SERVER_ERROR, data, headers=headers)
|
|
56
51
|
|
|
57
52
|
|
|
58
53
|
class HTTPBadRequest(HTTPErrorResponse):
|
|
59
|
-
def __init__(self, data: Any, *, headers: Headers =
|
|
54
|
+
def __init__(self: Self, data: Any, *, headers: Headers | None = None) -> None:
|
|
60
55
|
super().__init__(HTTPStatus.BAD_REQUEST, data, headers=headers)
|
|
61
56
|
|
|
62
57
|
|
|
63
58
|
class HTTPForbidden(HTTPErrorResponse):
|
|
64
|
-
def __init__(self, data: Any, *, headers: Headers =
|
|
59
|
+
def __init__(self: Self, data: Any, *, headers: Headers | None = None) -> None:
|
|
65
60
|
super().__init__(HTTPStatus.FORBIDDEN, data, headers=headers)
|
|
66
61
|
|
|
67
62
|
|
|
68
|
-
class HTTPConnectionClosedError(RuntimeError):
|
|
63
|
+
class HTTPConnectionClosedError(RuntimeError):
|
|
64
|
+
...
|
|
69
65
|
|
|
70
66
|
|
|
71
|
-
# >>>
|
|
67
|
+
# >>>
|
|
68
|
+
# I/O
|
|
69
|
+
# >>>
|
|
72
70
|
|
|
73
71
|
# todo(maximsmol): add max body length limit by default
|
|
74
72
|
|
|
73
|
+
# >>> Receive
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@trace_function(tracer)
|
|
77
|
+
async def receive_data(receive: HTTPReceiveCallable) -> bytes:
|
|
78
|
+
res = b""
|
|
79
|
+
more_body = True
|
|
80
|
+
while more_body:
|
|
81
|
+
msg = await receive()
|
|
82
|
+
if msg["type"] == "http.disconnect":
|
|
83
|
+
raise HTTPConnectionClosedError
|
|
84
|
+
|
|
85
|
+
res += msg["body"]
|
|
86
|
+
more_body = msg["more_body"]
|
|
87
|
+
|
|
88
|
+
# todo(maximsmol): accumulate instead of overriding
|
|
89
|
+
# todo(maximsmol): probably use the content-length header if present?
|
|
90
|
+
current_http_request_span().set_attribute("http.request.body.size", len(res))
|
|
91
|
+
|
|
92
|
+
return res
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@trace_function(tracer)
|
|
96
|
+
async def receive_json(receive: HTTPReceiveCallable) -> Any:
|
|
97
|
+
return orjson.loads(await receive_data(receive))
|
|
98
|
+
|
|
75
99
|
|
|
76
100
|
async def receive_class_ext(
|
|
77
101
|
receive: HTTPReceiveCallable, cls: type[T]
|
|
@@ -89,53 +113,25 @@ async def receive_class(receive: HTTPReceiveCallable, cls: type[T]) -> T:
|
|
|
89
113
|
return (await receive_class_ext(receive, cls))[1]
|
|
90
114
|
|
|
91
115
|
|
|
92
|
-
|
|
93
|
-
async def receive_json(receive: HTTPReceiveCallable) -> Any:
|
|
94
|
-
res = await receive_data(receive)
|
|
95
|
-
|
|
96
|
-
p = simdjson.Parser()
|
|
97
|
-
try:
|
|
98
|
-
return p.parse(res, True)
|
|
99
|
-
except ValueError as e:
|
|
100
|
-
raise HTTPBadRequest("Failed to parse JSON") from e
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
async def receive_data(receive: HTTPReceiveCallable):
|
|
104
|
-
res = b""
|
|
105
|
-
more_body = True
|
|
106
|
-
while more_body:
|
|
107
|
-
with tracer.start_as_current_span("read chunk") as s:
|
|
108
|
-
msg = await receive()
|
|
109
|
-
if msg.type == "http.disconnect":
|
|
110
|
-
raise HTTPConnectionClosedError()
|
|
111
|
-
|
|
112
|
-
res += msg.body
|
|
113
|
-
more_body = msg.more_body
|
|
116
|
+
# >>> Send
|
|
114
117
|
|
|
115
|
-
s.set_attributes({"size": len(msg.body), "more_body": more_body})
|
|
116
118
|
|
|
117
|
-
|
|
118
|
-
# todo(maximsmol): probably use the content-length header if present?
|
|
119
|
-
current_http_request_span().set_attribute("http.request_content_length", len(res))
|
|
120
|
-
|
|
121
|
-
return res
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
@trace_function_with_span(tracer)
|
|
119
|
+
@trace_function(tracer)
|
|
125
120
|
async def send_http_data(
|
|
126
|
-
s: Span,
|
|
127
121
|
send: HTTPSendCallable,
|
|
128
122
|
status: HTTPStatus,
|
|
129
123
|
data: str | bytes,
|
|
130
124
|
/,
|
|
131
125
|
*,
|
|
132
|
-
content_type: str | bytes | None = "text/plain",
|
|
133
|
-
headers: Headers =
|
|
134
|
-
):
|
|
126
|
+
content_type: str | bytes | None = b"text/plain",
|
|
127
|
+
headers: Headers | None = None,
|
|
128
|
+
) -> None:
|
|
129
|
+
if headers is None:
|
|
130
|
+
headers = {}
|
|
131
|
+
|
|
135
132
|
if isinstance(data, str):
|
|
136
133
|
data = data.encode("utf-8")
|
|
137
134
|
|
|
138
|
-
s.set_attribute("size", len(data))
|
|
139
135
|
headers_to_send: list[tuple[bytes, bytes]] = [
|
|
140
136
|
(b"Content-Length", str(len(data)).encode("latin-1"))
|
|
141
137
|
]
|
|
@@ -160,8 +156,6 @@ async def send_http_data(
|
|
|
160
156
|
HTTPResponseBodyEvent(type="http.response.body", body=data, more_body=False)
|
|
161
157
|
)
|
|
162
158
|
|
|
163
|
-
current_http_request_span().set_attribute("http.response_content_length", len(data))
|
|
164
|
-
|
|
165
159
|
|
|
166
160
|
@trace_function(tracer)
|
|
167
161
|
async def send_json(
|
|
@@ -171,10 +165,10 @@ async def send_json(
|
|
|
171
165
|
/,
|
|
172
166
|
*,
|
|
173
167
|
content_type: str = "application/json",
|
|
174
|
-
headers: Headers =
|
|
175
|
-
):
|
|
168
|
+
headers: Headers | None = None,
|
|
169
|
+
) -> None:
|
|
176
170
|
return await send_http_data(
|
|
177
|
-
send, status, dumps(data), content_type=content_type, headers=headers
|
|
171
|
+
send, status, orjson.dumps(data), content_type=content_type, headers=headers
|
|
178
172
|
)
|
|
179
173
|
|
|
180
174
|
|
|
@@ -185,9 +179,9 @@ async def send_auto(
|
|
|
185
179
|
data: str | bytes | Any,
|
|
186
180
|
/,
|
|
187
181
|
*,
|
|
188
|
-
headers: Headers =
|
|
189
|
-
):
|
|
190
|
-
if isinstance(data, str
|
|
182
|
+
headers: Headers | None = None,
|
|
183
|
+
) -> None:
|
|
184
|
+
if isinstance(data, str | bytes):
|
|
191
185
|
return await send_http_data(send, status, data, headers=headers)
|
|
192
186
|
|
|
193
187
|
return await send_json(send, status, data, headers=headers)
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
from enum import Enum
|
|
2
|
-
from typing import Any, TypeVar, cast
|
|
2
|
+
from typing import Any, Self, TypeVar, cast
|
|
3
3
|
|
|
4
|
-
import
|
|
5
|
-
import simdjson
|
|
4
|
+
import orjson
|
|
6
5
|
from latch_data_validation.data_validation import DataValidationError, validate
|
|
7
6
|
from latch_o11y.o11y import trace_function, trace_function_with_span
|
|
7
|
+
from opentelemetry import context
|
|
8
8
|
from opentelemetry.trace.span import Span
|
|
9
|
+
from opentelemetry.util.types import AttributeValue
|
|
9
10
|
|
|
10
11
|
from ..asgi_iface import (
|
|
11
12
|
WebsocketAcceptEvent,
|
|
@@ -18,39 +19,35 @@ from .common import Headers, tracer
|
|
|
18
19
|
|
|
19
20
|
T = TypeVar("T")
|
|
20
21
|
|
|
21
|
-
|
|
22
|
+
websocket_session_span_key = context.create_key("websocket_message_span")
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
def
|
|
25
|
-
return cast(Span, context.get_value(
|
|
25
|
+
def current_websocket_session_span() -> Span:
|
|
26
|
+
return cast(Span, context.get_value(websocket_session_span_key))
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
# >>> Error classes
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
class WebsocketStatus(int, Enum):
|
|
33
|
+
"""https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1"""
|
|
34
|
+
|
|
32
35
|
normal = 1000
|
|
33
36
|
"""
|
|
34
37
|
1000 indicates a normal closure, meaning that the purpose for
|
|
35
38
|
which the connection was established has been fulfilled.
|
|
36
|
-
|
|
37
|
-
https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
|
|
38
39
|
"""
|
|
39
40
|
|
|
40
41
|
going_away = 1001
|
|
41
42
|
"""
|
|
42
43
|
1001 indicates that an endpoint is "going away", such as a server
|
|
43
44
|
going down or a browser having navigated away from a page.
|
|
44
|
-
|
|
45
|
-
https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
|
|
46
45
|
"""
|
|
47
46
|
|
|
48
47
|
protocol_error = 1002
|
|
49
48
|
"""
|
|
50
49
|
1002 indicates that an endpoint is terminating the connection due
|
|
51
50
|
to a protocol error.
|
|
52
|
-
|
|
53
|
-
https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
|
|
54
51
|
"""
|
|
55
52
|
|
|
56
53
|
unsupported = 1003
|
|
@@ -59,15 +56,11 @@ class WebsocketStatus(int, Enum):
|
|
|
59
56
|
because it has received a type of data it cannot accept (e.g., an
|
|
60
57
|
endpoint that understands only text data MAY send this if it
|
|
61
58
|
receives a binary message).
|
|
62
|
-
|
|
63
|
-
https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
|
|
64
59
|
"""
|
|
65
60
|
|
|
66
61
|
reserved = 1004
|
|
67
62
|
"""
|
|
68
63
|
Reserved. The specific meaning might be defined in the future.
|
|
69
|
-
|
|
70
|
-
https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
|
|
71
64
|
"""
|
|
72
65
|
|
|
73
66
|
no_status = 1005
|
|
@@ -76,8 +69,6 @@ class WebsocketStatus(int, Enum):
|
|
|
76
69
|
Close control frame by an endpoint. It is designated for use in
|
|
77
70
|
applications expecting a status code to indicate that no status
|
|
78
71
|
code was actually present.
|
|
79
|
-
|
|
80
|
-
https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
|
|
81
72
|
"""
|
|
82
73
|
|
|
83
74
|
abnormal = 1006
|
|
@@ -87,8 +78,6 @@ class WebsocketStatus(int, Enum):
|
|
|
87
78
|
applications expecting a status code to indicate that the
|
|
88
79
|
connection was closed abnormally, e.g., without sending or
|
|
89
80
|
receiving a Close control frame.
|
|
90
|
-
|
|
91
|
-
https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
|
|
92
81
|
"""
|
|
93
82
|
|
|
94
83
|
unsupported_payload = 1007
|
|
@@ -97,8 +86,6 @@ class WebsocketStatus(int, Enum):
|
|
|
97
86
|
because it has received data within a message that was not
|
|
98
87
|
consistent with the type of the message (e.g., non-UTF-8 [RFC3629]
|
|
99
88
|
data within a text message).
|
|
100
|
-
|
|
101
|
-
https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
|
|
102
89
|
"""
|
|
103
90
|
|
|
104
91
|
policy_violation = 1008
|
|
@@ -108,8 +95,6 @@ class WebsocketStatus(int, Enum):
|
|
|
108
95
|
is a generic status code that can be returned when there is no
|
|
109
96
|
other more suitable status code (e.g., 1003 or 1009) or if there
|
|
110
97
|
is a need to hide specific details about the policy.
|
|
111
|
-
|
|
112
|
-
https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
|
|
113
98
|
"""
|
|
114
99
|
|
|
115
100
|
too_large = 1009
|
|
@@ -117,8 +102,6 @@ class WebsocketStatus(int, Enum):
|
|
|
117
102
|
1009 indicates that an endpoint is terminating the connection
|
|
118
103
|
because it has received a message that is too big for it to
|
|
119
104
|
process.
|
|
120
|
-
|
|
121
|
-
https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
|
|
122
105
|
"""
|
|
123
106
|
|
|
124
107
|
mandatory_extension = 1010
|
|
@@ -130,8 +113,6 @@ class WebsocketStatus(int, Enum):
|
|
|
130
113
|
are needed SHOULD appear in the /reason/ part of the Close frame.
|
|
131
114
|
Note that this status code is not used by the server, because it
|
|
132
115
|
can fail the WebSocket handshake instead.
|
|
133
|
-
|
|
134
|
-
https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
|
|
135
116
|
"""
|
|
136
117
|
|
|
137
118
|
server_error = 1011
|
|
@@ -139,8 +120,6 @@ class WebsocketStatus(int, Enum):
|
|
|
139
120
|
1011 indicates that a server is terminating the connection because
|
|
140
121
|
it encountered an unexpected condition that prevented it from
|
|
141
122
|
fulfilling the request.
|
|
142
|
-
|
|
143
|
-
https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
|
|
144
123
|
"""
|
|
145
124
|
|
|
146
125
|
tls_handshake_fail = 1015
|
|
@@ -150,79 +129,112 @@ class WebsocketStatus(int, Enum):
|
|
|
150
129
|
applications expecting a status code to indicate that the
|
|
151
130
|
connection was closed due to a failure to perform a TLS handshake
|
|
152
131
|
(e.g., the server certificate can't be verified).
|
|
153
|
-
|
|
154
|
-
https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
|
|
155
132
|
"""
|
|
156
133
|
|
|
157
134
|
|
|
158
|
-
class
|
|
159
|
-
def __init__(self, status: WebsocketStatus, data: Any
|
|
160
|
-
super().__init__()
|
|
135
|
+
class WebsocketErrorResponse(RuntimeError):
|
|
136
|
+
def __init__(self: Self, status: WebsocketStatus, data: Any) -> None:
|
|
161
137
|
self.status = status
|
|
162
138
|
self.data = data
|
|
163
|
-
self.headers = headers
|
|
164
139
|
|
|
165
140
|
|
|
166
|
-
class WebsocketErrorResponse
|
|
141
|
+
class WebsocketBadMessage(WebsocketErrorResponse):
|
|
142
|
+
def __init__(self: Self, data: Any) -> None:
|
|
143
|
+
super().__init__(WebsocketStatus.policy_violation, data)
|
|
167
144
|
|
|
168
145
|
|
|
169
|
-
class
|
|
146
|
+
class WebsocketInternalServerError(WebsocketErrorResponse):
|
|
147
|
+
def __init__(self: Self, data: Any) -> None:
|
|
148
|
+
super().__init__(WebsocketStatus.server_error, data)
|
|
170
149
|
|
|
171
150
|
|
|
172
|
-
class
|
|
173
|
-
def __init__(self
|
|
174
|
-
|
|
151
|
+
class WebsocketConnectionClosedError(RuntimeError):
|
|
152
|
+
def __init__(self: Self, code: WebsocketStatus) -> None:
|
|
153
|
+
self.code = code
|
|
175
154
|
|
|
176
155
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
156
|
+
# >>>
|
|
157
|
+
# I/O
|
|
158
|
+
# >>>
|
|
180
159
|
|
|
160
|
+
# >>> Send Lifecycle
|
|
181
161
|
|
|
182
|
-
# >>> I/O
|
|
183
162
|
|
|
163
|
+
@trace_function(tracer)
|
|
164
|
+
async def accept_connection(
|
|
165
|
+
send: WebsocketSendCallable,
|
|
166
|
+
/,
|
|
167
|
+
*,
|
|
168
|
+
subprotocol: str | None = None,
|
|
169
|
+
headers: Headers | None = None,
|
|
170
|
+
) -> None:
|
|
171
|
+
if headers is None:
|
|
172
|
+
headers = {}
|
|
173
|
+
|
|
174
|
+
headers_to_send: list[tuple[bytes, bytes]] = []
|
|
175
|
+
for k, v in headers.items():
|
|
176
|
+
if isinstance(k, str):
|
|
177
|
+
k = k.encode("latin-1")
|
|
178
|
+
if isinstance(v, str):
|
|
179
|
+
v = v.encode("latin-1")
|
|
180
|
+
headers_to_send.append((k, v))
|
|
184
181
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
182
|
+
await send(
|
|
183
|
+
WebsocketAcceptEvent(
|
|
184
|
+
type="websocket.accept", subprotocol=subprotocol, headers=headers_to_send
|
|
185
|
+
)
|
|
186
|
+
)
|
|
189
187
|
|
|
190
|
-
s.set_attribute("type", msg.type)
|
|
191
188
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
189
|
+
@trace_function_with_span(tracer)
|
|
190
|
+
async def close_connection(
|
|
191
|
+
s: Span, send: WebsocketSendCallable, /, *, status: WebsocketStatus, data: str
|
|
192
|
+
) -> None:
|
|
193
|
+
s.set_attributes({"status": status.name, "reason": data})
|
|
194
|
+
|
|
195
|
+
await send(
|
|
196
|
+
WebsocketCloseEvent(type="websocket.close", code=status.value, reason=data)
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
current_websocket_session_span().set_attribute("websocket.close_reason", data)
|
|
195
200
|
|
|
196
|
-
if msg.type == "websocket.disconnect":
|
|
197
|
-
raise WebsocketConnectionClosedError()
|
|
198
201
|
|
|
199
|
-
|
|
200
|
-
res = msg.bytes
|
|
201
|
-
elif msg.text is not None:
|
|
202
|
-
res = msg.text.encode("utf-8")
|
|
203
|
-
else:
|
|
204
|
-
raise WebsocketBadMessage("empty message")
|
|
202
|
+
# >>> Receive
|
|
205
203
|
|
|
206
|
-
|
|
207
|
-
return res
|
|
204
|
+
# todo(maximsmol): add max message length limit by default
|
|
208
205
|
|
|
209
206
|
|
|
210
207
|
@trace_function(tracer)
|
|
211
|
-
async def
|
|
212
|
-
|
|
208
|
+
async def receive_data(receive: WebsocketReceiveCallable) -> bytes | str:
|
|
209
|
+
msg = await receive()
|
|
213
210
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
211
|
+
if msg["type"] == "websocket.connect":
|
|
212
|
+
# todo(ayush): allow upgrades here as well?
|
|
213
|
+
raise ValueError("ASGI protocol violation: duplicate websocket.connect event")
|
|
214
|
+
|
|
215
|
+
if msg["type"] == "websocket.disconnect":
|
|
216
|
+
raise WebsocketConnectionClosedError(WebsocketStatus(msg["code"]))
|
|
217
|
+
|
|
218
|
+
if msg["bytes"] is not None:
|
|
219
|
+
res = msg["bytes"]
|
|
220
|
+
elif msg["text"] is not None:
|
|
221
|
+
res = msg["text"]
|
|
222
|
+
else:
|
|
223
|
+
raise WebsocketBadMessage("empty message")
|
|
224
|
+
|
|
225
|
+
return res
|
|
219
226
|
|
|
220
227
|
|
|
221
228
|
@trace_function(tracer)
|
|
222
|
-
async def
|
|
229
|
+
async def receive_json(receive: WebsocketReceiveCallable) -> Any:
|
|
230
|
+
return orjson.loads(await receive_data(receive))
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@trace_function(tracer)
|
|
234
|
+
async def receive_class_ext(
|
|
223
235
|
receive: WebsocketReceiveCallable, cls: type[T]
|
|
224
236
|
) -> tuple[Any, T]:
|
|
225
|
-
data = await
|
|
237
|
+
data = await receive_json(receive)
|
|
226
238
|
|
|
227
239
|
try:
|
|
228
240
|
return data, validate(data, cls)
|
|
@@ -232,68 +244,28 @@ async def receive_websocket_class_ext(
|
|
|
232
244
|
|
|
233
245
|
@trace_function(tracer)
|
|
234
246
|
async def receive_websocket_class(receive: WebsocketReceiveCallable, cls: type[T]) -> T:
|
|
235
|
-
return (await
|
|
236
|
-
|
|
247
|
+
return (await receive_class_ext(receive, cls))[1]
|
|
237
248
|
|
|
238
|
-
@trace_function_with_span(tracer)
|
|
239
|
-
async def send_websocket_data(
|
|
240
|
-
s: Span,
|
|
241
|
-
send: WebsocketSendCallable,
|
|
242
|
-
data: str | bytes,
|
|
243
|
-
):
|
|
244
|
-
if isinstance(data, str):
|
|
245
|
-
data = data.encode("utf-8")
|
|
246
249
|
|
|
247
|
-
|
|
250
|
+
# >>> Send
|
|
248
251
|
|
|
249
|
-
await send(WebsocketSendEvent(type="websocket.send", bytes=data, text=None))
|
|
250
|
-
|
|
251
|
-
current_websocket_request_span().set_attribute(
|
|
252
|
-
"websocket.sent_message_content_length", len(data)
|
|
253
|
-
)
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
@trace_function_with_span(tracer)
|
|
257
|
-
async def accept_websocket_connection(
|
|
258
|
-
s: Span,
|
|
259
|
-
send: WebsocketSendCallable,
|
|
260
|
-
receive: WebsocketReceiveCallable,
|
|
261
|
-
/,
|
|
262
|
-
*,
|
|
263
|
-
subprotocol: str | None = None,
|
|
264
|
-
headers: Headers = {},
|
|
265
|
-
):
|
|
266
|
-
msg = await receive()
|
|
267
|
-
if msg.type != "websocket.connect":
|
|
268
|
-
raise WebsocketBadMessage("cannot accept connection without connection request")
|
|
269
252
|
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
v = v.encode("latin-1")
|
|
277
|
-
headers_to_send.append((k, v))
|
|
253
|
+
@trace_function(tracer)
|
|
254
|
+
async def send_data(send: WebsocketSendCallable, data: str | bytes, /) -> None:
|
|
255
|
+
if isinstance(data, bytes):
|
|
256
|
+
await send(WebsocketSendEvent(type="websocket.send", bytes=data, text=None))
|
|
257
|
+
else:
|
|
258
|
+
await send(WebsocketSendEvent(type="websocket.send", bytes=None, text=data))
|
|
278
259
|
|
|
279
|
-
await send(
|
|
280
|
-
WebsocketAcceptEvent(
|
|
281
|
-
type="websocket.accept", subprotocol=subprotocol, headers=headers_to_send
|
|
282
|
-
)
|
|
283
|
-
)
|
|
284
260
|
|
|
261
|
+
@trace_function(tracer)
|
|
262
|
+
async def send_json(send: WebsocketSendCallable, data: Any, /) -> None:
|
|
263
|
+
return await send_data(send, orjson.dumps(data))
|
|
285
264
|
|
|
286
|
-
@trace_function_with_span(tracer)
|
|
287
|
-
async def close_websocket_connection(
|
|
288
|
-
s: Span,
|
|
289
|
-
send: WebsocketSendCallable,
|
|
290
|
-
/,
|
|
291
|
-
*,
|
|
292
|
-
status: WebsocketStatus,
|
|
293
|
-
data: str,
|
|
294
|
-
):
|
|
295
|
-
s.set_attribute("reason", data)
|
|
296
265
|
|
|
297
|
-
|
|
266
|
+
@trace_function(tracer)
|
|
267
|
+
async def send_websocket_auto(send: WebsocketSendCallable, data: Any, /) -> None:
|
|
268
|
+
if isinstance(data, str | bytes):
|
|
269
|
+
return await send_data(send, data)
|
|
298
270
|
|
|
299
|
-
|
|
271
|
+
return await send_json(send, data)
|