modal 0.62.115__py3-none-any.whl → 0.72.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +13 -9
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +402 -398
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -60
  11. modal/_resources.py +26 -7
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1025 -0
  15. modal/{execution_context.py → _runtime/execution_context.py} +11 -2
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +123 -6
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +50 -14
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +386 -104
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +299 -98
  29. modal/_utils/grpc_testing.py +47 -34
  30. modal/_utils/grpc_utils.py +54 -21
  31. modal/_utils/hash_utils.py +51 -10
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +3 -3
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +12 -10
  43. modal/app.py +561 -323
  44. modal/app.pyi +474 -262
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +22 -6
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +203 -42
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +61 -13
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +21 -48
  55. modal/cli/launch.py +28 -14
  56. modal/cli/network_file_system.py +57 -21
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +34 -9
  59. modal/cli/programs/vscode.py +58 -8
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +199 -96
  62. modal/cli/secret.py +5 -4
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +74 -8
  65. modal/cli/volume.py +97 -56
  66. modal/client.py +248 -144
  67. modal/client.pyi +156 -124
  68. modal/cloud_bucket_mount.py +43 -30
  69. modal/cloud_bucket_mount.pyi +32 -25
  70. modal/cls.py +528 -141
  71. modal/cls.pyi +189 -145
  72. modal/config.py +32 -15
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +50 -54
  76. modal/dict.pyi +120 -164
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +30 -43
  80. modal/experimental.py +62 -2
  81. modal/file_io.py +537 -0
  82. modal/file_io.pyi +235 -0
  83. modal/file_pattern_matcher.py +196 -0
  84. modal/functions.py +846 -428
  85. modal/functions.pyi +446 -387
  86. modal/gpu.py +57 -44
  87. modal/image.py +943 -417
  88. modal/image.pyi +584 -245
  89. modal/io_streams.py +434 -0
  90. modal/io_streams.pyi +122 -0
  91. modal/mount.py +223 -90
  92. modal/mount.pyi +241 -243
  93. modal/network_file_system.py +85 -86
  94. modal/network_file_system.pyi +151 -110
  95. modal/object.py +66 -36
  96. modal/object.pyi +166 -143
  97. modal/output.py +63 -0
  98. modal/parallel_map.py +73 -47
  99. modal/parallel_map.pyi +51 -63
  100. modal/partial_function.py +272 -107
  101. modal/partial_function.pyi +219 -120
  102. modal/proxy.py +15 -12
  103. modal/proxy.pyi +3 -8
  104. modal/queue.py +96 -72
  105. modal/queue.pyi +210 -135
  106. modal/requirements/2024.04.txt +2 -1
  107. modal/requirements/2024.10.txt +16 -0
  108. modal/requirements/README.md +21 -0
  109. modal/requirements/base-images.json +22 -0
  110. modal/retries.py +45 -4
  111. modal/runner.py +325 -203
  112. modal/runner.pyi +124 -110
  113. modal/running_app.py +27 -4
  114. modal/sandbox.py +509 -231
  115. modal/sandbox.pyi +396 -169
  116. modal/schedule.py +2 -2
  117. modal/scheduler_placement.py +20 -3
  118. modal/secret.py +41 -25
  119. modal/secret.pyi +62 -42
  120. modal/serving.py +39 -49
  121. modal/serving.pyi +37 -43
  122. modal/stream_type.py +15 -0
  123. modal/token_flow.py +5 -3
  124. modal/token_flow.pyi +37 -32
  125. modal/volume.py +123 -137
  126. modal/volume.pyi +228 -221
  127. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/METADATA +5 -5
  128. modal-0.72.13.dist-info/RECORD +174 -0
  129. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/top_level.txt +0 -1
  130. modal_docs/gen_reference_docs.py +3 -1
  131. modal_docs/mdmd/mdmd.py +0 -1
  132. modal_docs/mdmd/signatures.py +1 -2
  133. modal_global_objects/images/base_images.py +28 -0
  134. modal_global_objects/mounts/python_standalone.py +2 -2
  135. modal_proto/__init__.py +1 -1
  136. modal_proto/api.proto +1231 -531
  137. modal_proto/api_grpc.py +750 -430
  138. modal_proto/api_pb2.py +2102 -1176
  139. modal_proto/api_pb2.pyi +8859 -0
  140. modal_proto/api_pb2_grpc.py +1329 -675
  141. modal_proto/api_pb2_grpc.pyi +1416 -0
  142. modal_proto/modal_api_grpc.py +149 -0
  143. modal_proto/modal_options_grpc.py +3 -0
  144. modal_proto/options_pb2.pyi +20 -0
  145. modal_proto/options_pb2_grpc.pyi +7 -0
  146. modal_proto/py.typed +0 -0
  147. modal_version/__init__.py +1 -1
  148. modal_version/_version_generated.py +2 -2
  149. modal/_asgi.py +0 -370
  150. modal/_container_exec.py +0 -128
  151. modal/_container_io_manager.py +0 -646
  152. modal/_container_io_manager.pyi +0 -412
  153. modal/_sandbox_shell.py +0 -49
  154. modal/app_utils.py +0 -20
  155. modal/app_utils.pyi +0 -17
  156. modal/execution_context.pyi +0 -37
  157. modal/shared_volume.py +0 -23
  158. modal/shared_volume.pyi +0 -24
  159. modal-0.62.115.dist-info/RECORD +0 -207
  160. modal_global_objects/images/conda.py +0 -15
  161. modal_global_objects/images/debian_slim.py +0 -15
  162. modal_global_objects/images/micromamba.py +0 -15
  163. test/__init__.py +0 -1
  164. test/aio_test.py +0 -12
  165. test/async_utils_test.py +0 -279
  166. test/blob_test.py +0 -67
  167. test/cli_imports_test.py +0 -149
  168. test/cli_test.py +0 -674
  169. test/client_test.py +0 -203
  170. test/cloud_bucket_mount_test.py +0 -22
  171. test/cls_test.py +0 -636
  172. test/config_test.py +0 -149
  173. test/conftest.py +0 -1485
  174. test/container_app_test.py +0 -50
  175. test/container_test.py +0 -1405
  176. test/cpu_test.py +0 -23
  177. test/decorator_test.py +0 -85
  178. test/deprecation_test.py +0 -34
  179. test/dict_test.py +0 -51
  180. test/e2e_test.py +0 -68
  181. test/error_test.py +0 -7
  182. test/function_serialization_test.py +0 -32
  183. test/function_test.py +0 -791
  184. test/function_utils_test.py +0 -101
  185. test/gpu_test.py +0 -159
  186. test/grpc_utils_test.py +0 -82
  187. test/helpers.py +0 -47
  188. test/image_test.py +0 -814
  189. test/live_reload_test.py +0 -80
  190. test/lookup_test.py +0 -70
  191. test/mdmd_test.py +0 -329
  192. test/mount_test.py +0 -162
  193. test/mounted_files_test.py +0 -327
  194. test/network_file_system_test.py +0 -188
  195. test/notebook_test.py +0 -66
  196. test/object_test.py +0 -41
  197. test/package_utils_test.py +0 -25
  198. test/queue_test.py +0 -115
  199. test/resolver_test.py +0 -59
  200. test/retries_test.py +0 -67
  201. test/runner_test.py +0 -85
  202. test/sandbox_test.py +0 -191
  203. test/schedule_test.py +0 -15
  204. test/scheduler_placement_test.py +0 -57
  205. test/secret_test.py +0 -89
  206. test/serialization_test.py +0 -50
  207. test/stub_composition_test.py +0 -10
  208. test/stub_test.py +0 -361
  209. test/test_asgi_wrapper.py +0 -234
  210. test/token_flow_test.py +0 -18
  211. test/traceback_test.py +0 -135
  212. test/tunnel_test.py +0 -29
  213. test/utils_test.py +0 -88
  214. test/version_test.py +0 -14
  215. test/volume_test.py +0 -397
  216. test/watcher_test.py +0 -58
  217. test/webhook_test.py +0 -145
  218. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/LICENSE +0 -0
  219. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/WHEEL +0 -0
  220. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/entry_points.txt +0 -0
modal/functions.py CHANGED
@@ -1,73 +1,78 @@
1
1
  # Copyright Modal Labs 2023
2
- import asyncio
2
+ import dataclasses
3
3
  import inspect
4
+ import textwrap
4
5
  import time
6
+ import typing
5
7
  import warnings
8
+ from collections.abc import AsyncGenerator, Collection, Sequence, Sized
6
9
  from dataclasses import dataclass
7
10
  from pathlib import PurePosixPath
8
11
  from typing import (
9
12
  TYPE_CHECKING,
10
13
  Any,
11
- AsyncGenerator,
12
- AsyncIterator,
13
14
  Callable,
14
- Collection,
15
- Dict,
16
- List,
17
15
  Optional,
18
- Sequence,
19
- Sized,
20
- Tuple,
21
- Type,
22
16
  Union,
23
17
  )
24
18
 
25
- from aiostream import stream
19
+ import typing_extensions
26
20
  from google.protobuf.message import Message
27
21
  from grpclib import GRPCError, Status
28
22
  from synchronicity.combined_types import MethodWithAio
23
+ from synchronicity.exceptions import UserCodeException
29
24
 
30
- from modal_proto import api_grpc, api_pb2
25
+ from modal_proto import api_pb2
26
+ from modal_proto.modal_api_grpc import ModalClientModal
31
27
 
32
28
  from ._location import parse_cloud_provider
33
- from ._output import OutputManager
34
29
  from ._pty import get_pty_info
35
30
  from ._resolver import Resolver
36
31
  from ._resources import convert_fn_config_to_resources_config
37
- from ._serialization import serialize
32
+ from ._runtime.execution_context import current_input_id, is_local
33
+ from ._serialization import serialize, serialize_proto_params
34
+ from ._traceback import print_server_warnings
38
35
  from ._utils.async_utils import (
36
+ TaskContext,
37
+ aclosing,
38
+ async_merge,
39
+ callable_to_agen,
39
40
  synchronize_api,
40
41
  synchronizer,
41
42
  warn_if_generator_is_not_consumed,
42
43
  )
44
+ from ._utils.deprecation import deprecation_warning, renamed_parameter
43
45
  from ._utils.function_utils import (
44
46
  ATTEMPT_TIMEOUT_GRACE_PERIOD,
45
47
  OUTPUTS_TIMEOUT,
48
+ FunctionCreationStatus,
46
49
  FunctionInfo,
47
50
  _create_input,
48
51
  _process_result,
49
52
  _stream_function_call_data,
50
- get_referred_objects,
53
+ get_function_type,
51
54
  is_async,
52
55
  )
53
56
  from ._utils.grpc_utils import retry_transient_errors
54
- from ._utils.mount_utils import validate_mount_points, validate_volumes
57
+ from ._utils.mount_utils import validate_network_file_systems, validate_volumes
55
58
  from .call_graph import InputInfo, _reconstruct_call_graph
56
59
  from .client import _Client
57
60
  from .cloud_bucket_mount import _CloudBucketMount, cloud_bucket_mounts_to_proto
58
61
  from .config import config
59
62
  from .exception import (
60
63
  ExecutionError,
64
+ FunctionTimeoutError,
65
+ InternalFailure,
61
66
  InvalidError,
62
67
  NotFoundError,
63
- deprecation_warning,
68
+ OutputExpiredError,
64
69
  )
65
- from .execution_context import current_input_id, is_local
66
70
  from .gpu import GPU_T, parse_gpu_config
67
71
  from .image import _Image
68
72
  from .mount import _get_client_mount, _Mount, get_auto_mounts
69
73
  from .network_file_system import _NetworkFileSystem, network_file_system_mount_protos
70
- from .object import Object, _get_environment_name, _Object, live_method, live_method_gen
74
+ from .object import _get_environment_name, _Object, live_method, live_method_gen
75
+ from .output import _get_output_manager
71
76
  from .parallel_map import (
72
77
  _for_each_async,
73
78
  _for_each_sync,
@@ -79,7 +84,7 @@ from .parallel_map import (
79
84
  _SynchronizedQueue,
80
85
  )
81
86
  from .proxy import _Proxy
82
- from .retries import Retries
87
+ from .retries import Retries, RetryManager
83
88
  from .schedule import Schedule
84
89
  from .scheduler_placement import SchedulerPlacement
85
90
  from .secret import _Secret
@@ -87,32 +92,72 @@ from .volume import _Volume
87
92
 
88
93
  if TYPE_CHECKING:
89
94
  import modal.app
95
+ import modal.cls
96
+ import modal.partial_function
97
+
98
+
99
+ @dataclasses.dataclass
100
+ class _RetryContext:
101
+ function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType"
102
+ retry_policy: api_pb2.FunctionRetryPolicy
103
+ function_call_jwt: str
104
+ input_jwt: str
105
+ input_id: str
106
+ item: api_pb2.FunctionPutInputsItem
90
107
 
91
108
 
92
109
  class _Invocation:
93
110
  """Internal client representation of a single-input call to a Modal Function or Generator"""
94
111
 
95
- def __init__(self, stub: api_grpc.ModalClientStub, function_call_id: str, client: _Client):
112
+ stub: ModalClientModal
113
+
114
+ def __init__(
115
+ self,
116
+ stub: ModalClientModal,
117
+ function_call_id: str,
118
+ client: _Client,
119
+ retry_context: Optional[_RetryContext] = None,
120
+ ):
96
121
  self.stub = stub
97
122
  self.client = client # Used by the deserializer.
98
123
  self.function_call_id = function_call_id # TODO: remove and use only input_id
124
+ self._retry_context = retry_context
99
125
 
100
126
  @staticmethod
101
- async def create(function_id: str, args, kwargs, client: _Client) -> "_Invocation":
127
+ async def create(
128
+ function: "_Function",
129
+ args,
130
+ kwargs,
131
+ *,
132
+ client: _Client,
133
+ function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType",
134
+ ) -> "_Invocation":
102
135
  assert client.stub
103
- item = await _create_input(args, kwargs, client)
136
+ function_id = function.object_id
137
+ item = await _create_input(args, kwargs, client, method_name=function._use_method_name)
104
138
 
105
139
  request = api_pb2.FunctionMapRequest(
106
140
  function_id=function_id,
107
141
  parent_input_id=current_input_id() or "",
108
142
  function_call_type=api_pb2.FUNCTION_CALL_TYPE_UNARY,
109
143
  pipelined_inputs=[item],
144
+ function_call_invocation_type=function_call_invocation_type,
110
145
  )
111
146
  response = await retry_transient_errors(client.stub.FunctionMap, request)
112
147
  function_call_id = response.function_call_id
113
148
 
114
149
  if response.pipelined_inputs:
115
- return _Invocation(client.stub, function_call_id, client)
150
+ assert len(response.pipelined_inputs) == 1
151
+ input = response.pipelined_inputs[0]
152
+ retry_context = _RetryContext(
153
+ function_call_invocation_type=function_call_invocation_type,
154
+ retry_policy=response.retry_policy,
155
+ function_call_jwt=response.function_call_jwt,
156
+ input_jwt=input.input_jwt,
157
+ input_id=input.input_id,
158
+ item=item,
159
+ )
160
+ return _Invocation(client.stub, function_call_id, client, retry_context)
116
161
 
117
162
  request_put = api_pb2.FunctionPutInputsRequest(
118
163
  function_id=function_id, inputs=[item], function_call_id=function_call_id
@@ -124,11 +169,20 @@ class _Invocation:
124
169
  processed_inputs = inputs_response.inputs
125
170
  if not processed_inputs:
126
171
  raise Exception("Could not create function call - the input queue seems to be full")
127
- return _Invocation(client.stub, function_call_id, client)
172
+ input = inputs_response.inputs[0]
173
+ retry_context = _RetryContext(
174
+ function_call_invocation_type=function_call_invocation_type,
175
+ retry_policy=response.retry_policy,
176
+ function_call_jwt=response.function_call_jwt,
177
+ input_jwt=input.input_jwt,
178
+ input_id=input.input_id,
179
+ item=item,
180
+ )
181
+ return _Invocation(client.stub, function_call_id, client, retry_context)
128
182
 
129
183
  async def pop_function_call_outputs(
130
- self, timeout: Optional[float], clear_on_success: bool
131
- ) -> AsyncIterator[api_pb2.FunctionGetOutputsItem]:
184
+ self, timeout: Optional[float], clear_on_success: bool, input_jwts: Optional[list[str]] = None
185
+ ) -> api_pb2.FunctionGetOutputsResponse:
132
186
  t0 = time.time()
133
187
  if timeout is None:
134
188
  backend_timeout = OUTPUTS_TIMEOUT
@@ -142,53 +196,100 @@ class _Invocation:
142
196
  timeout=backend_timeout,
143
197
  last_entry_id="0-0",
144
198
  clear_on_success=clear_on_success,
199
+ requested_at=time.time(),
200
+ input_jwts=input_jwts,
145
201
  )
146
202
  response: api_pb2.FunctionGetOutputsResponse = await retry_transient_errors(
147
203
  self.stub.FunctionGetOutputs,
148
204
  request,
149
205
  attempt_timeout=backend_timeout + ATTEMPT_TIMEOUT_GRACE_PERIOD,
150
206
  )
207
+
151
208
  if len(response.outputs) > 0:
152
- for item in response.outputs:
153
- yield item
154
- return
209
+ return response
155
210
 
156
211
  if timeout is not None:
157
212
  # update timeout in retry loop
158
213
  backend_timeout = min(OUTPUTS_TIMEOUT, t0 + timeout - time.time())
159
214
  if backend_timeout < 0:
160
- break
215
+ # return the last response to check for state of num_unfinished_inputs
216
+ return response
217
+
218
+ async def _retry_input(self) -> None:
219
+ ctx = self._retry_context
220
+ if not ctx:
221
+ raise ValueError("Cannot retry input when _retry_context is empty.")
222
+
223
+ item = api_pb2.FunctionRetryInputsItem(input_jwt=ctx.input_jwt, input=ctx.item.input)
224
+ request = api_pb2.FunctionRetryInputsRequest(function_call_jwt=ctx.function_call_jwt, inputs=[item])
225
+ await retry_transient_errors(
226
+ self.client.stub.FunctionRetryInputs,
227
+ request,
228
+ )
161
229
 
162
- async def run_function(self) -> Any:
230
+ async def _get_single_output(self, expected_jwt: Optional[str] = None) -> Any:
163
231
  # waits indefinitely for a single result for the function, and clear the outputs buffer after
164
232
  item: api_pb2.FunctionGetOutputsItem = (
165
- await stream.list(self.pop_function_call_outputs(timeout=None, clear_on_success=True))
166
- )[0]
167
- assert not item.result.gen_status
233
+ await self.pop_function_call_outputs(
234
+ timeout=None,
235
+ clear_on_success=True,
236
+ input_jwts=[expected_jwt] if expected_jwt else None,
237
+ )
238
+ ).outputs[0]
168
239
  return await _process_result(item.result, item.data_format, self.stub, self.client)
169
240
 
241
+ async def run_function(self) -> Any:
242
+ # Use retry logic only if retry policy is specified and
243
+ ctx = self._retry_context
244
+ if (
245
+ not ctx
246
+ or not ctx.retry_policy
247
+ or ctx.retry_policy.retries == 0
248
+ or ctx.function_call_invocation_type != api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
249
+ ):
250
+ return await self._get_single_output()
251
+
252
+ # User errors including timeouts are managed by the user specified retry policy.
253
+ user_retry_manager = RetryManager(ctx.retry_policy)
254
+
255
+ while True:
256
+ try:
257
+ return await self._get_single_output(ctx.input_jwt)
258
+ except (UserCodeException, FunctionTimeoutError) as exc:
259
+ await user_retry_manager.raise_or_sleep(exc)
260
+ except InternalFailure:
261
+ # For system failures on the server, we retry immediately.
262
+ pass
263
+ await self._retry_input()
264
+
170
265
  async def poll_function(self, timeout: Optional[float] = None):
171
266
  """Waits up to timeout for a result from a function.
172
267
 
173
268
  If timeout is `None`, waits indefinitely. This function is not
174
269
  cancellation-safe.
175
270
  """
176
- items: List[api_pb2.FunctionGetOutputsItem] = await stream.list(
177
- self.pop_function_call_outputs(timeout=timeout, clear_on_success=False)
271
+ response: api_pb2.FunctionGetOutputsResponse = await self.pop_function_call_outputs(
272
+ timeout=timeout, clear_on_success=False
178
273
  )
179
-
180
- if len(items) == 0:
274
+ if len(response.outputs) == 0 and response.num_unfinished_inputs == 0:
275
+ # if no unfinished inputs and no outputs, then function expired
276
+ raise OutputExpiredError()
277
+ elif len(response.outputs) == 0:
181
278
  raise TimeoutError()
182
279
 
183
- return await _process_result(items[0].result, items[0].data_format, self.stub, self.client)
280
+ return await _process_result(
281
+ response.outputs[0].result, response.outputs[0].data_format, self.stub, self.client
282
+ )
184
283
 
185
284
  async def run_generator(self):
186
- data_stream = _stream_function_call_data(self.client, self.function_call_id, variant="data_out")
187
- combined_stream = stream.merge(data_stream, stream.call(self.run_function)) # type: ignore
188
-
189
285
  items_received = 0
190
286
  items_total: Union[int, None] = None # populated when self.run_function() completes
191
- async with combined_stream.stream() as streamer:
287
+ async with aclosing(
288
+ async_merge(
289
+ _stream_function_call_data(self.client, self.function_call_id, variant="data_out"),
290
+ callable_to_agen(self.run_function),
291
+ )
292
+ ) as streamer:
192
293
  async for item in streamer:
193
294
  if isinstance(item, api_pb2.GeneratorDone):
194
295
  items_total = item.items_total
@@ -207,13 +308,23 @@ class FunctionStats:
207
308
  """Simple data structure storing stats for a running function."""
208
309
 
209
310
  backlog: int
210
- num_active_runners: int
211
311
  num_total_runners: int
212
312
 
313
+ def __getattr__(self, name):
314
+ if name == "num_active_runners":
315
+ msg = (
316
+ "'FunctionStats.num_active_runners' is deprecated."
317
+ " It currently always has a value of 0,"
318
+ " but it will be removed in a future release."
319
+ )
320
+ deprecation_warning((2024, 6, 14), msg)
321
+ return 0
322
+ raise AttributeError(f"'FunctionStats' object has no attribute '{name}'")
323
+
213
324
 
214
325
  def _parse_retries(
215
326
  retries: Optional[Union[int, Retries]],
216
- raw_f: Optional[Callable] = None,
327
+ source: str = "",
217
328
  ) -> Optional[api_pb2.FunctionRetryPolicy]:
218
329
  if isinstance(retries, int):
219
330
  return Retries(
@@ -226,10 +337,9 @@ def _parse_retries(
226
337
  elif retries is None:
227
338
  return None
228
339
  else:
229
- err_object = f"Function {raw_f}" if raw_f else "Function"
230
- raise InvalidError(
231
- f"{err_object} retries must be an integer or instance of modal.Retries. Found: {type(retries)}"
232
- )
340
+ extra = f" on {source}" if source else ""
341
+ msg = f"Retries parameter must be an integer or instance of modal.Retries. Found: {type(retries)}{extra}."
342
+ raise InvalidError(msg)
233
343
 
234
344
 
235
345
  @dataclass
@@ -243,103 +353,152 @@ class _FunctionSpec:
243
353
  image: Optional[_Image]
244
354
  mounts: Sequence[_Mount]
245
355
  secrets: Sequence[_Secret]
246
- network_file_systems: Dict[Union[str, PurePosixPath], _NetworkFileSystem]
247
- volumes: Dict[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]]
248
- gpu: GPU_T
356
+ network_file_systems: dict[Union[str, PurePosixPath], _NetworkFileSystem]
357
+ volumes: dict[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]]
358
+ gpus: Union[GPU_T, list[GPU_T]] # TODO(irfansharif): Somehow assert that it's the first kind, in sandboxes
249
359
  cloud: Optional[str]
250
- cpu: Optional[float]
251
- memory: Optional[Union[int, Tuple[int, int]]]
360
+ cpu: Optional[Union[float, tuple[float, float]]]
361
+ memory: Optional[Union[int, tuple[int, int]]]
362
+ ephemeral_disk: Optional[int]
363
+ scheduler_placement: Optional[SchedulerPlacement]
364
+ proxy: Optional[_Proxy]
365
+
252
366
 
367
+ P = typing_extensions.ParamSpec("P")
368
+ ReturnType = typing.TypeVar("ReturnType", covariant=True)
369
+ OriginalReturnType = typing.TypeVar(
370
+ "OriginalReturnType", covariant=True
371
+ ) # differs from return type if ReturnType is coroutine
253
372
 
254
- class _Function(_Object, type_prefix="fu"):
373
+
374
+ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type_prefix="fu"):
255
375
  """Functions are the basic units of serverless execution on Modal.
256
376
 
257
377
  Generally, you will not construct a `Function` directly. Instead, use the
258
- `@app.function()` decorator on the `App` object (formerly called "Stub")
259
- for your application.
378
+ `App.function()` decorator to register your Python functions with your App.
260
379
  """
