modal 0.62.16__py3-none-any.whl → 0.72.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +17 -13
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +420 -937
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -59
  11. modal/_resources.py +51 -0
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1036 -0
  15. modal/_runtime/execution_context.py +89 -0
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +134 -9
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +52 -16
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +479 -100
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +460 -171
  29. modal/_utils/grpc_testing.py +47 -31
  30. modal/_utils/grpc_utils.py +62 -109
  31. modal/_utils/hash_utils.py +61 -19
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +5 -7
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +14 -12
  43. modal/app.py +1003 -314
  44. modal/app.pyi +540 -264
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +63 -53
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +205 -45
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +62 -14
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +64 -58
  55. modal/cli/launch.py +32 -18
  56. modal/cli/network_file_system.py +64 -83
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +35 -10
  59. modal/cli/programs/vscode.py +60 -10
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +234 -131
  62. modal/cli/secret.py +8 -7
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +79 -10
  65. modal/cli/volume.py +110 -109
  66. modal/client.py +250 -144
  67. modal/client.pyi +157 -118
  68. modal/cloud_bucket_mount.py +108 -34
  69. modal/cloud_bucket_mount.pyi +32 -38
  70. modal/cls.py +535 -148
  71. modal/cls.pyi +190 -146
  72. modal/config.py +41 -19
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +111 -65
  76. modal/dict.pyi +136 -131
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +34 -43
  80. modal/experimental.py +61 -2
  81. modal/extensions/ipython.py +5 -5
  82. modal/file_io.py +537 -0
  83. modal/file_io.pyi +235 -0
  84. modal/file_pattern_matcher.py +197 -0
  85. modal/functions.py +906 -911
  86. modal/functions.pyi +466 -430
  87. modal/gpu.py +57 -44
  88. modal/image.py +1089 -479
  89. modal/image.pyi +584 -228
  90. modal/io_streams.py +434 -0
  91. modal/io_streams.pyi +122 -0
  92. modal/mount.py +314 -101
  93. modal/mount.pyi +241 -235
  94. modal/network_file_system.py +92 -92
  95. modal/network_file_system.pyi +152 -110
  96. modal/object.py +67 -36
  97. modal/object.pyi +166 -143
  98. modal/output.py +63 -0
  99. modal/parallel_map.py +434 -0
  100. modal/parallel_map.pyi +75 -0
  101. modal/partial_function.py +282 -117
  102. modal/partial_function.pyi +222 -129
  103. modal/proxy.py +15 -12
  104. modal/proxy.pyi +3 -8
  105. modal/queue.py +182 -65
  106. modal/queue.pyi +218 -118
  107. modal/requirements/2024.04.txt +29 -0
  108. modal/requirements/2024.10.txt +16 -0
  109. modal/requirements/README.md +21 -0
  110. modal/requirements/base-images.json +22 -0
  111. modal/retries.py +48 -7
  112. modal/runner.py +459 -156
  113. modal/runner.pyi +135 -71
  114. modal/running_app.py +38 -0
  115. modal/sandbox.py +514 -236
  116. modal/sandbox.pyi +397 -169
  117. modal/schedule.py +4 -4
  118. modal/scheduler_placement.py +20 -3
  119. modal/secret.py +56 -31
  120. modal/secret.pyi +62 -42
  121. modal/serving.py +51 -56
  122. modal/serving.pyi +44 -36
  123. modal/stream_type.py +15 -0
  124. modal/token_flow.py +5 -3
  125. modal/token_flow.pyi +37 -32
  126. modal/volume.py +285 -157
  127. modal/volume.pyi +249 -184
  128. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/METADATA +7 -7
  129. modal-0.72.11.dist-info/RECORD +174 -0
  130. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
  131. modal_docs/gen_reference_docs.py +3 -1
  132. modal_docs/mdmd/mdmd.py +0 -1
  133. modal_docs/mdmd/signatures.py +5 -2
  134. modal_global_objects/images/base_images.py +28 -0
  135. modal_global_objects/mounts/python_standalone.py +2 -2
  136. modal_proto/__init__.py +1 -1
  137. modal_proto/api.proto +1288 -533
  138. modal_proto/api_grpc.py +856 -456
  139. modal_proto/api_pb2.py +2165 -1157
  140. modal_proto/api_pb2.pyi +8859 -0
  141. modal_proto/api_pb2_grpc.py +1674 -855
  142. modal_proto/api_pb2_grpc.pyi +1416 -0
  143. modal_proto/modal_api_grpc.py +149 -0
  144. modal_proto/modal_options_grpc.py +3 -0
  145. modal_proto/options_pb2.pyi +20 -0
  146. modal_proto/options_pb2_grpc.pyi +7 -0
  147. modal_proto/py.typed +0 -0
  148. modal_version/__init__.py +1 -1
  149. modal_version/_version_generated.py +2 -2
  150. modal/_asgi.py +0 -370
  151. modal/_container_entrypoint.pyi +0 -378
  152. modal/_container_exec.py +0 -128
  153. modal/_sandbox_shell.py +0 -49
  154. modal/shared_volume.py +0 -23
  155. modal/shared_volume.pyi +0 -24
  156. modal/stub.py +0 -783
  157. modal/stub.pyi +0 -332
  158. modal-0.62.16.dist-info/RECORD +0 -198
  159. modal_global_objects/images/conda.py +0 -15
  160. modal_global_objects/images/debian_slim.py +0 -15
  161. modal_global_objects/images/micromamba.py +0 -15
  162. test/__init__.py +0 -1
  163. test/aio_test.py +0 -12
  164. test/async_utils_test.py +0 -262
  165. test/blob_test.py +0 -67
  166. test/cli_imports_test.py +0 -149
  167. test/cli_test.py +0 -659
  168. test/client_test.py +0 -194
  169. test/cls_test.py +0 -630
  170. test/config_test.py +0 -137
  171. test/conftest.py +0 -1420
  172. test/container_app_test.py +0 -32
  173. test/container_test.py +0 -1389
  174. test/cpu_test.py +0 -23
  175. test/decorator_test.py +0 -85
  176. test/deprecation_test.py +0 -34
  177. test/dict_test.py +0 -33
  178. test/e2e_test.py +0 -68
  179. test/error_test.py +0 -7
  180. test/function_serialization_test.py +0 -32
  181. test/function_test.py +0 -653
  182. test/function_utils_test.py +0 -101
  183. test/gpu_test.py +0 -159
  184. test/grpc_utils_test.py +0 -141
  185. test/helpers.py +0 -42
  186. test/image_test.py +0 -669
  187. test/live_reload_test.py +0 -80
  188. test/lookup_test.py +0 -70
  189. test/mdmd_test.py +0 -329
  190. test/mount_test.py +0 -162
  191. test/mounted_files_test.py +0 -329
  192. test/network_file_system_test.py +0 -181
  193. test/notebook_test.py +0 -66
  194. test/object_test.py +0 -41
  195. test/package_utils_test.py +0 -25
  196. test/queue_test.py +0 -97
  197. test/resolver_test.py +0 -58
  198. test/retries_test.py +0 -67
  199. test/runner_test.py +0 -85
  200. test/sandbox_test.py +0 -191
  201. test/schedule_test.py +0 -15
  202. test/scheduler_placement_test.py +0 -29
  203. test/secret_test.py +0 -78
  204. test/serialization_test.py +0 -42
  205. test/stub_composition_test.py +0 -10
  206. test/stub_test.py +0 -360
  207. test/test_asgi_wrapper.py +0 -234
  208. test/token_flow_test.py +0 -18
  209. test/traceback_test.py +0 -135
  210. test/tunnel_test.py +0 -29
  211. test/utils_test.py +0 -88
  212. test/version_test.py +0 -14
  213. test/volume_test.py +0 -341
  214. test/watcher_test.py +0 -30
  215. test/webhook_test.py +0 -146
  216. /modal/{requirements.312.txt → requirements/2023.12.312.txt} +0 -0
  217. /modal/{requirements.txt → requirements/2023.12.txt} +0 -0
  218. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
  219. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
  220. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
modal/functions.py CHANGED
@@ -1,253 +1,163 @@
1
1
  # Copyright Modal Labs 2023
2
- import asyncio
2
+ import dataclasses
3
3
  import inspect
4
+ import textwrap
4
5
  import time
6
+ import typing
5
7
  import warnings
6
- from contextvars import ContextVar
8
+ from collections.abc import AsyncGenerator, Collection, Sequence, Sized
7
9
  from dataclasses import dataclass
8
10
  from pathlib import PurePosixPath
9
11
  from typing import (
10
12
  TYPE_CHECKING,
11
13
  Any,
12
- AsyncGenerator,
13
- AsyncIterable,
14
- AsyncIterator,
15
14
  Callable,
16
- Collection,
17
- Dict,
18
- List,
19
- Literal,
20
15
  Optional,
21
- Sequence,
22
- Set,
23
- Sized,
24
- Tuple,
25
- Type,
26
16
  Union,
27
17
  )
28
18
 
29
- from aiostream import pipe, stream
19
+ import typing_extensions
30
20
  from google.protobuf.message import Message
31
21
  from grpclib import GRPCError, Status
32
- from grpclib.exceptions import StreamTerminatedError
22
+ from synchronicity.combined_types import MethodWithAio
33
23
  from synchronicity.exceptions import UserCodeException
34
24
 
35
- from modal import _pty, is_local
36
- from modal_proto import api_grpc, api_pb2
25
+ from modal_proto import api_pb2
26
+ from modal_proto.modal_api_grpc import ModalClientModal
37
27
 
38
28
  from ._location import parse_cloud_provider
39
- from ._output import OutputManager
29
+ from ._pty import get_pty_info
40
30
  from ._resolver import Resolver
41
- from ._serialization import deserialize, deserialize_data_format, serialize
42
- from ._traceback import append_modal_tb
31
+ from ._resources import convert_fn_config_to_resources_config
32
+ from ._runtime.execution_context import current_input_id, is_local
33
+ from ._serialization import serialize, serialize_proto_params
34
+ from ._traceback import print_server_warnings
43
35
  from ._utils.async_utils import (
44
- queue_batch_iterator,
36
+ TaskContext,
37
+ aclosing,
38
+ async_merge,
39
+ callable_to_agen,
45
40
  synchronize_api,
46
41
  synchronizer,
47
42
  warn_if_generator_is_not_consumed,
48
43
  )
