modal 1.0.3.dev10__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (160) hide show
  1. modal/__init__.py +0 -2
  2. modal/__main__.py +3 -4
  3. modal/_billing.py +80 -0
  4. modal/_clustered_functions.py +7 -3
  5. modal/_clustered_functions.pyi +15 -3
  6. modal/_container_entrypoint.py +51 -69
  7. modal/_functions.py +508 -240
  8. modal/_grpc_client.py +171 -0
  9. modal/_load_context.py +105 -0
  10. modal/_object.py +81 -21
  11. modal/_output.py +58 -45
  12. modal/_partial_function.py +48 -73
  13. modal/_pty.py +7 -3
  14. modal/_resolver.py +26 -46
  15. modal/_runtime/asgi.py +4 -3
  16. modal/_runtime/container_io_manager.py +358 -220
  17. modal/_runtime/container_io_manager.pyi +296 -101
  18. modal/_runtime/execution_context.py +18 -2
  19. modal/_runtime/execution_context.pyi +64 -7
  20. modal/_runtime/gpu_memory_snapshot.py +262 -57
  21. modal/_runtime/user_code_imports.py +28 -58
  22. modal/_serialization.py +90 -6
  23. modal/_traceback.py +42 -1
  24. modal/_tunnel.pyi +380 -12
  25. modal/_utils/async_utils.py +84 -29
  26. modal/_utils/auth_token_manager.py +111 -0
  27. modal/_utils/blob_utils.py +181 -58
  28. modal/_utils/deprecation.py +19 -0
  29. modal/_utils/function_utils.py +91 -47
  30. modal/_utils/grpc_utils.py +89 -66
  31. modal/_utils/mount_utils.py +26 -1
  32. modal/_utils/name_utils.py +17 -3
  33. modal/_utils/task_command_router_client.py +536 -0
  34. modal/_utils/time_utils.py +34 -6
  35. modal/app.py +256 -88
  36. modal/app.pyi +909 -92
  37. modal/billing.py +5 -0
  38. modal/builder/2025.06.txt +18 -0
  39. modal/builder/PREVIEW.txt +18 -0
  40. modal/builder/base-images.json +58 -0
  41. modal/cli/_download.py +19 -3
  42. modal/cli/_traceback.py +3 -2
  43. modal/cli/app.py +4 -4
  44. modal/cli/cluster.py +15 -7
  45. modal/cli/config.py +5 -3
  46. modal/cli/container.py +7 -6
  47. modal/cli/dict.py +22 -16
  48. modal/cli/entry_point.py +12 -5
  49. modal/cli/environment.py +5 -4
  50. modal/cli/import_refs.py +3 -3
  51. modal/cli/launch.py +102 -5
  52. modal/cli/network_file_system.py +11 -12
  53. modal/cli/profile.py +3 -2
  54. modal/cli/programs/launch_instance_ssh.py +94 -0
  55. modal/cli/programs/run_jupyter.py +1 -1
  56. modal/cli/programs/run_marimo.py +95 -0
  57. modal/cli/programs/vscode.py +1 -1
  58. modal/cli/queues.py +57 -26
  59. modal/cli/run.py +91 -23
  60. modal/cli/secret.py +48 -22
  61. modal/cli/token.py +7 -8
  62. modal/cli/utils.py +4 -7
  63. modal/cli/volume.py +31 -25
  64. modal/client.py +15 -85
  65. modal/client.pyi +183 -62
  66. modal/cloud_bucket_mount.py +5 -3
  67. modal/cloud_bucket_mount.pyi +197 -5
  68. modal/cls.py +200 -126
  69. modal/cls.pyi +446 -68
  70. modal/config.py +29 -11
  71. modal/container_process.py +319 -19
  72. modal/container_process.pyi +190 -20
  73. modal/dict.py +290 -71
  74. modal/dict.pyi +835 -83
  75. modal/environments.py +15 -27
  76. modal/environments.pyi +46 -24
  77. modal/exception.py +14 -2
  78. modal/experimental/__init__.py +194 -40
  79. modal/experimental/flash.py +618 -0
  80. modal/experimental/flash.pyi +380 -0
  81. modal/experimental/ipython.py +11 -7
  82. modal/file_io.py +29 -36
  83. modal/file_io.pyi +251 -53
  84. modal/file_pattern_matcher.py +56 -16
  85. modal/functions.pyi +673 -92
  86. modal/gpu.py +1 -1
  87. modal/image.py +528 -176
  88. modal/image.pyi +1572 -145
  89. modal/io_streams.py +458 -128
  90. modal/io_streams.pyi +433 -52
  91. modal/mount.py +216 -151
  92. modal/mount.pyi +225 -78
  93. modal/network_file_system.py +45 -62
  94. modal/network_file_system.pyi +277 -56
  95. modal/object.pyi +93 -17
  96. modal/parallel_map.py +942 -129
  97. modal/parallel_map.pyi +294 -15
  98. modal/partial_function.py +0 -2
  99. modal/partial_function.pyi +234 -19
  100. modal/proxy.py +17 -8
  101. modal/proxy.pyi +36 -3
  102. modal/queue.py +270 -65
  103. modal/queue.pyi +817 -57
  104. modal/runner.py +115 -101
  105. modal/runner.pyi +205 -49
  106. modal/sandbox.py +512 -136
  107. modal/sandbox.pyi +845 -111
  108. modal/schedule.py +1 -1
  109. modal/secret.py +300 -70
  110. modal/secret.pyi +589 -34
  111. modal/serving.py +7 -11
  112. modal/serving.pyi +7 -8
  113. modal/snapshot.py +11 -8
  114. modal/snapshot.pyi +25 -4
  115. modal/token_flow.py +4 -4
  116. modal/token_flow.pyi +28 -8
  117. modal/volume.py +416 -158
  118. modal/volume.pyi +1117 -121
  119. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +10 -9
  120. modal-1.2.3.dev7.dist-info/RECORD +195 -0
  121. modal_docs/mdmd/mdmd.py +17 -4
  122. modal_proto/api.proto +534 -79
  123. modal_proto/api_grpc.py +337 -1
  124. modal_proto/api_pb2.py +1522 -968
  125. modal_proto/api_pb2.pyi +1619 -134
  126. modal_proto/api_pb2_grpc.py +699 -4
  127. modal_proto/api_pb2_grpc.pyi +226 -14
  128. modal_proto/modal_api_grpc.py +175 -154
  129. modal_proto/sandbox_router.proto +145 -0
  130. modal_proto/sandbox_router_grpc.py +105 -0
  131. modal_proto/sandbox_router_pb2.py +149 -0
  132. modal_proto/sandbox_router_pb2.pyi +333 -0
  133. modal_proto/sandbox_router_pb2_grpc.py +203 -0
  134. modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
  135. modal_proto/task_command_router.proto +144 -0
  136. modal_proto/task_command_router_grpc.py +105 -0
  137. modal_proto/task_command_router_pb2.py +149 -0
  138. modal_proto/task_command_router_pb2.pyi +333 -0
  139. modal_proto/task_command_router_pb2_grpc.py +203 -0
  140. modal_proto/task_command_router_pb2_grpc.pyi +75 -0
  141. modal_version/__init__.py +1 -1
  142. modal/requirements/PREVIEW.txt +0 -16
  143. modal/requirements/base-images.json +0 -26
  144. modal-1.0.3.dev10.dist-info/RECORD +0 -179
  145. modal_proto/modal_options_grpc.py +0 -3
  146. modal_proto/options.proto +0 -19
  147. modal_proto/options_grpc.py +0 -3
  148. modal_proto/options_pb2.py +0 -35
  149. modal_proto/options_pb2.pyi +0 -20
  150. modal_proto/options_pb2_grpc.py +0 -4
  151. modal_proto/options_pb2_grpc.pyi +0 -7
  152. /modal/{requirements → builder}/2023.12.312.txt +0 -0
  153. /modal/{requirements → builder}/2023.12.txt +0 -0
  154. /modal/{requirements → builder}/2024.04.txt +0 -0
  155. /modal/{requirements → builder}/2024.10.txt +0 -0
  156. /modal/{requirements → builder}/README.md +0 -0
  157. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
  158. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
  159. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
  160. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ from typing import (
16
16
  Any,
17
17
  Callable,
18
18
  ClassVar,
19
+ Generator,
19
20
  Optional,
20
21
  cast,
21
22
  )
