latch-asgi 0.2.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,5 @@
1
+ from typing import Self
2
+
1
3
  from opentelemetry import trace
2
4
  from opentelemetry.context.context import Context
3
5
  from opentelemetry.propagators import textmap
@@ -6,7 +8,7 @@ from opentelemetry.propagators import textmap
6
8
  class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
7
9
  # https://github.com/open-telemetry/opentelemetry-python-contrib/blob/934af7ea4f9b1e0294ced6a014d6eefdda156b2b/exporter/opentelemetry-exporter-datadog/src/opentelemetry/exporter/datadog/propagator.py
8
10
  def extract(
9
- self,
11
+ self: Self,
10
12
  carrier: textmap.CarrierT,
11
13
  context: Context | None = None,
12
14
  getter: textmap.Getter[textmap.CarrierT] = textmap.default_getter,
@@ -28,12 +30,16 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
28
30
  sampling_priority = getter.get(carrier, "x-datadog-sampling-priority")
29
31
 
30
32
  trace_flags = trace.TraceFlags()
31
- if sampling_priority is not None and len(sampling_priority) > 0:
32
- if int(sampling_priority[0]) in (
33
+ if (
34
+ sampling_priority is not None
35
+ and len(sampling_priority) > 0
36
+ and int(sampling_priority[0])
37
+ in (
33
38
  1, # auto keep
34
39
  2, # user keep
35
- ):
36
- trace_flags = trace.TraceFlags(trace.TraceFlags.SAMPLED)
40
+ )
41
+ ):
42
+ trace_flags = trace.TraceFlags(trace.TraceFlags.SAMPLED)
37
43
 
38
44
  dd_origin = getter.get(carrier, "x-datadog-origin")
39
45
 
@@ -51,7 +57,7 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
51
57
  return trace.set_span_in_context(trace.NonRecordingSpan(span_context), context)
52
58
 
53
59
  def inject(
54
- self,
60
+ self: Self,
55
61
  carrier: textmap.CarrierT,
56
62
  context: Context | None = None,
57
63
  setter: textmap.Setter[textmap.CarrierT] = textmap.default_setter,
@@ -69,11 +75,7 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
69
75
  "x-datadog-trace-id",
70
76
  str(span_context.trace_id & 0xFFFF_FFFF_FFFF_FFFF),
71
77
  )
72
- setter.set(
73
- carrier,
74
- "x-datadog-parent-id",
75
- str(span_context.span_id),
76
- )
78
+ setter.set(carrier, "x-datadog-parent-id", str(span_context.span_id))
77
79
  setter.set(
78
80
  carrier,
79
81
  "x-datadog-sampling-priority",
@@ -87,7 +89,7 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
87
89
  )
88
90
 
89
91
  @property
90
- def fields(self) -> set[str]:
92
+ def fields(self: Self) -> set[str]:
91
93
  return {
92
94
  "x-datadog-trace-id",
93
95
  "x-datadog-parent-id",
@@ -5,3 +5,36 @@ from opentelemetry.trace import get_tracer
5
5
  Headers: TypeAlias = dict[str | bytes, str | bytes]
6
6
 
7
7
  tracer = get_tracer(__name__)
8
+
9
+ otel_header_whitelist = {
10
+ "host",
11
+ "content-type",
12
+ "content-length",
13
+ "accept",
14
+ "accept-encoding",
15
+ "accept-language",
16
+ "user-agent",
17
+ "dnt",
18
+ "sec-fetch-dest",
19
+ "sec-fetch-mode",
20
+ "sec-fetch-site",
21
+ "sec-fetch-user",
22
+ "sec-gpc",
23
+ "te",
24
+ "upgrade-insecure-requests",
25
+ "device-memory",
26
+ "downlink",
27
+ "dpr",
28
+ "ect",
29
+ "rtt",
30
+ "sec-ch-prefers-color-scheme",
31
+ "sec-ch-prefers-reduced-motion",
32
+ "sec-ch-ua",
33
+ "sec-ch-ua-arch",
34
+ "sec-ch-ua-full-version",
35
+ "sec-ch-ua-mobile",
36
+ "sec-ch-ua-model",
37
+ "sec-ch-ua-platform",
38
+ "sec-ch-ua-platfrom-version",
39
+ "viewport-width",
40
+ }
@@ -1,12 +1,11 @@
1
1
  from http import HTTPStatus
2
- from typing import Any, Literal, TypeAlias, TypeVar, cast
2
+ from typing import Any, Literal, Self, TypeAlias, TypeVar, cast
3
3
 
4
- import opentelemetry.context as context
5
- import simdjson
4
+ import orjson
6
5
  from latch_data_validation.data_validation import DataValidationError, validate
7
6
  from latch_o11y.o11y import trace_function, trace_function_with_span
7
+ from opentelemetry import context
8
8
  from opentelemetry.trace.span import Span
9
- from orjson import dumps
10
9
 
11
10
  from ..asgi_iface import (
12
11
  HTTPReceiveCallable,
@@ -18,17 +17,9 @@ from .common import Headers, tracer
18
17
 
19
18
  T = TypeVar("T")
20
19
 
21
- HTTPMethod: TypeAlias = (
22
- Literal["GET"]
23
- | Literal["HEAD"]
24
- | Literal["POST"]
25
- | Literal["PUT"]
26
- | Literal["DELETE"]
27
- | Literal["CONNECT"]
28
- | Literal["OPTIONS"]
29
- | Literal["TRACE"]
30
- | Literal["PATCH"]
31
- )
20
+ HTTPMethod: TypeAlias = Literal[
21
+ "GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"
22
+ ]
32
23
 
33
24
  # >>> O11y
34
25
 
@@ -43,35 +34,68 @@ def current_http_request_span() -> Span:
43
34
 
44
35
 
45
36
  class HTTPErrorResponse(RuntimeError):
46
- def __init__(self, status: HTTPStatus, data: Any, *, headers: Headers = {}):
47
- super().__init__()
37
+ def __init__(
38
+ self: Self, status: HTTPStatus, data: Any, *, headers: Headers | None = None
39
+ ) -> None:
40
+ if headers is None:
41
+ headers = {}
42
+
48
43
  self.status = status
49
44
  self.data = data
50
45
  self.headers = headers
51
46
 
52
47
 
53
48
  class HTTPInternalServerError(HTTPErrorResponse):
54
- def __init__(self, data: Any, *, headers: Headers = {}):
49
+ def __init__(self: Self, data: Any, *, headers: Headers | None = None) -> None:
55
50
  super().__init__(HTTPStatus.INTERNAL_SERVER_ERROR, data, headers=headers)
56
51
 
57
52
 
58
53
  class HTTPBadRequest(HTTPErrorResponse):
59
- def __init__(self, data: Any, *, headers: Headers = {}):
54
+ def __init__(self: Self, data: Any, *, headers: Headers | None = None) -> None:
60
55
  super().__init__(HTTPStatus.BAD_REQUEST, data, headers=headers)
61
56
 
62
57
 
63
58
  class HTTPForbidden(HTTPErrorResponse):
64
- def __init__(self, data: Any, *, headers: Headers = {}):
59
+ def __init__(self: Self, data: Any, *, headers: Headers | None = None) -> None:
65
60
  super().__init__(HTTPStatus.FORBIDDEN, data, headers=headers)
66
61
 
67
62
 
68
- class HTTPConnectionClosedError(RuntimeError): ...
63
+ class HTTPConnectionClosedError(RuntimeError):
64
+ ...
69
65
 
70
66
 
71
- # >>> I/O
67
+ # >>>
68
+ # I/O
69
+ # >>>
72
70
 
73
71
  # todo(maximsmol): add max body length limit by default
74
72
 
73
+ # >>> Receive
74
+
75
+
76
+ @trace_function(tracer)
77
+ async def receive_data(receive: HTTPReceiveCallable) -> bytes:
78
+ res = b""
79
+ more_body = True
80
+ while more_body:
81
+ msg = await receive()
82
+ if msg["type"] == "http.disconnect":
83
+ raise HTTPConnectionClosedError
84
+
85
+ res += msg["body"]
86
+ more_body = msg["more_body"]
87
+
88
+ # todo(maximsmol): accumulate instead of overriding
89
+ # todo(maximsmol): probably use the content-length header if present?
90
+ current_http_request_span().set_attribute("http.request.body.size", len(res))
91
+
92
+ return res
93
+
94
+
95
+ @trace_function(tracer)
96
+ async def receive_json(receive: HTTPReceiveCallable) -> Any:
97
+ return orjson.loads(await receive_data(receive))
98
+
75
99
 
76
100
  async def receive_class_ext(
77
101
  receive: HTTPReceiveCallable, cls: type[T]
@@ -89,53 +113,25 @@ async def receive_class(receive: HTTPReceiveCallable, cls: type[T]) -> T:
89
113
  return (await receive_class_ext(receive, cls))[1]
90
114
 
91
115
 
92
- @trace_function(tracer)
93
- async def receive_json(receive: HTTPReceiveCallable) -> Any:
94
- res = await receive_data(receive)
95
-
96
- p = simdjson.Parser()
97
- try:
98
- return p.parse(res, True)
99
- except ValueError as e:
100
- raise HTTPBadRequest("Failed to parse JSON") from e
101
-
102
-
103
- async def receive_data(receive: HTTPReceiveCallable):
104
- res = b""
105
- more_body = True
106
- while more_body:
107
- with tracer.start_as_current_span("read chunk") as s:
108
- msg = await receive()
109
- if msg.type == "http.disconnect":
110
- raise HTTPConnectionClosedError()
111
-
112
- res += msg.body
113
- more_body = msg.more_body
116
+ # >>> Send
114
117
 
115
- s.set_attributes({"size": len(msg.body), "more_body": more_body})
116
118
 
117
- # todo(maximsmol): accumulate instead of overriding
118
- # todo(maximsmol): probably use the content-length header if present?
119
- current_http_request_span().set_attribute("http.request_content_length", len(res))
120
-
121
- return res
122
-
123
-
124
- @trace_function_with_span(tracer)
119
+ @trace_function(tracer)
125
120
  async def send_http_data(
126
- s: Span,
127
121
  send: HTTPSendCallable,
128
122
  status: HTTPStatus,
129
123
  data: str | bytes,
130
124
  /,
131
125
  *,
132
- content_type: str | bytes | None = "text/plain",
133
- headers: Headers = {},
134
- ):
126
+ content_type: str | bytes | None = b"text/plain",
127
+ headers: Headers | None = None,
128
+ ) -> None:
129
+ if headers is None:
130
+ headers = {}
131
+
135
132
  if isinstance(data, str):
136
133
  data = data.encode("utf-8")
137
134
 
138
- s.set_attribute("size", len(data))
139
135
  headers_to_send: list[tuple[bytes, bytes]] = [
140
136
  (b"Content-Length", str(len(data)).encode("latin-1"))
141
137
  ]
@@ -160,8 +156,6 @@ async def send_http_data(
160
156
  HTTPResponseBodyEvent(type="http.response.body", body=data, more_body=False)
161
157
  )
162
158
 
163
- current_http_request_span().set_attribute("http.response_content_length", len(data))
164
-
165
159
 
166
160
  @trace_function(tracer)
167
161
  async def send_json(
@@ -171,10 +165,10 @@ async def send_json(
171
165
  /,
172
166
  *,
173
167
  content_type: str = "application/json",
174
- headers: Headers = {},
175
- ):
168
+ headers: Headers | None = None,
169
+ ) -> None:
176
170
  return await send_http_data(
177
- send, status, dumps(data), content_type=content_type, headers=headers
171
+ send, status, orjson.dumps(data), content_type=content_type, headers=headers
178
172
  )
179
173
 
180
174
 
@@ -185,9 +179,9 @@ async def send_auto(
185
179
  data: str | bytes | Any,
186
180
  /,
187
181
  *,
188
- headers: Headers = {},
189
- ):
190
- if isinstance(data, str) or isinstance(data, bytes):
182
+ headers: Headers | None = None,
183
+ ) -> None:
184
+ if isinstance(data, str | bytes):
191
185
  return await send_http_data(send, status, data, headers=headers)
192
186
 
193
187
  return await send_json(send, status, data, headers=headers)
@@ -1,11 +1,12 @@
1
1
  from enum import Enum
2
- from typing import Any, TypeVar, cast
2
+ from typing import Any, Self, TypeVar, cast
3
3
 
4
- import opentelemetry.context as context
5
- import simdjson
4
+ import orjson
6
5
  from latch_data_validation.data_validation import DataValidationError, validate
7
6
  from latch_o11y.o11y import trace_function, trace_function_with_span
7
+ from opentelemetry import context
8
8
  from opentelemetry.trace.span import Span
9
+ from opentelemetry.util.types import AttributeValue
9
10
 
10
11
  from ..asgi_iface import (
11
12
  WebsocketAcceptEvent,
@@ -18,39 +19,35 @@ from .common import Headers, tracer
18
19
 
19
20
  T = TypeVar("T")
20
21
 
21
- websocket_request_span_key = context.create_key("websocket_request_span")
22
+ websocket_session_span_key = context.create_key("websocket_message_span")
22
23
 
23
24
 
24
- def current_websocket_request_span() -> Span:
25
- return cast(Span, context.get_value(websocket_request_span_key))
25
+ def current_websocket_session_span() -> Span:
26
+ return cast(Span, context.get_value(websocket_session_span_key))
26
27
 
27
28
 
28
29
  # >>> Error classes
29
30
 
30
31
 
31
32
  class WebsocketStatus(int, Enum):
33
+ """https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1"""
34
+
32
35
  normal = 1000
33
36
  """
34
37
  1000 indicates a normal closure, meaning that the purpose for
35
38
  which the connection was established has been fulfilled.
36
-
37
- https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
38
39
  """
39
40
 
40
41
  going_away = 1001
41
42
  """
42
43
  1001 indicates that an endpoint is "going away", such as a server
43
44
  going down or a browser having navigated away from a page.
44
-
45
- https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
46
45
  """
47
46
 
48
47
  protocol_error = 1002
49
48
  """
50
49
  1002 indicates that an endpoint is terminating the connection due
51
50
  to a protocol error.
52
-
53
- https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
54
51
  """
55
52
 
56
53
  unsupported = 1003
@@ -59,15 +56,11 @@ class WebsocketStatus(int, Enum):
59
56
  because it has received a type of data it cannot accept (e.g., an
60
57
  endpoint that understands only text data MAY send this if it
61
58
  receives a binary message).
62
-
63
- https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
64
59
  """
65
60
 
66
61
  reserved = 1004
67
62
  """
68
63
  Reserved. The specific meaning might be defined in the future.
69
-
70
- https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
71
64
  """
72
65
 
73
66
  no_status = 1005
@@ -76,8 +69,6 @@ class WebsocketStatus(int, Enum):
76
69
  Close control frame by an endpoint. It is designated for use in
77
70
  applications expecting a status code to indicate that no status
78
71
  code was actually present.
79
-
80
- https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
81
72
  """
82
73
 
83
74
  abnormal = 1006
@@ -87,8 +78,6 @@ class WebsocketStatus(int, Enum):
87
78
  applications expecting a status code to indicate that the
88
79
  connection was closed abnormally, e.g., without sending or
89
80
  receiving a Close control frame.
90
-
91
- https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
92
81
  """
93
82
 
94
83
  unsupported_payload = 1007
@@ -97,8 +86,6 @@ class WebsocketStatus(int, Enum):
97
86
  because it has received data within a message that was not
98
87
  consistent with the type of the message (e.g., non-UTF-8 [RFC3629]
99
88
  data within a text message).
100
-
101
- https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
102
89
  """
103
90
 
104
91
  policy_violation = 1008
@@ -108,8 +95,6 @@ class WebsocketStatus(int, Enum):
108
95
  is a generic status code that can be returned when there is no
109
96
  other more suitable status code (e.g., 1003 or 1009) or if there
110
97
  is a need to hide specific details about the policy.
111
-
112
- https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
113
98
  """
114
99
 
115
100
  too_large = 1009
@@ -117,8 +102,6 @@ class WebsocketStatus(int, Enum):
117
102
  1009 indicates that an endpoint is terminating the connection
118
103
  because it has received a message that is too big for it to
119
104
  process.
120
-
121
- https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
122
105
  """
123
106
 
124
107
  mandatory_extension = 1010
@@ -130,8 +113,6 @@ class WebsocketStatus(int, Enum):
130
113
  are needed SHOULD appear in the /reason/ part of the Close frame.
131
114
  Note that this status code is not used by the server, because it
132
115
  can fail the WebSocket handshake instead.
133
-
134
- https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
135
116
  """
136
117
 
137
118
  server_error = 1011
@@ -139,8 +120,6 @@ class WebsocketStatus(int, Enum):
139
120
  1011 indicates that a server is terminating the connection because
140
121
  it encountered an unexpected condition that prevented it from
141
122
  fulfilling the request.
142
-
143
- https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
144
123
  """
145
124
 
146
125
  tls_handshake_fail = 1015
@@ -150,79 +129,112 @@ class WebsocketStatus(int, Enum):
150
129
  applications expecting a status code to indicate that the
151
130
  connection was closed due to a failure to perform a TLS handshake
152
131
  (e.g., the server certificate can't be verified).
153
-
154
- https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
155
132
  """
156
133
 
157
134
 
158
- class WebsocketError(RuntimeError):
159
- def __init__(self, status: WebsocketStatus, data: Any, *, headers: Headers = {}):
160
- super().__init__()
135
+ class WebsocketErrorResponse(RuntimeError):
136
+ def __init__(self: Self, status: WebsocketStatus, data: Any) -> None:
161
137
  self.status = status
162
138
  self.data = data
163
- self.headers = headers
164
139
 
165
140
 
166
- class WebsocketErrorResponse(WebsocketError): ...
141
+ class WebsocketBadMessage(WebsocketErrorResponse):
142
+ def __init__(self: Self, data: Any) -> None:
143
+ super().__init__(WebsocketStatus.policy_violation, data)
167
144
 
168
145
 
169
- class WebsocketConnectionClosedError(RuntimeError): ...
146
+ class WebsocketInternalServerError(WebsocketErrorResponse):
147
+ def __init__(self: Self, data: Any) -> None:
148
+ super().__init__(WebsocketStatus.server_error, data)
170
149
 
171
150
 
172
- class WebsocketBadMessage(WebsocketErrorResponse):
173
- def __init__(self, data: Any, *, headers: Headers = {}):
174
- super().__init__(WebsocketStatus.policy_violation, data, headers=headers)
151
+ class WebsocketConnectionClosedError(RuntimeError):
152
+ def __init__(self: Self, code: WebsocketStatus) -> None:
153
+ self.code = code
175
154
 
176
155
 
177
- class WebsocketInternalServerError(WebsocketErrorResponse):
178
- def __init__(self, data: Any, *, headers: Headers = {}):
179
- super().__init__(WebsocketStatus.server_error, data, headers=headers)
156
+ # >>>
157
+ # I/O
158
+ # >>>
180
159
 
160
+ # >>> Send Lifecycle
181
161
 
182
- # >>> I/O
183
162
 
163
+ @trace_function(tracer)
164
+ async def accept_connection(
165
+ send: WebsocketSendCallable,
166
+ /,
167
+ *,
168
+ subprotocol: str | None = None,
169
+ headers: Headers | None = None,
170
+ ) -> None:
171
+ if headers is None:
172
+ headers = {}
173
+
174
+ headers_to_send: list[tuple[bytes, bytes]] = []
175
+ for k, v in headers.items():
176
+ if isinstance(k, str):
177
+ k = k.encode("latin-1")
178
+ if isinstance(v, str):
179
+ v = v.encode("latin-1")
180
+ headers_to_send.append((k, v))
184
181
 
185
- async def receive_websocket_data(receive: WebsocketReceiveCallable):
186
- res = b""
187
- with tracer.start_as_current_span("read websocket message") as s:
188
- msg = await receive()
182
+ await send(
183
+ WebsocketAcceptEvent(
184
+ type="websocket.accept", subprotocol=subprotocol, headers=headers_to_send
185
+ )
186
+ )
189
187
 
190
- s.set_attribute("type", msg.type)
191
188
 
192
- if msg.type == "websocket.connect":
193
- # todo(ayush): allow upgrades here as well?
194
- raise WebsocketBadMessage("connection has already been established")
189
+ @trace_function_with_span(tracer)
190
+ async def close_connection(
191
+ s: Span, send: WebsocketSendCallable, /, *, status: WebsocketStatus, data: str
192
+ ) -> None:
193
+ s.set_attributes({"status": status.name, "reason": data})
194
+
195
+ await send(
196
+ WebsocketCloseEvent(type="websocket.close", code=status.value, reason=data)
197
+ )
198
+
199
+ current_websocket_session_span().set_attribute("websocket.close_reason", data)
195
200
 
196
- if msg.type == "websocket.disconnect":
197
- raise WebsocketConnectionClosedError()
198
201
 
199
- if msg.bytes is not None:
200
- res = msg.bytes
201
- elif msg.text is not None:
202
- res = msg.text.encode("utf-8")
203
- else:
204
- raise WebsocketBadMessage("empty message")
202
+ # >>> Receive
205
203
 
206
- s.set_attributes({"size": len(res)})
207
- return res
204
+ # todo(maximsmol): add max message length limit by default
208
205
 
209
206
 
210
207
  @trace_function(tracer)
211
- async def receive_websocket_json(receive: WebsocketReceiveCallable) -> Any:
212
- res = await receive_websocket_data(receive)
208
+ async def receive_data(receive: WebsocketReceiveCallable) -> bytes | str:
209
+ msg = await receive()
213
210
 
214
- p = simdjson.Parser()
215
- try:
216
- return p.parse(res, True)
217
- except ValueError as e:
218
- raise WebsocketBadMessage("Failed to parse JSON") from e
211
+ if msg["type"] == "websocket.connect":
212
+ # todo(ayush): allow upgrades here as well?
213
+ raise ValueError("ASGI protocol violation: duplicate websocket.connect event")
214
+
215
+ if msg["type"] == "websocket.disconnect":
216
+ raise WebsocketConnectionClosedError(WebsocketStatus(msg["code"]))
217
+
218
+ if msg["bytes"] is not None:
219
+ res = msg["bytes"]
220
+ elif msg["text"] is not None:
221
+ res = msg["text"]
222
+ else:
223
+ raise WebsocketBadMessage("empty message")
224
+
225
+ return res
219
226
 
220
227
 
221
228
  @trace_function(tracer)
222
- async def receive_websocket_class_ext(
229
+ async def receive_json(receive: WebsocketReceiveCallable) -> Any:
230
+ return orjson.loads(await receive_data(receive))
231
+
232
+
233
+ @trace_function(tracer)
234
+ async def receive_class_ext(
223
235
  receive: WebsocketReceiveCallable, cls: type[T]
224
236
  ) -> tuple[Any, T]:
225
- data = await receive_websocket_json(receive)
237
+ data = await receive_json(receive)
226
238
 
227
239
  try:
228
240
  return data, validate(data, cls)
@@ -232,68 +244,28 @@ async def receive_websocket_class_ext(
232
244
 
233
245
  @trace_function(tracer)
234
246
  async def receive_websocket_class(receive: WebsocketReceiveCallable, cls: type[T]) -> T:
235
- return (await receive_websocket_class_ext(receive, cls))[1]
236
-
247
+ return (await receive_class_ext(receive, cls))[1]
237
248
 
238
- @trace_function_with_span(tracer)
239
- async def send_websocket_data(
240
- s: Span,
241
- send: WebsocketSendCallable,
242
- data: str | bytes,
243
- ):
244
- if isinstance(data, str):
245
- data = data.encode("utf-8")
246
249
 
247
- s.set_attribute("body.size", len(data))
250
+ # >>> Send
248
251
 
249
- await send(WebsocketSendEvent(type="websocket.send", bytes=data, text=None))
250
-
251
- current_websocket_request_span().set_attribute(
252
- "websocket.sent_message_content_length", len(data)
253
- )
254
-
255
-
256
- @trace_function_with_span(tracer)
257
- async def accept_websocket_connection(
258
- s: Span,
259
- send: WebsocketSendCallable,
260
- receive: WebsocketReceiveCallable,
261
- /,
262
- *,
263
- subprotocol: str | None = None,
264
- headers: Headers = {},
265
- ):
266
- msg = await receive()
267
- if msg.type != "websocket.connect":
268
- raise WebsocketBadMessage("cannot accept connection without connection request")
269
252
 
270
- headers_to_send: list[tuple[bytes, bytes]] = []
271
-
272
- for k, v in headers.items():
273
- if isinstance(k, str):
274
- k = k.encode("latin-1")
275
- if isinstance(v, str):
276
- v = v.encode("latin-1")
277
- headers_to_send.append((k, v))
253
+ @trace_function(tracer)
254
+ async def send_data(send: WebsocketSendCallable, data: str | bytes, /) -> None:
255
+ if isinstance(data, bytes):
256
+ await send(WebsocketSendEvent(type="websocket.send", bytes=data, text=None))
257
+ else:
258
+ await send(WebsocketSendEvent(type="websocket.send", bytes=None, text=data))
278
259
 
279
- await send(
280
- WebsocketAcceptEvent(
281
- type="websocket.accept", subprotocol=subprotocol, headers=headers_to_send
282
- )
283
- )
284
260
 
261
+ @trace_function(tracer)
262
+ async def send_json(send: WebsocketSendCallable, data: Any, /) -> None:
263
+ return await send_data(send, orjson.dumps(data))
285
264
 
286
- @trace_function_with_span(tracer)
287
- async def close_websocket_connection(
288
- s: Span,
289
- send: WebsocketSendCallable,
290
- /,
291
- *,
292
- status: WebsocketStatus,
293
- data: str,
294
- ):
295
- s.set_attribute("reason", data)
296
265
 
297
- await send(WebsocketCloseEvent("websocket.close", status.value, data))
266
+ @trace_function(tracer)
267
+ async def send_websocket_auto(send: WebsocketSendCallable, data: Any, /) -> None:
268
+ if isinstance(data, str | bytes):
269
+ return await send_data(send, data)
298
270
 
299
- current_websocket_request_span().set_attribute("websocket.http.close_reason", data)
271
+ return await send_json(send, data)