modal 0.62.115__py3-none-any.whl → 0.72.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/__init__.py +13 -9
- modal/__main__.py +41 -3
- modal/_clustered_functions.py +80 -0
- modal/_clustered_functions.pyi +22 -0
- modal/_container_entrypoint.py +402 -398
- modal/_ipython.py +3 -13
- modal/_location.py +17 -10
- modal/_output.py +243 -99
- modal/_pty.py +2 -2
- modal/_resolver.py +55 -60
- modal/_resources.py +26 -7
- modal/_runtime/__init__.py +1 -0
- modal/_runtime/asgi.py +519 -0
- modal/_runtime/container_io_manager.py +1025 -0
- modal/{execution_context.py → _runtime/execution_context.py} +11 -2
- modal/_runtime/telemetry.py +169 -0
- modal/_runtime/user_code_imports.py +356 -0
- modal/_serialization.py +123 -6
- modal/_traceback.py +47 -187
- modal/_tunnel.py +50 -14
- modal/_tunnel.pyi +19 -36
- modal/_utils/app_utils.py +3 -17
- modal/_utils/async_utils.py +386 -104
- modal/_utils/blob_utils.py +157 -186
- modal/_utils/bytes_io_segment_payload.py +97 -0
- modal/_utils/deprecation.py +89 -0
- modal/_utils/docker_utils.py +98 -0
- modal/_utils/function_utils.py +299 -98
- modal/_utils/grpc_testing.py +47 -34
- modal/_utils/grpc_utils.py +54 -21
- modal/_utils/hash_utils.py +51 -10
- modal/_utils/http_utils.py +39 -9
- modal/_utils/logger.py +2 -1
- modal/_utils/mount_utils.py +34 -16
- modal/_utils/name_utils.py +58 -0
- modal/_utils/package_utils.py +14 -1
- modal/_utils/pattern_utils.py +205 -0
- modal/_utils/rand_pb_testing.py +3 -3
- modal/_utils/shell_utils.py +15 -49
- modal/_vendor/a2wsgi_wsgi.py +62 -72
- modal/_vendor/cloudpickle.py +1 -1
- modal/_watcher.py +12 -10
- modal/app.py +561 -323
- modal/app.pyi +474 -262
- modal/call_graph.py +7 -6
- modal/cli/_download.py +22 -6
- modal/cli/_traceback.py +200 -0
- modal/cli/app.py +203 -42
- modal/cli/config.py +12 -5
- modal/cli/container.py +61 -13
- modal/cli/dict.py +128 -0
- modal/cli/entry_point.py +26 -13
- modal/cli/environment.py +40 -9
- modal/cli/import_refs.py +21 -48
- modal/cli/launch.py +28 -14
- modal/cli/network_file_system.py +57 -21
- modal/cli/profile.py +1 -1
- modal/cli/programs/run_jupyter.py +34 -9
- modal/cli/programs/vscode.py +58 -8
- modal/cli/queues.py +131 -0
- modal/cli/run.py +199 -96
- modal/cli/secret.py +5 -4
- modal/cli/token.py +7 -2
- modal/cli/utils.py +74 -8
- modal/cli/volume.py +97 -56
- modal/client.py +248 -144
- modal/client.pyi +156 -124
- modal/cloud_bucket_mount.py +43 -30
- modal/cloud_bucket_mount.pyi +32 -25
- modal/cls.py +528 -141
- modal/cls.pyi +189 -145
- modal/config.py +32 -15
- modal/container_process.py +177 -0
- modal/container_process.pyi +82 -0
- modal/dict.py +50 -54
- modal/dict.pyi +120 -164
- modal/environments.py +106 -5
- modal/environments.pyi +77 -25
- modal/exception.py +30 -43
- modal/experimental.py +62 -2
- modal/file_io.py +537 -0
- modal/file_io.pyi +235 -0
- modal/file_pattern_matcher.py +196 -0
- modal/functions.py +846 -428
- modal/functions.pyi +446 -387
- modal/gpu.py +57 -44
- modal/image.py +943 -417
- modal/image.pyi +584 -245
- modal/io_streams.py +434 -0
- modal/io_streams.pyi +122 -0
- modal/mount.py +223 -90
- modal/mount.pyi +241 -243
- modal/network_file_system.py +85 -86
- modal/network_file_system.pyi +151 -110
- modal/object.py +66 -36
- modal/object.pyi +166 -143
- modal/output.py +63 -0
- modal/parallel_map.py +73 -47
- modal/parallel_map.pyi +51 -63
- modal/partial_function.py +272 -107
- modal/partial_function.pyi +219 -120
- modal/proxy.py +15 -12
- modal/proxy.pyi +3 -8
- modal/queue.py +96 -72
- modal/queue.pyi +210 -135
- modal/requirements/2024.04.txt +2 -1
- modal/requirements/2024.10.txt +16 -0
- modal/requirements/README.md +21 -0
- modal/requirements/base-images.json +22 -0
- modal/retries.py +45 -4
- modal/runner.py +325 -203
- modal/runner.pyi +124 -110
- modal/running_app.py +27 -4
- modal/sandbox.py +509 -231
- modal/sandbox.pyi +396 -169
- modal/schedule.py +2 -2
- modal/scheduler_placement.py +20 -3
- modal/secret.py +41 -25
- modal/secret.pyi +62 -42
- modal/serving.py +39 -49
- modal/serving.pyi +37 -43
- modal/stream_type.py +15 -0
- modal/token_flow.py +5 -3
- modal/token_flow.pyi +37 -32
- modal/volume.py +123 -137
- modal/volume.pyi +228 -221
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/METADATA +5 -5
- modal-0.72.13.dist-info/RECORD +174 -0
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/top_level.txt +0 -1
- modal_docs/gen_reference_docs.py +3 -1
- modal_docs/mdmd/mdmd.py +0 -1
- modal_docs/mdmd/signatures.py +1 -2
- modal_global_objects/images/base_images.py +28 -0
- modal_global_objects/mounts/python_standalone.py +2 -2
- modal_proto/__init__.py +1 -1
- modal_proto/api.proto +1231 -531
- modal_proto/api_grpc.py +750 -430
- modal_proto/api_pb2.py +2102 -1176
- modal_proto/api_pb2.pyi +8859 -0
- modal_proto/api_pb2_grpc.py +1329 -675
- modal_proto/api_pb2_grpc.pyi +1416 -0
- modal_proto/modal_api_grpc.py +149 -0
- modal_proto/modal_options_grpc.py +3 -0
- modal_proto/options_pb2.pyi +20 -0
- modal_proto/options_pb2_grpc.pyi +7 -0
- modal_proto/py.typed +0 -0
- modal_version/__init__.py +1 -1
- modal_version/_version_generated.py +2 -2
- modal/_asgi.py +0 -370
- modal/_container_exec.py +0 -128
- modal/_container_io_manager.py +0 -646
- modal/_container_io_manager.pyi +0 -412
- modal/_sandbox_shell.py +0 -49
- modal/app_utils.py +0 -20
- modal/app_utils.pyi +0 -17
- modal/execution_context.pyi +0 -37
- modal/shared_volume.py +0 -23
- modal/shared_volume.pyi +0 -24
- modal-0.62.115.dist-info/RECORD +0 -207
- modal_global_objects/images/conda.py +0 -15
- modal_global_objects/images/debian_slim.py +0 -15
- modal_global_objects/images/micromamba.py +0 -15
- test/__init__.py +0 -1
- test/aio_test.py +0 -12
- test/async_utils_test.py +0 -279
- test/blob_test.py +0 -67
- test/cli_imports_test.py +0 -149
- test/cli_test.py +0 -674
- test/client_test.py +0 -203
- test/cloud_bucket_mount_test.py +0 -22
- test/cls_test.py +0 -636
- test/config_test.py +0 -149
- test/conftest.py +0 -1485
- test/container_app_test.py +0 -50
- test/container_test.py +0 -1405
- test/cpu_test.py +0 -23
- test/decorator_test.py +0 -85
- test/deprecation_test.py +0 -34
- test/dict_test.py +0 -51
- test/e2e_test.py +0 -68
- test/error_test.py +0 -7
- test/function_serialization_test.py +0 -32
- test/function_test.py +0 -791
- test/function_utils_test.py +0 -101
- test/gpu_test.py +0 -159
- test/grpc_utils_test.py +0 -82
- test/helpers.py +0 -47
- test/image_test.py +0 -814
- test/live_reload_test.py +0 -80
- test/lookup_test.py +0 -70
- test/mdmd_test.py +0 -329
- test/mount_test.py +0 -162
- test/mounted_files_test.py +0 -327
- test/network_file_system_test.py +0 -188
- test/notebook_test.py +0 -66
- test/object_test.py +0 -41
- test/package_utils_test.py +0 -25
- test/queue_test.py +0 -115
- test/resolver_test.py +0 -59
- test/retries_test.py +0 -67
- test/runner_test.py +0 -85
- test/sandbox_test.py +0 -191
- test/schedule_test.py +0 -15
- test/scheduler_placement_test.py +0 -57
- test/secret_test.py +0 -89
- test/serialization_test.py +0 -50
- test/stub_composition_test.py +0 -10
- test/stub_test.py +0 -361
- test/test_asgi_wrapper.py +0 -234
- test/token_flow_test.py +0 -18
- test/traceback_test.py +0 -135
- test/tunnel_test.py +0 -29
- test/utils_test.py +0 -88
- test/version_test.py +0 -14
- test/volume_test.py +0 -397
- test/watcher_test.py +0 -58
- test/webhook_test.py +0 -145
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/LICENSE +0 -0
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/WHEEL +0 -0
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/entry_points.txt +0 -0
modal/_runtime/asgi.py
ADDED
@@ -0,0 +1,519 @@
|
|
1
|
+
# Copyright Modal Labs 2022
|
2
|
+
|
3
|
+
# Note: this module isn't imported unless it's needed.
|
4
|
+
# This is because aiohttp is a pretty big dependency that adds significant latency when imported
|
5
|
+
|
6
|
+
import asyncio
|
7
|
+
from collections.abc import AsyncGenerator
|
8
|
+
from typing import Any, Callable, NoReturn, Optional, cast
|
9
|
+
|
10
|
+
import aiohttp
|
11
|
+
|
12
|
+
from modal._utils.async_utils import TaskContext
|
13
|
+
from modal._utils.blob_utils import MAX_OBJECT_SIZE_BYTES
|
14
|
+
from modal._utils.package_utils import parse_major_minor_version
|
15
|
+
from modal.config import logger
|
16
|
+
from modal.exception import ExecutionError, InvalidError
|
17
|
+
from modal.experimental import stop_fetching_inputs
|
18
|
+
|
19
|
+
from .execution_context import current_function_call_id
|
20
|
+
|
21
|
+
FIRST_MESSAGE_TIMEOUT_SECONDS = 5.0
|
22
|
+
|
23
|
+
|
24
|
+
class LifespanManager:
|
25
|
+
_startup: asyncio.Future
|
26
|
+
_shutdown: asyncio.Future
|
27
|
+
_queue: asyncio.Queue
|
28
|
+
_has_run_init: bool = False
|
29
|
+
_lifespan_supported: bool = False
|
30
|
+
|
31
|
+
def __init__(self, asgi_app, state):
|
32
|
+
self.asgi_app = asgi_app
|
33
|
+
self.state = state
|
34
|
+
|
35
|
+
async def ensure_init(self):
|
36
|
+
# making this async even though
|
37
|
+
# no async code since it has to run inside
|
38
|
+
# the event loop to tie the
|
39
|
+
# objects to the correct loop in python 3.9
|
40
|
+
if not self._has_run_init:
|
41
|
+
self._queue = asyncio.Queue()
|
42
|
+
self._startup = asyncio.Future()
|
43
|
+
self._shutdown = asyncio.Future()
|
44
|
+
self._has_run_init = True
|
45
|
+
|
46
|
+
async def background_task(self):
|
47
|
+
await self.ensure_init()
|
48
|
+
|
49
|
+
async def receive():
|
50
|
+
self._lifespan_supported = True
|
51
|
+
return await self._queue.get()
|
52
|
+
|
53
|
+
async def send(message):
|
54
|
+
if message["type"] == "lifespan.startup.complete":
|
55
|
+
self._startup.set_result(None)
|
56
|
+
elif message["type"] == "lifespan.startup.failed":
|
57
|
+
self._startup.set_exception(ExecutionError("ASGI lifespan startup failed"))
|
58
|
+
elif message["type"] == "lifespan.shutdown.complete":
|
59
|
+
self._shutdown.set_result(None)
|
60
|
+
elif message["type"] == "lifespan.shutdown.failed":
|
61
|
+
self._shutdown.set_exception(ExecutionError("ASGI lifespan shutdown failed"))
|
62
|
+
else:
|
63
|
+
raise ExecutionError(f"Unexpected message type: {message['type']}")
|
64
|
+
|
65
|
+
try:
|
66
|
+
await self.asgi_app({"type": "lifespan", "state": self.state}, receive, send)
|
67
|
+
except Exception as e:
|
68
|
+
if not self._lifespan_supported:
|
69
|
+
logger.info(f"ASGI lifespan task exited before receiving any messages with exception:\n{e}")
|
70
|
+
if not self._startup.done():
|
71
|
+
self._startup.set_result(None)
|
72
|
+
if not self._shutdown.done():
|
73
|
+
self._shutdown.set_result(None)
|
74
|
+
return
|
75
|
+
|
76
|
+
logger.error(f"Error in ASGI lifespan task: {e}")
|
77
|
+
if not self._startup.done():
|
78
|
+
self._startup.set_exception(ExecutionError("ASGI lifespan task exited startup"))
|
79
|
+
if not self._shutdown.done():
|
80
|
+
self._shutdown.set_exception(ExecutionError("ASGI lifespan task exited shutdown"))
|
81
|
+
else:
|
82
|
+
logger.info("ASGI Lifespan protocol is probably not supported by this library")
|
83
|
+
if not self._startup.done():
|
84
|
+
self._startup.set_result(None)
|
85
|
+
if not self._shutdown.done():
|
86
|
+
self._shutdown.set_result(None)
|
87
|
+
|
88
|
+
async def lifespan_startup(self):
|
89
|
+
await self.ensure_init()
|
90
|
+
self._queue.put_nowait({"type": "lifespan.startup"})
|
91
|
+
await self._startup
|
92
|
+
|
93
|
+
async def lifespan_shutdown(self):
|
94
|
+
await self.ensure_init()
|
95
|
+
self._queue.put_nowait({"type": "lifespan.shutdown"})
|
96
|
+
await self._shutdown
|
97
|
+
|
98
|
+
|
99
|
+
def asgi_app_wrapper(asgi_app, container_io_manager) -> tuple[Callable[..., AsyncGenerator], LifespanManager]:
|
100
|
+
state: dict[str, Any] = {} # used for lifespan state
|
101
|
+
|
102
|
+
async def fn(scope):
|
103
|
+
if "state" in scope:
|
104
|
+
# we don't expect users to set state in ASGI scope
|
105
|
+
# this should be handled internally by the LifespanManager
|
106
|
+
raise ExecutionError("Unpexected state in ASGI scope")
|
107
|
+
scope["state"] = state
|
108
|
+
function_call_id = current_function_call_id()
|
109
|
+
assert function_call_id, "internal error: function_call_id not set in asgi_app() scope"
|
110
|
+
|
111
|
+
messages_from_app: asyncio.Queue[dict[str, Any]] = asyncio.Queue(1)
|
112
|
+
messages_to_app: asyncio.Queue[dict[str, Any]] = asyncio.Queue(1)
|
113
|
+
|
114
|
+
async def disconnect_app():
|
115
|
+
if scope["type"] == "http":
|
116
|
+
await messages_to_app.put({"type": "http.disconnect"})
|
117
|
+
elif scope["type"] == "websocket":
|
118
|
+
await messages_to_app.put({"type": "websocket.disconnect"})
|
119
|
+
|
120
|
+
async def handle_first_input_timeout():
|
121
|
+
if scope["type"] == "http":
|
122
|
+
await messages_from_app.put({"type": "http.response.start", "status": 502})
|
123
|
+
await messages_from_app.put(
|
124
|
+
{
|
125
|
+
"type": "http.response.body",
|
126
|
+
"body": b"Missing request, possibly due to expiry or cancellation",
|
127
|
+
}
|
128
|
+
)
|
129
|
+
elif scope["type"] == "websocket":
|
130
|
+
await messages_from_app.put(
|
131
|
+
{
|
132
|
+
"type": "websocket.close",
|
133
|
+
"code": 1011,
|
134
|
+
"reason": "Missing request, possibly due to expiry or cancellation",
|
135
|
+
}
|
136
|
+
)
|
137
|
+
await disconnect_app()
|
138
|
+
|
139
|
+
async def fetch_data_in():
|
140
|
+
# Cancel an ASGI app call if the initial message is not received within a short timeout.
|
141
|
+
#
|
142
|
+
# This initial message, "http.request" or "websocket.connect", should be sent
|
143
|
+
# immediately after starting the ASGI app's function call. If it is not received, that
|
144
|
+
# indicates a request cancellation or other abnormal circumstance.
|
145
|
+
message_gen = container_io_manager.get_data_in.aio(function_call_id)
|
146
|
+
first_message_task = asyncio.create_task(message_gen.__anext__())
|
147
|
+
|
148
|
+
try:
|
149
|
+
# we are intentionally shielding + manually cancelling first_message_task, since cancellations
|
150
|
+
# can otherwise get ignored in case the cancellation and an awaited future resolve gets
|
151
|
+
# triggered in the same sequence before handing back control to the event loop.
|
152
|
+
first_message = await asyncio.shield(
|
153
|
+
asyncio.wait_for(first_message_task, FIRST_MESSAGE_TIMEOUT_SECONDS)
|
154
|
+
)
|
155
|
+
except asyncio.CancelledError:
|
156
|
+
if not first_message_task.done():
|
157
|
+
# see comment above about manual cancellation
|
158
|
+
first_message_task.cancel()
|
159
|
+
raise
|
160
|
+
except (asyncio.TimeoutError, StopAsyncIteration):
|
161
|
+
# About `StopAsyncIteration` above: The generator shouldn't typically exit,
|
162
|
+
# but if it does, we handle it like a timeout in that case.
|
163
|
+
await handle_first_input_timeout()
|
164
|
+
return
|
165
|
+
except Exception:
|
166
|
+
logger.exception("Internal error in asgi_app_wrapper")
|
167
|
+
await disconnect_app()
|
168
|
+
return
|
169
|
+
|
170
|
+
await messages_to_app.put(first_message)
|
171
|
+
async for message in message_gen:
|
172
|
+
await messages_to_app.put(message)
|
173
|
+
|
174
|
+
async def send(msg):
|
175
|
+
# Automatically split body chunks that are greater than the output size limit, to
|
176
|
+
# prevent them from being uploaded to S3.
|
177
|
+
if msg["type"] == "http.response.body":
|
178
|
+
body_chunk_size = MAX_OBJECT_SIZE_BYTES - 1024 # reserve 1 KiB for framing
|
179
|
+
body_chunk_limit = 20 * body_chunk_size
|
180
|
+
s3_chunk_size = 50 * body_chunk_size
|
181
|
+
|
182
|
+
size = len(msg.get("body", b""))
|
183
|
+
if size <= body_chunk_limit:
|
184
|
+
chunk_size = body_chunk_size
|
185
|
+
else:
|
186
|
+
# If the body is _very large_, we should still split it up to avoid sending all
|
187
|
+
# of the data in a huge chunk in S3.
|
188
|
+
chunk_size = s3_chunk_size
|
189
|
+
|
190
|
+
if size > chunk_size:
|
191
|
+
indices = list(range(0, size, chunk_size))
|
192
|
+
for i in indices[:-1]:
|
193
|
+
chunk = msg["body"][i : i + chunk_size]
|
194
|
+
await messages_from_app.put({"type": "http.response.body", "body": chunk, "more_body": True})
|
195
|
+
msg["body"] = msg["body"][indices[-1] :]
|
196
|
+
|
197
|
+
await messages_from_app.put(msg)
|
198
|
+
|
199
|
+
# Run the ASGI app, while draining the send message queue at the same time,
|
200
|
+
# and yielding results.
|
201
|
+
async with TaskContext() as tc:
|
202
|
+
tc.create_task(fetch_data_in())
|
203
|
+
|
204
|
+
async def receive():
|
205
|
+
return await messages_to_app.get()
|
206
|
+
|
207
|
+
app_task = tc.create_task(asgi_app(scope, receive, send))
|
208
|
+
pop_task = None
|
209
|
+
while True:
|
210
|
+
pop_task = tc.create_task(messages_from_app.get())
|
211
|
+
|
212
|
+
try:
|
213
|
+
done, pending = await asyncio.wait([pop_task, app_task], return_when=asyncio.FIRST_COMPLETED)
|
214
|
+
except asyncio.CancelledError:
|
215
|
+
break
|
216
|
+
|
217
|
+
if pop_task in done:
|
218
|
+
yield pop_task.result()
|
219
|
+
else:
|
220
|
+
# clean up the popping task, or we will leak unresolved tasks every loop iteration
|
221
|
+
pop_task.cancel()
|
222
|
+
|
223
|
+
if app_task in done:
|
224
|
+
while not messages_from_app.empty():
|
225
|
+
yield messages_from_app.get_nowait()
|
226
|
+
app_task.result() # consume/raise exceptions if there are any!
|
227
|
+
break
|
228
|
+
|
229
|
+
return fn, LifespanManager(asgi_app, state)
|
230
|
+
|
231
|
+
|
232
|
+
def wsgi_app_wrapper(wsgi_app, container_io_manager):
|
233
|
+
from modal._vendor.a2wsgi_wsgi import WSGIMiddleware
|
234
|
+
|
235
|
+
asgi_app = WSGIMiddleware(wsgi_app, workers=10000, send_queue_size=1) # unlimited workers
|
236
|
+
return asgi_app_wrapper(asgi_app, container_io_manager)
|
237
|
+
|
238
|
+
|
239
|
+
def webhook_asgi_app(fn: Callable[..., Any], method: str, docs: bool):
|
240
|
+
"""Return a FastAPI app wrapping a function handler."""
|
241
|
+
try:
|
242
|
+
from fastapi import FastAPI
|
243
|
+
from fastapi.middleware.cors import CORSMiddleware
|
244
|
+
except ImportError as exc:
|
245
|
+
message = (
|
246
|
+
"Modal web_endpoint functions require FastAPI to be installed in the modal.Image."
|
247
|
+
' Please update your Image definition code, e.g. with `.pip_install("fastapi[standard]")`.'
|
248
|
+
)
|
249
|
+
raise InvalidError(message) from exc
|
250
|
+
|
251
|
+
app = FastAPI(openapi_url="/openapi.json" if docs else None) # disabling openapi spec disables all docs
|
252
|
+
app.add_middleware(
|
253
|
+
CORSMiddleware,
|
254
|
+
allow_origins=["*"],
|
255
|
+
allow_credentials=True,
|
256
|
+
allow_methods=["*"],
|
257
|
+
allow_headers=["*"],
|
258
|
+
)
|
259
|
+
app.add_api_route("/", fn, methods=[method])
|
260
|
+
return app
|
261
|
+
|
262
|
+
|
263
|
+
def get_ip_address(ifname: bytes):
|
264
|
+
"""Get the IP address associated with a network interface in Linux."""
|
265
|
+
import fcntl
|
266
|
+
import socket
|
267
|
+
import struct
|
268
|
+
|
269
|
+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
270
|
+
return socket.inet_ntoa(
|
271
|
+
fcntl.ioctl(
|
272
|
+
s.fileno(),
|
273
|
+
0x8915, # SIOCGIFADDR
|
274
|
+
struct.pack("256s", ifname[:15]),
|
275
|
+
)[20:24]
|
276
|
+
)
|
277
|
+
|
278
|
+
|
279
|
+
def wait_for_web_server(host: str, port: int, *, timeout: float) -> None:
|
280
|
+
"""Wait until a web server port starts accepting TCP connections."""
|
281
|
+
import socket
|
282
|
+
import time
|
283
|
+
|
284
|
+
start_time = time.monotonic()
|
285
|
+
while True:
|
286
|
+
try:
|
287
|
+
with socket.create_connection((host, port), timeout=timeout):
|
288
|
+
break
|
289
|
+
except OSError as ex:
|
290
|
+
time.sleep(0.01)
|
291
|
+
if time.monotonic() - start_time >= timeout:
|
292
|
+
raise TimeoutError(
|
293
|
+
f"Waited too long for port {port} to start accepting connections. "
|
294
|
+
"Make sure the web server is bound to 0.0.0.0 (rather than localhost or 127.0.0.1), "
|
295
|
+
"or adjust `startup_timeout`."
|
296
|
+
) from ex
|
297
|
+
|
298
|
+
|
299
|
+
def _add_forwarded_for_header(scope):
|
300
|
+
# we strip X-Forwarded-For headers from the scope
|
301
|
+
# but we can add it back from the ASGI scope
|
302
|
+
# https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
|
303
|
+
|
304
|
+
# X-Forwarded-For headers is a comma separated list of IP addresses
|
305
|
+
# there may be multiple X-Forwarded-For headers
|
306
|
+
# we want to prepend the client IP to the first one
|
307
|
+
# but only if it doesn't exist in one of the headers already
|
308
|
+
|
309
|
+
first_x_forwarded_for_idx = None
|
310
|
+
if "headers" in scope and "client" in scope:
|
311
|
+
client_host = scope["client"][0]
|
312
|
+
|
313
|
+
for idx, header in enumerate(scope["headers"]):
|
314
|
+
if header[0] == b"X-Forwarded-For":
|
315
|
+
if first_x_forwarded_for_idx is None:
|
316
|
+
first_x_forwarded_for_idx = idx
|
317
|
+
values = header[1].decode().split(", ")
|
318
|
+
|
319
|
+
if client_host in values:
|
320
|
+
# we already have the client IP in this header
|
321
|
+
# return early
|
322
|
+
return scope
|
323
|
+
|
324
|
+
if first_x_forwarded_for_idx is not None:
|
325
|
+
# we have X-Forwarded-For headers but they don't have the client IP
|
326
|
+
# we need to prepend the client IP to the first one
|
327
|
+
values = [client_host] + scope["headers"][first_x_forwarded_for_idx][1].decode().split(", ")
|
328
|
+
scope["headers"][first_x_forwarded_for_idx] = (b"X-Forwarded-For", ", ".join(values).encode())
|
329
|
+
else:
|
330
|
+
# we don't have X-Forwarded-For headers, we need to add one
|
331
|
+
scope["headers"].append((b"X-Forwarded-For", client_host.encode()))
|
332
|
+
|
333
|
+
return scope
|
334
|
+
|
335
|
+
|
336
|
+
async def _proxy_http_request(session: aiohttp.ClientSession, scope, receive, send) -> None:
|
337
|
+
proxy_response: aiohttp.ClientResponse
|
338
|
+
|
339
|
+
scope = _add_forwarded_for_header(scope)
|
340
|
+
|
341
|
+
async def request_generator() -> AsyncGenerator[bytes, None]:
|
342
|
+
while True:
|
343
|
+
message = await receive()
|
344
|
+
if message["type"] == "http.request":
|
345
|
+
body = message.get("body", b"")
|
346
|
+
if body:
|
347
|
+
yield body
|
348
|
+
if not message.get("more_body", False):
|
349
|
+
break
|
350
|
+
elif message["type"] == "http.disconnect":
|
351
|
+
raise ConnectionAbortedError("Disconnect message received")
|
352
|
+
else:
|
353
|
+
raise ExecutionError(f"Unexpected message type: {message['type']}")
|
354
|
+
|
355
|
+
path = scope["path"]
|
356
|
+
if scope.get("query_string"):
|
357
|
+
path += "?" + scope["query_string"].decode()
|
358
|
+
|
359
|
+
try:
|
360
|
+
proxy_response = await session.request(
|
361
|
+
method=scope["method"],
|
362
|
+
url=path,
|
363
|
+
headers=[(k.decode(), v.decode()) for k, v in scope["headers"]],
|
364
|
+
data=None if scope["method"] in aiohttp.ClientRequest.GET_METHODS else request_generator(),
|
365
|
+
allow_redirects=False,
|
366
|
+
)
|
367
|
+
except ConnectionAbortedError:
|
368
|
+
return
|
369
|
+
except aiohttp.ClientConnectionError as e: # some versions of aiohttp wrap the error
|
370
|
+
if isinstance(e.__cause__, ConnectionAbortedError):
|
371
|
+
return
|
372
|
+
raise
|
373
|
+
|
374
|
+
async def send_response() -> None:
|
375
|
+
msg = {
|
376
|
+
"type": "http.response.start",
|
377
|
+
"status": proxy_response.status,
|
378
|
+
"headers": [(k.encode(), v.encode()) for k, v in proxy_response.headers.items()],
|
379
|
+
}
|
380
|
+
await send(msg)
|
381
|
+
async for data in proxy_response.content.iter_any():
|
382
|
+
msg = {"type": "http.response.body", "body": data, "more_body": True}
|
383
|
+
await send(msg)
|
384
|
+
await send({"type": "http.response.body"})
|
385
|
+
|
386
|
+
async def listen_for_disconnect() -> NoReturn:
|
387
|
+
while True:
|
388
|
+
message = await receive()
|
389
|
+
if (
|
390
|
+
message["type"] == "http.disconnect"
|
391
|
+
and proxy_response.connection is not None
|
392
|
+
and proxy_response.connection.transport is not None
|
393
|
+
):
|
394
|
+
proxy_response.connection.transport.abort()
|
395
|
+
|
396
|
+
async with TaskContext() as tc:
|
397
|
+
send_response_task = tc.create_task(send_response())
|
398
|
+
disconnect_task = tc.create_task(listen_for_disconnect())
|
399
|
+
await asyncio.wait([send_response_task, disconnect_task], return_when=asyncio.FIRST_COMPLETED)
|
400
|
+
|
401
|
+
|
402
|
+
async def _proxy_websocket_request(session: aiohttp.ClientSession, scope, receive, send) -> None:
|
403
|
+
first_message = await receive() # Consume the initial "websocket.connect" message.
|
404
|
+
if first_message["type"] == "websocket.disconnect":
|
405
|
+
return
|
406
|
+
elif first_message["type"] != "websocket.connect":
|
407
|
+
raise ExecutionError(f"Unexpected message type: {first_message['type']}")
|
408
|
+
|
409
|
+
path = scope["path"]
|
410
|
+
if scope.get("query_string"):
|
411
|
+
path += "?" + scope["query_string"].decode()
|
412
|
+
|
413
|
+
async with session.ws_connect(
|
414
|
+
url=path,
|
415
|
+
headers=[(k.decode(), v.decode()) for k, v in scope["headers"]], # type: ignore
|
416
|
+
protocols=scope.get("subprotocols", []),
|
417
|
+
) as upstream_ws:
|
418
|
+
|
419
|
+
async def client_to_upstream():
|
420
|
+
while True:
|
421
|
+
client_message = await receive()
|
422
|
+
if client_message["type"] == "websocket.disconnect":
|
423
|
+
await upstream_ws.close(code=client_message.get("code", 1005))
|
424
|
+
break
|
425
|
+
elif client_message["type"] == "websocket.receive":
|
426
|
+
if client_message.get("text") is not None:
|
427
|
+
await upstream_ws.send_str(client_message["text"])
|
428
|
+
elif client_message.get("bytes") is not None:
|
429
|
+
await upstream_ws.send_bytes(client_message["bytes"])
|
430
|
+
else:
|
431
|
+
raise ExecutionError(f"Unexpected message type: {client_message['type']}")
|
432
|
+
|
433
|
+
async def upstream_to_client():
|
434
|
+
msg: dict[str, Any] = {
|
435
|
+
"type": "websocket.accept",
|
436
|
+
"subprotocol": upstream_ws.protocol,
|
437
|
+
}
|
438
|
+
await send(msg)
|
439
|
+
|
440
|
+
while True:
|
441
|
+
upstream_message = await upstream_ws.receive()
|
442
|
+
if upstream_message.type == aiohttp.WSMsgType.closed:
|
443
|
+
msg = {"type": "websocket.close"}
|
444
|
+
if upstream_message.data is not None:
|
445
|
+
msg["code"] = cast(aiohttp.WSCloseCode, upstream_message.data).value
|
446
|
+
msg["reason"] = upstream_message.extra
|
447
|
+
await send(msg)
|
448
|
+
break
|
449
|
+
elif upstream_message.type == aiohttp.WSMsgType.text:
|
450
|
+
await send({"type": "websocket.send", "text": upstream_message.data})
|
451
|
+
elif upstream_message.type == aiohttp.WSMsgType.binary:
|
452
|
+
await send({"type": "websocket.send", "bytes": upstream_message.data})
|
453
|
+
else:
|
454
|
+
pass # Ignore all other upstream WebSocket message types.
|
455
|
+
|
456
|
+
async with TaskContext() as tc:
|
457
|
+
client_to_upstream_task = tc.create_task(client_to_upstream())
|
458
|
+
upstream_to_client_task = tc.create_task(upstream_to_client())
|
459
|
+
await asyncio.wait([client_to_upstream_task, upstream_to_client_task], return_when=asyncio.FIRST_COMPLETED)
|
460
|
+
|
461
|
+
|
462
|
+
async def _proxy_lifespan_request(base_url, scope, receive, send) -> None:
|
463
|
+
session: Optional[aiohttp.ClientSession] = None
|
464
|
+
while True:
|
465
|
+
message = await receive()
|
466
|
+
if message["type"] == "lifespan.startup":
|
467
|
+
if session is None:
|
468
|
+
session = aiohttp.ClientSession(
|
469
|
+
base_url,
|
470
|
+
cookie_jar=aiohttp.DummyCookieJar(),
|
471
|
+
timeout=aiohttp.ClientTimeout(total=3600),
|
472
|
+
auto_decompress=False,
|
473
|
+
read_bufsize=1024 * 1024, # 1 MiB
|
474
|
+
**(
|
475
|
+
# These options were introduced in aiohttp 3.9, and we can remove the
|
476
|
+
# conditional after deprecating image builder version 2023.12.
|
477
|
+
dict( # type: ignore
|
478
|
+
max_line_size=64 * 1024, # 64 KiB
|
479
|
+
max_field_size=64 * 1024, # 64 KiB
|
480
|
+
)
|
481
|
+
if parse_major_minor_version(aiohttp.__version__) >= (3, 9)
|
482
|
+
else {}
|
483
|
+
),
|
484
|
+
)
|
485
|
+
scope["state"]["session"] = session
|
486
|
+
await send({"type": "lifespan.startup.complete"})
|
487
|
+
elif message["type"] == "lifespan.shutdown":
|
488
|
+
if session is not None:
|
489
|
+
await session.close()
|
490
|
+
await send({"type": "lifespan.shutdown.complete"})
|
491
|
+
break
|
492
|
+
else:
|
493
|
+
raise ExecutionError(f"Unexpected message type: {message['type']}")
|
494
|
+
|
495
|
+
|
496
|
+
def web_server_proxy(host: str, port: int):
|
497
|
+
"""Return an ASGI app that proxies requests to a web server running on the same host."""
|
498
|
+
if not 0 < port < 65536:
|
499
|
+
raise InvalidError(f"Invalid port number: {port}")
|
500
|
+
|
501
|
+
base_url = f"http://{host}:{port}"
|
502
|
+
|
503
|
+
async def web_server_proxy_app(scope, receive, send):
|
504
|
+
try:
|
505
|
+
if scope["type"] == "lifespan":
|
506
|
+
await _proxy_lifespan_request(base_url, scope, receive, send)
|
507
|
+
elif scope["type"] == "http":
|
508
|
+
await _proxy_http_request(scope["state"]["session"], scope, receive, send)
|
509
|
+
elif scope["type"] == "websocket":
|
510
|
+
await _proxy_websocket_request(scope["state"]["session"], scope, receive, send)
|
511
|
+
else:
|
512
|
+
raise NotImplementedError(f"Scope {scope} is not understood")
|
513
|
+
|
514
|
+
except aiohttp.ClientConnectorError as exc:
|
515
|
+
# If the server is not running or not reachable, we should stop fetching new inputs.
|
516
|
+
logger.warning(f"Terminating runner due to @web_server connection issue: {exc}")
|
517
|
+
stop_fetching_inputs()
|
518
|
+
|
519
|
+
return web_server_proxy_app
|