@@ -24,22 +25,25 @@ from google.protobuf.empty_pb2 import Empty
24
25
  from grpclib import Status
25
26
  from synchronicity.async_wrap import asynccontextmanager
26
27
 
27
- import modal_proto.api_pb2
28
28
  from modal._runtime import gpu_memory_snapshot
29
- from modal._serialization import deserialize, serialize, serialize_data_format
30
- from modal._traceback import extract_traceback, print_exception
31
- from modal._utils.async_utils import TaskContext, asyncify, synchronize_api, synchronizer
32
- from modal._utils.blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload
29
+ from modal._serialization import (
30
+ deserialize_data_format,
31
+ pickle_exception,
32
+ pickle_traceback,
33
+ serialize_data_format,
34
+ )
35
+ from modal._traceback import print_exception
36
+ from modal._utils.async_utils import TaskContext, aclosing, asyncify, synchronize_api, synchronizer
37
+ from modal._utils.blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload, format_blob_data
33
38
  from modal._utils.function_utils import _stream_function_call_data
34
- from modal._utils.grpc_utils import retry_transient_errors
39
+ from modal._utils.grpc_utils import Retry
35
40
  from modal._utils.package_utils import parse_major_minor_version
36
41
  from modal.client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client
37
42
  from modal.config import config, logger
38
- from modal.exception import ClientClosed, InputCancellation, InvalidError, SerializationError
43
+ from modal.exception import ClientClosed, InputCancellation, InvalidError
39
44
  from modal_proto import api_pb2
40
45
 
41
46
  if TYPE_CHECKING:
42
- import modal._runtime.asgi
43
47
  import modal._runtime.user_code_imports
44
48
 
45
49
 
@@ -66,6 +70,7 @@ class IOContext:
66
70
  input_ids: list[str]
67
71
  retry_counts: list[int]
68
72
  function_call_ids: list[str]
73
+ attempt_tokens: list[str]
69
74
  function_inputs: list[api_pb2.FunctionInput]
70
75
  finalized_function: "modal._runtime.user_code_imports.FinalizedFunction"
71
76
 
@@ -77,6 +82,7 @@ class IOContext:
77
82
  input_ids: list[str],
78
83
  retry_counts: list[int],
79
84
  function_call_ids: list[str],
85
+ attempt_tokens: list[str],
80
86
  finalized_function: "modal._runtime.user_code_imports.FinalizedFunction",
81
87
  function_inputs: list[api_pb2.FunctionInput],
82
88
  is_batched: bool,
@@ -85,6 +91,7 @@ class IOContext:
85
91
  self.input_ids = input_ids
86
92
  self.retry_counts = retry_counts
87
93
  self.function_call_ids = function_call_ids
94
+ self.attempt_tokens = attempt_tokens
88
95
  self.finalized_function = finalized_function
89
96
  self.function_inputs = function_inputs
90
97
  self._is_batched = is_batched
@@ -95,11 +102,11 @@ class IOContext:
95
102
  cls,
96
103
  client: _Client,
97
104
  finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
98
- inputs: list[tuple[str, int, str, api_pb2.FunctionInput]],
105
+ inputs: list[tuple[str, int, str, str, api_pb2.FunctionInput]],
99
106
  is_batched: bool,
100
107
  ) -> "IOContext":
101
108
  assert len(inputs) >= 1 if is_batched else len(inputs) == 1
102
- input_ids, retry_counts, function_call_ids, function_inputs = zip(*inputs)
109
+ input_ids, retry_counts, function_call_ids, attempt_tokens, function_inputs = zip(*inputs)
103
110
 
