modal 1.0.3.dev10__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modal might be problematic. Click here for more details.
- modal/__init__.py +0 -2
- modal/__main__.py +3 -4
- modal/_billing.py +80 -0
- modal/_clustered_functions.py +7 -3
- modal/_clustered_functions.pyi +15 -3
- modal/_container_entrypoint.py +51 -69
- modal/_functions.py +508 -240
- modal/_grpc_client.py +171 -0
- modal/_load_context.py +105 -0
- modal/_object.py +81 -21
- modal/_output.py +58 -45
- modal/_partial_function.py +48 -73
- modal/_pty.py +7 -3
- modal/_resolver.py +26 -46
- modal/_runtime/asgi.py +4 -3
- modal/_runtime/container_io_manager.py +358 -220
- modal/_runtime/container_io_manager.pyi +296 -101
- modal/_runtime/execution_context.py +18 -2
- modal/_runtime/execution_context.pyi +64 -7
- modal/_runtime/gpu_memory_snapshot.py +262 -57
- modal/_runtime/user_code_imports.py +28 -58
- modal/_serialization.py +90 -6
- modal/_traceback.py +42 -1
- modal/_tunnel.pyi +380 -12
- modal/_utils/async_utils.py +84 -29
- modal/_utils/auth_token_manager.py +111 -0
- modal/_utils/blob_utils.py +181 -58
- modal/_utils/deprecation.py +19 -0
- modal/_utils/function_utils.py +91 -47
- modal/_utils/grpc_utils.py +89 -66
- modal/_utils/mount_utils.py +26 -1
- modal/_utils/name_utils.py +17 -3
- modal/_utils/task_command_router_client.py +536 -0
- modal/_utils/time_utils.py +34 -6
- modal/app.py +256 -88
- modal/app.pyi +909 -92
- modal/billing.py +5 -0
- modal/builder/2025.06.txt +18 -0
- modal/builder/PREVIEW.txt +18 -0
- modal/builder/base-images.json +58 -0
- modal/cli/_download.py +19 -3
- modal/cli/_traceback.py +3 -2
- modal/cli/app.py +4 -4
- modal/cli/cluster.py +15 -7
- modal/cli/config.py +5 -3
- modal/cli/container.py +7 -6
- modal/cli/dict.py +22 -16
- modal/cli/entry_point.py +12 -5
- modal/cli/environment.py +5 -4
- modal/cli/import_refs.py +3 -3
- modal/cli/launch.py +102 -5
- modal/cli/network_file_system.py +11 -12
- modal/cli/profile.py +3 -2
- modal/cli/programs/launch_instance_ssh.py +94 -0
- modal/cli/programs/run_jupyter.py +1 -1
- modal/cli/programs/run_marimo.py +95 -0
- modal/cli/programs/vscode.py +1 -1
- modal/cli/queues.py +57 -26
- modal/cli/run.py +91 -23
- modal/cli/secret.py +48 -22
- modal/cli/token.py +7 -8
- modal/cli/utils.py +4 -7
- modal/cli/volume.py +31 -25
- modal/client.py +15 -85
- modal/client.pyi +183 -62
- modal/cloud_bucket_mount.py +5 -3
- modal/cloud_bucket_mount.pyi +197 -5
- modal/cls.py +200 -126
- modal/cls.pyi +446 -68
- modal/config.py +29 -11
- modal/container_process.py +319 -19
- modal/container_process.pyi +190 -20
- modal/dict.py +290 -71
- modal/dict.pyi +835 -83
- modal/environments.py +15 -27
- modal/environments.pyi +46 -24
- modal/exception.py +14 -2
- modal/experimental/__init__.py +194 -40
- modal/experimental/flash.py +618 -0
- modal/experimental/flash.pyi +380 -0
- modal/experimental/ipython.py +11 -7
- modal/file_io.py +29 -36
- modal/file_io.pyi +251 -53
- modal/file_pattern_matcher.py +56 -16
- modal/functions.pyi +673 -92
- modal/gpu.py +1 -1
- modal/image.py +528 -176
- modal/image.pyi +1572 -145
- modal/io_streams.py +458 -128
- modal/io_streams.pyi +433 -52
- modal/mount.py +216 -151
- modal/mount.pyi +225 -78
- modal/network_file_system.py +45 -62
- modal/network_file_system.pyi +277 -56
- modal/object.pyi +93 -17
- modal/parallel_map.py +942 -129
- modal/parallel_map.pyi +294 -15
- modal/partial_function.py +0 -2
- modal/partial_function.pyi +234 -19
- modal/proxy.py +17 -8
- modal/proxy.pyi +36 -3
- modal/queue.py +270 -65
- modal/queue.pyi +817 -57
- modal/runner.py +115 -101
- modal/runner.pyi +205 -49
- modal/sandbox.py +512 -136
- modal/sandbox.pyi +845 -111
- modal/schedule.py +1 -1
- modal/secret.py +300 -70
- modal/secret.pyi +589 -34
- modal/serving.py +7 -11
- modal/serving.pyi +7 -8
- modal/snapshot.py +11 -8
- modal/snapshot.pyi +25 -4
- modal/token_flow.py +4 -4
- modal/token_flow.pyi +28 -8
- modal/volume.py +416 -158
- modal/volume.pyi +1117 -121
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +10 -9
- modal-1.2.3.dev7.dist-info/RECORD +195 -0
- modal_docs/mdmd/mdmd.py +17 -4
- modal_proto/api.proto +534 -79
- modal_proto/api_grpc.py +337 -1
- modal_proto/api_pb2.py +1522 -968
- modal_proto/api_pb2.pyi +1619 -134
- modal_proto/api_pb2_grpc.py +699 -4
- modal_proto/api_pb2_grpc.pyi +226 -14
- modal_proto/modal_api_grpc.py +175 -154
- modal_proto/sandbox_router.proto +145 -0
- modal_proto/sandbox_router_grpc.py +105 -0
- modal_proto/sandbox_router_pb2.py +149 -0
- modal_proto/sandbox_router_pb2.pyi +333 -0
- modal_proto/sandbox_router_pb2_grpc.py +203 -0
- modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
- modal_proto/task_command_router.proto +144 -0
- modal_proto/task_command_router_grpc.py +105 -0
- modal_proto/task_command_router_pb2.py +149 -0
- modal_proto/task_command_router_pb2.pyi +333 -0
- modal_proto/task_command_router_pb2_grpc.py +203 -0
- modal_proto/task_command_router_pb2_grpc.pyi +75 -0
- modal_version/__init__.py +1 -1
- modal/requirements/PREVIEW.txt +0 -16
- modal/requirements/base-images.json +0 -26
- modal-1.0.3.dev10.dist-info/RECORD +0 -179
- modal_proto/modal_options_grpc.py +0 -3
- modal_proto/options.proto +0 -19
- modal_proto/options_grpc.py +0 -3
- modal_proto/options_pb2.py +0 -35
- modal_proto/options_pb2.pyi +0 -20
- modal_proto/options_pb2_grpc.py +0 -4
- modal_proto/options_pb2_grpc.pyi +0 -7
- /modal/{requirements → builder}/2023.12.312.txt +0 -0
- /modal/{requirements → builder}/2023.12.txt +0 -0
- /modal/{requirements → builder}/2024.04.txt +0 -0
- /modal/{requirements → builder}/2024.10.txt +0 -0
- /modal/{requirements → builder}/README.md +0 -0
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
|
@@ -16,6 +16,7 @@ from typing import (
|
|
|
16
16
|
Any,
|
|
17
17
|
Callable,
|
|
18
18
|
ClassVar,
|
|
19
|
+
Generator,
|
|
19
20
|
Optional,
|
|
20
21
|
cast,
|
|
21
22
|
)
|
|
@@ -24,22 +25,25 @@ from google.protobuf.empty_pb2 import Empty
|
|
|
24
25
|
from grpclib import Status
|
|
25
26
|
from synchronicity.async_wrap import asynccontextmanager
|
|
26
27
|
|
|
27
|
-
import modal_proto.api_pb2
|
|
28
28
|
from modal._runtime import gpu_memory_snapshot
|
|
29
|
-
from modal._serialization import
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
29
|
+
from modal._serialization import (
|
|
30
|
+
deserialize_data_format,
|
|
31
|
+
pickle_exception,
|
|
32
|
+
pickle_traceback,
|
|
33
|
+
serialize_data_format,
|
|
34
|
+
)
|
|
35
|
+
from modal._traceback import print_exception
|
|
36
|
+
from modal._utils.async_utils import TaskContext, aclosing, asyncify, synchronize_api, synchronizer
|
|
37
|
+
from modal._utils.blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload, format_blob_data
|
|
33
38
|
from modal._utils.function_utils import _stream_function_call_data
|
|
34
|
-
from modal._utils.grpc_utils import
|
|
39
|
+
from modal._utils.grpc_utils import Retry
|
|
35
40
|
from modal._utils.package_utils import parse_major_minor_version
|
|
36
41
|
from modal.client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client
|
|
37
42
|
from modal.config import config, logger
|
|
38
|
-
from modal.exception import ClientClosed, InputCancellation, InvalidError
|
|
43
|
+
from modal.exception import ClientClosed, InputCancellation, InvalidError
|
|
39
44
|
from modal_proto import api_pb2
|
|
40
45
|
|
|
41
46
|
if TYPE_CHECKING:
|
|
42
|
-
import modal._runtime.asgi
|
|
43
47
|
import modal._runtime.user_code_imports
|
|
44
48
|
|
|
45
49
|
|
|
@@ -66,6 +70,7 @@ class IOContext:
|
|
|
66
70
|
input_ids: list[str]
|
|
67
71
|
retry_counts: list[int]
|
|
68
72
|
function_call_ids: list[str]
|
|
73
|
+
attempt_tokens: list[str]
|
|
69
74
|
function_inputs: list[api_pb2.FunctionInput]
|
|
70
75
|
finalized_function: "modal._runtime.user_code_imports.FinalizedFunction"
|
|
71
76
|
|
|
@@ -77,6 +82,7 @@ class IOContext:
|
|
|
77
82
|
input_ids: list[str],
|
|
78
83
|
retry_counts: list[int],
|
|
79
84
|
function_call_ids: list[str],
|
|
85
|
+
attempt_tokens: list[str],
|
|
80
86
|
finalized_function: "modal._runtime.user_code_imports.FinalizedFunction",
|
|
81
87
|
function_inputs: list[api_pb2.FunctionInput],
|
|
82
88
|
is_batched: bool,
|
|
@@ -85,6 +91,7 @@ class IOContext:
|
|
|
85
91
|
self.input_ids = input_ids
|
|
86
92
|
self.retry_counts = retry_counts
|
|
87
93
|
self.function_call_ids = function_call_ids
|
|
94
|
+
self.attempt_tokens = attempt_tokens
|
|
88
95
|
self.finalized_function = finalized_function
|
|
89
96
|
self.function_inputs = function_inputs
|
|
90
97
|
self._is_batched = is_batched
|
|
@@ -95,11 +102,11 @@ class IOContext:
|
|
|
95
102
|
cls,
|
|
96
103
|
client: _Client,
|
|
97
104
|
finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
|
|
98
|
-
inputs: list[tuple[str, int, str, api_pb2.FunctionInput]],
|
|
105
|
+
inputs: list[tuple[str, int, str, str, api_pb2.FunctionInput]],
|
|
99
106
|
is_batched: bool,
|
|
100
107
|
) -> "IOContext":
|
|
101
108
|
assert len(inputs) >= 1 if is_batched else len(inputs) == 1
|
|
102
|
-
input_ids, retry_counts, function_call_ids, function_inputs = zip(*inputs)
|
|
109
|
+
input_ids, retry_counts, function_call_ids, attempt_tokens, function_inputs = zip(*inputs)
|
|
103
110
|
|
|
104
111
|
async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput) -> api_pb2.FunctionInput:
|
|
105
112
|
# If we got a pointer to a blob, download it from S3.
|
|
@@ -121,6 +128,7 @@ class IOContext:
|
|
|
121
128
|
cast(list[str], input_ids),
|
|
122
129
|
cast(list[int], retry_counts),
|
|
123
130
|
cast(list[str], function_call_ids),
|
|
131
|
+
cast(list[str], attempt_tokens),
|
|
124
132
|
finalized_function,
|
|
125
133
|
cast(list[api_pb2.FunctionInput], function_inputs),
|
|
126
134
|
is_batched,
|
|
@@ -148,9 +156,13 @@ class IOContext:
|
|
|
148
156
|
# deserializing here instead of the constructor
|
|
149
157
|
# to make sure we handle user exceptions properly
|
|
150
158
|
# and don't retry
|
|
151
|
-
deserialized_args = [
|
|
152
|
-
|
|
153
|
-
|
|
159
|
+
deserialized_args = []
|
|
160
|
+
for input in self.function_inputs:
|
|
161
|
+
if input.args:
|
|
162
|
+
data_format = input.data_format
|
|
163
|
+
deserialized_args.append(deserialize_data_format(input.args, data_format, self._client))
|
|
164
|
+
else:
|
|
165
|
+
deserialized_args.append(((), {}))
|
|
154
166
|
if not self._is_batched:
|
|
155
167
|
return deserialized_args[0]
|
|
156
168
|
|
|
@@ -188,25 +200,229 @@ class IOContext:
|
|
|
188
200
|
}
|
|
189
201
|
return (), formatted_kwargs
|
|
190
202
|
|
|
191
|
-
def
|
|
203
|
+
def _generator_output_format(self) -> "api_pb2.DataFormat.ValueType":
|
|
204
|
+
return self._determine_output_format(self.function_inputs[0].data_format)
|
|
205
|
+
|
|
206
|
+
def _prepare_batch_output(self, data: Any) -> list[Any]:
|
|
207
|
+
# validate that output is valid for batch
|
|
208
|
+
if self._is_batched:
|
|
209
|
+
# assert data is list etc.
|
|
210
|
+
function_name = self.finalized_function.callable.__name__
|
|
211
|
+
|
|
212
|
+
if not isinstance(data, list):
|
|
213
|
+
raise InvalidError(f"Output of batched function {function_name} must be a list.")
|
|
214
|
+
if len(data) != len(self.input_ids):
|
|
215
|
+
raise InvalidError(
|
|
216
|
+
f"Output of batched function {function_name} must be a list of equal length as its inputs."
|
|
217
|
+
)
|
|
218
|
+
return data
|
|
219
|
+
else:
|
|
220
|
+
return [data]
|
|
221
|
+
|
|
222
|
+
def call_function_sync(self) -> list[Any]:
|
|
192
223
|
logger.debug(f"Starting input {self.input_ids}")
|
|
193
224
|
args, kwargs = self._args_and_kwargs()
|
|
194
|
-
|
|
225
|
+
expected_value_or_values = self.finalized_function.callable(*args, **kwargs)
|
|
226
|
+
if (
|
|
227
|
+
inspect.iscoroutine(expected_value_or_values)
|
|
228
|
+
or inspect.isgenerator(expected_value_or_values)
|
|
229
|
+
or inspect.isasyncgen(expected_value_or_values)
|
|
230
|
+
):
|
|
231
|
+
raise InvalidError(
|
|
232
|
+
f"Sync (non-generator) function return value of type {type(expected_value_or_values)}."
|
|
233
|
+
" You might need to use @app.function(..., is_generator=True)."
|
|
234
|
+
)
|
|
195
235
|
logger.debug(f"Finished input {self.input_ids}")
|
|
196
|
-
return
|
|
236
|
+
return self._prepare_batch_output(expected_value_or_values)
|
|
197
237
|
|
|
198
|
-
def
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
238
|
+
async def call_function_async(self) -> list[Any]:
|
|
239
|
+
logger.debug(f"Starting input {self.input_ids}")
|
|
240
|
+
args, kwargs = self._args_and_kwargs()
|
|
241
|
+
expected_coro = self.finalized_function.callable(*args, **kwargs)
|
|
242
|
+
if (
|
|
243
|
+
not inspect.iscoroutine(expected_coro)
|
|
244
|
+
or inspect.isgenerator(expected_coro)
|
|
245
|
+
or inspect.isasyncgen(expected_coro)
|
|
246
|
+
):
|
|
206
247
|
raise InvalidError(
|
|
207
|
-
f"
|
|
248
|
+
f"Async (non-generator) function returned value of type {type(expected_coro)}"
|
|
249
|
+
" You might need to use @app.function(..., is_generator=True)."
|
|
208
250
|
)
|
|
209
|
-
|
|
251
|
+
value = await expected_coro
|
|
252
|
+
logger.debug(f"Finished input {self.input_ids}")
|
|
253
|
+
return self._prepare_batch_output(value)
|
|
254
|
+
|
|
255
|
+
def call_generator_sync(self) -> Generator[Any, None, None]:
|
|
256
|
+
assert not self._is_batched
|
|
257
|
+
logger.debug(f"Starting generator input {self.input_ids}")
|
|
258
|
+
args, kwargs = self._args_and_kwargs()
|
|
259
|
+
expected_gen = self.finalized_function.callable(*args, **kwargs)
|
|
260
|
+
if not inspect.isgenerator(expected_gen):
|
|
261
|
+
raise InvalidError(f"Generator function returned value of type {type(expected_gen)}")
|
|
262
|
+
|
|
263
|
+
for result in expected_gen:
|
|
264
|
+
yield result
|
|
265
|
+
logger.debug(f"Finished generator input {self.input_ids}")
|
|
266
|
+
|
|
267
|
+
async def call_generator_async(self) -> AsyncGenerator[Any, None]:
|
|
268
|
+
assert not self._is_batched
|
|
269
|
+
logger.debug(f"Starting generator input {self.input_ids}")
|
|
270
|
+
args, kwargs = self._args_and_kwargs()
|
|
271
|
+
expected_async_gen = self.finalized_function.callable(*args, **kwargs)
|
|
272
|
+
if not inspect.isasyncgen(expected_async_gen):
|
|
273
|
+
raise InvalidError(f"Async generator function returned value of type {type(expected_async_gen)}")
|
|
274
|
+
|
|
275
|
+
async with aclosing(expected_async_gen) as gen:
|
|
276
|
+
async for result in gen:
|
|
277
|
+
yield result
|
|
278
|
+
logger.debug(f"Finished generator input {self.input_ids}")
|
|
279
|
+
|
|
280
|
+
async def output_items_cancellation(self, started_at: float):
|
|
281
|
+
output_created_at = time.time()
|
|
282
|
+
# Create terminated outputs for these inputs to signal that the cancellations have been completed.
|
|
283
|
+
return [
|
|
284
|
+
api_pb2.FunctionPutOutputsItem(
|
|
285
|
+
input_id=input_id,
|
|
286
|
+
input_started_at=started_at,
|
|
287
|
+
output_created_at=output_created_at,
|
|
288
|
+
result=api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_TERMINATED),
|
|
289
|
+
retry_count=retry_count,
|
|
290
|
+
)
|
|
291
|
+
for input_id, retry_count in zip(self.input_ids, self.retry_counts)
|
|
292
|
+
]
|
|
293
|
+
|
|
294
|
+
def _determine_output_format(self, input_format: "api_pb2.DataFormat.ValueType") -> "api_pb2.DataFormat.ValueType":
|
|
295
|
+
if input_format in self.finalized_function.supported_output_formats:
|
|
296
|
+
return input_format
|
|
297
|
+
elif self.finalized_function.supported_output_formats:
|
|
298
|
+
# This branch would normally be hit when calling a restricted_output function with Pickle input
|
|
299
|
+
# but we enforce cbor output at function definition level. In the future we might send the intended
|
|
300
|
+
# output format along with the input to make this disitinction in the calling client instead
|
|
301
|
+
logger.debug(
|
|
302
|
+
f"Got an input with format {input_format}, but can only produce output"
|
|
303
|
+
f" using formats {self.finalized_function.supported_output_formats}"
|
|
304
|
+
)
|
|
305
|
+
return self.finalized_function.supported_output_formats[0]
|
|
306
|
+
else:
|
|
307
|
+
# This should never happen since self.finalized_function.supported_output_formats should be
|
|
308
|
+
# populated with defaults in case it's empty, log a warning
|
|
309
|
+
logger.warning(f"Got an input with format {input_format}, but the function has no defined output formats")
|
|
310
|
+
return api_pb2.DATA_FORMAT_PICKLE
|
|
311
|
+
|
|
312
|
+
async def output_items_exception(
|
|
313
|
+
self, started_at: float, task_id: str, exc: BaseException
|
|
314
|
+
) -> list[api_pb2.FunctionPutOutputsItem]:
|
|
315
|
+
# Note: we're not pickling the traceback since it contains
|
|
316
|
+
# local references that means we can't unpickle it. We *are*
|
|
317
|
+
# pickling the exception, which may have some issues (there
|
|
318
|
+
# was an earlier note about it that it might not be possible
|
|
319
|
+
# to unpickle it in some cases). Let's watch out for issues.
|
|
320
|
+
repr_exc = repr(exc)
|
|
321
|
+
if len(repr_exc) >= MAX_OBJECT_SIZE_BYTES:
|
|
322
|
+
# We prevent large exception messages to avoid
|
|
323
|
+
# unhandled exceptions causing inf loops
|
|
324
|
+
# and just send backa trimmed version
|
|
325
|
+
trimmed_bytes = len(repr_exc) - MAX_OBJECT_SIZE_BYTES - 1000
|
|
326
|
+
repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000]
|
|
327
|
+
repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception"
|
|
328
|
+
|
|
329
|
+
data: bytes = pickle_exception(exc)
|
|
330
|
+
data_result_part = await format_blob_data(data, self._client.stub)
|
|
331
|
+
serialized_tb, tb_line_cache = pickle_traceback(exc, task_id)
|
|
332
|
+
|
|
333
|
+
# Failure outputs for when input exceptions occur
|
|
334
|
+
def data_format_specific_output(input_format: "api_pb2.DataFormat.ValueType") -> dict:
|
|
335
|
+
output_format = self._determine_output_format(input_format)
|
|
336
|
+
if output_format == api_pb2.DATA_FORMAT_PICKLE:
|
|
337
|
+
return {
|
|
338
|
+
"data_format": output_format,
|
|
339
|
+
"result": api_pb2.GenericResult(
|
|
340
|
+
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
|
|
341
|
+
exception=repr_exc,
|
|
342
|
+
traceback=traceback.format_exc(),
|
|
343
|
+
serialized_tb=serialized_tb,
|
|
344
|
+
tb_line_cache=tb_line_cache,
|
|
345
|
+
**data_result_part,
|
|
346
|
+
),
|
|
347
|
+
}
|
|
348
|
+
else:
|
|
349
|
+
return {
|
|
350
|
+
"data_format": output_format,
|
|
351
|
+
"result": api_pb2.GenericResult(
|
|
352
|
+
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
|
|
353
|
+
exception=repr_exc,
|
|
354
|
+
traceback=traceback.format_exc(),
|
|
355
|
+
),
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
# all inputs in the batch get the same failure:
|
|
359
|
+
output_created_at = time.time()
|
|
360
|
+
return [
|
|
361
|
+
api_pb2.FunctionPutOutputsItem(
|
|
362
|
+
input_id=input_id,
|
|
363
|
+
input_started_at=started_at,
|
|
364
|
+
output_created_at=output_created_at,
|
|
365
|
+
retry_count=retry_count,
|
|
366
|
+
**data_format_specific_output(function_input.data_format),
|
|
367
|
+
)
|
|
368
|
+
for input_id, retry_count, function_input in zip(self.input_ids, self.retry_counts, self.function_inputs)
|
|
369
|
+
]
|
|
370
|
+
|
|
371
|
+
def output_items_generator_done(self, started_at: float, items_total: int) -> list[api_pb2.FunctionPutOutputsItem]:
|
|
372
|
+
assert not self._is_batched, "generators are not supported with batched inputs"
|
|
373
|
+
assert len(self.function_inputs) == 1, "generators are expected to have 1 input"
|
|
374
|
+
# Serialize and format the data
|
|
375
|
+
serialized_bytes = serialize_data_format(
|
|
376
|
+
api_pb2.GeneratorDone(items_total=items_total), data_format=api_pb2.DATA_FORMAT_GENERATOR_DONE
|
|
377
|
+
)
|
|
378
|
+
return [
|
|
379
|
+
api_pb2.FunctionPutOutputsItem(
|
|
380
|
+
input_id=self.input_ids[0],
|
|
381
|
+
input_started_at=started_at,
|
|
382
|
+
output_created_at=time.time(),
|
|
383
|
+
result=api_pb2.GenericResult(
|
|
384
|
+
status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
|
|
385
|
+
data=serialized_bytes,
|
|
386
|
+
),
|
|
387
|
+
data_format=api_pb2.DATA_FORMAT_GENERATOR_DONE,
|
|
388
|
+
retry_count=self.retry_counts[0],
|
|
389
|
+
)
|
|
390
|
+
]
|
|
391
|
+
|
|
392
|
+
async def output_items(self, started_at: float, data: list[Any]) -> list[api_pb2.FunctionPutOutputsItem]:
|
|
393
|
+
output_created_at = time.time()
|
|
394
|
+
|
|
395
|
+
# Process all items concurrently and create output items directly
|
|
396
|
+
async def package_output(
|
|
397
|
+
item: Any, input_id: str, retry_count: int, input_format: "api_pb2.DataFormat.ValueType"
|
|
398
|
+
) -> api_pb2.FunctionPutOutputsItem:
|
|
399
|
+
output_format = self._determine_output_format(input_format)
|
|
400
|
+
|
|
401
|
+
serialized_bytes = serialize_data_format(item, data_format=output_format)
|
|
402
|
+
formatted = await format_blob_data(serialized_bytes, self._client.stub)
|
|
403
|
+
# Create the result
|
|
404
|
+
result = api_pb2.GenericResult(
|
|
405
|
+
status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
|
|
406
|
+
**formatted,
|
|
407
|
+
)
|
|
408
|
+
return api_pb2.FunctionPutOutputsItem(
|
|
409
|
+
input_id=input_id,
|
|
410
|
+
input_started_at=started_at,
|
|
411
|
+
output_created_at=output_created_at,
|
|
412
|
+
result=result,
|
|
413
|
+
data_format=output_format,
|
|
414
|
+
retry_count=retry_count,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
# Process all items concurrently
|
|
418
|
+
return await asyncio.gather(
|
|
419
|
+
*[
|
|
420
|
+
package_output(item, input_id, retry_count, function_input.data_format)
|
|
421
|
+
for item, input_id, retry_count, function_input in zip(
|
|
422
|
+
data, self.input_ids, self.retry_counts, self.function_inputs
|
|
423
|
+
)
|
|
424
|
+
]
|
|
425
|
+
)
|
|
210
426
|
|
|
211
427
|
|
|
212
428
|
class InputSlots:
|
|
@@ -267,6 +483,7 @@ class _ContainerIOManager:
|
|
|
267
483
|
app_id: str
|
|
268
484
|
function_def: api_pb2.Function
|
|
269
485
|
checkpoint_id: Optional[str]
|
|
486
|
+
input_plane_server_url: Optional[str]
|
|
270
487
|
|
|
271
488
|
calls_completed: int
|
|
272
489
|
total_user_time: float
|
|
@@ -290,7 +507,6 @@ class _ContainerIOManager:
|
|
|
290
507
|
|
|
291
508
|
_client: _Client
|
|
292
509
|
|
|
293
|
-
_GENERATOR_STOP_SENTINEL: ClassVar[Sentinel] = Sentinel()
|
|
294
510
|
_singleton: ClassVar[Optional["_ContainerIOManager"]] = None
|
|
295
511
|
|
|
296
512
|
def _init(self, container_args: api_pb2.ContainerArguments, client: _Client):
|
|
@@ -300,6 +516,8 @@ class _ContainerIOManager:
|
|
|
300
516
|
self.function_def = container_args.function_def
|
|
301
517
|
self.checkpoint_id = container_args.checkpoint_id or None
|
|
302
518
|
|
|
519
|
+
self.input_plane_server_url = container_args.input_plane_server_url
|
|
520
|
+
|
|
303
521
|
self.calls_completed = 0
|
|
304
522
|
self.total_user_time = 0.0
|
|
305
523
|
self.current_input_id = None
|
|
@@ -323,6 +541,7 @@ class _ContainerIOManager:
|
|
|
323
541
|
self._heartbeat_loop = None
|
|
324
542
|
self._heartbeat_condition = None
|
|
325
543
|
self._waiting_for_memory_snapshot = False
|
|
544
|
+
self._cuda_checkpoint_session = None
|
|
326
545
|
|
|
327
546
|
self._is_interactivity_enabled = False
|
|
328
547
|
self._fetching_inputs = True
|
|
@@ -404,8 +623,8 @@ class _ContainerIOManager:
|
|
|
404
623
|
await self.heartbeat_condition.wait()
|
|
405
624
|
|
|
406
625
|
request = api_pb2.ContainerHeartbeatRequest(canceled_inputs_return_outputs_v2=True)
|
|
407
|
-
response = await
|
|
408
|
-
|
|
626
|
+
response = await self._client.stub.ContainerHeartbeat(
|
|
627
|
+
request, retry=Retry(attempt_timeout=HEARTBEAT_TIMEOUT)
|
|
409
628
|
)
|
|
410
629
|
|
|
411
630
|
if response.HasField("cancel_input_event"):
|
|
@@ -452,10 +671,9 @@ class _ContainerIOManager:
|
|
|
452
671
|
target_concurrency=self._target_concurrency,
|
|
453
672
|
max_concurrency=self._max_concurrency,
|
|
454
673
|
)
|
|
455
|
-
resp = await
|
|
456
|
-
self._client.stub.FunctionGetDynamicConcurrency,
|
|
674
|
+
resp = await self._client.stub.FunctionGetDynamicConcurrency(
|
|
457
675
|
request,
|
|
458
|
-
attempt_timeout=DYNAMIC_CONCURRENCY_TIMEOUT_SECS,
|
|
676
|
+
retry=Retry(attempt_timeout=DYNAMIC_CONCURRENCY_TIMEOUT_SECS),
|
|
459
677
|
)
|
|
460
678
|
if resp.concurrency != self._input_slots.value and not self._stop_concurrency_loop:
|
|
461
679
|
logger.debug(f"Dynamic concurrency set from {self._input_slots.value} to {resp.concurrency}")
|
|
@@ -466,27 +684,23 @@ class _ContainerIOManager:
|
|
|
466
684
|
|
|
467
685
|
await asyncio.sleep(DYNAMIC_CONCURRENCY_INTERVAL_SECS)
|
|
468
686
|
|
|
469
|
-
|
|
470
|
-
def serialize_data_format(self, obj: Any, data_format: int) -> bytes:
|
|
471
|
-
return serialize_data_format(obj, data_format)
|
|
472
|
-
|
|
473
|
-
async def format_blob_data(self, data: bytes) -> dict[str, Any]:
|
|
474
|
-
return (
|
|
475
|
-
{"data_blob_id": await blob_upload(data, self._client.stub)}
|
|
476
|
-
if len(data) > MAX_OBJECT_SIZE_BYTES
|
|
477
|
-
else {"data": data}
|
|
478
|
-
)
|
|
479
|
-
|
|
480
|
-
async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]:
|
|
687
|
+
async def get_data_in(self, function_call_id: str, attempt_token: Optional[str]) -> AsyncIterator[Any]:
|
|
481
688
|
"""Read from the `data_in` stream of a function call."""
|
|
482
|
-
|
|
689
|
+
stub = self._client.stub
|
|
690
|
+
if self.input_plane_server_url:
|
|
691
|
+
stub = await self._client.get_stub(self.input_plane_server_url)
|
|
692
|
+
|
|
693
|
+
async for data in _stream_function_call_data(
|
|
694
|
+
self._client, stub, function_call_id, variant="data_in", attempt_token=attempt_token
|
|
695
|
+
):
|
|
483
696
|
yield data
|
|
484
697
|
|
|
485
698
|
async def put_data_out(
|
|
486
699
|
self,
|
|
487
700
|
function_call_id: str,
|
|
701
|
+
attempt_token: str,
|
|
488
702
|
start_index: int,
|
|
489
|
-
data_format:
|
|
703
|
+
data_format: "api_pb2.DataFormat.ValueType",
|
|
490
704
|
serialized_messages: list[Any],
|
|
491
705
|
) -> None:
|
|
492
706
|
"""Put data onto the `data_out` stream of a function call.
|
|
@@ -505,35 +719,60 @@ class _ContainerIOManager:
|
|
|
505
719
|
data_chunks.append(chunk)
|
|
506
720
|
|
|
507
721
|
req = api_pb2.FunctionCallPutDataRequest(function_call_id=function_call_id, data_chunks=data_chunks)
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
722
|
+
if attempt_token:
|
|
723
|
+
req.attempt_token = attempt_token # oneof clears function_call_id.
|
|
724
|
+
|
|
725
|
+
if self.input_plane_server_url:
|
|
726
|
+
stub = await self._client.get_stub(self.input_plane_server_url)
|
|
727
|
+
await stub.FunctionCallPutDataOut(req)
|
|
728
|
+
else:
|
|
729
|
+
await self._client.stub.FunctionCallPutDataOut(req)
|
|
730
|
+
|
|
731
|
+
@asynccontextmanager
|
|
732
|
+
async def generator_output_sender(
|
|
733
|
+
self,
|
|
734
|
+
function_call_id: str,
|
|
735
|
+
attempt_token: str,
|
|
736
|
+
data_format: "api_pb2.DataFormat.ValueType",
|
|
737
|
+
message_rx: asyncio.Queue,
|
|
738
|
+
) -> AsyncGenerator[None, None]:
|
|
739
|
+
"""Runs background task that feeds generator outputs into a function call's `data_out` stream."""
|
|
740
|
+
GENERATOR_STOP_SENTINEL = Sentinel()
|
|
741
|
+
|
|
742
|
+
async def generator_output_task():
|
|
743
|
+
index = 1
|
|
744
|
+
received_sentinel = False
|
|
745
|
+
while not received_sentinel:
|
|
746
|
+
message = await message_rx.get()
|
|
747
|
+
if message is GENERATOR_STOP_SENTINEL:
|
|
531
748
|
break
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
749
|
+
# ASGI 'http.response.start' and 'http.response.body' msgs are observed to be separated by 1ms.
|
|
750
|
+
# If we don't sleep here for 1ms we end up with an extra call to .put_data_out().
|
|
751
|
+
if index == 1:
|
|
752
|
+
await asyncio.sleep(0.001)
|
|
753
|
+
serialized_messages = [serialize_data_format(message, data_format)]
|
|
754
|
+
total_size = len(serialized_messages[0]) + 512
|
|
755
|
+
while total_size < 16 * 1024 * 1024: # 16 MiB, maximum size in a single message
|
|
756
|
+
try:
|
|
757
|
+
message = message_rx.get_nowait()
|
|
758
|
+
except asyncio.QueueEmpty:
|
|
759
|
+
break
|
|
760
|
+
if message is GENERATOR_STOP_SENTINEL:
|
|
761
|
+
received_sentinel = True
|
|
762
|
+
break
|
|
763
|
+
else:
|
|
764
|
+
serialized_messages.append(serialize_data_format(message, data_format))
|
|
765
|
+
total_size += len(serialized_messages[-1]) + 512 # 512 bytes for estimated framing overhead
|
|
766
|
+
await self.put_data_out(function_call_id, attempt_token, index, data_format, serialized_messages)
|
|
767
|
+
index += len(serialized_messages)
|
|
768
|
+
|
|
769
|
+
task = asyncio.create_task(generator_output_task())
|
|
770
|
+
try:
|
|
771
|
+
yield
|
|
772
|
+
finally:
|
|
773
|
+
# gracefully stop the task after all current inputs have been sent
|
|
774
|
+
await message_rx.put(GENERATOR_STOP_SENTINEL)
|
|
775
|
+
await task
|
|
537
776
|
|
|
538
777
|
async def _queue_create(self, size: int) -> asyncio.Queue:
|
|
539
778
|
"""Create a queue, on the synchronicity event loop (needed on Python 3.8 and 3.9)."""
|
|
@@ -560,7 +799,7 @@ class _ContainerIOManager:
|
|
|
560
799
|
self,
|
|
561
800
|
batch_max_size: int,
|
|
562
801
|
batch_wait_ms: int,
|
|
563
|
-
) -> AsyncIterator[list[tuple[str, int, str, api_pb2.FunctionInput]]]:
|
|
802
|
+
) -> AsyncIterator[list[tuple[str, int, str, str, api_pb2.FunctionInput]]]:
|
|
564
803
|
request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id)
|
|
565
804
|
iteration = 0
|
|
566
805
|
while self._fetching_inputs:
|
|
@@ -575,9 +814,7 @@ class _ContainerIOManager:
|
|
|
575
814
|
try:
|
|
576
815
|
# If number of active inputs is at max queue size, this will block.
|
|
577
816
|
iteration += 1
|
|
578
|
-
response: api_pb2.FunctionGetInputsResponse = await
|
|
579
|
-
self._client.stub.FunctionGetInputs, request
|
|
580
|
-
)
|
|
817
|
+
response: api_pb2.FunctionGetInputsResponse = await self._client.stub.FunctionGetInputs(request)
|
|
581
818
|
|
|
582
819
|
if response.rate_limit_sleep_duration:
|
|
583
820
|
logger.info(
|
|
@@ -595,7 +832,9 @@ class _ContainerIOManager:
|
|
|
595
832
|
if item.kill_switch:
|
|
596
833
|
logger.debug(f"Task {self.task_id} input kill signal input.")
|
|
597
834
|
return
|
|
598
|
-
inputs.append(
|
|
835
|
+
inputs.append(
|
|
836
|
+
(item.input_id, item.retry_count, item.function_call_id, item.attempt_token, item.input)
|
|
837
|
+
)
|
|
599
838
|
if item.input.final_input:
|
|
600
839
|
if request.batch_max_size > 0:
|
|
601
840
|
logger.debug(f"Task {self.task_id} Final input not expected in batch input stream")
|
|
@@ -636,62 +875,24 @@ class _ContainerIOManager:
|
|
|
636
875
|
self.current_input_id, self.current_input_started_at = io_context.input_ids[0], time.time()
|
|
637
876
|
yield io_context
|
|
638
877
|
self.current_input_id, self.current_input_started_at = (None, None)
|
|
639
|
-
|
|
640
878
|
# collect all active input slots, meaning all inputs have wrapped up.
|
|
641
879
|
await self._input_slots.close()
|
|
642
880
|
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
self,
|
|
646
|
-
io_context: IOContext,
|
|
647
|
-
started_at: float,
|
|
648
|
-
data_format: "modal_proto.api_pb2.DataFormat.ValueType",
|
|
649
|
-
results: list[api_pb2.GenericResult],
|
|
650
|
-
) -> None:
|
|
651
|
-
output_created_at = time.time()
|
|
652
|
-
outputs = [
|
|
653
|
-
api_pb2.FunctionPutOutputsItem(
|
|
654
|
-
input_id=input_id,
|
|
655
|
-
input_started_at=started_at,
|
|
656
|
-
output_created_at=output_created_at,
|
|
657
|
-
result=result,
|
|
658
|
-
data_format=data_format,
|
|
659
|
-
retry_count=retry_count,
|
|
660
|
-
)
|
|
661
|
-
for input_id, retry_count, result in zip(io_context.input_ids, io_context.retry_counts, results)
|
|
662
|
-
]
|
|
663
|
-
|
|
881
|
+
async def _send_outputs(self, started_at: float, outputs: list[api_pb2.FunctionPutOutputsItem]) -> None:
|
|
882
|
+
"""Send pre-built output items with retry and chunking."""
|
|
664
883
|
# There are multiple outputs for a single IOContext in the case of @modal.batched.
|
|
665
884
|
# Limit the batch size to 20 to stay within message size limits and buffer size limits.
|
|
666
885
|
output_batch_size = 20
|
|
667
886
|
for i in range(0, len(outputs), output_batch_size):
|
|
668
|
-
await
|
|
669
|
-
self._client.stub.FunctionPutOutputs,
|
|
887
|
+
await self._client.stub.FunctionPutOutputs(
|
|
670
888
|
api_pb2.FunctionPutOutputsRequest(outputs=outputs[i : i + output_batch_size]),
|
|
671
|
-
|
|
672
|
-
|
|
889
|
+
retry=Retry(
|
|
890
|
+
additional_status_codes=[Status.RESOURCE_EXHAUSTED],
|
|
891
|
+
max_retries=None, # Retry indefinitely, trying every 1s.
|
|
892
|
+
),
|
|
673
893
|
)
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
try:
|
|
677
|
-
return serialize(exc)
|
|
678
|
-
except Exception as serialization_exc:
|
|
679
|
-
# We can't always serialize exceptions.
|
|
680
|
-
err = f"Failed to serialize exception {exc} of type {type(exc)}: {serialization_exc}"
|
|
681
|
-
logger.info(err)
|
|
682
|
-
return serialize(SerializationError(err))
|
|
683
|
-
|
|
684
|
-
def serialize_traceback(self, exc: BaseException) -> tuple[Optional[bytes], Optional[bytes]]:
|
|
685
|
-
serialized_tb, tb_line_cache = None, None
|
|
686
|
-
|
|
687
|
-
try:
|
|
688
|
-
tb_dict, line_cache = extract_traceback(exc, self.task_id)
|
|
689
|
-
serialized_tb = serialize(tb_dict)
|
|
690
|
-
tb_line_cache = serialize(line_cache)
|
|
691
|
-
except Exception:
|
|
692
|
-
logger.info("Failed to serialize exception traceback.")
|
|
693
|
-
|
|
694
|
-
return serialized_tb, tb_line_cache
|
|
894
|
+
input_ids = [output.input_id for output in outputs]
|
|
895
|
+
self.exit_context(started_at, input_ids)
|
|
695
896
|
|
|
696
897
|
@asynccontextmanager
|
|
697
898
|
async def handle_user_exception(self) -> AsyncGenerator[None, None]:
|
|
@@ -714,11 +915,14 @@ class _ContainerIOManager:
|
|
|
714
915
|
# Since this is on a different thread, sys.exc_info() can't find the exception in the stack.
|
|
715
916
|
print_exception(type(exc), exc, exc.__traceback__)
|
|
716
917
|
|
|
717
|
-
serialized_tb, tb_line_cache = self.
|
|
918
|
+
serialized_tb, tb_line_cache = pickle_traceback(exc, self.task_id)
|
|
718
919
|
|
|
920
|
+
data_or_blob = await format_blob_data(pickle_exception(exc), self._client.stub)
|
|
719
921
|
result = api_pb2.GenericResult(
|
|
720
922
|
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
|
|
721
|
-
|
|
923
|
+
**data_or_blob,
|
|
924
|
+
# TODO: there is no way to communicate the data format here
|
|
925
|
+
# since it usually goes on the envelope outside of GenericResult
|
|
722
926
|
exception=repr(exc),
|
|
723
927
|
traceback="".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
|
|
724
928
|
serialized_tb=serialized_tb or b"",
|
|
@@ -726,7 +930,7 @@ class _ContainerIOManager:
|
|
|
726
930
|
)
|
|
727
931
|
|
|
728
932
|
req = api_pb2.TaskResultRequest(result=result)
|
|
729
|
-
await
|
|
933
|
+
await self._client.stub.TaskResult(req)
|
|
730
934
|
|
|
731
935
|
# Shut down the task gracefully
|
|
732
936
|
raise UserException()
|
|
@@ -748,18 +952,8 @@ class _ContainerIOManager:
|
|
|
748
952
|
# for the yield. Typically on event loop shutdown
|
|
749
953
|
raise
|
|
750
954
|
except (InputCancellation, asyncio.CancelledError):
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_TERMINATED)
|
|
754
|
-
for _ in io_context.input_ids
|
|
755
|
-
]
|
|
756
|
-
await self._push_outputs(
|
|
757
|
-
io_context=io_context,
|
|
758
|
-
started_at=started_at,
|
|
759
|
-
data_format=api_pb2.DATA_FORMAT_PICKLE,
|
|
760
|
-
results=results,
|
|
761
|
-
)
|
|
762
|
-
self.exit_context(started_at, io_context.input_ids)
|
|
955
|
+
outputs = await io_context.output_items_cancellation(started_at)
|
|
956
|
+
await self._send_outputs(started_at, outputs)
|
|
763
957
|
logger.warning(f"Successfully canceled input {io_context.input_ids}")
|
|
764
958
|
return
|
|
765
959
|
except BaseException as exc:
|
|
@@ -769,44 +963,8 @@ class _ContainerIOManager:
|
|
|
769
963
|
|
|
770
964
|
# print exception so it's logged
|
|
771
965
|
print_exception(*sys.exc_info())
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
# Note: we're not serializing the traceback since it contains
|
|
776
|
-
# local references that means we can't unpickle it. We *are*
|
|
777
|
-
# serializing the exception, which may have some issues (there
|
|
778
|
-
# was an earlier note about it that it might not be possible
|
|
779
|
-
# to unpickle it in some cases). Let's watch out for issues.
|
|
780
|
-
|
|
781
|
-
repr_exc = repr(exc)
|
|
782
|
-
if len(repr_exc) >= MAX_OBJECT_SIZE_BYTES:
|
|
783
|
-
# We prevent large exception messages to avoid
|
|
784
|
-
# unhandled exceptions causing inf loops
|
|
785
|
-
# and just send backa trimmed version
|
|
786
|
-
trimmed_bytes = len(repr_exc) - MAX_OBJECT_SIZE_BYTES - 1000
|
|
787
|
-
repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000]
|
|
788
|
-
repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception"
|
|
789
|
-
|
|
790
|
-
data: bytes = self.serialize_exception(exc) or b""
|
|
791
|
-
data_result_part = await self.format_blob_data(data)
|
|
792
|
-
results = [
|
|
793
|
-
api_pb2.GenericResult(
|
|
794
|
-
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
|
|
795
|
-
exception=repr_exc,
|
|
796
|
-
traceback=traceback.format_exc(),
|
|
797
|
-
serialized_tb=serialized_tb or b"",
|
|
798
|
-
tb_line_cache=tb_line_cache or b"",
|
|
799
|
-
**data_result_part,
|
|
800
|
-
)
|
|
801
|
-
for _ in io_context.input_ids
|
|
802
|
-
]
|
|
803
|
-
await self._push_outputs(
|
|
804
|
-
io_context=io_context,
|
|
805
|
-
started_at=started_at,
|
|
806
|
-
data_format=api_pb2.DATA_FORMAT_PICKLE,
|
|
807
|
-
results=results,
|
|
808
|
-
)
|
|
809
|
-
self.exit_context(started_at, io_context.input_ids)
|
|
966
|
+
outputs = await io_context.output_items_exception(started_at, self.task_id, exc)
|
|
967
|
+
await self._send_outputs(started_at, outputs)
|
|
810
968
|
|
|
811
969
|
def exit_context(self, started_at, input_ids: list[str]):
|
|
812
970
|
self.total_user_time += time.time() - started_at
|
|
@@ -817,32 +975,17 @@ class _ContainerIOManager:
|
|
|
817
975
|
|
|
818
976
|
self._input_slots.release()
|
|
819
977
|
|
|
978
|
+
# skip inspection of user-generated output_data for synchronicity input translation
|
|
820
979
|
@synchronizer.no_io_translation
|
|
821
980
|
async def push_outputs(
|
|
822
981
|
self,
|
|
823
982
|
io_context: IOContext,
|
|
824
983
|
started_at: float,
|
|
825
|
-
|
|
826
|
-
data_format: "modal_proto.api_pb2.DataFormat.ValueType",
|
|
984
|
+
output_data: list[Any], # one per output
|
|
827
985
|
) -> None:
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
)
|
|
832
|
-
results = [
|
|
833
|
-
api_pb2.GenericResult(
|
|
834
|
-
status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
|
|
835
|
-
**d,
|
|
836
|
-
)
|
|
837
|
-
for d in formatted_data
|
|
838
|
-
]
|
|
839
|
-
await self._push_outputs(
|
|
840
|
-
io_context=io_context,
|
|
841
|
-
started_at=started_at,
|
|
842
|
-
data_format=data_format,
|
|
843
|
-
results=results,
|
|
844
|
-
)
|
|
845
|
-
self.exit_context(started_at, io_context.input_ids)
|
|
986
|
+
# The standard output encoding+sending method for successful function outputs
|
|
987
|
+
outputs = await io_context.output_items(started_at, output_data)
|
|
988
|
+
await self._send_outputs(started_at, outputs)
|
|
846
989
|
|
|
847
990
|
async def memory_restore(self) -> None:
|
|
848
991
|
# Busy-wait for restore. `/__modal/restore-state.json` is created
|
|
@@ -881,13 +1024,11 @@ class _ContainerIOManager:
|
|
|
881
1024
|
# Restore GPU memory.
|
|
882
1025
|
if self.function_def._experimental_enable_gpu_snapshot and self.function_def.resources.gpu_config.gpu_type:
|
|
883
1026
|
logger.debug("GPU memory snapshot enabled. Attempting to restore GPU memory.")
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
)
|
|
890
|
-
gpu_memory_snapshot.toggle()
|
|
1027
|
+
|
|
1028
|
+
assert self._cuda_checkpoint_session, (
|
|
1029
|
+
"CudaCheckpointSession not found when attempting to restore GPU memory"
|
|
1030
|
+
)
|
|
1031
|
+
self._cuda_checkpoint_session.restore()
|
|
891
1032
|
|
|
892
1033
|
# Restore input to default state.
|
|
893
1034
|
self.current_input_id = None
|
|
@@ -907,14 +1048,9 @@ class _ContainerIOManager:
|
|
|
907
1048
|
# Snapshot GPU memory.
|
|
908
1049
|
if self.function_def._experimental_enable_gpu_snapshot and self.function_def.resources.gpu_config.gpu_type:
|
|
909
1050
|
logger.debug("GPU memory snapshot enabled. Attempting to snapshot GPU memory.")
|
|
910
|
-
gpu_process_state = gpu_memory_snapshot.get_state()
|
|
911
|
-
if gpu_process_state != gpu_memory_snapshot.CudaCheckpointState.RUNNING:
|
|
912
|
-
raise ValueError(
|
|
913
|
-
f"Cannot snapshot GPU state if it isn't running. Current GPU state: {gpu_process_state}"
|
|
914
|
-
)
|
|
915
1051
|
|
|
916
|
-
gpu_memory_snapshot.
|
|
917
|
-
|
|
1052
|
+
self._cuda_checkpoint_session = gpu_memory_snapshot.CudaCheckpointSession()
|
|
1053
|
+
self._cuda_checkpoint_session.checkpoint()
|
|
918
1054
|
|
|
919
1055
|
# Notify the heartbeat loop that the snapshot phase has begun in order to
|
|
920
1056
|
# prevent it from sending heartbeat RPCs
|
|
@@ -944,13 +1080,14 @@ class _ContainerIOManager:
|
|
|
944
1080
|
await asyncify(os.sync)()
|
|
945
1081
|
results = await asyncio.gather(
|
|
946
1082
|
*[
|
|
947
|
-
|
|
948
|
-
self._client.stub.VolumeCommit,
|
|
1083
|
+
self._client.stub.VolumeCommit(
|
|
949
1084
|
api_pb2.VolumeCommitRequest(volume_id=v_id),
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
1085
|
+
retry=Retry(
|
|
1086
|
+
max_retries=9,
|
|
1087
|
+
base_delay=0.25,
|
|
1088
|
+
max_delay=256,
|
|
1089
|
+
delay_factor=2,
|
|
1090
|
+
),
|
|
954
1091
|
)
|
|
955
1092
|
for v_id in volume_ids
|
|
956
1093
|
],
|
|
@@ -1019,7 +1156,8 @@ class _ContainerIOManager:
|
|
|
1019
1156
|
|
|
1020
1157
|
@classmethod
|
|
1021
1158
|
def stop_fetching_inputs(cls):
|
|
1022
|
-
|
|
1159
|
+
if not cls._singleton:
|
|
1160
|
+
raise RuntimeError("Must be called from within a Modal container.")
|
|
1023
1161
|
cls._singleton._fetching_inputs = False
|
|
1024
1162
|
|
|
1025
1163
|
|