modal 1.0.6.dev58__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (147) hide show
  1. modal/__main__.py +3 -4
  2. modal/_billing.py +80 -0
  3. modal/_clustered_functions.py +7 -3
  4. modal/_clustered_functions.pyi +4 -2
  5. modal/_container_entrypoint.py +41 -49
  6. modal/_functions.py +424 -195
  7. modal/_grpc_client.py +171 -0
  8. modal/_load_context.py +105 -0
  9. modal/_object.py +68 -20
  10. modal/_output.py +58 -45
  11. modal/_partial_function.py +36 -11
  12. modal/_pty.py +7 -3
  13. modal/_resolver.py +21 -35
  14. modal/_runtime/asgi.py +4 -3
  15. modal/_runtime/container_io_manager.py +301 -186
  16. modal/_runtime/container_io_manager.pyi +70 -61
  17. modal/_runtime/execution_context.py +18 -2
  18. modal/_runtime/execution_context.pyi +4 -1
  19. modal/_runtime/gpu_memory_snapshot.py +170 -63
  20. modal/_runtime/user_code_imports.py +28 -58
  21. modal/_serialization.py +57 -1
  22. modal/_utils/async_utils.py +33 -12
  23. modal/_utils/auth_token_manager.py +2 -5
  24. modal/_utils/blob_utils.py +110 -53
  25. modal/_utils/function_utils.py +49 -42
  26. modal/_utils/grpc_utils.py +80 -50
  27. modal/_utils/mount_utils.py +26 -1
  28. modal/_utils/name_utils.py +17 -3
  29. modal/_utils/task_command_router_client.py +536 -0
  30. modal/_utils/time_utils.py +34 -6
  31. modal/app.py +219 -83
  32. modal/app.pyi +229 -56
  33. modal/billing.py +5 -0
  34. modal/{requirements → builder}/2025.06.txt +1 -0
  35. modal/{requirements → builder}/PREVIEW.txt +1 -0
  36. modal/cli/_download.py +19 -3
  37. modal/cli/_traceback.py +3 -2
  38. modal/cli/app.py +4 -4
  39. modal/cli/cluster.py +15 -7
  40. modal/cli/config.py +5 -3
  41. modal/cli/container.py +7 -6
  42. modal/cli/dict.py +22 -16
  43. modal/cli/entry_point.py +12 -5
  44. modal/cli/environment.py +5 -4
  45. modal/cli/import_refs.py +3 -3
  46. modal/cli/launch.py +102 -5
  47. modal/cli/network_file_system.py +9 -13
  48. modal/cli/profile.py +3 -2
  49. modal/cli/programs/launch_instance_ssh.py +94 -0
  50. modal/cli/programs/run_jupyter.py +1 -1
  51. modal/cli/programs/run_marimo.py +95 -0
  52. modal/cli/programs/vscode.py +1 -1
  53. modal/cli/queues.py +57 -26
  54. modal/cli/run.py +58 -16
  55. modal/cli/secret.py +48 -22
  56. modal/cli/utils.py +3 -4
  57. modal/cli/volume.py +28 -25
  58. modal/client.py +13 -116
  59. modal/client.pyi +9 -91
  60. modal/cloud_bucket_mount.py +5 -3
  61. modal/cloud_bucket_mount.pyi +5 -1
  62. modal/cls.py +130 -102
  63. modal/cls.pyi +45 -85
  64. modal/config.py +29 -10
  65. modal/container_process.py +291 -13
  66. modal/container_process.pyi +95 -32
  67. modal/dict.py +282 -63
  68. modal/dict.pyi +423 -73
  69. modal/environments.py +15 -27
  70. modal/environments.pyi +5 -15
  71. modal/exception.py +8 -0
  72. modal/experimental/__init__.py +143 -38
  73. modal/experimental/flash.py +247 -78
  74. modal/experimental/flash.pyi +137 -9
  75. modal/file_io.py +14 -28
  76. modal/file_io.pyi +2 -2
  77. modal/file_pattern_matcher.py +25 -16
  78. modal/functions.pyi +134 -61
  79. modal/image.py +255 -86
  80. modal/image.pyi +300 -62
  81. modal/io_streams.py +436 -126
  82. modal/io_streams.pyi +236 -171
  83. modal/mount.py +62 -157
  84. modal/mount.pyi +45 -172
  85. modal/network_file_system.py +30 -53
  86. modal/network_file_system.pyi +16 -76
  87. modal/object.pyi +42 -8
  88. modal/parallel_map.py +821 -113
  89. modal/parallel_map.pyi +134 -0
  90. modal/partial_function.pyi +4 -1
  91. modal/proxy.py +16 -7
  92. modal/proxy.pyi +10 -2
  93. modal/queue.py +263 -61
  94. modal/queue.pyi +409 -66
  95. modal/runner.py +112 -92
  96. modal/runner.pyi +45 -27
  97. modal/sandbox.py +451 -124
  98. modal/sandbox.pyi +513 -67
  99. modal/secret.py +291 -67
  100. modal/secret.pyi +425 -19
  101. modal/serving.py +7 -11
  102. modal/serving.pyi +7 -8
  103. modal/snapshot.py +11 -8
  104. modal/token_flow.py +4 -4
  105. modal/volume.py +344 -98
  106. modal/volume.pyi +464 -68
  107. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +9 -8
  108. modal-1.2.3.dev7.dist-info/RECORD +195 -0
  109. modal_docs/mdmd/mdmd.py +11 -1
  110. modal_proto/api.proto +399 -67
  111. modal_proto/api_grpc.py +241 -1
  112. modal_proto/api_pb2.py +1395 -1000
  113. modal_proto/api_pb2.pyi +1239 -79
  114. modal_proto/api_pb2_grpc.py +499 -4
  115. modal_proto/api_pb2_grpc.pyi +162 -14
  116. modal_proto/modal_api_grpc.py +175 -160
  117. modal_proto/sandbox_router.proto +145 -0
  118. modal_proto/sandbox_router_grpc.py +105 -0
  119. modal_proto/sandbox_router_pb2.py +149 -0
  120. modal_proto/sandbox_router_pb2.pyi +333 -0
  121. modal_proto/sandbox_router_pb2_grpc.py +203 -0
  122. modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
  123. modal_proto/task_command_router.proto +144 -0
  124. modal_proto/task_command_router_grpc.py +105 -0
  125. modal_proto/task_command_router_pb2.py +149 -0
  126. modal_proto/task_command_router_pb2.pyi +333 -0
  127. modal_proto/task_command_router_pb2_grpc.py +203 -0
  128. modal_proto/task_command_router_pb2_grpc.pyi +75 -0
  129. modal_version/__init__.py +1 -1
  130. modal-1.0.6.dev58.dist-info/RECORD +0 -183
  131. modal_proto/modal_options_grpc.py +0 -3
  132. modal_proto/options.proto +0 -19
  133. modal_proto/options_grpc.py +0 -3
  134. modal_proto/options_pb2.py +0 -35
  135. modal_proto/options_pb2.pyi +0 -20
  136. modal_proto/options_pb2_grpc.py +0 -4
  137. modal_proto/options_pb2_grpc.pyi +0 -7
  138. /modal/{requirements → builder}/2023.12.312.txt +0 -0
  139. /modal/{requirements → builder}/2023.12.txt +0 -0
  140. /modal/{requirements → builder}/2024.04.txt +0 -0
  141. /modal/{requirements → builder}/2024.10.txt +0 -0
  142. /modal/{requirements → builder}/README.md +0 -0
  143. /modal/{requirements → builder}/base-images.json +0 -0
  144. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
  145. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
  146. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
  147. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
modal/_functions.py CHANGED
@@ -6,10 +6,10 @@ import textwrap
6
6
  import time
7
7
  import typing
8
8
  import warnings
9
- from collections.abc import AsyncGenerator, Sequence, Sized
9
+ from collections.abc import AsyncGenerator, Collection, Sequence, Sized
10
10
  from dataclasses import dataclass