104
111
  async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput) -> api_pb2.FunctionInput:
105
112
  # If we got a pointer to a blob, download it from S3.
@@ -121,6 +128,7 @@ class IOContext:
121
128
  cast(list[str], input_ids),
122
129
  cast(list[int], retry_counts),
123
130
  cast(list[str], function_call_ids),
131
+ cast(list[str], attempt_tokens),
124
132
  finalized_function,
125
133
  cast(list[api_pb2.FunctionInput], function_inputs),
126
134
  is_batched,
@@ -148,9 +156,13 @@ class IOContext:
148
156
  # deserializing here instead of the constructor
149
157
  # to make sure we handle user exceptions properly
150
158
  # and don't retry
151
- deserialized_args = [
152
- deserialize(input.args, self._client) if input.args else ((), {}) for input in self.function_inputs
153
- ]
159
+ deserialized_args = []
160
+ for input in self.function_inputs:
161
+ if input.args:
162
+ data_format = input.data_format
163
+ deserialized_args.append(deserialize_data_format(input.args, data_format, self._client))
164
+ else:
165
+ deserialized_args.append(((), {}))
154
166
  if not self._is_batched:
155
167
  return deserialized_args[0]
156
168
 
@@ -188,25 +200,229 @@ class IOContext:
188
200
  }
189
201
  return (), formatted_kwargs
190
202
 
191
- def call_finalized_function(self) -> Any:
203
+ def _generator_output_format(self) -> "api_pb2.DataFormat.ValueType":
204
+ return self._determine_output_format(self.function_inputs[0].data_format)
205
+
206
+ def _prepare_batch_output(self, data: Any) -> list[Any]:
207
+ # validate that output is valid for batch
208
+ if self._is_batched:
209
+ # assert data is list etc.
210
+ function_name = self.finalized_function.callable.__name__
211
+
212
+ if not isinstance(data, list):
213
+ raise InvalidError(f"Output of batched function {function_name} must be a list.")
214
+ if len(data) != len(self.input_ids):
215
+ raise InvalidError(
216
+ f"Output of batched function {function_name} must be a list of equal length as its inputs."
217
+ )
218
+ return data
219
+ else:
220
+ return [data]
221
+
222
+ def call_function_sync(self) -> list[Any]:
192
223
  logger.debug(f"Starting input {self.input_ids}")
193
224
  args, kwargs = self._args_and_kwargs()
194
- res = self.finalized_function.callable(*args, **kwargs)
225
+ expected_value_or_values = self.finalized_function.callable(*args, **kwargs)
226
+ if (
227
+ inspect.iscoroutine(expected_value_or_values)
228
+ or inspect.isgenerator(expected_value_or_values)
229
+ or inspect.isasyncgen(expected_value_or_values)
230
+ ):
231
+ raise InvalidError(
232
+ f"Sync (non-generator) function return value of type {type(expected_value_or_values)}."
233
+ " You might need to use @app.function(..., is_generator=True)."
234
+ )
195
235
  logger.debug(f"Finished input {self.input_ids}")
196
- return res
236
+ return self._prepare_batch_output(expected_value_or_values)
197
237
 
198
- def validate_output_data(self, data: Any) -> list[Any]:
199
- if not self._is_batched:
200
- return [data]
201
-
202
- function_name = self.finalized_function.callable.__name__
203
- if not isinstance(data, list):
204
- raise InvalidError(f"Output of batched function {function_name} must be a list.")
205
- if len(data) != len(self.input_ids):
238
+ async def call_function_async(self) -> list[Any]:
239
+ logger.debug(f"Starting input {self.input_ids}")
240
+ args, kwargs = self._args_and_kwargs()
241
+ expected_coro = self.finalized_function.callable(*args, **kwargs)
242
+ if (
243
+ not inspect.iscoroutine(expected_coro)
244
+ or inspect.isgenerator(expected_coro)
245
+ or inspect.isasyncgen(expected_coro)
246
+ ):
206
247
  raise InvalidError(
207
- f"Output of batched function {function_name} must be a list of equal length as its inputs."
248
+ f"Async (non-generator) function returned value of type {type(expected_coro)}"
249
+ " You might need to use @app.function(..., is_generator=True)."
208
250
  )
