modal 1.0.6.dev58__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (147) hide show
  1. modal/__main__.py +3 -4
  2. modal/_billing.py +80 -0
  3. modal/_clustered_functions.py +7 -3
  4. modal/_clustered_functions.pyi +4 -2
  5. modal/_container_entrypoint.py +41 -49
  6. modal/_functions.py +424 -195
  7. modal/_grpc_client.py +171 -0
  8. modal/_load_context.py +105 -0
  9. modal/_object.py +68 -20
  10. modal/_output.py +58 -45
  11. modal/_partial_function.py +36 -11
  12. modal/_pty.py +7 -3
  13. modal/_resolver.py +21 -35
  14. modal/_runtime/asgi.py +4 -3
  15. modal/_runtime/container_io_manager.py +301 -186
  16. modal/_runtime/container_io_manager.pyi +70 -61
  17. modal/_runtime/execution_context.py +18 -2
  18. modal/_runtime/execution_context.pyi +4 -1
  19. modal/_runtime/gpu_memory_snapshot.py +170 -63
  20. modal/_runtime/user_code_imports.py +28 -58
  21. modal/_serialization.py +57 -1
  22. modal/_utils/async_utils.py +33 -12
  23. modal/_utils/auth_token_manager.py +2 -5
  24. modal/_utils/blob_utils.py +110 -53
  25. modal/_utils/function_utils.py +49 -42
  26. modal/_utils/grpc_utils.py +80 -50
  27. modal/_utils/mount_utils.py +26 -1
  28. modal/_utils/name_utils.py +17 -3
  29. modal/_utils/task_command_router_client.py +536 -0
  30. modal/_utils/time_utils.py +34 -6
  31. modal/app.py +219 -83
  32. modal/app.pyi +229 -56
  33. modal/billing.py +5 -0
  34. modal/{requirements → builder}/2025.06.txt +1 -0
  35. modal/{requirements → builder}/PREVIEW.txt +1 -0
  36. modal/cli/_download.py +19 -3
  37. modal/cli/_traceback.py +3 -2
  38. modal/cli/app.py +4 -4
  39. modal/cli/cluster.py +15 -7
  40. modal/cli/config.py +5 -3
  41. modal/cli/container.py +7 -6
  42. modal/cli/dict.py +22 -16
  43. modal/cli/entry_point.py +12 -5
  44. modal/cli/environment.py +5 -4
  45. modal/cli/import_refs.py +3 -3
  46. modal/cli/launch.py +102 -5
  47. modal/cli/network_file_system.py +9 -13
  48. modal/cli/profile.py +3 -2
  49. modal/cli/programs/launch_instance_ssh.py +94 -0
  50. modal/cli/programs/run_jupyter.py +1 -1
  51. modal/cli/programs/run_marimo.py +95 -0
  52. modal/cli/programs/vscode.py +1 -1
  53. modal/cli/queues.py +57 -26
  54. modal/cli/run.py +58 -16
  55. modal/cli/secret.py +48 -22
  56. modal/cli/utils.py +3 -4
  57. modal/cli/volume.py +28 -25
  58. modal/client.py +13 -116
  59. modal/client.pyi +9 -91
  60. modal/cloud_bucket_mount.py +5 -3
  61. modal/cloud_bucket_mount.pyi +5 -1
  62. modal/cls.py +130 -102
  63. modal/cls.pyi +45 -85
  64. modal/config.py +29 -10
  65. modal/container_process.py +291 -13
  66. modal/container_process.pyi +95 -32
  67. modal/dict.py +282 -63
  68. modal/dict.pyi +423 -73
  69. modal/environments.py +15 -27
  70. modal/environments.pyi +5 -15
  71. modal/exception.py +8 -0
  72. modal/experimental/__init__.py +143 -38
  73. modal/experimental/flash.py +247 -78
  74. modal/experimental/flash.pyi +137 -9
  75. modal/file_io.py +14 -28
  76. modal/file_io.pyi +2 -2
  77. modal/file_pattern_matcher.py +25 -16
  78. modal/functions.pyi +134 -61
  79. modal/image.py +255 -86
  80. modal/image.pyi +300 -62
  81. modal/io_streams.py +436 -126
  82. modal/io_streams.pyi +236 -171
  83. modal/mount.py +62 -157
  84. modal/mount.pyi +45 -172
  85. modal/network_file_system.py +30 -53
  86. modal/network_file_system.pyi +16 -76
  87. modal/object.pyi +42 -8
  88. modal/parallel_map.py +821 -113
  89. modal/parallel_map.pyi +134 -0
  90. modal/partial_function.pyi +4 -1
  91. modal/proxy.py +16 -7
  92. modal/proxy.pyi +10 -2
  93. modal/queue.py +263 -61
  94. modal/queue.pyi +409 -66
  95. modal/runner.py +112 -92
  96. modal/runner.pyi +45 -27
  97. modal/sandbox.py +451 -124
  98. modal/sandbox.pyi +513 -67
  99. modal/secret.py +291 -67
  100. modal/secret.pyi +425 -19
  101. modal/serving.py +7 -11
  102. modal/serving.pyi +7 -8
  103. modal/snapshot.py +11 -8
  104. modal/token_flow.py +4 -4
  105. modal/volume.py +344 -98
  106. modal/volume.pyi +464 -68
  107. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +9 -8
  108. modal-1.2.3.dev7.dist-info/RECORD +195 -0
  109. modal_docs/mdmd/mdmd.py +11 -1
  110. modal_proto/api.proto +399 -67
  111. modal_proto/api_grpc.py +241 -1
  112. modal_proto/api_pb2.py +1395 -1000
  113. modal_proto/api_pb2.pyi +1239 -79
  114. modal_proto/api_pb2_grpc.py +499 -4
  115. modal_proto/api_pb2_grpc.pyi +162 -14
  116. modal_proto/modal_api_grpc.py +175 -160
  117. modal_proto/sandbox_router.proto +145 -0
  118. modal_proto/sandbox_router_grpc.py +105 -0
  119. modal_proto/sandbox_router_pb2.py +149 -0
  120. modal_proto/sandbox_router_pb2.pyi +333 -0
  121. modal_proto/sandbox_router_pb2_grpc.py +203 -0
  122. modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
  123. modal_proto/task_command_router.proto +144 -0
  124. modal_proto/task_command_router_grpc.py +105 -0
  125. modal_proto/task_command_router_pb2.py +149 -0
  126. modal_proto/task_command_router_pb2.pyi +333 -0
  127. modal_proto/task_command_router_pb2_grpc.py +203 -0
  128. modal_proto/task_command_router_pb2_grpc.pyi +75 -0
  129. modal_version/__init__.py +1 -1
  130. modal-1.0.6.dev58.dist-info/RECORD +0 -183
  131. modal_proto/modal_options_grpc.py +0 -3
  132. modal_proto/options.proto +0 -19
  133. modal_proto/options_grpc.py +0 -3
  134. modal_proto/options_pb2.py +0 -35
  135. modal_proto/options_pb2.pyi +0 -20
  136. modal_proto/options_pb2_grpc.py +0 -4
  137. modal_proto/options_pb2_grpc.pyi +0 -7
  138. /modal/{requirements → builder}/2023.12.312.txt +0 -0
  139. /modal/{requirements → builder}/2023.12.txt +0 -0
  140. /modal/{requirements → builder}/2024.04.txt +0 -0
  141. /modal/{requirements → builder}/2024.10.txt +0 -0
  142. /modal/{requirements → builder}/README.md +0 -0
  143. /modal/{requirements → builder}/base-images.json +0 -0
  144. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
  145. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
  146. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
  147. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ from typing import (
16
16
  Any,
17
17
  Callable,
18
18
  ClassVar,
19
+ Generator,
19
20
  Optional,
20
21
  cast,
21
22
  )