49
- from ._utils.blob_utils import (
50
- BLOB_MAX_PARALLELISM,
51
- MAX_OBJECT_SIZE_BYTES,
52
- blob_download,
53
- blob_upload,
44
+ from ._utils.deprecation import deprecation_warning, renamed_parameter
45
+ from ._utils.function_utils import (
46
+ ATTEMPT_TIMEOUT_GRACE_PERIOD,
47
+ OUTPUTS_TIMEOUT,
48
+ FunctionCreationStatus,
49
+ FunctionInfo,
50
+ _create_input,
51
+ _process_result,
52
+ _stream_function_call_data,
53
+ get_function_type,
54
+ is_async,
54
55
  )
55
- from ._utils.function_utils import FunctionInfo, get_referred_objects, is_async
56
- from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors, unary_stream
57
- from ._utils.mount_utils import validate_mount_points, validate_volumes
56
+ from ._utils.grpc_utils import retry_transient_errors
57
+ from ._utils.mount_utils import validate_network_file_systems, validate_volumes
58
58
  from .call_graph import InputInfo, _reconstruct_call_graph
59
59
  from .client import _Client
60
60
  from .cloud_bucket_mount import _CloudBucketMount, cloud_bucket_mounts_to_proto
61
- from .config import config, logger
61
+ from .config import config
62
62
  from .exception import (
63
63
  ExecutionError,
64
64
  FunctionTimeoutError,
65
+ InternalFailure,
65
66
  InvalidError,
66
67
  NotFoundError,
67
- RemoteError,
68
- deprecation_warning,
68
+ OutputExpiredError,
69
69
  )
70
70
  from .gpu import GPU_T, parse_gpu_config
71
71
  from .image import _Image
72
- from .mount import _get_client_mount, _Mount
72
+ from .mount import _get_client_mount, _Mount, get_auto_mounts
73
73
  from .network_file_system import _NetworkFileSystem, network_file_system_mount_protos
74
- from .object import Object, _get_environment_name, _Object, live_method, live_method_gen
74
+ from .object import _get_environment_name, _Object, live_method, live_method_gen
75
+ from .output import _get_output_manager
76
+ from .parallel_map import (
77
+ _for_each_async,
78
+ _for_each_sync,
79
+ _map_async,
80
+ _map_invocation,
81
+ _map_sync,
82
+ _starmap_async,
83
+ _starmap_sync,
84
+ _SynchronizedQueue,
85
+ )
75
86
  from .proxy import _Proxy
76
- from .retries import Retries
87
+ from .retries import Retries, RetryManager
77
88
  from .schedule import Schedule
78
89
  from .scheduler_placement import SchedulerPlacement
79
90
  from .secret import _Secret
80
91
  from .volume import _Volume
81
92
 
82
- OUTPUTS_TIMEOUT = 55.0 # seconds
83
- ATTEMPT_TIMEOUT_GRACE_PERIOD = 5 # seconds
84
-
85
-
86
93
  if TYPE_CHECKING:
87
- import modal.stub
88
-
89
-
90
- def exc_with_hints(exc: BaseException):
91
- """mdmd:hidden"""
92
- if isinstance(exc, ImportError) and exc.msg == "attempted relative import with no known parent package":
93
- exc.msg += """\n
94
- HINT: For relative imports to work, you might need to run your modal app as a module. Try:
95
- - `python -m my_pkg.my_app` instead of `python my_pkg/my_app.py`
96
- - `modal deploy my_pkg.my_app` instead of `modal deploy my_pkg/my_app.py`
97
- """
98
- elif isinstance(
99
- exc, RuntimeError
100
- ) and "CUDA error: no kernel image is available for execution on the device" in str(exc):
101
- msg = (
102
- exc.args[0]
103
- + """\n
104
- HINT: This error usually indicates an outdated CUDA version. Older versions of torch (<=1.12)
105
- come with CUDA 10.2 by default. If pinning to an older torch version, you can specify a CUDA version
106
- manually, for example:
107
- - image.pip_install("torch==1.12.1+cu116", find_links="https://download.pytorch.org/whl/torch_stable.html")
108
- """
109
- )
110
- exc.args = (msg,)
111
-
112
- return exc
113
-
114
-
115
- async def _process_result(result: api_pb2.GenericResult, data_format: int, stub, client=None):
116
- if result.WhichOneof("data_oneof") == "data_blob_id":
117
- data = await blob_download(result.data_blob_id, stub)
118
- else:
119
- data = result.data
120
-
121
- if result.status == api_pb2.GenericResult.GENERIC_STATUS_TIMEOUT:
122
- raise FunctionTimeoutError(result.exception)
123
- elif result.status != api_pb2.GenericResult.GENERIC_STATUS_SUCCESS:
124
- if data:
125
- try:
126
- exc = deserialize(data, client)
127
- except Exception as deser_exc:
128
- raise ExecutionError(
129
- "Could not deserialize remote exception due to local error:\n"
130
- + f"{deser_exc}\n"
131
- + "This can happen if your local environment does not have the remote exception definitions.\n"
132
- + "Here is the remote traceback:\n"
133
- + f"{result.traceback}"
134
- )
135
- if not isinstance(exc, BaseException):
136
- raise ExecutionError(f"Got remote exception of incorrect type {type(exc)}")
137
-
138
- if result.serialized_tb:
139
- try:
140
- tb_dict = deserialize(result.serialized_tb, client)
141
- line_cache = deserialize(result.tb_line_cache, client)
142
- append_modal_tb(exc, tb_dict, line_cache)
143
- except Exception:
144
- pass
145
- uc_exc = UserCodeException(exc_with_hints(exc))
146
- raise uc_exc
147
- raise RemoteError(result.exception)
94
+ import modal.app
95
+ import modal.cls
96
+ import modal.partial_function
148
97
 
149
- try:
150
- return deserialize_data_format(data, data_format, client)
151
- except ModuleNotFoundError as deser_exc:
152
- raise ExecutionError(
153
- "Could not deserialize result due to error:\n"
154
- + f"{deser_exc}\n"
155
- + "This can happen if your local environment does not have a module that was used to construct the result. \n"
156
- )
157
98
 
158
-
159
- async def _create_input(args, kwargs, client, idx: Optional[int] = None) -> api_pb2.FunctionPutInputsItem:
160
- """Serialize function arguments and create a FunctionInput protobuf,
161
- uploading to blob storage if needed.
162
- """
163
- if idx is None:
164
- idx = 0
165
-
166
- args_serialized = serialize((args, kwargs))
167
-
168
- if len(args_serialized) > MAX_OBJECT_SIZE_BYTES:
169
- args_blob_id = await blob_upload(args_serialized, client.stub)
170
-
171
- return api_pb2.FunctionPutInputsItem(
172
- input=api_pb2.FunctionInput(args_blob_id=args_blob_id, data_format=api_pb2.DATA_FORMAT_PICKLE),
173
- idx=idx,
174
- )
175
- else:
176
- return api_pb2.FunctionPutInputsItem(
177
- input=api_pb2.FunctionInput(args=args_serialized, data_format=api_pb2.DATA_FORMAT_PICKLE),
178
- idx=idx,
179
- )
180
-
181
-
182
- async def _stream_function_call_data(
183
- client, function_call_id: str, variant: Literal["data_in", "data_out"]
184
- ) -> AsyncIterator[Any]:
185
- """Read from the `data_in` or `data_out` stream of a function call."""
186
- last_index = 0
187
- retries_remaining = 10
188
-
189
- if variant == "data_in":
190
- stub_fn = client.stub.FunctionCallGetDataIn
191
- elif variant == "data_out":
192
- stub_fn = client.stub.FunctionCallGetDataOut
193
- else:
194
- raise ValueError(f"Invalid variant {variant}")
195
-
196
- while True:
197
- req = api_pb2.FunctionCallGetDataRequest(function_call_id=function_call_id, last_index=last_index)
198
- try:
199
- async for chunk in unary_stream(stub_fn, req):
200
- if chunk.index <= last_index:
201
- continue
202
- last_index = chunk.index
203
- if chunk.data_blob_id:
204
- message_bytes = await blob_download(chunk.data_blob_id, client.stub)
205
- else:
206
- message_bytes = chunk.data
207
- message = deserialize_data_format(message_bytes, chunk.data_format, client)
208
- yield message
209
- except (GRPCError, StreamTerminatedError) as exc:
210
- if retries_remaining > 0:
211
- retries_remaining -= 1
212
- if isinstance(exc, GRPCError):
213
- if exc.status in RETRYABLE_GRPC_STATUS_CODES:
214
- await asyncio.sleep(1.0)
215
- continue
216
- elif isinstance(exc, StreamTerminatedError):
217
- continue
218
- raise
219
-
220
-
221
- @dataclass
222
- class _OutputValue:
223
- # box class for distinguishing None results from non-existing/None markers
224
- value: Any
99
+ @dataclasses.dataclass
100
+ class _RetryContext:
101
+ function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType"
102
+ retry_policy: api_pb2.FunctionRetryPolicy
103
+ function_call_jwt: str
104
+ input_jwt: str
105
+ input_id: str
106
+ item: api_pb2.FunctionPutInputsItem
225
107
 
226
108
 
227
109
  class _Invocation:
228
110
  """Internal client representation of a single-input call to a Modal Function or Generator"""
229
111
 
230
- def __init__(self, stub: api_grpc.ModalClientStub, function_call_id: str, client: _Client):
112
+ stub: ModalClientModal
113
+
114
+ def __init__(
115
+ self,
116
+ stub: ModalClientModal,
117
+ function_call_id: str,
118
+ client: _Client,
119
+ retry_context: Optional[_RetryContext] = None,
120
+ ):
231
121
  self.stub = stub
232
122
  self.client = client # Used by the deserializer.
233
123
  self.function_call_id = function_call_id # TODO: remove and use only input_id
124
+ self._retry_context = retry_context
234
125
 
235
126
  @staticmethod
236
- async def create(function_id: str, args, kwargs, client: _Client) -> "_Invocation":
127
+ async def create(
128
+ function: "_Function",
129
+ args,
130
+ kwargs,
131
+ *,
132
+ client: _Client,
133
+ function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType",
134
+ ) -> "_Invocation":
237
135
  assert client.stub
238
- item = await _create_input(args, kwargs, client)
136
+ function_id = function.object_id
137
+ item = await _create_input(args, kwargs, client, method_name=function._use_method_name)
239
138
 
240
139
  request = api_pb2.FunctionMapRequest(
241
140
  function_id=function_id,
242
141
  parent_input_id=current_input_id() or "",
243
142
  function_call_type=api_pb2.FUNCTION_CALL_TYPE_UNARY,
244
143
  pipelined_inputs=[item],
144
+ function_call_invocation_type=function_call_invocation_type,
245
145
  )
246
146
  response = await retry_transient_errors(client.stub.FunctionMap, request)
247
147
  function_call_id = response.function_call_id
248
148
 
249
149
  if response.pipelined_inputs:
250
- return _Invocation(client.stub, function_call_id, client)
150
+ assert len(response.pipelined_inputs) == 1
151
+ input = response.pipelined_inputs[0]
152
+ retry_context = _RetryContext(
153
+ function_call_invocation_type=function_call_invocation_type,
154
+ retry_policy=response.retry_policy,
155
+ function_call_jwt=response.function_call_jwt,
156
+ input_jwt=input.input_jwt,
157
+ input_id=input.input_id,
158
+ item=item,
159
+ )
160
+ return _Invocation(client.stub, function_call_id, client, retry_context)
251
161
 
252
162
  request_put = api_pb2.FunctionPutInputsRequest(
253
163
  function_id=function_id, inputs=[item], function_call_id=function_call_id
@@ -259,11 +169,20 @@ class _Invocation:
259
169
  processed_inputs = inputs_response.inputs
260
170
  if not processed_inputs:
261
171
  raise Exception("Could not create function call - the input queue seems to be full")
262
- return _Invocation(client.stub, function_call_id, client)
172
+ input = inputs_response.inputs[0]
173
+ retry_context = _RetryContext(
174
+ function_call_invocation_type=function_call_invocation_type,
175
+ retry_policy=response.retry_policy,
176
+ function_call_jwt=response.function_call_jwt,
177
+ input_jwt=input.input_jwt,
178
+ input_id=input.input_id,
179
+ item=item,
180
+ )
181
+ return _Invocation(client.stub, function_call_id, client, retry_context)
263
182
 
264
183
  async def pop_function_call_outputs(
265
- self, timeout: Optional[float], clear_on_success: bool
266
- ) -> AsyncIterator[api_pb2.FunctionGetOutputsItem]:
184
+ self, timeout: Optional[float], clear_on_success: bool, input_jwts: Optional[list[str]] = None
185
+ ) -> api_pb2.FunctionGetOutputsResponse:
267
186
  t0 = time.time()
268
187
  if timeout is None:
269
188
  backend_timeout = OUTPUTS_TIMEOUT
@@ -277,53 +196,100 @@ class _Invocation:
277
196
  timeout=backend_timeout,
278
197
  last_entry_id="0-0",
279
198
  clear_on_success=clear_on_success,
199
+ requested_at=time.time(),
200
+ input_jwts=input_jwts,
280
201
  )
281
202
  response: api_pb2.FunctionGetOutputsResponse = await retry_transient_errors(
282
203
  self.stub.FunctionGetOutputs,
283
204
  request,
284
205
  attempt_timeout=backend_timeout + ATTEMPT_TIMEOUT_GRACE_PERIOD,
285
206
  )
207
+
286
208
  if len(response.outputs) > 0:
287
- for item in response.outputs:
288
- yield item
289
- return
209
+ return response
290
210
 
291
211
  if timeout is not None:
292
212
  # update timeout in retry loop
293
213
  backend_timeout = min(OUTPUTS_TIMEOUT, t0 + timeout - time.time())
294
214
  if backend_timeout < 0:
295
- break
215
+ # return the last response to check for state of num_unfinished_inputs
216
+ return response
217
+
218
+ async def _retry_input(self) -> None:
219
+ ctx = self._retry_context
220
+ if not ctx:
221
+ raise ValueError("Cannot retry input when _retry_context is empty.")
222
+
223
+ item = api_pb2.FunctionRetryInputsItem(input_jwt=ctx.input_jwt, input=ctx.item.input)
224
+ request = api_pb2.FunctionRetryInputsRequest(function_call_jwt=ctx.function_call_jwt, inputs=[item])
225
+ await retry_transient_errors(
226
+ self.client.stub.FunctionRetryInputs,
227
+ request,
228
+ )
296
229
 
297
- async def run_function(self) -> Any:
230
+ async def _get_single_output(self, expected_jwt: Optional[str] = None) -> Any:
298
231
  # waits indefinitely for a single result for the function, and clear the outputs buffer after
299
232
  item: api_pb2.FunctionGetOutputsItem = (
300
- await stream.list(self.pop_function_call_outputs(timeout=None, clear_on_success=True))
301
- )[0]
302
- assert not item.result.gen_status
233
+ await self.pop_function_call_outputs(
234
+ timeout=None,
235
+ clear_on_success=True,
236
+ input_jwts=[expected_jwt] if expected_jwt else None,
237
+ )
238
+ ).outputs[0]
303
239
  return await _process_result(item.result, item.data_format, self.stub, self.client)
304
240
 
241
+ async def run_function(self) -> Any:
242
+ # Use retry logic only if retry policy is specified and
243
+ ctx = self._retry_context
244
+ if (
245
+ not ctx
246
+ or not ctx.retry_policy
247
+ or ctx.retry_policy.retries == 0
248
+ or ctx.function_call_invocation_type != api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
249
+ ):
250
+ return await self._get_single_output()
251
+
252
+ # User errors including timeouts are managed by the user specified retry policy.
253
+ user_retry_manager = RetryManager(ctx.retry_policy)
254
+
255
+ while True:
256
+ try:
257
+ return await self._get_single_output(ctx.input_jwt)
258
+ except (UserCodeException, FunctionTimeoutError) as exc:
259
+ await user_retry_manager.raise_or_sleep(exc)
260
+ except InternalFailure:
261
+ # For system failures on the server, we retry immediately.
262
+ pass
263
+ await self._retry_input()
264
+
305
265
  async def poll_function(self, timeout: Optional[float] = None):
306
266
  """Waits up to timeout for a result from a function.
307
267
 
308
268
  If timeout is `None`, waits indefinitely. This function is not
309
269
  cancellation-safe.
310
270
  """
