modal 0.62.115__py3-none-any.whl → 0.72.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/__init__.py +13 -9
- modal/__main__.py +41 -3
- modal/_clustered_functions.py +80 -0
- modal/_clustered_functions.pyi +22 -0
- modal/_container_entrypoint.py +402 -398
- modal/_ipython.py +3 -13
- modal/_location.py +17 -10
- modal/_output.py +243 -99
- modal/_pty.py +2 -2
- modal/_resolver.py +55 -60
- modal/_resources.py +26 -7
- modal/_runtime/__init__.py +1 -0
- modal/_runtime/asgi.py +519 -0
- modal/_runtime/container_io_manager.py +1025 -0
- modal/{execution_context.py → _runtime/execution_context.py} +11 -2
- modal/_runtime/telemetry.py +169 -0
- modal/_runtime/user_code_imports.py +356 -0
- modal/_serialization.py +123 -6
- modal/_traceback.py +47 -187
- modal/_tunnel.py +50 -14
- modal/_tunnel.pyi +19 -36
- modal/_utils/app_utils.py +3 -17
- modal/_utils/async_utils.py +386 -104
- modal/_utils/blob_utils.py +157 -186
- modal/_utils/bytes_io_segment_payload.py +97 -0
- modal/_utils/deprecation.py +89 -0
- modal/_utils/docker_utils.py +98 -0
- modal/_utils/function_utils.py +299 -98
- modal/_utils/grpc_testing.py +47 -34
- modal/_utils/grpc_utils.py +54 -21
- modal/_utils/hash_utils.py +51 -10
- modal/_utils/http_utils.py +39 -9
- modal/_utils/logger.py +2 -1
- modal/_utils/mount_utils.py +34 -16
- modal/_utils/name_utils.py +58 -0
- modal/_utils/package_utils.py +14 -1
- modal/_utils/pattern_utils.py +205 -0
- modal/_utils/rand_pb_testing.py +3 -3
- modal/_utils/shell_utils.py +15 -49
- modal/_vendor/a2wsgi_wsgi.py +62 -72
- modal/_vendor/cloudpickle.py +1 -1
- modal/_watcher.py +12 -10
- modal/app.py +561 -323
- modal/app.pyi +474 -262
- modal/call_graph.py +7 -6
- modal/cli/_download.py +22 -6
- modal/cli/_traceback.py +200 -0
- modal/cli/app.py +203 -42
- modal/cli/config.py +12 -5
- modal/cli/container.py +61 -13
- modal/cli/dict.py +128 -0
- modal/cli/entry_point.py +26 -13
- modal/cli/environment.py +40 -9
- modal/cli/import_refs.py +21 -48
- modal/cli/launch.py +28 -14
- modal/cli/network_file_system.py +57 -21
- modal/cli/profile.py +1 -1
- modal/cli/programs/run_jupyter.py +34 -9
- modal/cli/programs/vscode.py +58 -8
- modal/cli/queues.py +131 -0
- modal/cli/run.py +199 -96
- modal/cli/secret.py +5 -4
- modal/cli/token.py +7 -2
- modal/cli/utils.py +74 -8
- modal/cli/volume.py +97 -56
- modal/client.py +248 -144
- modal/client.pyi +156 -124
- modal/cloud_bucket_mount.py +43 -30
- modal/cloud_bucket_mount.pyi +32 -25
- modal/cls.py +528 -141
- modal/cls.pyi +189 -145
- modal/config.py +32 -15
- modal/container_process.py +177 -0
- modal/container_process.pyi +82 -0
- modal/dict.py +50 -54
- modal/dict.pyi +120 -164
- modal/environments.py +106 -5
- modal/environments.pyi +77 -25
- modal/exception.py +30 -43
- modal/experimental.py +62 -2
- modal/file_io.py +537 -0
- modal/file_io.pyi +235 -0
- modal/file_pattern_matcher.py +196 -0
- modal/functions.py +846 -428
- modal/functions.pyi +446 -387
- modal/gpu.py +57 -44
- modal/image.py +943 -417
- modal/image.pyi +584 -245
- modal/io_streams.py +434 -0
- modal/io_streams.pyi +122 -0
- modal/mount.py +223 -90
- modal/mount.pyi +241 -243
- modal/network_file_system.py +85 -86
- modal/network_file_system.pyi +151 -110
- modal/object.py +66 -36
- modal/object.pyi +166 -143
- modal/output.py +63 -0
- modal/parallel_map.py +73 -47
- modal/parallel_map.pyi +51 -63
- modal/partial_function.py +272 -107
- modal/partial_function.pyi +219 -120
- modal/proxy.py +15 -12
- modal/proxy.pyi +3 -8
- modal/queue.py +96 -72
- modal/queue.pyi +210 -135
- modal/requirements/2024.04.txt +2 -1
- modal/requirements/2024.10.txt +16 -0
- modal/requirements/README.md +21 -0
- modal/requirements/base-images.json +22 -0
- modal/retries.py +45 -4
- modal/runner.py +325 -203
- modal/runner.pyi +124 -110
- modal/running_app.py +27 -4
- modal/sandbox.py +509 -231
- modal/sandbox.pyi +396 -169
- modal/schedule.py +2 -2
- modal/scheduler_placement.py +20 -3
- modal/secret.py +41 -25
- modal/secret.pyi +62 -42
- modal/serving.py +39 -49
- modal/serving.pyi +37 -43
- modal/stream_type.py +15 -0
- modal/token_flow.py +5 -3
- modal/token_flow.pyi +37 -32
- modal/volume.py +123 -137
- modal/volume.pyi +228 -221
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/METADATA +5 -5
- modal-0.72.13.dist-info/RECORD +174 -0
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/top_level.txt +0 -1
- modal_docs/gen_reference_docs.py +3 -1
- modal_docs/mdmd/mdmd.py +0 -1
- modal_docs/mdmd/signatures.py +1 -2
- modal_global_objects/images/base_images.py +28 -0
- modal_global_objects/mounts/python_standalone.py +2 -2
- modal_proto/__init__.py +1 -1
- modal_proto/api.proto +1231 -531
- modal_proto/api_grpc.py +750 -430
- modal_proto/api_pb2.py +2102 -1176
- modal_proto/api_pb2.pyi +8859 -0
- modal_proto/api_pb2_grpc.py +1329 -675
- modal_proto/api_pb2_grpc.pyi +1416 -0
- modal_proto/modal_api_grpc.py +149 -0
- modal_proto/modal_options_grpc.py +3 -0
- modal_proto/options_pb2.pyi +20 -0
- modal_proto/options_pb2_grpc.pyi +7 -0
- modal_proto/py.typed +0 -0
- modal_version/__init__.py +1 -1
- modal_version/_version_generated.py +2 -2
- modal/_asgi.py +0 -370
- modal/_container_exec.py +0 -128
- modal/_container_io_manager.py +0 -646
- modal/_container_io_manager.pyi +0 -412
- modal/_sandbox_shell.py +0 -49
- modal/app_utils.py +0 -20
- modal/app_utils.pyi +0 -17
- modal/execution_context.pyi +0 -37
- modal/shared_volume.py +0 -23
- modal/shared_volume.pyi +0 -24
- modal-0.62.115.dist-info/RECORD +0 -207
- modal_global_objects/images/conda.py +0 -15
- modal_global_objects/images/debian_slim.py +0 -15
- modal_global_objects/images/micromamba.py +0 -15
- test/__init__.py +0 -1
- test/aio_test.py +0 -12
- test/async_utils_test.py +0 -279
- test/blob_test.py +0 -67
- test/cli_imports_test.py +0 -149
- test/cli_test.py +0 -674
- test/client_test.py +0 -203
- test/cloud_bucket_mount_test.py +0 -22
- test/cls_test.py +0 -636
- test/config_test.py +0 -149
- test/conftest.py +0 -1485
- test/container_app_test.py +0 -50
- test/container_test.py +0 -1405
- test/cpu_test.py +0 -23
- test/decorator_test.py +0 -85
- test/deprecation_test.py +0 -34
- test/dict_test.py +0 -51
- test/e2e_test.py +0 -68
- test/error_test.py +0 -7
- test/function_serialization_test.py +0 -32
- test/function_test.py +0 -791
- test/function_utils_test.py +0 -101
- test/gpu_test.py +0 -159
- test/grpc_utils_test.py +0 -82
- test/helpers.py +0 -47
- test/image_test.py +0 -814
- test/live_reload_test.py +0 -80
- test/lookup_test.py +0 -70
- test/mdmd_test.py +0 -329
- test/mount_test.py +0 -162
- test/mounted_files_test.py +0 -327
- test/network_file_system_test.py +0 -188
- test/notebook_test.py +0 -66
- test/object_test.py +0 -41
- test/package_utils_test.py +0 -25
- test/queue_test.py +0 -115
- test/resolver_test.py +0 -59
- test/retries_test.py +0 -67
- test/runner_test.py +0 -85
- test/sandbox_test.py +0 -191
- test/schedule_test.py +0 -15
- test/scheduler_placement_test.py +0 -57
- test/secret_test.py +0 -89
- test/serialization_test.py +0 -50
- test/stub_composition_test.py +0 -10
- test/stub_test.py +0 -361
- test/test_asgi_wrapper.py +0 -234
- test/token_flow_test.py +0 -18
- test/traceback_test.py +0 -135
- test/tunnel_test.py +0 -29
- test/utils_test.py +0 -88
- test/version_test.py +0 -14
- test/volume_test.py +0 -397
- test/watcher_test.py +0 -58
- test/webhook_test.py +0 -145
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/LICENSE +0 -0
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/WHEEL +0 -0
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,1025 @@
|
|
1
|
+
# Copyright Modal Labs 2024
|
2
|
+
import asyncio
|
3
|
+
import importlib.metadata
|
4
|
+
import inspect
|
5
|
+
import json
|
6
|
+
import math
|
7
|
+
import os
|
8
|
+
import signal
|
9
|
+
import sys
|
10
|
+
import time
|
11
|
+
import traceback
|
12
|
+
from collections.abc import AsyncGenerator, AsyncIterator
|
13
|
+
from contextlib import AsyncExitStack
|
14
|
+
from pathlib import Path
|
15
|
+
from typing import (
|
16
|
+
TYPE_CHECKING,
|
17
|
+
Any,
|
18
|
+
Callable,
|
19
|
+
ClassVar,
|
20
|
+
Optional,
|
21
|
+
)
|
22
|
+
|
23
|
+
from google.protobuf.empty_pb2 import Empty
|
24
|
+
from grpclib import Status
|
25
|
+
from synchronicity.async_wrap import asynccontextmanager
|
26
|
+
|
27
|
+
import modal_proto.api_pb2
|
28
|
+
from modal._serialization import deserialize, serialize, serialize_data_format
|
29
|
+
from modal._traceback import extract_traceback, print_exception
|
30
|
+
from modal._utils.async_utils import TaskContext, asyncify, synchronize_api, synchronizer
|
31
|
+
from modal._utils.blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload
|
32
|
+
from modal._utils.function_utils import _stream_function_call_data
|
33
|
+
from modal._utils.grpc_utils import retry_transient_errors
|
34
|
+
from modal._utils.package_utils import parse_major_minor_version
|
35
|
+
from modal.client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client
|
36
|
+
from modal.config import config, logger
|
37
|
+
from modal.exception import ClientClosed, InputCancellation, InvalidError, SerializationError
|
38
|
+
from modal_proto import api_pb2
|
39
|
+
|
40
|
+
if TYPE_CHECKING:
|
41
|
+
import modal._runtime.asgi
|
42
|
+
import modal._runtime.user_code_imports
|
43
|
+
|
44
|
+
|
45
|
+
DYNAMIC_CONCURRENCY_INTERVAL_SECS = 3
|
46
|
+
DYNAMIC_CONCURRENCY_TIMEOUT_SECS = 10
|
47
|
+
MAX_OUTPUT_BATCH_SIZE: int = 49
|
48
|
+
|
49
|
+
RTT_S: float = 0.5 # conservative estimate of RTT in seconds.
|
50
|
+
|
51
|
+
|
52
|
+
class UserException(Exception):
|
53
|
+
"""Used to shut down the task gracefully."""
|
54
|
+
|
55
|
+
|
56
|
+
class Sentinel:
|
57
|
+
"""Used to get type-stubs to work with this object."""
|
58
|
+
|
59
|
+
|
60
|
+
class IOContext:
|
61
|
+
"""Context object for managing input, function calls, and function executions
|
62
|
+
in a batched or single input context.
|
63
|
+
"""
|
64
|
+
|
65
|
+
input_ids: list[str]
|
66
|
+
function_call_ids: list[str]
|
67
|
+
finalized_function: "modal._runtime.user_code_imports.FinalizedFunction"
|
68
|
+
|
69
|
+
_cancel_issued: bool = False
|
70
|
+
_cancel_callback: Optional[Callable[[], None]] = None
|
71
|
+
|
72
|
+
def __init__(
|
73
|
+
self,
|
74
|
+
input_ids: list[str],
|
75
|
+
function_call_ids: list[str],
|
76
|
+
finalized_function: "modal._runtime.user_code_imports.FinalizedFunction",
|
77
|
+
function_inputs: list[api_pb2.FunctionInput],
|
78
|
+
is_batched: bool,
|
79
|
+
client: _Client,
|
80
|
+
):
|
81
|
+
self.input_ids = input_ids
|
82
|
+
self.function_call_ids = function_call_ids
|
83
|
+
self.finalized_function = finalized_function
|
84
|
+
self._function_inputs = function_inputs
|
85
|
+
self._is_batched = is_batched
|
86
|
+
self._client = client
|
87
|
+
|
88
|
+
@classmethod
|
89
|
+
async def create(
|
90
|
+
cls,
|
91
|
+
client: _Client,
|
92
|
+
finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
|
93
|
+
inputs: list[tuple[str, str, api_pb2.FunctionInput]],
|
94
|
+
is_batched: bool,
|
95
|
+
) -> "IOContext":
|
96
|
+
assert len(inputs) >= 1 if is_batched else len(inputs) == 1
|
97
|
+
input_ids, function_call_ids, function_inputs = zip(*inputs)
|
98
|
+
|
99
|
+
async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput) -> api_pb2.FunctionInput:
|
100
|
+
# If we got a pointer to a blob, download it from S3.
|
101
|
+
if input.WhichOneof("args_oneof") == "args_blob_id":
|
102
|
+
args = await blob_download(input.args_blob_id, client.stub)
|
103
|
+
# Mutating
|
104
|
+
input.ClearField("args_blob_id")
|
105
|
+
input.args = args
|
106
|
+
|
107
|
+
return input
|
108
|
+
|
109
|
+
function_inputs = await asyncio.gather(*[_populate_input_blobs(client, input) for input in function_inputs])
|
110
|
+
# check every input in batch executes the same function
|
111
|
+
method_name = function_inputs[0].method_name
|
112
|
+
assert all(method_name == input.method_name for input in function_inputs)
|
113
|
+
finalized_function = finalized_functions[method_name]
|
114
|
+
return cls(input_ids, function_call_ids, finalized_function, function_inputs, is_batched, client)
|
115
|
+
|
116
|
+
def set_cancel_callback(self, cb: Callable[[], None]):
|
117
|
+
self._cancel_callback = cb
|
118
|
+
|
119
|
+
def cancel(self):
|
120
|
+
# Ensure we only issue the cancellation once.
|
121
|
+
if self._cancel_issued:
|
122
|
+
return
|
123
|
+
|
124
|
+
if self._cancel_callback:
|
125
|
+
logger.warning(f"Received a cancellation signal while processing input {self.input_ids}")
|
126
|
+
self._cancel_issued = True
|
127
|
+
self._cancel_callback()
|
128
|
+
else:
|
129
|
+
# TODO (elias): This should not normally happen but there is a small chance of a race
|
130
|
+
# between creating a new task for an input and attaching the cancellation callback
|
131
|
+
logger.warning("Unexpected: Could not cancel input")
|
132
|
+
|
133
|
+
def _args_and_kwargs(self) -> tuple[tuple[Any, ...], dict[str, list[Any]]]:
|
134
|
+
# deserializing here instead of the constructor
|
135
|
+
# to make sure we handle user exceptions properly
|
136
|
+
# and don't retry
|
137
|
+
deserialized_args = [
|
138
|
+
deserialize(input.args, self._client) if input.args else ((), {}) for input in self._function_inputs
|
139
|
+
]
|
140
|
+
if not self._is_batched:
|
141
|
+
return deserialized_args[0]
|
142
|
+
|
143
|
+
func_name = self.finalized_function.callable.__name__
|
144
|
+
|
145
|
+
param_names = []
|
146
|
+
for param in inspect.signature(self.finalized_function.callable).parameters.values():
|
147
|
+
param_names.append(param.name)
|
148
|
+
|
149
|
+
# aggregate args and kwargs of all inputs into a kwarg dict
|
150
|
+
kwargs_by_inputs: list[dict[str, Any]] = [{} for _ in range(len(self.input_ids))]
|
151
|
+
|
152
|
+
for i, (args, kwargs) in enumerate(deserialized_args):
|
153
|
+
# check that all batched inputs should have the same number of args and kwargs
|
154
|
+
if (num_params := len(args) + len(kwargs)) != len(param_names):
|
155
|
+
raise InvalidError(
|
156
|
+
f"Modal batched function {func_name} takes {len(param_names)} positional arguments, but one invocation in the batch has {num_params}." # noqa
|
157
|
+
)
|
158
|
+
|
159
|
+
for j, arg in enumerate(args):
|
160
|
+
kwargs_by_inputs[i][param_names[j]] = arg
|
161
|
+
for k, v in kwargs.items():
|
162
|
+
if k not in param_names:
|
163
|
+
raise InvalidError(
|
164
|
+
f"Modal batched function {func_name} got unexpected keyword argument {k} in one invocation in the batch." # noqa
|
165
|
+
)
|
166
|
+
if k in kwargs_by_inputs[i]:
|
167
|
+
raise InvalidError(
|
168
|
+
f"Modal batched function {func_name} got multiple values for argument {k} in one invocation in the batch." # noqa
|
169
|
+
)
|
170
|
+
kwargs_by_inputs[i][k] = v
|
171
|
+
|
172
|
+
formatted_kwargs = {
|
173
|
+
param_name: [kwargs[param_name] for kwargs in kwargs_by_inputs] for param_name in param_names
|
174
|
+
}
|
175
|
+
return (), formatted_kwargs
|
176
|
+
|
177
|
+
def call_finalized_function(self) -> Any:
|
178
|
+
logger.debug(f"Starting input {self.input_ids}")
|
179
|
+
args, kwargs = self._args_and_kwargs()
|
180
|
+
res = self.finalized_function.callable(*args, **kwargs)
|
181
|
+
logger.debug(f"Finished input {self.input_ids}")
|
182
|
+
return res
|
183
|
+
|
184
|
+
def validate_output_data(self, data: Any) -> list[Any]:
|
185
|
+
if not self._is_batched:
|
186
|
+
return [data]
|
187
|
+
|
188
|
+
function_name = self.finalized_function.callable.__name__
|
189
|
+
if not isinstance(data, list):
|
190
|
+
raise InvalidError(f"Output of batched function {function_name} must be a list.")
|
191
|
+
if len(data) != len(self.input_ids):
|
192
|
+
raise InvalidError(
|
193
|
+
f"Output of batched function {function_name} must be a list of equal length as its inputs."
|
194
|
+
)
|
195
|
+
return data
|
196
|
+
|
197
|
+
|
198
|
+
class InputSlots:
|
199
|
+
"""A semaphore that allows dynamically adjusting the concurrency."""
|
200
|
+
|
201
|
+
active: int
|
202
|
+
value: int
|
203
|
+
waiter: Optional[asyncio.Future]
|
204
|
+
closed: bool
|
205
|
+
|
206
|
+
def __init__(self, value: int) -> None:
|
207
|
+
self.active = 0
|
208
|
+
self.value = value
|
209
|
+
self.waiter = None
|
210
|
+
self.closed = False
|
211
|
+
|
212
|
+
async def acquire(self) -> None:
|
213
|
+
if self.active < self.value:
|
214
|
+
self.active += 1
|
215
|
+
elif self.waiter is None:
|
216
|
+
self.waiter = asyncio.get_running_loop().create_future()
|
217
|
+
await self.waiter
|
218
|
+
else:
|
219
|
+
raise RuntimeError("Concurrent waiters are not supported.")
|
220
|
+
|
221
|
+
def _wake_waiter(self) -> None:
|
222
|
+
if self.active < self.value and self.waiter is not None:
|
223
|
+
if not self.waiter.cancelled(): # could have been cancelled during interpreter shutdown
|
224
|
+
self.waiter.set_result(None)
|
225
|
+
self.waiter = None
|
226
|
+
self.active += 1
|
227
|
+
|
228
|
+
def release(self) -> None:
|
229
|
+
self.active -= 1
|
230
|
+
self._wake_waiter()
|
231
|
+
|
232
|
+
def set_value(self, value: int) -> None:
|
233
|
+
if self.closed:
|
234
|
+
return
|
235
|
+
self.value = value
|
236
|
+
self._wake_waiter()
|
237
|
+
|
238
|
+
async def close(self) -> None:
|
239
|
+
self.closed = True
|
240
|
+
for _ in range(self.value):
|
241
|
+
await self.acquire()
|
242
|
+
|
243
|
+
|
244
|
+
class _ContainerIOManager:
|
245
|
+
"""Synchronizes all RPC calls and network operations for a running container.
|
246
|
+
|
247
|
+
TODO: maybe we shouldn't synchronize the whole class.
|
248
|
+
Then we could potentially move a bunch of the global functions onto it.
|
249
|
+
"""
|
250
|
+
|
251
|
+
task_id: str
|
252
|
+
function_id: str
|
253
|
+
app_id: str
|
254
|
+
function_def: api_pb2.Function
|
255
|
+
checkpoint_id: Optional[str]
|
256
|
+
|
257
|
+
calls_completed: int
|
258
|
+
total_user_time: float
|
259
|
+
current_input_id: Optional[str]
|
260
|
+
current_inputs: dict[str, IOContext] # input_id -> IOContext
|
261
|
+
current_input_started_at: Optional[float]
|
262
|
+
|
263
|
+
_target_concurrency: int
|
264
|
+
_max_concurrency: int
|
265
|
+
_concurrency_loop: Optional[asyncio.Task]
|
266
|
+
_input_slots: InputSlots
|
267
|
+
|
268
|
+
_environment_name: str
|
269
|
+
_heartbeat_loop: Optional[asyncio.Task]
|
270
|
+
_heartbeat_condition: Optional[asyncio.Condition]
|
271
|
+
_waiting_for_memory_snapshot: bool
|
272
|
+
|
273
|
+
_is_interactivity_enabled: bool
|
274
|
+
_fetching_inputs: bool
|
275
|
+
|
276
|
+
_client: _Client
|
277
|
+
|
278
|
+
_GENERATOR_STOP_SENTINEL: ClassVar[Sentinel] = Sentinel()
|
279
|
+
_singleton: ClassVar[Optional["_ContainerIOManager"]] = None
|
280
|
+
|
281
|
+
def _init(self, container_args: api_pb2.ContainerArguments, client: _Client):
|
282
|
+
self.task_id = container_args.task_id
|
283
|
+
self.function_id = container_args.function_id
|
284
|
+
self.app_id = container_args.app_id
|
285
|
+
self.function_def = container_args.function_def
|
286
|
+
self.checkpoint_id = container_args.checkpoint_id or None
|
287
|
+
|
288
|
+
self.calls_completed = 0
|
289
|
+
self.total_user_time = 0.0
|
290
|
+
self.current_input_id = None
|
291
|
+
self.current_inputs = {}
|
292
|
+
self.current_input_started_at = None
|
293
|
+
|
294
|
+
if container_args.function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
|
295
|
+
target_concurrency = 1
|
296
|
+
max_concurrency = 1
|
297
|
+
else:
|
298
|
+
target_concurrency = container_args.function_def.target_concurrent_inputs or 1
|
299
|
+
max_concurrency = container_args.function_def.max_concurrent_inputs or target_concurrency
|
300
|
+
|
301
|
+
self._target_concurrency = target_concurrency
|
302
|
+
self._max_concurrency = max_concurrency
|
303
|
+
self._concurrency_loop = None
|
304
|
+
self._stop_concurrency_loop = False
|
305
|
+
self._input_slots = InputSlots(target_concurrency)
|
306
|
+
|
307
|
+
self._environment_name = container_args.environment_name
|
308
|
+
self._heartbeat_loop = None
|
309
|
+
self._heartbeat_condition = None
|
310
|
+
self._waiting_for_memory_snapshot = False
|
311
|
+
|
312
|
+
self._is_interactivity_enabled = False
|
313
|
+
self._fetching_inputs = True
|
314
|
+
|
315
|
+
self._client = client
|
316
|
+
assert isinstance(self._client, _Client)
|
317
|
+
|
318
|
+
@property
|
319
|
+
def heartbeat_condition(self) -> asyncio.Condition:
|
320
|
+
# ensures that heartbeat condition isn't assigned to an event loop until it's used for the first time
|
321
|
+
# (On Python 3.9 and below it would be assigned to the current thread's event loop on creation)
|
322
|
+
if self._heartbeat_condition is None:
|
323
|
+
self._heartbeat_condition = asyncio.Condition()
|
324
|
+
return self._heartbeat_condition
|
325
|
+
|
326
|
+
def __new__(cls, container_args: api_pb2.ContainerArguments, client: _Client) -> "_ContainerIOManager":
|
327
|
+
cls._singleton = super().__new__(cls)
|
328
|
+
cls._singleton._init(container_args, client)
|
329
|
+
return cls._singleton
|
330
|
+
|
331
|
+
@classmethod
|
332
|
+
def _reset_singleton(cls):
|
333
|
+
"""Only used for tests."""
|
334
|
+
cls._singleton = None
|
335
|
+
|
336
|
+
async def hello(self):
|
337
|
+
await self._client.stub.ContainerHello(Empty())
|
338
|
+
|
339
|
+
async def _run_heartbeat_loop(self):
|
340
|
+
while 1:
|
341
|
+
t0 = time.monotonic()
|
342
|
+
try:
|
343
|
+
if await self._heartbeat_handle_cancellations():
|
344
|
+
# got a cancellation event, fine to start another heartbeat immediately
|
345
|
+
# since the cancellation queue should be empty on the worker server
|
346
|
+
# however, we wait at least 1s to prevent short-circuiting the heartbeat loop
|
347
|
+
# in case there is ever a bug. This means it will take at least 1s between
|
348
|
+
# two subsequent cancellations on the same task at the moment
|
349
|
+
await asyncio.sleep(1.0)
|
350
|
+
continue
|
351
|
+
except ClientClosed:
|
352
|
+
logger.info("Stopping heartbeat loop due to client shutdown")
|
353
|
+
break
|
354
|
+
except Exception as exc:
|
355
|
+
# don't stop heartbeat loop if there are transient exceptions!
|
356
|
+
time_elapsed = time.monotonic() - t0
|
357
|
+
error = exc
|
358
|
+
logger.warning(f"Heartbeat attempt failed ({time_elapsed=}, {error=})")
|
359
|
+
|
360
|
+
heartbeat_duration = time.monotonic() - t0
|
361
|
+
time_until_next_hearbeat = max(0.0, HEARTBEAT_INTERVAL - heartbeat_duration)
|
362
|
+
await asyncio.sleep(time_until_next_hearbeat)
|
363
|
+
|
364
|
+
async def _heartbeat_handle_cancellations(self) -> bool:
|
365
|
+
# Return True if a cancellation event was received, in that case
|
366
|
+
# we shouldn't wait too long for another heartbeat
|
367
|
+
async with self.heartbeat_condition:
|
368
|
+
# Continuously wait until `waiting_for_memory_snapshot` is false.
|
369
|
+
# TODO(matt): Verify that a `while` is necessary over an `if`. Spurious
|
370
|
+
# wakeups could allow execution to continue despite `_waiting_for_memory_snapshot`
|
371
|
+
# being true.
|
372
|
+
while self._waiting_for_memory_snapshot:
|
373
|
+
await self.heartbeat_condition.wait()
|
374
|
+
|
375
|
+
request = api_pb2.ContainerHeartbeatRequest(canceled_inputs_return_outputs_v2=True)
|
376
|
+
response = await retry_transient_errors(
|
377
|
+
self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
|
378
|
+
)
|
379
|
+
|
380
|
+
if response.HasField("cancel_input_event"):
|
381
|
+
# response.cancel_input_event.terminate_containers is never set, the server gets the worker to handle it.
|
382
|
+
input_ids_to_cancel = response.cancel_input_event.input_ids
|
383
|
+
if input_ids_to_cancel:
|
384
|
+
if self._max_concurrency > 1:
|
385
|
+
for input_id in input_ids_to_cancel:
|
386
|
+
if input_id in self.current_inputs:
|
387
|
+
self.current_inputs[input_id].cancel()
|
388
|
+
|
389
|
+
elif self.current_input_id and self.current_input_id in input_ids_to_cancel:
|
390
|
+
# This goes to a registered signal handler for sync Modal functions, or to the
|
391
|
+
# `SignalHandlingEventLoop` for async functions.
|
392
|
+
#
|
393
|
+
# We only send this signal on functions that do not have concurrent inputs enabled.
|
394
|
+
# This allows us to do fine-grained input cancellation. On sync functions, the
|
395
|
+
# SIGUSR1 signal should interrupt the main thread where user code is running,
|
396
|
+
# raising an InputCancellation() exception. On async functions, the signal should
|
397
|
+
# reach a handler in SignalHandlingEventLoop, which cancels the task.
|
398
|
+
logger.warning(f"Received a cancellation signal while processing input {self.current_input_id}")
|
399
|
+
os.kill(os.getpid(), signal.SIGUSR1)
|
400
|
+
return True
|
401
|
+
return False
|
402
|
+
|
403
|
+
@asynccontextmanager
|
404
|
+
async def heartbeats(self, wait_for_mem_snap: bool) -> AsyncGenerator[None, None]:
|
405
|
+
async with TaskContext() as tc:
|
406
|
+
self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop())
|
407
|
+
t.set_name("heartbeat loop")
|
408
|
+
self._waiting_for_memory_snapshot = wait_for_mem_snap
|
409
|
+
try:
|
410
|
+
yield
|
411
|
+
finally:
|
412
|
+
t.cancel()
|
413
|
+
|
414
|
+
def stop_heartbeat(self):
|
415
|
+
if self._heartbeat_loop:
|
416
|
+
self._heartbeat_loop.cancel()
|
417
|
+
|
418
|
+
@asynccontextmanager
|
419
|
+
async def dynamic_concurrency_manager(self) -> AsyncGenerator[None, None]:
|
420
|
+
async with TaskContext() as tc:
|
421
|
+
self._concurrency_loop = t = tc.create_task(self._dynamic_concurrency_loop())
|
422
|
+
t.set_name("dynamic concurrency loop")
|
423
|
+
try:
|
424
|
+
yield
|
425
|
+
finally:
|
426
|
+
t.cancel()
|
427
|
+
|
428
|
+
async def _dynamic_concurrency_loop(self):
|
429
|
+
logger.debug(f"Starting dynamic concurrency loop for task {self.task_id}")
|
430
|
+
while not self._stop_concurrency_loop:
|
431
|
+
try:
|
432
|
+
request = api_pb2.FunctionGetDynamicConcurrencyRequest(
|
433
|
+
function_id=self.function_id,
|
434
|
+
target_concurrency=self._target_concurrency,
|
435
|
+
max_concurrency=self._max_concurrency,
|
436
|
+
)
|
437
|
+
resp = await retry_transient_errors(
|
438
|
+
self._client.stub.FunctionGetDynamicConcurrency,
|
439
|
+
request,
|
440
|
+
attempt_timeout=DYNAMIC_CONCURRENCY_TIMEOUT_SECS,
|
441
|
+
)
|
442
|
+
if resp.concurrency != self._input_slots.value and not self._stop_concurrency_loop:
|
443
|
+
logger.debug(f"Dynamic concurrency set from {self._input_slots.value} to {resp.concurrency}")
|
444
|
+
self._input_slots.set_value(resp.concurrency)
|
445
|
+
|
446
|
+
except Exception as exc:
|
447
|
+
logger.debug(f"Failed to get dynamic concurrency for task {self.task_id}, {exc}")
|
448
|
+
|
449
|
+
await asyncio.sleep(DYNAMIC_CONCURRENCY_INTERVAL_SECS)
|
450
|
+
|
451
|
+
async def get_serialized_function(self) -> tuple[Optional[Any], Optional[Callable[..., Any]]]:
|
452
|
+
# Fetch the serialized function definition
|
453
|
+
request = api_pb2.FunctionGetSerializedRequest(function_id=self.function_id)
|
454
|
+
response = await self._client.stub.FunctionGetSerialized(request)
|
455
|
+
if response.function_serialized:
|
456
|
+
fun = self.deserialize(response.function_serialized)
|
457
|
+
else:
|
458
|
+
fun = None
|
459
|
+
|
460
|
+
if response.class_serialized:
|
461
|
+
cls = self.deserialize(response.class_serialized)
|
462
|
+
else:
|
463
|
+
cls = None
|
464
|
+
|
465
|
+
return cls, fun
|
466
|
+
|
467
|
+
def serialize(self, obj: Any) -> bytes:
|
468
|
+
return serialize(obj)
|
469
|
+
|
470
|
+
def deserialize(self, data: bytes) -> Any:
|
471
|
+
return deserialize(data, self._client)
|
472
|
+
|
473
|
+
@synchronizer.no_io_translation
|
474
|
+
def serialize_data_format(self, obj: Any, data_format: int) -> bytes:
|
475
|
+
return serialize_data_format(obj, data_format)
|
476
|
+
|
477
|
+
async def format_blob_data(self, data: bytes) -> dict[str, Any]:
|
478
|
+
return (
|
479
|
+
{"data_blob_id": await blob_upload(data, self._client.stub)}
|
480
|
+
if len(data) > MAX_OBJECT_SIZE_BYTES
|
481
|
+
else {"data": data}
|
482
|
+
)
|
483
|
+
|
484
|
+
async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]:
|
485
|
+
"""Read from the `data_in` stream of a function call."""
|
486
|
+
async for data in _stream_function_call_data(self._client, function_call_id, "data_in"):
|
487
|
+
yield data
|
488
|
+
|
489
|
+
async def put_data_out(
|
490
|
+
self,
|
491
|
+
function_call_id: str,
|
492
|
+
start_index: int,
|
493
|
+
data_format: int,
|
494
|
+
messages_bytes: list[Any],
|
495
|
+
) -> None:
|
496
|
+
"""Put data onto the `data_out` stream of a function call.
|
497
|
+
|
498
|
+
This is used for generator outputs, which includes web endpoint responses. Note that this
|
499
|
+
was introduced as a performance optimization in client version 0.57, so older clients will
|
500
|
+
still use the previous Postgres-backed system based on `FunctionPutOutputs()`.
|
501
|
+
"""
|
502
|
+
data_chunks: list[api_pb2.DataChunk] = []
|
503
|
+
for i, message_bytes in enumerate(messages_bytes):
|
504
|
+
chunk = api_pb2.DataChunk(data_format=data_format, index=start_index + i) # type: ignore
|
505
|
+
if len(message_bytes) > MAX_OBJECT_SIZE_BYTES:
|
506
|
+
chunk.data_blob_id = await blob_upload(message_bytes, self._client.stub)
|
507
|
+
else:
|
508
|
+
chunk.data = message_bytes
|
509
|
+
data_chunks.append(chunk)
|
510
|
+
|
511
|
+
req = api_pb2.FunctionCallPutDataRequest(function_call_id=function_call_id, data_chunks=data_chunks)
|
512
|
+
await retry_transient_errors(self._client.stub.FunctionCallPutDataOut, req)
|
513
|
+
|
514
|
+
async def generator_output_task(self, function_call_id: str, data_format: int, message_rx: asyncio.Queue) -> None:
|
515
|
+
"""Task that feeds generator outputs into a function call's `data_out` stream."""
|
516
|
+
index = 1
|
517
|
+
received_sentinel = False
|
518
|
+
while not received_sentinel:
|
519
|
+
message = await message_rx.get()
|
520
|
+
if message is self._GENERATOR_STOP_SENTINEL:
|
521
|
+
break
|
522
|
+
# ASGI 'http.response.start' and 'http.response.body' msgs are observed to be separated by 1ms.
|
523
|
+
# If we don't sleep here for 1ms we end up with an extra call to .put_data_out().
|
524
|
+
if index == 1:
|
525
|
+
await asyncio.sleep(0.001)
|
526
|
+
messages_bytes = [serialize_data_format(message, data_format)]
|
527
|
+
total_size = len(messages_bytes[0]) + 512
|
528
|
+
while total_size < 16 * 1024 * 1024: # 16 MiB, maximum size in a single message
|
529
|
+
try:
|
530
|
+
message = message_rx.get_nowait()
|
531
|
+
except asyncio.QueueEmpty:
|
532
|
+
break
|
533
|
+
if message is self._GENERATOR_STOP_SENTINEL:
|
534
|
+
received_sentinel = True
|
535
|
+
break
|
536
|
+
else:
|
537
|
+
messages_bytes.append(serialize_data_format(message, data_format))
|
538
|
+
total_size += len(messages_bytes[-1]) + 512 # 512 bytes for estimated framing overhead
|
539
|
+
await self.put_data_out(function_call_id, index, data_format, messages_bytes)
|
540
|
+
index += len(messages_bytes)
|
541
|
+
|
542
|
+
async def _queue_create(self, size: int) -> asyncio.Queue:
|
543
|
+
"""Create a queue, on the synchronicity event loop (needed on Python 3.8 and 3.9)."""
|
544
|
+
return asyncio.Queue(size)
|
545
|
+
|
546
|
+
async def _queue_put(self, queue: asyncio.Queue, value: Any) -> None:
|
547
|
+
"""Put a value onto a queue, using the synchronicity event loop."""
|
548
|
+
await queue.put(value)
|
549
|
+
|
550
|
+
def get_average_call_time(self) -> float:
|
551
|
+
if self.calls_completed == 0:
|
552
|
+
return 0
|
553
|
+
|
554
|
+
return self.total_user_time / self.calls_completed
|
555
|
+
|
556
|
+
def get_max_inputs_to_fetch(self):
|
557
|
+
if self.calls_completed == 0:
|
558
|
+
return 1
|
559
|
+
|
560
|
+
return math.ceil(RTT_S / max(self.get_average_call_time(), 1e-6))
|
561
|
+
|
562
|
+
@synchronizer.no_io_translation
|
563
|
+
async def _generate_inputs(
|
564
|
+
self,
|
565
|
+
batch_max_size: int,
|
566
|
+
batch_wait_ms: int,
|
567
|
+
) -> AsyncIterator[list[tuple[str, str, api_pb2.FunctionInput]]]:
|
568
|
+
request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id)
|
569
|
+
iteration = 0
|
570
|
+
while self._fetching_inputs:
|
571
|
+
await self._input_slots.acquire()
|
572
|
+
|
573
|
+
request.average_call_time = self.get_average_call_time()
|
574
|
+
request.max_values = self.get_max_inputs_to_fetch() # Deprecated; remove.
|
575
|
+
request.input_concurrency = self.get_input_concurrency()
|
576
|
+
request.batch_max_size, request.batch_linger_ms = batch_max_size, batch_wait_ms
|
577
|
+
|
578
|
+
yielded = False
|
579
|
+
try:
|
580
|
+
# If number of active inputs is at max queue size, this will block.
|
581
|
+
iteration += 1
|
582
|
+
response: api_pb2.FunctionGetInputsResponse = await retry_transient_errors(
|
583
|
+
self._client.stub.FunctionGetInputs, request
|
584
|
+
)
|
585
|
+
|
586
|
+
if response.rate_limit_sleep_duration:
|
587
|
+
logger.info(
|
588
|
+
"Task exceeded rate limit, sleeping for %.2fs before trying again."
|
589
|
+
% response.rate_limit_sleep_duration
|
590
|
+
)
|
591
|
+
await asyncio.sleep(response.rate_limit_sleep_duration)
|
592
|
+
elif response.inputs:
|
593
|
+
# for input cancellations and concurrency logic we currently assume
|
594
|
+
# that there is no input buffering in the container
|
595
|
+
assert 0 < len(response.inputs) <= max(1, request.batch_max_size)
|
596
|
+
inputs = []
|
597
|
+
final_input_received = False
|
598
|
+
for item in response.inputs:
|
599
|
+
if item.kill_switch:
|
600
|
+
logger.debug(f"Task {self.task_id} input kill signal input.")
|
601
|
+
return
|
602
|
+
|
603
|
+
inputs.append((item.input_id, item.function_call_id, item.input))
|
604
|
+
if item.input.final_input:
|
605
|
+
if request.batch_max_size > 0:
|
606
|
+
logger.debug(f"Task {self.task_id} Final input not expected in batch input stream")
|
607
|
+
final_input_received = True
|
608
|
+
break
|
609
|
+
|
610
|
+
# If yielded, allow input slots to be released via exit_context
|
611
|
+
yield inputs
|
612
|
+
yielded = True
|
613
|
+
|
614
|
+
# We only support max_inputs = 1 at the moment
|
615
|
+
if final_input_received or self.function_def.max_inputs == 1:
|
616
|
+
return
|
617
|
+
finally:
|
618
|
+
if not yielded:
|
619
|
+
self._input_slots.release()
|
620
|
+
|
621
|
+
@synchronizer.no_io_translation
|
622
|
+
async def run_inputs_outputs(
|
623
|
+
self,
|
624
|
+
finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
|
625
|
+
batch_max_size: int = 0,
|
626
|
+
batch_wait_ms: int = 0,
|
627
|
+
) -> AsyncIterator[IOContext]:
|
628
|
+
# Ensure we do not fetch new inputs when container is too busy.
|
629
|
+
# Before trying to fetch an input, acquire an input slot:
|
630
|
+
# - if no input is fetched, release the input slot.
|
631
|
+
# - or, when the output for the fetched input is sent, release the input slot.
|
632
|
+
dynamic_concurrency_manager = (
|
633
|
+
self.dynamic_concurrency_manager() if self._max_concurrency > self._target_concurrency else AsyncExitStack()
|
634
|
+
)
|
635
|
+
async with dynamic_concurrency_manager:
|
636
|
+
async for inputs in self._generate_inputs(batch_max_size, batch_wait_ms):
|
637
|
+
io_context = await IOContext.create(self._client, finalized_functions, inputs, batch_max_size > 0)
|
638
|
+
for input_id in io_context.input_ids:
|
639
|
+
self.current_inputs[input_id] = io_context
|
640
|
+
|
641
|
+
self.current_input_id, self.current_input_started_at = io_context.input_ids[0], time.time()
|
642
|
+
yield io_context
|
643
|
+
self.current_input_id, self.current_input_started_at = (None, None)
|
644
|
+
|
645
|
+
# collect all active input slots, meaning all inputs have wrapped up.
|
646
|
+
await self._input_slots.close()
|
647
|
+
|
648
|
+
@synchronizer.no_io_translation
|
649
|
+
async def _push_outputs(
|
650
|
+
self,
|
651
|
+
io_context: IOContext,
|
652
|
+
started_at: float,
|
653
|
+
data_format: "modal_proto.api_pb2.DataFormat.ValueType",
|
654
|
+
results: list[api_pb2.GenericResult],
|
655
|
+
) -> None:
|
656
|
+
output_created_at = time.time()
|
657
|
+
outputs = [
|
658
|
+
api_pb2.FunctionPutOutputsItem(
|
659
|
+
input_id=input_id,
|
660
|
+
input_started_at=started_at,
|
661
|
+
output_created_at=output_created_at,
|
662
|
+
result=result,
|
663
|
+
data_format=data_format,
|
664
|
+
)
|
665
|
+
for input_id, result in zip(io_context.input_ids, results)
|
666
|
+
]
|
667
|
+
await retry_transient_errors(
|
668
|
+
self._client.stub.FunctionPutOutputs,
|
669
|
+
api_pb2.FunctionPutOutputsRequest(outputs=outputs),
|
670
|
+
additional_status_codes=[Status.RESOURCE_EXHAUSTED],
|
671
|
+
max_retries=None, # Retry indefinitely, trying every 1s.
|
672
|
+
)
|
673
|
+
|
674
|
+
def serialize_exception(self, exc: BaseException) -> bytes:
|
675
|
+
try:
|
676
|
+
return self.serialize(exc)
|
677
|
+
except Exception as serialization_exc:
|
678
|
+
# We can't always serialize exceptions.
|
679
|
+
err = f"Failed to serialize exception {exc} of type {type(exc)}: {serialization_exc}"
|
680
|
+
logger.info(err)
|
681
|
+
return self.serialize(SerializationError(err))
|
682
|
+
|
683
|
+
def serialize_traceback(self, exc: BaseException) -> tuple[Optional[bytes], Optional[bytes]]:
|
684
|
+
serialized_tb, tb_line_cache = None, None
|
685
|
+
|
686
|
+
try:
|
687
|
+
tb_dict, line_cache = extract_traceback(exc, self.task_id)
|
688
|
+
serialized_tb = self.serialize(tb_dict)
|
689
|
+
tb_line_cache = self.serialize(line_cache)
|
690
|
+
except Exception:
|
691
|
+
logger.info("Failed to serialize exception traceback.")
|
692
|
+
|
693
|
+
return serialized_tb, tb_line_cache
|
694
|
+
|
695
|
+
@asynccontextmanager
|
696
|
+
async def handle_user_exception(self) -> AsyncGenerator[None, None]:
|
697
|
+
"""Sets the task as failed in a way where it's not retried.
|
698
|
+
|
699
|
+
Used for handling exceptions from container lifecycle methods at the moment, which should
|
700
|
+
trigger a task failure state.
|
701
|
+
"""
|
702
|
+
try:
|
703
|
+
yield
|
704
|
+
except KeyboardInterrupt:
|
705
|
+
# Send no task result in case we get sigint:ed by the runner
|
706
|
+
# The status of the input should have been handled externally already in that case
|
707
|
+
raise
|
708
|
+
except BaseException as exc:
|
709
|
+
if isinstance(exc, ImportError):
|
710
|
+
# Catches errors raised by global scope imports
|
711
|
+
check_fastapi_pydantic_compatibility(exc)
|
712
|
+
|
713
|
+
# Since this is on a different thread, sys.exc_info() can't find the exception in the stack.
|
714
|
+
print_exception(type(exc), exc, exc.__traceback__)
|
715
|
+
|
716
|
+
serialized_tb, tb_line_cache = self.serialize_traceback(exc)
|
717
|
+
|
718
|
+
result = api_pb2.GenericResult(
|
719
|
+
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
|
720
|
+
data=self.serialize_exception(exc),
|
721
|
+
exception=repr(exc),
|
722
|
+
traceback="".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
|
723
|
+
serialized_tb=serialized_tb or b"",
|
724
|
+
tb_line_cache=tb_line_cache or b"",
|
725
|
+
)
|
726
|
+
|
727
|
+
req = api_pb2.TaskResultRequest(result=result)
|
728
|
+
await retry_transient_errors(self._client.stub.TaskResult, req)
|
729
|
+
|
730
|
+
# Shut down the task gracefully
|
731
|
+
raise UserException()
|
732
|
+
|
733
|
+
@asynccontextmanager
|
734
|
+
async def handle_input_exception(
|
735
|
+
self,
|
736
|
+
io_context: IOContext,
|
737
|
+
started_at: float,
|
738
|
+
) -> AsyncGenerator[None, None]:
|
739
|
+
"""Handle an exception while processing a function input."""
|
740
|
+
try:
|
741
|
+
yield
|
742
|
+
except (KeyboardInterrupt, GeneratorExit):
|
743
|
+
# We need to explicitly reraise these BaseExceptions to not handle them in the catch-all:
|
744
|
+
# 1. KeyboardInterrupt can end up here even though this runs on non-main thread, since the
|
745
|
+
# code block yielded to could be sending back a main thread exception
|
746
|
+
# 2. GeneratorExit - raised if this (async) generator is garbage collected while waiting
|
747
|
+
# for the yield. Typically on event loop shutdown
|
748
|
+
raise
|
749
|
+
except (InputCancellation, asyncio.CancelledError):
|
750
|
+
# Create terminated outputs for these inputs to signal that the cancellations have been completed.
|
751
|
+
results = [
|
752
|
+
api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_TERMINATED)
|
753
|
+
for _ in io_context.input_ids
|
754
|
+
]
|
755
|
+
await self._push_outputs(
|
756
|
+
io_context=io_context,
|
757
|
+
started_at=started_at,
|
758
|
+
data_format=api_pb2.DATA_FORMAT_PICKLE,
|
759
|
+
results=results,
|
760
|
+
)
|
761
|
+
self.exit_context(started_at, io_context.input_ids)
|
762
|
+
logger.warning(f"Successfully canceled input {io_context.input_ids}")
|
763
|
+
return
|
764
|
+
except BaseException as exc:
|
765
|
+
if isinstance(exc, ImportError):
|
766
|
+
# Catches errors raised by imports from within function body
|
767
|
+
check_fastapi_pydantic_compatibility(exc)
|
768
|
+
|
769
|
+
# print exception so it's logged
|
770
|
+
print_exception(*sys.exc_info())
|
771
|
+
|
772
|
+
serialized_tb, tb_line_cache = self.serialize_traceback(exc)
|
773
|
+
|
774
|
+
# Note: we're not serializing the traceback since it contains
|
775
|
+
# local references that means we can't unpickle it. We *are*
|
776
|
+
# serializing the exception, which may have some issues (there
|
777
|
+
# was an earlier note about it that it might not be possible
|
778
|
+
# to unpickle it in some cases). Let's watch out for issues.
|
779
|
+
|
780
|
+
repr_exc = repr(exc)
|
781
|
+
if len(repr_exc) >= MAX_OBJECT_SIZE_BYTES:
|
782
|
+
# We prevent large exception messages to avoid
|
783
|
+
# unhandled exceptions causing inf loops
|
784
|
+
# and just send backa trimmed version
|
785
|
+
trimmed_bytes = len(repr_exc) - MAX_OBJECT_SIZE_BYTES - 1000
|
786
|
+
repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000]
|
787
|
+
repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception"
|
788
|
+
|
789
|
+
data: bytes = self.serialize_exception(exc) or b""
|
790
|
+
data_result_part = await self.format_blob_data(data)
|
791
|
+
results = [
|
792
|
+
api_pb2.GenericResult(
|
793
|
+
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
|
794
|
+
exception=repr_exc,
|
795
|
+
traceback=traceback.format_exc(),
|
796
|
+
serialized_tb=serialized_tb or b"",
|
797
|
+
tb_line_cache=tb_line_cache or b"",
|
798
|
+
**data_result_part,
|
799
|
+
)
|
800
|
+
for _ in io_context.input_ids
|
801
|
+
]
|
802
|
+
await self._push_outputs(
|
803
|
+
io_context=io_context,
|
804
|
+
started_at=started_at,
|
805
|
+
data_format=api_pb2.DATA_FORMAT_PICKLE,
|
806
|
+
results=results,
|
807
|
+
)
|
808
|
+
self.exit_context(started_at, io_context.input_ids)
|
809
|
+
|
810
|
+
def exit_context(self, started_at, input_ids: list[str]):
|
811
|
+
self.total_user_time += time.time() - started_at
|
812
|
+
self.calls_completed += 1
|
813
|
+
|
814
|
+
for input_id in input_ids:
|
815
|
+
self.current_inputs.pop(input_id)
|
816
|
+
|
817
|
+
self._input_slots.release()
|
818
|
+
|
819
|
+
@synchronizer.no_io_translation
|
820
|
+
async def push_outputs(
|
821
|
+
self,
|
822
|
+
io_context: IOContext,
|
823
|
+
started_at: float,
|
824
|
+
data: Any,
|
825
|
+
data_format: "modal_proto.api_pb2.DataFormat.ValueType",
|
826
|
+
) -> None:
|
827
|
+
data = io_context.validate_output_data(data)
|
828
|
+
formatted_data = await asyncio.gather(
|
829
|
+
*[self.format_blob_data(self.serialize_data_format(d, data_format)) for d in data]
|
830
|
+
)
|
831
|
+
results = [
|
832
|
+
api_pb2.GenericResult(
|
833
|
+
status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
|
834
|
+
**d,
|
835
|
+
)
|
836
|
+
for d in formatted_data
|
837
|
+
]
|
838
|
+
await self._push_outputs(
|
839
|
+
io_context=io_context,
|
840
|
+
started_at=started_at,
|
841
|
+
data_format=data_format,
|
842
|
+
results=results,
|
843
|
+
)
|
844
|
+
self.exit_context(started_at, io_context.input_ids)
|
845
|
+
|
846
|
+
async def memory_restore(self) -> None:
|
847
|
+
# Busy-wait for restore. `/__modal/restore-state.json` is created
|
848
|
+
# by the worker process with updates to the container config.
|
849
|
+
restored_path = Path(config.get("restore_state_path"))
|
850
|
+
start = time.perf_counter()
|
851
|
+
while not restored_path.exists():
|
852
|
+
logger.debug(f"Waiting for restore (elapsed={time.perf_counter() - start:.3f}s)")
|
853
|
+
await asyncio.sleep(0.01)
|
854
|
+
continue
|
855
|
+
|
856
|
+
logger.debug("Container: restored")
|
857
|
+
|
858
|
+
# Look for state file and create new client with updated credentials.
|
859
|
+
# State data is serialized with key-value pairs, example: {"task_id": "tk-000"}
|
860
|
+
with restored_path.open("r") as file:
|
861
|
+
restored_state = json.load(file)
|
862
|
+
|
863
|
+
# Start a debugger if the worker tells us to
|
864
|
+
if int(restored_state.get("snapshot_debug", 0)):
|
865
|
+
logger.debug("Entering snapshot debugger")
|
866
|
+
breakpoint()
|
867
|
+
|
868
|
+
# Local ContainerIOManager state.
|
869
|
+
for key in ["task_id", "function_id"]:
|
870
|
+
if value := restored_state.get(key):
|
871
|
+
logger.debug(f"Updating ContainerIOManager.{key} = {value}")
|
872
|
+
setattr(self, key, restored_state[key])
|
873
|
+
|
874
|
+
# Env vars and global state.
|
875
|
+
for key, value in restored_state.items():
|
876
|
+
# Empty string indicates that value does not need to be updated.
|
877
|
+
if value != "":
|
878
|
+
config.override_locally(key, value)
|
879
|
+
|
880
|
+
# Restore input to default state.
|
881
|
+
self.current_input_id = None
|
882
|
+
self.current_inputs = {}
|
883
|
+
self.current_input_started_at = None
|
884
|
+
self._client = await _Client.from_env()
|
885
|
+
|
886
|
+
async def memory_snapshot(self) -> None:
|
887
|
+
"""Message server indicating that function is ready to be checkpointed."""
|
888
|
+
if self.checkpoint_id:
|
889
|
+
logger.debug(f"Checkpoint ID: {self.checkpoint_id} (Memory Snapshot ID)")
|
890
|
+
else:
|
891
|
+
raise ValueError("No checkpoint ID provided for memory snapshot")
|
892
|
+
|
893
|
+
# Pause heartbeats since they keep the client connection open which causes the snapshotter to crash
|
894
|
+
async with self.heartbeat_condition:
|
895
|
+
# Notify the heartbeat loop that the snapshot phase has begun in order to
|
896
|
+
# prevent it from sending heartbeat RPCs
|
897
|
+
self._waiting_for_memory_snapshot = True
|
898
|
+
self.heartbeat_condition.notify_all()
|
899
|
+
|
900
|
+
await self._client.stub.ContainerCheckpoint(
|
901
|
+
api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id)
|
902
|
+
)
|
903
|
+
|
904
|
+
await self._client._close(prep_for_restore=True)
|
905
|
+
|
906
|
+
logger.debug("Memory snapshot request sent. Connection closed.")
|
907
|
+
await self.memory_restore()
|
908
|
+
# Turn heartbeats back on. This is safe since the snapshot RPC
|
909
|
+
# and the restore phase has finished.
|
910
|
+
self._waiting_for_memory_snapshot = False
|
911
|
+
self.heartbeat_condition.notify_all()
|
912
|
+
|
913
|
+
async def volume_commit(self, volume_ids: list[str]) -> None:
|
914
|
+
"""
|
915
|
+
Perform volume commit for given `volume_ids`.
|
916
|
+
Only used on container exit to persist uncommitted changes on behalf of user.
|
917
|
+
"""
|
918
|
+
if not volume_ids:
|
919
|
+
return
|
920
|
+
await asyncify(os.sync)()
|
921
|
+
results = await asyncio.gather(
|
922
|
+
*[
|
923
|
+
retry_transient_errors(
|
924
|
+
self._client.stub.VolumeCommit,
|
925
|
+
api_pb2.VolumeCommitRequest(volume_id=v_id),
|
926
|
+
max_retries=9,
|
927
|
+
base_delay=0.25,
|
928
|
+
max_delay=256,
|
929
|
+
delay_factor=2,
|
930
|
+
)
|
931
|
+
for v_id in volume_ids
|
932
|
+
],
|
933
|
+
return_exceptions=True,
|
934
|
+
)
|
935
|
+
for volume_id, res in zip(volume_ids, results):
|
936
|
+
if isinstance(res, Exception):
|
937
|
+
logger.error(f"modal.Volume background commit failed for {volume_id}. Exception: {res}")
|
938
|
+
else:
|
939
|
+
logger.debug(f"modal.Volume background commit success for {volume_id}.")
|
940
|
+
|
941
|
+
async def interact(self, from_breakpoint: bool = False):
|
942
|
+
if self._is_interactivity_enabled:
|
943
|
+
# Currently, interactivity is enabled forever
|
944
|
+
return
|
945
|
+
self._is_interactivity_enabled = True
|
946
|
+
|
947
|
+
if not self.function_def.pty_info.pty_type:
|
948
|
+
trigger = "breakpoint()" if from_breakpoint else "modal.interact()"
|
949
|
+
raise InvalidError(f"Cannot use {trigger} without running Modal in interactive mode.")
|
950
|
+
|
951
|
+
try:
|
952
|
+
await self._client.stub.FunctionStartPtyShell(Empty())
|
953
|
+
except Exception as e:
|
954
|
+
logger.error("Failed to start PTY shell.")
|
955
|
+
raise e
|
956
|
+
|
957
|
+
@property
|
958
|
+
def target_concurrency(self) -> int:
|
959
|
+
return self._target_concurrency
|
960
|
+
|
961
|
+
@property
|
962
|
+
def max_concurrency(self) -> int:
|
963
|
+
return self._max_concurrency
|
964
|
+
|
965
|
+
@classmethod
|
966
|
+
def get_input_concurrency(cls) -> int:
|
967
|
+
"""
|
968
|
+
Returns the number of usable input slots.
|
969
|
+
|
970
|
+
If concurrency is reduced, active slots can exceed allotted slots. Returns the larger value
|
971
|
+
in this case.
|
972
|
+
"""
|
973
|
+
|
974
|
+
io_manager = cls._singleton
|
975
|
+
assert io_manager
|
976
|
+
return max(io_manager._input_slots.active, io_manager._input_slots.value)
|
977
|
+
|
978
|
+
@classmethod
|
979
|
+
def set_input_concurrency(cls, concurrency: int):
|
980
|
+
"""
|
981
|
+
Edit the number of input slots.
|
982
|
+
|
983
|
+
This disables the background loop which automatically adjusts concurrency
|
984
|
+
within [target_concurrency, max_concurrency].
|
985
|
+
"""
|
986
|
+
io_manager = cls._singleton
|
987
|
+
assert io_manager
|
988
|
+
io_manager._stop_concurrency_loop = True
|
989
|
+
concurrency = min(concurrency, io_manager._max_concurrency)
|
990
|
+
io_manager._input_slots.set_value(concurrency)
|
991
|
+
|
992
|
+
@classmethod
|
993
|
+
def stop_fetching_inputs(cls):
|
994
|
+
assert cls._singleton
|
995
|
+
cls._singleton._fetching_inputs = False
|
996
|
+
|
997
|
+
|
998
|
+
ContainerIOManager = synchronize_api(_ContainerIOManager)
|
999
|
+
|
1000
|
+
|
1001
|
+
def check_fastapi_pydantic_compatibility(exc: ImportError) -> None:
|
1002
|
+
"""Add a helpful note to an exception that is likely caused by a pydantic<>fastapi version incompatibility.
|
1003
|
+
|
1004
|
+
We need this becasue the legacy set of container requirements (image_builder_version=2023.12) contains a
|
1005
|
+
version of fastapi that is not forwards-compatible with pydantic 2.0+, and users commonly run into issues
|
1006
|
+
building an image that specifies a more recent version only for pydantic.
|
1007
|
+
"""
|
1008
|
+
note = (
|
1009
|
+
"Please ensure that your Image contains compatible versions of fastapi and pydantic."
|
1010
|
+
" If using pydantic>=2.0, you must also install fastapi>=0.100."
|
1011
|
+
)
|
1012
|
+
name = exc.name or ""
|
1013
|
+
if name.startswith("pydantic"):
|
1014
|
+
try:
|
1015
|
+
fastapi_version = parse_major_minor_version(importlib.metadata.version("fastapi"))
|
1016
|
+
pydantic_version = parse_major_minor_version(importlib.metadata.version("pydantic"))
|
1017
|
+
if pydantic_version >= (2, 0) and fastapi_version < (0, 100):
|
1018
|
+
if sys.version_info < (3, 11):
|
1019
|
+
# https://peps.python.org/pep-0678/
|
1020
|
+
exc.__notes__ = [note]
|
1021
|
+
else:
|
1022
|
+
exc.add_note(note)
|
1023
|
+
except Exception:
|
1024
|
+
# Since we're just trying to add a helpful message, don't fail here
|
1025
|
+
pass
|