@@ -24,22 +25,25 @@ from google.protobuf.empty_pb2 import Empty
24
25
  from grpclib import Status
25
26
  from synchronicity.async_wrap import asynccontextmanager
26
27
 
27
- import modal_proto.api_pb2
28
28
  from modal._runtime import gpu_memory_snapshot
29
- from modal._serialization import deserialize, serialize, serialize_data_format
30
- from modal._traceback import extract_traceback, print_exception
31
- from modal._utils.async_utils import TaskContext, asyncify, synchronize_api, synchronizer
32
- from modal._utils.blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload
29
+ from modal._serialization import (
30
+ deserialize_data_format,
31
+ pickle_exception,
32
+ pickle_traceback,
33
+ serialize_data_format,
34
+ )
35
+ from modal._traceback import print_exception
36
+ from modal._utils.async_utils import TaskContext, aclosing, asyncify, synchronize_api, synchronizer
37
+ from modal._utils.blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload, format_blob_data
33
38
  from modal._utils.function_utils import _stream_function_call_data
34
- from modal._utils.grpc_utils import retry_transient_errors
39
+ from modal._utils.grpc_utils import Retry
35
40
  from modal._utils.package_utils import parse_major_minor_version
36
41
  from modal.client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client
37
42
  from modal.config import config, logger
38
- from modal.exception import ClientClosed, InputCancellation, InvalidError, SerializationError
43
+ from modal.exception import ClientClosed, InputCancellation, InvalidError
39
44
  from modal_proto import api_pb2
40
45
 
41
46
  if TYPE_CHECKING:
42
- import modal._runtime.asgi
43
47
  import modal._runtime.user_code_imports
44
48
 
45
49
 
@@ -66,6 +70,7 @@ class IOContext:
66
70
  input_ids: list[str]
67
71
  retry_counts: list[int]
68
72
  function_call_ids: list[str]
73
+ attempt_tokens: list[str]
69
74
  function_inputs: list[api_pb2.FunctionInput]
70
75
  finalized_function: "modal._runtime.user_code_imports.FinalizedFunction"
71
76
 
@@ -77,6 +82,7 @@ class IOContext:
77
82
  input_ids: list[str],
78
83
  retry_counts: list[int],
79
84
  function_call_ids: list[str],
85
+ attempt_tokens: list[str],
80
86
  finalized_function: "modal._runtime.user_code_imports.FinalizedFunction",
81
87
  function_inputs: list[api_pb2.FunctionInput],
82
88
  is_batched: bool,