261
380
 
262
381
  # TODO: more type annotations
263
382
  _info: Optional[FunctionInfo]
264
- _all_mounts: Collection[_Mount]
265
- _app: "modal.app._App"
266
- _obj: Any
383
+ _serve_mounts: frozenset[_Mount] # set at load time, only by loader
384
+ _app: Optional["modal.app._App"] = None
385
+ _obj: Optional["modal.cls._Obj"] = None # only set for InstanceServiceFunctions and bound instance methods
267
386
  _web_url: Optional[str]
268
- _is_remote_cls_method: bool = False # TODO(erikbern): deprecated
269
387
  _function_name: Optional[str]
270
388
  _is_method: bool
271
- _spec: _FunctionSpec
389
+ _spec: Optional[_FunctionSpec] = None
272
390
  _tag: str
273
391
  _raw_f: Callable[..., Any]
274
392
  _build_args: dict
275
- _parent: "_Function"
393
+
394
+ _is_generator: Optional[bool] = None
395
+ _cluster_size: Optional[int] = None
396
+
397
+ # when this is the method of a class/object function, invocation of this function
398
+ # should supply the method name in the FunctionInput:
399
+ _use_method_name: str = ""
400
+
401
+ _class_parameter_info: Optional["api_pb2.ClassParameterInfo"] = None
402
+ _method_handle_metadata: Optional[dict[str, "api_pb2.FunctionHandleMetadata"]] = None
403
+
404
+ def _bind_method(
405
+ self,
406
+ user_cls,
407
+ method_name: str,
408
+ partial_function: "modal.partial_function._PartialFunction",
409
+ ):
410
+ """mdmd:hidden
411
+
412
+ Creates a _Function that is bound to a specific class method name. This _Function is not uniquely tied
413
+ to any backend function -- its object_id is the function ID of the class service function.
414
+
415
+ """
416
+ class_service_function = self
417
+ assert class_service_function._info # has to be a local function to be able to "bind" it
418
+ assert not class_service_function._is_method # should not be used on an already bound method placeholder
419
+ assert not class_service_function._obj # should only be used on base function / class service function
420
+ full_name = f"{user_cls.__name__}.{method_name}"
421
+
422
+ rep = f"Method({full_name})"
423
+ fun = _Object.__new__(_Function)
424
+ fun._init(rep)
425
+ fun._tag = full_name
426
+ fun._raw_f = partial_function.raw_f
427
+ fun._info = FunctionInfo(
428
+ partial_function.raw_f, user_cls=user_cls, serialized=class_service_function.info.is_serialized()
429
+ ) # needed for .local()
430
+ fun._use_method_name = method_name
431
+ fun._app = class_service_function._app
432
+ fun._is_generator = partial_function.is_generator
433
+ fun._cluster_size = partial_function.cluster_size
434
+ fun._spec = class_service_function._spec
435
+ fun._is_method = True
436
+ return fun
276
437
 
277
438
  @staticmethod
278
439
  def from_args(
279
440
  info: FunctionInfo,
280
441
  app,
281
442
  image: _Image,
282
- secret: Optional[_Secret] = None,
283
443
  secrets: Sequence[_Secret] = (),
284
444
  schedule: Optional[Schedule] = None,
285
- is_generator=False,
286
- gpu: GPU_T = None,
445
+ is_generator: bool = False,
446
+ gpu: Union[GPU_T, list[GPU_T]] = None,
287
447
  # TODO: maybe break this out into a separate decorator for notebooks.
288
448
  mounts: Collection[_Mount] = (),
289
- network_file_systems: Dict[Union[str, PurePosixPath], _NetworkFileSystem] = {},
449
+ network_file_systems: dict[Union[str, PurePosixPath], _NetworkFileSystem] = {},
290
450
  allow_cross_region_volumes: bool = False,
291
- volumes: Dict[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]] = {},
451
+ volumes: dict[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]] = {},
292
452
  webhook_config: Optional[api_pb2.WebhookConfig] = None,
293
- memory: Optional[Union[int, Tuple[int, int]]] = None,
453
+ memory: Optional[Union[int, tuple[int, int]]] = None,
294
454
  proxy: Optional[_Proxy] = None,
295
455
  retries: Optional[Union[int, Retries]] = None,
296
456
  timeout: Optional[int] = None,
297
457
  concurrency_limit: Optional[int] = None,
298
458
  allow_concurrent_inputs: Optional[int] = None,
459
+ batch_max_size: Optional[int] = None,
460
+ batch_wait_ms: Optional[int] = None,
299
461
  container_idle_timeout: Optional[int] = None,
300
- cpu: Optional[float] = None,
462
+ cpu: Optional[Union[float, tuple[float, float]]] = None,
301
463
  keep_warm: Optional[int] = None, # keep_warm=True is equivalent to keep_warm=1
302
464
  cloud: Optional[str] = None,
303
- _experimental_boost: bool = False,
304
- _experimental_scheduler: bool = False,
305
- _experimental_scheduler_placement: Optional[SchedulerPlacement] = None,
465
+ scheduler_placement: Optional[SchedulerPlacement] = None,
306
466
  is_builder_function: bool = False,
307
467
  is_auto_snapshot: bool = False,
308
468
  enable_memory_snapshot: bool = False,
309
- checkpointing_enabled: Optional[bool] = None,
310
- allow_background_volume_commits: bool = False,
311
469
  block_network: bool = False,
470
+ i6pn_enabled: bool = False,
471
+ cluster_size: Optional[int] = None, # Experimental: Clustered functions
312
472
  max_inputs: Optional[int] = None,
473
+ ephemeral_disk: Optional[int] = None,
474
+ _experimental_buffer_containers: Optional[int] = None,
475
+ _experimental_proxy_ip: Optional[str] = None,
476
+ _experimental_custom_scaling_factor: Optional[float] = None,
313
477
  ) -> None:
314
478
  """mdmd:hidden"""
479
+ # Needed to avoid circular imports
480
+ from .partial_function import _find_partial_methods_for_user_cls, _PartialFunctionFlags
481
+
315
482
  tag = info.get_tag()
316
483
 
317
- raw_f = info.raw_f
318
- assert callable(raw_f)
319
- if schedule is not None:
320
- if not info.is_nullary():
484
+ if info.raw_f:
485
+ raw_f = info.raw_f
486
+ assert callable(raw_f)
487
+ if schedule is not None and not info.is_nullary():
321
488
  raise InvalidError(
322
489
  f"Function {raw_f} has a schedule, so it needs to support being called with no arguments"
323
490
  )
324
-
325
- if secret is not None:
326
- deprecation_warning(
327
- (2024, 1, 31),
328
- "The singular `secret` parameter is deprecated. Pass a list to `secrets` instead.",
329
- )
330
- secrets = [secret, *secrets]
331
-
332
- if checkpointing_enabled is not None:
333
- deprecation_warning(
334
- (2024, 3, 4),
335
- "The argument `checkpointing_enabled` is now deprecated. Use `enable_memory_snapshot` instead.",
336
- )
337
- enable_memory_snapshot = checkpointing_enabled
491
+ else:
492
+ # must be a "class service function"
493
+ assert info.user_cls
494
+ assert not webhook_config
495
+ assert not schedule
338
496
 
339
497
  explicit_mounts = mounts
340
498
 
341
499
  if is_local():
342
500
  entrypoint_mounts = info.get_entrypoint_mount()
501
+
343
502
  all_mounts = [
344
503
  _get_client_mount(),
345
504
  *explicit_mounts,
@@ -354,34 +513,47 @@ class _Function(_Object, type_prefix="fu"):
354
513
  # TODO: maybe the entire constructor should be exited early if not local?
355
514
  all_mounts = []
356
515
 
357
- retry_policy = _parse_retries(retries, raw_f)
516
+ retry_policy = _parse_retries(
517
+ retries, f"Function '{info.get_tag()}'" if info.raw_f else f"Class '{info.get_tag()}'"
518
+ )
358
519
 
359
- gpu_config = parse_gpu_config(gpu)
520
+ if webhook_config is not None and retry_policy is not None:
521
+ raise InvalidError(
522
+ "Web endpoints do not support retries.",
523
+ )
524
+
525
+ if is_generator and retry_policy is not None:
526
+ deprecation_warning(
527
+ (2024, 6, 25),
528
+ "Retries for generator functions are deprecated and will soon be removed.",
529
+ )
360
530
 
361
531
  if proxy:
362
532
  # HACK: remove this once we stop using ssh tunnels for this.
363
533
  if image:
534
+ # TODO(elias): this will cause an error if users use prior `.add_local_*` commands without copy=True
364
535
  image = image.apt_install("autossh")
365
536
 
366
537
  function_spec = _FunctionSpec(
367
538
  mounts=all_mounts,
368
539
  secrets=secrets,
369
- gpu=gpu,
540
+ gpus=gpu,
370
541
  network_file_systems=network_file_systems,
371
542
  volumes=volumes,
372
543
  image=image,
373
544
  cloud=cloud,
374
545
  cpu=cpu,
375
546
  memory=memory,
547
+ ephemeral_disk=ephemeral_disk,
548
+ scheduler_placement=scheduler_placement,
549
+ proxy=proxy,
376
550
  )
377
551
 
378
- if info.cls and not is_auto_snapshot:
379
- # Needed to avoid circular imports
380
- from .partial_function import _find_callables_for_cls, _PartialFunctionFlags
381
-
382
- build_functions = list(_find_callables_for_cls(info.cls, _PartialFunctionFlags.BUILD).values())
383
- for build_function in build_functions:
384
- snapshot_info = FunctionInfo(build_function, cls=info.cls)
552
+ if info.user_cls and not is_auto_snapshot:
553
+ build_functions = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.BUILD).items()
554
+ for k, pf in build_functions:
555
+ build_function = pf.raw_f
556
+ snapshot_info = FunctionInfo(build_function, user_cls=info.user_cls)
385
557
  snapshot_function = _Function.from_args(
386
558
  snapshot_info,
387
559
  app=None,
@@ -392,16 +564,17 @@ class _Function(_Object, type_prefix="fu"):
392
564
  network_file_systems=network_file_systems,
393
565
  volumes=volumes,
394
566
  memory=memory,
395
- timeout=86400, # TODO: make this an argument to `@build()`
567
+ timeout=pf.build_timeout,
396
568
  cpu=cpu,
569
+ ephemeral_disk=ephemeral_disk,
397
570
  is_builder_function=True,
398
571
  is_auto_snapshot=True,
399
- _experimental_scheduler_placement=_experimental_scheduler_placement,
572
+ scheduler_placement=scheduler_placement,
400
573
  )