11
11
  from pathlib import PurePosixPath
12
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union
12
+ from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Optional, Union
13
13
 
14
14
  import typing_extensions
15
15
  from google.protobuf.message import Message
@@ -19,7 +19,8 @@ from synchronicity.combined_types import MethodWithAio
19
19
  from modal_proto import api_pb2
20
20
  from modal_proto.modal_api_grpc import ModalClientModal
21
21
 
22
- from ._object import _get_environment_name, _Object, live_method, live_method_gen
22
+ from ._load_context import LoadContext
23
+ from ._object import _Object, live_method, live_method_gen
23
24
  from ._pty import get_pty_info
24
25
  from ._resolver import Resolver
25
26
  from ._resources import convert_fn_config_to_resources_config
@@ -47,15 +48,13 @@ from ._utils.function_utils import (
47
48
  OUTPUTS_TIMEOUT,
48
49
  FunctionCreationStatus,
49
50
  FunctionInfo,
50
- IncludeSourceMode,
51
51
  _create_input,
52
52
  _process_result,
53
53
  _stream_function_call_data,
54
54
  get_function_type,
55
- get_include_source_mode,
56
55
  is_async,
57
56
  )
58
- from ._utils.grpc_utils import RetryWarningMessage, retry_transient_errors
57
+ from ._utils.grpc_utils import Retry, RetryWarningMessage
59
58
  from ._utils.mount_utils import validate_network_file_systems, validate_volumes
60
59
  from .call_graph import InputInfo, _reconstruct_call_graph
61
60
  from .client import _Client
@@ -73,12 +72,16 @@ from .mount import _get_client_mount, _Mount
73
72
  from .network_file_system import _NetworkFileSystem, network_file_system_mount_protos
74
73
  from .output import _get_output_manager
75
74
  from .parallel_map import (
75
+ _experimental_spawn_map_async,
76
+ _experimental_spawn_map_sync,
76
77
  _for_each_async,
77
78
  _for_each_sync,
78
79
  _map_async,
79
80
  _map_invocation,
81
+ _map_invocation_inputplane,
80
82
  _map_sync,
81
83
  _spawn_map_async,
84
+ _spawn_map_invocation,
82
85
  _spawn_map_sync,
83
86
  _starmap_async,
84
87
  _starmap_sync,
@@ -92,12 +95,14 @@ from .secret import _Secret
92
95
  from .volume import _Volume
93
96
 
94
97
  if TYPE_CHECKING:
95
- import modal._partial_function
96
98
  import modal.app
97
99
  import modal.cls
98
- import modal.partial_function
99
100
 
100
101
  MAX_INTERNAL_FAILURE_COUNT = 8
102
+ TERMINAL_STATUSES = (
103
+ api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
104
+ api_pb2.GenericResult.GENERIC_STATUS_TERMINATED,
105
+ )
101
106
 
102
107
 
103
108
  @dataclasses.dataclass
@@ -146,8 +151,7 @@ class _Invocation:
146
151
  args,
147
152
  kwargs,
148
153
  stub,
149
- max_object_size_bytes=function._max_object_size_bytes,
150
- method_name=function._use_method_name,
154
+ function=function,
151
155
  function_call_invocation_type=function_call_invocation_type,
152
156
  )
153
157
 
@@ -161,21 +165,22 @@ class _Invocation:
161
165
 
162
166
  if from_spawn_map:
163
167
  request.from_spawn_map = True