@@ -85,6 +91,7 @@ class IOContext:
85
91
  self.input_ids = input_ids
86
92
  self.retry_counts = retry_counts
87
93
  self.function_call_ids = function_call_ids
94
+ self.attempt_tokens = attempt_tokens
88
95
  self.finalized_function = finalized_function
89
96
  self.function_inputs = function_inputs
90
97
  self._is_batched = is_batched
@@ -95,11 +102,11 @@ class IOContext:
95
102
  cls,
96
103
  client: _Client,
97
104
  finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
98
- inputs: list[tuple[str, int, str, api_pb2.FunctionInput]],
105
+ inputs: list[tuple[str, int, str, str, api_pb2.FunctionInput]],
99
106
  is_batched: bool,
100
107
  ) -> "IOContext":
101
108
  assert len(inputs) >= 1 if is_batched else len(inputs) == 1
102
- input_ids, retry_counts, function_call_ids, function_inputs = zip(*inputs)
109
+ input_ids, retry_counts, function_call_ids, attempt_tokens, function_inputs = zip(*inputs)
103
110
 
104
111
  async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput) -> api_pb2.FunctionInput:
105
112
  # If we got a pointer to a blob, download it from S3.
@@ -121,6 +128,7 @@ class IOContext:
121
128
  cast(list[str], input_ids),
122
129
  cast(list[int], retry_counts),
123
130
  cast(list[str], function_call_ids),
131
+ cast(list[str], attempt_tokens),
124
132
  finalized_function,
125
133
  cast(list[api_pb2.FunctionInput], function_inputs),
126
134
  is_batched,
@@ -148,9 +156,13 @@ class IOContext:
148
156
  # deserializing here instead of the constructor
149
157
  # to make sure we handle user exceptions properly
150
158
  # and don't retry
151
- deserialized_args = [
152
- deserialize(input.args, self._client) if input.args else ((), {}) for input in self.function_inputs
153
- ]
159
+ deserialized_args = []
160
+ for input in self.function_inputs:
161
+ if input.args:
162
+ data_format = input.data_format
163
+ deserialized_args.append(deserialize_data_format(input.args, data_format, self._client))
164
+ else:
165
+ deserialized_args.append(((), {}))
154
166
  if not self._is_batched:
155
167
  return deserialized_args[0]
156
168
 
@@ -188,25 +200,229 @@ class IOContext:
188
200
  }
189
201
  return (), formatted_kwargs
190
202
 
191
- def call_finalized_function(self) -> Any:
203
+ def _generator_output_format(self) -> "api_pb2.DataFormat.ValueType":
204
+ return self._determine_output_format(self.function_inputs[0].data_format)
205
+
206
+ def _prepare_batch_output(self, data: Any) -> list[Any]:
207
+ # validate that output is valid for batch
208
+ if self._is_batched:
209
+ # assert data is list etc.
210
+ function_name = self.finalized_function.callable.__name__
211
+
212
+ if not isinstance(data, list):
213
+ raise InvalidError(f"Output of batched function {function_name} must be a list.")
214
+ if len(data) != len(self.input_ids):
215
+ raise InvalidError(
216
+ f"Output of batched function {function_name} must be a list of equal length as its inputs."
217
+ )
218
+ return data
219
+ else:
220
+ return [data]
221
+
222
+ def call_function_sync(self) -> list[Any]:
192
223
  logger.debug(f"Starting input {self.input_ids}")
193
224
  args, kwargs = self._args_and_kwargs()
194
- res = self.finalized_function.callable(*args, **kwargs)
225
+ expected_value_or_values = self.finalized_function.callable(*args, **kwargs)
226
+ if (
227
+ inspect.iscoroutine(expected_value_or_values)
228
+ or inspect.isgenerator(expected_value_or_values)
229
+ or inspect.isasyncgen(expected_value_or_values)
230
+ ):
231
+ raise InvalidError(
232
+ f"Sync (non-generator) function return value of type {type(expected_value_or_values)}."
233
+ " You might need to use @app.function(..., is_generator=True)."
234
+ )
195
235
  logger.debug(f"Finished input {self.input_ids}")
196
- return res
197
-
198
- def validate_output_data(self, data: Any) -> list[Any]:
199
- if not self._is_batched:
200
- return [data]
236
+ return self._prepare_batch_output(expected_value_or_values)
201
237
 
202
- function_name = self.finalized_function.callable.__name__
203
- if not isinstance(data, list):
204
- raise InvalidError(f"Output of batched function {function_name} must be a list.")
205
- if len(data) != len(self.input_ids):
238
+ async def call_function_async(self) -> list[Any]:
239
+ logger.debug(f"Starting input {self.input_ids}")
240
+ args, kwargs = self._args_and_kwargs()
241
+ expected_coro = self.finalized_function.callable(*args, **kwargs)
242
+ if (
243
+ not inspect.iscoroutine(expected_coro)
244
+ or inspect.isgenerator(expected_coro)
245
+ or inspect.isasyncgen(expected_coro)
246
+ ):
206
247
  raise InvalidError(
207
- f"Output of batched function {function_name} must be a list of equal length as its inputs."
248
+ f"Async (non-generator) function returned value of type {type(expected_coro)}"
249
+ " You might need to use @app.function(..., is_generator=True)."
208
250
  )