311
- items: List[api_pb2.FunctionGetOutputsItem] = await stream.list(
312
- self.pop_function_call_outputs(timeout=timeout, clear_on_success=False)
271
+ response: api_pb2.FunctionGetOutputsResponse = await self.pop_function_call_outputs(
272
+ timeout=timeout, clear_on_success=False
313
273
  )
314
-
315
- if len(items) == 0:
274
+ if len(response.outputs) == 0 and response.num_unfinished_inputs == 0:
275
+ # if no unfinished inputs and no outputs, then function expired
276
+ raise OutputExpiredError()
277
+ elif len(response.outputs) == 0:
316
278
  raise TimeoutError()
317
279
 
318
- return await _process_result(items[0].result, items[0].data_format, self.stub, self.client)
280
+ return await _process_result(
281
+ response.outputs[0].result, response.outputs[0].data_format, self.stub, self.client
282
+ )
319
283
 
320
284
  async def run_generator(self):
321
- data_stream = _stream_function_call_data(self.client, self.function_call_id, variant="data_out")
322
- combined_stream = stream.merge(data_stream, stream.call(self.run_function)) # type: ignore
323
-
324
285
  items_received = 0
325
286
  items_total: Union[int, None] = None # populated when self.run_function() completes
326
- async with combined_stream.stream() as streamer:
287
+ async with aclosing(
288
+ async_merge(
289
+ _stream_function_call_data(self.client, self.function_call_id, variant="data_out"),
290
+ callable_to_agen(self.run_function),
291
+ )
292
+ ) as streamer:
327
293
  async for item in streamer:
328
294
  if isinstance(item, api_pb2.GeneratorDone):
329
295
  items_total = item.items_total
@@ -336,187 +302,29 @@ class _Invocation:
336
302
  break
337
303
 
338
304
 
339
- MAP_INVOCATION_CHUNK_SIZE = 49
340
-
341
-
342
- async def _map_invocation(
343
- function_id: str,
344
- input_stream: AsyncIterable[Any],
345
- kwargs: Dict[str, Any],
346
- client: _Client,
347
- order_outputs: bool,
348
- return_exceptions: bool,
349
- count_update_callback: Optional[Callable[[int, int], None]],
350
- ):
351
- assert client.stub
352
- request = api_pb2.FunctionMapRequest(
353
- function_id=function_id,
354
- parent_input_id=current_input_id() or "",
355
- function_call_type=api_pb2.FUNCTION_CALL_TYPE_MAP,
356
- return_exceptions=return_exceptions,
357
- )
358
- response = await retry_transient_errors(client.stub.FunctionMap, request)
359
-
360
- function_call_id = response.function_call_id
361
-
362
- have_all_inputs = False
363
- num_inputs = 0
364
- num_outputs = 0
365
- pending_outputs: Dict[str, int] = {} # Map input_id -> next expected gen_index value
366
- completed_outputs: Set[str] = set() # Set of input_ids whose outputs are complete (expecting no more values)
367
-
368
- input_queue: asyncio.Queue = asyncio.Queue()
369
-
370
- async def create_input(arg: Any) -> api_pb2.FunctionPutInputsItem:
371
- nonlocal num_inputs
372
- idx = num_inputs
373
- num_inputs += 1
374
- item = await _create_input(arg, kwargs, client, idx=idx)
375
- return item
376
-
377
- async def drain_input_generator():
378
- # Parallelize uploading blobs
379
- proto_input_stream = stream.iterate(input_stream) | pipe.map(
380
- create_input, # type: ignore[reportArgumentType]
381
- ordered=True,
382
- task_limit=BLOB_MAX_PARALLELISM,
383
- )
384
- async with proto_input_stream.stream() as streamer:
385
- async for item in streamer:
386
- await input_queue.put(item)
387
-
388
- # close queue iterator
389
- await input_queue.put(None)
390
- yield
391
-
392
- async def pump_inputs():
393
- assert client.stub
394
- nonlocal have_all_inputs
395
- async for items in queue_batch_iterator(input_queue, MAP_INVOCATION_CHUNK_SIZE):
396
- request = api_pb2.FunctionPutInputsRequest(
397
- function_id=function_id, inputs=items, function_call_id=function_call_id
398
- )
399
- logger.debug(
400
- f"Pushing {len(items)} inputs to server. Num queued inputs awaiting push is {input_queue.qsize()}."
401
- )
402
- resp = await retry_transient_errors(
403
- client.stub.FunctionPutInputs,
404
- request,
405
- max_retries=None,
406
- max_delay=10,
407
- additional_status_codes=[Status.RESOURCE_EXHAUSTED],
408
- )
409
- for item in resp.inputs:
410
- pending_outputs.setdefault(item.input_id, 0)
411
- logger.debug(
412
- f"Successfully pushed {len(items)} inputs to server. Num queued inputs awaiting push is {input_queue.qsize()}."
413
- )
414
-
415
- have_all_inputs = True
416
- yield
417
-
418
- async def get_all_outputs():
419
- assert client.stub
420
- nonlocal num_inputs, num_outputs, have_all_inputs
421
- last_entry_id = "0-0"
422
- while not have_all_inputs or len(pending_outputs) > len(completed_outputs):
423
- request = api_pb2.FunctionGetOutputsRequest(
424
- function_call_id=function_call_id,
425
- timeout=OUTPUTS_TIMEOUT,
426
- last_entry_id=last_entry_id,
427
- clear_on_success=False,
428
- )
429
- response = await retry_transient_errors(
430
- client.stub.FunctionGetOutputs,
431
- request,
432
- max_retries=20,
433
- attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
434
- )
435
-
436
- if len(response.outputs) == 0:
437
- continue
438
-
439
- last_entry_id = response.last_entry_id
440
- for item in response.outputs:
441
- pending_outputs.setdefault(item.input_id, 0)
442
- if item.input_id in completed_outputs:
443
- # If this input is already completed, it means the output has already been
444
- # processed and was received again due to a duplicate.
445
- continue
446
- completed_outputs.add(item.input_id)
447
- num_outputs += 1
448
- yield item
449
-
450
- async def get_all_outputs_and_clean_up():
451
- assert client.stub
452
- try:
453
- async for item in get_all_outputs():
454
- yield item
455
- finally:
456
- # "ack" that we have all outputs we are interested in and let backend clear results
457
- request = api_pb2.FunctionGetOutputsRequest(
458
- function_call_id=function_call_id,
459
- timeout=0,
460
- last_entry_id="0-0",
461
- clear_on_success=True,
462
- )
463
- await retry_transient_errors(client.stub.FunctionGetOutputs, request)
464
-
465
- async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> Tuple[int, Any]:
466
- try:
467
- output = await _process_result(item.result, item.data_format, client.stub, client)
468
- except Exception as e:
469
- if return_exceptions:
470
- output = e
471
- else:
472
- raise e
473
- return (item.idx, output)
474
-
475
- async def poll_outputs():
476
- outputs = stream.iterate(get_all_outputs_and_clean_up())
477
- outputs_fetched = outputs | pipe.map(fetch_output, ordered=True, task_limit=BLOB_MAX_PARALLELISM) # type: ignore
478
-
479
- # map to store out-of-order outputs received
480
- received_outputs = {}
481
- output_idx = 0
482
-
483
- async with outputs_fetched.stream() as streamer:
484
- async for idx, output in streamer:
485
- if count_update_callback is not None:
486
- count_update_callback(num_outputs, num_inputs)
487
- if not order_outputs:
488
- yield _OutputValue(output)
489
- else:
490
- # hold on to outputs for function maps, so we can reorder them correctly.
491
- received_outputs[idx] = output
492
- while output_idx in received_outputs:
493
- output = received_outputs.pop(output_idx)
494
- yield _OutputValue(output)
495
- output_idx += 1
496
-
497
- assert len(received_outputs) == 0
498
-
499
- response_gen = stream.merge(drain_input_generator(), pump_inputs(), poll_outputs())
500
-
501
- async with response_gen.stream() as streamer:
502
- async for response in streamer:
503
- if response is not None:
504
- yield response.value
505
-
506
-
507
305
  # Wrapper type for api_pb2.FunctionStats
508
306
  @dataclass(frozen=True)
509
307
  class FunctionStats:
510
308
  """Simple data structure storing stats for a running function."""
511
309
 
512
310
  backlog: int
513
- num_active_runners: int
514
311
  num_total_runners: int
515
312
 
313
+ def __getattr__(self, name):
314
+ if name == "num_active_runners":
315
+ msg = (
316
+ "'FunctionStats.num_active_runners' is deprecated."
317
+ " It currently always has a value of 0,"
318
+ " but it will be removed in a future release."
319
+ )
320
+ deprecation_warning((2024, 6, 14), msg)
321
+ return 0
322
+ raise AttributeError(f"'FunctionStats' object has no attribute '{name}'")
323
+
516
324
 
517
325
  def _parse_retries(
518
326
  retries: Optional[Union[int, Retries]],
519
- raw_f: Optional[Callable] = None,
327
+ source: str = "",
520
328
  ) -> Optional[api_pb2.FunctionRetryPolicy]:
521
329
  if isinstance(retries, int):
522
330
  return Retries(
@@ -529,118 +337,168 @@ def _parse_retries(
529
337
  elif retries is None:
530
338
  return None
531
339
  else:
532
- err_object = f"Function {raw_f}" if raw_f else "Function"
533
- raise InvalidError(
534
- f"{err_object} retries must be an integer or instance of modal.Retries. Found: {type(retries)}"
535
- )
340
+ extra = f" on {source}" if source else ""
341
+ msg = f"Retries parameter must be an integer or instance of modal.Retries. Found: {type(retries)}{extra}."
342
+ raise InvalidError(msg)
536
343
 
537
344
 
538
345
  @dataclass
539
- class FunctionEnv:
346
+ class _FunctionSpec:
540
347
  """
541
- Stores information about the function environment. This is used for `modal shell` to support
542
- running shells in the same environment as a user-defined function.
348
+ Stores information about a Function specification.
349
+ This is used for `modal shell` to support running shells with
350
+ the same configuration as a user-defined Function.
543
351
  """
544
352
 
545
353
  image: Optional[_Image]
546
354
  mounts: Sequence[_Mount]
547
355
  secrets: Sequence[_Secret]
548
- network_file_systems: Dict[Union[str, PurePosixPath], _NetworkFileSystem]
549
- volumes: Dict[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]]
550
- gpu: GPU_T
356
+ network_file_systems: dict[Union[str, PurePosixPath], _NetworkFileSystem]
357
+ volumes: dict[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]]
358
+ gpus: Union[GPU_T, list[GPU_T]] # TODO(irfansharif): Somehow assert that it's the first kind, in sandboxes
551
359
  cloud: Optional[str]
552
- cpu: Optional[float]
553
- memory: Optional[int]
360
+ cpu: Optional[Union[float, tuple[float, float]]]
361
+ memory: Optional[Union[int, tuple[int, int]]]
362
+ ephemeral_disk: Optional[int]
363
+ scheduler_placement: Optional[SchedulerPlacement]
364
+ proxy: Optional[_Proxy]
365
+
366
+
367
+ P = typing_extensions.ParamSpec("P")
368
+ ReturnType = typing.TypeVar("ReturnType", covariant=True)
369
+ OriginalReturnType = typing.TypeVar(
370
+ "OriginalReturnType", covariant=True
371
+ ) # differs from return type if ReturnType is coroutine
554
372
 
555
373
 
556
- class _Function(_Object, type_prefix="fu"):
374
+ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type_prefix="fu"):
557
375
  """Functions are the basic units of serverless execution on Modal.
558
376
 
559
377
  Generally, you will not construct a `Function` directly. Instead, use the
560
- `@stub.function()` decorator on the `Stub` object for your application.
378
+ `App.function()` decorator to register your Python functions with your App.
561
379
  """
562
380
 
563
381
  # TODO: more type annotations
564
382
  _info: Optional[FunctionInfo]
565
- _all_mounts: Collection[_Mount]
566
- _stub: "modal.stub._Stub"
567
- _obj: Any
383
+ _serve_mounts: frozenset[_Mount] # set at load time, only by loader
384
+ _app: Optional["modal.app._App"] = None
385
+ _obj: Optional["modal.cls._Obj"] = None # only set for InstanceServiceFunctions and bound instance methods
568
386
  _web_url: Optional[str]
569
- _is_remote_cls_method: bool = False # TODO(erikbern): deprecated
570
387
  _function_name: Optional[str]
571
388
  _is_method: bool
572
- _env: FunctionEnv
389
+ _spec: Optional[_FunctionSpec] = None
573
390
  _tag: str
574
391
  _raw_f: Callable[..., Any]
575
392
  _build_args: dict
576
- _parent: "_Function"
393
+
394
+ _is_generator: Optional[bool] = None
395
+ _cluster_size: Optional[int] = None
396
+
397
+ # when this is the method of a class/object function, invocation of this function
398
+ # should supply the method name in the FunctionInput:
399
+ _use_method_name: str = ""
400
+
401
+ _class_parameter_info: Optional["api_pb2.ClassParameterInfo"] = None
402
+ _method_handle_metadata: Optional[dict[str, "api_pb2.FunctionHandleMetadata"]] = None
403
+
404
+ def _bind_method(
405
+ self,
406
+ user_cls,
407
+ method_name: str,
408
+ partial_function: "modal.partial_function._PartialFunction",
409
+ ):
410
+ """mdmd:hidden
411
+
412
+ Creates a _Function that is bound to a specific class method name. This _Function is not uniquely tied
413
+ to any backend function -- its object_id is the function ID of the class service function.
414
+
415
+ """
416
+ class_service_function = self
417
+ assert class_service_function._info # has to be a local function to be able to "bind" it
418
+ assert not class_service_function._is_method # should not be used on an already bound method placeholder
419
+ assert not class_service_function._obj # should only be used on base function / class service function
420
+ full_name = f"{user_cls.__name__}.{method_name}"
421
+
422
+ rep = f"Method({full_name})"
423
+ fun = _Object.__new__(_Function)
424
+ fun._init(rep)
425
+ fun._tag = full_name
426
+ fun._raw_f = partial_function.raw_f
427
+ fun._info = FunctionInfo(
428
+ partial_function.raw_f, user_cls=user_cls, serialized=class_service_function.info.is_serialized()
429
+ ) # needed for .local()
430
+ fun._use_method_name = method_name
431
+ fun._app = class_service_function._app
432
+ fun._is_generator = partial_function.is_generator
433
+ fun._cluster_size = partial_function.cluster_size
434
+ fun._spec = class_service_function._spec
435
+ fun._is_method = True
436
+ return fun
577
437
 
578
438
  @staticmethod
579
439
  def from_args(
580
440
  info: FunctionInfo,
581
- stub,
441
+ app,
582
442
  image: _Image,
583
- secret: Optional[_Secret] = None,
584
443
  secrets: Sequence[_Secret] = (),
585
444
  schedule: Optional[Schedule] = None,
586
- is_generator=False,
587
- gpu: GPU_T = None,
445
+ is_generator: bool = False,
446
+ gpu: Union[GPU_T, list[GPU_T]] = None,
588
447
  # TODO: maybe break this out into a separate decorator for notebooks.
589
448
  mounts: Collection[_Mount] = (),
590
- network_file_systems: Dict[Union[str, PurePosixPath], _NetworkFileSystem] = {},
449
+ network_file_systems: dict[Union[str, PurePosixPath], _NetworkFileSystem] = {},
591
450
  allow_cross_region_volumes: bool = False,
592
- volumes: Dict[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]] = {},
451
+ volumes: dict[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]] = {},
593
452
  webhook_config: Optional[api_pb2.WebhookConfig] = None,
594
- memory: Optional[int] = None,
453
+ memory: Optional[Union[int, tuple[int, int]]] = None,
595
454
  proxy: Optional[_Proxy] = None,
596
455
  retries: Optional[Union[int, Retries]] = None,
597
456
  timeout: Optional[int] = None,
598
457
  concurrency_limit: Optional[int] = None,
599
458
  allow_concurrent_inputs: Optional[int] = None,
459
+ batch_max_size: Optional[int] = None,
460
+ batch_wait_ms: Optional[int] = None,
600
461
  container_idle_timeout: Optional[int] = None,
601
- cpu: Optional[float] = None,
462
+ cpu: Optional[Union[float, tuple[float, float]]] = None,
602
463
  keep_warm: Optional[int] = None, # keep_warm=True is equivalent to keep_warm=1
603
464
  cloud: Optional[str] = None,
604
- _experimental_boost: bool = False,
605
- _experimental_scheduler: bool = False,
606
- _experimental_scheduler_placement: Optional[SchedulerPlacement] = None,
465
+ scheduler_placement: Optional[SchedulerPlacement] = None,
607
466
  is_builder_function: bool = False,
608
467
  is_auto_snapshot: bool = False,
