latch-asgi 0.3.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,5 @@
1
+ from typing import Self
2
+
1
3
  from opentelemetry import trace
2
4
  from opentelemetry.context.context import Context
3
5
  from opentelemetry.propagators import textmap
@@ -6,7 +8,7 @@ from opentelemetry.propagators import textmap
6
8
  class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
7
9
  # https://github.com/open-telemetry/opentelemetry-python-contrib/blob/934af7ea4f9b1e0294ced6a014d6eefdda156b2b/exporter/opentelemetry-exporter-datadog/src/opentelemetry/exporter/datadog/propagator.py
8
10
  def extract(
9
- self,
11
+ self: Self,
10
12
  carrier: textmap.CarrierT,
11
13
  context: Context | None = None,
12
14
  getter: textmap.Getter[textmap.CarrierT] = textmap.default_getter,
@@ -28,12 +30,16 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
28
30
  sampling_priority = getter.get(carrier, "x-datadog-sampling-priority")
29
31
 
30
32
  trace_flags = trace.TraceFlags()
31
- if sampling_priority is not None and len(sampling_priority) > 0:
32
- if int(sampling_priority[0]) in (
33
+ if (
34
+ sampling_priority is not None
35
+ and len(sampling_priority) > 0
36
+ and int(sampling_priority[0])
37
+ in (
33
38
  1, # auto keep
34
39
  2, # user keep
35
- ):
36
- trace_flags = trace.TraceFlags(trace.TraceFlags.SAMPLED)
40
+ )
41
+ ):
42
+ trace_flags = trace.TraceFlags(trace.TraceFlags.SAMPLED)
37
43
 
38
44
  dd_origin = getter.get(carrier, "x-datadog-origin")
39
45
 
@@ -51,7 +57,7 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
51
57
  return trace.set_span_in_context(trace.NonRecordingSpan(span_context), context)
52
58
 
53
59
  def inject(
54
- self,
60
+ self: Self,
55
61
  carrier: textmap.CarrierT,
56
62
  context: Context | None = None,
57
63
  setter: textmap.Setter[textmap.CarrierT] = textmap.default_setter,
@@ -69,11 +75,7 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
69
75
  "x-datadog-trace-id",
70
76
  str(span_context.trace_id & 0xFFFF_FFFF_FFFF_FFFF),
71
77
  )
72
- setter.set(
73
- carrier,
74
- "x-datadog-parent-id",
75
- str(span_context.span_id),
76
- )
78
+ setter.set(carrier, "x-datadog-parent-id", str(span_context.span_id))
77
79
  setter.set(
78
80
  carrier,
79
81
  "x-datadog-sampling-priority",
@@ -87,7 +89,7 @@ class DDTraceContextTextMapPropagator(textmap.TextMapPropagator):
87
89
  )
88
90
 
89
91
  @property
90
- def fields(self) -> set[str]:
92
+ def fields(self: Self) -> set[str]:
91
93
  return {
92
94
  "x-datadog-trace-id",
93
95
  "x-datadog-parent-id",
@@ -5,3 +5,36 @@ from opentelemetry.trace import get_tracer
5
5
  Headers: TypeAlias = dict[str | bytes, str | bytes]
6
6
 
7
7
  tracer = get_tracer(__name__)
8
+
9
+ otel_header_whitelist = {
10
+ "host",
11
+ "content-type",
12
+ "content-length",
13
+ "accept",
14
+ "accept-encoding",
15
+ "accept-language",
16
+ "user-agent",
17
+ "dnt",
18
+ "sec-fetch-dest",
19
+ "sec-fetch-mode",
20
+ "sec-fetch-site",
21
+ "sec-fetch-user",
22
+ "sec-gpc",
23
+ "te",
24
+ "upgrade-insecure-requests",
25
+ "device-memory",
26
+ "downlink",
27
+ "dpr",
28
+ "ect",
29
+ "rtt",
30
+ "sec-ch-prefers-color-scheme",
31
+ "sec-ch-prefers-reduced-motion",
32
+ "sec-ch-ua",
33
+ "sec-ch-ua-arch",
34
+ "sec-ch-ua-full-version",
35
+ "sec-ch-ua-mobile",
36
+ "sec-ch-ua-model",
37
+ "sec-ch-ua-platform",
38
+ "sec-ch-ua-platfrom-version",
39
+ "viewport-width",
40
+ }
@@ -1,10 +1,10 @@
1
1
  from http import HTTPStatus
2
- from typing import Any, Literal, TypeAlias, TypeVar, cast
2
+ from typing import Any, Literal, Self, TypeAlias, TypeVar, cast
3
3
 
4
- import opentelemetry.context as context
5
4
  import orjson
6
5
  from latch_data_validation.data_validation import DataValidationError, validate
7
6
  from latch_o11y.o11y import trace_function, trace_function_with_span
7
+ from opentelemetry import context
8
8
  from opentelemetry.trace.span import Span
9
9
 
10
10
  from ..asgi_iface import (
@@ -17,17 +17,9 @@ from .common import Headers, tracer
17
17
 
18
18
  T = TypeVar("T")
19
19
 
20
- HTTPMethod: TypeAlias = (
21
- Literal["GET"]
22
- | Literal["HEAD"]
23
- | Literal["POST"]
24
- | Literal["PUT"]
25
- | Literal["DELETE"]
26
- | Literal["CONNECT"]
27
- | Literal["OPTIONS"]
28
- | Literal["TRACE"]
29
- | Literal["PATCH"]
30
- )
20
+ HTTPMethod: TypeAlias = Literal[
21
+ "GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"
22
+ ]
31
23
 
32
24
  # >>> O11y
33
25
 
@@ -42,34 +34,68 @@ def current_http_request_span() -> Span:
42
34
 
43
35
 
44
36
  class HTTPErrorResponse(RuntimeError):
45
- def __init__(self, status: HTTPStatus, data: Any, *, headers: Headers = {}):
37
+ def __init__(
38
+ self: Self, status: HTTPStatus, data: Any, *, headers: Headers | None = None
39
+ ) -> None:
40
+ if headers is None:
41
+ headers = {}
42
+
46
43
  self.status = status
47
44
  self.data = data
48
45
  self.headers = headers
49
46
 
50
47
 
51
48
  class HTTPInternalServerError(HTTPErrorResponse):
52
- def __init__(self, data: Any, *, headers: Headers = {}):
49
+ def __init__(self: Self, data: Any, *, headers: Headers | None = None) -> None:
53
50
  super().__init__(HTTPStatus.INTERNAL_SERVER_ERROR, data, headers=headers)
54
51
 
55
52
 
56
53
  class HTTPBadRequest(HTTPErrorResponse):
57
- def __init__(self, data: Any, *, headers: Headers = {}):
54
+ def __init__(self: Self, data: Any, *, headers: Headers | None = None) -> None:
58
55
  super().__init__(HTTPStatus.BAD_REQUEST, data, headers=headers)
59
56
 
60
57
 
61
58
  class HTTPForbidden(HTTPErrorResponse):
62
- def __init__(self, data: Any, *, headers: Headers = {}):
59
+ def __init__(self: Self, data: Any, *, headers: Headers | None = None) -> None:
63
60
  super().__init__(HTTPStatus.FORBIDDEN, data, headers=headers)
64
61
 
65
62
 
66
- class HTTPConnectionClosedError(RuntimeError): ...
63
+ class HTTPConnectionClosedError(RuntimeError):
64
+ ...
67
65
 
68
66
 
69
- # >>> I/O
67
+ # >>>
68
+ # I/O
69
+ # >>>
70
70
 
71
71
  # todo(maximsmol): add max body length limit by default
72
72
 
73
+ # >>> Receive
74
+
75
+
76
+ @trace_function(tracer)
77
+ async def receive_data(receive: HTTPReceiveCallable) -> bytes:
78
+ res = b""
79
+ more_body = True
80
+ while more_body:
81
+ msg = await receive()
82
+ if msg["type"] == "http.disconnect":
83
+ raise HTTPConnectionClosedError
84
+
85
+ res += msg["body"]
86
+ more_body = msg["more_body"]
87
+
88
+ # todo(maximsmol): accumulate instead of overriding
89
+ # todo(maximsmol): probably use the content-length header if present?
90
+ current_http_request_span().set_attribute("http.request.body.size", len(res))
91
+
92
+ return res
93
+
94
+
95
+ @trace_function(tracer)
96
+ async def receive_json(receive: HTTPReceiveCallable) -> Any:
97
+ return orjson.loads(await receive_data(receive))
98
+
73
99
 
74
100
  async def receive_class_ext(
75
101
  receive: HTTPReceiveCallable, cls: type[T]
@@ -87,47 +113,25 @@ async def receive_class(receive: HTTPReceiveCallable, cls: type[T]) -> T:
87
113
  return (await receive_class_ext(receive, cls))[1]
88
114
 
89
115
 
90
- @trace_function(tracer)
91
- async def receive_json(receive: HTTPReceiveCallable) -> Any:
92
- return orjson.loads(await receive_data(receive))
93
-
94
-
95
- async def receive_data(receive: HTTPReceiveCallable):
96
- res = b""
97
- more_body = True
98
- while more_body:
99
- with tracer.start_as_current_span("read chunk") as s:
100
- msg = await receive()
101
- if msg.type == "http.disconnect":
102
- raise HTTPConnectionClosedError()
103
-
104
- res += msg.body
105
- more_body = msg.more_body
116
+ # >>> Send
106
117
 
107
- s.set_attributes({"size": len(msg.body), "more_body": more_body})
108
118
 
109
- # todo(maximsmol): accumulate instead of overriding
110
- # todo(maximsmol): probably use the content-length header if present?
111
- current_http_request_span().set_attribute("http.request_content_length", len(res))
112
-
113
- return res
114
-
115
-
116
- @trace_function_with_span(tracer)
119
+ @trace_function(tracer)
117
120
  async def send_http_data(
118
- s: Span,
119
121
  send: HTTPSendCallable,
120
122
  status: HTTPStatus,
121
123
  data: str | bytes,
122
124
  /,
123
125
  *,
124
- content_type: str | bytes | None = "text/plain",
125
- headers: Headers = {},
126
- ):
126
+ content_type: str | bytes | None = b"text/plain",
127
+ headers: Headers | None = None,
128
+ ) -> None:
129
+ if headers is None:
130
+ headers = {}
131
+
127
132
  if isinstance(data, str):
128
133
  data = data.encode("utf-8")
129
134
 
130
- s.set_attribute("size", len(data))
131
135
  headers_to_send: list[tuple[bytes, bytes]] = [
132
136
  (b"Content-Length", str(len(data)).encode("latin-1"))
133
137
  ]
@@ -152,8 +156,6 @@ async def send_http_data(
152
156
  HTTPResponseBodyEvent(type="http.response.body", body=data, more_body=False)
153
157
  )
154
158
 
155
- current_http_request_span().set_attribute("http.response_content_length", len(data))
156
-
157
159
 
158
160
  @trace_function(tracer)
159
161
  async def send_json(
@@ -163,8 +165,8 @@ async def send_json(
163
165
  /,
164
166
  *,
165
167
  content_type: str = "application/json",
166
- headers: Headers = {},
167
- ):
168
+ headers: Headers | None = None,
169
+ ) -> None:
168
170
  return await send_http_data(
169
171
  send, status, orjson.dumps(data), content_type=content_type, headers=headers
170
172
  )
@@ -177,9 +179,9 @@ async def send_auto(
177
179
  data: str | bytes | Any,
178
180
  /,
179
181
  *,
180
- headers: Headers = {},
181
- ):
182
- if isinstance(data, str) or isinstance(data, bytes):
182
+ headers: Headers | None = None,
183
+ ) -> None:
184
+ if isinstance(data, str | bytes):
183
185
  return await send_http_data(send, status, data, headers=headers)
184
186
 
185
187
  return await send_json(send, status, data, headers=headers)
@@ -1,11 +1,12 @@
1
1
  from enum import Enum
2
- from typing import Any, TypeVar, cast
2
+ from typing import Any, Self, TypeVar, cast
3
3
 
4
- import opentelemetry.context as context
4
+ import orjson
5
5
  from latch_data_validation.data_validation import DataValidationError, validate
6
6
  from latch_o11y.o11y import trace_function, trace_function_with_span
7
+ from opentelemetry import context
7
8
  from opentelemetry.trace.span import Span
8
- import orjson
9
+ from opentelemetry.util.types import AttributeValue
9
10
 
10
11
  from ..asgi_iface import (
11
12
  WebsocketAcceptEvent,
@@ -18,20 +19,18 @@ from .common import Headers, tracer
18
19
 
19
20
  T = TypeVar("T")
20
21
 
21
- websocket_request_span_key = context.create_key("websocket_request_span")
22
+ websocket_session_span_key = context.create_key("websocket_message_span")
22
23
 
23
24
 
24
- def current_websocket_request_span() -> Span:
25
- return cast(Span, context.get_value(websocket_request_span_key))
25
+ def current_websocket_session_span() -> Span:
26
+ return cast(Span, context.get_value(websocket_session_span_key))
26
27
 
27
28
 
28
29
  # >>> Error classes
29
30
 
30
31
 
31
32
  class WebsocketStatus(int, Enum):
32
- """
33
- https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
34
- """
33
+ """https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1"""
35
34
 
36
35
  normal = 1000
37
36
  """
@@ -134,63 +133,108 @@ class WebsocketStatus(int, Enum):
134
133
 
135
134
 
136
135
  class WebsocketErrorResponse(RuntimeError):
137
- def __init__(self, status: WebsocketStatus, data: Any, *, headers: Headers = {}):
138
- super().__init__()
136
+ def __init__(self: Self, status: WebsocketStatus, data: Any) -> None:
139
137
  self.status = status
140
138
  self.data = data
141
- self.headers = headers
142
139
 
143
140
 
144
141
  class WebsocketBadMessage(WebsocketErrorResponse):
145
- def __init__(self, data: Any, *, headers: Headers = {}):
146
- super().__init__(WebsocketStatus.policy_violation, data, headers=headers)
142
+ def __init__(self: Self, data: Any) -> None:
143
+ super().__init__(WebsocketStatus.policy_violation, data)
147
144
 
148
145
 
149
146
  class WebsocketInternalServerError(WebsocketErrorResponse):
150
- def __init__(self, data: Any, *, headers: Headers = {}):
151
- super().__init__(WebsocketStatus.server_error, data, headers=headers)
147
+ def __init__(self: Self, data: Any) -> None:
148
+ super().__init__(WebsocketStatus.server_error, data)
149
+
150
+
151
+ class WebsocketConnectionClosedError(RuntimeError):
152
+ def __init__(self: Self, code: WebsocketStatus) -> None:
153
+ self.code = code
154
+
155
+
156
+ # >>>
157
+ # I/O
158
+ # >>>
159
+
160
+ # >>> Send Lifecycle
161
+
162
+
163
+ @trace_function(tracer)
164
+ async def accept_connection(
165
+ send: WebsocketSendCallable,
166
+ /,
167
+ *,
168
+ subprotocol: str | None = None,
169
+ headers: Headers | None = None,
170
+ ) -> None:
171
+ if headers is None:
172
+ headers = {}
173
+
174
+ headers_to_send: list[tuple[bytes, bytes]] = []
175
+ for k, v in headers.items():
176
+ if isinstance(k, str):
177
+ k = k.encode("latin-1")
178
+ if isinstance(v, str):
179
+ v = v.encode("latin-1")
180
+ headers_to_send.append((k, v))
152
181
 
182
+ await send(
183
+ WebsocketAcceptEvent(
184
+ type="websocket.accept", subprotocol=subprotocol, headers=headers_to_send
185
+ )
186
+ )
187
+
188
+
189
+ @trace_function_with_span(tracer)
190
+ async def close_connection(
191
+ s: Span, send: WebsocketSendCallable, /, *, status: WebsocketStatus, data: str
192
+ ) -> None:
193
+ s.set_attributes({"status": status.name, "reason": data})
194
+
195
+ await send(
196
+ WebsocketCloseEvent(type="websocket.close", code=status.value, reason=data)
197
+ )
153
198
 
154
- class WebsocketConnectionClosedError(RuntimeError): ...
199
+ current_websocket_session_span().set_attribute("websocket.close_reason", data)
155
200
 
156
201
 
157
- # >>> I/O
202
+ # >>> Receive
158
203
 
204
+ # todo(maximsmol): add max message length limit by default
159
205
 
160
- async def receive_websocket_data(receive: WebsocketReceiveCallable):
161
- with tracer.start_as_current_span("read websocket message") as s:
162
- msg = await receive()
163
206
 
164
- s.set_attribute("type", msg.type)
207
+ @trace_function(tracer)
208
+ async def receive_data(receive: WebsocketReceiveCallable) -> bytes | str:
209
+ msg = await receive()
165
210
 
166
- if msg.type == "websocket.connect":
167
- # todo(ayush): allow upgrades here as well?
168
- raise WebsocketBadMessage("connection has already been established")
211
+ if msg["type"] == "websocket.connect":
212
+ # todo(ayush): allow upgrades here as well?
213
+ raise ValueError("ASGI protocol violation: duplicate websocket.connect event")
169
214
 
170
- if msg.type == "websocket.disconnect":
171
- raise WebsocketConnectionClosedError()
215
+ if msg["type"] == "websocket.disconnect":
216
+ raise WebsocketConnectionClosedError(WebsocketStatus(msg["code"]))
172
217
 
173
- if msg.bytes is not None:
174
- res = msg.bytes
175
- elif msg.text is not None:
176
- res = msg.text.encode("utf-8")
177
- else:
178
- raise WebsocketBadMessage("empty message")
218
+ if msg["bytes"] is not None:
219
+ res = msg["bytes"]
220
+ elif msg["text"] is not None:
221
+ res = msg["text"]
222
+ else:
223
+ raise WebsocketBadMessage("empty message")
179
224
 
180
- s.set_attribute("size", len(res))
181
- return res
225
+ return res
182
226
 
183
227
 
184
228
  @trace_function(tracer)
185
- async def receive_websocket_json(receive: WebsocketReceiveCallable) -> Any:
186
- orjson.loads(await receive_websocket_data(receive))
229
+ async def receive_json(receive: WebsocketReceiveCallable) -> Any:
230
+ return orjson.loads(await receive_data(receive))
187
231
 
188
232
 
189
233
  @trace_function(tracer)
190
- async def receive_websocket_class_ext(
234
+ async def receive_class_ext(
191
235
  receive: WebsocketReceiveCallable, cls: type[T]
192
236
  ) -> tuple[Any, T]:
193
- data = await receive_websocket_json(receive)
237
+ data = await receive_json(receive)
194
238
 
195
239
  try:
196
240
  return data, validate(data, cls)
@@ -200,67 +244,28 @@ async def receive_websocket_class_ext(
200
244
 
201
245
  @trace_function(tracer)
202
246
  async def receive_websocket_class(receive: WebsocketReceiveCallable, cls: type[T]) -> T:
203
- return (await receive_websocket_class_ext(receive, cls))[1]
247
+ return (await receive_class_ext(receive, cls))[1]
204
248
 
205
249
 
206
- @trace_function_with_span(tracer)
207
- async def send_websocket_data(
208
- s: Span,
209
- send: WebsocketSendCallable,
210
- data: str | bytes,
211
- ):
212
- if isinstance(data, str):
213
- data = data.encode("utf-8")
214
-
215
- s.set_attribute("body.size", len(data))
216
-
217
- await send(WebsocketSendEvent(type="websocket.send", bytes=data, text=None))
218
-
219
- current_websocket_request_span().set_attribute(
220
- "websocket.sent_message_content_length", len(data)
221
- )
250
+ # >>> Send
222
251
 
223
252
 
224
253
  @trace_function(tracer)
225
- async def accept_websocket_connection(
226
- send: WebsocketSendCallable,
227
- receive: WebsocketReceiveCallable,
228
- /,
229
- *,
230
- subprotocol: str | None = None,
231
- headers: Headers = {},
232
- ):
233
- msg = await receive()
234
- if msg.type != "websocket.connect":
235
- raise WebsocketBadMessage("cannot accept connection without connection request")
236
-
237
- headers_to_send: list[tuple[bytes, bytes]] = []
238
-
239
- for k, v in headers.items():
240
- if isinstance(k, str):
241
- k = k.encode("latin-1")
242
- if isinstance(v, str):
243
- v = v.encode("latin-1")
244
- headers_to_send.append((k, v))
254
+ async def send_data(send: WebsocketSendCallable, data: str | bytes, /) -> None:
255
+ if isinstance(data, bytes):
256
+ await send(WebsocketSendEvent(type="websocket.send", bytes=data, text=None))
257
+ else:
258
+ await send(WebsocketSendEvent(type="websocket.send", bytes=None, text=data))
245
259
 
246
- await send(
247
- WebsocketAcceptEvent(
248
- type="websocket.accept", subprotocol=subprotocol, headers=headers_to_send
249
- )
250
- )
251
260
 
261
+ @trace_function(tracer)
262
+ async def send_json(send: WebsocketSendCallable, data: Any, /) -> None:
263
+ return await send_data(send, orjson.dumps(data))
252
264
 
253
- @trace_function_with_span(tracer)
254
- async def close_websocket_connection(
255
- s: Span,
256
- send: WebsocketSendCallable,
257
- /,
258
- *,
259
- status: WebsocketStatus,
260
- data: str,
261
- ):
262
- s.set_attribute("reason", data)
263
265
 
264
- await send(WebsocketCloseEvent("websocket.close", status.value, data))
266
+ @trace_function(tracer)
267
+ async def send_websocket_auto(send: WebsocketSendCallable, data: Any, /) -> None:
268
+ if isinstance(data, str | bytes):
269
+ return await send_data(send, data)
265
270
 
266
- current_websocket_request_span().set_attribute("websocket.http.close_reason", data)
271
+ return await send_json(send, data)