209
- return data
251
+ value = await expected_coro
252
+ logger.debug(f"Finished input {self.input_ids}")
253
+ return self._prepare_batch_output(value)
254
+
255
+ def call_generator_sync(self) -> Generator[Any, None, None]:
256
+ assert not self._is_batched
257
+ logger.debug(f"Starting generator input {self.input_ids}")
258
+ args, kwargs = self._args_and_kwargs()
259
+ expected_gen = self.finalized_function.callable(*args, **kwargs)
260
+ if not inspect.isgenerator(expected_gen):
261
+ raise InvalidError(f"Generator function returned value of type {type(expected_gen)}")
262
+
263
+ for result in expected_gen:
264
+ yield result
265
+ logger.debug(f"Finished generator input {self.input_ids}")
266
+
267
+ async def call_generator_async(self) -> AsyncGenerator[Any, None]:
268
+ assert not self._is_batched
269
+ logger.debug(f"Starting generator input {self.input_ids}")
270
+ args, kwargs = self._args_and_kwargs()
271
+ expected_async_gen = self.finalized_function.callable(*args, **kwargs)
272
+ if not inspect.isasyncgen(expected_async_gen):
273
+ raise InvalidError(f"Async generator function returned value of type {type(expected_async_gen)}")
274
+
275
+ async with aclosing(expected_async_gen) as gen:
276
+ async for result in gen:
277
+ yield result
278
+ logger.debug(f"Finished generator input {self.input_ids}")
279
+
280
+ async def output_items_cancellation(self, started_at: float):
281
+ output_created_at = time.time()
282
+ # Create terminated outputs for these inputs to signal that the cancellations have been completed.
283
+ return [
284
+ api_pb2.FunctionPutOutputsItem(
285
+ input_id=input_id,
286
+ input_started_at=started_at,
287
+ output_created_at=output_created_at,
288
+ result=api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_TERMINATED),
289
+ retry_count=retry_count,
290
+ )
291
+ for input_id, retry_count in zip(self.input_ids, self.retry_counts)
292
+ ]
293
+
294
+ def _determine_output_format(self, input_format: "api_pb2.DataFormat.ValueType") -> "api_pb2.DataFormat.ValueType":
295
+ if input_format in self.finalized_function.supported_output_formats:
296
+ return input_format
297
+ elif self.finalized_function.supported_output_formats:
298
+ # This branch would normally be hit when calling a restricted_output function with Pickle input
299
+ # but we enforce cbor output at function definition level. In the future we might send the intended
300
+ # output format along with the input to make this disitinction in the calling client instead
301
+ logger.debug(
302
+ f"Got an input with format {input_format}, but can only produce output"
303
+ f" using formats {self.finalized_function.supported_output_formats}"
304
+ )
305
+ return self.finalized_function.supported_output_formats[0]
306
+ else:
307
+ # This should never happen since self.finalized_function.supported_output_formats should be
308
+ # populated with defaults in case it's empty, log a warning
309
+ logger.warning(f"Got an input with format {input_format}, but the function has no defined output formats")
310
+ return api_pb2.DATA_FORMAT_PICKLE
311
+
312
+ async def output_items_exception(
313
+ self, started_at: float, task_id: str, exc: BaseException
314
+ ) -> list[api_pb2.FunctionPutOutputsItem]:
315
+ # Note: we're not pickling the traceback since it contains
316
+ # local references that means we can't unpickle it. We *are*
317
+ # pickling the exception, which may have some issues (there
318
+ # was an earlier note about it that it might not be possible
319
+ # to unpickle it in some cases). Let's watch out for issues.
320
+ repr_exc = repr(exc)
321
+ if len(repr_exc) >= MAX_OBJECT_SIZE_BYTES:
322
+ # We prevent large exception messages to avoid
323
+ # unhandled exceptions causing inf loops
324
+ # and just send backa trimmed version
325
+ trimmed_bytes = len(repr_exc) - MAX_OBJECT_SIZE_BYTES - 1000
326
+ repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000]
327
+ repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception"
328
+
329
+ data: bytes = pickle_exception(exc)
330
+ data_result_part = await format_blob_data(data, self._client.stub)
331
+ serialized_tb, tb_line_cache = pickle_traceback(exc, task_id)
332
+
333
+ # Failure outputs for when input exceptions occur
334
+ def data_format_specific_output(input_format: "api_pb2.DataFormat.ValueType") -> dict:
335
+ output_format = self._determine_output_format(input_format)
336
+ if output_format == api_pb2.DATA_FORMAT_PICKLE:
337
+ return {
338
+ "data_format": output_format,
339
+ "result": api_pb2.GenericResult(
340
+ status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
341
+ exception=repr_exc,
342
+ traceback=traceback.format_exc(),
343
+ serialized_tb=serialized_tb,
344
+ tb_line_cache=tb_line_cache,
345
+ **data_result_part,
346
+ ),
347
+ }
348
+ else:
349
+ return {
350
+ "data_format": output_format,
351
+ "result": api_pb2.GenericResult(
352
+ status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
353
+ exception=repr_exc,
354
+ traceback=traceback.format_exc(),
355
+ ),
356
+ }
357
+
358
+ # all inputs in the batch get the same failure:
359
+ output_created_at = time.time()
360
+ return [
361
+ api_pb2.FunctionPutOutputsItem(
362
+ input_id=input_id,
363
+ input_started_at=started_at,
364
+ output_created_at=output_created_at,
365
+ retry_count=retry_count,
366
+ **data_format_specific_output(function_input.data_format),
367
+ )
368
+ for input_id, retry_count, function_input in zip(self.input_ids, self.retry_counts, self.function_inputs)
369
+ ]
370
+
371
+ def output_items_generator_done(self, started_at: float, items_total: int) -> list[api_pb2.FunctionPutOutputsItem]:
372
+ assert not self._is_batched, "generators are not supported with batched inputs"
373
+ assert len(self.function_inputs) == 1, "generators are expected to have 1 input"
374
+ # Serialize and format the data
375
+ serialized_bytes = serialize_data_format(
376
+ api_pb2.GeneratorDone(items_total=items_total), data_format=api_pb2.DATA_FORMAT_GENERATOR_DONE
377
+ )
378
+ return [
379
+ api_pb2.FunctionPutOutputsItem(
380
+ input_id=self.input_ids[0],
381
+ input_started_at=started_at,
382
+ output_created_at=time.time(),
383
+ result=api_pb2.GenericResult(
384
+ status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
385
+ data=serialized_bytes,
386
+ ),
387
+ data_format=api_pb2.DATA_FORMAT_GENERATOR_DONE,
388
+ retry_count=self.retry_counts[0],
389
+ )
390
+ ]
391
+
392
+ async def output_items(self, started_at: float, data: list[Any]) -> list[api_pb2.FunctionPutOutputsItem]:
393
+ output_created_at = time.time()
394
+
395
+ # Process all items concurrently and create output items directly
396
+ async def package_output(
397
+ item: Any, input_id: str, retry_count: int, input_format: "api_pb2.DataFormat.ValueType"
398
+ ) -> api_pb2.FunctionPutOutputsItem:
399
+ output_format = self._determine_output_format(input_format)
400
+
401
+ serialized_bytes = serialize_data_format(item, data_format=output_format)
402
+ formatted = await format_blob_data(serialized_bytes, self._client.stub)
403
+ # Create the result
404
+ result = api_pb2.GenericResult(
405
+ status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
406
+ **formatted,
407
+ )
408
+ return api_pb2.FunctionPutOutputsItem(
409
+ input_id=input_id,
410
+ input_started_at=started_at,
411
+ output_created_at=output_created_at,
412
+ result=result,
413
+ data_format=output_format,
414
+ retry_count=retry_count,
415
+ )
416
+
417
+ # Process all items concurrently
418
+ return await asyncio.gather(
419
+ *[
420
+ package_output(item, input_id, retry_count, function_input.data_format)
421
+ for item, input_id, retry_count, function_input in zip(
422
+ data, self.input_ids, self.retry_counts, self.function_inputs
423
+ )
424
+ ]
425
+ )
210
426
 