164
- response = await retry_transient_errors(
165
- client.stub.FunctionMap,
168
+ response = await client.stub.FunctionMap(
166
169
  request,
167
- max_retries=None,
168
- max_delay=30.0,
169
- retry_warning_message=RetryWarningMessage(
170
- message="Warning: `.spawn_map(...)` for function `{self._function_name}` is waiting to create"
171
- "more function calls. This may be due to hitting rate limits or function backlog limits.",
172
- warning_interval=10,
173
- errors_to_warn_for=[Status.RESOURCE_EXHAUSTED],
170
+ retry=Retry(
171
+ max_retries=None,
172
+ max_delay=30.0,
173
+ warning_message=RetryWarningMessage(
174
+ message="Warning: `.spawn_map(...)` for function `{self._function_name}` is waiting to create"
175
+ "more function calls. This may be due to hitting rate limits or function backlog limits.",
176
+ warning_interval=10,
177
+ errors_to_warn_for=[Status.RESOURCE_EXHAUSTED],
178
+ ),
179
+ additional_status_codes=[Status.RESOURCE_EXHAUSTED],
174
180
  ),
175
- additional_status_codes=[Status.RESOURCE_EXHAUSTED],
176
181
  )
177
182
  else:
178
- response = await retry_transient_errors(client.stub.FunctionMap, request)
183
+ response = await client.stub.FunctionMap(request)
179
184
 
180
185
  function_call_id = response.function_call_id
181
186
  if response.pipelined_inputs:
@@ -195,10 +200,7 @@ class _Invocation:
195
200
  request_put = api_pb2.FunctionPutInputsRequest(
196
201
  function_id=function_id, inputs=[item], function_call_id=function_call_id
197
202
  )
198
- inputs_response: api_pb2.FunctionPutInputsResponse = await retry_transient_errors(
199
- client.stub.FunctionPutInputs,
200
- request_put,
201
- )
203
+ inputs_response: api_pb2.FunctionPutInputsResponse = await client.stub.FunctionPutInputs(request_put)
202
204
  processed_inputs = inputs_response.inputs
203
205
  if not processed_inputs:
204
206
  raise Exception("Could not create function call - the input queue seems to be full")
@@ -215,7 +217,11 @@ class _Invocation:
215
217
  return _Invocation(stub, function_call_id, client, retry_context)
216
218
 
217
219
  async def pop_function_call_outputs(
218
- self, timeout: Optional[float], clear_on_success: bool, input_jwts: Optional[list[str]] = None
220
+ self,
221
+ index: int = 0,
222
+ timeout: Optional[float] = None,
223
+ clear_on_success: bool = False,
224
+ input_jwts: Optional[list[str]] = None,
219
225
  ) -> api_pb2.FunctionGetOutputsResponse:
220
226
  t0 = time.time()
221
227
  if timeout is None:
@@ -233,11 +239,12 @@ class _Invocation:
233
239
  clear_on_success=clear_on_success,
234
240
  requested_at=time.time(),
235
241
  input_jwts=input_jwts,
242
+ start_idx=index,
243
+ end_idx=index,
236
244
  )
237
- response: api_pb2.FunctionGetOutputsResponse = await retry_transient_errors(
238
- self.stub.FunctionGetOutputs,
245
+ response: api_pb2.FunctionGetOutputsResponse = await self.stub.FunctionGetOutputs(
239
246
  request,
240
- attempt_timeout=backend_timeout + ATTEMPT_TIMEOUT_GRACE_PERIOD,
247
+ retry=Retry(attempt_timeout=backend_timeout + ATTEMPT_TIMEOUT_GRACE_PERIOD),
241
248
  )
242
249
 
243
250
  if len(response.outputs) > 0:
@@ -257,15 +264,13 @@ class _Invocation:
257
264
 
258
265
  item = api_pb2.FunctionRetryInputsItem(input_jwt=ctx.input_jwt, input=ctx.item.input)
259
266
  request = api_pb2.FunctionRetryInputsRequest(function_call_jwt=ctx.function_call_jwt, inputs=[item])
260
- await retry_transient_errors(
261
- self.stub.FunctionRetryInputs,
262
- request,
263
- )
267
+ await self.stub.FunctionRetryInputs(request)
264
268
 
265
269
  async def _get_single_output(self, expected_jwt: Optional[str] = None) -> api_pb2.FunctionGetOutputsItem:
266
270
  # waits indefinitely for a single result for the function, and clear the outputs buffer after
267
271
  item: api_pb2.FunctionGetOutputsItem = (
268
272
  await self.pop_function_call_outputs(
273
+ index=0,
269
274
  timeout=None,
270
275
  clear_on_success=True,
271
276
  input_jwts=[expected_jwt] if expected_jwt else None,
@@ -291,11 +296,7 @@ class _Invocation:
291
296
 
292
297
  while True:
293
298
  item = await self._get_single_output(ctx.input_jwt)
294
- if item.result.status in (
295
- api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
296
- api_pb2.GenericResult.GENERIC_STATUS_TERMINATED,
297
- ):
298
- # success or cancellations are "final" results
299
+ if item.result.status in TERMINAL_STATUSES:
299
300
  return await _process_result(item.result, item.data_format, self.stub, self.client)
300
301
 
301
302
  if item.result.status != api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE:
@@ -309,14 +310,16 @@ class _Invocation:
309
310
 
310
311
  await self._retry_input()
311
312
 
312
- async def poll_function(self, timeout: Optional[float] = None):
313
+ async def poll_function(self, timeout: Optional[float] = None, *, index: int = 0):
313
314
  """Waits up to timeout for a result from a function.
314
315
 
315
316
  If timeout is `None`, waits indefinitely. This function is not
316
317
  cancellation-safe.
317
318
  """
318
319
  response: api_pb2.FunctionGetOutputsResponse = await self.pop_function_call_outputs(
319
- timeout=timeout, clear_on_success=False
320
+ index=index,
321
+ timeout=timeout,
322
+ clear_on_success=False,
320
323
  )
321
324
  if len(response.outputs) == 0 and response.num_unfinished_inputs == 0:
322
325
  # if no unfinished inputs and no outputs, then function expired
@@ -349,11 +352,45 @@ class _Invocation:
349
352
  if items_total is not None and items_received >= items_total:
350
353
  break
351
354
 
355
+ async def enumerate(self, start_index: int, end_index: int):
356
+ """Iterate over the results of the function call in the range [start_index, end_index)."""
357
+ limit = 49
358
+ current_index = start_index
359
+ while current_index < end_index:
360
+ # batch_end_indx is inclusive, so we subtract 1 to get the last index in the batch.
361
+ batch_end_index = min(current_index + limit, end_index) - 1
362
+ request = api_pb2.FunctionGetOutputsRequest(
363
+ function_call_id=self.function_call_id,
364
+ timeout=0,
365
+ last_entry_id="0-0",
366
+ clear_on_success=False,
367
+ requested_at=time.time(),
368
+ start_idx=current_index,
369
+ end_idx=batch_end_index,
370
+ )
371
+ response: api_pb2.FunctionGetOutputsResponse = await self.stub.FunctionGetOutputs(
372
+ request, retry=Retry(attempt_timeout=ATTEMPT_TIMEOUT_GRACE_PERIOD)
373
+ )
374
+
375
+ outputs = list(response.outputs)
376
+ outputs.sort(key=lambda x: x.idx)
377
+ for output in outputs:
378
+ if output.idx != current_index:
379
+ break
380
+ result = await _process_result(output.result, output.data_format, self.stub, self.client)
381
+ yield output.idx, result
382
+ current_index += 1
383
+
384
+ # We're missing current_index, so we need to poll the function for the next result
385
+ if len(outputs) < (batch_end_index - current_index + 1):
386
+ result = await self.poll_function(index=current_index)
387
+ yield current_index, result
388
+ current_index += 1
389
+
352
390
 
353
391
  class _InputPlaneInvocation:
354
392
  """Internal client representation of a single-input call to a Modal Function using the input
355
- plane server API. As of 4/22/2025, this class is experimental and not used in production.
356
- It is OK to make breaking changes to this class."""
393
+ plane server API."""
357
394
 
358
395
  stub: ModalClientModal
359
396
 
@@ -364,6 +401,7 @@ class _InputPlaneInvocation:
364
401
  client: _Client,
365
402
  input_item: api_pb2.FunctionPutInputsItem,
366
403
  function_id: str,
404
+ retry_policy: api_pb2.FunctionRetryPolicy,
367
405
  input_plane_region: str,
368
406
  ):
369
407
  self.stub = stub
@@ -371,6 +409,7 @@ class _InputPlaneInvocation:
371
409
  self.attempt_token = attempt_token
372
410
  self.input_item = input_item
373
411
  self.function_id = function_id
412
+ self.retry_policy = retry_policy
374
413
  self.input_plane_region = input_plane_region
375
414
 
376
415
  @staticmethod
@@ -392,8 +431,7 @@ class _InputPlaneInvocation:
392
431
  args,
393
432
  kwargs,
394
433
  control_plane_stub,
395
- max_object_size_bytes=function._max_object_size_bytes,
396
- method_name=function._use_method_name,
434
+ function=function,
397
435
  )
398
436
 
399
437
  request = api_pb2.AttemptStartRequest(
@@ -401,15 +439,20 @@ class _InputPlaneInvocation:
401
439
  parent_input_id=current_input_id() or "",
402
440
  input=input_item,
403
441
  )
404
- metadata = await _InputPlaneInvocation._get_metadata(input_plane_region, client)
405
- response = await retry_transient_errors(stub.AttemptStart, request, metadata=metadata)
442
+
443
+ metadata = await client.get_input_plane_metadata(input_plane_region)
444
+ response = await stub.AttemptStart(request, metadata=metadata)
406
445
  attempt_token = response.attempt_token
407
446
 
408
- return _InputPlaneInvocation(stub, attempt_token, client, input_item, function_id, input_plane_region)
447
+ return _InputPlaneInvocation(
448
+ stub, attempt_token, client, input_item, function_id, response.retry_policy, input_plane_region
449
+ )
409
450
 
410
451
  async def run_function(self) -> Any:
452
+ # User errors including timeouts are managed by the user-specified retry policy.
453
+ user_retry_manager = RetryManager(self.retry_policy)
454
+
411
455
  # This will retry when the server returns GENERIC_STATUS_INTERNAL_FAILURE, i.e. lost inputs or worker preemption
412
- # TODO(ryan): add logic to retry for user defined retry policy
413
456
  internal_failure_count = 0
414
457
  while True:
415
458
  await_request = api_pb2.AttemptAwaitRequest(
@@ -417,42 +460,79 @@ class _InputPlaneInvocation:
417
460
  timeout_secs=OUTPUTS_TIMEOUT,
418
461
  requested_at=time.time(),
419
462
  )
420
- metadata = await self._get_metadata(self.input_plane_region, self.client)
421
- await_response: api_pb2.AttemptAwaitResponse = await retry_transient_errors(
422
- self.stub.AttemptAwait,
463
+ metadata = await self.client.get_input_plane_metadata(self.input_plane_region)
464
+ await_response: api_pb2.AttemptAwaitResponse = await self.stub.AttemptAwait(
423
465
  await_request,
424
- attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
466
+ retry=Retry(attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD),
425
467
  metadata=metadata,
426
468
  )
427
469
 
428
- if await_response.HasField("output"):
429
- if await_response.output.result.status == api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE:
430
- internal_failure_count += 1
431
- # Limit the number of times we retry
432
- if internal_failure_count < MAX_INTERNAL_FAILURE_COUNT:
433
- # For system failures on the server, we retry immediately,
434
- # and the failure does not count towards the retry policy.
435
- retry_request = api_pb2.AttemptRetryRequest(
436
- function_id=self.function_id,
437
- parent_input_id=current_input_id() or "",
438
- input=self.input_item,
439
- attempt_token=self.attempt_token,
440
- )
441
- # TODO(ryan): Add exponential backoff?
442
- retry_response = await retry_transient_errors(
443
- self.stub.AttemptRetry,
444
- retry_request,
445
- metadata=metadata,
446
- )
447
- self.attempt_token = retry_response.attempt_token
448
- continue
470
+ # Keep awaiting until we get an output.
471
+ if not await_response.HasField("output"):
472
+ continue
449
473
 
450
- control_plane_stub = self.client.stub
451
- # Note: Blob download is done on the control plane stub, not the input plane stub!
474
+ # If we have a final output, return.
475
+ if await_response.output.result.status in TERMINAL_STATUSES:
452
476
  return await _process_result(
453
- await_response.output.result, await_response.output.data_format, control_plane_stub, self.client
477
+ await_response.output.result, await_response.output.data_format, self.client.stub, self.client
454
478
  )
455
479
 
480
+ # We have a failure (internal or application), so see if there are any retries left, and if so, retry.
481
+ if await_response.output.result.status == api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE:
482
+ internal_failure_count += 1
483
+ # Limit the number of times we retry internal failures.
484
+ if internal_failure_count < MAX_INTERNAL_FAILURE_COUNT:
485
+ # We immediately retry internal failures and the failure doesn't count towards the retry policy.
486
+ self.attempt_token = await self._retry_input(metadata)
487
+ continue
488
+ elif (delay_ms := user_retry_manager.get_delay_ms()) is not None:
489
+ # We still have user retries left, so sleep and retry.
490
+ await asyncio.sleep(delay_ms / 1000)
491
+ self.attempt_token = await self._retry_input(metadata)
492
+ continue
493
+
494
+ # No more retries left.
495
+ return await _process_result(
496
+ await_response.output.result, await_response.output.data_format, self.client.stub, self.client
497
+ )
498
+
499
+ async def _retry_input(self, metadata: list[tuple[str, str]]) -> str:
500
+ retry_request = api_pb2.AttemptRetryRequest(
501
+ function_id=self.function_id,
502
+ parent_input_id=current_input_id() or "",
503
+ input=self.input_item,
504
+ attempt_token=self.attempt_token,
505
+ )
506
+ retry_response = await self.stub.AttemptRetry(retry_request, metadata=metadata)
507
+ return retry_response.attempt_token
508
+
509
+ async def run_generator(self):
510
+ items_received = 0
511
+ # populated when self.run_function() completes
512
+ items_total: Union[int, None] = None
513
+ async with aclosing(
514
+ async_merge(
515
+ _stream_function_call_data(
516
+ self.client,
517
+ self.stub,
518
+ function_call_id=None,
519
+ variant="data_out",
520
+ attempt_token=self.attempt_token,
521
+ ),
522
+ callable_to_agen(self.run_function),
523
+ )
524
+ ) as streamer:
525
+ async for item in streamer:
526
+ if isinstance(item, api_pb2.GeneratorDone):
527
+ items_total = item.items_total
528
+ else:
529
+ yield item
530
+ items_received += 1
531
+ # The comparison avoids infinite loops if a non-deterministic generator is retried
532
+ # and produces less data in the second run than what was already sent.
533
+ if items_total is not None and items_received >= items_total:
534
+ break
535
+
456
536
  @staticmethod
457
537
  async def _get_metadata(input_plane_region: str, client: _Client) -> list[tuple[str, str]]:
458
538
  if not input_plane_region:
@@ -500,7 +580,7 @@ class _FunctionSpec:
500
580
 
501
581
  image: Optional[_Image]
502
582
  mounts: Sequence[_Mount]
503
- secrets: Sequence[_Secret]
583
+ secrets: Collection[_Secret]
504
584
  network_file_systems: dict[Union[str, PurePosixPath], _NetworkFileSystem]
505
585
  volumes: dict[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]]
506
586
  # TODO(irfansharif): Somehow assert that it's the first kind, in sandboxes
@@ -513,6 +593,21 @@ class _FunctionSpec:
513
593
  proxy: Optional[_Proxy]
514
594
 
515
595
 
596
+ def _get_supported_input_output_formats(is_web_endpoint: bool, is_generator: bool, restrict_output: bool):
597
+ if is_web_endpoint:
598
+ supported_input_formats = [api_pb2.DATA_FORMAT_ASGI]
599
+ supported_output_formats = [api_pb2.DATA_FORMAT_ASGI, api_pb2.DATA_FORMAT_GENERATOR_DONE]
600
+ else:
601
+ supported_input_formats = [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]
602
+ if restrict_output:
603
+ supported_output_formats = [api_pb2.DATA_FORMAT_CBOR]
604
+ else:
605
+ supported_output_formats = [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]
606
+ if is_generator:
607
+ supported_output_formats.append(api_pb2.DATA_FORMAT_GENERATOR_DONE)
608
+ return supported_input_formats, supported_output_formats
609
+
610
+
516
611
  P = typing_extensions.ParamSpec("P")
517
612
  ReturnType = typing.TypeVar("ReturnType", covariant=True)
518
613
  OriginalReturnType = typing.TypeVar(
@@ -562,9 +657,10 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
562
657
  @staticmethod
563
658
  def from_local(
564
659
  info: FunctionInfo,
565
- app,
660
+ app: Optional["modal.app._App"], # App here should only be None in case of Image.run_function
566
661
  image: _Image,
567
- secrets: Sequence[_Secret] = (),
662
+ env: Optional[dict[str, Optional[str]]] = None,
663
+ secrets: Optional[Collection[_Secret]] = None,
568
664
  schedule: Optional[Schedule] = None,
569
665
  is_generator: bool = False,
570
666
  gpu: Union[GPU_T, list[GPU_T]] = None,
@@ -576,7 +672,8 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
576
672
  memory: Optional[Union[int, tuple[int, int]]] = None,
577
673
  proxy: Optional[_Proxy] = None,
578
674
  retries: Optional[Union[int, Retries]] = None,
579
- timeout: Optional[int] = None,
675
+ timeout: int = 300,
676
+ startup_timeout: Optional[int] = None,
580
677
  min_containers: Optional[int] = None,
581
678
  max_containers: Optional[int] = None,
582
679
  buffer_containers: Optional[int] = None,
@@ -598,14 +695,16 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
598
695
  rdma: Optional[bool] = None,
599
696
  max_inputs: Optional[int] = None,
600
697
  ephemeral_disk: Optional[int] = None,
601
- # current default: first-party, future default: main-package
602
- include_source: Optional[bool] = None,
698
+ include_source: bool = True,
603
699
  experimental_options: Optional[dict[str, str]] = None,
604
700
  _experimental_proxy_ip: Optional[str] = None,
605
701
  _experimental_custom_scaling_factor: Optional[float] = None,
606
- _experimental_enable_gpu_snapshot: bool = False,
702
+ restrict_output: bool = False,
607
703
  ) -> "_Function":
608
- """mdmd:hidden"""
704
+ """mdmd:hidden
705
+
706
+ Note: This is not intended to be public API.
707
+ """
609
708
  # Needed to avoid circular imports
610
709
  from ._partial_function import _find_partial_methods_for_user_cls, _PartialFunctionFlags
611
710
 
@@ -624,15 +723,10 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
624
723
  assert not webhook_config
625
724
  assert not schedule
626
725
 
627
- include_source_mode = get_include_source_mode(include_source)
628
- if include_source_mode != IncludeSourceMode.INCLUDE_NOTHING:
629
- entrypoint_mounts = info.get_entrypoint_mount()
630
- else:
631
- entrypoint_mounts = {}
632
-
726
+ entrypoint_mount = info.get_entrypoint_mount() if include_source else {}
633
727
  all_mounts = [
634
728
  _get_client_mount(),
635
- *entrypoint_mounts.values(),
729
+ *entrypoint_mount.values(),
636
730
  ]
637
731
 
638
732
  retry_policy = _parse_retries(
@@ -645,6 +739,13 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
645
739
  if is_generator:
646
740
  raise InvalidError("Generator functions do not support retries.")
647
741
 
742
+ if timeout is None: # type: ignore[unreachable] # Help users who aren't using type checkers
743
+ raise InvalidError("The `timeout` parameter cannot be set to None: https://modal.com/docs/guide/timeouts")
744
+
745
+ secrets = secrets or []
746
+ if env:
747
+ secrets = [*secrets, _Secret.from_dict(env)]
748
+
648
749
  function_spec = _FunctionSpec(
649
750
  mounts=all_mounts,
650
751
  secrets=secrets,
@@ -732,17 +833,23 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
732
833
  for method_name, partial_function in interface_methods.items():
733
834
  function_type = get_function_type(partial_function.params.is_generator)
734
835
  function_name = f"{info.user_cls.__name__}.{method_name}"
836
+ is_web_endpoint = partial_function._is_web_endpoint()
735
837
  method_schema = get_callable_schema(
736
838
  partial_function._get_raw_f(),
737
- is_web_endpoint=partial_function._is_web_endpoint(),
839
+ is_web_endpoint=is_web_endpoint,
738
840
  ignore_first_argument=True,
739
841
  )
842
+ method_input_formats, method_output_formats = _get_supported_input_output_formats(
843
+ is_web_endpoint, partial_function.params.is_generator or False, restrict_output
844
+ )
740
845
 
741
846
  method_definition = api_pb2.MethodDefinition(
742
847
  webhook_config=partial_function.params.webhook_config,
743
848
  function_type=function_type,
744
849
  function_name=function_name,
745
850
  function_schema=method_schema,
851
+ supported_input_formats=method_input_formats,
852
+ supported_output_formats=method_output_formats,
746
853
  )
747
854
  method_definitions[method_name] = method_definition
748
855
 
@@ -766,18 +873,30 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
766
873
 
767
874
  return deps
768
875
 
769
- async def _preload(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
770
- assert resolver.client and resolver.client.stub
876
+ if info.is_service_class():
877
+ # classes don't have data formats themselves - input/output formats are set per method above
878
+ supported_input_formats = []
879
+ supported_output_formats = []
880
+ else:
881
+ is_web_endpoint = webhook_config is not None and webhook_config.type != api_pb2.WEBHOOK_TYPE_UNSPECIFIED
882
+ supported_input_formats, supported_output_formats = _get_supported_input_output_formats(
883
+ is_web_endpoint, is_generator, restrict_output
884
+ )
771
885
 
772
- assert resolver.app_id
886
+ async def _preload(
887
+ self: _Function, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]
888
+ ):
889
+ assert load_context.app_id
773
890
  req = api_pb2.FunctionPrecreateRequest(
774
- app_id=resolver.app_id,
891
+ app_id=load_context.app_id,
775
892
  function_name=info.function_name,
776
893
  function_type=function_type,
777
894
  existing_function_id=existing_object_id or "",
778
895
  function_schema=get_callable_schema(info.raw_f, is_web_endpoint=bool(webhook_config))
779
896
  if info.raw_f
780
897
  else None,
898
+ supported_input_formats=supported_input_formats,
899
+ supported_output_formats=supported_output_formats,
781
900
  )
782
901
  if method_definitions:
783
902
  for method_name, method_definition in method_definitions.items():
@@ -785,11 +904,12 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
785
904
  elif webhook_config:
786
905
  req.webhook_config.CopyFrom(webhook_config)
787
906
 
788
- response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req)
789
- self._hydrate(response.function_id, resolver.client, response.handle_metadata)
907
+ response = await load_context.client.stub.FunctionPrecreate(req)
908
+ self._hydrate(response.function_id, load_context.client, response.handle_metadata)
790
909
 
791
- async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
792
- assert resolver.client and resolver.client.stub
910
+ async def _load(
911
+ self: _Function, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]
912
+ ):
793
913
  with FunctionCreationStatus(resolver, tag) as function_creation_status:
794
914
  timeout_secs = timeout
795
915
 
@@ -856,6 +976,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
856
976
  function_schema = (
857
977
  get_callable_schema(info.raw_f, is_web_endpoint=bool(webhook_config)) if info.raw_f else None
858
978
  )
979
+
859
980
  # Create function remotely
860
981
  function_definition = api_pb2.Function(
861
982
  module_name=info.module_name or "",
@@ -876,6 +997,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
876
997
  proxy_id=(proxy.object_id if proxy else None),
877
998
  retry_policy=retry_policy,
878
999
  timeout_secs=timeout_secs or 0,
1000
+ startup_timeout_secs=startup_timeout or timeout_secs,
879
1001
  pty_info=pty_info,
880
1002
  cloud_provider_str=cloud if cloud else "",
881
1003
  runtime=config.get("function_runtime"),
@@ -909,7 +1031,6 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
909
1031
  _experimental_concurrent_cancellations=True,
910
1032
  _experimental_proxy_ip=_experimental_proxy_ip,
911
1033
  _experimental_custom_scaling=_experimental_custom_scaling_factor is not None,
912
- _experimental_enable_gpu_snapshot=_experimental_enable_gpu_snapshot,
913
1034
  # --- These are deprecated in favor of autoscaler_settings
914
1035
  warm_pool_size=min_containers or 0,
915
1036
  concurrency_limit=max_containers or 0,
@@ -917,6 +1038,8 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
917
1038
  task_idle_timeout_secs=scaledown_window or 0,
918
1039
  # ---
919
1040
  function_schema=function_schema,
1041
+ supported_input_formats=supported_input_formats,
1042
+ supported_output_formats=supported_output_formats,
920
1043
  )
921
1044
 
922
1045
  if isinstance(gpu, list):
@@ -930,6 +1053,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
930
1053
  autoscaler_settings=function_definition.autoscaler_settings,
931
1054
  worker_id=function_definition.worker_id,
932
1055
  timeout_secs=function_definition.timeout_secs,
1056
+ startup_timeout_secs=function_definition.startup_timeout_secs,
933
1057
  web_url=function_definition.web_url,
934
1058
  web_url_info=function_definition.web_url_info,
935
1059
  webhook_config=function_definition.webhook_config,
@@ -946,12 +1070,13 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
946
1070
  _experimental_group_size=function_definition._experimental_group_size,
947
1071
  _experimental_buffer_containers=function_definition._experimental_buffer_containers,
948
1072
  _experimental_custom_scaling=function_definition._experimental_custom_scaling,
949
- _experimental_enable_gpu_snapshot=_experimental_enable_gpu_snapshot,
950
1073
  _experimental_proxy_ip=function_definition._experimental_proxy_ip,
951
1074
  snapshot_debug=function_definition.snapshot_debug,
952
1075
  runtime_perf_record=function_definition.runtime_perf_record,
953
1076
  function_schema=function_schema,
954
1077
  untrusted=function_definition.untrusted,
1078
+ supported_input_formats=supported_input_formats,
1079
+ supported_output_formats=supported_output_formats,
955
1080
  )
956
1081
 
957
1082
  ranked_functions = []
@@ -980,18 +1105,16 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
980
1105
  ),
981
1106
  )
982
1107
 
983
- assert resolver.app_id
1108
+ assert load_context.app_id
984
1109
  assert (function_definition is None) != (function_data is None) # xor
985
1110
  request = api_pb2.FunctionCreateRequest(
986
- app_id=resolver.app_id,
1111
+ app_id=load_context.app_id,
987
1112
  function=function_definition,
988
1113
  function_data=function_data,
989
1114
  existing_function_id=existing_object_id or "",
990
1115
  )
991
1116
  try:
992
- response: api_pb2.FunctionCreateResponse = await retry_transient_errors(
993
- resolver.client.stub.FunctionCreate, request
994
- )
1117
+ response: api_pb2.FunctionCreateResponse = await load_context.client.stub.FunctionCreate(request)
995
1118
  except GRPCError as exc:
996
1119
  if exc.status == Status.INVALID_ARGUMENT:
997
1120
  raise InvalidError(exc.message)
@@ -1006,10 +1129,14 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1006
1129
  serve_mounts = {m for m in all_mounts if m.is_local()}
1007
1130
  serve_mounts |= image._serve_mounts
1008
1131
  obj._serve_mounts = frozenset(serve_mounts)
1009
- self._hydrate(response.function_id, resolver.client, response.handle_metadata)
1132
+ self._hydrate(response.function_id, load_context.client, response.handle_metadata)
1010
1133
 
1011
1134
  rep = f"Function({tag})"
1012
- obj = _Function._from_loader(_load, rep, preload=_preload, deps=_deps)
1135
+ # Pass a *reference* to the App's LoadContext - this is important since the App is
1136
+ # the only way to infer a LoadContext for an `@app.function`, and the App doesn't
1137
+ # get its client until *after* the Function is created.
1138
+ load_context = app._root_load_context if app else LoadContext.empty()
1139
+ obj = _Function._from_loader(_load, rep, preload=_preload, deps=_deps, load_context_overrides=load_context)
1013
1140
 
1014
1141
  obj._raw_f = info.raw_f
1015
1142
  obj._info = info
@@ -1051,7 +1178,12 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1051
1178
 
1052
1179
  parent = self
1053
1180
 
1054
- async def _load(param_bound_func: _Function, resolver: Resolver, existing_object_id: Optional[str]):
1181
+ async def _load(
1182
+ param_bound_func: _Function,
1183
+ resolver: Resolver,
1184
+ load_context: LoadContext,
1185
+ existing_object_id: Optional[str],
1186
+ ):
1055
1187
  if not parent.is_hydrated:
1056
1188
  # While the base Object.hydrate() method appears to be idempotent, it's not always safe
1057
1189
  await parent.hydrate()
@@ -1084,7 +1216,6 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1084
1216
  param_bound_func._hydrate_from_other(parent)
1085
1217
  return
1086
1218
 
1087
- environment_name = _get_environment_name(None, resolver)
1088
1219
  assert parent is not None and parent.is_hydrated
1089
1220
 
1090
1221
  if options:
@@ -1102,6 +1233,8 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1102
1233
  replace_secret_ids=bool(options.secrets),
1103
1234
  replace_volume_mounts=len(volume_mounts) > 0,
1104
1235
  volume_mounts=volume_mounts,
1236
+ cloud_bucket_mounts=cloud_bucket_mounts_to_proto(options.cloud_bucket_mounts),
1237
+ replace_cloud_bucket_mounts=bool(options.cloud_bucket_mounts),
1105
1238
  resources=options.resources,
1106
1239
  retry_policy=options.retry_policy,
1107
1240
  concurrency_limit=options.max_containers,
@@ -1112,6 +1245,8 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1112
1245
  target_concurrent_inputs=options.target_concurrent_inputs,
1113
1246
  batch_max_size=options.batch_max_size,
1114
1247
  batch_linger_ms=options.batch_wait_ms,
1248
+ scheduler_placement=options.scheduler_placement,
1249
+ cloud_provider_str=options.cloud,
1115
1250
  )
1116
1251
  else:
1117
1252
  options_pb = None
@@ -1120,20 +1255,30 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1120
1255
  function_id=parent.object_id,
1121
1256
  serialized_params=serialized_params,
1122
1257
  function_options=options_pb,
1123
- environment_name=environment_name
1258
+ environment_name=load_context.environment_name
1124
1259
  or "", # TODO: investigate shouldn't environment name always be specified here?
1125
1260
  )
1126
1261
 
1127
- response = await retry_transient_errors(parent._client.stub.FunctionBindParams, req)
1262
+ response = await parent._client.stub.FunctionBindParams(req)
1128
1263
  param_bound_func._hydrate(response.bound_function_id, parent._client, response.handle_metadata)
1129
1264
 
1130
1265
  def _deps():
1131
1266
  if options:
1132
- all_deps = [v for _, v in options.validated_volumes] + list(options.secrets)
1267
+ all_deps = (
1268
+ [v for _, v in options.validated_volumes]
1269
+ + list(options.secrets)
1270
+ + [mount.secret for _, mount in options.cloud_bucket_mounts if mount.secret]
1271
+ )
1133
1272
  return [dep for dep in all_deps if not dep.is_hydrated]
1134
1273
  return []
1135
1274
 
1136
- fun: _Function = _Function._from_loader(_load, "Function(parametrized)", hydrate_lazily=True, deps=_deps)
1275
+ fun: _Function = _Function._from_loader(
1276
+ _load,
1277
+ "Function(parametrized)",
1278
+ hydrate_lazily=True,
1279
+ deps=_deps,
1280
+ load_context_overrides=self._load_context_overrides,
1281
+ )
1137
1282
 
1138
1283
  fun._info = self._info
1139
1284
  fun._obj = obj
@@ -1184,7 +1329,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1184
1329
  scaledown_window=scaledown_window,
1185
1330
  )
1186
1331
  request = api_pb2.FunctionUpdateSchedulingParamsRequest(function_id=self.object_id, settings=settings)
1187
- await retry_transient_errors(self.client.stub.FunctionUpdateSchedulingParams, request)
1332
+ await self.client.stub.FunctionUpdateSchedulingParams(request)
1188
1333
 
1189
1334
  # One idea would be for FunctionUpdateScheduleParams to return the current (coalesced) settings
1190
1335
  # and then we could return them here (would need some ad hoc dataclass, which I don't love)
@@ -1231,33 +1376,43 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1231
1376
  cls,
1232
1377
  app_name: str,
1233
1378
  name: str,
1234
- namespace=None, # mdmd:line-hidden
1235
- environment_name: Optional[str] = None,
1379
+ *,
1380
+ load_context_overrides: LoadContext,
1236
1381
  ):
1237
1382
  # internal function lookup implementation that allows lookup of class "service functions"
1238
1383
  # in addition to non-class functions
1239
- async def _load_remote(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
1240
- assert resolver.client and resolver.client.stub
1384
+ async def _load_remote(
1385
+ self: _Function, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]
1386
+ ):
1241
1387
  request = api_pb2.FunctionGetRequest(
1242
1388
  app_name=app_name,
1243
1389
  object_tag=name,
1244
- environment_name=_get_environment_name(environment_name, resolver) or "",
1390
+ environment_name=load_context.environment_name,
1245
1391
  )
1246
1392
  try:
1247
- response = await retry_transient_errors(resolver.client.stub.FunctionGet, request)
1393
+ response = await load_context.client.stub.FunctionGet(request)
1248
1394
  except NotFoundError as exc:
1249
1395
  # refine the error message
1250
- env_context = f" (in the '{environment_name}' environment)" if environment_name else ""
1396
+ env_context = (
1397
+ f" (in the '{load_context.environment_name}' environment)" if load_context.environment_name else ""
1398
+ )
1251
1399
  raise NotFoundError(
1252
1400
  f"Lookup failed for Function '{name}' from the '{app_name}' app{env_context}: {exc}."
1253
1401
  ) from None
1254
1402
 
1255
1403
  print_server_warnings(response.server_warnings)
1256
1404
 
1257
- self._hydrate(response.function_id, resolver.client, response.handle_metadata)
1405
+ self._hydrate(response.function_id, load_context.client, response.handle_metadata)
1258
1406
 
1259
- rep = f"Function.from_name('{app_name}', '{name}')"
1260
- return cls._from_loader(_load_remote, rep, is_another_app=True, hydrate_lazily=True)
1407
+ environment_rep = (
1408
+ f", environment_name={load_context_overrides.environment_name!r}"
1409
+ if load_context_overrides._environment_name # slightly ugly - checking if _environment_name is overridden
1410
+ else ""
1411
+ )
1412
+ rep = f"modal.Function.from_name('{app_name}', '{name}'{environment_rep})"
1413
+ return cls._from_loader(
1414
+ _load_remote, rep, is_another_app=True, hydrate_lazily=True, load_context_overrides=load_context_overrides
1415
+ )
1261
1416
 
1262
1417
  @classmethod
1263
1418
  def from_name(
@@ -1267,6 +1422,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1267
1422
  *,
1268
1423
  namespace=None, # mdmd:line-hidden
1269
1424
  environment_name: Optional[str] = None,
1425
+ client: Optional[_Client] = None,
1270
1426
  ) -> "_Function":
1271
1427
  """Reference a Function from a deployed App by its name.
1272
1428
 
@@ -1290,41 +1446,9 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1290
1446
  )
1291
1447
 
1292
1448
  warn_if_passing_namespace(namespace, "modal.Function.from_name")
1293
- return cls._from_name(app_name, name, environment_name=environment_name)
1294
-
1295
- @staticmethod
1296
- async def lookup(
1297
- app_name: str,
1298
- name: str,
1299
- namespace=None, # mdmd:line-hidden
1300
- client: Optional[_Client] = None,
1301
- environment_name: Optional[str] = None,
1302
- ) -> "_Function":
1303
- """mdmd:hidden
1304
- Lookup a Function from a deployed App by its name.
1305
-
1306
- DEPRECATED: This method is deprecated in favor of `modal.Function.from_name`.
1307
-
1308
- In contrast to `modal.Function.from_name`, this is an eager method
1309
- that will hydrate the local object with metadata from Modal servers.
1310
-
1311
- ```python notest
1312
- f = modal.Function.lookup("other-app", "function")
1313
- ```
1314
- """
1315
- deprecation_warning(
1316
- (2025, 1, 27),
1317
- "`modal.Function.lookup` is deprecated and will be removed in a future release."
1318
- " It can be replaced with `modal.Function.from_name`."
1319
- "\n\nSee https://modal.com/docs/guide/modal-1-0-migration for more information.",
1449
+ return cls._from_name(
1450
+ app_name, name, load_context_overrides=LoadContext(environment_name=environment_name, client=client)
1320
1451
  )
1321
- warn_if_passing_namespace(namespace, "modal.Function.lookup")
1322
- obj = _Function.from_name(app_name, name, environment_name=environment_name)
1323
- if client is None:
1324
- client = await _Client.from_env()
1325
- resolver = Resolver(client=client)
1326
- await resolver.load(obj)
1327
- return obj
1328
1452
 
1329
1453
  @property
1330
1454
  def tag(self) -> str:
@@ -1380,6 +1504,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1380
1504
  self._info = None
1381
1505
  self._serve_mounts = frozenset()
1382
1506
  self._metadata = None
1507
+ self._experimental_flash_urls = None
1383
1508
 
1384
1509
  def _hydrate_metadata(self, metadata: Optional[Message]):
1385
1510
  # Overridden concrete implementation of base class method
@@ -1407,6 +1532,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1407
1532
  self._max_object_size_bytes = (
1408
1533
  metadata.max_object_size_bytes if metadata.HasField("max_object_size_bytes") else MAX_OBJECT_SIZE_BYTES
1409
1534
  )
1535
+ self._experimental_flash_urls = metadata._experimental_flash_urls
1410
1536
 
1411
1537
  def _get_metadata(self):
1412
1538
  # Overridden concrete implementation of base class method
@@ -1424,6 +1550,9 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1424
1550
  input_plane_url=self._input_plane_url,
1425
1551
  input_plane_region=self._input_plane_region,
1426
1552
  max_object_size_bytes=self._max_object_size_bytes,
1553
+ _experimental_flash_urls=self._experimental_flash_urls,
1554
+ supported_input_formats=self._metadata.supported_input_formats if self._metadata else [],
1555
+ supported_output_formats=self._metadata.supported_output_formats if self._metadata else [],
1427
1556
  )
1428
1557
 
1429
1558
  def _check_no_web_url(self, fn_name: str):
@@ -1454,6 +1583,11 @@ Use the `Function.get_web_url()` method instead.
1454
1583
  """URL of a Function running as a web endpoint."""
1455
1584
  return self._web_url
1456
1585
 
1586
+ @live_method
1587
+ async def _experimental_get_flash_urls(self) -> Optional[list[str]]:
1588
+ """URL of the flash service for the function."""
1589
+ return list(self._experimental_flash_urls) if self._experimental_flash_urls else None
1590
+
1457
1591
  @property
1458
1592
  async def is_generator(self) -> bool:
1459
1593
  """mdmd:hidden"""
@@ -1495,20 +1629,51 @@ Use the `Function.get_web_url()` method instead.
1495
1629
  else:
1496
1630
  count_update_callback = None
1497
1631
 
1498
- async with aclosing(
1499
- _map_invocation(
1500
- self,
1501
- input_queue,
1502
- self.client,
1503
- order_outputs,
1504
- return_exceptions,
1505
- wrap_returned_exceptions,
1506
- count_update_callback,
1507
- api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
1508
- )
1509
- ) as stream:
1510
- async for item in stream:
1511
- yield item
1632
+ if self._input_plane_url:
1633
+ async with aclosing(
1634
+ _map_invocation_inputplane(
1635
+ self,
1636
+ input_queue,
1637
+ self.client,
1638
+ order_outputs,
1639
+ return_exceptions,
1640
+ wrap_returned_exceptions,
1641
+ count_update_callback,
1642
+ )
1643
+ ) as stream:
1644
+ async for item in stream:
1645
+ yield item
1646
+ else:
1647
+ async with aclosing(
1648
+ _map_invocation(
1649
+ self,
1650
+ input_queue,
1651
+ self.client,
1652
+ order_outputs,
1653
+ return_exceptions,
1654
+ wrap_returned_exceptions,
1655
+ count_update_callback,
1656
+ api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
1657
+ )
1658
+ ) as stream:
1659
+ async for item in stream:
1660
+ yield item
1661
+
1662
+ @live_method
1663
+ async def _spawn_map(self, input_queue: _SynchronizedQueue) -> "_FunctionCall[ReturnType]":
1664
+ self._check_no_web_url("spawn_map")
1665
+ if self._is_generator:
1666
+ raise InvalidError("A generator function cannot be called with `.spawn_map(...)`.")
1667
+
1668
+ assert self._function_name
1669
+ function_call_id, num_inputs = await _spawn_map_invocation(
1670
+ self,
1671
+ input_queue,
1672
+ self.client,
1673
+ )
1674
+ fc: _FunctionCall[ReturnType] = _FunctionCall._new_hydrated(function_call_id, self.client, None)
1675
+ fc._num_inputs = num_inputs # set the cached value of num_inputs
1676
+ return fc
1512
1677
 
1513
1678
  async def _call_function(self, args, kwargs) -> ReturnType:
1514
1679
  invocation: Union[_Invocation, _InputPlaneInvocation]
@@ -1552,13 +1717,24 @@ Use the `Function.get_web_url()` method instead.
1552
1717
  @live_method_gen
1553
1718
  @synchronizer.no_input_translation
1554
1719
  async def _call_generator(self, args, kwargs):
1555
- invocation = await _Invocation.create(
1556
- self,
1557
- args,
1558
- kwargs,
1559
- client=self.client,
1560
- function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY,
1561
- )
1720
+ invocation: Union[_Invocation, _InputPlaneInvocation]
1721
+ if self._input_plane_url:
1722
+ invocation = await _InputPlaneInvocation.create(
1723
+ self,
1724
+ args,
1725
+ kwargs,
1726
+ client=self.client,
1727
+ input_plane_url=self._input_plane_url,
1728
+ input_plane_region=self._input_plane_region,
1729
+ )
1730
+ else:
1731
+ invocation = await _Invocation.create(
1732
+ self,
1733
+ args,
1734
+ kwargs,
1735
+ client=self.client,
1736
+ function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY,
1737
+ )
1562
1738
  async for res in invocation.run_generator():
1563
1739
  yield res
1564
1740
 
@@ -1622,8 +1798,9 @@ Use the `Function.get_web_url()` method instead.
1622
1798
  # "user code" to run on the synchronicity thread, which seems bad
1623
1799
  if not self._is_local():
1624
1800
  msg = (
1625
- "The definition for this function is missing here so it is not possible to invoke it locally. "
1626
- "If this function was retrieved via `Function.lookup` you need to use `.remote()`."
1801
+ "The definition for this Function is missing, so it is not possible to invoke it locally. "
1802
+ "If this function was retrieved via `Function.from_name`, "
1803
+ "you need to use one of the remote invocation methods instead."
1627
1804
  )
1628
1805
  raise ExecutionError(msg)
1629
1806
 
@@ -1724,10 +1901,9 @@ Use the `Function.get_web_url()` method instead.
1724
1901
  @live_method
1725
1902
  async def get_current_stats(self) -> FunctionStats:
1726
1903
  """Return a `FunctionStats` object describing the current function's queue and runner counts."""
1727
- resp = await retry_transient_errors(
1728
- self.client.stub.FunctionGetCurrentStats,
1904
+ resp = await self.client.stub.FunctionGetCurrentStats(
1729
1905
  api_pb2.FunctionGetCurrentStatsRequest(function_id=self.object_id),
1730
- total_timeout=10.0,
1906
+ retry=Retry(total_timeout=10.0),
1731
1907
  )
1732
1908
  return FunctionStats(backlog=resp.backlog, num_total_runners=resp.num_total_tasks)
1733
1909
 
@@ -1745,6 +1921,7 @@ Use the `Function.get_web_url()` method instead.
1745
1921
  starmap = MethodWithAio(_starmap_sync, _starmap_async, synchronizer)
1746
1922
  for_each = MethodWithAio(_for_each_sync, _for_each_async, synchronizer)
1747
1923
  spawn_map = MethodWithAio(_spawn_map_sync, _spawn_map_async, synchronizer)
1924
+ experimental_spawn_map = MethodWithAio(_experimental_spawn_map_sync, _experimental_spawn_map_async, synchronizer)
1748
1925
 
1749
1926
 
1750
1927
  class _FunctionCall(typing.Generic[ReturnType], _Object, type_prefix="fc"):
@@ -1759,12 +1936,25 @@ class _FunctionCall(typing.Generic[ReturnType], _Object, type_prefix="fc"):
1759
1936
  """
1760
1937
 
1761
1938
  _is_generator: bool = False
1939
+ _num_inputs: Optional[int] = None
1762
1940
 
1763
1941
  def _invocation(self):
1764
1942
  return _Invocation(self.client.stub, self.object_id, self.client)
1765
1943
 
1766
- async def get(self, timeout: Optional[float] = None) -> ReturnType:
1767
- """Get the result of the function call.
1944
+ @live_method
1945
+ async def num_inputs(self) -> int:
1946
+ """Get the number of inputs in the function call."""
1947
+ if self._num_inputs is None:
1948
+ request = api_pb2.FunctionCallFromIdRequest(function_call_id=self.object_id)
1949
+ resp = await self.client.stub.FunctionCallFromId(request)
1950
+ self._num_inputs = resp.num_inputs # cached
1951
+ return self._num_inputs
1952
+
1953
+ @live_method
1954
+ async def get(self, timeout: Optional[float] = None, *, index: int = 0) -> ReturnType:
1955
+ """Get the result of the index-th input of the function call.
1956
+ `.spawn()` calls have a single output, so only specifying `index=0` is valid.
1957
+ A non-zero index is useful when your function has multiple outputs, like via `.spawn_map()`.
1768
1958
 
1769
1959
  This function waits indefinitely by default. It takes an optional
1770
1960
  `timeout` argument that specifies the maximum number of seconds to wait,
@@ -1772,8 +1962,39 @@ class _FunctionCall(typing.Generic[ReturnType], _Object, type_prefix="fc"):
1772
1962
 
1773
1963
  The returned coroutine is not cancellation-safe.
1774
1964
  """
1775
- return await self._invocation().poll_function(timeout=timeout)
1965
+ return await self._invocation().poll_function(timeout=timeout, index=index)
1966
+
1967
+ @live_method_gen
1968
+ async def iter(self, *, start: int = 0, end: Optional[int] = None) -> AsyncIterator[ReturnType]:
1969
+ """Iterate in-order over the results of the function call.
1970
+
1971
+ Optionally, specify a range [start, end) to iterate over.
1972
+
1973
+ Example:
1974
+ ```python
1975
+ @app.function()
1976
+ def my_func(a):
1977
+ return a ** 2
1978
+
1776
1979
 
1980
+ @app.local_entrypoint()
1981
+ def main():
1982
+ fc = my_func.spawn_map([1, 2, 3, 4])
1983
+ assert list(fc.iter()) == [1, 4, 9, 16]
1984
+ assert list(fc.iter(start=1, end=3)) == [4, 9]
1985
+ ```
1986
+
1987
+ If `end` is not provided, it will iterate over all results.
1988
+ """
1989
+ num_inputs = await self.num_inputs()
1990
+ if end is None:
1991
+ end = num_inputs
1992
+ if start < 0 or end > num_inputs:
1993
+ raise ValueError(f"Invalid index range: {start} to {end} for {num_inputs} inputs")
1994
+ async for _, item in self._invocation().enumerate(start_index=start, end_index=end):
1995
+ yield item
1996
+
1997
+ @live_method
1777
1998
  async def get_call_graph(self) -> list[InputInfo]:
1778
1999
  """Returns a structure representing the call graph from a given root
1779
2000
  call ID, along with the status of execution for each node.
@@ -1783,9 +2004,10 @@ class _FunctionCall(typing.Generic[ReturnType], _Object, type_prefix="fc"):
1783
2004
  """
1784
2005
  assert self._client and self._client.stub
1785
2006
  request = api_pb2.FunctionGetCallGraphRequest(function_call_id=self.object_id)
1786
- response = await retry_transient_errors(self._client.stub.FunctionGetCallGraph, request)
2007
+ response = await self._client.stub.FunctionGetCallGraph(request)
1787
2008
  return _reconstruct_call_graph(response)
1788
2009
 
2010
+ @live_method
1789
2011
  async def cancel(
1790
2012
  self,
1791
2013
  # if true, containers running the inputs are forcibly terminated
@@ -1801,7 +2023,7 @@ class _FunctionCall(typing.Generic[ReturnType], _Object, type_prefix="fc"):
1801
2023
  function_call_id=self.object_id, terminate_containers=terminate_containers
1802
2024
  )
1803
2025
  assert self._client and self._client.stub
1804
- await retry_transient_errors(self._client.stub.FunctionCallCancel, request)
2026
+ await self._client.stub.FunctionCallCancel(request)
1805
2027
 
1806
2028
  @staticmethod
1807
2029
  async def from_id(function_call_id: str, client: Optional[_Client] = None) -> "_FunctionCall[Any]":
@@ -1823,11 +2045,18 @@ class _FunctionCall(typing.Generic[ReturnType], _Object, type_prefix="fc"):
1823
2045
  if you no longer have access to the original object returned from `Function.spawn`.
1824
2046
 
1825
2047
  """
1826
- if client is None:
1827
- client = await _Client.from_env()
1828
2048
 
1829
- fc: _FunctionCall[Any] = _FunctionCall._new_hydrated(function_call_id, client, None)
1830
- return fc
2049
+ async def _load(
2050
+ self: _FunctionCall, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]
2051
+ ):
2052
+ # this loader doesn't do anything in practice, but it will get the client from the load_context
2053
+ self._hydrate(function_call_id, load_context.client, None)
2054
+
2055
+ rep = f"FunctionCall.from_id({function_call_id!r})"
2056
+
2057
+ return _FunctionCall._from_loader(
2058
+ _load, rep, hydrate_lazily=True, load_context_overrides=LoadContext(client=client)
2059
+ )
1831
2060
 
1832
2061
  @staticmethod
1833
2062
  async def gather(*function_calls: "_FunctionCall[T]") -> typing.Sequence[T]: