modal 1.1.5.dev43__py3-none-any.whl → 1.1.5.dev45__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modal might be problematic. Click here for more details.
- modal/_container_entrypoint.py +19 -41
- modal/_functions.py +21 -14
- modal/_runtime/container_io_manager.py +252 -150
- modal/_runtime/container_io_manager.pyi +32 -48
- modal/_runtime/user_code_imports.py +15 -5
- modal/_serialization.py +57 -1
- modal/_utils/blob_utils.py +4 -0
- modal/_utils/function_utils.py +22 -8
- modal/app.py +4 -0
- modal/app.pyi +4 -0
- modal/client.pyi +2 -2
- modal/config.py +5 -0
- modal/dict.py +15 -2
- modal/dict.pyi +22 -6
- modal/functions.pyi +7 -6
- modal/image.py +10 -3
- modal/parallel_map.py +2 -4
- {modal-1.1.5.dev43.dist-info → modal-1.1.5.dev45.dist-info}/METADATA +2 -1
- {modal-1.1.5.dev43.dist-info → modal-1.1.5.dev45.dist-info}/RECORD +24 -24
- modal_version/__init__.py +1 -1
- {modal-1.1.5.dev43.dist-info → modal-1.1.5.dev45.dist-info}/WHEEL +0 -0
- {modal-1.1.5.dev43.dist-info → modal-1.1.5.dev45.dist-info}/entry_points.txt +0 -0
- {modal-1.1.5.dev43.dist-info → modal-1.1.5.dev45.dist-info}/licenses/LICENSE +0 -0
- {modal-1.1.5.dev43.dist-info → modal-1.1.5.dev45.dist-info}/top_level.txt +0 -0
|
@@ -16,6 +16,7 @@ from typing import (
|
|
|
16
16
|
Any,
|
|
17
17
|
Callable,
|
|
18
18
|
ClassVar,
|
|
19
|
+
Generator,
|
|
19
20
|
Optional,
|
|
20
21
|
cast,
|
|
21
22
|
)
|
|
@@ -24,18 +25,22 @@ from google.protobuf.empty_pb2 import Empty
|
|
|
24
25
|
from grpclib import Status
|
|
25
26
|
from synchronicity.async_wrap import asynccontextmanager
|
|
26
27
|
|
|
27
|
-
import modal_proto.api_pb2
|
|
28
28
|
from modal._runtime import gpu_memory_snapshot
|
|
29
|
-
from modal._serialization import
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
29
|
+
from modal._serialization import (
|
|
30
|
+
deserialize_data_format,
|
|
31
|
+
pickle_exception,
|
|
32
|
+
pickle_traceback,
|
|
33
|
+
serialize_data_format,
|
|
34
|
+
)
|
|
35
|
+
from modal._traceback import print_exception
|
|
36
|
+
from modal._utils.async_utils import TaskContext, aclosing, asyncify, synchronize_api, synchronizer
|
|
37
|
+
from modal._utils.blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload, format_blob_data
|
|
33
38
|
from modal._utils.function_utils import _stream_function_call_data
|
|
34
39
|
from modal._utils.grpc_utils import retry_transient_errors
|
|
35
40
|
from modal._utils.package_utils import parse_major_minor_version
|
|
36
41
|
from modal.client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client
|
|
37
42
|
from modal.config import config, logger
|
|
38
|
-
from modal.exception import ClientClosed, InputCancellation, InvalidError
|
|
43
|
+
from modal.exception import ClientClosed, InputCancellation, InvalidError
|
|
39
44
|
from modal_proto import api_pb2
|
|
40
45
|
|
|
41
46
|
if TYPE_CHECKING:
|
|
@@ -151,9 +156,13 @@ class IOContext:
|
|
|
151
156
|
# deserializing here instead of the constructor
|
|
152
157
|
# to make sure we handle user exceptions properly
|
|
153
158
|
# and don't retry
|
|
154
|
-
deserialized_args = [
|
|
155
|
-
|
|
156
|
-
|
|
159
|
+
deserialized_args = []
|
|
160
|
+
for input in self.function_inputs:
|
|
161
|
+
if input.args:
|
|
162
|
+
data_format = input.data_format
|
|
163
|
+
deserialized_args.append(deserialize_data_format(input.args, data_format, self._client))
|
|
164
|
+
else:
|
|
165
|
+
deserialized_args.append(((), {}))
|
|
157
166
|
if not self._is_batched:
|
|
158
167
|
return deserialized_args[0]
|
|
159
168
|
|
|
@@ -191,25 +200,225 @@ class IOContext:
|
|
|
191
200
|
}
|
|
192
201
|
return (), formatted_kwargs
|
|
193
202
|
|
|
194
|
-
def
|
|
203
|
+
def _generator_output_format(self) -> "api_pb2.DataFormat.ValueType":
|
|
204
|
+
return self._determine_output_format(self.function_inputs[0].data_format)
|
|
205
|
+
|
|
206
|
+
def _prepare_batch_output(self, data: Any) -> list[Any]:
|
|
207
|
+
# validate that output is valid for batch
|
|
208
|
+
if self._is_batched:
|
|
209
|
+
# assert data is list etc.
|
|
210
|
+
function_name = self.finalized_function.callable.__name__
|
|
211
|
+
|
|
212
|
+
if not isinstance(data, list):
|
|
213
|
+
raise InvalidError(f"Output of batched function {function_name} must be a list.")
|
|
214
|
+
if len(data) != len(self.input_ids):
|
|
215
|
+
raise InvalidError(
|
|
216
|
+
f"Output of batched function {function_name} must be a list of equal length as its inputs."
|
|
217
|
+
)
|
|
218
|
+
return data
|
|
219
|
+
else:
|
|
220
|
+
return [data]
|
|
221
|
+
|
|
222
|
+
def call_function_sync(self) -> list[Any]:
|
|
195
223
|
logger.debug(f"Starting input {self.input_ids}")
|
|
196
224
|
args, kwargs = self._args_and_kwargs()
|
|
197
|
-
|
|
225
|
+
expected_value_or_values = self.finalized_function.callable(*args, **kwargs)
|
|
226
|
+
if (
|
|
227
|
+
inspect.iscoroutine(expected_value_or_values)
|
|
228
|
+
or inspect.isgenerator(expected_value_or_values)
|
|
229
|
+
or inspect.isasyncgen(expected_value_or_values)
|
|
230
|
+
):
|
|
231
|
+
raise InvalidError(
|
|
232
|
+
f"Sync (non-generator) function return value of type {type(expected_value_or_values)}."
|
|
233
|
+
" You might need to use @app.function(..., is_generator=True)."
|
|
234
|
+
)
|
|
198
235
|
logger.debug(f"Finished input {self.input_ids}")
|
|
199
|
-
return
|
|
236
|
+
return self._prepare_batch_output(expected_value_or_values)
|
|
200
237
|
|
|
201
|
-
def
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
238
|
+
async def call_function_async(self) -> list[Any]:
|
|
239
|
+
logger.debug(f"Starting input {self.input_ids}")
|
|
240
|
+
args, kwargs = self._args_and_kwargs()
|
|
241
|
+
expected_coro = self.finalized_function.callable(*args, **kwargs)
|
|
242
|
+
if (
|
|
243
|
+
not inspect.iscoroutine(expected_coro)
|
|
244
|
+
or inspect.isgenerator(expected_coro)
|
|
245
|
+
or inspect.isasyncgen(expected_coro)
|
|
246
|
+
):
|
|
209
247
|
raise InvalidError(
|
|
210
|
-
f"
|
|
248
|
+
f"Async (non-generator) function returned value of type {type(expected_coro)}"
|
|
249
|
+
" You might need to use @app.function(..., is_generator=True)."
|
|
211
250
|
)
|
|
212
|
-
|
|
251
|
+
value = await expected_coro
|
|
252
|
+
logger.debug(f"Finished input {self.input_ids}")
|
|
253
|
+
return self._prepare_batch_output(value)
|
|
254
|
+
|
|
255
|
+
def call_generator_sync(self) -> Generator[Any, None, None]:
|
|
256
|
+
assert not self._is_batched
|
|
257
|
+
logger.debug(f"Starting generator input {self.input_ids}")
|
|
258
|
+
args, kwargs = self._args_and_kwargs()
|
|
259
|
+
expected_gen = self.finalized_function.callable(*args, **kwargs)
|
|
260
|
+
if not inspect.isgenerator(expected_gen):
|
|
261
|
+
raise InvalidError(f"Generator function returned value of type {type(expected_gen)}")
|
|
262
|
+
|
|
263
|
+
for result in expected_gen:
|
|
264
|
+
yield result
|
|
265
|
+
logger.debug(f"Finished generator input {self.input_ids}")
|
|
266
|
+
|
|
267
|
+
async def call_generator_async(self) -> AsyncGenerator[Any, None]:
|
|
268
|
+
assert not self._is_batched
|
|
269
|
+
logger.debug(f"Starting generator input {self.input_ids}")
|
|
270
|
+
args, kwargs = self._args_and_kwargs()
|
|
271
|
+
expected_async_gen = self.finalized_function.callable(*args, **kwargs)
|
|
272
|
+
if not inspect.isasyncgen(expected_async_gen):
|
|
273
|
+
raise InvalidError(f"Async generator function returned value of type {type(expected_async_gen)}")
|
|
274
|
+
|
|
275
|
+
async with aclosing(expected_async_gen) as gen:
|
|
276
|
+
async for result in gen:
|
|
277
|
+
yield result
|
|
278
|
+
logger.debug(f"Finished generator input {self.input_ids}")
|
|
279
|
+
|
|
280
|
+
async def output_items_cancellation(self, started_at: float):
|
|
281
|
+
# Create terminated outputs for these inputs to signal that the cancellations have been completed.
|
|
282
|
+
return [
|
|
283
|
+
api_pb2.FunctionPutOutputsItem(
|
|
284
|
+
input_id=input_id,
|
|
285
|
+
input_started_at=started_at,
|
|
286
|
+
result=api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_TERMINATED),
|
|
287
|
+
retry_count=retry_count,
|
|
288
|
+
)
|
|
289
|
+
for input_id, retry_count in zip(self.input_ids, self.retry_counts)
|
|
290
|
+
]
|
|
291
|
+
|
|
292
|
+
def _determine_output_format(self, input_format: "api_pb2.DataFormat.ValueType") -> "api_pb2.DataFormat.ValueType":
|
|
293
|
+
if input_format in self.finalized_function.supported_output_formats:
|
|
294
|
+
return input_format
|
|
295
|
+
elif self.finalized_function.supported_output_formats:
|
|
296
|
+
# This branch would normally be hit when calling a restricted_output function with Pickle input
|
|
297
|
+
# but we enforce cbor output at function definition level. In the future we might send the intended
|
|
298
|
+
# output format along with the input to make this disitinction in the calling client instead
|
|
299
|
+
logger.debug(
|
|
300
|
+
f"Got an input with format {input_format}, but can only produce output"
|
|
301
|
+
f" using formats {self.finalized_function.supported_output_formats}"
|
|
302
|
+
)
|
|
303
|
+
return self.finalized_function.supported_output_formats[0]
|
|
304
|
+
else:
|
|
305
|
+
# This should never happen since self.finalized_function.supported_output_formats should be
|
|
306
|
+
# populated with defaults in case it's empty, log a warning
|
|
307
|
+
logger.warning(f"Got an input with format {input_format}, but the function has no defined output formats")
|
|
308
|
+
return api_pb2.DATA_FORMAT_PICKLE
|
|
309
|
+
|
|
310
|
+
async def output_items_exception(
|
|
311
|
+
self, started_at: float, task_id: str, exc: BaseException
|
|
312
|
+
) -> list[api_pb2.FunctionPutOutputsItem]:
|
|
313
|
+
# Note: we're not pickling the traceback since it contains
|
|
314
|
+
# local references that means we can't unpickle it. We *are*
|
|
315
|
+
# pickling the exception, which may have some issues (there
|
|
316
|
+
# was an earlier note about it that it might not be possible
|
|
317
|
+
# to unpickle it in some cases). Let's watch out for issues.
|
|
318
|
+
repr_exc = repr(exc)
|
|
319
|
+
if len(repr_exc) >= MAX_OBJECT_SIZE_BYTES:
|
|
320
|
+
# We prevent large exception messages to avoid
|
|
321
|
+
# unhandled exceptions causing inf loops
|
|
322
|
+
# and just send backa trimmed version
|
|
323
|
+
trimmed_bytes = len(repr_exc) - MAX_OBJECT_SIZE_BYTES - 1000
|
|
324
|
+
repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000]
|
|
325
|
+
repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception"
|
|
326
|
+
|
|
327
|
+
data: bytes = pickle_exception(exc)
|
|
328
|
+
data_result_part = await format_blob_data(data, self._client.stub)
|
|
329
|
+
serialized_tb, tb_line_cache = pickle_traceback(exc, task_id)
|
|
330
|
+
|
|
331
|
+
# Failure outputs for when input exceptions occur
|
|
332
|
+
def data_format_specific_output(input_format: "api_pb2.DataFormat.ValueType") -> dict:
|
|
333
|
+
output_format = self._determine_output_format(input_format)
|
|
334
|
+
if output_format == api_pb2.DATA_FORMAT_PICKLE:
|
|
335
|
+
return {
|
|
336
|
+
"data_format": output_format,
|
|
337
|
+
"result": api_pb2.GenericResult(
|
|
338
|
+
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
|
|
339
|
+
exception=repr_exc,
|
|
340
|
+
traceback=traceback.format_exc(),
|
|
341
|
+
serialized_tb=serialized_tb,
|
|
342
|
+
tb_line_cache=tb_line_cache,
|
|
343
|
+
**data_result_part,
|
|
344
|
+
),
|
|
345
|
+
}
|
|
346
|
+
else:
|
|
347
|
+
return {
|
|
348
|
+
"data_format": output_format,
|
|
349
|
+
"result": api_pb2.GenericResult(
|
|
350
|
+
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
|
|
351
|
+
exception=repr_exc,
|
|
352
|
+
traceback=traceback.format_exc(),
|
|
353
|
+
),
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
# all inputs in the batch get the same failure:
|
|
357
|
+
return [
|
|
358
|
+
api_pb2.FunctionPutOutputsItem(
|
|
359
|
+
input_id=input_id,
|
|
360
|
+
input_started_at=started_at,
|
|
361
|
+
retry_count=retry_count,
|
|
362
|
+
**data_format_specific_output(function_input.data_format),
|
|
363
|
+
)
|
|
364
|
+
for input_id, retry_count, function_input in zip(self.input_ids, self.retry_counts, self.function_inputs)
|
|
365
|
+
]
|
|
366
|
+
|
|
367
|
+
def output_items_generator_done(self, started_at: float, items_total: int) -> list[api_pb2.FunctionPutOutputsItem]:
|
|
368
|
+
assert not self._is_batched, "generators are not supported with batched inputs"
|
|
369
|
+
assert len(self.function_inputs) == 1, "generators are expected to have 1 input"
|
|
370
|
+
# Serialize and format the data
|
|
371
|
+
serialized_bytes = serialize_data_format(
|
|
372
|
+
api_pb2.GeneratorDone(items_total=items_total), data_format=api_pb2.DATA_FORMAT_GENERATOR_DONE
|
|
373
|
+
)
|
|
374
|
+
return [
|
|
375
|
+
api_pb2.FunctionPutOutputsItem(
|
|
376
|
+
input_id=self.input_ids[0],
|
|
377
|
+
input_started_at=started_at,
|
|
378
|
+
output_created_at=time.time(),
|
|
379
|
+
result=api_pb2.GenericResult(
|
|
380
|
+
status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
|
|
381
|
+
data=serialized_bytes,
|
|
382
|
+
),
|
|
383
|
+
data_format=api_pb2.DATA_FORMAT_GENERATOR_DONE,
|
|
384
|
+
retry_count=self.retry_counts[0],
|
|
385
|
+
)
|
|
386
|
+
]
|
|
387
|
+
|
|
388
|
+
async def output_items(self, started_at: float, data: list[Any]) -> list[api_pb2.FunctionPutOutputsItem]:
|
|
389
|
+
output_created_at = time.time()
|
|
390
|
+
|
|
391
|
+
# Process all items concurrently and create output items directly
|
|
392
|
+
async def package_output(
|
|
393
|
+
item: Any, input_id: str, retry_count: int, input_format: "api_pb2.DataFormat.ValueType"
|
|
394
|
+
) -> api_pb2.FunctionPutOutputsItem:
|
|
395
|
+
output_format = self._determine_output_format(input_format)
|
|
396
|
+
|
|
397
|
+
serialized_bytes = serialize_data_format(item, data_format=output_format)
|
|
398
|
+
formatted = await format_blob_data(serialized_bytes, self._client.stub)
|
|
399
|
+
# Create the result
|
|
400
|
+
result = api_pb2.GenericResult(
|
|
401
|
+
status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
|
|
402
|
+
**formatted,
|
|
403
|
+
)
|
|
404
|
+
return api_pb2.FunctionPutOutputsItem(
|
|
405
|
+
input_id=input_id,
|
|
406
|
+
input_started_at=started_at,
|
|
407
|
+
output_created_at=output_created_at,
|
|
408
|
+
result=result,
|
|
409
|
+
data_format=output_format,
|
|
410
|
+
retry_count=retry_count,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# Process all items concurrently
|
|
414
|
+
return await asyncio.gather(
|
|
415
|
+
*[
|
|
416
|
+
package_output(item, input_id, retry_count, function_input.data_format)
|
|
417
|
+
for item, input_id, retry_count, function_input in zip(
|
|
418
|
+
data, self.input_ids, self.retry_counts, self.function_inputs
|
|
419
|
+
)
|
|
420
|
+
]
|
|
421
|
+
)
|
|
213
422
|
|
|
214
423
|
|
|
215
424
|
class InputSlots:
|
|
@@ -472,17 +681,6 @@ class _ContainerIOManager:
|
|
|
472
681
|
|
|
473
682
|
await asyncio.sleep(DYNAMIC_CONCURRENCY_INTERVAL_SECS)
|
|
474
683
|
|
|
475
|
-
@synchronizer.no_io_translation
|
|
476
|
-
def serialize_data_format(self, obj: Any, data_format: int) -> bytes:
|
|
477
|
-
return serialize_data_format(obj, data_format)
|
|
478
|
-
|
|
479
|
-
async def format_blob_data(self, data: bytes) -> dict[str, Any]:
|
|
480
|
-
return (
|
|
481
|
-
{"data_blob_id": await blob_upload(data, self._client.stub)}
|
|
482
|
-
if len(data) > MAX_OBJECT_SIZE_BYTES
|
|
483
|
-
else {"data": data}
|
|
484
|
-
)
|
|
485
|
-
|
|
486
684
|
async def get_data_in(self, function_call_id: str, attempt_token: Optional[str]) -> AsyncIterator[Any]:
|
|
487
685
|
"""Read from the `data_in` stream of a function call."""
|
|
488
686
|
stub = self._client.stub
|
|
@@ -499,7 +697,7 @@ class _ContainerIOManager:
|
|
|
499
697
|
function_call_id: str,
|
|
500
698
|
attempt_token: str,
|
|
501
699
|
start_index: int,
|
|
502
|
-
data_format:
|
|
700
|
+
data_format: "api_pb2.DataFormat.ValueType",
|
|
503
701
|
serialized_messages: list[Any],
|
|
504
702
|
) -> None:
|
|
505
703
|
"""Put data onto the `data_out` stream of a function call.
|
|
@@ -529,7 +727,11 @@ class _ContainerIOManager:
|
|
|
529
727
|
|
|
530
728
|
@asynccontextmanager
|
|
531
729
|
async def generator_output_sender(
|
|
532
|
-
self,
|
|
730
|
+
self,
|
|
731
|
+
function_call_id: str,
|
|
732
|
+
attempt_token: str,
|
|
733
|
+
data_format: "api_pb2.DataFormat.ValueType",
|
|
734
|
+
message_rx: asyncio.Queue,
|
|
533
735
|
) -> AsyncGenerator[None, None]:
|
|
534
736
|
"""Runs background task that feeds generator outputs into a function call's `data_out` stream."""
|
|
535
737
|
GENERATOR_STOP_SENTINEL = Sentinel()
|
|
@@ -672,31 +874,11 @@ class _ContainerIOManager:
|
|
|
672
874
|
self.current_input_id, self.current_input_started_at = io_context.input_ids[0], time.time()
|
|
673
875
|
yield io_context
|
|
674
876
|
self.current_input_id, self.current_input_started_at = (None, None)
|
|
675
|
-
|
|
676
877
|
# collect all active input slots, meaning all inputs have wrapped up.
|
|
677
878
|
await self._input_slots.close()
|
|
678
879
|
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
self,
|
|
682
|
-
io_context: IOContext,
|
|
683
|
-
started_at: float,
|
|
684
|
-
data_format: "modal_proto.api_pb2.DataFormat.ValueType",
|
|
685
|
-
results: list[api_pb2.GenericResult],
|
|
686
|
-
) -> None:
|
|
687
|
-
output_created_at = time.time()
|
|
688
|
-
outputs = [
|
|
689
|
-
api_pb2.FunctionPutOutputsItem(
|
|
690
|
-
input_id=input_id,
|
|
691
|
-
input_started_at=started_at,
|
|
692
|
-
output_created_at=output_created_at,
|
|
693
|
-
result=result,
|
|
694
|
-
data_format=data_format,
|
|
695
|
-
retry_count=retry_count,
|
|
696
|
-
)
|
|
697
|
-
for input_id, retry_count, result in zip(io_context.input_ids, io_context.retry_counts, results)
|
|
698
|
-
]
|
|
699
|
-
|
|
880
|
+
async def _send_outputs(self, started_at: float, outputs: list[api_pb2.FunctionPutOutputsItem]) -> None:
|
|
881
|
+
"""Send pre-built output items with retry and chunking."""
|
|
700
882
|
# There are multiple outputs for a single IOContext in the case of @modal.batched.
|
|
701
883
|
# Limit the batch size to 20 to stay within message size limits and buffer size limits.
|
|
702
884
|
output_batch_size = 20
|
|
@@ -707,27 +889,8 @@ class _ContainerIOManager:
|
|
|
707
889
|
additional_status_codes=[Status.RESOURCE_EXHAUSTED],
|
|
708
890
|
max_retries=None, # Retry indefinitely, trying every 1s.
|
|
709
891
|
)
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
try:
|
|
713
|
-
return serialize(exc)
|
|
714
|
-
except Exception as serialization_exc:
|
|
715
|
-
# We can't always serialize exceptions.
|
|
716
|
-
err = f"Failed to serialize exception {exc} of type {type(exc)}: {serialization_exc}"
|
|
717
|
-
logger.info(err)
|
|
718
|
-
return serialize(SerializationError(err))
|
|
719
|
-
|
|
720
|
-
def serialize_traceback(self, exc: BaseException) -> tuple[Optional[bytes], Optional[bytes]]:
|
|
721
|
-
serialized_tb, tb_line_cache = None, None
|
|
722
|
-
|
|
723
|
-
try:
|
|
724
|
-
tb_dict, line_cache = extract_traceback(exc, self.task_id)
|
|
725
|
-
serialized_tb = serialize(tb_dict)
|
|
726
|
-
tb_line_cache = serialize(line_cache)
|
|
727
|
-
except Exception:
|
|
728
|
-
logger.info("Failed to serialize exception traceback.")
|
|
729
|
-
|
|
730
|
-
return serialized_tb, tb_line_cache
|
|
892
|
+
input_ids = [output.input_id for output in outputs]
|
|
893
|
+
self.exit_context(started_at, input_ids)
|
|
731
894
|
|
|
732
895
|
@asynccontextmanager
|
|
733
896
|
async def handle_user_exception(self) -> AsyncGenerator[None, None]:
|
|
@@ -750,11 +913,11 @@ class _ContainerIOManager:
|
|
|
750
913
|
# Since this is on a different thread, sys.exc_info() can't find the exception in the stack.
|
|
751
914
|
print_exception(type(exc), exc, exc.__traceback__)
|
|
752
915
|
|
|
753
|
-
serialized_tb, tb_line_cache = self.
|
|
916
|
+
serialized_tb, tb_line_cache = pickle_traceback(exc, self.task_id)
|
|
754
917
|
|
|
755
918
|
result = api_pb2.GenericResult(
|
|
756
919
|
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
|
|
757
|
-
data=
|
|
920
|
+
data=pickle_exception(exc),
|
|
758
921
|
exception=repr(exc),
|
|
759
922
|
traceback="".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
|
|
760
923
|
serialized_tb=serialized_tb or b"",
|
|
@@ -784,18 +947,8 @@ class _ContainerIOManager:
|
|
|
784
947
|
# for the yield. Typically on event loop shutdown
|
|
785
948
|
raise
|
|
786
949
|
except (InputCancellation, asyncio.CancelledError):
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_TERMINATED)
|
|
790
|
-
for _ in io_context.input_ids
|
|
791
|
-
]
|
|
792
|
-
await self._push_outputs(
|
|
793
|
-
io_context=io_context,
|
|
794
|
-
started_at=started_at,
|
|
795
|
-
data_format=api_pb2.DATA_FORMAT_PICKLE,
|
|
796
|
-
results=results,
|
|
797
|
-
)
|
|
798
|
-
self.exit_context(started_at, io_context.input_ids)
|
|
950
|
+
outputs = await io_context.output_items_cancellation(started_at)
|
|
951
|
+
await self._send_outputs(started_at, outputs)
|
|
799
952
|
logger.warning(f"Successfully canceled input {io_context.input_ids}")
|
|
800
953
|
return
|
|
801
954
|
except BaseException as exc:
|
|
@@ -805,44 +958,8 @@ class _ContainerIOManager:
|
|
|
805
958
|
|
|
806
959
|
# print exception so it's logged
|
|
807
960
|
print_exception(*sys.exc_info())
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
# Note: we're not serializing the traceback since it contains
|
|
812
|
-
# local references that means we can't unpickle it. We *are*
|
|
813
|
-
# serializing the exception, which may have some issues (there
|
|
814
|
-
# was an earlier note about it that it might not be possible
|
|
815
|
-
# to unpickle it in some cases). Let's watch out for issues.
|
|
816
|
-
|
|
817
|
-
repr_exc = repr(exc)
|
|
818
|
-
if len(repr_exc) >= MAX_OBJECT_SIZE_BYTES:
|
|
819
|
-
# We prevent large exception messages to avoid
|
|
820
|
-
# unhandled exceptions causing inf loops
|
|
821
|
-
# and just send backa trimmed version
|
|
822
|
-
trimmed_bytes = len(repr_exc) - MAX_OBJECT_SIZE_BYTES - 1000
|
|
823
|
-
repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000]
|
|
824
|
-
repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception"
|
|
825
|
-
|
|
826
|
-
data: bytes = self.serialize_exception(exc) or b""
|
|
827
|
-
data_result_part = await self.format_blob_data(data)
|
|
828
|
-
results = [
|
|
829
|
-
api_pb2.GenericResult(
|
|
830
|
-
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
|
|
831
|
-
exception=repr_exc,
|
|
832
|
-
traceback=traceback.format_exc(),
|
|
833
|
-
serialized_tb=serialized_tb or b"",
|
|
834
|
-
tb_line_cache=tb_line_cache or b"",
|
|
835
|
-
**data_result_part,
|
|
836
|
-
)
|
|
837
|
-
for _ in io_context.input_ids
|
|
838
|
-
]
|
|
839
|
-
await self._push_outputs(
|
|
840
|
-
io_context=io_context,
|
|
841
|
-
started_at=started_at,
|
|
842
|
-
data_format=api_pb2.DATA_FORMAT_PICKLE,
|
|
843
|
-
results=results,
|
|
844
|
-
)
|
|
845
|
-
self.exit_context(started_at, io_context.input_ids)
|
|
961
|
+
outputs = await io_context.output_items_exception(started_at, self.task_id, exc)
|
|
962
|
+
await self._send_outputs(started_at, outputs)
|
|
846
963
|
|
|
847
964
|
def exit_context(self, started_at, input_ids: list[str]):
|
|
848
965
|
self.total_user_time += time.time() - started_at
|
|
@@ -853,32 +970,17 @@ class _ContainerIOManager:
|
|
|
853
970
|
|
|
854
971
|
self._input_slots.release()
|
|
855
972
|
|
|
973
|
+
# skip inspection of user-generated output_data for synchronicity input translation
|
|
856
974
|
@synchronizer.no_io_translation
|
|
857
975
|
async def push_outputs(
|
|
858
976
|
self,
|
|
859
977
|
io_context: IOContext,
|
|
860
978
|
started_at: float,
|
|
861
|
-
|
|
862
|
-
data_format: "modal_proto.api_pb2.DataFormat.ValueType",
|
|
979
|
+
output_data: list[Any], # one per output
|
|
863
980
|
) -> None:
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
)
|
|
868
|
-
results = [
|
|
869
|
-
api_pb2.GenericResult(
|
|
870
|
-
status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
|
|
871
|
-
**d,
|
|
872
|
-
)
|
|
873
|
-
for d in formatted_data
|
|
874
|
-
]
|
|
875
|
-
await self._push_outputs(
|
|
876
|
-
io_context=io_context,
|
|
877
|
-
started_at=started_at,
|
|
878
|
-
data_format=data_format,
|
|
879
|
-
results=results,
|
|
880
|
-
)
|
|
881
|
-
self.exit_context(started_at, io_context.input_ids)
|
|
981
|
+
# The standard output encoding+sending method for successful function outputs
|
|
982
|
+
outputs = await io_context.output_items(started_at, output_data)
|
|
983
|
+
await self._send_outputs(started_at, outputs)
|
|
882
984
|
|
|
883
985
|
async def memory_restore(self) -> None:
|
|
884
986
|
# Busy-wait for restore. `/__modal/restore-state.json` is created
|
|
@@ -58,8 +58,23 @@ class IOContext:
|
|
|
58
58
|
def set_cancel_callback(self, cb: collections.abc.Callable[[], None]): ...
|
|
59
59
|
def cancel(self): ...
|
|
60
60
|
def _args_and_kwargs(self) -> tuple[tuple[typing.Any, ...], dict[str, list[typing.Any]]]: ...
|
|
61
|
-
def
|
|
62
|
-
def
|
|
61
|
+
def _generator_output_format(self) -> int: ...
|
|
62
|
+
def _prepare_batch_output(self, data: typing.Any) -> list[typing.Any]: ...
|
|
63
|
+
def call_function_sync(self) -> list[typing.Any]: ...
|
|
64
|
+
async def call_function_async(self) -> list[typing.Any]: ...
|
|
65
|
+
def call_generator_sync(self) -> typing.Generator[typing.Any, None, None]: ...
|
|
66
|
+
def call_generator_async(self) -> collections.abc.AsyncGenerator[typing.Any, None]: ...
|
|
67
|
+
async def output_items_cancellation(self, started_at: float): ...
|
|
68
|
+
def _determine_output_format(self, input_format: int) -> int: ...
|
|
69
|
+
async def output_items_exception(
|
|
70
|
+
self, started_at: float, task_id: str, exc: BaseException
|
|
71
|
+
) -> list[modal_proto.api_pb2.FunctionPutOutputsItem]: ...
|
|
72
|
+
def output_items_generator_done(
|
|
73
|
+
self, started_at: float, items_total: int
|
|
74
|
+
) -> list[modal_proto.api_pb2.FunctionPutOutputsItem]: ...
|
|
75
|
+
async def output_items(
|
|
76
|
+
self, started_at: float, data: list[typing.Any]
|
|
77
|
+
) -> list[modal_proto.api_pb2.FunctionPutOutputsItem]: ...
|
|
63
78
|
|
|
64
79
|
class InputSlots:
|
|
65
80
|
"""A semaphore that allows dynamically adjusting the concurrency."""
|
|
@@ -133,8 +148,6 @@ class _ContainerIOManager:
|
|
|
133
148
|
def stop_heartbeat(self): ...
|
|
134
149
|
def dynamic_concurrency_manager(self) -> typing.AsyncContextManager[None]: ...
|
|
135
150
|
async def _dynamic_concurrency_loop(self): ...
|
|
136
|
-
def serialize_data_format(self, obj: typing.Any, data_format: int) -> bytes: ...
|
|
137
|
-
async def format_blob_data(self, data: bytes) -> dict[str, typing.Any]: ...
|
|
138
151
|
def get_data_in(
|
|
139
152
|
self, function_call_id: str, attempt_token: typing.Optional[str]
|
|
140
153
|
) -> collections.abc.AsyncIterator[typing.Any]:
|
|
@@ -182,15 +195,10 @@ class _ContainerIOManager:
|
|
|
182
195
|
batch_max_size: int = 0,
|
|
183
196
|
batch_wait_ms: int = 0,
|
|
184
197
|
) -> collections.abc.AsyncIterator[IOContext]: ...
|
|
185
|
-
async def
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
data_format: int,
|
|
190
|
-
results: list[modal_proto.api_pb2.GenericResult],
|
|
191
|
-
) -> None: ...
|
|
192
|
-
def serialize_exception(self, exc: BaseException) -> bytes: ...
|
|
193
|
-
def serialize_traceback(self, exc: BaseException) -> tuple[typing.Optional[bytes], typing.Optional[bytes]]: ...
|
|
198
|
+
async def _send_outputs(self, started_at: float, outputs: list[modal_proto.api_pb2.FunctionPutOutputsItem]) -> None:
|
|
199
|
+
"""Send pre-built output items with retry and chunking."""
|
|
200
|
+
...
|
|
201
|
+
|
|
194
202
|
def handle_user_exception(self) -> typing.AsyncContextManager[None]:
|
|
195
203
|
"""Sets the task as failed in a way where it's not retried.
|
|
196
204
|
|
|
@@ -204,9 +212,7 @@ class _ContainerIOManager:
|
|
|
204
212
|
...
|
|
205
213
|
|
|
206
214
|
def exit_context(self, started_at, input_ids: list[str]): ...
|
|
207
|
-
async def push_outputs(
|
|
208
|
-
self, io_context: IOContext, started_at: float, data: typing.Any, data_format: int
|
|
209
|
-
) -> None: ...
|
|
215
|
+
async def push_outputs(self, io_context: IOContext, started_at: float, output_data: list[typing.Any]) -> None: ...
|
|
210
216
|
async def memory_restore(self) -> None: ...
|
|
211
217
|
async def memory_snapshot(self) -> None:
|
|
212
218
|
"""Message server indicating that function is ready to be checkpointed."""
|
|
@@ -332,14 +338,6 @@ class ContainerIOManager:
|
|
|
332
338
|
|
|
333
339
|
_dynamic_concurrency_loop: ___dynamic_concurrency_loop_spec[typing_extensions.Self]
|
|
334
340
|
|
|
335
|
-
def serialize_data_format(self, obj: typing.Any, data_format: int) -> bytes: ...
|
|
336
|
-
|
|
337
|
-
class __format_blob_data_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
338
|
-
def __call__(self, /, data: bytes) -> dict[str, typing.Any]: ...
|
|
339
|
-
async def aio(self, /, data: bytes) -> dict[str, typing.Any]: ...
|
|
340
|
-
|
|
341
|
-
format_blob_data: __format_blob_data_spec[typing_extensions.Self]
|
|
342
|
-
|
|
343
341
|
class __get_data_in_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
344
342
|
def __call__(
|
|
345
343
|
self, /, function_call_id: str, attempt_token: typing.Optional[str]
|
|
@@ -460,28 +458,16 @@ class ContainerIOManager:
|
|
|
460
458
|
|
|
461
459
|
run_inputs_outputs: __run_inputs_outputs_spec[typing_extensions.Self]
|
|
462
460
|
|
|
463
|
-
class
|
|
464
|
-
def __call__(
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
io_context: IOContext,
|
|
468
|
-
started_at: float,
|
|
469
|
-
data_format: int,
|
|
470
|
-
results: list[modal_proto.api_pb2.GenericResult],
|
|
471
|
-
) -> None: ...
|
|
472
|
-
async def aio(
|
|
473
|
-
self,
|
|
474
|
-
/,
|
|
475
|
-
io_context: IOContext,
|
|
476
|
-
started_at: float,
|
|
477
|
-
data_format: int,
|
|
478
|
-
results: list[modal_proto.api_pb2.GenericResult],
|
|
479
|
-
) -> None: ...
|
|
461
|
+
class ___send_outputs_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
462
|
+
def __call__(self, /, started_at: float, outputs: list[modal_proto.api_pb2.FunctionPutOutputsItem]) -> None:
|
|
463
|
+
"""Send pre-built output items with retry and chunking."""
|
|
464
|
+
...
|
|
480
465
|
|
|
481
|
-
|
|
466
|
+
async def aio(self, /, started_at: float, outputs: list[modal_proto.api_pb2.FunctionPutOutputsItem]) -> None:
|
|
467
|
+
"""Send pre-built output items with retry and chunking."""
|
|
468
|
+
...
|
|
482
469
|
|
|
483
|
-
|
|
484
|
-
def serialize_traceback(self, exc: BaseException) -> tuple[typing.Optional[bytes], typing.Optional[bytes]]: ...
|
|
470
|
+
_send_outputs: ___send_outputs_spec[typing_extensions.Self]
|
|
485
471
|
|
|
486
472
|
class __handle_user_exception_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
487
473
|
def __call__(self, /) -> synchronicity.combined_types.AsyncAndBlockingContextManager[None]:
|
|
@@ -518,10 +504,8 @@ class ContainerIOManager:
|
|
|
518
504
|
def exit_context(self, started_at, input_ids: list[str]): ...
|
|
519
505
|
|
|
520
506
|
class __push_outputs_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
521
|
-
def __call__(self, /, io_context: IOContext, started_at: float,
|
|
522
|
-
async def aio(
|
|
523
|
-
self, /, io_context: IOContext, started_at: float, data: typing.Any, data_format: int
|
|
524
|
-
) -> None: ...
|
|
507
|
+
def __call__(self, /, io_context: IOContext, started_at: float, output_data: list[typing.Any]) -> None: ...
|
|
508
|
+
async def aio(self, /, io_context: IOContext, started_at: float, output_data: list[typing.Any]) -> None: ...
|
|
525
509
|
|
|
526
510
|
push_outputs: __push_outputs_spec[typing_extensions.Self]
|
|
527
511
|
|