209
- return data
251
+ value = await expected_coro
252
+ logger.debug(f"Finished input {self.input_ids}")
253
+ return self._prepare_batch_output(value)
254
+
255
+ def call_generator_sync(self) -> Generator[Any, None, None]:
256
+ assert not self._is_batched
257
+ logger.debug(f"Starting generator input {self.input_ids}")
258
+ args, kwargs = self._args_and_kwargs()
259
+ expected_gen = self.finalized_function.callable(*args, **kwargs)
260
+ if not inspect.isgenerator(expected_gen):
261
+ raise InvalidError(f"Generator function returned value of type {type(expected_gen)}")
262
+
263
+ for result in expected_gen:
264
+ yield result
265
+ logger.debug(f"Finished generator input {self.input_ids}")
266
+
267
+ async def call_generator_async(self) -> AsyncGenerator[Any, None]:
268
+ assert not self._is_batched
269
+ logger.debug(f"Starting generator input {self.input_ids}")
270
+ args, kwargs = self._args_and_kwargs()
271
+ expected_async_gen = self.finalized_function.callable(*args, **kwargs)
272
+ if not inspect.isasyncgen(expected_async_gen):
273
+ raise InvalidError(f"Async generator function returned value of type {type(expected_async_gen)}")
274
+
275
+ async with aclosing(expected_async_gen) as gen:
276
+ async for result in gen:
277
+ yield result
278
+ logger.debug(f"Finished generator input {self.input_ids}")
279
+
280
+ async def output_items_cancellation(self, started_at: float):
281
+ output_created_at = time.time()
282
+ # Create terminated outputs for these inputs to signal that the cancellations have been completed.
283
+ return [
284
+ api_pb2.FunctionPutOutputsItem(
285
+ input_id=input_id,
286
+ input_started_at=started_at,
287
+ output_created_at=output_created_at,
288
+ result=api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_TERMINATED),
289
+ retry_count=retry_count,
290
+ )
291
+ for input_id, retry_count in zip(self.input_ids, self.retry_counts)
292
+ ]
293
+
294
+ def _determine_output_format(self, input_format: "api_pb2.DataFormat.ValueType") -> "api_pb2.DataFormat.ValueType":
295
+ if input_format in self.finalized_function.supported_output_formats:
296
+ return input_format
297
+ elif self.finalized_function.supported_output_formats:
298
+ # This branch would normally be hit when calling a restricted_output function with Pickle input
299
+ # but we enforce cbor output at function definition level. In the future we might send the intended
300
+ # output format along with the input to make this disitinction in the calling client instead
301
+ logger.debug(
302
+ f"Got an input with format {input_format}, but can only produce output"
303
+ f" using formats {self.finalized_function.supported_output_formats}"
304
+ )
305
+ return self.finalized_function.supported_output_formats[0]
306
+ else:
307
+ # This should never happen since self.finalized_function.supported_output_formats should be
308
+ # populated with defaults in case it's empty, log a warning
309
+ logger.warning(f"Got an input with format {input_format}, but the function has no defined output formats")
310
+ return api_pb2.DATA_FORMAT_PICKLE
311
+
312
+ async def output_items_exception(
313
+ self, started_at: float, task_id: str, exc: BaseException
314
+ ) -> list[api_pb2.FunctionPutOutputsItem]:
315
+ # Note: we're not pickling the traceback since it contains
316
+ # local references that means we can't unpickle it. We *are*
317
+ # pickling the exception, which may have some issues (there
318
+ # was an earlier note about it that it might not be possible
319
+ # to unpickle it in some cases). Let's watch out for issues.
320
+ repr_exc = repr(exc)
321
+ if len(repr_exc) >= MAX_OBJECT_SIZE_BYTES:
322
+ # We prevent large exception messages to avoid
323
+ # unhandled exceptions causing inf loops
324
+ # and just send backa trimmed version
325
+ trimmed_bytes = len(repr_exc) - MAX_OBJECT_SIZE_BYTES - 1000
326
+ repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000]
327
+ repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception"
328
+
329
+ data: bytes = pickle_exception(exc)
330
+ data_result_part = await format_blob_data(data, self._client.stub)
331
+ serialized_tb, tb_line_cache = pickle_traceback(exc, task_id)
332
+
333
+ # Failure outputs for when input exceptions occur
334
+ def data_format_specific_output(input_format: "api_pb2.DataFormat.ValueType") -> dict:
335
+ output_format = self._determine_output_format(input_format)
336
+ if output_format == api_pb2.DATA_FORMAT_PICKLE:
337
+ return {
338
+ "data_format": output_format,
339
+ "result": api_pb2.GenericResult(
340
+ status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
341
+ exception=repr_exc,
342
+ traceback=traceback.format_exc(),
343
+ serialized_tb=serialized_tb,
344
+ tb_line_cache=tb_line_cache,
345
+ **data_result_part,
346
+ ),
347
+ }
348
+ else:
349
+ return {
350
+ "data_format": output_format,
351
+ "result": api_pb2.GenericResult(
352
+ status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
353
+ exception=repr_exc,
354
+ traceback=traceback.format_exc(),
355
+ ),
356
+ }
357
+
358
+ # all inputs in the batch get the same failure:
359
+ output_created_at = time.time()
360
+ return [
361
+ api_pb2.FunctionPutOutputsItem(
362
+ input_id=input_id,
363
+ input_started_at=started_at,
364
+ output_created_at=output_created_at,
365
+ retry_count=retry_count,
366
+ **data_format_specific_output(function_input.data_format),
367
+ )
368
+ for input_id, retry_count, function_input in zip(self.input_ids, self.retry_counts, self.function_inputs)
369
+ ]
370
+
371
+ def output_items_generator_done(self, started_at: float, items_total: int) -> list[api_pb2.FunctionPutOutputsItem]:
372
+ assert not self._is_batched, "generators are not supported with batched inputs"
373
+ assert len(self.function_inputs) == 1, "generators are expected to have 1 input"
374
+ # Serialize and format the data
375
+ serialized_bytes = serialize_data_format(
376
+ api_pb2.GeneratorDone(items_total=items_total), data_format=api_pb2.DATA_FORMAT_GENERATOR_DONE
377
+ )
378
+ return [
379
+ api_pb2.FunctionPutOutputsItem(
380
+ input_id=self.input_ids[0],
381
+ input_started_at=started_at,
382
+ output_created_at=time.time(),
383
+ result=api_pb2.GenericResult(
384
+ status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
385
+ data=serialized_bytes,
386
+ ),
387
+ data_format=api_pb2.DATA_FORMAT_GENERATOR_DONE,
388
+ retry_count=self.retry_counts[0],
389
+ )
390
+ ]
391
+
392
+ async def output_items(self, started_at: float, data: list[Any]) -> list[api_pb2.FunctionPutOutputsItem]:
393
+ output_created_at = time.time()
394
+
395
+ # Process all items concurrently and create output items directly
396
+ async def package_output(
397
+ item: Any, input_id: str, retry_count: int, input_format: "api_pb2.DataFormat.ValueType"
398
+ ) -> api_pb2.FunctionPutOutputsItem:
399
+ output_format = self._determine_output_format(input_format)
400
+
401
+ serialized_bytes = serialize_data_format(item, data_format=output_format)
402
+ formatted = await format_blob_data(serialized_bytes, self._client.stub)
403
+ # Create the result
404
+ result = api_pb2.GenericResult(
405
+ status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
406
+ **formatted,
407
+ )
408
+ return api_pb2.FunctionPutOutputsItem(
409
+ input_id=input_id,
410
+ input_started_at=started_at,
411
+ output_created_at=output_created_at,
412
+ result=result,
413
+ data_format=output_format,
414
+ retry_count=retry_count,
415
+ )
416
+
417
+ # Process all items concurrently
418
+ return await asyncio.gather(
419
+ *[
420
+ package_output(item, input_id, retry_count, function_input.data_format)
421
+ for item, input_id, retry_count, function_input in zip(
422
+ data, self.input_ids, self.retry_counts, self.function_inputs
423
+ )
424
+ ]
425
+ )
210
426
 
