modal 0.62.115__py3-none-any.whl → 0.72.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +13 -9
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +402 -398
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -60
  11. modal/_resources.py +26 -7
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1025 -0
  15. modal/{execution_context.py → _runtime/execution_context.py} +11 -2
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +123 -6
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +50 -14
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +386 -104
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +299 -98
  29. modal/_utils/grpc_testing.py +47 -34
  30. modal/_utils/grpc_utils.py +54 -21
  31. modal/_utils/hash_utils.py +51 -10
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +3 -3
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +12 -10
  43. modal/app.py +561 -323
  44. modal/app.pyi +474 -262
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +22 -6
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +203 -42
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +61 -13
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +21 -48
  55. modal/cli/launch.py +28 -14
  56. modal/cli/network_file_system.py +57 -21
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +34 -9
  59. modal/cli/programs/vscode.py +58 -8
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +199 -96
  62. modal/cli/secret.py +5 -4
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +74 -8
  65. modal/cli/volume.py +97 -56
  66. modal/client.py +248 -144
  67. modal/client.pyi +156 -124
  68. modal/cloud_bucket_mount.py +43 -30
  69. modal/cloud_bucket_mount.pyi +32 -25
  70. modal/cls.py +528 -141
  71. modal/cls.pyi +189 -145
  72. modal/config.py +32 -15
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +50 -54
  76. modal/dict.pyi +120 -164
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +30 -43
  80. modal/experimental.py +62 -2
  81. modal/file_io.py +537 -0
  82. modal/file_io.pyi +235 -0
  83. modal/file_pattern_matcher.py +196 -0
  84. modal/functions.py +846 -428
  85. modal/functions.pyi +446 -387
  86. modal/gpu.py +57 -44
  87. modal/image.py +943 -417
  88. modal/image.pyi +584 -245
  89. modal/io_streams.py +434 -0
  90. modal/io_streams.pyi +122 -0
  91. modal/mount.py +223 -90
  92. modal/mount.pyi +241 -243
  93. modal/network_file_system.py +85 -86
  94. modal/network_file_system.pyi +151 -110
  95. modal/object.py +66 -36
  96. modal/object.pyi +166 -143
  97. modal/output.py +63 -0
  98. modal/parallel_map.py +73 -47
  99. modal/parallel_map.pyi +51 -63
  100. modal/partial_function.py +272 -107
  101. modal/partial_function.pyi +219 -120
  102. modal/proxy.py +15 -12
  103. modal/proxy.pyi +3 -8
  104. modal/queue.py +96 -72
  105. modal/queue.pyi +210 -135
  106. modal/requirements/2024.04.txt +2 -1
  107. modal/requirements/2024.10.txt +16 -0
  108. modal/requirements/README.md +21 -0
  109. modal/requirements/base-images.json +22 -0
  110. modal/retries.py +45 -4
  111. modal/runner.py +325 -203
  112. modal/runner.pyi +124 -110
  113. modal/running_app.py +27 -4
  114. modal/sandbox.py +509 -231
  115. modal/sandbox.pyi +396 -169
  116. modal/schedule.py +2 -2
  117. modal/scheduler_placement.py +20 -3
  118. modal/secret.py +41 -25
  119. modal/secret.pyi +62 -42
  120. modal/serving.py +39 -49
  121. modal/serving.pyi +37 -43
  122. modal/stream_type.py +15 -0
  123. modal/token_flow.py +5 -3
  124. modal/token_flow.pyi +37 -32
  125. modal/volume.py +123 -137
  126. modal/volume.pyi +228 -221
  127. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/METADATA +5 -5
  128. modal-0.72.13.dist-info/RECORD +174 -0
  129. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/top_level.txt +0 -1
  130. modal_docs/gen_reference_docs.py +3 -1
  131. modal_docs/mdmd/mdmd.py +0 -1
  132. modal_docs/mdmd/signatures.py +1 -2
  133. modal_global_objects/images/base_images.py +28 -0
  134. modal_global_objects/mounts/python_standalone.py +2 -2
  135. modal_proto/__init__.py +1 -1
  136. modal_proto/api.proto +1231 -531
  137. modal_proto/api_grpc.py +750 -430
  138. modal_proto/api_pb2.py +2102 -1176
  139. modal_proto/api_pb2.pyi +8859 -0
  140. modal_proto/api_pb2_grpc.py +1329 -675
  141. modal_proto/api_pb2_grpc.pyi +1416 -0
  142. modal_proto/modal_api_grpc.py +149 -0
  143. modal_proto/modal_options_grpc.py +3 -0
  144. modal_proto/options_pb2.pyi +20 -0
  145. modal_proto/options_pb2_grpc.pyi +7 -0
  146. modal_proto/py.typed +0 -0
  147. modal_version/__init__.py +1 -1
  148. modal_version/_version_generated.py +2 -2
  149. modal/_asgi.py +0 -370
  150. modal/_container_exec.py +0 -128
  151. modal/_container_io_manager.py +0 -646
  152. modal/_container_io_manager.pyi +0 -412
  153. modal/_sandbox_shell.py +0 -49
  154. modal/app_utils.py +0 -20
  155. modal/app_utils.pyi +0 -17
  156. modal/execution_context.pyi +0 -37
  157. modal/shared_volume.py +0 -23
  158. modal/shared_volume.pyi +0 -24
  159. modal-0.62.115.dist-info/RECORD +0 -207
  160. modal_global_objects/images/conda.py +0 -15
  161. modal_global_objects/images/debian_slim.py +0 -15
  162. modal_global_objects/images/micromamba.py +0 -15
  163. test/__init__.py +0 -1
  164. test/aio_test.py +0 -12
  165. test/async_utils_test.py +0 -279
  166. test/blob_test.py +0 -67
  167. test/cli_imports_test.py +0 -149
  168. test/cli_test.py +0 -674
  169. test/client_test.py +0 -203
  170. test/cloud_bucket_mount_test.py +0 -22
  171. test/cls_test.py +0 -636
  172. test/config_test.py +0 -149
  173. test/conftest.py +0 -1485
  174. test/container_app_test.py +0 -50
  175. test/container_test.py +0 -1405
  176. test/cpu_test.py +0 -23
  177. test/decorator_test.py +0 -85
  178. test/deprecation_test.py +0 -34
  179. test/dict_test.py +0 -51
  180. test/e2e_test.py +0 -68
  181. test/error_test.py +0 -7
  182. test/function_serialization_test.py +0 -32
  183. test/function_test.py +0 -791
  184. test/function_utils_test.py +0 -101
  185. test/gpu_test.py +0 -159
  186. test/grpc_utils_test.py +0 -82
  187. test/helpers.py +0 -47
  188. test/image_test.py +0 -814
  189. test/live_reload_test.py +0 -80
  190. test/lookup_test.py +0 -70
  191. test/mdmd_test.py +0 -329
  192. test/mount_test.py +0 -162
  193. test/mounted_files_test.py +0 -327
  194. test/network_file_system_test.py +0 -188
  195. test/notebook_test.py +0 -66
  196. test/object_test.py +0 -41
  197. test/package_utils_test.py +0 -25
  198. test/queue_test.py +0 -115
  199. test/resolver_test.py +0 -59
  200. test/retries_test.py +0 -67
  201. test/runner_test.py +0 -85
  202. test/sandbox_test.py +0 -191
  203. test/schedule_test.py +0 -15
  204. test/scheduler_placement_test.py +0 -57
  205. test/secret_test.py +0 -89
  206. test/serialization_test.py +0 -50
  207. test/stub_composition_test.py +0 -10
  208. test/stub_test.py +0 -361
  209. test/test_asgi_wrapper.py +0 -234
  210. test/token_flow_test.py +0 -18
  211. test/traceback_test.py +0 -135
  212. test/tunnel_test.py +0 -29
  213. test/utils_test.py +0 -88
  214. test/version_test.py +0 -14
  215. test/volume_test.py +0 -397
  216. test/watcher_test.py +0 -58
  217. test/webhook_test.py +0 -145
  218. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/LICENSE +0 -0
  219. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/WHEEL +0 -0
  220. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/entry_points.txt +0 -0