401
574
  image = _Image._from_args(
402
575
  base_images={"base": image},
403
576
  build_function=snapshot_function,
404
- force_build=image.force_build,
577
+ force_build=image.force_build or pf.force_build,
405
578
  )
406
579
 
407
580
  if keep_warm is not None and not isinstance(keep_warm, int):
@@ -409,9 +582,15 @@ class _Function(_Object, type_prefix="fu"):
409
582
 
410
583
  if (keep_warm is not None) and (concurrency_limit is not None) and concurrency_limit < keep_warm:
411
584
  raise InvalidError(
412
- f"Function `{info.function_name}` has `{concurrency_limit=}`, strictly less than its `{keep_warm=}` parameter."
585
+ f"Function `{info.function_name}` has `{concurrency_limit=}`, "
586
+ f"strictly less than its `{keep_warm=}` parameter."
413
587
  )
414
588
 
589
+ if _experimental_custom_scaling_factor is not None and (
590
+ _experimental_custom_scaling_factor < 0 or _experimental_custom_scaling_factor > 1
591
+ ):
592
+ raise InvalidError("`_experimental_custom_scaling_factor` must be between 0.0 and 1.0 inclusive.")
593
+
415
594
  if not cloud and not is_builder_function:
416
595
  cloud = config.get("default_cloud")
417
596
  if cloud:
@@ -428,22 +607,56 @@ class _Function(_Object, type_prefix="fu"):
428
607
  else:
429
608
  raise InvalidError("Webhooks cannot be generators")
430
609
 
610
+ if info.raw_f and batch_max_size:
611
+ func_name = info.raw_f.__name__
612
+ if is_generator:
613
+ raise InvalidError(f"Modal batched function {func_name} cannot return generators")
614
+ for arg in inspect.signature(info.raw_f).parameters.values():
615
+ if arg.default is not inspect.Parameter.empty:
616
+ raise InvalidError(f"Modal batched function {func_name} does not accept default arguments.")
617
+
618
+ if container_idle_timeout is not None and container_idle_timeout <= 0:
619
+ raise InvalidError("`container_idle_timeout` must be > 0")
620
+
621
+ if max_inputs is not None:
622
+ if not isinstance(max_inputs, int):
623
+ raise InvalidError(f"`max_inputs` must be an int, not {type(max_inputs).__name__}")
624
+ if max_inputs <= 0:
625
+ raise InvalidError("`max_inputs` must be positive")
626
+ if max_inputs > 1:
627
+ raise InvalidError("Only `max_inputs=1` is currently supported")
628
+
431
629
  # Validate volumes
432
630
  validated_volumes = validate_volumes(volumes)
433
631
  cloud_bucket_mounts = [(k, v) for k, v in validated_volumes if isinstance(v, _CloudBucketMount)]
434
632
  validated_volumes = [(k, v) for k, v in validated_volumes if isinstance(v, _Volume)]
435
633
 
436
634
  # Validate NFS
437
- if not isinstance(network_file_systems, dict):
438
- raise InvalidError("network_file_systems must be a dict[str, NetworkFileSystem] where the keys are paths")
439
- validated_network_file_systems = validate_mount_points("Network file system", network_file_systems)
635
+ validated_network_file_systems = validate_network_file_systems(network_file_systems)
440
636
 
441
637
  # Validate image
442
638
  if image is not None and not isinstance(image, _Image):
443
639
  raise InvalidError(f"Expected modal.Image object. Got {type(image)}.")
444
640
 
445
- def _deps(only_explicit_mounts=False) -> List[_Object]:
446
- deps: List[_Object] = list(secrets)
641
+ method_definitions: Optional[dict[str, api_pb2.MethodDefinition]] = None
642
+
643
+ if info.user_cls:
644
+ method_definitions = {}
645
+ partial_functions = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.FUNCTION)
646
+ for method_name, partial_function in partial_functions.items():
647
+ function_type = get_function_type(partial_function.is_generator)
648
+ function_name = f"{info.user_cls.__name__}.{method_name}"
649
+ method_definition = api_pb2.MethodDefinition(
650
+ webhook_config=partial_function.webhook_config,
651
+ function_type=function_type,
652
+ function_name=function_name,
653
+ )
654
+ method_definitions[method_name] = method_definition
655
+
656
+ function_type = get_function_type(is_generator)
657
+
658
+ def _deps(only_explicit_mounts=False) -> list[_Object]:
659
+ deps: list[_Object] = list(secrets)
447
660
  if only_explicit_mounts:
448
661
  # TODO: this is a bit hacky, but all_mounts may differ in the container vs locally
449
662
  # We don't want the function dependencies to change, so we have this way to force it to
@@ -467,267 +680,358 @@ class _Function(_Object, type_prefix="fu"):
467
680
  if cloud_bucket_mount.secret:
468
681
  deps.append(cloud_bucket_mount.secret)
469
682
 
470
- # Add implicit dependencies from the function's code
471
- objs: list[Object] = get_referred_objects(info.raw_f)
472
- _objs: list[_Object] = synchronizer._translate_in(objs) # type: ignore
473
- deps += _objs
474
683
  return deps
475
684
 
476
685
  async def _preload(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
477
686
  assert resolver.client and resolver.client.stub
478
- if is_generator:
479
- function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR
480
- else:
481
- function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION
482
687
 
688
+ assert resolver.app_id
483
689
  req = api_pb2.FunctionPrecreateRequest(
484
690
  app_id=resolver.app_id,
485
691
  function_name=info.function_name,
486
692
  function_type=function_type,
487
- webhook_config=webhook_config,
488
693
  existing_function_id=existing_object_id or "",
489
694
  )
695
+ if method_definitions:
696
+ for method_name, method_definition in method_definitions.items():
697
+ req.method_definitions[method_name].CopyFrom(method_definition)
698
+ elif webhook_config:
699
+ req.webhook_config.CopyFrom(webhook_config)
490
700
  response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req)
491
701
  self._hydrate(response.function_id, resolver.client, response.handle_metadata)
492
702
 
493
703
  async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
494
704
  assert resolver.client and resolver.client.stub
495
- status_row = resolver.add_status_row()
496
- status_row.message(f"Creating {tag}...")
705
+ with FunctionCreationStatus(resolver, tag) as function_creation_status:
706
+ timeout_secs = timeout
497
707
 