211
427
 
212
428
  class InputSlots:
@@ -267,6 +483,7 @@ class _ContainerIOManager:
267
483
  app_id: str
268
484
  function_def: api_pb2.Function
269
485
  checkpoint_id: Optional[str]
486
+ input_plane_server_url: Optional[str]
270
487
 
271
488
  calls_completed: int
272
489
  total_user_time: float
@@ -290,7 +507,6 @@ class _ContainerIOManager:
290
507
 
291
508
  _client: _Client
292
509
 
293
- _GENERATOR_STOP_SENTINEL: ClassVar[Sentinel] = Sentinel()
294
510
  _singleton: ClassVar[Optional["_ContainerIOManager"]] = None
295
511
 
296
512
  def _init(self, container_args: api_pb2.ContainerArguments, client: _Client):
@@ -300,6 +516,8 @@ class _ContainerIOManager:
300
516
  self.function_def = container_args.function_def
301
517
  self.checkpoint_id = container_args.checkpoint_id or None
302
518
 
519
+ self.input_plane_server_url = container_args.input_plane_server_url
520
+
303
521
  self.calls_completed = 0
304
522
  self.total_user_time = 0.0
305
523
  self.current_input_id = None
@@ -323,6 +541,7 @@ class _ContainerIOManager:
323
541
  self._heartbeat_loop = None
324
542
  self._heartbeat_condition = None
325
543
  self._waiting_for_memory_snapshot = False
544
+ self._cuda_checkpoint_session = None
326
545
 
327
546
  self._is_interactivity_enabled = False
328
547
  self._fetching_inputs = True
@@ -404,8 +623,8 @@ class _ContainerIOManager:
404
623
  await self.heartbeat_condition.wait()
405
624
 
406
625
  request = api_pb2.ContainerHeartbeatRequest(canceled_inputs_return_outputs_v2=True)
407
- response = await retry_transient_errors(
408
- self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
626
+ response = await self._client.stub.ContainerHeartbeat(
627
+ request, retry=Retry(attempt_timeout=HEARTBEAT_TIMEOUT)
409
628
  )
410
629
 
411
630
  if response.HasField("cancel_input_event"):
@@ -452,10 +671,9 @@ class _ContainerIOManager:
452
671
  target_concurrency=self._target_concurrency,
453
672
  max_concurrency=self._max_concurrency,
454
673
  )
455
- resp = await retry_transient_errors(
456
- self._client.stub.FunctionGetDynamicConcurrency,
674
+ resp = await self._client.stub.FunctionGetDynamicConcurrency(
457
675
  request,
458
- attempt_timeout=DYNAMIC_CONCURRENCY_TIMEOUT_SECS,
676
+ retry=Retry(attempt_timeout=DYNAMIC_CONCURRENCY_TIMEOUT_SECS),
459
677
  )
460
678
  if resp.concurrency != self._input_slots.value and not self._stop_concurrency_loop:
461
679
  logger.debug(f"Dynamic concurrency set from {self._input_slots.value} to {resp.concurrency}")
@@ -466,27 +684,23 @@ class _ContainerIOManager:
466
684
 
467
685
  await asyncio.sleep(DYNAMIC_CONCURRENCY_INTERVAL_SECS)
468
686
 
469
- @synchronizer.no_io_translation
470
- def serialize_data_format(self, obj: Any, data_format: int) -> bytes:
471
- return serialize_data_format(obj, data_format)
472
-
473
- async def format_blob_data(self, data: bytes) -> dict[str, Any]:
474
- return (
475
- {"data_blob_id": await blob_upload(data, self._client.stub)}
476
- if len(data) > MAX_OBJECT_SIZE_BYTES
477
- else {"data": data}
478
- )
479
-
480
- async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]:
687
+ async def get_data_in(self, function_call_id: str, attempt_token: Optional[str]) -> AsyncIterator[Any]:
481
688
  """Read from the `data_in` stream of a function call."""
482
- async for data in _stream_function_call_data(self._client, function_call_id, "data_in"):
689
+ stub = self._client.stub
690
+ if self.input_plane_server_url:
691
+ stub = await self._client.get_stub(self.input_plane_server_url)
692
+
693
+ async for data in _stream_function_call_data(
694
+ self._client, stub, function_call_id, variant="data_in", attempt_token=attempt_token
695
+ ):
483
696
  yield data
484
697
 
485
698
  async def put_data_out(
486
699
  self,
487
700
  function_call_id: str,
701
+ attempt_token: str,
488
702
  start_index: int,
489
- data_format: int,
703
+ data_format: "api_pb2.DataFormat.ValueType",
490
704
  serialized_messages: list[Any],
491
705
  ) -> None:
