modal 0.62.16__py3-none-any.whl → 0.72.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/__init__.py +17 -13
- modal/__main__.py +41 -3
- modal/_clustered_functions.py +80 -0
- modal/_clustered_functions.pyi +22 -0
- modal/_container_entrypoint.py +420 -937
- modal/_ipython.py +3 -13
- modal/_location.py +17 -10
- modal/_output.py +243 -99
- modal/_pty.py +2 -2
- modal/_resolver.py +55 -59
- modal/_resources.py +51 -0
- modal/_runtime/__init__.py +1 -0
- modal/_runtime/asgi.py +519 -0
- modal/_runtime/container_io_manager.py +1036 -0
- modal/_runtime/execution_context.py +89 -0
- modal/_runtime/telemetry.py +169 -0
- modal/_runtime/user_code_imports.py +356 -0
- modal/_serialization.py +134 -9
- modal/_traceback.py +47 -187
- modal/_tunnel.py +52 -16
- modal/_tunnel.pyi +19 -36
- modal/_utils/app_utils.py +3 -17
- modal/_utils/async_utils.py +479 -100
- modal/_utils/blob_utils.py +157 -186
- modal/_utils/bytes_io_segment_payload.py +97 -0
- modal/_utils/deprecation.py +89 -0
- modal/_utils/docker_utils.py +98 -0
- modal/_utils/function_utils.py +460 -171
- modal/_utils/grpc_testing.py +47 -31
- modal/_utils/grpc_utils.py +62 -109
- modal/_utils/hash_utils.py +61 -19
- modal/_utils/http_utils.py +39 -9
- modal/_utils/logger.py +2 -1
- modal/_utils/mount_utils.py +34 -16
- modal/_utils/name_utils.py +58 -0
- modal/_utils/package_utils.py +14 -1
- modal/_utils/pattern_utils.py +205 -0
- modal/_utils/rand_pb_testing.py +5 -7
- modal/_utils/shell_utils.py +15 -49
- modal/_vendor/a2wsgi_wsgi.py +62 -72
- modal/_vendor/cloudpickle.py +1 -1
- modal/_watcher.py +14 -12
- modal/app.py +1003 -314
- modal/app.pyi +540 -264
- modal/call_graph.py +7 -6
- modal/cli/_download.py +63 -53
- modal/cli/_traceback.py +200 -0
- modal/cli/app.py +205 -45
- modal/cli/config.py +12 -5
- modal/cli/container.py +62 -14
- modal/cli/dict.py +128 -0
- modal/cli/entry_point.py +26 -13
- modal/cli/environment.py +40 -9
- modal/cli/import_refs.py +64 -58
- modal/cli/launch.py +32 -18
- modal/cli/network_file_system.py +64 -83
- modal/cli/profile.py +1 -1
- modal/cli/programs/run_jupyter.py +35 -10
- modal/cli/programs/vscode.py +60 -10
- modal/cli/queues.py +131 -0
- modal/cli/run.py +234 -131
- modal/cli/secret.py +8 -7
- modal/cli/token.py +7 -2
- modal/cli/utils.py +79 -10
- modal/cli/volume.py +110 -109
- modal/client.py +250 -144
- modal/client.pyi +157 -118
- modal/cloud_bucket_mount.py +108 -34
- modal/cloud_bucket_mount.pyi +32 -38
- modal/cls.py +535 -148
- modal/cls.pyi +190 -146
- modal/config.py +41 -19
- modal/container_process.py +177 -0
- modal/container_process.pyi +82 -0
- modal/dict.py +111 -65
- modal/dict.pyi +136 -131
- modal/environments.py +106 -5
- modal/environments.pyi +77 -25
- modal/exception.py +34 -43
- modal/experimental.py +61 -2
- modal/extensions/ipython.py +5 -5
- modal/file_io.py +537 -0
- modal/file_io.pyi +235 -0
- modal/file_pattern_matcher.py +197 -0
- modal/functions.py +906 -911
- modal/functions.pyi +466 -430
- modal/gpu.py +57 -44
- modal/image.py +1089 -479
- modal/image.pyi +584 -228
- modal/io_streams.py +434 -0
- modal/io_streams.pyi +122 -0
- modal/mount.py +314 -101
- modal/mount.pyi +241 -235
- modal/network_file_system.py +92 -92
- modal/network_file_system.pyi +152 -110
- modal/object.py +67 -36
- modal/object.pyi +166 -143
- modal/output.py +63 -0
- modal/parallel_map.py +434 -0
- modal/parallel_map.pyi +75 -0
- modal/partial_function.py +282 -117
- modal/partial_function.pyi +222 -129
- modal/proxy.py +15 -12
- modal/proxy.pyi +3 -8
- modal/queue.py +182 -65
- modal/queue.pyi +218 -118
- modal/requirements/2024.04.txt +29 -0
- modal/requirements/2024.10.txt +16 -0
- modal/requirements/README.md +21 -0
- modal/requirements/base-images.json +22 -0
- modal/retries.py +48 -7
- modal/runner.py +459 -156
- modal/runner.pyi +135 -71
- modal/running_app.py +38 -0
- modal/sandbox.py +514 -236
- modal/sandbox.pyi +397 -169
- modal/schedule.py +4 -4
- modal/scheduler_placement.py +20 -3
- modal/secret.py +56 -31
- modal/secret.pyi +62 -42
- modal/serving.py +51 -56
- modal/serving.pyi +44 -36
- modal/stream_type.py +15 -0
- modal/token_flow.py +5 -3
- modal/token_flow.pyi +37 -32
- modal/volume.py +285 -157
- modal/volume.pyi +249 -184
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/METADATA +7 -7
- modal-0.72.11.dist-info/RECORD +174 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
- modal_docs/gen_reference_docs.py +3 -1
- modal_docs/mdmd/mdmd.py +0 -1
- modal_docs/mdmd/signatures.py +5 -2
- modal_global_objects/images/base_images.py +28 -0
- modal_global_objects/mounts/python_standalone.py +2 -2
- modal_proto/__init__.py +1 -1
- modal_proto/api.proto +1288 -533
- modal_proto/api_grpc.py +856 -456
- modal_proto/api_pb2.py +2165 -1157
- modal_proto/api_pb2.pyi +8859 -0
- modal_proto/api_pb2_grpc.py +1674 -855
- modal_proto/api_pb2_grpc.pyi +1416 -0
- modal_proto/modal_api_grpc.py +149 -0
- modal_proto/modal_options_grpc.py +3 -0
- modal_proto/options_pb2.pyi +20 -0
- modal_proto/options_pb2_grpc.pyi +7 -0
- modal_proto/py.typed +0 -0
- modal_version/__init__.py +1 -1
- modal_version/_version_generated.py +2 -2
- modal/_asgi.py +0 -370
- modal/_container_entrypoint.pyi +0 -378
- modal/_container_exec.py +0 -128
- modal/_sandbox_shell.py +0 -49
- modal/shared_volume.py +0 -23
- modal/shared_volume.pyi +0 -24
- modal/stub.py +0 -783
- modal/stub.pyi +0 -332
- modal-0.62.16.dist-info/RECORD +0 -198
- modal_global_objects/images/conda.py +0 -15
- modal_global_objects/images/debian_slim.py +0 -15
- modal_global_objects/images/micromamba.py +0 -15
- test/__init__.py +0 -1
- test/aio_test.py +0 -12
- test/async_utils_test.py +0 -262
- test/blob_test.py +0 -67
- test/cli_imports_test.py +0 -149
- test/cli_test.py +0 -659
- test/client_test.py +0 -194
- test/cls_test.py +0 -630
- test/config_test.py +0 -137
- test/conftest.py +0 -1420
- test/container_app_test.py +0 -32
- test/container_test.py +0 -1389
- test/cpu_test.py +0 -23
- test/decorator_test.py +0 -85
- test/deprecation_test.py +0 -34
- test/dict_test.py +0 -33
- test/e2e_test.py +0 -68
- test/error_test.py +0 -7
- test/function_serialization_test.py +0 -32
- test/function_test.py +0 -653
- test/function_utils_test.py +0 -101
- test/gpu_test.py +0 -159
- test/grpc_utils_test.py +0 -141
- test/helpers.py +0 -42
- test/image_test.py +0 -669
- test/live_reload_test.py +0 -80
- test/lookup_test.py +0 -70
- test/mdmd_test.py +0 -329
- test/mount_test.py +0 -162
- test/mounted_files_test.py +0 -329
- test/network_file_system_test.py +0 -181
- test/notebook_test.py +0 -66
- test/object_test.py +0 -41
- test/package_utils_test.py +0 -25
- test/queue_test.py +0 -97
- test/resolver_test.py +0 -58
- test/retries_test.py +0 -67
- test/runner_test.py +0 -85
- test/sandbox_test.py +0 -191
- test/schedule_test.py +0 -15
- test/scheduler_placement_test.py +0 -29
- test/secret_test.py +0 -78
- test/serialization_test.py +0 -42
- test/stub_composition_test.py +0 -10
- test/stub_test.py +0 -360
- test/test_asgi_wrapper.py +0 -234
- test/token_flow_test.py +0 -18
- test/traceback_test.py +0 -135
- test/tunnel_test.py +0 -29
- test/utils_test.py +0 -88
- test/version_test.py +0 -14
- test/volume_test.py +0 -341
- test/watcher_test.py +0 -30
- test/webhook_test.py +0 -146
- /modal/{requirements.312.txt → requirements/2023.12.312.txt} +0 -0
- /modal/{requirements.txt → requirements/2023.12.txt} +0 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
modal/_container_entrypoint.py
CHANGED
@@ -1,63 +1,103 @@
|
|
1
1
|
# Copyright Modal Labs 2022
|
2
|
-
|
2
|
+
# ruff: noqa: E402
|
3
|
+
import os
|
4
|
+
|
5
|
+
from modal._runtime.user_code_imports import Service, import_class_service, import_single_function_service
|
6
|
+
|
7
|
+
telemetry_socket = os.environ.get("MODAL_TELEMETRY_SOCKET")
|
8
|
+
if telemetry_socket:
|
9
|
+
from ._runtime.telemetry import instrument_imports
|
10
|
+
|
11
|
+
instrument_imports(telemetry_socket)
|
3
12
|
|
4
13
|
import asyncio
|
5
|
-
import
|
6
|
-
import contextlib
|
7
|
-
import importlib
|
14
|
+
import concurrent.futures
|
8
15
|
import inspect
|
9
|
-
import
|
10
|
-
import math
|
11
|
-
import os
|
16
|
+
import queue
|
12
17
|
import signal
|
13
18
|
import sys
|
14
19
|
import threading
|
15
20
|
import time
|
16
|
-
import
|
17
|
-
from
|
18
|
-
from dataclasses import dataclass
|
19
|
-
from pathlib import Path
|
20
|
-
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Callable, List, Optional, Set, Tuple, Type
|
21
|
+
from collections.abc import Sequence
|
22
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional
|
21
23
|
|
22
|
-
from
|
24
|
+
from google.protobuf.message import Message
|
23
25
|
|
26
|
+
from modal._clustered_functions import initialize_clustered_function
|
27
|
+
from modal._proxy_tunnel import proxy_tunnel
|
28
|
+
from modal._serialization import deserialize, deserialize_proto_params
|
29
|
+
from modal._utils.async_utils import TaskContext, synchronizer
|
30
|
+
from modal._utils.function_utils import (
|
31
|
+
callable_has_non_self_params,
|
32
|
+
)
|
33
|
+
from modal.app import App, _App
|
34
|
+
from modal.client import Client, _Client
|
35
|
+
from modal.config import logger
|
36
|
+
from modal.exception import ExecutionError, InputCancellation, InvalidError
|
37
|
+
from modal.partial_function import (
|
38
|
+
_find_callables_for_obj,
|
39
|
+
_PartialFunctionFlags,
|
40
|
+
)
|
41
|
+
from modal.running_app import RunningApp
|
24
42
|
from modal_proto import api_pb2
|
25
43
|
|
26
|
-
from .
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
webhook_asgi_app,
|
32
|
-
wsgi_app_wrapper,
|
44
|
+
from ._runtime.container_io_manager import (
|
45
|
+
ContainerIOManager,
|
46
|
+
IOContext,
|
47
|
+
UserException,
|
48
|
+
_ContainerIOManager,
|
33
49
|
)
|
34
|
-
from .
|
35
|
-
from ._serialization import deserialize, deserialize_data_format, serialize, serialize_data_format
|
36
|
-
from ._traceback import extract_traceback
|
37
|
-
from ._utils.async_utils import TaskContext, asyncify, synchronize_api, synchronizer
|
38
|
-
from ._utils.blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload
|
39
|
-
from ._utils.function_utils import LocalFunctionError, is_async as get_is_async, is_global_function, method_has_params
|
40
|
-
from ._utils.grpc_utils import retry_transient_errors
|
41
|
-
from .app import ContainerApp, _container_app, _ContainerApp, interact
|
42
|
-
from .client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, Client, _Client
|
43
|
-
from .cls import Cls
|
44
|
-
from .config import config, logger
|
45
|
-
from .exception import InputCancellation, InvalidError
|
46
|
-
from .functions import Function, _Function, _set_current_context_ids, _stream_function_call_data
|
47
|
-
from .partial_function import _find_callables_for_obj, _PartialFunctionFlags
|
48
|
-
from .stub import _Stub
|
50
|
+
from ._runtime.execution_context import _set_current_context_ids
|
49
51
|
|
50
52
|
if TYPE_CHECKING:
|
51
|
-
|
53
|
+
import modal._runtime.container_io_manager
|
54
|
+
import modal.object
|
55
|
+
|
56
|
+
|
57
|
+
class DaemonizedThreadPool:
|
58
|
+
# Used instead of ThreadPoolExecutor, since the latter won't allow
|
59
|
+
# the interpreter to shut down before the currently running tasks
|
60
|
+
# have finished
|
61
|
+
def __init__(self, max_threads: int):
|
62
|
+
self.max_threads = max_threads
|
63
|
+
|
64
|
+
def __enter__(self):
|
65
|
+
self.spawned_workers = 0
|
66
|
+
self.inputs: queue.Queue[Any] = queue.Queue()
|
67
|
+
self.finished = threading.Event()
|
68
|
+
return self
|
69
|
+
|
70
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
71
|
+
self.finished.set()
|
52
72
|
|
53
|
-
|
73
|
+
if exc_type is None:
|
74
|
+
self.inputs.join()
|
75
|
+
else:
|
76
|
+
# special case - allows us to exit the
|
77
|
+
if self.inputs.unfinished_tasks:
|
78
|
+
logger.info(
|
79
|
+
f"Exiting DaemonizedThreadPool with {self.inputs.unfinished_tasks} active "
|
80
|
+
f"inputs due to exception: {repr(exc_type)}"
|
81
|
+
)
|
54
82
|
|
55
|
-
|
83
|
+
def submit(self, func, *args):
|
84
|
+
def worker_thread():
|
85
|
+
while not self.finished.is_set():
|
86
|
+
try:
|
87
|
+
_func, _args = self.inputs.get(timeout=1)
|
88
|
+
except queue.Empty:
|
89
|
+
continue
|
90
|
+
try:
|
91
|
+
_func(*_args)
|
92
|
+
except BaseException:
|
93
|
+
logger.exception(f"Exception raised by {_func} in DaemonizedThreadPool worker!")
|
94
|
+
self.inputs.task_done()
|
56
95
|
|
96
|
+
if self.spawned_workers < self.max_threads:
|
97
|
+
threading.Thread(target=worker_thread, daemon=True).start()
|
98
|
+
self.spawned_workers += 1
|
57
99
|
|
58
|
-
|
59
|
-
# Used to shut down the task gracefully
|
60
|
-
pass
|
100
|
+
self.inputs.put((func, args))
|
61
101
|
|
62
102
|
|
63
103
|
class UserCodeEventLoop:
|
@@ -76,14 +116,25 @@ class UserCodeEventLoop:
|
|
76
116
|
|
77
117
|
def __enter__(self):
|
78
118
|
self.loop = asyncio.new_event_loop()
|
119
|
+
self.tasks = set()
|
79
120
|
return self
|
80
121
|
|
81
122
|
def __exit__(self, exc_type, exc_value, traceback):
|
82
123
|
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
|
83
124
|
if sys.version_info[:2] >= (3, 9):
|
84
125
|
self.loop.run_until_complete(self.loop.shutdown_default_executor()) # Introduced in Python 3.9
|
126
|
+
|
127
|
+
for task in self.tasks:
|
128
|
+
task.cancel()
|
129
|
+
|
85
130
|
self.loop.close()
|
86
131
|
|
132
|
+
def create_task(self, coro):
|
133
|
+
task = self.loop.create_task(coro)
|
134
|
+
self.tasks.add(task)
|
135
|
+
task.add_done_callback(self.tasks.discard)
|
136
|
+
return task
|
137
|
+
|
87
138
|
def run(self, coro):
|
88
139
|
task = asyncio.ensure_future(coro, loop=self.loop)
|
89
140
|
self._sigints = 0
|
@@ -99,7 +150,9 @@ class UserCodeEventLoop:
|
|
99
150
|
# first sigint is graceful
|
100
151
|
task.cancel()
|
101
152
|
return
|
102
|
-
|
153
|
+
|
154
|
+
# this should normally not happen, but the second sigint would "hard kill" the event loop!
|
155
|
+
raise KeyboardInterrupt()
|
103
156
|
|
104
157
|
ignore_sigint = signal.getsignal(signal.SIGINT) == signal.SIG_IGN
|
105
158
|
if not ignore_sigint:
|
@@ -122,972 +175,381 @@ class UserCodeEventLoop:
|
|
122
175
|
self.loop.remove_signal_handler(signal.SIGINT)
|
123
176
|
|
124
177
|
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
self.calls_completed = 0
|
143
|
-
self.total_user_time: float = 0.0
|
144
|
-
self.current_input_id: Optional[str] = None
|
145
|
-
self.current_input_started_at: Optional[float] = None
|
146
|
-
|
147
|
-
self._input_concurrency: Optional[int] = None
|
148
|
-
|
149
|
-
self._semaphore: Optional[asyncio.Semaphore] = None
|
150
|
-
self._environment_name = container_args.environment_name
|
151
|
-
self._waiting_for_checkpoint = False
|
152
|
-
self._heartbeat_loop = None
|
153
|
-
|
154
|
-
self._client = client
|
155
|
-
assert isinstance(self._client, _Client)
|
156
|
-
|
157
|
-
async def initialize_app(self) -> _ContainerApp:
|
158
|
-
await _container_app.init(self._client, self.app_id, self._environment_name, self.function_def)
|
159
|
-
return _container_app
|
160
|
-
|
161
|
-
async def _run_heartbeat_loop(self):
|
162
|
-
while 1:
|
163
|
-
t0 = time.monotonic()
|
164
|
-
try:
|
165
|
-
if await self._heartbeat_handle_cancellations():
|
166
|
-
# got a cancellation event, fine to start another heartbeat immediately
|
167
|
-
# since the cancellation queue should be empty on the worker server
|
168
|
-
# however, we wait at least 1s to prevent short-circuiting the heartbeat loop
|
169
|
-
# in case there is ever a bug. This means it will take at least 1s between
|
170
|
-
# two subsequent cancellations on the same task at the moment
|
171
|
-
await asyncio.sleep(1.0)
|
172
|
-
continue
|
173
|
-
except Exception as exc:
|
174
|
-
# don't stop heartbeat loop if there are transient exceptions!
|
175
|
-
time_elapsed = time.monotonic() - t0
|
176
|
-
error = exc
|
177
|
-
logger.warning(f"Heartbeat attempt failed ({time_elapsed=}, {error=})")
|
178
|
-
|
179
|
-
heartbeat_duration = time.monotonic() - t0
|
180
|
-
time_until_next_hearbeat = max(0.0, HEARTBEAT_INTERVAL - heartbeat_duration)
|
181
|
-
await asyncio.sleep(time_until_next_hearbeat)
|
182
|
-
|
183
|
-
async def _heartbeat_handle_cancellations(self) -> bool:
|
184
|
-
# Return True if a cancellation event was received, in that case we shouldn't wait too long for another heartbeat
|
185
|
-
|
186
|
-
# Don't send heartbeats for tasks waiting to be checkpointed.
|
187
|
-
# Calling gRPC methods open new connections which block the
|
188
|
-
# checkpointing process.
|
189
|
-
if self._waiting_for_checkpoint:
|
190
|
-
return False
|
191
|
-
|
192
|
-
request = api_pb2.ContainerHeartbeatRequest(supports_graceful_input_cancellation=True)
|
193
|
-
if self.current_input_id is not None:
|
194
|
-
request.current_input_id = self.current_input_id
|
195
|
-
if self.current_input_started_at is not None:
|
196
|
-
request.current_input_started_at = self.current_input_started_at
|
197
|
-
|
198
|
-
# TODO(erikbern): capture exceptions?
|
199
|
-
response = await retry_transient_errors(
|
200
|
-
self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
|
201
|
-
)
|
178
|
+
def call_function(
|
179
|
+
user_code_event_loop: UserCodeEventLoop,
|
180
|
+
container_io_manager: "modal._runtime.container_io_manager.ContainerIOManager",
|
181
|
+
finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
|
182
|
+
batch_max_size: int,
|
183
|
+
batch_wait_ms: int,
|
184
|
+
):
|
185
|
+
async def run_input_async(io_context: IOContext) -> None:
|
186
|
+
started_at = time.time()
|
187
|
+
input_ids, function_call_ids = io_context.input_ids, io_context.function_call_ids
|
188
|
+
reset_context = _set_current_context_ids(input_ids, function_call_ids)
|
189
|
+
async with container_io_manager.handle_input_exception.aio(io_context, started_at):
|
190
|
+
res = io_context.call_finalized_function()
|
191
|
+
# TODO(erikbern): any exception below shouldn't be considered a user exception
|
192
|
+
if io_context.finalized_function.is_generator:
|
193
|
+
if not inspect.isasyncgen(res):
|
194
|
+
raise InvalidError(f"Async generator function returned value of type {type(res)}")
|
202
195
|
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
196
|
+
# Send up to this many outputs at a time.
|
197
|
+
generator_queue: asyncio.Queue[Any] = await container_io_manager._queue_create.aio(1024)
|
198
|
+
generator_output_task = asyncio.create_task(
|
199
|
+
container_io_manager.generator_output_task.aio(
|
200
|
+
function_call_ids[0],
|
201
|
+
io_context.finalized_function.data_format,
|
202
|
+
generator_queue,
|
210
203
|
)
|
211
|
-
# This is equivalent to a task cancellation or preemption from worker code,
|
212
|
-
# except we do not send a SIGKILL to forcefully exit after 30 seconds.
|
213
|
-
#
|
214
|
-
# SIGINT always interrupts the main thread, but not any auxiliary threads. On a
|
215
|
-
# sync function without concurrent inputs, this raises a KeyboardInterrupt. When
|
216
|
-
# there are concurrent inputs, we cannot interrupt the thread pool, but the
|
217
|
-
# interpreter stops waiting for daemon threads and exits. On async functions,
|
218
|
-
# this signal lands outside the event loop, stopping `run_until_complete()`.
|
219
|
-
os.kill(os.getpid(), signal.SIGINT)
|
220
|
-
|
221
|
-
elif self.current_input_id in input_ids_to_cancel:
|
222
|
-
# This goes to a registered signal handler for sync Modal functions, or to the
|
223
|
-
# `SignalHandlingEventLoop` for async functions.
|
224
|
-
#
|
225
|
-
# We only send this signal on functions that do not have concurrent inputs enabled.
|
226
|
-
# This allows us to do fine-grained input cancellation. On sync functions, the
|
227
|
-
# SIGUSR1 signal should interrupt the main thread where user code is running,
|
228
|
-
# raising an InputCancellation() exception. On async functions, the signal should
|
229
|
-
# reach a handler in SignalHandlingEventLoop, which cancels the task.
|
230
|
-
os.kill(os.getpid(), signal.SIGUSR1)
|
231
|
-
return True
|
232
|
-
return False
|
233
|
-
|
234
|
-
@contextlib.asynccontextmanager
|
235
|
-
async def heartbeats(self):
|
236
|
-
async with TaskContext() as tc:
|
237
|
-
self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop())
|
238
|
-
t.set_name("heartbeat loop")
|
239
|
-
try:
|
240
|
-
yield
|
241
|
-
finally:
|
242
|
-
t.cancel()
|
243
|
-
|
244
|
-
def stop_heartbeat(self):
|
245
|
-
if self._heartbeat_loop:
|
246
|
-
self._heartbeat_loop.cancel()
|
247
|
-
|
248
|
-
async def get_serialized_function(self) -> Tuple[Optional[Any], Callable]:
|
249
|
-
# Fetch the serialized function definition
|
250
|
-
request = api_pb2.FunctionGetSerializedRequest(function_id=self.function_id)
|
251
|
-
response = await self._client.stub.FunctionGetSerialized(request)
|
252
|
-
fun = self.deserialize(response.function_serialized)
|
253
|
-
|
254
|
-
if response.class_serialized:
|
255
|
-
cls = self.deserialize(response.class_serialized)
|
256
|
-
else:
|
257
|
-
cls = None
|
258
|
-
|
259
|
-
return cls, fun
|
260
|
-
|
261
|
-
def serialize(self, obj: Any) -> bytes:
|
262
|
-
return serialize(obj)
|
263
|
-
|
264
|
-
def deserialize(self, data: bytes) -> Any:
|
265
|
-
return deserialize(data, self._client)
|
266
|
-
|
267
|
-
@synchronizer.no_io_translation
|
268
|
-
def serialize_data_format(self, obj: Any, data_format: int) -> bytes:
|
269
|
-
return serialize_data_format(obj, data_format)
|
270
|
-
|
271
|
-
def deserialize_data_format(self, data: bytes, data_format: int) -> Any:
|
272
|
-
return deserialize_data_format(data, data_format, self._client)
|
273
|
-
|
274
|
-
async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]:
|
275
|
-
"""Read from the `data_in` stream of a function call."""
|
276
|
-
async for data in _stream_function_call_data(self._client, function_call_id, "data_in"):
|
277
|
-
yield data
|
278
|
-
|
279
|
-
async def put_data_out(
|
280
|
-
self,
|
281
|
-
function_call_id: str,
|
282
|
-
start_index: int,
|
283
|
-
data_format: int,
|
284
|
-
messages_bytes: List[Any],
|
285
|
-
) -> None:
|
286
|
-
"""Put data onto the `data_out` stream of a function call.
|
287
|
-
|
288
|
-
This is used for generator outputs, which includes web endpoint responses. Note that this
|
289
|
-
was introduced as a performance optimization in client version 0.57, so older clients will
|
290
|
-
still use the previous Postgres-backed system based on `FunctionPutOutputs()`.
|
291
|
-
"""
|
292
|
-
data_chunks: List[api_pb2.DataChunk] = []
|
293
|
-
for i, message_bytes in enumerate(messages_bytes):
|
294
|
-
chunk = api_pb2.DataChunk(data_format=data_format, index=start_index + i) # type: ignore
|
295
|
-
if len(message_bytes) > MAX_OBJECT_SIZE_BYTES:
|
296
|
-
chunk.data_blob_id = await blob_upload(message_bytes, self._client.stub)
|
297
|
-
else:
|
298
|
-
chunk.data = message_bytes
|
299
|
-
data_chunks.append(chunk)
|
300
|
-
|
301
|
-
req = api_pb2.FunctionCallPutDataRequest(function_call_id=function_call_id, data_chunks=data_chunks)
|
302
|
-
await retry_transient_errors(self._client.stub.FunctionCallPutDataOut, req)
|
303
|
-
|
304
|
-
async def generator_output_task(self, function_call_id: str, data_format: int, message_rx: asyncio.Queue) -> None:
|
305
|
-
"""Task that feeds generator outputs into a function call's `data_out` stream."""
|
306
|
-
index = 1
|
307
|
-
received_sentinel = False
|
308
|
-
while not received_sentinel:
|
309
|
-
message = await message_rx.get()
|
310
|
-
if message is self._GENERATOR_STOP_SENTINEL:
|
311
|
-
break
|
312
|
-
# ASGI 'http.response.start' and 'http.response.body' msgs are observed to be separated by 1ms.
|
313
|
-
# If we don't sleep here for 1ms we end up with an extra call to .put_data_out().
|
314
|
-
if index == 1:
|
315
|
-
await asyncio.sleep(0.001)
|
316
|
-
messages_bytes = [serialize_data_format(message, data_format)]
|
317
|
-
total_size = len(messages_bytes[0]) + 512
|
318
|
-
while total_size < 16 * 1024 * 1024: # 16 MiB, maximum size in a single message
|
319
|
-
try:
|
320
|
-
message = message_rx.get_nowait()
|
321
|
-
except asyncio.QueueEmpty:
|
322
|
-
break
|
323
|
-
if message is self._GENERATOR_STOP_SENTINEL:
|
324
|
-
received_sentinel = True
|
325
|
-
break
|
326
|
-
else:
|
327
|
-
messages_bytes.append(serialize_data_format(message, data_format))
|
328
|
-
total_size += len(messages_bytes[-1]) + 512 # 512 bytes for estimated framing overhead
|
329
|
-
await self.put_data_out(function_call_id, index, data_format, messages_bytes)
|
330
|
-
index += len(messages_bytes)
|
331
|
-
|
332
|
-
async def _queue_create(self, size: int) -> asyncio.Queue:
|
333
|
-
"""Create a queue, on the synchronicity event loop (needed on Python 3.8 and 3.9)."""
|
334
|
-
return asyncio.Queue(size)
|
335
|
-
|
336
|
-
async def _queue_put(self, queue: asyncio.Queue, value: Any) -> None:
|
337
|
-
"""Put a value onto a queue, using the synchronicity event loop."""
|
338
|
-
await queue.put(value)
|
339
|
-
|
340
|
-
async def populate_input_blobs(self, item: api_pb2.FunctionInput):
|
341
|
-
args = await blob_download(item.args_blob_id, self._client.stub)
|
342
|
-
|
343
|
-
# Mutating
|
344
|
-
item.ClearField("args_blob_id")
|
345
|
-
item.args = args
|
346
|
-
return item
|
347
|
-
|
348
|
-
def get_average_call_time(self) -> float:
|
349
|
-
if self.calls_completed == 0:
|
350
|
-
return 0
|
351
|
-
|
352
|
-
return self.total_user_time / self.calls_completed
|
353
|
-
|
354
|
-
def get_max_inputs_to_fetch(self):
|
355
|
-
if self.calls_completed == 0:
|
356
|
-
return 1
|
357
|
-
|
358
|
-
return math.ceil(RTT_S / max(self.get_average_call_time(), 1e-6))
|
359
|
-
|
360
|
-
@synchronizer.no_io_translation
|
361
|
-
async def _generate_inputs(self) -> AsyncIterator[Tuple[str, str, api_pb2.FunctionInput]]:
|
362
|
-
request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id)
|
363
|
-
eof_received = False
|
364
|
-
iteration = 0
|
365
|
-
while not eof_received and _container_app.fetching_inputs:
|
366
|
-
request.average_call_time = self.get_average_call_time()
|
367
|
-
request.max_values = self.get_max_inputs_to_fetch() # Deprecated; remove.
|
368
|
-
request.input_concurrency = self._input_concurrency
|
369
|
-
|
370
|
-
await self._semaphore.acquire()
|
371
|
-
yielded = False
|
372
|
-
try:
|
373
|
-
# If number of active inputs is at max queue size, this will block.
|
374
|
-
iteration += 1
|
375
|
-
response: api_pb2.FunctionGetInputsResponse = await retry_transient_errors(
|
376
|
-
self._client.stub.FunctionGetInputs, request
|
377
204
|
)
|
378
205
|
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
)
|
384
|
-
await asyncio.sleep(response.rate_limit_sleep_duration)
|
385
|
-
elif response.inputs:
|
386
|
-
# for input cancellations and concurrency logic we currently assume
|
387
|
-
# that there is no input buffering in the container
|
388
|
-
assert len(response.inputs) == 1
|
389
|
-
|
390
|
-
for item in response.inputs:
|
391
|
-
if item.kill_switch:
|
392
|
-
logger.debug(f"Task {self.task_id} input kill signal input.")
|
393
|
-
eof_received = True
|
394
|
-
break
|
395
|
-
if item.input_id in self.cancelled_input_ids:
|
396
|
-
continue
|
397
|
-
|
398
|
-
# If we got a pointer to a blob, download it from S3.
|
399
|
-
if item.input.WhichOneof("args_oneof") == "args_blob_id":
|
400
|
-
input_pb = await self.populate_input_blobs(item.input)
|
401
|
-
else:
|
402
|
-
input_pb = item.input
|
403
|
-
|
404
|
-
# If yielded, allow semaphore to be released via complete_call
|
405
|
-
yield (item.input_id, item.function_call_id, input_pb)
|
406
|
-
yielded = True
|
407
|
-
|
408
|
-
# We only support max_inputs = 1 at the moment
|
409
|
-
if item.input.final_input or self.function_def.max_inputs == 1:
|
410
|
-
eof_received = True
|
411
|
-
break
|
412
|
-
finally:
|
413
|
-
if not yielded:
|
414
|
-
self._semaphore.release()
|
415
|
-
|
416
|
-
@synchronizer.no_io_translation
|
417
|
-
async def run_inputs_outputs(self, input_concurrency: int = 1) -> AsyncIterator[Tuple[str, str, Any, Any]]:
|
418
|
-
# Ensure we do not fetch new inputs when container is too busy.
|
419
|
-
# Before trying to fetch an input, acquire the semaphore:
|
420
|
-
# - if no input is fetched, release the semaphore.
|
421
|
-
# - or, when the output for the fetched input is sent, release the semaphore.
|
422
|
-
self._input_concurrency = input_concurrency
|
423
|
-
self._semaphore = asyncio.Semaphore(input_concurrency)
|
424
|
-
|
425
|
-
try:
|
426
|
-
async for input_id, function_call_id, input_pb in self._generate_inputs():
|
427
|
-
args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {})
|
428
|
-
self.current_input_id, self.current_input_started_at = (input_id, time.time())
|
429
|
-
yield input_id, function_call_id, args, kwargs
|
430
|
-
self.current_input_id, self.current_input_started_at = (None, None)
|
431
|
-
finally:
|
432
|
-
# collect all active input slots, meaning all inputs have wrapped up.
|
433
|
-
for _ in range(input_concurrency):
|
434
|
-
await self._semaphore.acquire()
|
435
|
-
|
436
|
-
async def _push_output(self, input_id, started_at: float, data_format=api_pb2.DATA_FORMAT_UNSPECIFIED, **kwargs):
|
437
|
-
# upload data to S3 if too big.
|
438
|
-
if "data" in kwargs and kwargs["data"] and len(kwargs["data"]) > MAX_OBJECT_SIZE_BYTES:
|
439
|
-
data_blob_id = await blob_upload(kwargs["data"], self._client.stub)
|
440
|
-
# mutating kwargs.
|
441
|
-
del kwargs["data"]
|
442
|
-
kwargs["data_blob_id"] = data_blob_id
|
443
|
-
|
444
|
-
output = api_pb2.FunctionPutOutputsItem(
|
445
|
-
input_id=input_id,
|
446
|
-
input_started_at=started_at,
|
447
|
-
output_created_at=time.time(),
|
448
|
-
result=api_pb2.GenericResult(**kwargs),
|
449
|
-
data_format=data_format,
|
450
|
-
)
|
451
|
-
|
452
|
-
await retry_transient_errors(
|
453
|
-
self._client.stub.FunctionPutOutputs,
|
454
|
-
api_pb2.FunctionPutOutputsRequest(outputs=[output]),
|
455
|
-
additional_status_codes=[Status.RESOURCE_EXHAUSTED],
|
456
|
-
max_retries=None, # Retry indefinitely, trying every 1s.
|
457
|
-
)
|
458
|
-
|
459
|
-
def serialize_exception(self, exc: BaseException) -> Optional[bytes]:
|
460
|
-
try:
|
461
|
-
return self.serialize(exc)
|
462
|
-
except Exception as serialization_exc:
|
463
|
-
logger.info(f"Failed to serialize exception {exc}: {serialization_exc}")
|
464
|
-
# We can't always serialize exceptions.
|
465
|
-
return None
|
466
|
-
|
467
|
-
def serialize_traceback(self, exc: BaseException) -> Tuple[Optional[bytes], Optional[bytes]]:
|
468
|
-
serialized_tb, tb_line_cache = None, None
|
469
|
-
|
470
|
-
try:
|
471
|
-
tb_dict, line_cache = extract_traceback(exc, self.task_id)
|
472
|
-
serialized_tb = self.serialize(tb_dict)
|
473
|
-
tb_line_cache = self.serialize(line_cache)
|
474
|
-
except Exception:
|
475
|
-
logger.info("Failed to serialize exception traceback.")
|
476
|
-
|
477
|
-
return serialized_tb, tb_line_cache
|
478
|
-
|
479
|
-
@contextlib.asynccontextmanager
|
480
|
-
async def handle_user_exception(self) -> AsyncGenerator[None, None]:
|
481
|
-
"""Sets the task as failed in a way where it's not retried.
|
482
|
-
|
483
|
-
Used for handling exceptions from container lifecycle methods at the moment, which should
|
484
|
-
trigger a task failure state.
|
485
|
-
"""
|
486
|
-
try:
|
487
|
-
yield
|
488
|
-
except KeyboardInterrupt:
|
489
|
-
# Send no task result in case we get sigint:ed by the runner
|
490
|
-
# The status of the input should have been handled externally already in that case
|
491
|
-
raise
|
492
|
-
except BaseException as exc:
|
493
|
-
# Since this is on a different thread, sys.exc_info() can't find the exception in the stack.
|
494
|
-
traceback.print_exception(type(exc), exc, exc.__traceback__)
|
495
|
-
|
496
|
-
serialized_tb, tb_line_cache = self.serialize_traceback(exc)
|
497
|
-
|
498
|
-
result = api_pb2.GenericResult(
|
499
|
-
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
|
500
|
-
data=self.serialize_exception(exc),
|
501
|
-
exception=repr(exc),
|
502
|
-
traceback="".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
|
503
|
-
serialized_tb=serialized_tb,
|
504
|
-
tb_line_cache=tb_line_cache,
|
505
|
-
)
|
506
|
-
|
507
|
-
req = api_pb2.TaskResultRequest(result=result)
|
508
|
-
await retry_transient_errors(self._client.stub.TaskResult, req)
|
509
|
-
|
510
|
-
# Shut down the task gracefully
|
511
|
-
raise UserException()
|
512
|
-
|
513
|
-
@contextlib.asynccontextmanager
|
514
|
-
async def handle_input_exception(self, input_id, started_at: float) -> AsyncGenerator[None, None]:
|
515
|
-
"""Handle an exception while processing a function input."""
|
516
|
-
try:
|
517
|
-
yield
|
518
|
-
except KeyboardInterrupt:
|
519
|
-
raise
|
520
|
-
except (InputCancellation, asyncio.CancelledError):
|
521
|
-
# just skip creating any output for this input and keep going with the next instead
|
522
|
-
# it should have been marked as cancelled already in the backend at this point so it
|
523
|
-
# won't be retried
|
524
|
-
logger.warning(f"The current input ({input_id=}) was cancelled by a user request")
|
525
|
-
await self.complete_call(started_at)
|
526
|
-
return
|
527
|
-
except BaseException as exc:
|
528
|
-
# print exception so it's logged
|
529
|
-
traceback.print_exc()
|
530
|
-
serialized_tb, tb_line_cache = self.serialize_traceback(exc)
|
531
|
-
|
532
|
-
# Note: we're not serializing the traceback since it contains
|
533
|
-
# local references that means we can't unpickle it. We *are*
|
534
|
-
# serializing the exception, which may have some issues (there
|
535
|
-
# was an earlier note about it that it might not be possible
|
536
|
-
# to unpickle it in some cases). Let's watch out for issues.
|
537
|
-
await self._push_output(
|
538
|
-
input_id,
|
539
|
-
started_at=started_at,
|
540
|
-
data_format=api_pb2.DATA_FORMAT_PICKLE,
|
541
|
-
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
|
542
|
-
data=self.serialize_exception(exc),
|
543
|
-
exception=repr(exc),
|
544
|
-
traceback=traceback.format_exc(),
|
545
|
-
serialized_tb=serialized_tb,
|
546
|
-
tb_line_cache=tb_line_cache,
|
547
|
-
)
|
548
|
-
await self.complete_call(started_at)
|
549
|
-
|
550
|
-
async def complete_call(self, started_at):
|
551
|
-
self.total_user_time += time.time() - started_at
|
552
|
-
self.calls_completed += 1
|
553
|
-
self._semaphore.release()
|
554
|
-
|
555
|
-
@synchronizer.no_io_translation
|
556
|
-
async def push_output(self, input_id, started_at: float, data: Any, data_format: int) -> None:
|
557
|
-
await self._push_output(
|
558
|
-
input_id,
|
559
|
-
started_at=started_at,
|
560
|
-
data_format=data_format,
|
561
|
-
status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
|
562
|
-
data=self.serialize_data_format(data, data_format),
|
563
|
-
)
|
564
|
-
await self.complete_call(started_at)
|
565
|
-
|
566
|
-
async def restore(self) -> None:
|
567
|
-
# Busy-wait for restore. `/__modal/restore-state.json` is created
|
568
|
-
# by the worker process with updates to the container config.
|
569
|
-
restored_path = Path(config.get("restore_state_path"))
|
570
|
-
start = time.perf_counter()
|
571
|
-
while not restored_path.exists():
|
572
|
-
logger.debug(f"Waiting for restore (elapsed={time.perf_counter() - start:.3f}s)")
|
573
|
-
await asyncio.sleep(0.01)
|
574
|
-
continue
|
575
|
-
|
576
|
-
logger.debug("Container: restored")
|
577
|
-
|
578
|
-
# Look for state file and create new client with updated credentials.
|
579
|
-
# State data is serialized with key-value pairs, example: {"task_id": "tk-000"}
|
580
|
-
with restored_path.open("r") as file:
|
581
|
-
restored_state = json.load(file)
|
582
|
-
|
583
|
-
# Local FunctionIOManager state.
|
584
|
-
for key in ["task_id", "function_id"]:
|
585
|
-
if value := restored_state.get(key):
|
586
|
-
logger.debug(f"Updating FunctionIOManager.{key} = {value}")
|
587
|
-
setattr(self, key, restored_state[key])
|
588
|
-
|
589
|
-
# Env vars and global state.
|
590
|
-
for key, value in restored_state.items():
|
591
|
-
# Empty string indicates that value does not need to be updated.
|
592
|
-
if value != "":
|
593
|
-
config.override_locally(key, value)
|
594
|
-
|
595
|
-
# Restore input to default state.
|
596
|
-
self.current_input_id = None
|
597
|
-
self.current_input_started_at = None
|
598
|
-
|
599
|
-
self._client = await _Client.from_env()
|
600
|
-
self._waiting_for_checkpoint = False
|
601
|
-
|
602
|
-
async def checkpoint(self) -> None:
|
603
|
-
"""Message server indicating that function is ready to be checkpointed."""
|
604
|
-
if self.checkpoint_id:
|
605
|
-
logger.debug(f"Checkpoint ID: {self.checkpoint_id}")
|
606
|
-
|
607
|
-
await self._client.stub.ContainerCheckpoint(
|
608
|
-
api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id)
|
609
|
-
)
|
206
|
+
item_count = 0
|
207
|
+
async for value in res:
|
208
|
+
await container_io_manager._queue_put.aio(generator_queue, value)
|
209
|
+
item_count += 1
|
610
210
|
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
Perform volume commit for given `volume_ids`.
|
620
|
-
Only used on container exit to persist uncommitted changes on behalf of user.
|
621
|
-
"""
|
622
|
-
if not volume_ids:
|
623
|
-
return
|
624
|
-
await asyncify(os.sync)()
|
625
|
-
results = await asyncio.gather(
|
626
|
-
*[
|
627
|
-
retry_transient_errors(
|
628
|
-
self._client.stub.VolumeCommit,
|
629
|
-
api_pb2.VolumeCommitRequest(volume_id=v_id),
|
630
|
-
max_retries=9,
|
631
|
-
base_delay=0.25,
|
632
|
-
max_delay=256,
|
633
|
-
delay_factor=2,
|
211
|
+
await container_io_manager._queue_put.aio(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL)
|
212
|
+
await generator_output_task # Wait to finish sending generator outputs.
|
213
|
+
message = api_pb2.GeneratorDone(items_total=item_count)
|
214
|
+
await container_io_manager.push_outputs.aio(
|
215
|
+
io_context,
|
216
|
+
started_at,
|
217
|
+
message,
|
218
|
+
api_pb2.DATA_FORMAT_GENERATOR_DONE,
|
634
219
|
)
|
635
|
-
for v_id in volume_ids
|
636
|
-
],
|
637
|
-
return_exceptions=True,
|
638
|
-
)
|
639
|
-
for volume_id, res in zip(volume_ids, results):
|
640
|
-
if isinstance(res, Exception):
|
641
|
-
logger.error(f"modal.Volume background commit failed for {volume_id}. Exception: {res}")
|
642
220
|
else:
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
221
|
+
if not inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
|
222
|
+
raise InvalidError(
|
223
|
+
f"Async (non-generator) function returned value of type {type(res)}"
|
224
|
+
" You might need to use @app.function(..., is_generator=True)."
|
225
|
+
)
|
226
|
+
value = await res
|
227
|
+
await container_io_manager.push_outputs.aio(
|
228
|
+
io_context,
|
229
|
+
started_at,
|
230
|
+
value,
|
231
|
+
io_context.finalized_function.data_format,
|
232
|
+
)
|
233
|
+
reset_context()
|
648
234
|
|
649
|
-
def
|
650
|
-
function_io_manager, #: FunctionIOManager, TODO: this type is generated at runtime
|
651
|
-
imp_fun: ImportedFunction,
|
652
|
-
):
|
653
|
-
def run_input(input_id: str, function_call_id: str, args: Any, kwargs: Any) -> None:
|
235
|
+
def run_input_sync(io_context: IOContext) -> None:
|
654
236
|
started_at = time.time()
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
res =
|
659
|
-
logger.debug(f"Finished input {input_id} (sync)")
|
237
|
+
input_ids, function_call_ids = io_context.input_ids, io_context.function_call_ids
|
238
|
+
reset_context = _set_current_context_ids(input_ids, function_call_ids)
|
239
|
+
with container_io_manager.handle_input_exception(io_context, started_at):
|
240
|
+
res = io_context.call_finalized_function()
|
660
241
|
|
661
242
|
# TODO(erikbern): any exception below shouldn't be considered a user exception
|
662
|
-
if
|
243
|
+
if io_context.finalized_function.is_generator:
|
663
244
|
if not inspect.isgenerator(res):
|
664
245
|
raise InvalidError(f"Generator function returned value of type {type(res)}")
|
665
246
|
|
666
247
|
# Send up to this many outputs at a time.
|
667
|
-
generator_queue: asyncio.Queue[Any] =
|
668
|
-
generator_output_task =
|
669
|
-
|
670
|
-
|
248
|
+
generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024)
|
249
|
+
generator_output_task: concurrent.futures.Future = container_io_manager.generator_output_task( # type: ignore
|
250
|
+
function_call_ids[0],
|
251
|
+
io_context.finalized_function.data_format,
|
671
252
|
generator_queue,
|
672
|
-
_future=True, # Synchronicity magic to return a future.
|
253
|
+
_future=True, # type: ignore # Synchronicity magic to return a future.
|
673
254
|
)
|
674
255
|
|
675
256
|
item_count = 0
|
676
257
|
for value in res:
|
677
|
-
|
258
|
+
container_io_manager._queue_put(generator_queue, value)
|
678
259
|
item_count += 1
|
679
260
|
|
680
|
-
|
261
|
+
container_io_manager._queue_put(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL)
|
681
262
|
generator_output_task.result() # Wait to finish sending generator outputs.
|
682
263
|
message = api_pb2.GeneratorDone(items_total=item_count)
|
683
|
-
|
264
|
+
container_io_manager.push_outputs(io_context, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE)
|
684
265
|
else:
|
685
266
|
if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
|
686
267
|
raise InvalidError(
|
687
268
|
f"Sync (non-generator) function return value of type {type(res)}."
|
688
|
-
" You might need to use @
|
269
|
+
" You might need to use @app.function(..., is_generator=True)."
|
689
270
|
)
|
690
|
-
|
271
|
+
container_io_manager.push_outputs(
|
272
|
+
io_context, started_at, res, io_context.finalized_function.data_format
|
273
|
+
)
|
691
274
|
reset_context()
|
692
275
|
|
693
|
-
if
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
pass
|
717
|
-
inputs.task_done()
|
718
|
-
|
719
|
-
for input_id, function_call_id, args, kwargs in function_io_manager.run_inputs_outputs(
|
720
|
-
imp_fun.input_concurrency
|
721
|
-
):
|
722
|
-
if spawned_workers < imp_fun.input_concurrency:
|
723
|
-
threading.Thread(target=worker_thread, daemon=True).start()
|
724
|
-
spawned_workers += 1
|
725
|
-
inputs.put((input_id, function_call_id, args, kwargs))
|
726
|
-
|
727
|
-
finished.set()
|
728
|
-
inputs.join()
|
729
|
-
|
730
|
-
else:
|
731
|
-
for input_id, function_call_id, args, kwargs in function_io_manager.run_inputs_outputs(
|
732
|
-
imp_fun.input_concurrency
|
733
|
-
):
|
734
|
-
try:
|
735
|
-
run_input(input_id, function_call_id, args, kwargs)
|
736
|
-
except:
|
737
|
-
raise
|
738
|
-
|
739
|
-
|
740
|
-
async def call_function_async(
|
741
|
-
function_io_manager, #: FunctionIOManager, TODO: this type is generated at runtime
|
742
|
-
imp_fun: ImportedFunction,
|
743
|
-
):
|
744
|
-
async def run_input(input_id: str, function_call_id: str, args: Any, kwargs: Any) -> None:
|
745
|
-
started_at = time.time()
|
746
|
-
reset_context = _set_current_context_ids(input_id, function_call_id)
|
747
|
-
async with function_io_manager.handle_input_exception.aio(input_id, started_at):
|
748
|
-
logger.debug(f"Starting input {input_id} (async)")
|
749
|
-
res = imp_fun.fun(*args, **kwargs)
|
750
|
-
logger.debug(f"Finished input {input_id} (async)")
|
751
|
-
|
752
|
-
# TODO(erikbern): any exception below shouldn't be considered a user exception
|
753
|
-
if imp_fun.is_generator:
|
754
|
-
if not inspect.isasyncgen(res):
|
755
|
-
raise InvalidError(f"Async generator function returned value of type {type(res)}")
|
756
|
-
|
757
|
-
# Send up to this many outputs at a time.
|
758
|
-
generator_queue: asyncio.Queue[Any] = await function_io_manager._queue_create.aio(1024)
|
759
|
-
generator_output_task = asyncio.create_task(
|
760
|
-
function_io_manager.generator_output_task.aio(
|
761
|
-
function_call_id,
|
762
|
-
imp_fun.data_format,
|
763
|
-
generator_queue,
|
276
|
+
if container_io_manager.target_concurrency > 1:
|
277
|
+
with DaemonizedThreadPool(max_threads=container_io_manager.max_concurrency) as thread_pool:
|
278
|
+
|
279
|
+
def make_async_cancel_callback(task):
|
280
|
+
def f():
|
281
|
+
user_code_event_loop.loop.call_soon_threadsafe(task.cancel)
|
282
|
+
|
283
|
+
return f
|
284
|
+
|
285
|
+
did_sigint = False
|
286
|
+
|
287
|
+
def cancel_callback_sync():
|
288
|
+
nonlocal did_sigint
|
289
|
+
# We only want one sigint even if multiple inputs are cancelled
|
290
|
+
# A second sigint would forcibly shut down the event loop and spew
|
291
|
+
# out a bunch of tracebacks, which we only want to happen in case
|
292
|
+
# the worker kills this process after a failed self-termination
|
293
|
+
if not did_sigint:
|
294
|
+
did_sigint = True
|
295
|
+
logger.warning(
|
296
|
+
"User cancelling input of non-async functions with allow_concurrent_inputs > 1.\n"
|
297
|
+
"This shuts down the container, causing concurrently running inputs to be "
|
298
|
+
"rescheduled in other containers."
|
764
299
|
)
|
765
|
-
|
766
|
-
|
767
|
-
item_count = 0
|
768
|
-
async for value in res:
|
769
|
-
await function_io_manager._queue_put.aio(generator_queue, value)
|
770
|
-
item_count += 1
|
300
|
+
os.kill(os.getpid(), signal.SIGINT)
|
771
301
|
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
302
|
+
async def run_concurrent_inputs():
|
303
|
+
# all run_input coroutines will have completed by the time we leave the execution context
|
304
|
+
# but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s
|
305
|
+
# for them to resolve gracefully:
|
306
|
+
async with TaskContext(0.01) as task_context:
|
307
|
+
async for io_context in container_io_manager.run_inputs_outputs.aio(
|
308
|
+
finalized_functions, batch_max_size, batch_wait_ms
|
309
|
+
):
|
310
|
+
# Note that run_inputs_outputs will not return until all the input slots are released
|
311
|
+
# so that they can be acquired by the run_inputs_outputs finalizer
|
312
|
+
# This prevents leaving the task_context before outputs have been created
|
313
|
+
# TODO: refactor to make this a bit more easy to follow?
|
314
|
+
if io_context.finalized_function.is_async:
|
315
|
+
input_task = task_context.create_task(run_input_async(io_context))
|
316
|
+
io_context.set_cancel_callback(make_async_cancel_callback(input_task))
|
317
|
+
else:
|
318
|
+
# run sync input in thread
|
319
|
+
thread_pool.submit(run_input_sync, io_context)
|
320
|
+
io_context.set_cancel_callback(cancel_callback_sync)
|
787
321
|
|
788
|
-
|
789
|
-
# all run_input coroutines will have completed by the time we leave the execution context
|
790
|
-
# but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s
|
791
|
-
# for them to resolve gracefully:
|
792
|
-
async with TaskContext(0.01) as execution_context:
|
793
|
-
async for input_id, function_call_id, args, kwargs in function_io_manager.run_inputs_outputs.aio(
|
794
|
-
imp_fun.input_concurrency
|
795
|
-
):
|
796
|
-
# Note that run_inputs_outputs will not return until the concurrency semaphore has
|
797
|
-
# released all its slots so that they can be acquired by the run_inputs_outputs finalizer
|
798
|
-
# This prevents leaving the execution_context before outputs have been created
|
799
|
-
# TODO: refactor to make this a bit more easy to follow?
|
800
|
-
execution_context.create_task(run_input(input_id, function_call_id, args, kwargs))
|
801
|
-
else:
|
802
|
-
async for input_id, function_call_id, args, kwargs in function_io_manager.run_inputs_outputs.aio(
|
803
|
-
imp_fun.input_concurrency
|
804
|
-
):
|
805
|
-
await run_input(input_id, function_call_id, args, kwargs)
|
806
|
-
|
807
|
-
|
808
|
-
@dataclass
|
809
|
-
class ImportedFunction:
|
810
|
-
obj: Any
|
811
|
-
fun: Callable
|
812
|
-
stub: Optional[_Stub]
|
813
|
-
is_async: bool
|
814
|
-
is_generator: bool
|
815
|
-
data_format: int # api_pb2.DataFormat
|
816
|
-
input_concurrency: int
|
817
|
-
is_auto_snapshot: bool
|
818
|
-
function: _Function
|
819
|
-
|
820
|
-
|
821
|
-
def import_function(
|
822
|
-
function_def: api_pb2.Function,
|
823
|
-
ser_cls,
|
824
|
-
ser_fun,
|
825
|
-
ser_params: Optional[bytes],
|
826
|
-
function_io_manager,
|
827
|
-
client: Client,
|
828
|
-
) -> ImportedFunction:
|
829
|
-
"""Imports a function dynamically, and locates the stub.
|
830
|
-
|
831
|
-
This is somewhat complex because we're dealing with 3 quite different type of functions:
|
832
|
-
1. Functions defined in global scope and decorated in global scope (Function objects)
|
833
|
-
2. Functions defined in global scope but decorated elsewhere (these will be raw callables)
|
834
|
-
3. Serialized functions
|
835
|
-
|
836
|
-
In addition, we also need to handle
|
837
|
-
* Normal functions
|
838
|
-
* Methods on classes (in which case we need to instantiate the object)
|
839
|
-
|
840
|
-
This helper also handles web endpoints, ASGI/WSGI servers, and HTTP servers.
|
841
|
-
|
842
|
-
In order to locate the stub, we try two things:
|
843
|
-
* If the function is a Function, we can get the stub directly from it
|
844
|
-
* Otherwise, use the stub name and look it up from a global list of stubs: this
|
845
|
-
typically only happens in case 2 above, or in sometimes for case 3
|
846
|
-
|
847
|
-
Note that `import_function` is *not* synchronized, becase we need it to run on the main
|
848
|
-
thread. This is so that any user code running in global scope (which executes as a part of
|
849
|
-
the import) runs on the right thread.
|
850
|
-
"""
|
851
|
-
module: Optional[ModuleType] = None
|
852
|
-
cls: Optional[Type] = None
|
853
|
-
fun: Callable
|
854
|
-
function: Optional[_Function] = None
|
855
|
-
active_stub: Optional[_Stub] = None
|
856
|
-
pty_info: api_pb2.PTYInfo = function_def.pty_info
|
857
|
-
|
858
|
-
if ser_fun is not None:
|
859
|
-
# This is a serialized function we already fetched from the server
|
860
|
-
cls, fun = ser_cls, ser_fun
|
322
|
+
user_code_event_loop.run(run_concurrent_inputs())
|
861
323
|
else:
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
if not is_global_function(qual_name):
|
867
|
-
raise LocalFunctionError("Attempted to load a function defined in a function scope")
|
868
|
-
|
869
|
-
parts = qual_name.split(".")
|
870
|
-
if len(parts) == 1:
|
871
|
-
# This is a function
|
872
|
-
cls = None
|
873
|
-
f = getattr(module, qual_name)
|
874
|
-
if isinstance(f, Function):
|
875
|
-
function = synchronizer._translate_in(f)
|
876
|
-
fun = function.get_raw_f()
|
877
|
-
active_stub = function._stub
|
878
|
-
else:
|
879
|
-
fun = f
|
880
|
-
elif len(parts) == 2:
|
881
|
-
# This is a method on a class
|
882
|
-
cls_name, fun_name = parts
|
883
|
-
cls = getattr(module, cls_name)
|
884
|
-
if isinstance(cls, Cls):
|
885
|
-
# The cls decorator is in global scope
|
886
|
-
_cls = synchronizer._translate_in(cls)
|
887
|
-
fun = _cls._callables[fun_name]
|
888
|
-
function = _cls._functions.get(fun_name)
|
889
|
-
active_stub = _cls._stub
|
890
|
-
else:
|
891
|
-
# This is a raw class
|
892
|
-
fun = getattr(cls, fun_name)
|
893
|
-
else:
|
894
|
-
raise InvalidError(f"Invalid function qualname {qual_name}")
|
895
|
-
|
896
|
-
# If the cls/function decorator was applied in local scope, but the stub is global, we can look it up
|
897
|
-
if active_stub is None:
|
898
|
-
# This branch is reached in the special case that the imported function is 1) not serialized, and 2) isn't a FunctionHandle - i.e, not decorated at definition time
|
899
|
-
# Look at all instantiated stubs - if there is only one with the indicated name, use that one
|
900
|
-
stub_name: Optional[str] = function_def.stub_name or None # coalesce protobuf field to None
|
901
|
-
matching_stubs = _Stub._all_stubs.get(stub_name, [])
|
902
|
-
if len(matching_stubs) > 1:
|
903
|
-
if stub_name is not None:
|
904
|
-
warning_sub_message = f"stub with the same name ('{stub_name}')"
|
324
|
+
for io_context in container_io_manager.run_inputs_outputs(finalized_functions, batch_max_size, batch_wait_ms):
|
325
|
+
if io_context.finalized_function.is_async:
|
326
|
+
user_code_event_loop.run(run_input_async(io_context))
|
905
327
|
else:
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
elif len(matching_stubs) == 1:
|
911
|
-
(active_stub,) = matching_stubs
|
912
|
-
# there could also technically be zero found stubs, but that should probably never be an issue since that would mean user won't use is_inside or other function handles anyway
|
913
|
-
|
914
|
-
# Check this property before we turn it into a method (overriden by webhooks)
|
915
|
-
is_async = get_is_async(fun)
|
916
|
-
|
917
|
-
# Use the function definition for whether this is a generator (overriden by webhooks)
|
918
|
-
is_generator = function_def.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR
|
919
|
-
|
920
|
-
# What data format is used for function inputs and outputs
|
921
|
-
data_format = api_pb2.DATA_FORMAT_PICKLE
|
922
|
-
|
923
|
-
# Container can fetch multiple inputs simultaneously
|
924
|
-
if pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
|
925
|
-
# Concurrency doesn't apply for `modal shell`.
|
926
|
-
input_concurrency = 1
|
927
|
-
else:
|
928
|
-
input_concurrency = function_def.allow_concurrent_inputs or 1
|
328
|
+
# Set up a custom signal handler for `SIGUSR1`, which gets translated to an InputCancellation
|
329
|
+
# during function execution. This is sent to cancel inputs from the user
|
330
|
+
def _cancel_input_signal_handler(signum, stackframe):
|
331
|
+
raise InputCancellation("Input was cancelled by user")
|
929
332
|
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
333
|
+
usr1_handler = signal.signal(signal.SIGUSR1, _cancel_input_signal_handler)
|
334
|
+
# run this sync code in the main thread, blocking the "userland" event loop
|
335
|
+
# this lets us cancel it using a signal handler that raises an exception
|
336
|
+
try:
|
337
|
+
run_input_sync(io_context)
|
338
|
+
finally:
|
339
|
+
signal.signal(signal.SIGUSR1, usr1_handler) # reset signal handler
|
340
|
+
|
341
|
+
|
342
|
+
def get_active_app_fallback(function_def: api_pb2.Function) -> _App:
|
343
|
+
# This branch is reached in the special case that the imported function/class is:
|
344
|
+
# 1) not serialized, and
|
345
|
+
# 2) isn't a FunctionHandle - i.e, not decorated at definition time
|
346
|
+
# Look at all instantiated apps - if there is only one with the indicated name, use that one
|
347
|
+
app_name: Optional[str] = function_def.app_name or None # coalesce protobuf field to None
|
348
|
+
matching_apps = _App._all_apps.get(app_name, [])
|
349
|
+
if len(matching_apps) == 1:
|
350
|
+
active_app: _App = matching_apps[0]
|
351
|
+
return active_app
|
352
|
+
|
353
|
+
if len(matching_apps) > 1:
|
354
|
+
if app_name is not None:
|
355
|
+
warning_sub_message = f"app with the same name ('{app_name}')"
|
935
356
|
else:
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
fun = fun.__get__(obj)
|
942
|
-
else:
|
943
|
-
obj = None
|
944
|
-
|
945
|
-
if function_def.webhook_config.type:
|
946
|
-
is_async = True
|
947
|
-
is_generator = True
|
948
|
-
data_format = api_pb2.DATA_FORMAT_ASGI
|
949
|
-
|
950
|
-
if function_def.webhook_config.type == api_pb2.WEBHOOK_TYPE_ASGI_APP:
|
951
|
-
# Function returns an asgi_app, which we can use as a callable.
|
952
|
-
fun = asgi_app_wrapper(fun(), function_io_manager)
|
953
|
-
|
954
|
-
elif function_def.webhook_config.type == api_pb2.WEBHOOK_TYPE_WSGI_APP:
|
955
|
-
# Function returns an wsgi_app, which we can use as a callable.
|
956
|
-
fun = wsgi_app_wrapper(fun(), function_io_manager)
|
957
|
-
|
958
|
-
elif function_def.webhook_config.type == api_pb2.WEBHOOK_TYPE_FUNCTION:
|
959
|
-
# Function is a webhook without an ASGI app. Create one for it.
|
960
|
-
fun = asgi_app_wrapper(
|
961
|
-
webhook_asgi_app(fun, function_def.webhook_config.method),
|
962
|
-
function_io_manager,
|
963
|
-
)
|
357
|
+
warning_sub_message = "unnamed app"
|
358
|
+
logger.warning(
|
359
|
+
f"You have more than one {warning_sub_message}. "
|
360
|
+
"It's recommended to name all your Apps uniquely when using multiple apps"
|
361
|
+
)
|
964
362
|
|
965
|
-
|
966
|
-
|
967
|
-
|
968
|
-
|
969
|
-
# We intentionally try to connect to the external interface instead of the loopback
|
970
|
-
# interface here so users are forced to expose the server. This allows us to potentially
|
971
|
-
# change the implementation to use an external bridge in the future.
|
972
|
-
host = get_ip_address(b"eth0")
|
973
|
-
port = function_def.webhook_config.web_server_port
|
974
|
-
startup_timeout = function_def.webhook_config.web_server_startup_timeout
|
975
|
-
wait_for_web_server(host, port, timeout=startup_timeout)
|
976
|
-
fun = asgi_app_wrapper(web_server_proxy(host, port), function_io_manager)
|
977
|
-
|
978
|
-
else:
|
979
|
-
raise InvalidError(f"Unrecognized web endpoint type {function_def.webhook_config.type}")
|
980
|
-
|
981
|
-
return ImportedFunction(
|
982
|
-
obj,
|
983
|
-
fun,
|
984
|
-
active_stub,
|
985
|
-
is_async,
|
986
|
-
is_generator,
|
987
|
-
data_format,
|
988
|
-
input_concurrency,
|
989
|
-
function_def.is_auto_snapshot,
|
990
|
-
function,
|
991
|
-
)
|
363
|
+
# If we don't have an active app, create one on the fly
|
364
|
+
# The app object is used to carry the app layout etc
|
365
|
+
return _App()
|
992
366
|
|
993
367
|
|
994
368
|
def call_lifecycle_functions(
|
995
369
|
event_loop: UserCodeEventLoop,
|
996
|
-
|
997
|
-
funcs:
|
370
|
+
container_io_manager, #: ContainerIOManager, TODO: this type is generated at runtime
|
371
|
+
funcs: Sequence[Callable[..., Any]],
|
998
372
|
) -> None:
|
999
373
|
"""Call function(s), can be sync or async, but any return values are ignored."""
|
1000
|
-
with
|
374
|
+
with container_io_manager.handle_user_exception():
|
1001
375
|
for func in funcs:
|
1002
376
|
# We are deprecating parameterized exit methods but want to gracefully handle old code.
|
1003
377
|
# We can remove this once the deprecation in the actual @exit decorator is enforced.
|
1004
|
-
args = (None, None, None) if
|
1005
|
-
|
1006
|
-
|
1007
|
-
|
378
|
+
args = (None, None, None) if callable_has_non_self_params(func) else ()
|
379
|
+
# in case func is non-async, it's executed here and sigint will by default
|
380
|
+
# interrupt it using a KeyboardInterrupt exception
|
381
|
+
res = func(*args)
|
1008
382
|
if inspect.iscoroutine(res):
|
1009
383
|
# if however func is async, we have to jump through some hoops
|
1010
384
|
event_loop.run(res)
|
1011
385
|
|
1012
386
|
|
387
|
+
def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function, _client: "modal.client._Client"):
|
388
|
+
if function_def.class_parameter_info.format in (
|
389
|
+
api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_UNSPECIFIED,
|
390
|
+
api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PICKLE,
|
391
|
+
):
|
392
|
+
# legacy serialization format - pickle of `(args, kwargs)` w/ support for modal object arguments
|
393
|
+
param_args, param_kwargs = deserialize(serialized_params, _client)
|
394
|
+
elif function_def.class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO:
|
395
|
+
param_args = ()
|
396
|
+
param_kwargs = deserialize_proto_params(serialized_params, list(function_def.class_parameter_info.schema))
|
397
|
+
else:
|
398
|
+
raise ExecutionError(
|
399
|
+
f"Unknown class parameter serialization format: {function_def.class_parameter_info.format}"
|
400
|
+
)
|
401
|
+
|
402
|
+
return param_args, param_kwargs
|
403
|
+
|
404
|
+
|
1013
405
|
def main(container_args: api_pb2.ContainerArguments, client: Client):
|
1014
|
-
# This is a bit weird but we need both the blocking and async versions of
|
406
|
+
# This is a bit weird but we need both the blocking and async versions of ContainerIOManager.
|
1015
407
|
# At some point, we should fix that by having built-in support for running "user code"
|
1016
|
-
|
408
|
+
container_io_manager = ContainerIOManager(container_args, client)
|
409
|
+
active_app: _App
|
410
|
+
service: Service
|
411
|
+
function_def = container_args.function_def
|
412
|
+
is_auto_snapshot: bool = function_def.is_auto_snapshot
|
413
|
+
# The worker sets this flag to "1" for snapshot and restore tasks. Otherwise, this flag is unset,
|
414
|
+
# in which case snapshots should be disabled.
|
415
|
+
is_snapshotting_function = (
|
416
|
+
function_def.is_checkpointing_function and os.environ.get("MODAL_ENABLE_SNAP_RESTORE", "0") == "1"
|
417
|
+
)
|
418
|
+
|
419
|
+
_client: _Client = synchronizer._translate_in(client) # TODO(erikbern): ugly
|
1017
420
|
|
1018
|
-
#
|
1019
|
-
|
421
|
+
# Call ContainerHello - currently a noop but might be used later for things
|
422
|
+
container_io_manager.hello()
|
1020
423
|
|
1021
|
-
with
|
424
|
+
with container_io_manager.heartbeats(is_snapshotting_function), UserCodeEventLoop() as event_loop:
|
1022
425
|
# If this is a serialized function, fetch the definition from the server
|
1023
|
-
if
|
1024
|
-
ser_cls, ser_fun =
|
426
|
+
if function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED:
|
427
|
+
ser_cls, ser_fun = container_io_manager.get_serialized_function()
|
1025
428
|
else:
|
1026
429
|
ser_cls, ser_fun = None, None
|
1027
430
|
|
1028
431
|
# Initialize the function, importing user code.
|
1029
|
-
with
|
1030
|
-
|
1031
|
-
container_args.function_def,
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
432
|
+
with container_io_manager.handle_user_exception():
|
433
|
+
if container_args.serialized_params:
|
434
|
+
param_args, param_kwargs = deserialize_params(container_args.serialized_params, function_def, _client)
|
435
|
+
else:
|
436
|
+
param_args = ()
|
437
|
+
param_kwargs = {}
|
438
|
+
|
439
|
+
if function_def.is_class:
|
440
|
+
service = import_class_service(
|
441
|
+
function_def,
|
442
|
+
ser_cls,
|
443
|
+
param_args,
|
444
|
+
param_kwargs,
|
445
|
+
)
|
446
|
+
else:
|
447
|
+
service = import_single_function_service(
|
448
|
+
function_def,
|
449
|
+
ser_cls,
|
450
|
+
ser_fun,
|
451
|
+
param_args,
|
452
|
+
param_kwargs,
|
453
|
+
)
|
454
|
+
|
455
|
+
# If the cls/function decorator was applied in local scope, but the app is global, we can look it up
|
456
|
+
if service.app is not None:
|
457
|
+
active_app = service.app
|
458
|
+
else:
|
459
|
+
# if the app can't be inferred by the imported function, use name-based fallback
|
460
|
+
active_app = get_active_app_fallback(function_def)
|
461
|
+
|
462
|
+
if function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
|
463
|
+
# Concurrency and batching doesn't apply for `modal shell`.
|
464
|
+
batch_max_size = 0
|
465
|
+
batch_wait_ms = 0
|
466
|
+
else:
|
467
|
+
batch_max_size = function_def.batch_max_size or 0
|
468
|
+
batch_wait_ms = function_def.batch_linger_ms or 0
|
1038
469
|
|
1039
|
-
#
|
1040
|
-
|
1041
|
-
|
470
|
+
# Get ids and metadata for objects (primarily functions and classes) on the app
|
471
|
+
container_app: RunningApp = container_io_manager.get_app_objects(container_args.app_layout)
|
472
|
+
|
473
|
+
# Initialize objects on the app.
|
474
|
+
# This is basically only functions and classes - anything else is deprecated and will be unsupported soon
|
475
|
+
app: App = synchronizer._translate_out(active_app)
|
476
|
+
app._init_container(client, container_app)
|
1042
477
|
|
1043
478
|
# Hydrate all function dependencies.
|
1044
479
|
# TODO(erikbern): we an remove this once we
|
1045
480
|
# 1. Enable lazy hydration for all objects
|
1046
481
|
# 2. Fully deprecate .new() objects
|
1047
|
-
if
|
1048
|
-
dep_object_ids:
|
1049
|
-
|
482
|
+
if service.code_deps is not None: # this is not set for serialized or non-global scope functions
|
483
|
+
dep_object_ids: list[str] = [dep.object_id for dep in function_def.object_dependencies]
|
484
|
+
if len(service.code_deps) != len(dep_object_ids):
|
485
|
+
raise ExecutionError(
|
486
|
+
f"Function has {len(service.code_deps)} dependencies"
|
487
|
+
f" but container got {len(dep_object_ids)} object ids.\n"
|
488
|
+
f"Code deps: {service.code_deps}\n"
|
489
|
+
f"Object ids: {dep_object_ids}"
|
490
|
+
)
|
491
|
+
for object_id, obj in zip(dep_object_ids, service.code_deps):
|
492
|
+
metadata: Message = container_app.object_handle_metadata[object_id]
|
493
|
+
obj._hydrate(object_id, _client, metadata)
|
494
|
+
|
495
|
+
# Initialize clustered functions.
|
496
|
+
if function_def._experimental_group_size > 0:
|
497
|
+
initialize_clustered_function(
|
498
|
+
client,
|
499
|
+
container_args.task_id,
|
500
|
+
function_def._experimental_group_size,
|
501
|
+
)
|
1050
502
|
|
1051
|
-
# Identify all "enter" methods that need to run before we
|
1052
|
-
if
|
1053
|
-
|
1054
|
-
|
503
|
+
# Identify all "enter" methods that need to run before we snapshot.
|
504
|
+
if service.user_cls_instance is not None and not is_auto_snapshot:
|
505
|
+
pre_snapshot_methods = _find_callables_for_obj(
|
506
|
+
service.user_cls_instance, _PartialFunctionFlags.ENTER_PRE_SNAPSHOT
|
507
|
+
)
|
508
|
+
call_lifecycle_functions(event_loop, container_io_manager, list(pre_snapshot_methods.values()))
|
1055
509
|
|
1056
510
|
# If this container is being used to create a checkpoint, checkpoint the container after
|
1057
|
-
# global imports and
|
1058
|
-
if
|
1059
|
-
|
511
|
+
# global imports and initialization. Checkpointed containers run from this point onwards.
|
512
|
+
if is_snapshotting_function:
|
513
|
+
container_io_manager.memory_snapshot()
|
1060
514
|
|
1061
515
|
# Install hooks for interactive functions.
|
1062
|
-
|
516
|
+
def breakpoint_wrapper():
|
517
|
+
# note: it would be nice to not have breakpoint_wrapper() included in the backtrace
|
518
|
+
container_io_manager.interact(from_breakpoint=True)
|
519
|
+
import pdb
|
1063
520
|
|
1064
|
-
|
1065
|
-
# note: it would be nice to not have breakpoint_wrapper() included in the backtrace
|
1066
|
-
interact()
|
1067
|
-
import pdb
|
521
|
+
frame = inspect.currentframe().f_back
|
1068
522
|
|
1069
|
-
|
523
|
+
pdb.Pdb().set_trace(frame)
|
1070
524
|
|
1071
|
-
|
525
|
+
sys.breakpointhook = breakpoint_wrapper
|
1072
526
|
|
1073
|
-
# Identify the "enter" methods to run after resuming from a
|
1074
|
-
if
|
1075
|
-
|
1076
|
-
|
527
|
+
# Identify the "enter" methods to run after resuming from a snapshot.
|
528
|
+
if service.user_cls_instance is not None and not is_auto_snapshot:
|
529
|
+
post_snapshot_methods = _find_callables_for_obj(
|
530
|
+
service.user_cls_instance, _PartialFunctionFlags.ENTER_POST_SNAPSHOT
|
531
|
+
)
|
532
|
+
call_lifecycle_functions(event_loop, container_io_manager, list(post_snapshot_methods.values()))
|
1077
533
|
|
534
|
+
with container_io_manager.handle_user_exception():
|
535
|
+
finalized_functions = service.get_finalized_functions(function_def, container_io_manager)
|
1078
536
|
# Execute the function.
|
537
|
+
lifespan_background_tasks = []
|
1079
538
|
try:
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
539
|
+
for finalized_function in finalized_functions.values():
|
540
|
+
if finalized_function.lifespan_manager:
|
541
|
+
lifespan_background_tasks.append(
|
542
|
+
event_loop.create_task(finalized_function.lifespan_manager.background_task())
|
543
|
+
)
|
544
|
+
with container_io_manager.handle_user_exception():
|
545
|
+
event_loop.run(finalized_function.lifespan_manager.lifespan_startup())
|
546
|
+
call_function(
|
547
|
+
event_loop,
|
548
|
+
container_io_manager,
|
549
|
+
finalized_functions,
|
550
|
+
batch_max_size,
|
551
|
+
batch_wait_ms,
|
552
|
+
)
|
1091
553
|
finally:
|
1092
554
|
# Run exit handlers. From this point onward, ignore all SIGINT signals that come from
|
1093
555
|
# graceful shutdowns originating on the worker, as well as stray SIGUSR1 signals that
|
@@ -1096,15 +558,27 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
|
|
1096
558
|
usr1_handler = signal.signal(signal.SIGUSR1, signal.SIG_IGN)
|
1097
559
|
|
1098
560
|
try:
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
561
|
+
try:
|
562
|
+
# run lifespan shutdown for asgi apps
|
563
|
+
for finalized_function in finalized_functions.values():
|
564
|
+
if finalized_function.lifespan_manager:
|
565
|
+
with container_io_manager.handle_user_exception():
|
566
|
+
event_loop.run(finalized_function.lifespan_manager.lifespan_shutdown())
|
567
|
+
finally:
|
568
|
+
# no need to keep the lifespan asgi call around - we send it no more messages
|
569
|
+
for lifespan_background_task in lifespan_background_tasks:
|
570
|
+
lifespan_background_task.cancel() # prevent dangling tasks
|
571
|
+
|
572
|
+
# Identify "exit" methods and run them.
|
573
|
+
# want to make sure this is called even if the lifespan manager fails
|
574
|
+
if service.user_cls_instance is not None and not is_auto_snapshot:
|
575
|
+
exit_methods = _find_callables_for_obj(service.user_cls_instance, _PartialFunctionFlags.EXIT)
|
576
|
+
call_lifecycle_functions(event_loop, container_io_manager, list(exit_methods.values()))
|
1103
577
|
|
1104
578
|
# Finally, commit on exit to catch uncommitted volume changes and surface background
|
1105
579
|
# commit errors.
|
1106
|
-
|
1107
|
-
[v.volume_id for v in
|
580
|
+
container_io_manager.volume_commit(
|
581
|
+
[v.volume_id for v in function_def.volume_mounts if v.allow_background_commits]
|
1108
582
|
)
|
1109
583
|
finally:
|
1110
584
|
# Restore the original signal handler, needed for container_test hygiene since the
|
@@ -1117,7 +591,15 @@ if __name__ == "__main__":
|
|
1117
591
|
logger.debug("Container: starting")
|
1118
592
|
|
1119
593
|
container_args = api_pb2.ContainerArguments()
|
1120
|
-
|
594
|
+
|
595
|
+
container_arguments_path: Optional[str] = os.environ.get("MODAL_CONTAINER_ARGUMENTS_PATH")
|
596
|
+
if container_arguments_path is None:
|
597
|
+
# TODO(erikbern): this fallback is for old workers and we can remove it very soon (days)
|
598
|
+
import base64
|
599
|
+
|
600
|
+
container_args.ParseFromString(base64.b64decode(sys.argv[1]))
|
601
|
+
else:
|
602
|
+
container_args.ParseFromString(open(container_arguments_path, "rb").read())
|
1121
603
|
|
1122
604
|
# Note that we're creating the client in a synchronous context, but it will be running in a separate thread.
|
1123
605
|
# This is good because if the function is long running then we the client can still send heartbeats
|
@@ -1137,7 +619,7 @@ if __name__ == "__main__":
|
|
1137
619
|
# from shutting down. The sleep(0) here is needed for finished ThreadPoolExecutor resources to
|
1138
620
|
# shut down without triggering this warning (e.g., `@wsgi_app()`).
|
1139
621
|
time.sleep(0)
|
1140
|
-
lingering_threads:
|
622
|
+
lingering_threads: list[threading.Thread] = []
|
1141
623
|
for thread in threading.enumerate():
|
1142
624
|
current_thread = threading.get_ident()
|
1143
625
|
if thread.ident is not None and thread.ident != current_thread and not thread.daemon and thread.is_alive():
|
@@ -1145,7 +627,8 @@ if __name__ == "__main__":
|
|
1145
627
|
if lingering_threads:
|
1146
628
|
thread_names = ", ".join(t.name for t in lingering_threads)
|
1147
629
|
logger.warning(
|
1148
|
-
f"Detected {len(lingering_threads)} background thread(s) [{thread_names}] still running
|
630
|
+
f"Detected {len(lingering_threads)} background thread(s) [{thread_names}] still running "
|
631
|
+
"after container exit. This will prevent runner shutdown for up to 30 seconds."
|
1149
632
|
)
|
1150
633
|
|
1151
634
|
logger.debug("Container: done")
|