609
468
  enable_memory_snapshot: bool = False,
610
- checkpointing_enabled: Optional[bool] = None,
611
- allow_background_volume_commits: bool = False,
612
469
  block_network: bool = False,
470
+ i6pn_enabled: bool = False,
471
+ cluster_size: Optional[int] = None, # Experimental: Clustered functions
613
472
  max_inputs: Optional[int] = None,
473
+ ephemeral_disk: Optional[int] = None,
474
+ _experimental_buffer_containers: Optional[int] = None,
475
+ _experimental_proxy_ip: Optional[str] = None,
476
+ _experimental_custom_scaling_factor: Optional[float] = None,
614
477
  ) -> None:
615
478
  """mdmd:hidden"""
479
+ # Needed to avoid circular imports
480
+ from .partial_function import _find_partial_methods_for_user_cls, _PartialFunctionFlags
481
+
616
482
  tag = info.get_tag()
617
483
 
618
- raw_f = info.raw_f
619
- assert callable(raw_f)
620
- if schedule is not None:
621
- if not info.is_nullary():
484
+ if info.raw_f:
485
+ raw_f = info.raw_f
486
+ assert callable(raw_f)
487
+ if schedule is not None and not info.is_nullary():
622
488
  raise InvalidError(
623
489
  f"Function {raw_f} has a schedule, so it needs to support being called with no arguments"
624
490
  )
625
-
626
- if secret is not None:
627
- deprecation_warning(
628
- (2024, 1, 31),
629
- "The singular `secret` parameter is deprecated. Pass a list to `secrets` instead.",
630
- )
631
- secrets = [secret, *secrets]
632
-
633
- if checkpointing_enabled is not None:
634
- deprecation_warning(
635
- (2024, 3, 4),
636
- "The argument `checkpointing_enabled` is now deprecated. Use `enable_memory_snapshot` instead.",
637
- )
638
- enable_memory_snapshot = checkpointing_enabled
491
+ else:
492
+ # must be a "class service function"
493
+ assert info.user_cls
494
+ assert not webhook_config
495
+ assert not schedule
639
496
 
640
497
  explicit_mounts = mounts
641
498
 
642
499
  if is_local():
643
500
  entrypoint_mounts = info.get_entrypoint_mount()
501
+
644
502
  all_mounts = [
645
503
  _get_client_mount(),
646
504
  *explicit_mounts,
@@ -648,45 +506,57 @@ class _Function(_Object, type_prefix="fu"):
648
506
  ]
649
507
 
650
508
  if config.get("automount"):
651
- automounts = info.get_auto_mounts()
652
- all_mounts += automounts
509
+ all_mounts += get_auto_mounts()
653
510
  else:
654
511
  # skip any mount introspection/logic inside containers, since the function
655
512
  # should already be hydrated
656
513
  # TODO: maybe the entire constructor should be exited early if not local?
657
514
  all_mounts = []
658
515
 
659
- retry_policy = _parse_retries(retries, raw_f)
516
+ retry_policy = _parse_retries(
517
+ retries, f"Function '{info.get_tag()}'" if info.raw_f else f"Class '{info.get_tag()}'"
518
+ )
660
519
 
661
- gpu_config = parse_gpu_config(gpu)
520
+ if webhook_config is not None and retry_policy is not None:
521
+ raise InvalidError(
522
+ "Web endpoints do not support retries.",
523
+ )
524
+
525
+ if is_generator and retry_policy is not None:
526
+ deprecation_warning(
527
+ (2024, 6, 25),
528
+ "Retries for generator functions are deprecated and will soon be removed.",
529
+ )
662
530
 
663
531
  if proxy:
664
532
  # HACK: remove this once we stop using ssh tunnels for this.
665
533
  if image:
534
+ # TODO(elias): this will cause an error if users use prior `.add_local_*` commands without copy=True
666
535
  image = image.apt_install("autossh")
667
536
 
668
- function_env = FunctionEnv(
537
+ function_spec = _FunctionSpec(
669
538
  mounts=all_mounts,
670
539
  secrets=secrets,
671
- gpu=gpu,
540
+ gpus=gpu,
672
541
  network_file_systems=network_file_systems,
673
542
  volumes=volumes,
674
543
  image=image,
675
544
  cloud=cloud,
676
545
  cpu=cpu,
677
546
  memory=memory,
547
+ ephemeral_disk=ephemeral_disk,
548
+ scheduler_placement=scheduler_placement,
549
+ proxy=proxy,
678
550
  )
679
551
 
680
- if info.cls and not is_auto_snapshot:
681
- # Needed to avoid circular imports
682
- from .partial_function import _find_callables_for_cls, _PartialFunctionFlags
683
-
684
- build_functions = list(_find_callables_for_cls(info.cls, _PartialFunctionFlags.BUILD).values())
685
- for build_function in build_functions:
686
- snapshot_info = FunctionInfo(build_function, cls=info.cls)
552
+ if info.user_cls and not is_auto_snapshot:
553
+ build_functions = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.BUILD).items()
554
+ for k, pf in build_functions:
555
+ build_function = pf.raw_f
556
+ snapshot_info = FunctionInfo(build_function, user_cls=info.user_cls)
687
557
  snapshot_function = _Function.from_args(
688
558
  snapshot_info,
689
- stub=None,
559
+ app=None,
690
560
  image=image,
691
561
  secrets=secrets,
692
562
  gpu=gpu,
@@ -694,16 +564,17 @@ class _Function(_Object, type_prefix="fu"):
694
564
  network_file_systems=network_file_systems,
695
565
  volumes=volumes,
696
566
  memory=memory,
697
- timeout=86400, # TODO: make this an argument to `@build()`
567
+ timeout=pf.build_timeout,
698
568
  cpu=cpu,
569
+ ephemeral_disk=ephemeral_disk,
699
570
  is_builder_function=True,
700
571
  is_auto_snapshot=True,
701
- _experimental_scheduler_placement=_experimental_scheduler_placement,
572
+ scheduler_placement=scheduler_placement,
702
573
  )
703
574
  image = _Image._from_args(
704
575
  base_images={"base": image},
705
576
  build_function=snapshot_function,
706
- force_build=image.force_build,
577
+ force_build=image.force_build or pf.force_build,
707
578
  )
708
579
 
709
580
  if keep_warm is not None and not isinstance(keep_warm, int):
@@ -711,9 +582,15 @@ class _Function(_Object, type_prefix="fu"):
711
582
 
712
583
  if (keep_warm is not None) and (concurrency_limit is not None) and concurrency_limit < keep_warm:
713
584
  raise InvalidError(
714
- f"Function `{info.function_name}` has `{concurrency_limit=}`, strictly less than its `{keep_warm=}` parameter."
585
+ f"Function `{info.function_name}` has `{concurrency_limit=}`, "
586
+ f"strictly less than its `{keep_warm=}` parameter."
715
587
  )
716
588
 
589
+ if _experimental_custom_scaling_factor is not None and (
590
+ _experimental_custom_scaling_factor < 0 or _experimental_custom_scaling_factor > 1
591
+ ):
592
+ raise InvalidError("`_experimental_custom_scaling_factor` must be between 0.0 and 1.0 inclusive.")
593
+
717
594
  if not cloud and not is_builder_function:
718
595
  cloud = config.get("default_cloud")
719
596
  if cloud:
@@ -730,22 +607,56 @@ class _Function(_Object, type_prefix="fu"):
730
607
  else:
731
608
  raise InvalidError("Webhooks cannot be generators")
732
609
 
610
+ if info.raw_f and batch_max_size:
611
+ func_name = info.raw_f.__name__
612
+ if is_generator:
613
+ raise InvalidError(f"Modal batched function {func_name} cannot return generators")
614
+ for arg in inspect.signature(info.raw_f).parameters.values():
615
+ if arg.default is not inspect.Parameter.empty:
616
+ raise InvalidError(f"Modal batched function {func_name} does not accept default arguments.")
617
+
618
+ if container_idle_timeout is not None and container_idle_timeout <= 0:
619
+ raise InvalidError("`container_idle_timeout` must be > 0")
620
+
621
+ if max_inputs is not None:
622
+ if not isinstance(max_inputs, int):
623
+ raise InvalidError(f"`max_inputs` must be an int, not {type(max_inputs).__name__}")
624
+ if max_inputs <= 0:
625
+ raise InvalidError("`max_inputs` must be positive")
626
+ if max_inputs > 1:
627
+ raise InvalidError("Only `max_inputs=1` is currently supported")
628
+
733
629
  # Validate volumes
734
630
  validated_volumes = validate_volumes(volumes)
735
631
  cloud_bucket_mounts = [(k, v) for k, v in validated_volumes if isinstance(v, _CloudBucketMount)]
736
632
  validated_volumes = [(k, v) for k, v in validated_volumes if isinstance(v, _Volume)]
737
633
 
738
634
  # Validate NFS
739
- if not isinstance(network_file_systems, dict):
740
- raise InvalidError("network_file_systems must be a dict[str, NetworkFileSystem] where the keys are paths")
741
- validated_network_file_systems = validate_mount_points("Network file system", network_file_systems)
635
+ validated_network_file_systems = validate_network_file_systems(network_file_systems)
742
636
 
743
637
  # Validate image
744
638
  if image is not None and not isinstance(image, _Image):
745
639
  raise InvalidError(f"Expected modal.Image object. Got {type(image)}.")
746
640
 
747
- def _deps(only_explicit_mounts=False) -> List[_Object]:
748
- deps: List[_Object] = list(secrets)
641
+ method_definitions: Optional[dict[str, api_pb2.MethodDefinition]] = None
642
+
643
+ if info.user_cls:
644
+ method_definitions = {}
645
+ partial_functions = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.FUNCTION)
646
+ for method_name, partial_function in partial_functions.items():
647
+ function_type = get_function_type(partial_function.is_generator)
648
+ function_name = f"{info.user_cls.__name__}.{method_name}"
649
+ method_definition = api_pb2.MethodDefinition(
650
+ webhook_config=partial_function.webhook_config,
651
+ function_type=function_type,
652
+ function_name=function_name,
653
+ )
654
+ method_definitions[method_name] = method_definition
655
+
656
+ function_type = get_function_type(is_generator)
657
+
658
+ def _deps(only_explicit_mounts=False) -> list[_Object]:
659
+ deps: list[_Object] = list(secrets)
749
660
  if only_explicit_mounts:
750
661
  # TODO: this is a bit hacky, but all_mounts may differ in the container vs locally
751
662
  # We don't want the function dependencies to change, so we have this way to force it to
@@ -769,271 +680,358 @@ class _Function(_Object, type_prefix="fu"):
769
680
  if cloud_bucket_mount.secret:
770
681
  deps.append(cloud_bucket_mount.secret)
771
682
 
772
- # Add implicit dependencies from the function's code
773
- objs: list[Object] = get_referred_objects(info.raw_f)
774
- _objs: list[_Object] = synchronizer._translate_in(objs) # type: ignore
775
- deps += _objs
776
683
  return deps
777
684
 
778
685
  async def _preload(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
779
686
  assert resolver.client and resolver.client.stub
780
- if is_generator:
781
- function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR
782
- else:
783
- function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION
784
687
 
688
+ assert resolver.app_id
785
689
  req = api_pb2.FunctionPrecreateRequest(
786
690
  app_id=resolver.app_id,
787
691
  function_name=info.function_name,
788
692
  function_type=function_type,
789
- webhook_config=webhook_config,
790
693
  existing_function_id=existing_object_id or "",
791
694
  )
695
+ if method_definitions:
696
+ for method_name, method_definition in method_definitions.items():
697
+ req.method_definitions[method_name].CopyFrom(method_definition)
698
+ elif webhook_config:
699
+ req.webhook_config.CopyFrom(webhook_config)
792
700
  response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req)
793
701
  self._hydrate(response.function_id, resolver.client, response.handle_metadata)
794
702
 
795
703
  async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
796
704
  assert resolver.client and resolver.client.stub
797
- status_row = resolver.add_status_row()
798
- status_row.message(f"Creating {tag}...")
799
-
800
- if is_generator:
801
- function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR
802
- else:
803
- function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION
804
-
805
- if cpu is not None and cpu < 0.25:
806
- raise InvalidError(f"Invalid fractional CPU value {cpu}. Cannot have less than 0.25 CPU resources.")
807
- milli_cpu = int(1000 * cpu) if cpu is not None else 0
808
-
809
- timeout_secs = timeout
705
+ with FunctionCreationStatus(resolver, tag) as function_creation_status:
706
+ timeout_secs = timeout
810
707
 