modal/_runtime/asgi.py ADDED
@@ -0,0 +1,519 @@
1
+ # Copyright Modal Labs 2022
2
+
3
+ # Note: this module isn't imported unless it's needed.
4
+ # This is because aiohttp is a pretty big dependency that adds significant latency when imported
5
+
6
+ import asyncio
7
+ from collections.abc import AsyncGenerator
8
+ from typing import Any, Callable, NoReturn, Optional, cast
9
+
10
+ import aiohttp
11
+
12
+ from modal._utils.async_utils import TaskContext
13
+ from modal._utils.blob_utils import MAX_OBJECT_SIZE_BYTES
14
+ from modal._utils.package_utils import parse_major_minor_version
15
+ from modal.config import logger
16
+ from modal.exception import ExecutionError, InvalidError
17
+ from modal.experimental import stop_fetching_inputs
18
+
19
+ from .execution_context import current_function_call_id
20
+
21
+ FIRST_MESSAGE_TIMEOUT_SECONDS = 5.0
22
+
23
+
24
+ class LifespanManager:
25
+ _startup: asyncio.Future
26
+ _shutdown: asyncio.Future
27
+ _queue: asyncio.Queue
28
+ _has_run_init: bool = False
29
+ _lifespan_supported: bool = False
30
+
31
+ def __init__(self, asgi_app, state):
32
+ self.asgi_app = asgi_app
33
+ self.state = state
34
+
35
+ async def ensure_init(self):
36
+ # making this async even though
37
+ # no async code since it has to run inside
38
+ # the event loop to tie the
39
+ # objects to the correct loop in python 3.9
40
+ if not self._has_run_init:
41
+ self._queue = asyncio.Queue()
42
+ self._startup = asyncio.Future()
43
+ self._shutdown = asyncio.Future()
44
+ self._has_run_init = True
45
+
46
+ async def background_task(self):
47
+ await self.ensure_init()
48
+
49
+ async def receive():
50
+ self._lifespan_supported = True
51
+ return await self._queue.get()
52
+
53
+ async def send(message):
54
+ if message["type"] == "lifespan.startup.complete":
55
+ self._startup.set_result(None)
56
+ elif message["type"] == "lifespan.startup.failed":
57
+ self._startup.set_exception(ExecutionError("ASGI lifespan startup failed"))
58
+ elif message["type"] == "lifespan.shutdown.complete":
59
+ self._shutdown.set_result(None)
60
+ elif message["type"] == "lifespan.shutdown.failed":
61
+ self._shutdown.set_exception(ExecutionError("ASGI lifespan shutdown failed"))
62
+ else:
63
+ raise ExecutionError(f"Unexpected message type: {message['type']}")
64
+
65
+ try:
66
+ await self.asgi_app({"type": "lifespan", "state": self.state}, receive, send)
67
+ except Exception as e:
68
+ if not self._lifespan_supported:
69
+ logger.info(f"ASGI lifespan task exited before receiving any messages with exception:\n{e}")
70
+ if not self._startup.done():
71
+ self._startup.set_result(None)
72
+ if not self._shutdown.done():
73
+ self._shutdown.set_result(None)
74
+ return
75
+
76
+ logger.error(f"Error in ASGI lifespan task: {e}")
77
+ if not self._startup.done():
78
+ self._startup.set_exception(ExecutionError("ASGI lifespan task exited startup"))
79
+ if not self._shutdown.done():
80
+ self._shutdown.set_exception(ExecutionError("ASGI lifespan task exited shutdown"))
81
+ else:
82
+ logger.info("ASGI Lifespan protocol is probably not supported by this library")
83
+ if not self._startup.done():
84
+ self._startup.set_result(None)
85
+ if not self._shutdown.done():
86
+ self._shutdown.set_result(None)
87
+
88
+ async def lifespan_startup(self):
89
+ await self.ensure_init()
90
+ self._queue.put_nowait({"type": "lifespan.startup"})
91
+ await self._startup
92
+
93
+ async def lifespan_shutdown(self):
94
+ await self.ensure_init()
95
+ self._queue.put_nowait({"type": "lifespan.shutdown"})
96
+ await self._shutdown
97
+
98
+
99
+ def asgi_app_wrapper(asgi_app, container_io_manager) -> tuple[Callable[..., AsyncGenerator], LifespanManager]:
100
+ state: dict[str, Any] = {} # used for lifespan state
101
+
102
+ async def fn(scope):
103
+ if "state" in scope:
104
+ # we don't expect users to set state in ASGI scope
105
+ # this should be handled internally by the LifespanManager
106
+ raise ExecutionError("Unpexected state in ASGI scope")
107
+ scope["state"] = state
108
+ function_call_id = current_function_call_id()
109
+ assert function_call_id, "internal error: function_call_id not set in asgi_app() scope"
110
+
111
+ messages_from_app: asyncio.Queue[dict[str, Any]] = asyncio.Queue(1)
112
+ messages_to_app: asyncio.Queue[dict[str, Any]] = asyncio.Queue(1)
113
+
114
+ async def disconnect_app():
115
+ if scope["type"] == "http":
116
+ await messages_to_app.put({"type": "http.disconnect"})
117
+ elif scope["type"] == "websocket":
118
+ await messages_to_app.put({"type": "websocket.disconnect"})
119
+
120
+ async def handle_first_input_timeout():
121
+ if scope["type"] == "http":
122
+ await messages_from_app.put({"type": "http.response.start", "status": 502})
123
+ await messages_from_app.put(
124
+ {
125
+ "type": "http.response.body",
126
+ "body": b"Missing request, possibly due to expiry or cancellation",
127
+ }
128
+ )
129
+ elif scope["type"] == "websocket":
130
+ await messages_from_app.put(
131
+ {
132
+ "type": "websocket.close",
133
+ "code": 1011,
134
+ "reason": "Missing request, possibly due to expiry or cancellation",
135
+ }
136
+ )
137
+ await disconnect_app()
138
+
139
+ async def fetch_data_in():
140
+ # Cancel an ASGI app call if the initial message is not received within a short timeout.
141
+ #
142
+ # This initial message, "http.request" or "websocket.connect", should be sent
143
+ # immediately after starting the ASGI app's function call. If it is not received, that
144
+ # indicates a request cancellation or other abnormal circumstance.
145
+ message_gen = container_io_manager.get_data_in.aio(function_call_id)
146
+ first_message_task = asyncio.create_task(message_gen.__anext__())
147
+
148
+ try:
149
+ # we are intentionally shielding + manually cancelling first_message_task, since cancellations
150
+ # can otherwise get ignored in case the cancellation and an awaited future resolve gets
151
+ # triggered in the same sequence before handing back control to the event loop.
152
+ first_message = await asyncio.shield(
153
+ asyncio.wait_for(first_message_task, FIRST_MESSAGE_TIMEOUT_SECONDS)
154
+ )
155
+ except asyncio.CancelledError:
156
+ if not first_message_task.done():
157
+ # see comment above about manual cancellation
158
+ first_message_task.cancel()
159
+ raise
160
+ except (asyncio.TimeoutError, StopAsyncIteration):
161
+ # About `StopAsyncIteration` above: The generator shouldn't typically exit,
162
+ # but if it does, we handle it like a timeout in that case.
163
+ await handle_first_input_timeout()
164
+ return
165
+ except Exception:
166
+ logger.exception("Internal error in asgi_app_wrapper")
167
+ await disconnect_app()
168
+ return
169
+
170
+ await messages_to_app.put(first_message)
171
+ async for message in message_gen:
172
+ await messages_to_app.put(message)
173
+
174
+ async def send(msg):
175
+ # Automatically split body chunks that are greater than the output size limit, to
176
+ # prevent them from being uploaded to S3.
177
+ if msg["type"] == "http.response.body":
178
+ body_chunk_size = MAX_OBJECT_SIZE_BYTES - 1024 # reserve 1 KiB for framing
179
+ body_chunk_limit = 20 * body_chunk_size
180
+ s3_chunk_size = 50 * body_chunk_size
181
+
182
+ size = len(msg.get("body", b""))
183
+ if size <= body_chunk_limit:
184
+ chunk_size = body_chunk_size
185
+ else:
186
+ # If the body is _very large_, we should still split it up to avoid sending all
187
+ # of the data in a huge chunk in S3.
188
+ chunk_size = s3_chunk_size
189
+
190
+ if size > chunk_size:
191
+ indices = list(range(0, size, chunk_size))
192
+ for i in indices[:-1]:
193
+ chunk = msg["body"][i : i + chunk_size]
194
+ await messages_from_app.put({"type": "http.response.body", "body": chunk, "more_body": True})
195
+ msg["body"] = msg["body"][indices[-1] :]
196
+
197
+ await messages_from_app.put(msg)
198
+
199
+ # Run the ASGI app, while draining the send message queue at the same time,
200
+ # and yielding results.
201
+ async with TaskContext() as tc:
202
+ tc.create_task(fetch_data_in())
203
+
204
+ async def receive():
205
+ return await messages_to_app.get()
206
+
207
+ app_task = tc.create_task(asgi_app(scope, receive, send))
208
+ pop_task = None
209
+ while True:
210
+ pop_task = tc.create_task(messages_from_app.get())
211
+
212
+ try:
213
+ done, pending = await asyncio.wait([pop_task, app_task], return_when=asyncio.FIRST_COMPLETED)
214
+ except asyncio.CancelledError:
215
+ break
216
+
217
+ if pop_task in done:
218
+ yield pop_task.result()
219
+ else:
220
+ # clean up the popping task, or we will leak unresolved tasks every loop iteration
221
+ pop_task.cancel()
222
+
223
+ if app_task in done:
224
+ while not messages_from_app.empty():
225
+ yield messages_from_app.get_nowait()
226
+ app_task.result() # consume/raise exceptions if there are any!
227
+ break
228
+
229
+ return fn, LifespanManager(asgi_app, state)
230
+
231
+
232
+ def wsgi_app_wrapper(wsgi_app, container_io_manager):
233
+ from modal._vendor.a2wsgi_wsgi import WSGIMiddleware
234
+
235
+ asgi_app = WSGIMiddleware(wsgi_app, workers=10000, send_queue_size=1) # unlimited workers
236
+ return asgi_app_wrapper(asgi_app, container_io_manager)
237
+
238
+
239
+ def webhook_asgi_app(fn: Callable[..., Any], method: str, docs: bool):
240
+ """Return a FastAPI app wrapping a function handler."""
241
+ try:
242
+ from fastapi import FastAPI
243
+ from fastapi.middleware.cors import CORSMiddleware
244
+ except ImportError as exc:
245
+ message = (
246
+ "Modal web_endpoint functions require FastAPI to be installed in the modal.Image."
247
+ ' Please update your Image definition code, e.g. with `.pip_install("fastapi[standard]")`.'
248
+ )
249
+ raise InvalidError(message) from exc
250
+
251
+ app = FastAPI(openapi_url="/openapi.json" if docs else None) # disabling openapi spec disables all docs
252
+ app.add_middleware(
253
+ CORSMiddleware,
254
+ allow_origins=["*"],
255
+ allow_credentials=True,
256
+ allow_methods=["*"],
257
+ allow_headers=["*"],
258
+ )
259
+ app.add_api_route("/", fn, methods=[method])
260
+ return app
261
+
262
+
263
+ def get_ip_address(ifname: bytes):
264
+ """Get the IP address associated with a network interface in Linux."""
265
+ import fcntl
266
+ import socket
267
+ import struct
268
+
269
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
270
+ return socket.inet_ntoa(
271
+ fcntl.ioctl(
272
+ s.fileno(),
273
+ 0x8915, # SIOCGIFADDR
274
+ struct.pack("256s", ifname[:15]),
275
+ )[20:24]
276
+ )
277
+
278
+
279
+ def wait_for_web_server(host: str, port: int, *, timeout: float) -> None:
280
+ """Wait until a web server port starts accepting TCP connections."""
281
+ import socket
282
+ import time
283
+
284
+ start_time = time.monotonic()
285
+ while True:
286
+ try:
287
+ with socket.create_connection((host, port), timeout=timeout):
288
+ break
289
+ except OSError as ex:
290
+ time.sleep(0.01)
291
+ if time.monotonic() - start_time >= timeout:
292
+ raise TimeoutError(
293
+ f"Waited too long for port {port} to start accepting connections. "
294
+ "Make sure the web server is bound to 0.0.0.0 (rather than localhost or 127.0.0.1), "
295
+ "or adjust `startup_timeout`."
296
+ ) from ex
297
+
298
+
299
+ def _add_forwarded_for_header(scope):
300
+ # we strip X-Forwarded-For headers from the scope
301
+ # but we can add it back from the ASGI scope
302
+ # https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
303
+
304
+ # X-Forwarded-For headers is a comma separated list of IP addresses
305
+ # there may be multiple X-Forwarded-For headers
306
+ # we want to prepend the client IP to the first one
307
+ # but only if it doesn't exist in one of the headers already
308
+
309
+ first_x_forwarded_for_idx = None
310
+ if "headers" in scope and "client" in scope:
311
+ client_host = scope["client"][0]
312
+
313
+ for idx, header in enumerate(scope["headers"]):
314
+ if header[0] == b"X-Forwarded-For":
315
+ if first_x_forwarded_for_idx is None:
316
+ first_x_forwarded_for_idx = idx
317
+ values = header[1].decode().split(", ")
318
+
319
+ if client_host in values:
320
+ # we already have the client IP in this header
321
+ # return early
322
+ return scope
323
+
324
+ if first_x_forwarded_for_idx is not None:
325
+ # we have X-Forwarded-For headers but they don't have the client IP
326
+ # we need to prepend the client IP to the first one
327
+ values = [client_host] + scope["headers"][first_x_forwarded_for_idx][1].decode().split(", ")
328
+ scope["headers"][first_x_forwarded_for_idx] = (b"X-Forwarded-For", ", ".join(values).encode())
329
+ else:
330
+ # we don't have X-Forwarded-For headers, we need to add one
331
+ scope["headers"].append((b"X-Forwarded-For", client_host.encode()))
332
+
333
+ return scope
334
+
335
+
336
+ async def _proxy_http_request(session: aiohttp.ClientSession, scope, receive, send) -> None:
337
+ proxy_response: aiohttp.ClientResponse
338
+
339
+ scope = _add_forwarded_for_header(scope)
340
+
341
+ async def request_generator() -> AsyncGenerator[bytes, None]:
342
+ while True:
343
+ message = await receive()
344
+ if message["type"] == "http.request":
345
+ body = message.get("body", b"")
346
+ if body:
347
+ yield body
348
+ if not message.get("more_body", False):
349
+ break
350
+ elif message["type"] == "http.disconnect":
351
+ raise ConnectionAbortedError("Disconnect message received")
352
+ else:
353
+ raise ExecutionError(f"Unexpected message type: {message['type']}")
354
+
355
+ path = scope["path"]
356
+ if scope.get("query_string"):
357
+ path += "?" + scope["query_string"].decode()
358
+
359
+ try:
360
+ proxy_response = await session.request(
361
+ method=scope["method"],
362
+ url=path,
363
+ headers=[(k.decode(), v.decode()) for k, v in scope["headers"]],
364
+ data=None if scope["method"] in aiohttp.ClientRequest.GET_METHODS else request_generator(),
365
+ allow_redirects=False,
366
+ )
367
+ except ConnectionAbortedError:
368
+ return
369
+ except aiohttp.ClientConnectionError as e: # some versions of aiohttp wrap the error
370
+ if isinstance(e.__cause__, ConnectionAbortedError):
371
+ return
372
+ raise
373
+
374
+ async def send_response() -> None:
375
+ msg = {
376
+ "type": "http.response.start",
377
+ "status": proxy_response.status,
378
+ "headers": [(k.encode(), v.encode()) for k, v in proxy_response.headers.items()],
379
+ }
380
+ await send(msg)
381
+ async for data in proxy_response.content.iter_any():
382
+ msg = {"type": "http.response.body", "body": data, "more_body": True}
383
+ await send(msg)
384
+ await send({"type": "http.response.body"})
385
+
386
+ async def listen_for_disconnect() -> NoReturn:
387
+ while True:
388
+ message = await receive()
389
+ if (
390
+ message["type"] == "http.disconnect"
391
+ and proxy_response.connection is not None
392
+ and proxy_response.connection.transport is not None
393
+ ):
394
+ proxy_response.connection.transport.abort()
395
+
396
+ async with TaskContext() as tc:
397
+ send_response_task = tc.create_task(send_response())
398
+ disconnect_task = tc.create_task(listen_for_disconnect())
399
+ await asyncio.wait([send_response_task, disconnect_task], return_when=asyncio.FIRST_COMPLETED)
400
+
401
+
402
+ async def _proxy_websocket_request(session: aiohttp.ClientSession, scope, receive, send) -> None:
403
+ first_message = await receive() # Consume the initial "websocket.connect" message.
404
+ if first_message["type"] == "websocket.disconnect":
405
+ return
406
+ elif first_message["type"] != "websocket.connect":
407
+ raise ExecutionError(f"Unexpected message type: {first_message['type']}")
408
+
409
+ path = scope["path"]
410
+ if scope.get("query_string"):
411
+ path += "?" + scope["query_string"].decode()
412
+
413
+ async with session.ws_connect(
414
+ url=path,
415
+ headers=[(k.decode(), v.decode()) for k, v in scope["headers"]], # type: ignore
416
+ protocols=scope.get("subprotocols", []),
417
+ ) as upstream_ws:
418
+
419
+ async def client_to_upstream():
420
+ while True:
421
+ client_message = await receive()
422
+ if client_message["type"] == "websocket.disconnect":
423
+ await upstream_ws.close(code=client_message.get("code", 1005))
424
+ break
425
+ elif client_message["type"] == "websocket.receive":
426
+ if client_message.get("text") is not None:
427
+ await upstream_ws.send_str(client_message["text"])
428
+ elif client_message.get("bytes") is not None:
429
+ await upstream_ws.send_bytes(client_message["bytes"])
430
+ else:
431
+ raise ExecutionError(f"Unexpected message type: {client_message['type']}")
432
+
433
+ async def upstream_to_client():
434
+ msg: dict[str, Any] = {
435
+ "type": "websocket.accept",
436
+ "subprotocol": upstream_ws.protocol,
437
+ }
438
+ await send(msg)
439
+
440
+ while True:
441
+ upstream_message = await upstream_ws.receive()
442
+ if upstream_message.type == aiohttp.WSMsgType.closed:
443
+ msg = {"type": "websocket.close"}
444
+ if upstream_message.data is not None:
445
+ msg["code"] = cast(aiohttp.WSCloseCode, upstream_message.data).value
446
+ msg["reason"] = upstream_message.extra
447
+ await send(msg)
448
+ break
449
+ elif upstream_message.type == aiohttp.WSMsgType.text:
450
+ await send({"type": "websocket.send", "text": upstream_message.data})
451
+ elif upstream_message.type == aiohttp.WSMsgType.binary:
452
+ await send({"type": "websocket.send", "bytes": upstream_message.data})
453
+ else:
454
+ pass # Ignore all other upstream WebSocket message types.
455
+
456
+ async with TaskContext() as tc:
457
+ client_to_upstream_task = tc.create_task(client_to_upstream())
458
+ upstream_to_client_task = tc.create_task(upstream_to_client())
459
+ await asyncio.wait([client_to_upstream_task, upstream_to_client_task], return_when=asyncio.FIRST_COMPLETED)
460
+
461
+
462
+ async def _proxy_lifespan_request(base_url, scope, receive, send) -> None:
463
+ session: Optional[aiohttp.ClientSession] = None
464
+ while True:
465
+ message = await receive()
466
+ if message["type"] == "lifespan.startup":
467
+ if session is None:
468
+ session = aiohttp.ClientSession(
469
+ base_url,
470
+ cookie_jar=aiohttp.DummyCookieJar(),
471
+ timeout=aiohttp.ClientTimeout(total=3600),
472
+ auto_decompress=False,
473
+ read_bufsize=1024 * 1024, # 1 MiB
474
+ **(
475
+ # These options were introduced in aiohttp 3.9, and we can remove the
476
+ # conditional after deprecating image builder version 2023.12.
477
+ dict( # type: ignore
478
+ max_line_size=64 * 1024, # 64 KiB
479
+ max_field_size=64 * 1024, # 64 KiB
480
+ )
481
+ if parse_major_minor_version(aiohttp.__version__) >= (3, 9)
482
+ else {}
483
+ ),
484
+ )
485
+ scope["state"]["session"] = session
486
+ await send({"type": "lifespan.startup.complete"})
487
+ elif message["type"] == "lifespan.shutdown":
488
+ if session is not None:
489
+ await session.close()
490
+ await send({"type": "lifespan.shutdown.complete"})
491
+ break
492
+ else:
493
+ raise ExecutionError(f"Unexpected message type: {message['type']}")
494
+
495
+
496
+ def web_server_proxy(host: str, port: int):
497
+ """Return an ASGI app that proxies requests to a web server running on the same host."""
498
+ if not 0 < port < 65536:
499
+ raise InvalidError(f"Invalid port number: {port}")
500
+
501
+ base_url = f"http://{host}:{port}"
502
+
503
+ async def web_server_proxy_app(scope, receive, send):
504
+ try:
505
+ if scope["type"] == "lifespan":
506
+ await _proxy_lifespan_request(base_url, scope, receive, send)
507
+ elif scope["type"] == "http":
508
+ await _proxy_http_request(scope["state"]["session"], scope, receive, send)
509
+ elif scope["type"] == "websocket":
510
+ await _proxy_websocket_request(scope["state"]["session"], scope, receive, send)
511
+ else:
512
+ raise NotImplementedError(f"Scope {scope} is not understood")
513
+
514
+ except aiohttp.ClientConnectorError as exc:
515
+ # If the server is not running or not reachable, we should stop fetching new inputs.
516
+ logger.warning(f"Terminating runner due to @web_server connection issue: {exc}")
517
+ stop_fetching_inputs()
518
+
519
+ return web_server_proxy_app