211
427
 
212
428
  class InputSlots:
@@ -300,11 +516,7 @@ class _ContainerIOManager:
300
516
  self.function_def = container_args.function_def
301
517
  self.checkpoint_id = container_args.checkpoint_id or None
302
518
 
303
- # We could also have the worker pass this in explicitly.
304
- self.input_plane_server_url = None
305
- for obj in container_args.app_layout.objects:
306
- if obj.object_id == self.function_id:
307
- self.input_plane_server_url = obj.function_handle_metadata.input_plane_url
519
+ self.input_plane_server_url = container_args.input_plane_server_url
308
520
 
309
521
  self.calls_completed = 0
310
522
  self.total_user_time = 0.0
@@ -411,8 +623,8 @@ class _ContainerIOManager:
411
623
  await self.heartbeat_condition.wait()
412
624
 
413
625
  request = api_pb2.ContainerHeartbeatRequest(canceled_inputs_return_outputs_v2=True)
414
- response = await retry_transient_errors(
415
- self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
626
+ response = await self._client.stub.ContainerHeartbeat(
627
+ request, retry=Retry(attempt_timeout=HEARTBEAT_TIMEOUT)
416
628
  )
417
629
 
418
630
  if response.HasField("cancel_input_event"):
@@ -459,10 +671,9 @@ class _ContainerIOManager:
459
671
  target_concurrency=self._target_concurrency,
460
672
  max_concurrency=self._max_concurrency,
461
673
  )