811
- if stub and stub.is_interactive and not is_builder_function:
812
- pty_info = _pty.get_pty_info(shell=False)
813
- else:
814
- pty_info = None
815
-
816
- if info.is_serialized():
817
- # Use cloudpickle. Used when working w/ Jupyter notebooks.
818
- # serialize at _load time, not function decoration time
819
- # otherwise we can't capture a surrounding class for lifetime methods etc.
820
- function_serialized = info.serialized_function()
821
- class_serialized = serialize(info.cls) if info.cls is not None else None
822
-
823
- # Ensure that large data in global variables does not blow up the gRPC payload,
824
- # which has maximum size 100 MiB. We set the limit lower for performance reasons.
825
- if len(function_serialized) > 16 << 20: # 16 MiB
826
- raise InvalidError(
827
- f"Function {info.raw_f} has size {len(function_serialized)} bytes when packaged. "
828
- "This is larger than the maximum limit of 16 MiB. "
829
- "Try reducing the size of the closure by using parameters or mounts, not large global variables."
830
- )
831
- elif len(function_serialized) > 256 << 10: # 256 KiB
832
- warnings.warn(
833
- f"Function {info.raw_f} has size {len(function_serialized)} bytes when packaged. "
834
- "This is larger than the recommended limit of 256 KiB. "
835
- "Try reducing the size of the closure by using parameters or mounts, not large global variables."
708
+ if app and app.is_interactive and not is_builder_function:
709
+ pty_info = get_pty_info(shell=False)
710
+ else:
711
+ pty_info = None
712
+
713
+ if info.is_serialized():
714
+ # Use cloudpickle. Used when working w/ Jupyter notebooks.
715
+ # serialize at _load time, not function decoration time
716
+ # otherwise we can't capture a surrounding class for lifetime methods etc.
717
+ function_serialized = info.serialized_function()
718
+ class_serialized = serialize(info.user_cls) if info.user_cls is not None else None
719
+ # Ensure that large data in global variables does not blow up the gRPC payload,
720
+ # which has maximum size 100 MiB. We set the limit lower for performance reasons.
721
+ if len(function_serialized) > 16 << 20: # 16 MiB
722
+ raise InvalidError(
723
+ f"Function {info.raw_f} has size {len(function_serialized)} bytes when packaged. "
724
+ "This is larger than the maximum limit of 16 MiB. "
725
+ "Try reducing the size of the closure by using parameters or mounts, "
726
+ "not large global variables."
727
+ )
728
+ elif len(function_serialized) > 256 << 10: # 256 KiB
729
+ warnings.warn(
730
+ f"Function {info.raw_f} has size {len(function_serialized)} bytes when packaged. "
731
+ "This is larger than the recommended limit of 256 KiB. "
732
+ "Try reducing the size of the closure by using parameters or mounts, "
733
+ "not large global variables."
734
+ )
735
+ else:
736
+ function_serialized = None
737
+ class_serialized = None
738
+
739
+ app_name = ""
740
+ if app and app.name:
741
+ app_name = app.name
742
+
743
+ # Relies on dicts being ordered (true as of Python 3.6).
744
+ volume_mounts = [
745
+ api_pb2.VolumeMount(
746
+ mount_path=path,
747
+ volume_id=volume.object_id,
748
+ allow_background_commits=True,
836
749
  )
837
- else:
838
- function_serialized = None
839
- class_serialized = None
840
-
841
- stub_name = ""
842
- if stub and stub.name:
843
- stub_name = stub.name
844
-
845
- # Relies on dicts being ordered (true as of Python 3.6).
846
- volume_mounts = [
847
- api_pb2.VolumeMount(
848
- mount_path=path,
849
- volume_id=volume.object_id,
850
- allow_background_commits=allow_background_volume_commits,
750
+ for path, volume in validated_volumes
751
+ ]
752
+ loaded_mount_ids = {m.object_id for m in all_mounts} | {m.object_id for m in image._mount_layers}
753
+
754
+ # Get object dependencies
755
+ object_dependencies = []
756
+ for dep in _deps(only_explicit_mounts=True):
757
+ if not dep.object_id:
758
+ raise Exception(f"Dependency {dep} isn't hydrated")
759
+ object_dependencies.append(api_pb2.ObjectDependency(object_id=dep.object_id))
760
+
761
+ function_data: Optional[api_pb2.FunctionData] = None
762
+ function_definition: Optional[api_pb2.Function] = None
763
+
764
+ # Create function remotely
765
+ function_definition = api_pb2.Function(
766
+ module_name=info.module_name or "",
767
+ function_name=info.function_name,
768
+ mount_ids=loaded_mount_ids,
769
+ secret_ids=[secret.object_id for secret in secrets],
770
+ image_id=(image.object_id if image else ""),
771
+ definition_type=info.get_definition_type(),
772
+ function_serialized=function_serialized or b"",
773
+ class_serialized=class_serialized or b"",
774
+ function_type=function_type,
775
+ webhook_config=webhook_config,
776
+ method_definitions=method_definitions,
777
+ method_definitions_set=True,
778
+ shared_volume_mounts=network_file_system_mount_protos(
779
+ validated_network_file_systems, allow_cross_region_volumes
780
+ ),
781
+ volume_mounts=volume_mounts,
782
+ proxy_id=(proxy.object_id if proxy else None),
783
+ retry_policy=retry_policy,
784
+ timeout_secs=timeout_secs or 0,
785
+ task_idle_timeout_secs=container_idle_timeout or 0,
786
+ concurrency_limit=concurrency_limit or 0,
787
+ pty_info=pty_info,
788
+ cloud_provider=cloud_provider,
789
+ warm_pool_size=keep_warm or 0,
790
+ runtime=config.get("function_runtime"),
791
+ runtime_debug=config.get("function_runtime_debug"),
792
+ runtime_perf_record=config.get("runtime_perf_record"),
793
+ app_name=app_name,
794
+ is_builder_function=is_builder_function,
795
+ target_concurrent_inputs=allow_concurrent_inputs or 0,
796
+ batch_max_size=batch_max_size or 0,
797
+ batch_linger_ms=batch_wait_ms or 0,
798
+ worker_id=config.get("worker_id"),
799
+ is_auto_snapshot=is_auto_snapshot,
800
+ is_method=bool(info.user_cls) and not info.is_service_class(),
801
+ checkpointing_enabled=enable_memory_snapshot,
802
+ object_dependencies=object_dependencies,
803
+ block_network=block_network,
804
+ max_inputs=max_inputs or 0,
805
+ cloud_bucket_mounts=cloud_bucket_mounts_to_proto(cloud_bucket_mounts),
806
+ scheduler_placement=scheduler_placement.proto if scheduler_placement else None,
807
+ is_class=info.is_service_class(),
808
+ class_parameter_info=info.class_parameter_info(),
809
+ i6pn_enabled=i6pn_enabled,
810
+ schedule=schedule.proto_message if schedule is not None else None,
811
+ snapshot_debug=config.get("snapshot_debug"),
812
+ _experimental_group_size=cluster_size or 0, # Experimental: Clustered functions
813
+ _experimental_concurrent_cancellations=True,
814
+ _experimental_buffer_containers=_experimental_buffer_containers or 0,
815
+ _experimental_proxy_ip=_experimental_proxy_ip,
816
+ _experimental_custom_scaling=_experimental_custom_scaling_factor is not None,
851
817
  )
852
- for path, volume in validated_volumes
853
- ]
854
- loaded_mount_ids = {m.object_id for m in all_mounts}
855
-
856
- # Get object dependencies
857
- object_dependencies = []
858
- for dep in _deps(only_explicit_mounts=True):
859
- if not dep.object_id:
860
- raise Exception(f"Dependency {dep} isn't hydrated")
861
- object_dependencies.append(api_pb2.ObjectDependency(object_id=dep.object_id))
862
-
863
- # Create function remotely
864
- function_definition = api_pb2.Function(
865
- module_name=info.module_name or "",
866
- function_name=info.function_name,
867
- mount_ids=loaded_mount_ids,
868
- secret_ids=[secret.object_id for secret in secrets],
869
- image_id=(image.object_id if image else ""),
870
- definition_type=info.definition_type,
871
- function_serialized=function_serialized or b"",
872
- class_serialized=class_serialized or b"",
873
- function_type=function_type,
874
- resources=api_pb2.Resources(milli_cpu=milli_cpu, gpu_config=gpu_config, memory_mb=memory or 0),
875
- webhook_config=webhook_config,
876
- shared_volume_mounts=network_file_system_mount_protos(
877
- validated_network_file_systems, allow_cross_region_volumes
878
- ),
879
- volume_mounts=volume_mounts,
880
- proxy_id=(proxy.object_id if proxy else None),
881
- retry_policy=retry_policy,
882
- timeout_secs=timeout_secs or 0,
883
- task_idle_timeout_secs=container_idle_timeout or 0,
884
- concurrency_limit=concurrency_limit or 0,
885
- pty_info=pty_info,
886
- cloud_provider=cloud_provider,
887
- warm_pool_size=keep_warm or 0,
888
- runtime=config.get("function_runtime"),
889
- runtime_debug=config.get("function_runtime_debug"),
890
- stub_name=stub_name,
891
- is_builder_function=is_builder_function,
892
- allow_concurrent_inputs=allow_concurrent_inputs or 0,
893
- worker_id=config.get("worker_id"),
894
- is_auto_snapshot=is_auto_snapshot,
895
- is_method=bool(info.cls),
896
- checkpointing_enabled=enable_memory_snapshot,
897
- is_checkpointing_function=False,
898
- object_dependencies=object_dependencies,
899
- block_network=block_network,
900
- max_inputs=max_inputs or 0,
901
- cloud_bucket_mounts=cloud_bucket_mounts_to_proto(cloud_bucket_mounts),
902
- _experimental_boost=_experimental_boost,
903
- _experimental_scheduler=_experimental_scheduler,
904
- _experimental_scheduler_placement=_experimental_scheduler_placement.proto
905
- if _experimental_scheduler_placement
906
- else None,
907
- )
908
- request = api_pb2.FunctionCreateRequest(
909
- app_id=resolver.app_id,
910
- function=function_definition,
911
- schedule=schedule.proto_message if schedule is not None else None,
912
- existing_function_id=existing_object_id or "",
913
- )
914
- try:
915
- response: api_pb2.FunctionCreateResponse = await retry_transient_errors(
916
- resolver.client.stub.FunctionCreate, request
917
- )
918
- except GRPCError as exc:
919
- if exc.status == Status.INVALID_ARGUMENT:
920
- raise InvalidError(exc.message)
921
- if exc.status == Status.FAILED_PRECONDITION:
922
- raise InvalidError(exc.message)
923
- if exc.message and "Received :status = '413'" in exc.message:
924
- raise InvalidError(f"Function {raw_f} is too large to deploy.")
925
- raise
926
-
927
- if response.function.web_url:
928
- # Ensure terms used here match terms used in modal.com/docs/guide/webhook-urls doc.
929
- if response.function.web_url_info.truncated:
930
- suffix = " [grey70](label truncated)[/grey70]"
931
- elif response.function.web_url_info.has_unique_hash:
932
- suffix = " [grey70](label includes conflict-avoidance hash)[/grey70]"
933
- elif response.function.web_url_info.label_stolen:
934
- suffix = " [grey70](label stolen)[/grey70]"
935
- else:
936
- suffix = ""
937
- # TODO: this is only printed when we're showing progress. Maybe move this somewhere else.
938
- status_row.finish(f"Created {tag} => [magenta underline]{response.web_url}[/magenta underline]{suffix}")
939
-
940
- # Print custom domain in terminal
941
- for custom_domain in response.function.custom_domain_info:
942
- custom_domain_status_row = resolver.add_status_row()
943
- custom_domain_status_row.finish(
944
- f"Custom domain for {tag} => [magenta underline]{custom_domain.url}[/magenta underline]{suffix}"
818
+
819
+ if isinstance(gpu, list):
820
+ function_data = api_pb2.FunctionData(
821
+ module_name=function_definition.module_name,
822
+ function_name=function_definition.function_name,
823
+ function_type=function_definition.function_type,
824
+ warm_pool_size=function_definition.warm_pool_size,
825
+ concurrency_limit=function_definition.concurrency_limit,
826
+ task_idle_timeout_secs=function_definition.task_idle_timeout_secs,
827
+ worker_id=function_definition.worker_id,
828
+ timeout_secs=function_definition.timeout_secs,
829
+ web_url=function_definition.web_url,
830
+ web_url_info=function_definition.web_url_info,
831
+ webhook_config=function_definition.webhook_config,
832
+ custom_domain_info=function_definition.custom_domain_info,
833
+ schedule=schedule.proto_message if schedule is not None else None,
834
+ is_class=function_definition.is_class,
835
+ class_parameter_info=function_definition.class_parameter_info,
836
+ is_method=function_definition.is_method,
837
+ use_function_id=function_definition.use_function_id,
838
+ use_method_name=function_definition.use_method_name,
839
+ method_definitions=function_definition.method_definitions,
840
+ method_definitions_set=function_definition.method_definitions_set,
841
+ _experimental_group_size=function_definition._experimental_group_size,
842
+ _experimental_buffer_containers=function_definition._experimental_buffer_containers,
843
+ _experimental_custom_scaling=function_definition._experimental_custom_scaling,
844
+ _experimental_proxy_ip=function_definition._experimental_proxy_ip,
845
+ snapshot_debug=function_definition.snapshot_debug,
846
+ runtime_perf_record=function_definition.runtime_perf_record,
945
847
  )
946
848
 
947
- else:
948
- status_row.finish(f"Created {tag}.")
849
+ ranked_functions = []
850
+ for rank, _gpu in enumerate(gpu):
851
+ function_definition_copy = api_pb2.Function()
852
+ function_definition_copy.CopyFrom(function_definition)
853
+
854
+ function_definition_copy.resources.CopyFrom(
855
+ convert_fn_config_to_resources_config(
856
+ cpu=cpu, memory=memory, gpu=_gpu, ephemeral_disk=ephemeral_disk
857
+ ),
858
+ )
859
+ ranked_function = api_pb2.FunctionData.RankedFunction(
860
+ rank=rank,
861
+ function=function_definition_copy,
862
+ )
863
+ ranked_functions.append(ranked_function)
864
+ function_data.ranked_functions.extend(ranked_functions)
865
+ function_definition = None # function_definition is not used in this case
866
+ else:
867
+ # TODO(irfansharif): Assert on this specific type once we get rid of python 3.9.
868
+ # assert isinstance(gpu, GPU_T) # includes the case where gpu==None case
869
+ function_definition.resources.CopyFrom(
870
+ convert_fn_config_to_resources_config(
871
+ cpu=cpu, memory=memory, gpu=gpu, ephemeral_disk=ephemeral_disk
872
+ ), # type: ignore
873
+ )
949
874
 
875
+ assert resolver.app_id
876
+ assert (function_definition is None) != (function_data is None) # xor
877
+ request = api_pb2.FunctionCreateRequest(
878
+ app_id=resolver.app_id,
879
+ function=function_definition,
880
+ function_data=function_data,
881
+ existing_function_id=existing_object_id or "",
882
+ defer_updates=True,
883
+ )
884
+ try:
885
+ response: api_pb2.FunctionCreateResponse = await retry_transient_errors(
886
+ resolver.client.stub.FunctionCreate, request
887
+ )
888
+ except GRPCError as exc:
889
+ if exc.status == Status.INVALID_ARGUMENT:
890
+ raise InvalidError(exc.message)
891
+ if exc.status == Status.FAILED_PRECONDITION:
892
+ raise InvalidError(exc.message)
893
+ if exc.message and "Received :status = '413'" in exc.message:
894
+ raise InvalidError(f"Function {info.function_name} is too large to deploy.")
895
+ raise
896
+ function_creation_status.set_response(response)
897
+ serve_mounts = {m for m in all_mounts if m.is_local()} # needed for modal.serve file watching
898
+ serve_mounts |= image._serve_mounts
899
+ obj._serve_mounts = frozenset(serve_mounts)
950
900
  self._hydrate(response.function_id, resolver.client, response.handle_metadata)
951
901
 
952
902
  rep = f"Function({tag})"
953
903
  obj = _Function._from_loader(_load, rep, preload=_preload, deps=_deps)
954
904
 
955
- obj._raw_f = raw_f
905
+ obj._raw_f = info.raw_f
956
906
  obj._info = info
957
907
  obj._tag = tag
958
- obj._all_mounts = all_mounts # needed for modal.serve file watching
959
- obj._stub = stub # needed for CLI right now
908
+ obj._app = app # needed for CLI right now
960
909
  obj._obj = None
961
910
  obj._is_generator = is_generator
962
- obj._is_method = bool(info.cls)
963
- obj._env = function_env # needed for modal shell
911
+ obj._cluster_size = cluster_size
912
+ obj._is_method = False
913
+ obj._spec = function_spec # needed for modal shell
964
914
 
965
- # Used to check whether we should rebuild an image using run_function
966
- # Plaintext source and arg definition for the function, so it's part of the image
967
- # hash. We can't use the cloudpickle hash because it's not very stable.
915
+ # Used to check whether we should rebuild a modal.Image which uses `run_function`.
916
+ gpus: list[GPU_T] = gpu if isinstance(gpu, list) else [gpu]
968
917
  obj._build_args = dict( # See get_build_def
969
918
  secrets=repr(secrets),
970
- gpu_config=repr(gpu_config),
919
+ gpu_config=repr([parse_gpu_config(_gpu) for _gpu in gpus]),
971
920
  mounts=repr(mounts),
972
921
  network_file_systems=repr(network_file_systems),
973
922
  )
923
+ # these key are excluded if empty to avoid rebuilds on client upgrade
924
+ if volumes:
925
+ obj._build_args["volumes"] = repr(volumes)
926
+ if cloud or scheduler_placement:
927
+ obj._build_args["cloud"] = repr(cloud)
928
+ obj._build_args["scheduler_placement"] = repr(scheduler_placement)
974
929
 
975
930
  return obj
976
931
 
977
- def from_parametrized(
932
+ def _bind_parameters(
978
933
  self,
979
- obj,
980
- from_other_workspace: bool,
934
+ obj: "modal.cls._Obj",
981
935
  options: Optional[api_pb2.FunctionOptions],
982
936
  args: Sized,
983
- kwargs: Dict[str, Any],
937
+ kwargs: dict[str, Any],
984
938
  ) -> "_Function":
985
- """mdmd:hidden"""
939
+ """mdmd:hidden
986
940
 
987
- async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
988
- if not self._parent.is_hydrated:
941
+ Binds a class-function to a specific instance of (init params, options) or a new workspace
942
+ """
943
+
944
+ # In some cases, reuse the base function, i.e. not create new clones of each method or the "service function"
945
+ can_use_parent = len(args) + len(kwargs) == 0 and options is None
946
+ parent = self
947
+
948
+ async def _load(param_bound_func: _Function, resolver: Resolver, existing_object_id: Optional[str]):
949
+ if parent is None:
950
+ raise ExecutionError("Can't find the parent class' service function")
951
+ try:
952
+ identity = f"{parent.info.function_name} class service function"
953
+ except Exception:
954
+ # Can't always look up the function name that way, so fall back to generic message
955
+ identity = "class service function for a parameterized class"
956
+ if not parent.is_hydrated:
957
+ if parent.app._running_app is None:
958
+ reason = ", because the App it is defined on is not running"
959
+ else:
960
+ reason = ""
989
961
  raise ExecutionError(
990
- "Base function in class has not been hydrated. This might happen if an object is"
991
- " defined on a different stub, or if it's on the same stub but it didn't get"
992
- " created because it wasn't defined in global scope."
962
+ f"The {identity} has not been hydrated with the metadata it needs to run on Modal{reason}."
993
963
  )
994
- assert self._parent._client.stub
995
- serialized_params = serialize((args, kwargs))
964
+
965
+ assert parent._client.stub
966
+
967
+ if can_use_parent:
968
+ # We can end up here if parent wasn't hydrated when class was instantiated, but has been since.
969
+ param_bound_func._hydrate_from_other(parent)
970
+ return
971
+
972
+ if (
973
+ parent._class_parameter_info
974
+ and parent._class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO
975
+ ):
976
+ if args:
977
+ # TODO(elias) - We could potentially support positional args as well, if we want to?
978
+ raise InvalidError(
979
+ "Can't use positional arguments with modal.parameter-based synthetic constructors.\n"
980
+ "Use (<parameter_name>=value) keyword arguments when constructing classes instead."
981
+ )
982
+ serialized_params = serialize_proto_params(kwargs, parent._class_parameter_info.schema)
983
+ else:
984
+ serialized_params = serialize((args, kwargs))
996
985
  environment_name = _get_environment_name(None, resolver)
986
+ assert parent is not None
997
987
  req = api_pb2.FunctionBindParamsRequest(
998
- function_id=self._parent._object_id,
988
+ function_id=parent._object_id,
999
989
  serialized_params=serialized_params,
1000
990
  function_options=options,
1001
991
  environment_name=environment_name
1002
992
  or "", # TODO: investigate shouldn't environment name always be specified here?
1003
993
  )
1004
- response = await retry_transient_errors(self._parent._client.stub.FunctionBindParams, req)
1005
- self._hydrate(response.bound_function_id, self._parent._client, response.handle_metadata)
1006
-
1007
- fun = _Function._from_loader(_load, "Function(parametrized)", hydrate_lazily=True)
1008
- if len(args) + len(kwargs) == 0 and not from_other_workspace and options is None and self.is_hydrated:
1009
- # Edge case that lets us hydrate all objects right away
1010
- fun._hydrate_from_other(self)
1011
- fun._is_remote_cls_method = True # TODO(erikbern): deprecated
994
+
995
+ response = await retry_transient_errors(parent._client.stub.FunctionBindParams, req)
996
+ param_bound_func._hydrate(response.bound_function_id, parent._client, response.handle_metadata)
997
+
998
+ fun: _Function = _Function._from_loader(_load, "Function(parametrized)", hydrate_lazily=True)
999
+
1000
+ if can_use_parent and parent.is_hydrated:
1001
+ # skip the resolver altogether:
1002
+ fun._hydrate_from_other(parent)
1003
+
1012
1004
  fun._info = self._info
1013
1005
  fun._obj = obj
1014
- fun._is_generator = self._is_generator
1015
- fun._is_method = True
1016
- fun._parent = self
1017
-
1018
1006
  return fun
1019
1007
 
1020
1008
  @live_method
1021
1009
  async def keep_warm(self, warm_pool_size: int) -> None:
1022
- """Set the warm pool size for the function (including parametrized functions).
1010
+ """Set the warm pool size for the function.
1023
1011
 
1024
- Please exercise care when using this advanced feature! Setting and forgetting a warm pool on functions can lead to increased costs.
1012
+ Please exercise care when using this advanced feature!
1013
+ Setting and forgetting a warm pool on functions can lead to increased costs.
1025
1014
 
1026
- ```python
1015
+ ```python notest
1027
1016
  # Usage on a regular function.
1028
1017
  f = modal.Function.lookup("my-app", "function")
1029
1018
  f.keep_warm(2)
1030
1019
 
1031
1020
  # Usage on a parametrized function.
1032
1021
  Model = modal.Cls.lookup("my-app", "Model")
1033
- Model("fine-tuned-model").inference.keep_warm(2)
1022
+ Model("fine-tuned-model").keep_warm(2)
1034
1023
  ```
1035
1024
  """
1025
+ if self._is_method:
1026
+ raise InvalidError(
1027
+ textwrap.dedent(
1028
+ """
1029
+ The `.keep_warm()` method can not be used on Modal class *methods* deployed using Modal >v0.63.
1036
1030
 
1031
+ Call `.keep_warm()` on the class *instance* instead.
1032
+ """
1033
+ )
1034
+ )
1037
1035
  assert self._client and self._client.stub
1038
1036
  request = api_pb2.FunctionUpdateSchedulingParamsRequest(
1039
1037
  function_id=self._object_id, warm_pool_size_override=warm_pool_size
@@ -1041,17 +1039,22 @@ class _Function(_Object, type_prefix="fu"):
1041
1039
  await retry_transient_errors(self._client.stub.FunctionUpdateSchedulingParams, request)
1042
1040
 
1043
1041
  @classmethod
1042
+ @renamed_parameter((2024, 12, 18), "tag", "name")
1044
1043
  def from_name(
1045
- cls: Type["_Function"],
1044
+ cls: type["_Function"],
1046
1045
  app_name: str,
1047
- tag: Optional[str] = None,
1046
+ name: str,
1048
1047
  namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE,
1049
1048
  environment_name: Optional[str] = None,
1050
1049
  ) -> "_Function":
1051
- """Retrieve a function with a given name and tag.
1050
+ """Reference a Function from a deployed App by its name.
1051
+
1052
+ In contast to `modal.Function.lookup`, this is a lazy method
1053
+ that defers hydrating the local object with metadata from
1054
+ Modal servers until the first time it is actually used.
1052
1055
 
1053
1056
  ```python
1054
- other_function = modal.Function.from_name("other-app", "function")
1057
+ f = modal.Function.from_name("other-app", "function")
1055
1058
  ```
1056
1059
  """
1057
1060
 
@@ -1059,7 +1062,7 @@ class _Function(_Object, type_prefix="fu"):
1059
1062
  assert resolver.client and resolver.client.stub
1060
1063
  request = api_pb2.FunctionGetRequest(
1061
1064
  app_name=app_name,
1062
- object_tag=tag or "",
1065
+ object_tag=name,
1063
1066
  namespace=namespace,
1064
1067
  environment_name=_get_environment_name(environment_name, resolver) or "",
1065
1068
  )
@@ -1071,26 +1074,32 @@ class _Function(_Object, type_prefix="fu"):
1071
1074
  else:
1072
1075
  raise
1073
1076
 
1077
+ print_server_warnings(response.server_warnings)
1078
+
1074
1079
  self._hydrate(response.function_id, resolver.client, response.handle_metadata)
1075
1080
 
1076
1081
  rep = f"Ref({app_name})"
1077
- return cls._from_loader(_load_remote, rep, is_another_app=True)
1082
+ return cls._from_loader(_load_remote, rep, is_another_app=True, hydrate_lazily=True)
1078
1083
 
1079
1084
  @staticmethod
1085
+ @renamed_parameter((2024, 12, 18), "tag", "name")
1080
1086
  async def lookup(
1081
1087
  app_name: str,
1082
- tag: Optional[str] = None,
1088
+ name: str,
1083
1089
  namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE,
1084
1090
  client: Optional[_Client] = None,
1085
1091
  environment_name: Optional[str] = None,
1086
1092
  ) -> "_Function":
1087
- """Lookup a function with a given name and tag.
1093
+ """Lookup a Function from a deployed App by its name.
1088
1094
 
1089
- ```python
1090
- other_function = modal.Function.lookup("other-app", "function")
1095
+ In contrast to `modal.Function.from_name`, this is an eager method
1096
+ that will hydrate the local object with metadata from Modal servers.
1097
+
1098
+ ```python notest
1099
+ f = modal.Function.lookup("other-app", "function")
1091
1100
  ```
1092
1101
  """
1093
- obj = _Function.from_name(app_name, tag, namespace=namespace, environment_name=environment_name)
1102
+ obj = _Function.from_name(app_name, name, namespace=namespace, environment_name=environment_name)
1094
1103
  if client is None:
1095
1104
  client = await _Client.from_env()
1096
1105
  resolver = Resolver(client=client)
@@ -1104,9 +1113,18 @@ class _Function(_Object, type_prefix="fu"):
1104
1113
  return self._tag
1105
1114
 
1106
1115
  @property
1107
- def stub(self) -> "modal.stub._Stub":
1116
+ def app(self) -> "modal.app._App":
1108
1117
  """mdmd:hidden"""
1109
- return self._stub
1118
+ if self._app is None:
1119
+ raise ExecutionError("The app has not been assigned on the function at this point")
1120
+
1121
+ return self._app
1122
+
1123
+ @property
1124
+ def stub(self) -> "modal.app._App":
1125
+ """mdmd:hidden"""
1126
+ # Deprecated soon, only for backwards compatibility
1127
+ return self.app
1110
1128
 
1111
1129
  @property
1112
1130
  def info(self) -> FunctionInfo:
@@ -1115,12 +1133,15 @@ class _Function(_Object, type_prefix="fu"):
1115
1133
  return self._info
1116
1134
 
1117
1135
  @property
1118
- def env(self) -> FunctionEnv:
1136
+ def spec(self) -> _FunctionSpec:
1119
1137
  """mdmd:hidden"""
1120
- return self._env
1138
+ assert self._spec
1139
+ return self._spec
1121
1140
 
1122
1141
  def get_build_def(self) -> str:
1123
1142
  """mdmd:hidden"""
1143
+ # Plaintext source and arg definition for the function, so it's part of the image
1144
+ # hash. We can't use the cloudpickle hash because it's not very stable.
1124
1145
  assert hasattr(self, "_raw_f") and hasattr(self, "_build_args")
1125
1146
  return f"{inspect.getsource(self._raw_f)}\n{repr(self._build_args)}"
1126
1147
 
@@ -1130,208 +1151,170 @@ class _Function(_Object, type_prefix="fu"):
1130
1151
  # Overridden concrete implementation of base class method
1131
1152
  self._progress = None
1132
1153
  self._is_generator = None
1154
+ self._cluster_size = None
1133
1155
  self._web_url = None
1134
- self._output_mgr: Optional[OutputManager] = None
1135
- self._mute_cancellation = (
1136
- False # set when a user terminates the app intentionally, to prevent useless traceback spam
1137
- )
1138
1156
  self._function_name = None
1139
1157
  self._info = None
1158
+ self._serve_mounts = frozenset()
1140
1159
 
1141
1160
  def _hydrate_metadata(self, metadata: Optional[Message]):
1142
1161
  # Overridden concrete implementation of base class method
1143
- assert metadata and isinstance(metadata, (api_pb2.Function, api_pb2.FunctionHandleMetadata))
1162
+ assert metadata and isinstance(metadata, api_pb2.FunctionHandleMetadata)
1144
1163
  self._is_generator = metadata.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR
1145
1164
  self._web_url = metadata.web_url
1146
1165
  self._function_name = metadata.function_name
1147
1166
  self._is_method = metadata.is_method
1167
+ self._use_method_name = metadata.use_method_name
1168
+ self._class_parameter_info = metadata.class_parameter_info
1169
+ self._method_handle_metadata = dict(metadata.method_handle_metadata)
1170
+ self._definition_id = metadata.definition_id
1148
1171
 
1149
1172
  def _get_metadata(self):
1150
1173
  # Overridden concrete implementation of base class method
1151
- assert self._function_name
1174
+ assert self._function_name, f"Function name must be set before metadata can be retrieved for {self}"
1152
1175
  return api_pb2.FunctionHandleMetadata(
1153
1176
  function_name=self._function_name,
1154
- function_type=(
1155
- api_pb2.Function.FUNCTION_TYPE_GENERATOR
1156
- if self._is_generator
1157
- else api_pb2.Function.FUNCTION_TYPE_FUNCTION
1158
- ),
1177
+ function_type=get_function_type(self._is_generator),
1159
1178
  web_url=self._web_url or "",
1179
+ use_method_name=self._use_method_name,
1180
+ is_method=self._is_method,
1181
+ class_parameter_info=self._class_parameter_info,
1182
+ definition_id=self._definition_id,
1183
+ method_handle_metadata=self._method_handle_metadata,
1160
1184
  )
1161
1185
 
1162
- def _set_mute_cancellation(self, value: bool = True):
1163
- self._mute_cancellation = value
1164
-
1165
- def _set_output_mgr(self, output_mgr: OutputManager):
1166
- self._output_mgr = output_mgr
1186
+ def _check_no_web_url(self, fn_name: str):
1187
+ if self._web_url:
1188
+ raise InvalidError(
1189
+ f"A webhook function cannot be invoked for remote execution with `.{fn_name}`. "
1190
+ f"Invoke this function via its web url '{self._web_url}' "
1191
+ + f"or call it locally: {self._function_name}.local()"
1192
+ )
1167
1193
 
1194
+ # TODO (live_method on properties is not great, since it could be blocking the event loop from async contexts)
1168
1195
  @property
1169
- def web_url(self) -> str:
1196
+ @live_method
1197
+ async def web_url(self) -> str:
1170
1198
  """URL of a Function running as a web endpoint."""
1171
1199
  if not self._web_url:
1172
1200
  raise ValueError(
1173
- f"No web_url can be found for function {self._function_name}. web_url can only be referenced from a running app context"
1201
+ f"No web_url can be found for function {self._function_name}. web_url "
1202
+ "can only be referenced from a running app context"
1174
1203
  )
1175
1204
  return self._web_url
1176
1205
 
1177
1206
  @property
1178
- def is_generator(self) -> bool:
1207
+ async def is_generator(self) -> bool:
1179
1208
  """mdmd:hidden"""
1180
- assert self._is_generator is not None
1209
+ # hacky: kind of like @live_method, but not hydrating if we have the value already from local source
1210
+ if self._is_generator is not None:
1211
+ # this is set if the function or class is local
1212
+ return self._is_generator
1213
+
1214
+ # not set - this is a from_name lookup - hydrate
1215
+ await self.resolve()
1216
+ assert self._is_generator is not None # should be set now
1181
1217
  return self._is_generator
1182
1218
 
1183
- async def _map(self, input_stream: AsyncIterable[Any], order_outputs: bool, return_exceptions: bool, kwargs={}):
1184
- if self._web_url:
1185
- raise InvalidError(
1186
- "A web endpoint function cannot be directly invoked for parallel remote execution. "
1187
- f"Invoke this function via its web url '{self._web_url}' or call it locally: {self._function_name}()."
1188
- )
1219
+ @property
1220
+ def cluster_size(self) -> int:
1221
+ """mdmd:hidden"""
1222
+ return self._cluster_size or 1
1223
+
1224
+ @live_method_gen
1225
+ async def _map(
1226
+ self, input_queue: _SynchronizedQueue, order_outputs: bool, return_exceptions: bool
1227
+ ) -> AsyncGenerator[Any, None]:
1228
+ """mdmd:hidden
1229
+
1230
+ Synchronicity-wrapped map implementation. To be safe against invocations of user code in
1231
+ the synchronicity thread it doesn't accept an [async]iterator, and instead takes a
1232
+ _SynchronizedQueue instance that is fed by higher level functions like .map()
1233
+
1234
+ _SynchronizedQueue is used instead of asyncio.Queue so that the main thread can put
1235
+ items in the queue safely.
1236
+ """
1237
+ self._check_no_web_url("map")
1189
1238
  if self._is_generator:
1190
1239
  raise InvalidError("A generator function cannot be called with `.map(...)`.")
1191
1240
 
1192
1241
  assert self._function_name
1193
- count_update_callback = (
1194
- self._output_mgr.function_progress_callback(self._function_name, total=None) if self._output_mgr else None
1195
- )
1242
+ if output_mgr := _get_output_manager():
1243
+ count_update_callback = output_mgr.function_progress_callback(self._function_name, total=None)
1244
+ else:
1245
+ count_update_callback = None
1246
+
1247
+ async with aclosing(
1248
+ _map_invocation(
1249
+ self, # type: ignore
1250
+ input_queue,
1251
+ self._client,
1252
+ order_outputs,
1253
+ return_exceptions,
1254
+ count_update_callback,
1255
+ )
1256
+ ) as stream:
1257
+ async for item in stream:
1258
+ yield item
1196
1259
 
1197
- async for item in _map_invocation(
1198
- self.object_id,
1199
- input_stream,
1260
+ async def _call_function(self, args, kwargs) -> ReturnType:
1261
+ if config.get("client_retries"):
1262
+ function_call_invocation_type = api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
1263
+ else:
1264
+ function_call_invocation_type = api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY
1265
+ invocation = await _Invocation.create(
1266
+ self,
1267
+ args,
1200
1268
  kwargs,
1201
- self._client,
1202
- order_outputs,
1203
- return_exceptions,
1204
- count_update_callback,
1205
- ):
1206
- yield item
1269
+ client=self._client,
1270
+ function_call_invocation_type=function_call_invocation_type,
1271
+ )
1207
1272
 
1208
- async def _call_function(self, args, kwargs):
1209
- invocation = await _Invocation.create(self.object_id, args, kwargs, self._client)
1210
- try:
1211
- return await invocation.run_function()
1212
- except asyncio.CancelledError:
1213
- # this can happen if the user terminates a program, triggering a cancellation cascade
1214
- if not self._mute_cancellation:
1215
- raise
1273
+ return await invocation.run_function()
1216
1274
 
1217
- async def _call_function_nowait(self, args, kwargs) -> _Invocation:
1218
- return await _Invocation.create(self.object_id, args, kwargs, self._client)
1275
+ async def _call_function_nowait(
1276
+ self, args, kwargs, function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType"
1277
+ ) -> _Invocation:
1278
+ return await _Invocation.create(
1279
+ self, args, kwargs, client=self._client, function_call_invocation_type=function_call_invocation_type
1280
+ )
1219
1281
 
1220
- @warn_if_generator_is_not_consumed
1282
+ @warn_if_generator_is_not_consumed()
1221
1283
  @live_method_gen
1222
1284
  @synchronizer.no_input_translation
1223
1285
  async def _call_generator(self, args, kwargs):
1224
- invocation = await _Invocation.create(self.object_id, args, kwargs, self._client)
1286
+ invocation = await _Invocation.create(
1287
+ self,
1288
+ args,
1289
+ kwargs,
1290
+ client=self._client,
1291
+ function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY,
1292
+ )
1225
1293
  async for res in invocation.run_generator():
1226
1294
  yield res
1227
1295
 
1228
1296
  @synchronizer.no_io_translation
1229
1297
  async def _call_generator_nowait(self, args, kwargs):
1230
- return await _Invocation.create(self.object_id, args, kwargs, self._client)
1231
-
1232
- @warn_if_generator_is_not_consumed
1233
- @live_method_gen
1234
- @synchronizer.no_input_translation
1235
- async def map(
1236
- self,
1237
- *input_iterators, # one input iterator per argument in the mapped-over function/generator
1238
- kwargs={}, # any extra keyword arguments for the function
1239
- order_outputs: bool = True, # return outputs in order
1240
- return_exceptions: bool = False, # propogate exceptions (False) or aggregate them in the results list (True)
1241
- ) -> AsyncGenerator[Any, None]:
1242
- """Parallel map over a set of inputs.
1243
-
1244
- Takes one iterator argument per argument in the function being mapped over.
1245
-
1246
- Example:
1247
- ```python
1248
- @stub.function()
1249
- def my_func(a):
1250
- return a ** 2
1251
-
1252
-
1253
- @stub.local_entrypoint()
1254
- def main():
1255
- assert list(my_func.map([1, 2, 3, 4])) == [1, 4, 9, 16]
1256
- ```
1257
-
1258
- If applied to a `stub.function`, `map()` returns one result per input and the output order
1259
- is guaranteed to be the same as the input order. Set `order_outputs=False` to return results
1260
- in the order that they are completed instead.
1261
-
1262
- `return_exceptions` can be used to treat exceptions as successful results:
1263
-
1264
- ```python
1265
- @stub.function()
1266
- def my_func(a):
1267
- if a == 2:
1268
- raise Exception("ohno")
1269
- return a ** 2
1270
-
1271
-
1272
- @stub.local_entrypoint()
1273
- def main():
1274
- # [0, 1, UserCodeException(Exception('ohno'))]
1275
- print(list(my_func.map(range(3), return_exceptions=True)))
1276
- ```
1277
- """
1278
-
1279
- input_stream = stream.zip(*(stream.iterate(it) for it in input_iterators))
1280
- async for item in self._map(input_stream, order_outputs, return_exceptions, kwargs):
1281
- yield item
1282
-
1283
- @synchronizer.no_input_translation
1284
- async def for_each(self, *input_iterators, kwargs={}, ignore_exceptions: bool = False):
1285
- """Execute function for all inputs, ignoring outputs.
1286
-
1287
- Convenient alias for `.map()` in cases where the function just needs to be called.
1288
- as the caller doesn't have to consume the generator to process the inputs.
1289
- """
1290
- # TODO(erikbern): it would be better if this is more like a map_spawn that immediately exits
1291
- # rather than iterating over the result
1292
- async for _ in self.map(
1293
- *input_iterators, kwargs=kwargs, order_outputs=False, return_exceptions=ignore_exceptions
1294
- ):
1295
- pass
1296
-
1297
- @warn_if_generator_is_not_consumed
1298
- @live_method_gen
1299
- @synchronizer.no_input_translation
1300
- async def starmap(
1301
- self, input_iterator, kwargs={}, order_outputs: bool = True, return_exceptions: bool = False
1302
- ) -> AsyncGenerator[Any, None]:
1303
- """Like `map`, but spreads arguments over multiple function arguments.
1304
-
1305
- Assumes every input is a sequence (e.g. a tuple).
1306
-
1307
- Example:
1308
- ```python
1309
- @stub.function()
1310
- def my_func(a, b):
1311
- return a + b
1312
-
1313
-
1314
- @stub.local_entrypoint()
1315
- def main():
1316
- assert list(my_func.starmap([(1, 2), (3, 4)])) == [3, 7]
1317
- ```
1318
- """
1319
- input_stream = stream.iterate(input_iterator)
1320
- async for item in self._map(input_stream, order_outputs, return_exceptions, kwargs):
1321
- yield item
1298
+ deprecation_warning(
1299
+ (2024, 12, 11),
1300
+ "Calling spawn on a generator function is deprecated and will soon raise an exception.",
1301
+ )
1302
+ return await _Invocation.create(
1303
+ self,
1304
+ args,
1305
+ kwargs,
1306
+ client=self._client,
1307
+ function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC_LEGACY,
1308
+ )
1322
1309
 
1323
1310
  @synchronizer.no_io_translation
1324
1311
  @live_method
1325
- async def remote(self, *args, **kwargs) -> Any:
1312
+ async def remote(self, *args: P.args, **kwargs: P.kwargs) -> ReturnType:
1326
1313
  """
1327
1314
  Calls the function remotely, executing it with the given arguments and returning the execution's result.
1328
1315
  """
1329
1316
  # TODO: Generics/TypeVars
1330
- if self._web_url:
1331
- raise InvalidError(
1332
- "A web endpoint function cannot be invoked for remote execution with `.remote`. "
1333
- f"Invoke this function via its web url '{self._web_url}' or call it locally: {self._function_name}()."
1334
- )
1317
+ self._check_no_web_url("remote")
1335
1318
  if self._is_generator:
1336
1319
  raise InvalidError(
1337
1320
  "A generator function cannot be called with `.remote(...)`. Use `.remote_gen(...)` instead."
@@ -1346,11 +1329,7 @@ class _Function(_Object, type_prefix="fu"):
1346
1329
  Calls the generator remotely, executing it with the given arguments and returning the execution's result.
1347
1330
  """
1348
1331
  # TODO: Generics/TypeVars
1349
- if self._web_url:
1350
- raise InvalidError(
1351
- "A web endpoint function cannot be invoked for remote execution with `.remote`. "
1352
- f"Invoke this function via its web url '{self._web_url}' or call it locally: {self._function_name}()."
1353
- )
1332
+ self._check_no_web_url("remote_gen")
1354
1333
 
1355
1334
  if not self._is_generator:
1356
1335
  raise InvalidError(
@@ -1359,22 +1338,15 @@ class _Function(_Object, type_prefix="fu"):
1359
1338
  async for item in self._call_generator(args, kwargs): # type: ignore
1360
1339
  yield item
1361
1340
 
1362
- @synchronizer.no_io_translation
1363
- @live_method
1364
- async def shell(self, *args, **kwargs) -> None:
1365
- if self._is_generator:
1366
- async for item in self._call_generator(args, kwargs):
1367
- pass
1368
- else:
1369
- await self._call_function(args, kwargs)
1341
+ def _is_local(self):
1342
+ return self._info is not None
1370
1343
 
1371
- def _get_is_remote_cls_method(self):
1372
- return self._is_remote_cls_method
1373
-
1374
- def _get_info(self):
1344
+ def _get_info(self) -> FunctionInfo:
1345
+ if not self._info:
1346
+ raise ExecutionError("Can't get info for a function that isn't locally defined")
1375
1347
  return self._info
1376
1348
 
1377
- def _get_obj(self):
1349
+ def _get_obj(self) -> Optional["modal.cls._Obj"]:
1378
1350
  if not self._is_method:
1379
1351
  return None
1380
1352
  elif not self._obj:
@@ -1383,83 +1355,129 @@ class _Function(_Object, type_prefix="fu"):
1383
1355
  return self._obj
1384
1356
 
1385
1357
  @synchronizer.nowrap
1386
- def local(self, *args, **kwargs) -> Any:
1358
+ def local(self, *args: P.args, **kwargs: P.kwargs) -> OriginalReturnType:
1387
1359
  """
1388
1360
  Calls the function locally, executing it with the given arguments and returning the execution's result.
1389
- This method allows a caller to execute the standard Python function wrapped by Modal.
1361
+
1362
+ The function will execute in the same environment as the caller, just like calling the underlying function
1363
+ directly in Python. In particular, only secrets available in the caller environment will be available
1364
+ through environment variables.
1390
1365
  """
1391
1366
  # TODO(erikbern): it would be nice to remove the nowrap thing, but right now that would cause
1392
1367
  # "user code" to run on the synchronicity thread, which seems bad
1393
- info = self._get_info()
1394
- if not info:
1368
+ if not self._is_local():
1395
1369
  msg = (
1396
- "The definition for this function is missing so it is not possible to invoke it locally. "
1370
+ "The definition for this function is missing here so it is not possible to invoke it locally. "
1397
1371
  "If this function was retrieved via `Function.lookup` you need to use `.remote()`."
1398
1372
  )
1399
1373
  raise ExecutionError(msg)
1400
1374
 
1401
- obj = self._get_obj()
1375
+ info = self._get_info()
1376
+ if not info.raw_f:
1377
+ # Here if calling .local on a service function itself which should never happen
1378
+ # TODO: check if we end up here in a container for a serialized function?
1379
+ raise ExecutionError("Can't call .local on service function")
1380
+
1381
+ if is_local() and self.spec.volumes or self.spec.network_file_systems:
1382
+ warnings.warn(
1383
+ f"The {info.function_name} function is executing locally "
1384
+ + "and will not have access to the mounted Volume or NetworkFileSystem data"
1385
+ )
1386
+
1387
+ obj: Optional["modal.cls._Obj"] = self._get_obj()
1402
1388
 
1403
1389
  if not obj:
1404
1390
  fun = info.raw_f
1405
1391
  return fun(*args, **kwargs)
1406
1392
  else:
1407
1393
  # This is a method on a class, so bind the self to the function
1408
- local_obj = obj.get_local_obj()
1409
- fun = info.raw_f.__get__(local_obj)
1394
+ user_cls_instance = obj._cached_user_cls_instance()
1395
+ fun = info.raw_f.__get__(user_cls_instance)
1410
1396
 
1397
+ # TODO: replace implicit local enter/exit with a context manager
1411
1398
  if is_async(info.raw_f):
1412
1399
  # We want to run __aenter__ and fun in the same coroutine
1413
1400
  async def coro():
1414
- await obj.aenter()
1401
+ await obj._aenter()
1415
1402
  return await fun(*args, **kwargs)
1416
1403
 
1417
- return coro()
1404
+ return coro() # type: ignore
1418
1405
  else:
1419
- obj.enter()
1406
+ obj._enter()
1420
1407
  return fun(*args, **kwargs)
1421
1408
 
1422
1409
  @synchronizer.no_input_translation
1423
1410
  @live_method
1424
- async def spawn(self, *args, **kwargs) -> Optional["_FunctionCall"]:
1425
- """Calls the function with the given arguments, without waiting for the results.
1411
+ async def _experimental_spawn(self, *args: P.args, **kwargs: P.kwargs) -> "_FunctionCall[ReturnType]":
1412
+ """[Experimental] Calls the function with the given arguments, without waiting for the results.
1426
1413
 
1427
- Returns a `modal.functions.FunctionCall` object, that can later be polled or waited for using `.get(timeout=...)`.
1414
+ This experimental version of the spawn method allows up to 1 million inputs to be spawned.
1415
+
1416
+ Returns a `modal.functions.FunctionCall` object, that can later be polled or
1417
+ waited for using `.get(timeout=...)`.
1428
1418
  Conceptually similar to `multiprocessing.pool.apply_async`, or a Future/Promise in other contexts.
1419
+ """
1420
+ self._check_no_web_url("_experimental_spawn")
1421
+ if self._is_generator:
1422
+ invocation = await self._call_generator_nowait(args, kwargs)
1423
+ else:
1424
+ invocation = await self._call_function_nowait(
1425
+ args, kwargs, function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC
1426
+ )
1429
1427
 
1430
- *Note:* `.spawn()` on a modal generator function does call and execute the generator, but does not currently
1431
- return a function handle for polling the result.
1428
+ fc = _FunctionCall._new_hydrated(invocation.function_call_id, invocation.client, None)
1429
+ fc._is_generator = self._is_generator if self._is_generator else False
1430
+ return fc
1431
+
1432
+ @synchronizer.no_input_translation
1433
+ @live_method
1434
+ async def spawn(self, *args: P.args, **kwargs: P.kwargs) -> "_FunctionCall[ReturnType]":
1435
+ """Calls the function with the given arguments, without waiting for the results.
1436
+
1437
+ Returns a `modal.functions.FunctionCall` object, that can later be polled or
1438
+ waited for using `.get(timeout=...)`.
1439
+ Conceptually similar to `multiprocessing.pool.apply_async`, or a Future/Promise in other contexts.
1432
1440
  """
1441
+ self._check_no_web_url("spawn")
1433
1442
  if self._is_generator:
1434
- await self._call_generator_nowait(args, kwargs)
1435
- return None
1443
+ invocation = await self._call_generator_nowait(args, kwargs)
1444
+ else:
1445
+ invocation = await self._call_function_nowait(
1446
+ args, kwargs, api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC_LEGACY
1447
+ )
1436
1448
 
1437
- invocation = await self._call_function_nowait(args, kwargs)
1438
- return _FunctionCall._new_hydrated(invocation.function_call_id, invocation.client, None)
1449
+ fc = _FunctionCall._new_hydrated(invocation.function_call_id, invocation.client, None)
1450
+ fc._is_generator = self._is_generator if self._is_generator else False
1451
+ return fc
1439
1452
 
1440
1453
  def get_raw_f(self) -> Callable[..., Any]:
1441
1454
  """Return the inner Python object wrapped by this Modal Function."""
1442
- if not self._info:
1443
- raise AttributeError("_info has not been set on this FunctionHandle and not available in this context")
1444
-
1445
- return self._info.raw_f
1455
+ return self._raw_f
1446
1456
 
1447
1457
  @live_method
1448
1458
  async def get_current_stats(self) -> FunctionStats:
1449
1459
  """Return a `FunctionStats` object describing the current function's queue and runner counts."""
1450
1460
  assert self._client.stub
1451
- resp = await self._client.stub.FunctionGetCurrentStats(
1452
- api_pb2.FunctionGetCurrentStatsRequest(function_id=self.object_id)
1453
- )
1454
- return FunctionStats(
1455
- backlog=resp.backlog, num_active_runners=resp.num_active_tasks, num_total_runners=resp.num_total_tasks
1461
+ resp = await retry_transient_errors(
1462
+ self._client.stub.FunctionGetCurrentStats,
1463
+ api_pb2.FunctionGetCurrentStatsRequest(function_id=self.object_id),
1464
+ total_timeout=10.0,
1456
1465
  )
1466
+ return FunctionStats(backlog=resp.backlog, num_total_runners=resp.num_total_tasks)
1467
+
1468
+ # A bit hacky - but the map-style functions need to not be synchronicity-wrapped
1469
+ # in order to not execute their input iterators on the synchronicity event loop.
1470
+ # We still need to wrap them using MethodWithAio to maintain a synchronicity-like
1471
+ # api with `.aio` and get working type-stubs and reference docs generation:
1472
+ map = MethodWithAio(_map_sync, _map_async, synchronizer)
1473
+ starmap = MethodWithAio(_starmap_sync, _starmap_async, synchronizer)
1474
+ for_each = MethodWithAio(_for_each_sync, _for_each_async, synchronizer)
1457
1475
 
1458
1476
 
1459
1477
  Function = synchronize_api(_Function)
1460
1478
 
1461
1479
 
1462
- class _FunctionCall(_Object, type_prefix="fc"):
1480
+ class _FunctionCall(typing.Generic[ReturnType], _Object, type_prefix="fc"):
1463
1481
  """A reference to an executed function call.
1464
1482
 
1465
1483
  Constructed using `.spawn(...)` on a Modal function with the same
@@ -1470,11 +1488,13 @@ class _FunctionCall(_Object, type_prefix="fc"):
1470
1488
  Conceptually similar to a Future/Promise/AsyncResult in other contexts and languages.
1471
1489
  """
1472
1490
 
1491
+ _is_generator: bool = False
1492
+
1473
1493
  def _invocation(self):
1474
1494
  assert self._client.stub
1475
1495
  return _Invocation(self._client.stub, self.object_id, self._client)
1476
1496
 
1477
- async def get(self, timeout: Optional[float] = None):
1497
+ async def get(self, timeout: Optional[float] = None) -> ReturnType:
1478
1498
  """Get the result of the function call.
1479
1499
 
1480
1500
  This function waits indefinitely by default. It takes an optional
@@ -1483,9 +1503,23 @@ class _FunctionCall(_Object, type_prefix="fc"):
1483
1503
 
1484
1504
  The returned coroutine is not cancellation-safe.
1485
1505
  """
1506
+
1507
+ if self._is_generator:
1508
+ raise Exception("Cannot get the result of a generator function call. Use `get_gen` instead.")
1509
+
1486
1510
  return await self._invocation().poll_function(timeout=timeout)
1487
1511
 
1488
- async def get_call_graph(self) -> List[InputInfo]:
1512
+ async def get_gen(self) -> AsyncGenerator[Any, None]:
1513
+ """
1514
+ Calls the generator remotely, executing it with the given arguments and returning the execution's result.
1515
+ """
1516
+ if not self._is_generator:
1517
+ raise Exception("Cannot iterate over a non-generator function call. Use `get` instead.")
1518
+
1519
+ async for res in self._invocation().run_generator():
1520
+ yield res
1521
+
1522
+ async def get_call_graph(self) -> list[InputInfo]:
1489
1523
  """Returns a structure representing the call graph from a given root
1490
1524
  call ID, along with the status of execution for each node.
1491
1525
 
@@ -1497,24 +1531,38 @@ class _FunctionCall(_Object, type_prefix="fc"):
1497
1531
  response = await retry_transient_errors(self._client.stub.FunctionGetCallGraph, request)
1498
1532
  return _reconstruct_call_graph(response)
1499
1533
 
1500
- async def cancel(self):
1501
- """Cancels the function call, which will stop its execution and mark its inputs as [`TERMINATED`](/docs/reference/modal.call_graph#modalcall_graphinputstatus)."""
1502
- request = api_pb2.FunctionCallCancelRequest(function_call_id=self.object_id)
1534
+ async def cancel(
1535
+ self,
1536
+ terminate_containers: bool = False, # if true, containers running the inputs are forcibly terminated
1537
+ ):
1538
+ """Cancels the function call, which will stop its execution and mark its inputs as
1539
+ [`TERMINATED`](/docs/reference/modal.call_graph#modalcall_graphinputstatus).
1540
+
1541
+ If `terminate_containers=True` - the containers running the cancelled inputs are all terminated
1542
+ causing any non-cancelled inputs on those containers to be rescheduled in new containers.
1543
+ """
1544
+ request = api_pb2.FunctionCallCancelRequest(
1545
+ function_call_id=self.object_id, terminate_containers=terminate_containers
1546
+ )
1503
1547
  assert self._client and self._client.stub
1504
1548
  await retry_transient_errors(self._client.stub.FunctionCallCancel, request)
1505
1549
 
1506
1550
  @staticmethod
1507
- async def from_id(function_call_id: str, client: Optional[_Client] = None) -> "_FunctionCall":
1551
+ async def from_id(
1552
+ function_call_id: str, client: Optional[_Client] = None, is_generator: bool = False
1553
+ ) -> "_FunctionCall":
1508
1554
  if client is None:
1509
1555
  client = await _Client.from_env()
1510
1556
 
1511
- return _FunctionCall._new_hydrated(function_call_id, client, None)
1557
+ fc = _FunctionCall._new_hydrated(function_call_id, client, None)
1558
+ fc._is_generator = is_generator
1559
+ return fc
1512
1560
 
1513
1561
 
1514
1562
  FunctionCall = synchronize_api(_FunctionCall)
1515
1563
 
1516
1564
 
1517
- async def _gather(*function_calls: _FunctionCall):
1565
+ async def _gather(*function_calls: _FunctionCall[ReturnType]) -> typing.Sequence[ReturnType]:
1518
1566
  """Wait until all Modal function calls have results before returning
1519
1567
 
1520
1568
  Accepts a variable number of FunctionCall objects as returned by `Function.spawn()`.
@@ -1532,63 +1580,10 @@ async def _gather(*function_calls: _FunctionCall):
1532
1580
  ```
1533
1581
  """
1534
1582
  try:
1535
- return await asyncio.gather(*[fc.get() for fc in function_calls])
1583
+ return await TaskContext.gather(*[fc.get() for fc in function_calls])
1536
1584
  except Exception as exc:
1537
1585
  # TODO: kill all running function calls
1538
1586
  raise exc
1539
1587
 
1540
1588
 
1541
1589
  gather = synchronize_api(_gather)
1542
-
1543
-
1544
- _current_input_id: ContextVar = ContextVar("_current_input_id")
1545
- _current_function_call_id: ContextVar = ContextVar("_current_function_call_id")
1546
-
1547
-
1548
- def current_input_id() -> Optional[str]:
1549
- """Returns the input ID for the current input.
1550
-
1551
- Can only be called from Modal function (i.e. in a container context).
1552
-
1553
- ```python
1554
- from modal import current_input_id
1555
-
1556
- @stub.function()
1557
- def process_stuff():
1558
- print(f"Starting to process {current_input_id()}")
1559
- ```
1560
- """
1561
- try:
1562
- return _current_input_id.get()
1563
- except LookupError:
1564
- return None
1565
-
1566
-
1567
- def current_function_call_id() -> Optional[str]:
1568
- """Returns the function call ID for the current input.
1569
-
1570
- Can only be called from Modal function (i.e. in a container context).
1571
-
1572
- ```python
1573
- from modal import current_function_call_id
1574
-
1575
- @stub.function()
1576
- def process_stuff():
1577
- print(f"Starting to process input from {current_function_call_id()}")
1578
- ```
1579
- """
1580
- try:
1581
- return _current_function_call_id.get()
1582
- except LookupError:
1583
- return None
1584
-
1585
-
1586
- def _set_current_context_ids(input_id: str, function_call_id: str) -> Callable[[], None]:
1587
- input_token = _current_input_id.set(input_id)
1588
- function_call_token = _current_function_call_id.set(function_call_id)
1589
-
1590
- def _reset_current_context_ids():
1591
- _current_input_id.reset(input_token)
1592
- _current_function_call_id.reset(function_call_token)
1593
-
1594
- return _reset_current_context_ids