492
706
  """Put data onto the `data_out` stream of a function call.
@@ -505,35 +719,60 @@ class _ContainerIOManager:
505
719
  data_chunks.append(chunk)
506
720
 
507
721
  req = api_pb2.FunctionCallPutDataRequest(function_call_id=function_call_id, data_chunks=data_chunks)
508
- await retry_transient_errors(self._client.stub.FunctionCallPutDataOut, req)
509
-
510
- async def generator_output_task(self, function_call_id: str, data_format: int, message_rx: asyncio.Queue) -> None:
511
- """Task that feeds generator outputs into a function call's `data_out` stream."""
512
- index = 1
513
- received_sentinel = False
514
- while not received_sentinel:
515
- message = await message_rx.get()
516
- if message is self._GENERATOR_STOP_SENTINEL:
517
- break
518
- # ASGI 'http.response.start' and 'http.response.body' msgs are observed to be separated by 1ms.
519
- # If we don't sleep here for 1ms we end up with an extra call to .put_data_out().
520
- if index == 1:
521
- await asyncio.sleep(0.001)
522
- serialized_messages = [serialize_data_format(message, data_format)]
523
- total_size = len(serialized_messages[0]) + 512
524
- while total_size < 16 * 1024 * 1024: # 16 MiB, maximum size in a single message
525
- try:
526
- message = message_rx.get_nowait()
527
- except asyncio.QueueEmpty:
528
- break
529
- if message is self._GENERATOR_STOP_SENTINEL:
530
- received_sentinel = True
722
+ if attempt_token:
723
+ req.attempt_token = attempt_token # oneof clears function_call_id.
724
+
725
+ if self.input_plane_server_url:
726
+ stub = await self._client.get_stub(self.input_plane_server_url)
727
+ await stub.FunctionCallPutDataOut(req)
728
+ else:
729
+ await self._client.stub.FunctionCallPutDataOut(req)
730
+
731
+ @asynccontextmanager
732
+ async def generator_output_sender(
733
+ self,
734
+ function_call_id: str,
735
+ attempt_token: str,
736
+ data_format: "api_pb2.DataFormat.ValueType",
737
+ message_rx: asyncio.Queue,
738
+ ) -> AsyncGenerator[None, None]:
739
+ """Runs background task that feeds generator outputs into a function call's `data_out` stream."""
740
+ GENERATOR_STOP_SENTINEL = Sentinel()
741
+
742
+ async def generator_output_task():
743
+ index = 1
744
+ received_sentinel = False
745
+ while not received_sentinel:
746
+ message = await message_rx.get()
747
+ if message is GENERATOR_STOP_SENTINEL:
531
748
  break
532
- else:
533
- serialized_messages.append(serialize_data_format(message, data_format))
534
- total_size += len(serialized_messages[-1]) + 512 # 512 bytes for estimated framing overhead
535
- await self.put_data_out(function_call_id, index, data_format, serialized_messages)
536
- index += len(serialized_messages)
749
+ # ASGI 'http.response.start' and 'http.response.body' msgs are observed to be separated by 1ms.
750
+ # If we don't sleep here for 1ms we end up with an extra call to .put_data_out().
751
+ if index == 1:
752
+ await asyncio.sleep(0.001)
753
+ serialized_messages = [serialize_data_format(message, data_format)]
754
+ total_size = len(serialized_messages[0]) + 512
755
+ while total_size < 16 * 1024 * 1024: # 16 MiB, maximum size in a single message
756
+ try:
757
+ message = message_rx.get_nowait()
758
+ except asyncio.QueueEmpty:
759
+ break
760
+ if message is GENERATOR_STOP_SENTINEL:
761
+ received_sentinel = True
762
+ break
763
+ else:
764
+ serialized_messages.append(serialize_data_format(message, data_format))
765
+ total_size += len(serialized_messages[-1]) + 512 # 512 bytes for estimated framing overhead
766
+ await self.put_data_out(function_call_id, attempt_token, index, data_format, serialized_messages)
767
+ index += len(serialized_messages)
768
+
769
+ task = asyncio.create_task(generator_output_task())
770
+ try:
771
+ yield
772
+ finally:
773
+ # gracefully stop the task after all current inputs have been sent
774
+ await message_rx.put(GENERATOR_STOP_SENTINEL)
775
+ await task
537
776
 
538
777
  async def _queue_create(self, size: int) -> asyncio.Queue:
539
778
  """Create a queue, on the synchronicity event loop (needed on Python 3.8 and 3.9)."""
@@ -560,7 +799,7 @@ class _ContainerIOManager:
560
799
  self,
561
800
  batch_max_size: int,
562
801
  batch_wait_ms: int,
563
- ) -> AsyncIterator[list[tuple[str, int, str, api_pb2.FunctionInput]]]:
802
+ ) -> AsyncIterator[list[tuple[str, int, str, str, api_pb2.FunctionInput]]]:
564
803
  request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id)
565
804
  iteration = 0
566
805
  while self._fetching_inputs:
@@ -575,9 +814,7 @@ class _ContainerIOManager:
575
814
  try:
576
815
  # If number of active inputs is at max queue size, this will block.
577
816
  iteration += 1
578
- response: api_pb2.FunctionGetInputsResponse = await retry_transient_errors(
579
- self._client.stub.FunctionGetInputs, request
580
- )
817
+ response: api_pb2.FunctionGetInputsResponse = await self._client.stub.FunctionGetInputs(request)
581
818
 
582
819
  if response.rate_limit_sleep_duration:
583
820
  logger.info(
@@ -595,7 +832,9 @@ class _ContainerIOManager:
595
832
  if item.kill_switch:
596
833
  logger.debug(f"Task {self.task_id} input kill signal input.")
597
834
  return
598
- inputs.append((item.input_id, item.retry_count, item.function_call_id, item.input))
835
+ inputs.append(
836
+ (item.input_id, item.retry_count, item.function_call_id, item.attempt_token, item.input)
837
+ )
599
838
  if item.input.final_input:
600
839
  if request.batch_max_size > 0:
601
840
  logger.debug(f"Task {self.task_id} Final input not expected in batch input stream")
@@ -636,62 +875,24 @@ class _ContainerIOManager:
636
875
  self.current_input_id, self.current_input_started_at = io_context.input_ids[0], time.time()
637
876
  yield io_context
638
877
  self.current_input_id, self.current_input_started_at = (None, None)
639
-
640
878
  # collect all active input slots, meaning all inputs have wrapped up.
641
879
  await self._input_slots.close()
642
880
 
643
- @synchronizer.no_io_translation
644
- async def _push_outputs(
645
- self,
646
- io_context: IOContext,
647
- started_at: float,
648
- data_format: "modal_proto.api_pb2.DataFormat.ValueType",
649
- results: list[api_pb2.GenericResult],
650
- ) -> None:
651
- output_created_at = time.time()
652
- outputs = [
653
- api_pb2.FunctionPutOutputsItem(
654
- input_id=input_id,
655
- input_started_at=started_at,
656
- output_created_at=output_created_at,
657
- result=result,
658
- data_format=data_format,
659
- retry_count=retry_count,
660
- )
661
- for input_id, retry_count, result in zip(io_context.input_ids, io_context.retry_counts, results)
662
- ]
663
-
881
+ async def _send_outputs(self, started_at: float, outputs: list[api_pb2.FunctionPutOutputsItem]) -> None:
882
+ """Send pre-built output items with retry and chunking."""
664
883
  # There are multiple outputs for a single IOContext in the case of @modal.batched.
665
884
  # Limit the batch size to 20 to stay within message size limits and buffer size limits.
666
885
  output_batch_size = 20
667
886
  for i in range(0, len(outputs), output_batch_size):
668
- await retry_transient_errors(
669
- self._client.stub.FunctionPutOutputs,
887
+ await self._client.stub.FunctionPutOutputs(
670
888
  api_pb2.FunctionPutOutputsRequest(outputs=outputs[i : i + output_batch_size]),
671
- additional_status_codes=[Status.RESOURCE_EXHAUSTED],
672
- max_retries=None, # Retry indefinitely, trying every 1s.
889
+ retry=Retry(
890
+ additional_status_codes=[Status.RESOURCE_EXHAUSTED],
891
+ max_retries=None, # Retry indefinitely, trying every 1s.
892
+ ),
673
893
  )
674
-
675
- def serialize_exception(self, exc: BaseException) -> bytes:
676
- try:
677
- return serialize(exc)
678
- except Exception as serialization_exc:
679
- # We can't always serialize exceptions.
680
- err = f"Failed to serialize exception {exc} of type {type(exc)}: {serialization_exc}"
681
- logger.info(err)
682
- return serialize(SerializationError(err))
683
-
684
- def serialize_traceback(self, exc: BaseException) -> tuple[Optional[bytes], Optional[bytes]]:
685
- serialized_tb, tb_line_cache = None, None
686
-
687
- try:
688
- tb_dict, line_cache = extract_traceback(exc, self.task_id)
689
- serialized_tb = serialize(tb_dict)
690
- tb_line_cache = serialize(line_cache)
691
- except Exception:
692
- logger.info("Failed to serialize exception traceback.")
693
-
694
- return serialized_tb, tb_line_cache
894
+ input_ids = [output.input_id for output in outputs]
895
+ self.exit_context(started_at, input_ids)
695
896
 
696
897
  @asynccontextmanager
697
898
  async def handle_user_exception(self) -> AsyncGenerator[None, None]:
@@ -714,11 +915,14 @@ class _ContainerIOManager:
714
915
  # Since this is on a different thread, sys.exc_info() can't find the exception in the stack.
715
916
  print_exception(type(exc), exc, exc.__traceback__)
716
917
 
717
- serialized_tb, tb_line_cache = self.serialize_traceback(exc)
918
+ serialized_tb, tb_line_cache = pickle_traceback(exc, self.task_id)
718
919
 
920
+ data_or_blob = await format_blob_data(pickle_exception(exc), self._client.stub)
719
921
  result = api_pb2.GenericResult(
720
922
  status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
721
- data=self.serialize_exception(exc),
923
+ **data_or_blob,
924
+ # TODO: there is no way to communicate the data format here
925
+ # since it usually goes on the envelope outside of GenericResult
722
926
  exception=repr(exc),
723
927
  traceback="".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
724
928
  serialized_tb=serialized_tb or b"",
@@ -726,7 +930,7 @@ class _ContainerIOManager:
726
930
  )
727
931
 
728
932
  req = api_pb2.TaskResultRequest(result=result)
729
- await retry_transient_errors(self._client.stub.TaskResult, req)
933
+ await self._client.stub.TaskResult(req)
730
934
 
731
935
  # Shut down the task gracefully
732
936
  raise UserException()
@@ -748,18 +952,8 @@ class _ContainerIOManager:
748
952
  # for the yield. Typically on event loop shutdown
749
953
  raise
750
954
  except (InputCancellation, asyncio.CancelledError):
751
- # Create terminated outputs for these inputs to signal that the cancellations have been completed.
752
- results = [
753
- api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_TERMINATED)
754
- for _ in io_context.input_ids
755
- ]
756
- await self._push_outputs(
757
- io_context=io_context,
758
- started_at=started_at,
759
- data_format=api_pb2.DATA_FORMAT_PICKLE,
760
- results=results,
761
- )
762
- self.exit_context(started_at, io_context.input_ids)
955
+ outputs = await io_context.output_items_cancellation(started_at)
956
+ await self._send_outputs(started_at, outputs)
763
957
  logger.warning(f"Successfully canceled input {io_context.input_ids}")
764
958
  return
765
959
  except BaseException as exc:
@@ -769,44 +963,8 @@ class _ContainerIOManager:
769
963
 
770
964
  # print exception so it's logged
771
965
  print_exception(*sys.exc_info())
772
-
773
- serialized_tb, tb_line_cache = self.serialize_traceback(exc)
774
-
775
- # Note: we're not serializing the traceback since it contains
776
- # local references that means we can't unpickle it. We *are*
777
- # serializing the exception, which may have some issues (there
778
- # was an earlier note about it that it might not be possible
779
- # to unpickle it in some cases). Let's watch out for issues.
780
-
781
- repr_exc = repr(exc)
782
- if len(repr_exc) >= MAX_OBJECT_SIZE_BYTES:
783
- # We prevent large exception messages to avoid
784
- # unhandled exceptions causing inf loops
785
- # and just send backa trimmed version
786
- trimmed_bytes = len(repr_exc) - MAX_OBJECT_SIZE_BYTES - 1000
787
- repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000]
788
- repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception"
789
-
790
- data: bytes = self.serialize_exception(exc) or b""
791
- data_result_part = await self.format_blob_data(data)
792
- results = [
793
- api_pb2.GenericResult(
794
- status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
795
- exception=repr_exc,
796
- traceback=traceback.format_exc(),
797
- serialized_tb=serialized_tb or b"",
798
- tb_line_cache=tb_line_cache or b"",
799
- **data_result_part,
800
- )
801
- for _ in io_context.input_ids
802
- ]
803
- await self._push_outputs(
804
- io_context=io_context,
805
- started_at=started_at,
806
- data_format=api_pb2.DATA_FORMAT_PICKLE,
807
- results=results,
808
- )
809
- self.exit_context(started_at, io_context.input_ids)
966
+ outputs = await io_context.output_items_exception(started_at, self.task_id, exc)
967
+ await self._send_outputs(started_at, outputs)
810
968
 