498
- if is_generator:
499
- function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR
500
- else:
501
- function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION
502
-
503
- timeout_secs = timeout
504
-
505
- if app and app.is_interactive and not is_builder_function:
506
- pty_info = get_pty_info(shell=False)
507
- else:
508
- pty_info = None
509
-
510
- if info.is_serialized():
511
- # Use cloudpickle. Used when working w/ Jupyter notebooks.
512
- # serialize at _load time, not function decoration time
513
- # otherwise we can't capture a surrounding class for lifetime methods etc.
514
- function_serialized = info.serialized_function()
515
- class_serialized = serialize(info.cls) if info.cls is not None else None
516
-
517
- # Ensure that large data in global variables does not blow up the gRPC payload,
518
- # which has maximum size 100 MiB. We set the limit lower for performance reasons.
519
- if len(function_serialized) > 16 << 20: # 16 MiB
520
- raise InvalidError(
521
- f"Function {info.raw_f} has size {len(function_serialized)} bytes when packaged. "
522
- "This is larger than the maximum limit of 16 MiB. "
523
- "Try reducing the size of the closure by using parameters or mounts, not large global variables."
524
- )
525
- elif len(function_serialized) > 256 << 10: # 256 KiB
526
- warnings.warn(
527
- f"Function {info.raw_f} has size {len(function_serialized)} bytes when packaged. "
528
- "This is larger than the recommended limit of 256 KiB. "
529
- "Try reducing the size of the closure by using parameters or mounts, not large global variables."
708
+ if app and app.is_interactive and not is_builder_function:
709
+ pty_info = get_pty_info(shell=False)
710
+ else:
711
+ pty_info = None
712
+
713
+ if info.is_serialized():
714
+ # Use cloudpickle. Used when working w/ Jupyter notebooks.
715
+ # serialize at _load time, not function decoration time
716
+ # otherwise we can't capture a surrounding class for lifetime methods etc.
717
+ function_serialized = info.serialized_function()
718
+ class_serialized = serialize(info.user_cls) if info.user_cls is not None else None
719
+ # Ensure that large data in global variables does not blow up the gRPC payload,
720
+ # which has maximum size 100 MiB. We set the limit lower for performance reasons.
721
+ if len(function_serialized) > 16 << 20: # 16 MiB
722
+ raise InvalidError(
723
+ f"Function {info.raw_f} has size {len(function_serialized)} bytes when packaged. "
724
+ "This is larger than the maximum limit of 16 MiB. "
725
+ "Try reducing the size of the closure by using parameters or mounts, "
726
+ "not large global variables."
727
+ )
728
+ elif len(function_serialized) > 256 << 10: # 256 KiB
729
+ warnings.warn(
730
+ f"Function {info.raw_f} has size {len(function_serialized)} bytes when packaged. "
731
+ "This is larger than the recommended limit of 256 KiB. "
732
+ "Try reducing the size of the closure by using parameters or mounts, "
733
+ "not large global variables."
734
+ )
735
+ else:
736
+ function_serialized = None
737
+ class_serialized = None
738
+
739
+ app_name = ""
740
+ if app and app.name:
741
+ app_name = app.name
742
+
743
+ # Relies on dicts being ordered (true as of Python 3.6).
744
+ volume_mounts = [
745
+ api_pb2.VolumeMount(
746
+ mount_path=path,
747
+ volume_id=volume.object_id,
748
+ allow_background_commits=True,
530
749
  )
531
- else:
532
- function_serialized = None
533
- class_serialized = None
534
-
535
- app_name = ""
536
- if app and app.name:
537
- app_name = app.name
538
-
539
- # Relies on dicts being ordered (true as of Python 3.6).
540
- volume_mounts = [
541
- api_pb2.VolumeMount(
542
- mount_path=path,
543
- volume_id=volume.object_id,
544
- allow_background_commits=allow_background_volume_commits,
750
+ for path, volume in validated_volumes
751
+ ]
752
+ loaded_mount_ids = {m.object_id for m in all_mounts} | {m.object_id for m in image._mount_layers}
753
+
754
+ # Get object dependencies
755
+ object_dependencies = []
756
+ for dep in _deps(only_explicit_mounts=True):
757
+ if not dep.object_id:
758
+ raise Exception(f"Dependency {dep} isn't hydrated")
759
+ object_dependencies.append(api_pb2.ObjectDependency(object_id=dep.object_id))
760
+
761
+ function_data: Optional[api_pb2.FunctionData] = None
762
+ function_definition: Optional[api_pb2.Function] = None
763
+
764
+ # Create function remotely
765
+ function_definition = api_pb2.Function(
766
+ module_name=info.module_name or "",
767
+ function_name=info.function_name,
768
+ mount_ids=loaded_mount_ids,
769
+ secret_ids=[secret.object_id for secret in secrets],
770
+ image_id=(image.object_id if image else ""),
771
+ definition_type=info.get_definition_type(),
772
+ function_serialized=function_serialized or b"",
773
+ class_serialized=class_serialized or b"",
774
+ function_type=function_type,
775
+ webhook_config=webhook_config,
776
+ method_definitions=method_definitions,
777
+ method_definitions_set=True,
778
+ shared_volume_mounts=network_file_system_mount_protos(
779
+ validated_network_file_systems, allow_cross_region_volumes
780
+ ),
781
+ volume_mounts=volume_mounts,
782
+ proxy_id=(proxy.object_id if proxy else None),
783
+ retry_policy=retry_policy,
784
+ timeout_secs=timeout_secs or 0,
785
+ task_idle_timeout_secs=container_idle_timeout or 0,
786
+ concurrency_limit=concurrency_limit or 0,
787
+ pty_info=pty_info,
788
+ cloud_provider=cloud_provider,
789
+ warm_pool_size=keep_warm or 0,
790
+ runtime=config.get("function_runtime"),
791
+ runtime_debug=config.get("function_runtime_debug"),
792
+ runtime_perf_record=config.get("runtime_perf_record"),
793
+ app_name=app_name,
794
+ is_builder_function=is_builder_function,
795
+ target_concurrent_inputs=allow_concurrent_inputs or 0,
796
+ batch_max_size=batch_max_size or 0,
797
+ batch_linger_ms=batch_wait_ms or 0,
798
+ worker_id=config.get("worker_id"),
799
+ is_auto_snapshot=is_auto_snapshot,
800
+ is_method=bool(info.user_cls) and not info.is_service_class(),
801
+ checkpointing_enabled=enable_memory_snapshot,
802
+ object_dependencies=object_dependencies,
803
+ block_network=block_network,
804
+ max_inputs=max_inputs or 0,
805
+ cloud_bucket_mounts=cloud_bucket_mounts_to_proto(cloud_bucket_mounts),
806
+ scheduler_placement=scheduler_placement.proto if scheduler_placement else None,
807
+ is_class=info.is_service_class(),
808
+ class_parameter_info=info.class_parameter_info(),
809
+ i6pn_enabled=i6pn_enabled,
810
+ schedule=schedule.proto_message if schedule is not None else None,
811
+ snapshot_debug=config.get("snapshot_debug"),
812
+ _experimental_group_size=cluster_size or 0, # Experimental: Clustered functions
813
+ _experimental_concurrent_cancellations=True,
814
+ _experimental_buffer_containers=_experimental_buffer_containers or 0,
815
+ _experimental_proxy_ip=_experimental_proxy_ip,
816
+ _experimental_custom_scaling=_experimental_custom_scaling_factor is not None,
545
817
  )
546
- for path, volume in validated_volumes
547
- ]
548
- loaded_mount_ids = {m.object_id for m in all_mounts}
549
-
550
- # Get object dependencies
551
- object_dependencies = []
552
- for dep in _deps(only_explicit_mounts=True):
553
- if not dep.object_id:
554
- raise Exception(f"Dependency {dep} isn't hydrated")
555
- object_dependencies.append(api_pb2.ObjectDependency(object_id=dep.object_id))
556
-
557
- # Create function remotely
558
- function_definition = api_pb2.Function(
559
- module_name=info.module_name or "",
560
- function_name=info.function_name,
561
- mount_ids=loaded_mount_ids,
562
- secret_ids=[secret.object_id for secret in secrets],
563
- image_id=(image.object_id if image else ""),
564
- definition_type=info.definition_type,
565
- function_serialized=function_serialized or b"",
566
- class_serialized=class_serialized or b"",
567
- function_type=function_type,
568
- resources=convert_fn_config_to_resources_config(cpu=cpu, memory=memory, gpu=gpu),
569
- webhook_config=webhook_config,
570
- shared_volume_mounts=network_file_system_mount_protos(
571
- validated_network_file_systems, allow_cross_region_volumes
572
- ),
573
- volume_mounts=volume_mounts,
574
- proxy_id=(proxy.object_id if proxy else None),
575
- retry_policy=retry_policy,
576
- timeout_secs=timeout_secs or 0,
577
- task_idle_timeout_secs=container_idle_timeout or 0,
578
- concurrency_limit=concurrency_limit or 0,
579
- pty_info=pty_info,
580
- cloud_provider=cloud_provider,
581
- warm_pool_size=keep_warm or 0,
582
- runtime=config.get("function_runtime"),
583
- runtime_debug=config.get("function_runtime_debug"),
584
- app_name=app_name,
585
- is_builder_function=is_builder_function,
586
- allow_concurrent_inputs=allow_concurrent_inputs or 0,
587
- worker_id=config.get("worker_id"),
588
- is_auto_snapshot=is_auto_snapshot,
589
- is_method=bool(info.cls),
590
- checkpointing_enabled=enable_memory_snapshot,
591
- is_checkpointing_function=False,
592
- object_dependencies=object_dependencies,
593
- block_network=block_network,
594
- max_inputs=max_inputs or 0,
595
- cloud_bucket_mounts=cloud_bucket_mounts_to_proto(cloud_bucket_mounts),
596
- _experimental_boost=_experimental_boost,
597
- _experimental_scheduler=_experimental_scheduler,
598
- _experimental_scheduler_placement=_experimental_scheduler_placement.proto
599
- if _experimental_scheduler_placement
600
- else None,
601
- )
602
- request = api_pb2.FunctionCreateRequest(
603
- app_id=resolver.app_id,
604
- function=function_definition,
605
- schedule=schedule.proto_message if schedule is not None else None,
606
- existing_function_id=existing_object_id or "",
607
- )
608
- try:
609
- response: api_pb2.FunctionCreateResponse = await retry_transient_errors(
610
- resolver.client.stub.FunctionCreate, request
611
- )
612
- except GRPCError as exc:
613
- if exc.status == Status.INVALID_ARGUMENT:
614
- raise InvalidError(exc.message)
615
- if exc.status == Status.FAILED_PRECONDITION:
616
- raise InvalidError(exc.message)
617
- if exc.message and "Received :status = '413'" in exc.message:
618
- raise InvalidError(f"Function {raw_f} is too large to deploy.")
619
- raise
620
-
621
- if response.function.web_url:
622
- # Ensure terms used here match terms used in modal.com/docs/guide/webhook-urls doc.
623
- if response.function.web_url_info.truncated:
624
- suffix = " [grey70](label truncated)[/grey70]"
625
- elif response.function.web_url_info.has_unique_hash:
626
- suffix = " [grey70](label includes conflict-avoidance hash)[/grey70]"
627
- elif response.function.web_url_info.label_stolen:
628
- suffix = " [grey70](label stolen)[/grey70]"
629
- else:
630
- suffix = ""
631
- # TODO: this is only printed when we're showing progress. Maybe move this somewhere else.
632
- status_row.finish(f"Created {tag} => [magenta underline]{response.web_url}[/magenta underline]{suffix}")
633
-
634
- # Print custom domain in terminal
635
- for custom_domain in response.function.custom_domain_info:
636
- custom_domain_status_row = resolver.add_status_row()
637
- custom_domain_status_row.finish(
638
- f"Custom domain for {tag} => [magenta underline]{custom_domain.url}[/magenta underline]{suffix}"
818
+
819
+ if isinstance(gpu, list):
820
+ function_data = api_pb2.FunctionData(
821
+ module_name=function_definition.module_name,
822
+ function_name=function_definition.function_name,
823
+ function_type=function_definition.function_type,
824
+ warm_pool_size=function_definition.warm_pool_size,
825
+ concurrency_limit=function_definition.concurrency_limit,
826
+ task_idle_timeout_secs=function_definition.task_idle_timeout_secs,
827
+ worker_id=function_definition.worker_id,
828
+ timeout_secs=function_definition.timeout_secs,
829
+ web_url=function_definition.web_url,
830
+ web_url_info=function_definition.web_url_info,
831
+ webhook_config=function_definition.webhook_config,
832
+ custom_domain_info=function_definition.custom_domain_info,
833
+ schedule=schedule.proto_message if schedule is not None else None,
834
+ is_class=function_definition.is_class,
835
+ class_parameter_info=function_definition.class_parameter_info,
836
+ is_method=function_definition.is_method,
837
+ use_function_id=function_definition.use_function_id,
838
+ use_method_name=function_definition.use_method_name,
839
+ method_definitions=function_definition.method_definitions,
840
+ method_definitions_set=function_definition.method_definitions_set,
841
+ _experimental_group_size=function_definition._experimental_group_size,
842
+ _experimental_buffer_containers=function_definition._experimental_buffer_containers,
843
+ _experimental_custom_scaling=function_definition._experimental_custom_scaling,
844
+ _experimental_proxy_ip=function_definition._experimental_proxy_ip,
845
+ snapshot_debug=function_definition.snapshot_debug,
846
+ runtime_perf_record=function_definition.runtime_perf_record,
639
847
  )
640
848
 
641
- else:
642
- status_row.finish(f"Created {tag}.")
849
+ ranked_functions = []
850
+ for rank, _gpu in enumerate(gpu):
851
+ function_definition_copy = api_pb2.Function()
852
+ function_definition_copy.CopyFrom(function_definition)
853
+
854
+ function_definition_copy.resources.CopyFrom(
855
+ convert_fn_config_to_resources_config(
856
+ cpu=cpu, memory=memory, gpu=_gpu, ephemeral_disk=ephemeral_disk
857
+ ),
858
+ )
859
+ ranked_function = api_pb2.FunctionData.RankedFunction(
860
+ rank=rank,
861
+ function=function_definition_copy,
862
+ )
863
+ ranked_functions.append(ranked_function)
864
+ function_data.ranked_functions.extend(ranked_functions)
865
+ function_definition = None # function_definition is not used in this case
866
+ else:
867
+ # TODO(irfansharif): Assert on this specific type once we get rid of python 3.9.
868
+ # assert isinstance(gpu, GPU_T) # includes the case where gpu==None case
869
+ function_definition.resources.CopyFrom(
870
+ convert_fn_config_to_resources_config(
871
+ cpu=cpu, memory=memory, gpu=gpu, ephemeral_disk=ephemeral_disk
872
+ ), # type: ignore
873
+ )
643
874
 
875
+ assert resolver.app_id
876
+ assert (function_definition is None) != (function_data is None) # xor
877
+ request = api_pb2.FunctionCreateRequest(
878
+ app_id=resolver.app_id,
879
+ function=function_definition,
880
+ function_data=function_data,
881
+ existing_function_id=existing_object_id or "",
882
+ defer_updates=True,
883
+ )
884
+ try:
885
+ response: api_pb2.FunctionCreateResponse = await retry_transient_errors(
886
+ resolver.client.stub.FunctionCreate, request
887
+ )
888
+ except GRPCError as exc:
889
+ if exc.status == Status.INVALID_ARGUMENT:
890
+ raise InvalidError(exc.message)
891
+ if exc.status == Status.FAILED_PRECONDITION:
892
+ raise InvalidError(exc.message)
893
+ if exc.message and "Received :status = '413'" in exc.message:
894
+ raise InvalidError(f"Function {info.function_name} is too large to deploy.")
895
+ raise
896
+ function_creation_status.set_response(response)
897
+ serve_mounts = {m for m in all_mounts if m.is_local()} # needed for modal.serve file watching
898
+ serve_mounts |= image._serve_mounts
899
+ obj._serve_mounts = frozenset(serve_mounts)
644
900
  self._hydrate(response.function_id, resolver.client, response.handle_metadata)
645
901
 
646
902
  rep = f"Function({tag})"
647
903
  obj = _Function._from_loader(_load, rep, preload=_preload, deps=_deps)
648
904
 
649
- obj._raw_f = raw_f
905
+ obj._raw_f = info.raw_f
650
906
  obj._info = info
651
907
  obj._tag = tag
652
- obj._all_mounts = all_mounts # needed for modal.serve file watching
653
908
  obj._app = app # needed for CLI right now
654
909
  obj._obj = None
655
910
  obj._is_generator = is_generator
656
- obj._is_method = bool(info.cls)
911
+ obj._cluster_size = cluster_size
912
+ obj._is_method = False
657
913
  obj._spec = function_spec # needed for modal shell
658
914
 
659
- # Used to check whether we should rebuild an image using run_function
660
- # Plaintext source and arg definition for the function, so it's part of the image
661
- # hash. We can't use the cloudpickle hash because it's not very stable.
915
+ # Used to check whether we should rebuild a modal.Image which uses `run_function`.
916
+ gpus: list[GPU_T] = gpu if isinstance(gpu, list) else [gpu]
662
917
  obj._build_args = dict( # See get_build_def
663
918
  secrets=repr(secrets),
664
- gpu_config=repr(gpu_config),
919
+ gpu_config=repr([parse_gpu_config(_gpu) for _gpu in gpus]),
665
920
  mounts=repr(mounts),
666
921
  network_file_systems=repr(network_file_systems),
667
922
  )
923
+ # these key are excluded if empty to avoid rebuilds on client upgrade
924
+ if volumes:
925
+ obj._build_args["volumes"] = repr(volumes)
926
+ if cloud or scheduler_placement:
927
+ obj._build_args["cloud"] = repr(cloud)
928
+ obj._build_args["scheduler_placement"] = repr(scheduler_placement)
668
929
 
669
930
  return obj
670
931
 
671
- def from_parametrized(
932
+ def _bind_parameters(
672
933
  self,
673
- obj,
674
- from_other_workspace: bool,
934
+ obj: "modal.cls._Obj",
675
935
  options: Optional[api_pb2.FunctionOptions],
676
936
  args: Sized,
677
- kwargs: Dict[str, Any],
937
+ kwargs: dict[str, Any],
678
938
  ) -> "_Function":
679
- """mdmd:hidden"""
939
+ """mdmd:hidden
680
940
 
681
- async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
682
- if not self._parent.is_hydrated:
941
+ Binds a class-function to a specific instance of (init params, options) or a new workspace
942
+ """
943
+
944
+ # In some cases, reuse the base function, i.e. not create new clones of each method or the "service function"
945
+ can_use_parent = len(args) + len(kwargs) == 0 and options is None
946
+ parent = self
947
+
948
+ async def _load(param_bound_func: _Function, resolver: Resolver, existing_object_id: Optional[str]):
949
+ if parent is None:
950
+ raise ExecutionError("Can't find the parent class' service function")
951
+ try:
952
+ identity = f"{parent.info.function_name} class service function"
953
+ except Exception:
954
+ # Can't always look up the function name that way, so fall back to generic message
955
+ identity = "class service function for a parameterized class"
956
+ if not parent.is_hydrated:
957
+ if parent.app._running_app is None:
958
+ reason = ", because the App it is defined on is not running"
959
+ else:
960
+ reason = ""
683
961
  raise ExecutionError(
684
- "Base function in class has not been hydrated. This might happen if an object is"
685
- " defined on a different stub, or if it's on the same stub but it didn't get"
686
- " created because it wasn't defined in global scope."
962
+ f"The {identity} has not been hydrated with the metadata it needs to run on Modal{reason}."
687
963
  )
688
- assert self._parent._client.stub
689
- serialized_params = serialize((args, kwargs))
964
+
965
+ assert parent._client.stub
966
+
967
+ if can_use_parent:
968
+ # We can end up here if parent wasn't hydrated when class was instantiated, but has been since.
969
+ param_bound_func._hydrate_from_other(parent)
970
+ return
971
+
972
+ if (
973
+ parent._class_parameter_info
974
+ and parent._class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO
975
+ ):
976
+ if args:
977
+ # TODO(elias) - We could potentially support positional args as well, if we want to?
978
+ raise InvalidError(
979
+ "Can't use positional arguments with modal.parameter-based synthetic constructors.\n"
980
+ "Use (<parameter_name>=value) keyword arguments when constructing classes instead."
981
+ )
982
+ serialized_params = serialize_proto_params(kwargs, parent._class_parameter_info.schema)
983
+ else:
984
+ serialized_params = serialize((args, kwargs))
690
985
  environment_name = _get_environment_name(None, resolver)
986
+ assert parent is not None
691
987
  req = api_pb2.FunctionBindParamsRequest(
692
- function_id=self._parent._object_id,
988
+ function_id=parent._object_id,
693
989
  serialized_params=serialized_params,
694
990
  function_options=options,
695
991
  environment_name=environment_name
696
992
  or "", # TODO: investigate shouldn't environment name always be specified here?
697
993
  )
698
- response = await retry_transient_errors(self._parent._client.stub.FunctionBindParams, req)
699
- self._hydrate(response.bound_function_id, self._parent._client, response.handle_metadata)
700
-
701
- fun = _Function._from_loader(_load, "Function(parametrized)", hydrate_lazily=True)
702
- if len(args) + len(kwargs) == 0 and not from_other_workspace and options is None and self.is_hydrated:
703
- # Edge case that lets us hydrate all objects right away
704
- fun._hydrate_from_other(self)
705
- fun._is_remote_cls_method = True # TODO(erikbern): deprecated
994
+
995
+ response = await retry_transient_errors(parent._client.stub.FunctionBindParams, req)
996
+ param_bound_func._hydrate(response.bound_function_id, parent._client, response.handle_metadata)
997
+
998
+ fun: _Function = _Function._from_loader(_load, "Function(parametrized)", hydrate_lazily=True)
999
+
1000
+ if can_use_parent and parent.is_hydrated:
1001
+ # skip the resolver altogether:
1002
+ fun._hydrate_from_other(parent)
1003
+
706
1004
  fun._info = self._info
707
1005
  fun._obj = obj
708
- fun._is_generator = self._is_generator
709
- fun._is_method = True
710
- fun._parent = self
711
-
712
1006
  return fun
713
1007
 
714
1008
  @live_method
715
1009
  async def keep_warm(self, warm_pool_size: int) -> None:
716
- """Set the warm pool size for the function (including parametrized functions).
1010
+ """Set the warm pool size for the function.
717
1011
 
718
- Please exercise care when using this advanced feature! Setting and forgetting a warm pool on functions can lead to increased costs.
1012
+ Please exercise care when using this advanced feature!
1013
+ Setting and forgetting a warm pool on functions can lead to increased costs.
719
1014
 
720
- ```python
1015
+ ```python notest
721
1016
  # Usage on a regular function.
722
1017
  f = modal.Function.lookup("my-app", "function")
723
1018
  f.keep_warm(2)
724
1019
 
725
1020
  # Usage on a parametrized function.
726
1021
  Model = modal.Cls.lookup("my-app", "Model")
727
- Model("fine-tuned-model").inference.keep_warm(2)
1022
+ Model("fine-tuned-model").keep_warm(2)
728
1023
  ```
729
1024
  """
1025
+ if self._is_method:
1026
+ raise InvalidError(
1027
+ textwrap.dedent(
1028
+ """
1029
+ The `.keep_warm()` method can not be used on Modal class *methods* deployed using Modal >v0.63.
730
1030
 
1031
+ Call `.keep_warm()` on the class *instance* instead.
1032
+ """
1033
+ )
1034
+ )
731
1035
  assert self._client and self._client.stub
732
1036
  request = api_pb2.FunctionUpdateSchedulingParamsRequest(
733
1037
  function_id=self._object_id, warm_pool_size_override=warm_pool_size
@@ -735,17 +1039,22 @@ class _Function(_Object, type_prefix="fu"):
735
1039
  await retry_transient_errors(self._client.stub.FunctionUpdateSchedulingParams, request)
736
1040
 
737
1041
  @classmethod
1042
+ @renamed_parameter((2024, 12, 18), "tag", "name")
738
1043
  def from_name(
739
- cls: Type["_Function"],
1044
+ cls: type["_Function"],
740
1045
  app_name: str,
741
- tag: Optional[str] = None,
1046
+ name: str,
742
1047
  namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE,
743
1048
  environment_name: Optional[str] = None,
744
1049
  ) -> "_Function":
745
- """Retrieve a function with a given name and tag.
1050
+ """Reference a Function from a deployed App by its name.
1051
+
1052
+ In contast to `modal.Function.lookup`, this is a lazy method
1053
+ that defers hydrating the local object with metadata from
1054
+ Modal servers until the first time it is actually used.
746
1055
 
747
1056
  ```python
748
- other_function = modal.Function.from_name("other-app", "function")
1057
+ f = modal.Function.from_name("other-app", "function")
749
1058
  ```
750
1059
  """
751
1060
 
@@ -753,7 +1062,7 @@ class _Function(_Object, type_prefix="fu"):
753
1062
  assert resolver.client and resolver.client.stub
754
1063
  request = api_pb2.FunctionGetRequest(
755
1064
  app_name=app_name,
756
- object_tag=tag or "",
1065
+ object_tag=name,
757
1066
  namespace=namespace,
758
1067
  environment_name=_get_environment_name(environment_name, resolver) or "",
759
1068
  )
@@ -765,26 +1074,32 @@ class _Function(_Object, type_prefix="fu"):
765
1074
  else:
766
1075
  raise
767
1076
 
1077
+ print_server_warnings(response.server_warnings)
1078
+
768
1079
  self._hydrate(response.function_id, resolver.client, response.handle_metadata)
769
1080
 
770
1081
  rep = f"Ref({app_name})"
771
- return cls._from_loader(_load_remote, rep, is_another_app=True)
1082
+ return cls._from_loader(_load_remote, rep, is_another_app=True, hydrate_lazily=True)
772
1083
 
773
1084
  @staticmethod
1085
+ @renamed_parameter((2024, 12, 18), "tag", "name")
774
1086
  async def lookup(
775
1087
  app_name: str,
776
- tag: Optional[str] = None,
1088
+ name: str,
777
1089
  namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE,
778
1090
  client: Optional[_Client] = None,
779
1091
  environment_name: Optional[str] = None,
780
1092
  ) -> "_Function":
781
- """Lookup a function with a given name and tag.
1093
+ """Lookup a Function from a deployed App by its name.
782
1094
 
783
- ```python
784
- other_function = modal.Function.lookup("other-app", "function")
1095
+ In contrast to `modal.Function.from_name`, this is an eager method
1096
+ that will hydrate the local object with metadata from Modal servers.
1097
+
1098
+ ```python notest
1099
+ f = modal.Function.lookup("other-app", "function")
785
1100
  ```
786
1101
  """
787
- obj = _Function.from_name(app_name, tag, namespace=namespace, environment_name=environment_name)
1102
+ obj = _Function.from_name(app_name, name, namespace=namespace, environment_name=environment_name)
788
1103
  if client is None:
789
1104
  client = await _Client.from_env()
790
1105
  resolver = Resolver(client=client)
@@ -800,13 +1115,16 @@ class _Function(_Object, type_prefix="fu"):
800
1115
  @property
801
1116
  def app(self) -> "modal.app._App":
802
1117
  """mdmd:hidden"""
1118
+ if self._app is None:
1119
+ raise ExecutionError("The app has not been assigned on the function at this point")
1120
+
803
1121
  return self._app
804
1122
 
805
1123
  @property
806
1124
  def stub(self) -> "modal.app._App":
807
1125
  """mdmd:hidden"""
808
1126
  # Deprecated soon, only for backwards compatibility
809
- return self._app
1127
+ return self.app
810
1128
 
811
1129
  @property
812
1130
  def info(self) -> FunctionInfo:
@@ -817,10 +1135,13 @@ class _Function(_Object, type_prefix="fu"):
817
1135
  @property
818
1136
  def spec(self) -> _FunctionSpec:
819
1137
  """mdmd:hidden"""
1138
+ assert self._spec
820
1139
  return self._spec
821
1140
 
822
1141
  def get_build_def(self) -> str:
823
1142
  """mdmd:hidden"""
1143
+ # Plaintext source and arg definition for the function, so it's part of the image
1144
+ # hash. We can't use the cloudpickle hash because it's not very stable.
824
1145
  assert hasattr(self, "_raw_f") and hasattr(self, "_build_args")
825
1146
  return f"{inspect.getsource(self._raw_f)}\n{repr(self._build_args)}"
826
1147
 
@@ -830,128 +1151,170 @@ class _Function(_Object, type_prefix="fu"):
830
1151
  # Overridden concrete implementation of base class method
831
1152
  self._progress = None
832
1153
  self._is_generator = None
1154
+ self._cluster_size = None
833
1155
  self._web_url = None
834
- self._output_mgr: Optional[OutputManager] = None
835
- self._mute_cancellation = (
836
- False # set when a user terminates the app intentionally, to prevent useless traceback spam
837
- )
838
1156
  self._function_name = None
839
1157
  self._info = None
1158
+ self._serve_mounts = frozenset()
840
1159
 
841
1160
  def _hydrate_metadata(self, metadata: Optional[Message]):
842
1161
  # Overridden concrete implementation of base class method
843
- assert metadata and isinstance(metadata, (api_pb2.Function, api_pb2.FunctionHandleMetadata))
1162
+ assert metadata and isinstance(metadata, api_pb2.FunctionHandleMetadata)
844
1163
  self._is_generator = metadata.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR
845
1164
  self._web_url = metadata.web_url
846
1165
  self._function_name = metadata.function_name
847
1166
  self._is_method = metadata.is_method
1167
+ self._use_method_name = metadata.use_method_name
1168
+ self._class_parameter_info = metadata.class_parameter_info
1169
+ self._method_handle_metadata = dict(metadata.method_handle_metadata)
1170
+ self._definition_id = metadata.definition_id
848
1171
 
849
1172
  def _get_metadata(self):
850
1173
  # Overridden concrete implementation of base class method
851
- assert self._function_name
1174
+ assert self._function_name, f"Function name must be set before metadata can be retrieved for {self}"
852
1175
  return api_pb2.FunctionHandleMetadata(
853
1176
  function_name=self._function_name,
854
- function_type=(
855
- api_pb2.Function.FUNCTION_TYPE_GENERATOR
856
- if self._is_generator
857
- else api_pb2.Function.FUNCTION_TYPE_FUNCTION
858
- ),
1177
+ function_type=get_function_type(self._is_generator),
859
1178
  web_url=self._web_url or "",
1179
+ use_method_name=self._use_method_name,
1180
+ is_method=self._is_method,
1181
+ class_parameter_info=self._class_parameter_info,
1182
+ definition_id=self._definition_id,
1183
+ method_handle_metadata=self._method_handle_metadata,
860
1184
  )
861
1185
 
862
- def _set_mute_cancellation(self, value: bool = True):
863
- self._mute_cancellation = value
864
-
865
- def _set_output_mgr(self, output_mgr: OutputManager):
866
- self._output_mgr = output_mgr
1186
+ def _check_no_web_url(self, fn_name: str):
1187
+ if self._web_url:
1188
+ raise InvalidError(
1189
+ f"A webhook function cannot be invoked for remote execution with `.{fn_name}`. "
1190
+ f"Invoke this function via its web url '{self._web_url}' "
1191
+ + f"or call it locally: {self._function_name}.local()"
1192
+ )
867
1193
 
1194
+ # TODO (live_method on properties is not great, since it could be blocking the event loop from async contexts)
868
1195
  @property
869
- def web_url(self) -> str:
1196
+ @live_method
1197
+ async def web_url(self) -> str:
870
1198
  """URL of a Function running as a web endpoint."""
871
1199
  if not self._web_url:
872
1200
  raise ValueError(
873
- f"No web_url can be found for function {self._function_name}. web_url can only be referenced from a running app context"
1201
+ f"No web_url can be found for function {self._function_name}. web_url "
1202
+ "can only be referenced from a running app context"
874
1203
  )
875
1204
  return self._web_url
876
1205
 
877
1206
  @property
878
- def is_generator(self) -> bool:
1207
+ async def is_generator(self) -> bool:
879
1208
  """mdmd:hidden"""
880
- assert self._is_generator is not None
1209
+ # hacky: kind of like @live_method, but not hydrating if we have the value already from local source
1210
+ if self._is_generator is not None:
1211
+ # this is set if the function or class is local
1212
+ return self._is_generator
1213
+
1214
+ # not set - this is a from_name lookup - hydrate
1215
+ await self.resolve()
1216
+ assert self._is_generator is not None # should be set now
881
1217
  return self._is_generator
882
1218
 
1219
+ @property
1220
+ def cluster_size(self) -> int:
1221
+ """mdmd:hidden"""
1222
+ return self._cluster_size or 1
1223
+
883
1224
  @live_method_gen
884
1225
  async def _map(
885
1226
  self, input_queue: _SynchronizedQueue, order_outputs: bool, return_exceptions: bool
886
1227
  ) -> AsyncGenerator[Any, None]:
887
1228
  """mdmd:hidden
888
1229
 
889
- Synchronicity-wrapped map implementation. To be safe against invocations of user code in the synchronicity thread
890
- it doesn't accept an [async]iterator, and instead takes a _SynchronizedQueue instance that is fed by
891
- higher level functions like .map()
1230
+ Synchronicity-wrapped map implementation. To be safe against invocations of user code in
1231
+ the synchronicity thread it doesn't accept an [async]iterator, and instead takes a
1232
+ _SynchronizedQueue instance that is fed by higher level functions like .map()
892
1233
 
893
1234
  _SynchronizedQueue is used instead of asyncio.Queue so that the main thread can put
894
1235
  items in the queue safely.
895
1236
  """
896
- if self._web_url:
897
- raise InvalidError(
898
- "A web endpoint function cannot be directly invoked for parallel remote execution. "
899
- f"Invoke this function via its web url '{self._web_url}' or call it locally: {self._function_name}()."
900
- )
1237
+ self._check_no_web_url("map")
901
1238
  if self._is_generator:
902
1239
  raise InvalidError("A generator function cannot be called with `.map(...)`.")
903
1240
 
904
1241
  assert self._function_name
905
- count_update_callback = (
906
- self._output_mgr.function_progress_callback(self._function_name, total=None) if self._output_mgr else None
907
- )
1242
+ if output_mgr := _get_output_manager():
1243
+ count_update_callback = output_mgr.function_progress_callback(self._function_name, total=None)
1244
+ else:
1245
+ count_update_callback = None
1246
+
1247
+ async with aclosing(
1248
+ _map_invocation(
1249
+ self, # type: ignore
1250
+ input_queue,
1251
+ self._client,
1252
+ order_outputs,
1253
+ return_exceptions,
1254
+ count_update_callback,
1255
+ )
1256
+ ) as stream:
1257
+ async for item in stream:
1258
+ yield item
908
1259
 
909
- async for item in _map_invocation(
910
- self.object_id,
911
- input_queue,
912
- self._client,
913
- order_outputs,
914
- return_exceptions,
915
- count_update_callback,
916
- ):
917
- yield item
1260
+ async def _call_function(self, args, kwargs) -> ReturnType:
1261
+ if config.get("client_retries"):
1262
+ function_call_invocation_type = api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
1263
+ else:
1264
+ function_call_invocation_type = api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY
1265
+ invocation = await _Invocation.create(
1266
+ self,
1267
+ args,
1268
+ kwargs,
1269
+ client=self._client,
1270
+ function_call_invocation_type=function_call_invocation_type,
1271
+ )
918
1272
 
919
- async def _call_function(self, args, kwargs):
920
- invocation = await _Invocation.create(self.object_id, args, kwargs, self._client)
921
- try:
922
- return await invocation.run_function()
923
- except asyncio.CancelledError:
924
- # this can happen if the user terminates a program, triggering a cancellation cascade
925
- if not self._mute_cancellation:
926
- raise
1273
+ return await invocation.run_function()
927
1274
 
928
- async def _call_function_nowait(self, args, kwargs) -> _Invocation:
929
- return await _Invocation.create(self.object_id, args, kwargs, self._client)
1275
+ async def _call_function_nowait(
1276
+ self, args, kwargs, function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType"
1277
+ ) -> _Invocation:
1278
+ return await _Invocation.create(
1279
+ self, args, kwargs, client=self._client, function_call_invocation_type=function_call_invocation_type
1280
+ )
930
1281
 
931
1282
  @warn_if_generator_is_not_consumed()
932
1283
  @live_method_gen
933
1284
  @synchronizer.no_input_translation
934
1285
  async def _call_generator(self, args, kwargs):
935
- invocation = await _Invocation.create(self.object_id, args, kwargs, self._client)
1286
+ invocation = await _Invocation.create(
1287
+ self,
1288
+ args,
1289
+ kwargs,
1290
+ client=self._client,
1291
+ function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY,
1292
+ )
936
1293
  async for res in invocation.run_generator():
937
1294
  yield res
938
1295
 
939
1296
  @synchronizer.no_io_translation
940
1297
  async def _call_generator_nowait(self, args, kwargs):
941
- return await _Invocation.create(self.object_id, args, kwargs, self._client)
1298
+ deprecation_warning(
1299
+ (2024, 12, 11),
1300
+ "Calling spawn on a generator function is deprecated and will soon raise an exception.",
1301
+ )
1302
+ return await _Invocation.create(
1303
+ self,
1304
+ args,
1305
+ kwargs,
1306
+ client=self._client,
1307
+ function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC_LEGACY,
1308
+ )
942
1309
 
943
1310
  @synchronizer.no_io_translation
944
1311
  @live_method
945
- async def remote(self, *args, **kwargs) -> Any:
1312
+ async def remote(self, *args: P.args, **kwargs: P.kwargs) -> ReturnType:
946
1313
  """
947
1314
  Calls the function remotely, executing it with the given arguments and returning the execution's result.
948
1315
  """
949
1316
  # TODO: Generics/TypeVars
950
- if self._web_url:
951
- raise InvalidError(
952
- "A web endpoint function cannot be invoked for remote execution with `.remote`. "
953
- f"Invoke this function via its web url '{self._web_url}' or call it locally: {self._function_name}()."
954
- )
1317
+ self._check_no_web_url("remote")
955
1318
  if self._is_generator:
956
1319
  raise InvalidError(
957
1320
  "A generator function cannot be called with `.remote(...)`. Use `.remote_gen(...)` instead."
@@ -966,11 +1329,7 @@ class _Function(_Object, type_prefix="fu"):
966
1329
  Calls the generator remotely, executing it with the given arguments and returning the execution's result.
967
1330
  """
968
1331
  # TODO: Generics/TypeVars
969
- if self._web_url:
970
- raise InvalidError(
971
- "A web endpoint function cannot be invoked for remote execution with `.remote`. "
972
- f"Invoke this function via its web url '{self._web_url}' or call it locally: {self._function_name}()."
973
- )
1332
+ self._check_no_web_url("remote_gen")
974
1333
 
975
1334
  if not self._is_generator:
976
1335
  raise InvalidError(
@@ -979,22 +1338,15 @@ class _Function(_Object, type_prefix="fu"):
979
1338
  async for item in self._call_generator(args, kwargs): # type: ignore
980
1339
  yield item
981
1340
 
982
- @synchronizer.no_io_translation
983
- @live_method
984
- async def shell(self, *args, **kwargs) -> None:
985
- if self._is_generator:
986
- async for item in self._call_generator(args, kwargs):
987
- pass
988
- else:
989
- await self._call_function(args, kwargs)
990
-
991
- def _get_is_remote_cls_method(self):
992
- return self._is_remote_cls_method
1341
+ def _is_local(self):
1342
+ return self._info is not None
993
1343
 
994
- def _get_info(self):
1344
+ def _get_info(self) -> FunctionInfo:
1345
+ if not self._info:
1346
+ raise ExecutionError("Can't get info for a function that isn't locally defined")
995
1347
  return self._info
996
1348
 
997
- def _get_obj(self):
1349
+ def _get_obj(self) -> Optional["modal.cls._Obj"]:
998
1350
  if not self._is_method:
999
1351
  return None
1000
1352
  elif not self._obj:
@@ -1003,79 +1355,115 @@ class _Function(_Object, type_prefix="fu"):
1003
1355
  return self._obj
1004
1356
 
1005
1357
  @synchronizer.nowrap
1006
- def local(self, *args, **kwargs) -> Any:
1358
+ def local(self, *args: P.args, **kwargs: P.kwargs) -> OriginalReturnType:
1007
1359
  """
1008
1360
  Calls the function locally, executing it with the given arguments and returning the execution's result.
1009
1361
 
1010
1362
  The function will execute in the same environment as the caller, just like calling the underlying function
1011
- directly in Python. In particular, secrets will not be available through environment variables.
1363
+ directly in Python. In particular, only secrets available in the caller environment will be available
1364
+ through environment variables.
1012
1365
  """
1013
1366
  # TODO(erikbern): it would be nice to remove the nowrap thing, but right now that would cause
1014
1367
  # "user code" to run on the synchronicity thread, which seems bad
1015
- info = self._get_info()
1016
- if not info:
1368
+ if not self._is_local():
1017
1369
  msg = (
1018
- "The definition for this function is missing so it is not possible to invoke it locally. "
1370
+ "The definition for this function is missing here so it is not possible to invoke it locally. "
1019
1371
  "If this function was retrieved via `Function.lookup` you need to use `.remote()`."
1020
1372
  )
1021
1373
  raise ExecutionError(msg)
1022
1374
 
1023
- obj = self._get_obj()
1375
+ info = self._get_info()
1376
+ if not info.raw_f:
1377
+ # Here if calling .local on a service function itself which should never happen
1378
+ # TODO: check if we end up here in a container for a serialized function?
1379
+ raise ExecutionError("Can't call .local on service function")
1380
+
1381
+ if is_local() and self.spec.volumes or self.spec.network_file_systems:
1382
+ warnings.warn(
1383
+ f"The {info.function_name} function is executing locally "
1384
+ + "and will not have access to the mounted Volume or NetworkFileSystem data"
1385
+ )
1386
+
1387
+ obj: Optional["modal.cls._Obj"] = self._get_obj()
1024
1388
 
1025
1389
  if not obj:
1026
1390
  fun = info.raw_f
1027
1391
  return fun(*args, **kwargs)
1028
1392
  else:
1029
1393
  # This is a method on a class, so bind the self to the function
1030
- local_obj = obj.get_local_obj()
1031
- fun = info.raw_f.__get__(local_obj)
1394
+ user_cls_instance = obj._cached_user_cls_instance()
1395
+ fun = info.raw_f.__get__(user_cls_instance)
1032
1396
 
1397
+ # TODO: replace implicit local enter/exit with a context manager
1033
1398
  if is_async(info.raw_f):
1034
1399
  # We want to run __aenter__ and fun in the same coroutine
1035
1400
  async def coro():
1036
- await obj.aenter()
1401
+ await obj._aenter()
1037
1402
  return await fun(*args, **kwargs)
1038
1403
 
1039
- return coro()
1404
+ return coro() # type: ignore
1040
1405
  else:
1041
- obj.enter()
1406
+ obj._enter()
1042
1407
  return fun(*args, **kwargs)
1043
1408
 
1044
1409
  @synchronizer.no_input_translation
1045
1410
  @live_method
1046
- async def spawn(self, *args, **kwargs) -> Optional["_FunctionCall"]:
1047
- """Calls the function with the given arguments, without waiting for the results.
1411
+ async def _experimental_spawn(self, *args: P.args, **kwargs: P.kwargs) -> "_FunctionCall[ReturnType]":
1412
+ """[Experimental] Calls the function with the given arguments, without waiting for the results.
1413
+
1414
+ This experimental version of the spawn method allows up to 1 million inputs to be spawned.
1048
1415
 
1049
- Returns a `modal.functions.FunctionCall` object, that can later be polled or waited for using `.get(timeout=...)`.
1416
+ Returns a `modal.functions.FunctionCall` object, that can later be polled or
1417
+ waited for using `.get(timeout=...)`.
1050
1418
  Conceptually similar to `multiprocessing.pool.apply_async`, or a Future/Promise in other contexts.
1419
+ """
1420
+ self._check_no_web_url("_experimental_spawn")
1421
+ if self._is_generator:
1422
+ invocation = await self._call_generator_nowait(args, kwargs)
1423
+ else:
1424
+ invocation = await self._call_function_nowait(
1425
+ args, kwargs, function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC
1426
+ )
1427
+
1428
+ fc = _FunctionCall._new_hydrated(invocation.function_call_id, invocation.client, None)
1429
+ fc._is_generator = self._is_generator if self._is_generator else False
1430
+ return fc
1051
1431
 
1052
- *Note:* `.spawn()` on a modal generator function does call and execute the generator, but does not currently
1053
- return a function handle for polling the result.
1432
+ @synchronizer.no_input_translation
1433
+ @live_method
1434
+ async def spawn(self, *args: P.args, **kwargs: P.kwargs) -> "_FunctionCall[ReturnType]":
1435
+ """Calls the function with the given arguments, without waiting for the results.
1436
+
1437
+ Returns a `modal.functions.FunctionCall` object, that can later be polled or
1438
+ waited for using `.get(timeout=...)`.
1439
+ Conceptually similar to `multiprocessing.pool.apply_async`, or a Future/Promise in other contexts.
1054
1440
  """
1441
+ self._check_no_web_url("spawn")
1055
1442
  if self._is_generator:
1056
- await self._call_generator_nowait(args, kwargs)
1057
- return None
1443
+ invocation = await self._call_generator_nowait(args, kwargs)
1444
+ else:
1445
+ invocation = await self._call_function_nowait(
1446
+ args, kwargs, api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC_LEGACY
1447
+ )
1058
1448
 
1059
- invocation = await self._call_function_nowait(args, kwargs)
1060
- return _FunctionCall._new_hydrated(invocation.function_call_id, invocation.client, None)
1449
+ fc = _FunctionCall._new_hydrated(invocation.function_call_id, invocation.client, None)
1450
+ fc._is_generator = self._is_generator if self._is_generator else False
1451
+ return fc
1061
1452
 
1062
1453
  def get_raw_f(self) -> Callable[..., Any]:
1063
1454
  """Return the inner Python object wrapped by this Modal Function."""
1064
- if not self._info:
1065
- raise AttributeError("_info has not been set on this FunctionHandle and not available in this context")
1066
-
1067
- return self._info.raw_f
1455
+ return self._raw_f
1068
1456
 
1069
1457
  @live_method
1070
1458
  async def get_current_stats(self) -> FunctionStats:
1071
1459
  """Return a `FunctionStats` object describing the current function's queue and runner counts."""
1072
1460
  assert self._client.stub
1073
- resp = await self._client.stub.FunctionGetCurrentStats(
1074
- api_pb2.FunctionGetCurrentStatsRequest(function_id=self.object_id)
1075
- )
1076
- return FunctionStats(
1077
- backlog=resp.backlog, num_active_runners=resp.num_active_tasks, num_total_runners=resp.num_total_tasks
1461
+ resp = await retry_transient_errors(
1462
+ self._client.stub.FunctionGetCurrentStats,
1463
+ api_pb2.FunctionGetCurrentStatsRequest(function_id=self.object_id),
1464
+ total_timeout=10.0,
1078
1465
  )
1466
+ return FunctionStats(backlog=resp.backlog, num_total_runners=resp.num_total_tasks)
1079
1467
 
1080
1468
  # A bit hacky - but the map-style functions need to not be synchronicity-wrapped
1081
1469
  # in order to not execute their input iterators on the synchronicity event loop.
@@ -1089,7 +1477,7 @@ class _Function(_Object, type_prefix="fu"):
1089
1477
  Function = synchronize_api(_Function)
1090
1478
 
1091
1479
 
1092
- class _FunctionCall(_Object, type_prefix="fc"):
1480
+ class _FunctionCall(typing.Generic[ReturnType], _Object, type_prefix="fc"):
1093
1481
  """A reference to an executed function call.
1094
1482
 
1095
1483
  Constructed using `.spawn(...)` on a Modal function with the same
@@ -1100,11 +1488,13 @@ class _FunctionCall(_Object, type_prefix="fc"):
1100
1488
  Conceptually similar to a Future/Promise/AsyncResult in other contexts and languages.
1101
1489
  """
1102
1490
 
1491
+ _is_generator: bool = False
1492
+
1103
1493
  def _invocation(self):
1104
1494
  assert self._client.stub
1105
1495
  return _Invocation(self._client.stub, self.object_id, self._client)
1106
1496
 
1107
- async def get(self, timeout: Optional[float] = None):
1497
+ async def get(self, timeout: Optional[float] = None) -> ReturnType:
1108
1498
  """Get the result of the function call.
1109
1499
 
1110
1500
  This function waits indefinitely by default. It takes an optional
@@ -1113,9 +1503,23 @@ class _FunctionCall(_Object, type_prefix="fc"):
1113
1503
 
1114
1504
  The returned coroutine is not cancellation-safe.
1115
1505
  """
1506
+
1507
+ if self._is_generator:
1508
+ raise Exception("Cannot get the result of a generator function call. Use `get_gen` instead.")
1509
+
1116
1510
  return await self._invocation().poll_function(timeout=timeout)
1117
1511
 
1118
- async def get_call_graph(self) -> List[InputInfo]:
1512
+ async def get_gen(self) -> AsyncGenerator[Any, None]:
1513
+ """
1514
+ Calls the generator remotely, executing it with the given arguments and returning the execution's result.
1515
+ """
1516
+ if not self._is_generator:
1517
+ raise Exception("Cannot iterate over a non-generator function call. Use `get` instead.")
1518
+
1519
+ async for res in self._invocation().run_generator():
1520
+ yield res
1521
+
1522
+ async def get_call_graph(self) -> list[InputInfo]:
1119
1523
  """Returns a structure representing the call graph from a given root
1120
1524
  call ID, along with the status of execution for each node.
1121
1525
 
@@ -1127,24 +1531,38 @@ class _FunctionCall(_Object, type_prefix="fc"):
1127
1531
  response = await retry_transient_errors(self._client.stub.FunctionGetCallGraph, request)
1128
1532
  return _reconstruct_call_graph(response)
1129
1533
 
1130
- async def cancel(self):
1131
- """Cancels the function call, which will stop its execution and mark its inputs as [`TERMINATED`](/docs/reference/modal.call_graph#modalcall_graphinputstatus)."""
1132
- request = api_pb2.FunctionCallCancelRequest(function_call_id=self.object_id)
1534
+ async def cancel(
1535
+ self,
1536
+ terminate_containers: bool = False, # if true, containers running the inputs are forcibly terminated
1537
+ ):
1538
+ """Cancels the function call, which will stop its execution and mark its inputs as
1539
+ [`TERMINATED`](/docs/reference/modal.call_graph#modalcall_graphinputstatus).
1540
+
1541
+ If `terminate_containers=True` - the containers running the cancelled inputs are all terminated
1542
+ causing any non-cancelled inputs on those containers to be rescheduled in new containers.
1543
+ """
1544
+ request = api_pb2.FunctionCallCancelRequest(
1545
+ function_call_id=self.object_id, terminate_containers=terminate_containers
1546
+ )
1133
1547
  assert self._client and self._client.stub
1134
1548
  await retry_transient_errors(self._client.stub.FunctionCallCancel, request)
1135
1549
 
1136
1550
  @staticmethod
1137
- async def from_id(function_call_id: str, client: Optional[_Client] = None) -> "_FunctionCall":
1551
+ async def from_id(
1552
+ function_call_id: str, client: Optional[_Client] = None, is_generator: bool = False
1553
+ ) -> "_FunctionCall":
1138
1554
  if client is None:
1139
1555
  client = await _Client.from_env()
1140
1556
 
1141
- return _FunctionCall._new_hydrated(function_call_id, client, None)
1557
+ fc = _FunctionCall._new_hydrated(function_call_id, client, None)
1558
+ fc._is_generator = is_generator
1559
+ return fc
1142
1560
 
1143
1561
 
1144
1562
  FunctionCall = synchronize_api(_FunctionCall)
1145
1563
 
1146
1564
 
1147
- async def _gather(*function_calls: _FunctionCall):
1565
+ async def _gather(*function_calls: _FunctionCall[ReturnType]) -> typing.Sequence[ReturnType]:
1148
1566
  """Wait until all Modal function calls have results before returning
1149
1567
 
1150
1568
  Accepts a variable number of FunctionCall objects as returned by `Function.spawn()`.
@@ -1162,7 +1580,7 @@ async def _gather(*function_calls: _FunctionCall):
1162
1580
  ```
1163
1581
  """
1164
1582
  try:
1165
- return await asyncio.gather(*[fc.get() for fc in function_calls])
1583
+ return await TaskContext.gather(*[fc.get() for fc in function_calls])
1166
1584
  except Exception as exc:
1167
1585
  # TODO: kill all running function calls
1168
1586
  raise exc