462
- resp = await retry_transient_errors(
463
- self._client.stub.FunctionGetDynamicConcurrency,
674
+ resp = await self._client.stub.FunctionGetDynamicConcurrency(
464
675
  request,
465
- attempt_timeout=DYNAMIC_CONCURRENCY_TIMEOUT_SECS,
676
+ retry=Retry(attempt_timeout=DYNAMIC_CONCURRENCY_TIMEOUT_SECS),
466
677
  )
467
678
  if resp.concurrency != self._input_slots.value and not self._stop_concurrency_loop:
468
679
  logger.debug(f"Dynamic concurrency set from {self._input_slots.value} to {resp.concurrency}")
@@ -473,31 +684,23 @@ class _ContainerIOManager:
473
684
 
474
685
  await asyncio.sleep(DYNAMIC_CONCURRENCY_INTERVAL_SECS)
475
686
 
476
- @synchronizer.no_io_translation
477
- def serialize_data_format(self, obj: Any, data_format: int) -> bytes:
478
- return serialize_data_format(obj, data_format)
479
-
480
- async def format_blob_data(self, data: bytes) -> dict[str, Any]:
481
- return (
482
- {"data_blob_id": await blob_upload(data, self._client.stub)}
483
- if len(data) > MAX_OBJECT_SIZE_BYTES
484
- else {"data": data}
485
- )
486
-
487
- async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]:
687
+ async def get_data_in(self, function_call_id: str, attempt_token: Optional[str]) -> AsyncIterator[Any]:
488
688
  """Read from the `data_in` stream of a function call."""
489
689
  stub = self._client.stub
490
690
  if self.input_plane_server_url:
491
691
  stub = await self._client.get_stub(self.input_plane_server_url)
492
692
 
493
- async for data in _stream_function_call_data(self._client, stub, function_call_id, "data_in"):
693
+ async for data in _stream_function_call_data(
694
+ self._client, stub, function_call_id, variant="data_in", attempt_token=attempt_token
695
+ ):
494
696
  yield data
495
697
 
496
698
  async def put_data_out(
497
699
  self,
498
700
  function_call_id: str,
701
+ attempt_token: str,
499
702
  start_index: int,
500
- data_format: int,
703
+ data_format: "api_pb2.DataFormat.ValueType",
501
704
  serialized_messages: list[Any],
502
705
  ) -> None:
503
706
  """Put data onto the `data_out` stream of a function call.
@@ -516,16 +719,22 @@ class _ContainerIOManager:
516
719
  data_chunks.append(chunk)
517
720
 
518
721
  req = api_pb2.FunctionCallPutDataRequest(function_call_id=function_call_id, data_chunks=data_chunks)
722
+ if attempt_token:
723
+ req.attempt_token = attempt_token # oneof clears function_call_id.
519
724
 
520
725
  if self.input_plane_server_url:
521
726
  stub = await self._client.get_stub(self.input_plane_server_url)
522
- await retry_transient_errors(stub.FunctionCallPutDataOut, req)
727
+ await stub.FunctionCallPutDataOut(req)
523
728
  else:
524
- await retry_transient_errors(self._client.stub.FunctionCallPutDataOut, req)
729
+ await self._client.stub.FunctionCallPutDataOut(req)
525
730
 
526
731
  @asynccontextmanager
527
732
  async def generator_output_sender(
528
- self, function_call_id: str, data_format: int, message_rx: asyncio.Queue
733
+ self,
734
+ function_call_id: str,
735
+ attempt_token: str,
736
+ data_format: "api_pb2.DataFormat.ValueType",
737
+ message_rx: asyncio.Queue,
529
738
  ) -> AsyncGenerator[None, None]:
530
739
  """Runs background task that feeds generator outputs into a function call's `data_out` stream."""
531
740
  GENERATOR_STOP_SENTINEL = Sentinel()
@@ -554,7 +763,7 @@ class _ContainerIOManager:
554
763
  else:
555
764
  serialized_messages.append(serialize_data_format(message, data_format))
556
765
  total_size += len(serialized_messages[-1]) + 512 # 512 bytes for estimated framing overhead
557
- await self.put_data_out(function_call_id, index, data_format, serialized_messages)
766
+ await self.put_data_out(function_call_id, attempt_token, index, data_format, serialized_messages)
558
767
  index += len(serialized_messages)
559
768
 
560
769
  task = asyncio.create_task(generator_output_task())
@@ -590,7 +799,7 @@ class _ContainerIOManager:
590
799
  self,
591
800
  batch_max_size: int,
592
801
  batch_wait_ms: int,
593
- ) -> AsyncIterator[list[tuple[str, int, str, api_pb2.FunctionInput]]]:
802
+ ) -> AsyncIterator[list[tuple[str, int, str, str, api_pb2.FunctionInput]]]:
594
803
  request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id)
595
804
  iteration = 0
596
805
  while self._fetching_inputs:
@@ -605,9 +814,7 @@ class _ContainerIOManager:
605
814
  try:
606
815
  # If number of active inputs is at max queue size, this will block.
607
816
  iteration += 1
608
- response: api_pb2.FunctionGetInputsResponse = await retry_transient_errors(
609
- self._client.stub.FunctionGetInputs, request
610
- )
817
+ response: api_pb2.FunctionGetInputsResponse = await self._client.stub.FunctionGetInputs(request)
611
818
 
612
819
  if response.rate_limit_sleep_duration:
613
820
  logger.info(
@@ -625,7 +832,9 @@ class _ContainerIOManager:
625
832
  if item.kill_switch:
626
833
  logger.debug(f"Task {self.task_id} input kill signal input.")
627
834
  return
628
- inputs.append((item.input_id, item.retry_count, item.function_call_id, item.input))
835
+ inputs.append(
836
+ (item.input_id, item.retry_count, item.function_call_id, item.attempt_token, item.input)
837
+ )
629
838
  if item.input.final_input:
630
839
  if request.batch_max_size > 0:
631
840
  logger.debug(f"Task {self.task_id} Final input not expected in batch input stream")
@@ -666,62 +875,24 @@ class _ContainerIOManager:
666
875
  self.current_input_id, self.current_input_started_at = io_context.input_ids[0], time.time()
667
876
  yield io_context
668
877
  self.current_input_id, self.current_input_started_at = (None, None)
669
-
670
878
  # collect all active input slots, meaning all inputs have wrapped up.
671
879
  await self._input_slots.close()
672
880
 
673
- @synchronizer.no_io_translation
674
- async def _push_outputs(
675
- self,
676
- io_context: IOContext,
677
- started_at: float,
678
- data_format: "modal_proto.api_pb2.DataFormat.ValueType",
679
- results: list[api_pb2.GenericResult],
680
- ) -> None:
681
- output_created_at = time.time()
682
- outputs = [
683
- api_pb2.FunctionPutOutputsItem(
684
- input_id=input_id,
685
- input_started_at=started_at,
686
- output_created_at=output_created_at,
687
- result=result,
688
- data_format=data_format,
689
- retry_count=retry_count,
690
- )
691
- for input_id, retry_count, result in zip(io_context.input_ids, io_context.retry_counts, results)
692
- ]
693
-
881
+ async def _send_outputs(self, started_at: float, outputs: list[api_pb2.FunctionPutOutputsItem]) -> None:
882
+ """Send pre-built output items with retry and chunking."""
694
883
  # There are multiple outputs for a single IOContext in the case of @modal.batched.
695
884
  # Limit the batch size to 20 to stay within message size limits and buffer size limits.
696
885
  output_batch_size = 20
697
886
  for i in range(0, len(outputs), output_batch_size):
698
- await retry_transient_errors(
699
- self._client.stub.FunctionPutOutputs,
887
+ await self._client.stub.FunctionPutOutputs(
700
888
  api_pb2.FunctionPutOutputsRequest(outputs=outputs[i : i + output_batch_size]),
701
- additional_status_codes=[Status.RESOURCE_EXHAUSTED],
702
- max_retries=None, # Retry indefinitely, trying every 1s.
889
+ retry=Retry(
890
+ additional_status_codes=[Status.RESOURCE_EXHAUSTED],
891
+ max_retries=None, # Retry indefinitely, trying every 1s.
892
+ ),
703
893
  )
704
-
705
- def serialize_exception(self, exc: BaseException) -> bytes:
706
- try:
707
- return serialize(exc)
708
- except Exception as serialization_exc:
709
- # We can't always serialize exceptions.
710
- err = f"Failed to serialize exception {exc} of type {type(exc)}: {serialization_exc}"
711
- logger.info(err)
712
- return serialize(SerializationError(err))
713
-
714
- def serialize_traceback(self, exc: BaseException) -> tuple[Optional[bytes], Optional[bytes]]:
715
- serialized_tb, tb_line_cache = None, None
716
-
717
- try:
718
- tb_dict, line_cache = extract_traceback(exc, self.task_id)
719
- serialized_tb = serialize(tb_dict)
720
- tb_line_cache = serialize(line_cache)
721
- except Exception:
722
- logger.info("Failed to serialize exception traceback.")
723
-
724
- return serialized_tb, tb_line_cache
894
+ input_ids = [output.input_id for output in outputs]
895
+ self.exit_context(started_at, input_ids)
725
896
 
726
897
  @asynccontextmanager
727
898
  async def handle_user_exception(self) -> AsyncGenerator[None, None]:
@@ -744,11 +915,14 @@ class _ContainerIOManager:
744
915
  # Since this is on a different thread, sys.exc_info() can't find the exception in the stack.
745
916
  print_exception(type(exc), exc, exc.__traceback__)
746
917
 
747
- serialized_tb, tb_line_cache = self.serialize_traceback(exc)
918
+ serialized_tb, tb_line_cache = pickle_traceback(exc, self.task_id)
748
919
 
920
+ data_or_blob = await format_blob_data(pickle_exception(exc), self._client.stub)
749
921
  result = api_pb2.GenericResult(
750
922
  status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
751
- data=self.serialize_exception(exc),
923
+ **data_or_blob,
924
+ # TODO: there is no way to communicate the data format here
925
+ # since it usually goes on the envelope outside of GenericResult
752
926
  exception=repr(exc),
753
927
  traceback="".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
754
928
  serialized_tb=serialized_tb or b"",
@@ -756,7 +930,7 @@ class _ContainerIOManager:
756
930
  )
757
931
 
758
932
  req = api_pb2.TaskResultRequest(result=result)
759
- await retry_transient_errors(self._client.stub.TaskResult, req)
933
+ await self._client.stub.TaskResult(req)
760
934
 
761
935
  # Shut down the task gracefully
762
936
  raise UserException()
@@ -778,18 +952,8 @@ class _ContainerIOManager:
778
952
  # for the yield. Typically on event loop shutdown
779
953
  raise
780
954
  except (InputCancellation, asyncio.CancelledError):
781
- # Create terminated outputs for these inputs to signal that the cancellations have been completed.
782
- results = [
783
- api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_TERMINATED)
784
- for _ in io_context.input_ids
785
- ]
786
- await self._push_outputs(
787
- io_context=io_context,
788
- started_at=started_at,
789
- data_format=api_pb2.DATA_FORMAT_PICKLE,
790
- results=results,
791
- )
792
- self.exit_context(started_at, io_context.input_ids)
955
+ outputs = await io_context.output_items_cancellation(started_at)
956
+ await self._send_outputs(started_at, outputs)
793
957
  logger.warning(f"Successfully canceled input {io_context.input_ids}")
794
958
  return
795
959
  except BaseException as exc:
@@ -799,44 +963,8 @@ class _ContainerIOManager:
799
963
 
800
964
  # print exception so it's logged
801
965
  print_exception(*sys.exc_info())
802
-
803
- serialized_tb, tb_line_cache = self.serialize_traceback(exc)
804
-
805
- # Note: we're not serializing the traceback since it contains
806
- # local references that means we can't unpickle it. We *are*
807
- # serializing the exception, which may have some issues (there
808
- # was an earlier note about it that it might not be possible
809
- # to unpickle it in some cases). Let's watch out for issues.
810
-
811
- repr_exc = repr(exc)
812
- if len(repr_exc) >= MAX_OBJECT_SIZE_BYTES:
813
- # We prevent large exception messages to avoid
814
- # unhandled exceptions causing inf loops
815
- # and just send backa trimmed version
816
- trimmed_bytes = len(repr_exc) - MAX_OBJECT_SIZE_BYTES - 1000
817
- repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000]
818
- repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception"
819
-
820
- data: bytes = self.serialize_exception(exc) or b""
821
- data_result_part = await self.format_blob_data(data)
822
- results = [
823
- api_pb2.GenericResult(
824
- status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
825
- exception=repr_exc,
826
- traceback=traceback.format_exc(),
827
- serialized_tb=serialized_tb or b"",
828
- tb_line_cache=tb_line_cache or b"",
829
- **data_result_part,
830
- )
831
- for _ in io_context.input_ids
832
- ]
833
- await self._push_outputs(
834
- io_context=io_context,
835
- started_at=started_at,
836
- data_format=api_pb2.DATA_FORMAT_PICKLE,
837
- results=results,
838
- )
839
- self.exit_context(started_at, io_context.input_ids)
966
+ outputs = await io_context.output_items_exception(started_at, self.task_id, exc)
967
+ await self._send_outputs(started_at, outputs)
840
968
 