811
969
  def exit_context(self, started_at, input_ids: list[str]):
812
970
  self.total_user_time += time.time() - started_at
@@ -817,32 +975,17 @@ class _ContainerIOManager:
817
975
 
818
976
  self._input_slots.release()
819
977
 
978
+ # skip inspection of user-generated output_data for synchronicity input translation
820
979
  @synchronizer.no_io_translation
821
980
  async def push_outputs(
822
981
  self,
823
982
  io_context: IOContext,
824
983
  started_at: float,
825
- data: Any,
826
- data_format: "modal_proto.api_pb2.DataFormat.ValueType",
984
+ output_data: list[Any], # one per output
827
985
  ) -> None:
828
- data = io_context.validate_output_data(data)
829
- formatted_data = await asyncio.gather(
830
- *[self.format_blob_data(self.serialize_data_format(d, data_format)) for d in data]
831
- )
832
- results = [
833
- api_pb2.GenericResult(
834
- status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
835
- **d,
836
- )
837
- for d in formatted_data
838
- ]
839
- await self._push_outputs(
840
- io_context=io_context,
841
- started_at=started_at,
842
- data_format=data_format,
843
- results=results,
844
- )
845
- self.exit_context(started_at, io_context.input_ids)
986
+ # The standard output encoding+sending method for successful function outputs
987
+ outputs = await io_context.output_items(started_at, output_data)
988
+ await self._send_outputs(started_at, outputs)
846
989
 
847
990
  async def memory_restore(self) -> None:
848
991
  # Busy-wait for restore. `/__modal/restore-state.json` is created
@@ -881,13 +1024,11 @@ class _ContainerIOManager:
881
1024
  # Restore GPU memory.
882
1025
  if self.function_def._experimental_enable_gpu_snapshot and self.function_def.resources.gpu_config.gpu_type:
883
1026
  logger.debug("GPU memory snapshot enabled. Attempting to restore GPU memory.")
884
- gpu_process_state = gpu_memory_snapshot.get_state()
885
- if gpu_process_state != gpu_memory_snapshot.CudaCheckpointState.CHECKPOINTED:
886
- raise ValueError(
887
- "Cannot restore GPU state if GPU isn't in a 'checkpointed' state. "
888
- f"Current GPU state: {gpu_process_state}"
889
- )
890
- gpu_memory_snapshot.toggle()
1027
+
1028
+ assert self._cuda_checkpoint_session, (
1029
+ "CudaCheckpointSession not found when attempting to restore GPU memory"
1030
+ )
1031
+ self._cuda_checkpoint_session.restore()
891
1032
 
892
1033
  # Restore input to default state.
893
1034
  self.current_input_id = None
@@ -907,14 +1048,9 @@ class _ContainerIOManager:
907
1048
  # Snapshot GPU memory.
908
1049
  if self.function_def._experimental_enable_gpu_snapshot and self.function_def.resources.gpu_config.gpu_type:
909
1050
  logger.debug("GPU memory snapshot enabled. Attempting to snapshot GPU memory.")
910
- gpu_process_state = gpu_memory_snapshot.get_state()
911
- if gpu_process_state != gpu_memory_snapshot.CudaCheckpointState.RUNNING:
912
- raise ValueError(
913
- f"Cannot snapshot GPU state if it isn't running. Current GPU state: {gpu_process_state}"
914
- )
915
1051
 
916
- gpu_memory_snapshot.toggle()
917
- gpu_memory_snapshot.wait_for_state(gpu_memory_snapshot.CudaCheckpointState.CHECKPOINTED)
1052
+ self._cuda_checkpoint_session = gpu_memory_snapshot.CudaCheckpointSession()
1053
+ self._cuda_checkpoint_session.checkpoint()
918
1054
 
919
1055
  # Notify the heartbeat loop that the snapshot phase has begun in order to
920
1056
  # prevent it from sending heartbeat RPCs
@@ -944,13 +1080,14 @@ class _ContainerIOManager:
944
1080
  await asyncify(os.sync)()
945
1081
  results = await asyncio.gather(
946
1082
  *[
947
- retry_transient_errors(
948
- self._client.stub.VolumeCommit,
1083
+ self._client.stub.VolumeCommit(
949
1084
  api_pb2.VolumeCommitRequest(volume_id=v_id),
950
- max_retries=9,
951
- base_delay=0.25,
952
- max_delay=256,
953
- delay_factor=2,
1085
+ retry=Retry(
1086
+ max_retries=9,
1087
+ base_delay=0.25,
1088
+ max_delay=256,
1089
+ delay_factor=2,
1090
+ ),
954
1091
  )
955
1092
  for v_id in volume_ids
956
1093
  ],
@@ -1019,7 +1156,8 @@ class _ContainerIOManager:
1019
1156
 
1020
1157
  @classmethod
1021
1158
  def stop_fetching_inputs(cls):
1022
- assert cls._singleton
1159
+ if not cls._singleton:
1160
+ raise RuntimeError("Must be called from within a Modal container.")
1023
1161
  cls._singleton._fetching_inputs = False
1024
1162
 
1025
1163