841
969
  def exit_context(self, started_at, input_ids: list[str]):
842
970
  self.total_user_time += time.time() - started_at
@@ -847,32 +975,17 @@ class _ContainerIOManager:
847
975
 
848
976
  self._input_slots.release()
849
977
 
978
+ # skip inspection of user-generated output_data for synchronicity input translation
850
979
  @synchronizer.no_io_translation
851
980
  async def push_outputs(
852
981
  self,
853
982
  io_context: IOContext,
854
983
  started_at: float,
855
- data: Any,
856
- data_format: "modal_proto.api_pb2.DataFormat.ValueType",
984
+ output_data: list[Any], # one per output
857
985
  ) -> None:
858
- data = io_context.validate_output_data(data)
859
- formatted_data = await asyncio.gather(
860
- *[self.format_blob_data(self.serialize_data_format(d, data_format)) for d in data]
861
- )
862
- results = [
863
- api_pb2.GenericResult(
864
- status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
865
- **d,
866
- )
867
- for d in formatted_data
868
- ]
869
- await self._push_outputs(
870
- io_context=io_context,
871
- started_at=started_at,
872
- data_format=data_format,
873
- results=results,
874
- )
875
- self.exit_context(started_at, io_context.input_ids)
986
+ # The standard output encoding+sending method for successful function outputs
987
+ outputs = await io_context.output_items(started_at, output_data)
988
+ await self._send_outputs(started_at, outputs)
876
989
 
877
990
  async def memory_restore(self) -> None:
878
991
  # Busy-wait for restore. `/__modal/restore-state.json` is created
@@ -967,13 +1080,14 @@ class _ContainerIOManager:
967
1080
  await asyncify(os.sync)()
968
1081
  results = await asyncio.gather(
969
1082
  *[
970
- retry_transient_errors(
971
- self._client.stub.VolumeCommit,
1083
+ self._client.stub.VolumeCommit(
972
1084
  api_pb2.VolumeCommitRequest(volume_id=v_id),
973
- max_retries=9,
974
- base_delay=0.25,
975
- max_delay=256,
976
- delay_factor=2,
1085
+ retry=Retry(
1086
+ max_retries=9,
1087
+ base_delay=0.25,
1088
+ max_delay=256,
1089
+ delay_factor=2,
1090
+ ),
977
1091
  )
978
1092
  for v_id in volume_ids
979
1093
  ],
@@ -1042,7 +1156,8 @@ class _ContainerIOManager:
1042
1156
 
1043
1157
  @classmethod
1044
1158
  def stop_fetching_inputs(cls):
1045
- assert cls._singleton
1159
+ if not cls._singleton:
1160
+ raise RuntimeError("Must be called from within a Modal container.")
1046
1161
  cls._singleton._fetching_inputs = False
1047
1162
 
1048
1163