modal 0.62.16__py3-none-any.whl → 0.72.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/__init__.py +17 -13
- modal/__main__.py +41 -3
- modal/_clustered_functions.py +80 -0
- modal/_clustered_functions.pyi +22 -0
- modal/_container_entrypoint.py +420 -937
- modal/_ipython.py +3 -13
- modal/_location.py +17 -10
- modal/_output.py +243 -99
- modal/_pty.py +2 -2
- modal/_resolver.py +55 -59
- modal/_resources.py +51 -0
- modal/_runtime/__init__.py +1 -0
- modal/_runtime/asgi.py +519 -0
- modal/_runtime/container_io_manager.py +1036 -0
- modal/_runtime/execution_context.py +89 -0
- modal/_runtime/telemetry.py +169 -0
- modal/_runtime/user_code_imports.py +356 -0
- modal/_serialization.py +134 -9
- modal/_traceback.py +47 -187
- modal/_tunnel.py +52 -16
- modal/_tunnel.pyi +19 -36
- modal/_utils/app_utils.py +3 -17
- modal/_utils/async_utils.py +479 -100
- modal/_utils/blob_utils.py +157 -186
- modal/_utils/bytes_io_segment_payload.py +97 -0
- modal/_utils/deprecation.py +89 -0
- modal/_utils/docker_utils.py +98 -0
- modal/_utils/function_utils.py +460 -171
- modal/_utils/grpc_testing.py +47 -31
- modal/_utils/grpc_utils.py +62 -109
- modal/_utils/hash_utils.py +61 -19
- modal/_utils/http_utils.py +39 -9
- modal/_utils/logger.py +2 -1
- modal/_utils/mount_utils.py +34 -16
- modal/_utils/name_utils.py +58 -0
- modal/_utils/package_utils.py +14 -1
- modal/_utils/pattern_utils.py +205 -0
- modal/_utils/rand_pb_testing.py +5 -7
- modal/_utils/shell_utils.py +15 -49
- modal/_vendor/a2wsgi_wsgi.py +62 -72
- modal/_vendor/cloudpickle.py +1 -1
- modal/_watcher.py +14 -12
- modal/app.py +1003 -314
- modal/app.pyi +540 -264
- modal/call_graph.py +7 -6
- modal/cli/_download.py +63 -53
- modal/cli/_traceback.py +200 -0
- modal/cli/app.py +205 -45
- modal/cli/config.py +12 -5
- modal/cli/container.py +62 -14
- modal/cli/dict.py +128 -0
- modal/cli/entry_point.py +26 -13
- modal/cli/environment.py +40 -9
- modal/cli/import_refs.py +64 -58
- modal/cli/launch.py +32 -18
- modal/cli/network_file_system.py +64 -83
- modal/cli/profile.py +1 -1
- modal/cli/programs/run_jupyter.py +35 -10
- modal/cli/programs/vscode.py +60 -10
- modal/cli/queues.py +131 -0
- modal/cli/run.py +234 -131
- modal/cli/secret.py +8 -7
- modal/cli/token.py +7 -2
- modal/cli/utils.py +79 -10
- modal/cli/volume.py +110 -109
- modal/client.py +250 -144
- modal/client.pyi +157 -118
- modal/cloud_bucket_mount.py +108 -34
- modal/cloud_bucket_mount.pyi +32 -38
- modal/cls.py +535 -148
- modal/cls.pyi +190 -146
- modal/config.py +41 -19
- modal/container_process.py +177 -0
- modal/container_process.pyi +82 -0
- modal/dict.py +111 -65
- modal/dict.pyi +136 -131
- modal/environments.py +106 -5
- modal/environments.pyi +77 -25
- modal/exception.py +34 -43
- modal/experimental.py +61 -2
- modal/extensions/ipython.py +5 -5
- modal/file_io.py +537 -0
- modal/file_io.pyi +235 -0
- modal/file_pattern_matcher.py +197 -0
- modal/functions.py +906 -911
- modal/functions.pyi +466 -430
- modal/gpu.py +57 -44
- modal/image.py +1089 -479
- modal/image.pyi +584 -228
- modal/io_streams.py +434 -0
- modal/io_streams.pyi +122 -0
- modal/mount.py +314 -101
- modal/mount.pyi +241 -235
- modal/network_file_system.py +92 -92
- modal/network_file_system.pyi +152 -110
- modal/object.py +67 -36
- modal/object.pyi +166 -143
- modal/output.py +63 -0
- modal/parallel_map.py +434 -0
- modal/parallel_map.pyi +75 -0
- modal/partial_function.py +282 -117
- modal/partial_function.pyi +222 -129
- modal/proxy.py +15 -12
- modal/proxy.pyi +3 -8
- modal/queue.py +182 -65
- modal/queue.pyi +218 -118
- modal/requirements/2024.04.txt +29 -0
- modal/requirements/2024.10.txt +16 -0
- modal/requirements/README.md +21 -0
- modal/requirements/base-images.json +22 -0
- modal/retries.py +48 -7
- modal/runner.py +459 -156
- modal/runner.pyi +135 -71
- modal/running_app.py +38 -0
- modal/sandbox.py +514 -236
- modal/sandbox.pyi +397 -169
- modal/schedule.py +4 -4
- modal/scheduler_placement.py +20 -3
- modal/secret.py +56 -31
- modal/secret.pyi +62 -42
- modal/serving.py +51 -56
- modal/serving.pyi +44 -36
- modal/stream_type.py +15 -0
- modal/token_flow.py +5 -3
- modal/token_flow.pyi +37 -32
- modal/volume.py +285 -157
- modal/volume.pyi +249 -184
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/METADATA +7 -7
- modal-0.72.11.dist-info/RECORD +174 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
- modal_docs/gen_reference_docs.py +3 -1
- modal_docs/mdmd/mdmd.py +0 -1
- modal_docs/mdmd/signatures.py +5 -2
- modal_global_objects/images/base_images.py +28 -0
- modal_global_objects/mounts/python_standalone.py +2 -2
- modal_proto/__init__.py +1 -1
- modal_proto/api.proto +1288 -533
- modal_proto/api_grpc.py +856 -456
- modal_proto/api_pb2.py +2165 -1157
- modal_proto/api_pb2.pyi +8859 -0
- modal_proto/api_pb2_grpc.py +1674 -855
- modal_proto/api_pb2_grpc.pyi +1416 -0
- modal_proto/modal_api_grpc.py +149 -0
- modal_proto/modal_options_grpc.py +3 -0
- modal_proto/options_pb2.pyi +20 -0
- modal_proto/options_pb2_grpc.pyi +7 -0
- modal_proto/py.typed +0 -0
- modal_version/__init__.py +1 -1
- modal_version/_version_generated.py +2 -2
- modal/_asgi.py +0 -370
- modal/_container_entrypoint.pyi +0 -378
- modal/_container_exec.py +0 -128
- modal/_sandbox_shell.py +0 -49
- modal/shared_volume.py +0 -23
- modal/shared_volume.pyi +0 -24
- modal/stub.py +0 -783
- modal/stub.pyi +0 -332
- modal-0.62.16.dist-info/RECORD +0 -198
- modal_global_objects/images/conda.py +0 -15
- modal_global_objects/images/debian_slim.py +0 -15
- modal_global_objects/images/micromamba.py +0 -15
- test/__init__.py +0 -1
- test/aio_test.py +0 -12
- test/async_utils_test.py +0 -262
- test/blob_test.py +0 -67
- test/cli_imports_test.py +0 -149
- test/cli_test.py +0 -659
- test/client_test.py +0 -194
- test/cls_test.py +0 -630
- test/config_test.py +0 -137
- test/conftest.py +0 -1420
- test/container_app_test.py +0 -32
- test/container_test.py +0 -1389
- test/cpu_test.py +0 -23
- test/decorator_test.py +0 -85
- test/deprecation_test.py +0 -34
- test/dict_test.py +0 -33
- test/e2e_test.py +0 -68
- test/error_test.py +0 -7
- test/function_serialization_test.py +0 -32
- test/function_test.py +0 -653
- test/function_utils_test.py +0 -101
- test/gpu_test.py +0 -159
- test/grpc_utils_test.py +0 -141
- test/helpers.py +0 -42
- test/image_test.py +0 -669
- test/live_reload_test.py +0 -80
- test/lookup_test.py +0 -70
- test/mdmd_test.py +0 -329
- test/mount_test.py +0 -162
- test/mounted_files_test.py +0 -329
- test/network_file_system_test.py +0 -181
- test/notebook_test.py +0 -66
- test/object_test.py +0 -41
- test/package_utils_test.py +0 -25
- test/queue_test.py +0 -97
- test/resolver_test.py +0 -58
- test/retries_test.py +0 -67
- test/runner_test.py +0 -85
- test/sandbox_test.py +0 -191
- test/schedule_test.py +0 -15
- test/scheduler_placement_test.py +0 -29
- test/secret_test.py +0 -78
- test/serialization_test.py +0 -42
- test/stub_composition_test.py +0 -10
- test/stub_test.py +0 -360
- test/test_asgi_wrapper.py +0 -234
- test/token_flow_test.py +0 -18
- test/traceback_test.py +0 -135
- test/tunnel_test.py +0 -29
- test/utils_test.py +0 -88
- test/version_test.py +0 -14
- test/volume_test.py +0 -341
- test/watcher_test.py +0 -30
- test/webhook_test.py +0 -146
- /modal/{requirements.312.txt → requirements/2023.12.312.txt} +0 -0
- /modal/{requirements.txt → requirements/2023.12.txt} +0 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
modal/functions.py
CHANGED
@@ -1,253 +1,163 @@
|
|
1
1
|
# Copyright Modal Labs 2023
|
2
|
-
import
|
2
|
+
import dataclasses
|
3
3
|
import inspect
|
4
|
+
import textwrap
|
4
5
|
import time
|
6
|
+
import typing
|
5
7
|
import warnings
|
6
|
-
from
|
8
|
+
from collections.abc import AsyncGenerator, Collection, Sequence, Sized
|
7
9
|
from dataclasses import dataclass
|
8
10
|
from pathlib import PurePosixPath
|
9
11
|
from typing import (
|
10
12
|
TYPE_CHECKING,
|
11
13
|
Any,
|
12
|
-
AsyncGenerator,
|
13
|
-
AsyncIterable,
|
14
|
-
AsyncIterator,
|
15
14
|
Callable,
|
16
|
-
Collection,
|
17
|
-
Dict,
|
18
|
-
List,
|
19
|
-
Literal,
|
20
15
|
Optional,
|
21
|
-
Sequence,
|
22
|
-
Set,
|
23
|
-
Sized,
|
24
|
-
Tuple,
|
25
|
-
Type,
|
26
16
|
Union,
|
27
17
|
)
|
28
18
|
|
29
|
-
|
19
|
+
import typing_extensions
|
30
20
|
from google.protobuf.message import Message
|
31
21
|
from grpclib import GRPCError, Status
|
32
|
-
from
|
22
|
+
from synchronicity.combined_types import MethodWithAio
|
33
23
|
from synchronicity.exceptions import UserCodeException
|
34
24
|
|
35
|
-
from
|
36
|
-
from modal_proto import
|
25
|
+
from modal_proto import api_pb2
|
26
|
+
from modal_proto.modal_api_grpc import ModalClientModal
|
37
27
|
|
38
28
|
from ._location import parse_cloud_provider
|
39
|
-
from .
|
29
|
+
from ._pty import get_pty_info
|
40
30
|
from ._resolver import Resolver
|
41
|
-
from .
|
42
|
-
from .
|
31
|
+
from ._resources import convert_fn_config_to_resources_config
|
32
|
+
from ._runtime.execution_context import current_input_id, is_local
|
33
|
+
from ._serialization import serialize, serialize_proto_params
|
34
|
+
from ._traceback import print_server_warnings
|
43
35
|
from ._utils.async_utils import (
|
44
|
-
|
36
|
+
TaskContext,
|
37
|
+
aclosing,
|
38
|
+
async_merge,
|
39
|
+
callable_to_agen,
|
45
40
|
synchronize_api,
|
46
41
|
synchronizer,
|
47
42
|
warn_if_generator_is_not_consumed,
|
48
43
|
)
|
49
|
-
from ._utils.
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
44
|
+
from ._utils.deprecation import deprecation_warning, renamed_parameter
|
45
|
+
from ._utils.function_utils import (
|
46
|
+
ATTEMPT_TIMEOUT_GRACE_PERIOD,
|
47
|
+
OUTPUTS_TIMEOUT,
|
48
|
+
FunctionCreationStatus,
|
49
|
+
FunctionInfo,
|
50
|
+
_create_input,
|
51
|
+
_process_result,
|
52
|
+
_stream_function_call_data,
|
53
|
+
get_function_type,
|
54
|
+
is_async,
|
54
55
|
)
|
55
|
-
from ._utils.
|
56
|
-
from ._utils.
|
57
|
-
from ._utils.mount_utils import validate_mount_points, validate_volumes
|
56
|
+
from ._utils.grpc_utils import retry_transient_errors
|
57
|
+
from ._utils.mount_utils import validate_network_file_systems, validate_volumes
|
58
58
|
from .call_graph import InputInfo, _reconstruct_call_graph
|
59
59
|
from .client import _Client
|
60
60
|
from .cloud_bucket_mount import _CloudBucketMount, cloud_bucket_mounts_to_proto
|
61
|
-
from .config import config
|
61
|
+
from .config import config
|
62
62
|
from .exception import (
|
63
63
|
ExecutionError,
|
64
64
|
FunctionTimeoutError,
|
65
|
+
InternalFailure,
|
65
66
|
InvalidError,
|
66
67
|
NotFoundError,
|
67
|
-
|
68
|
-
deprecation_warning,
|
68
|
+
OutputExpiredError,
|
69
69
|
)
|
70
70
|
from .gpu import GPU_T, parse_gpu_config
|
71
71
|
from .image import _Image
|
72
|
-
from .mount import _get_client_mount, _Mount
|
72
|
+
from .mount import _get_client_mount, _Mount, get_auto_mounts
|
73
73
|
from .network_file_system import _NetworkFileSystem, network_file_system_mount_protos
|
74
|
-
from .object import
|
74
|
+
from .object import _get_environment_name, _Object, live_method, live_method_gen
|
75
|
+
from .output import _get_output_manager
|
76
|
+
from .parallel_map import (
|
77
|
+
_for_each_async,
|
78
|
+
_for_each_sync,
|
79
|
+
_map_async,
|
80
|
+
_map_invocation,
|
81
|
+
_map_sync,
|
82
|
+
_starmap_async,
|
83
|
+
_starmap_sync,
|
84
|
+
_SynchronizedQueue,
|
85
|
+
)
|
75
86
|
from .proxy import _Proxy
|
76
|
-
from .retries import Retries
|
87
|
+
from .retries import Retries, RetryManager
|
77
88
|
from .schedule import Schedule
|
78
89
|
from .scheduler_placement import SchedulerPlacement
|
79
90
|
from .secret import _Secret
|
80
91
|
from .volume import _Volume
|
81
92
|
|
82
|
-
OUTPUTS_TIMEOUT = 55.0 # seconds
|
83
|
-
ATTEMPT_TIMEOUT_GRACE_PERIOD = 5 # seconds
|
84
|
-
|
85
|
-
|
86
93
|
if TYPE_CHECKING:
|
87
|
-
import modal.
|
88
|
-
|
89
|
-
|
90
|
-
def exc_with_hints(exc: BaseException):
|
91
|
-
"""mdmd:hidden"""
|
92
|
-
if isinstance(exc, ImportError) and exc.msg == "attempted relative import with no known parent package":
|
93
|
-
exc.msg += """\n
|
94
|
-
HINT: For relative imports to work, you might need to run your modal app as a module. Try:
|
95
|
-
- `python -m my_pkg.my_app` instead of `python my_pkg/my_app.py`
|
96
|
-
- `modal deploy my_pkg.my_app` instead of `modal deploy my_pkg/my_app.py`
|
97
|
-
"""
|
98
|
-
elif isinstance(
|
99
|
-
exc, RuntimeError
|
100
|
-
) and "CUDA error: no kernel image is available for execution on the device" in str(exc):
|
101
|
-
msg = (
|
102
|
-
exc.args[0]
|
103
|
-
+ """\n
|
104
|
-
HINT: This error usually indicates an outdated CUDA version. Older versions of torch (<=1.12)
|
105
|
-
come with CUDA 10.2 by default. If pinning to an older torch version, you can specify a CUDA version
|
106
|
-
manually, for example:
|
107
|
-
- image.pip_install("torch==1.12.1+cu116", find_links="https://download.pytorch.org/whl/torch_stable.html")
|
108
|
-
"""
|
109
|
-
)
|
110
|
-
exc.args = (msg,)
|
111
|
-
|
112
|
-
return exc
|
113
|
-
|
114
|
-
|
115
|
-
async def _process_result(result: api_pb2.GenericResult, data_format: int, stub, client=None):
|
116
|
-
if result.WhichOneof("data_oneof") == "data_blob_id":
|
117
|
-
data = await blob_download(result.data_blob_id, stub)
|
118
|
-
else:
|
119
|
-
data = result.data
|
120
|
-
|
121
|
-
if result.status == api_pb2.GenericResult.GENERIC_STATUS_TIMEOUT:
|
122
|
-
raise FunctionTimeoutError(result.exception)
|
123
|
-
elif result.status != api_pb2.GenericResult.GENERIC_STATUS_SUCCESS:
|
124
|
-
if data:
|
125
|
-
try:
|
126
|
-
exc = deserialize(data, client)
|
127
|
-
except Exception as deser_exc:
|
128
|
-
raise ExecutionError(
|
129
|
-
"Could not deserialize remote exception due to local error:\n"
|
130
|
-
+ f"{deser_exc}\n"
|
131
|
-
+ "This can happen if your local environment does not have the remote exception definitions.\n"
|
132
|
-
+ "Here is the remote traceback:\n"
|
133
|
-
+ f"{result.traceback}"
|
134
|
-
)
|
135
|
-
if not isinstance(exc, BaseException):
|
136
|
-
raise ExecutionError(f"Got remote exception of incorrect type {type(exc)}")
|
137
|
-
|
138
|
-
if result.serialized_tb:
|
139
|
-
try:
|
140
|
-
tb_dict = deserialize(result.serialized_tb, client)
|
141
|
-
line_cache = deserialize(result.tb_line_cache, client)
|
142
|
-
append_modal_tb(exc, tb_dict, line_cache)
|
143
|
-
except Exception:
|
144
|
-
pass
|
145
|
-
uc_exc = UserCodeException(exc_with_hints(exc))
|
146
|
-
raise uc_exc
|
147
|
-
raise RemoteError(result.exception)
|
94
|
+
import modal.app
|
95
|
+
import modal.cls
|
96
|
+
import modal.partial_function
|
148
97
|
|
149
|
-
try:
|
150
|
-
return deserialize_data_format(data, data_format, client)
|
151
|
-
except ModuleNotFoundError as deser_exc:
|
152
|
-
raise ExecutionError(
|
153
|
-
"Could not deserialize result due to error:\n"
|
154
|
-
+ f"{deser_exc}\n"
|
155
|
-
+ "This can happen if your local environment does not have a module that was used to construct the result. \n"
|
156
|
-
)
|
157
98
|
|
158
|
-
|
159
|
-
|
160
|
-
""
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
args_serialized = serialize((args, kwargs))
|
167
|
-
|
168
|
-
if len(args_serialized) > MAX_OBJECT_SIZE_BYTES:
|
169
|
-
args_blob_id = await blob_upload(args_serialized, client.stub)
|
170
|
-
|
171
|
-
return api_pb2.FunctionPutInputsItem(
|
172
|
-
input=api_pb2.FunctionInput(args_blob_id=args_blob_id, data_format=api_pb2.DATA_FORMAT_PICKLE),
|
173
|
-
idx=idx,
|
174
|
-
)
|
175
|
-
else:
|
176
|
-
return api_pb2.FunctionPutInputsItem(
|
177
|
-
input=api_pb2.FunctionInput(args=args_serialized, data_format=api_pb2.DATA_FORMAT_PICKLE),
|
178
|
-
idx=idx,
|
179
|
-
)
|
180
|
-
|
181
|
-
|
182
|
-
async def _stream_function_call_data(
|
183
|
-
client, function_call_id: str, variant: Literal["data_in", "data_out"]
|
184
|
-
) -> AsyncIterator[Any]:
|
185
|
-
"""Read from the `data_in` or `data_out` stream of a function call."""
|
186
|
-
last_index = 0
|
187
|
-
retries_remaining = 10
|
188
|
-
|
189
|
-
if variant == "data_in":
|
190
|
-
stub_fn = client.stub.FunctionCallGetDataIn
|
191
|
-
elif variant == "data_out":
|
192
|
-
stub_fn = client.stub.FunctionCallGetDataOut
|
193
|
-
else:
|
194
|
-
raise ValueError(f"Invalid variant {variant}")
|
195
|
-
|
196
|
-
while True:
|
197
|
-
req = api_pb2.FunctionCallGetDataRequest(function_call_id=function_call_id, last_index=last_index)
|
198
|
-
try:
|
199
|
-
async for chunk in unary_stream(stub_fn, req):
|
200
|
-
if chunk.index <= last_index:
|
201
|
-
continue
|
202
|
-
last_index = chunk.index
|
203
|
-
if chunk.data_blob_id:
|
204
|
-
message_bytes = await blob_download(chunk.data_blob_id, client.stub)
|
205
|
-
else:
|
206
|
-
message_bytes = chunk.data
|
207
|
-
message = deserialize_data_format(message_bytes, chunk.data_format, client)
|
208
|
-
yield message
|
209
|
-
except (GRPCError, StreamTerminatedError) as exc:
|
210
|
-
if retries_remaining > 0:
|
211
|
-
retries_remaining -= 1
|
212
|
-
if isinstance(exc, GRPCError):
|
213
|
-
if exc.status in RETRYABLE_GRPC_STATUS_CODES:
|
214
|
-
await asyncio.sleep(1.0)
|
215
|
-
continue
|
216
|
-
elif isinstance(exc, StreamTerminatedError):
|
217
|
-
continue
|
218
|
-
raise
|
219
|
-
|
220
|
-
|
221
|
-
@dataclass
|
222
|
-
class _OutputValue:
|
223
|
-
# box class for distinguishing None results from non-existing/None markers
|
224
|
-
value: Any
|
99
|
+
@dataclasses.dataclass
|
100
|
+
class _RetryContext:
|
101
|
+
function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType"
|
102
|
+
retry_policy: api_pb2.FunctionRetryPolicy
|
103
|
+
function_call_jwt: str
|
104
|
+
input_jwt: str
|
105
|
+
input_id: str
|
106
|
+
item: api_pb2.FunctionPutInputsItem
|
225
107
|
|
226
108
|
|
227
109
|
class _Invocation:
|
228
110
|
"""Internal client representation of a single-input call to a Modal Function or Generator"""
|
229
111
|
|
230
|
-
|
112
|
+
stub: ModalClientModal
|
113
|
+
|
114
|
+
def __init__(
|
115
|
+
self,
|
116
|
+
stub: ModalClientModal,
|
117
|
+
function_call_id: str,
|
118
|
+
client: _Client,
|
119
|
+
retry_context: Optional[_RetryContext] = None,
|
120
|
+
):
|
231
121
|
self.stub = stub
|
232
122
|
self.client = client # Used by the deserializer.
|
233
123
|
self.function_call_id = function_call_id # TODO: remove and use only input_id
|
124
|
+
self._retry_context = retry_context
|
234
125
|
|
235
126
|
@staticmethod
|
236
|
-
async def create(
|
127
|
+
async def create(
|
128
|
+
function: "_Function",
|
129
|
+
args,
|
130
|
+
kwargs,
|
131
|
+
*,
|
132
|
+
client: _Client,
|
133
|
+
function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType",
|
134
|
+
) -> "_Invocation":
|
237
135
|
assert client.stub
|
238
|
-
|
136
|
+
function_id = function.object_id
|
137
|
+
item = await _create_input(args, kwargs, client, method_name=function._use_method_name)
|
239
138
|
|
240
139
|
request = api_pb2.FunctionMapRequest(
|
241
140
|
function_id=function_id,
|
242
141
|
parent_input_id=current_input_id() or "",
|
243
142
|
function_call_type=api_pb2.FUNCTION_CALL_TYPE_UNARY,
|
244
143
|
pipelined_inputs=[item],
|
144
|
+
function_call_invocation_type=function_call_invocation_type,
|
245
145
|
)
|
246
146
|
response = await retry_transient_errors(client.stub.FunctionMap, request)
|
247
147
|
function_call_id = response.function_call_id
|
248
148
|
|
249
149
|
if response.pipelined_inputs:
|
250
|
-
|
150
|
+
assert len(response.pipelined_inputs) == 1
|
151
|
+
input = response.pipelined_inputs[0]
|
152
|
+
retry_context = _RetryContext(
|
153
|
+
function_call_invocation_type=function_call_invocation_type,
|
154
|
+
retry_policy=response.retry_policy,
|
155
|
+
function_call_jwt=response.function_call_jwt,
|
156
|
+
input_jwt=input.input_jwt,
|
157
|
+
input_id=input.input_id,
|
158
|
+
item=item,
|
159
|
+
)
|
160
|
+
return _Invocation(client.stub, function_call_id, client, retry_context)
|
251
161
|
|
252
162
|
request_put = api_pb2.FunctionPutInputsRequest(
|
253
163
|
function_id=function_id, inputs=[item], function_call_id=function_call_id
|
@@ -259,11 +169,20 @@ class _Invocation:
|
|
259
169
|
processed_inputs = inputs_response.inputs
|
260
170
|
if not processed_inputs:
|
261
171
|
raise Exception("Could not create function call - the input queue seems to be full")
|
262
|
-
|
172
|
+
input = inputs_response.inputs[0]
|
173
|
+
retry_context = _RetryContext(
|
174
|
+
function_call_invocation_type=function_call_invocation_type,
|
175
|
+
retry_policy=response.retry_policy,
|
176
|
+
function_call_jwt=response.function_call_jwt,
|
177
|
+
input_jwt=input.input_jwt,
|
178
|
+
input_id=input.input_id,
|
179
|
+
item=item,
|
180
|
+
)
|
181
|
+
return _Invocation(client.stub, function_call_id, client, retry_context)
|
263
182
|
|
264
183
|
async def pop_function_call_outputs(
|
265
|
-
self, timeout: Optional[float], clear_on_success: bool
|
266
|
-
) ->
|
184
|
+
self, timeout: Optional[float], clear_on_success: bool, input_jwts: Optional[list[str]] = None
|
185
|
+
) -> api_pb2.FunctionGetOutputsResponse:
|
267
186
|
t0 = time.time()
|
268
187
|
if timeout is None:
|
269
188
|
backend_timeout = OUTPUTS_TIMEOUT
|
@@ -277,53 +196,100 @@ class _Invocation:
|
|
277
196
|
timeout=backend_timeout,
|
278
197
|
last_entry_id="0-0",
|
279
198
|
clear_on_success=clear_on_success,
|
199
|
+
requested_at=time.time(),
|
200
|
+
input_jwts=input_jwts,
|
280
201
|
)
|
281
202
|
response: api_pb2.FunctionGetOutputsResponse = await retry_transient_errors(
|
282
203
|
self.stub.FunctionGetOutputs,
|
283
204
|
request,
|
284
205
|
attempt_timeout=backend_timeout + ATTEMPT_TIMEOUT_GRACE_PERIOD,
|
285
206
|
)
|
207
|
+
|
286
208
|
if len(response.outputs) > 0:
|
287
|
-
|
288
|
-
yield item
|
289
|
-
return
|
209
|
+
return response
|
290
210
|
|
291
211
|
if timeout is not None:
|
292
212
|
# update timeout in retry loop
|
293
213
|
backend_timeout = min(OUTPUTS_TIMEOUT, t0 + timeout - time.time())
|
294
214
|
if backend_timeout < 0:
|
295
|
-
|
215
|
+
# return the last response to check for state of num_unfinished_inputs
|
216
|
+
return response
|
217
|
+
|
218
|
+
async def _retry_input(self) -> None:
|
219
|
+
ctx = self._retry_context
|
220
|
+
if not ctx:
|
221
|
+
raise ValueError("Cannot retry input when _retry_context is empty.")
|
222
|
+
|
223
|
+
item = api_pb2.FunctionRetryInputsItem(input_jwt=ctx.input_jwt, input=ctx.item.input)
|
224
|
+
request = api_pb2.FunctionRetryInputsRequest(function_call_jwt=ctx.function_call_jwt, inputs=[item])
|
225
|
+
await retry_transient_errors(
|
226
|
+
self.client.stub.FunctionRetryInputs,
|
227
|
+
request,
|
228
|
+
)
|
296
229
|
|
297
|
-
async def
|
230
|
+
async def _get_single_output(self, expected_jwt: Optional[str] = None) -> Any:
|
298
231
|
# waits indefinitely for a single result for the function, and clear the outputs buffer after
|
299
232
|
item: api_pb2.FunctionGetOutputsItem = (
|
300
|
-
await
|
301
|
-
|
302
|
-
|
233
|
+
await self.pop_function_call_outputs(
|
234
|
+
timeout=None,
|
235
|
+
clear_on_success=True,
|
236
|
+
input_jwts=[expected_jwt] if expected_jwt else None,
|
237
|
+
)
|
238
|
+
).outputs[0]
|
303
239
|
return await _process_result(item.result, item.data_format, self.stub, self.client)
|
304
240
|
|
241
|
+
async def run_function(self) -> Any:
|
242
|
+
# Use retry logic only if retry policy is specified and
|
243
|
+
ctx = self._retry_context
|
244
|
+
if (
|
245
|
+
not ctx
|
246
|
+
or not ctx.retry_policy
|
247
|
+
or ctx.retry_policy.retries == 0
|
248
|
+
or ctx.function_call_invocation_type != api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
|
249
|
+
):
|
250
|
+
return await self._get_single_output()
|
251
|
+
|
252
|
+
# User errors including timeouts are managed by the user specified retry policy.
|
253
|
+
user_retry_manager = RetryManager(ctx.retry_policy)
|
254
|
+
|
255
|
+
while True:
|
256
|
+
try:
|
257
|
+
return await self._get_single_output(ctx.input_jwt)
|
258
|
+
except (UserCodeException, FunctionTimeoutError) as exc:
|
259
|
+
await user_retry_manager.raise_or_sleep(exc)
|
260
|
+
except InternalFailure:
|
261
|
+
# For system failures on the server, we retry immediately.
|
262
|
+
pass
|
263
|
+
await self._retry_input()
|
264
|
+
|
305
265
|
async def poll_function(self, timeout: Optional[float] = None):
|
306
266
|
"""Waits up to timeout for a result from a function.
|
307
267
|
|
308
268
|
If timeout is `None`, waits indefinitely. This function is not
|
309
269
|
cancellation-safe.
|
310
270
|
"""
|
311
|
-
|
312
|
-
|
271
|
+
response: api_pb2.FunctionGetOutputsResponse = await self.pop_function_call_outputs(
|
272
|
+
timeout=timeout, clear_on_success=False
|
313
273
|
)
|
314
|
-
|
315
|
-
|
274
|
+
if len(response.outputs) == 0 and response.num_unfinished_inputs == 0:
|
275
|
+
# if no unfinished inputs and no outputs, then function expired
|
276
|
+
raise OutputExpiredError()
|
277
|
+
elif len(response.outputs) == 0:
|
316
278
|
raise TimeoutError()
|
317
279
|
|
318
|
-
return await _process_result(
|
280
|
+
return await _process_result(
|
281
|
+
response.outputs[0].result, response.outputs[0].data_format, self.stub, self.client
|
282
|
+
)
|
319
283
|
|
320
284
|
async def run_generator(self):
|
321
|
-
data_stream = _stream_function_call_data(self.client, self.function_call_id, variant="data_out")
|
322
|
-
combined_stream = stream.merge(data_stream, stream.call(self.run_function)) # type: ignore
|
323
|
-
|
324
285
|
items_received = 0
|
325
286
|
items_total: Union[int, None] = None # populated when self.run_function() completes
|
326
|
-
async with
|
287
|
+
async with aclosing(
|
288
|
+
async_merge(
|
289
|
+
_stream_function_call_data(self.client, self.function_call_id, variant="data_out"),
|
290
|
+
callable_to_agen(self.run_function),
|
291
|
+
)
|
292
|
+
) as streamer:
|
327
293
|
async for item in streamer:
|
328
294
|
if isinstance(item, api_pb2.GeneratorDone):
|
329
295
|
items_total = item.items_total
|
@@ -336,187 +302,29 @@ class _Invocation:
|
|
336
302
|
break
|
337
303
|
|
338
304
|
|
339
|
-
MAP_INVOCATION_CHUNK_SIZE = 49
|
340
|
-
|
341
|
-
|
342
|
-
async def _map_invocation(
|
343
|
-
function_id: str,
|
344
|
-
input_stream: AsyncIterable[Any],
|
345
|
-
kwargs: Dict[str, Any],
|
346
|
-
client: _Client,
|
347
|
-
order_outputs: bool,
|
348
|
-
return_exceptions: bool,
|
349
|
-
count_update_callback: Optional[Callable[[int, int], None]],
|
350
|
-
):
|
351
|
-
assert client.stub
|
352
|
-
request = api_pb2.FunctionMapRequest(
|
353
|
-
function_id=function_id,
|
354
|
-
parent_input_id=current_input_id() or "",
|
355
|
-
function_call_type=api_pb2.FUNCTION_CALL_TYPE_MAP,
|
356
|
-
return_exceptions=return_exceptions,
|
357
|
-
)
|
358
|
-
response = await retry_transient_errors(client.stub.FunctionMap, request)
|
359
|
-
|
360
|
-
function_call_id = response.function_call_id
|
361
|
-
|
362
|
-
have_all_inputs = False
|
363
|
-
num_inputs = 0
|
364
|
-
num_outputs = 0
|
365
|
-
pending_outputs: Dict[str, int] = {} # Map input_id -> next expected gen_index value
|
366
|
-
completed_outputs: Set[str] = set() # Set of input_ids whose outputs are complete (expecting no more values)
|
367
|
-
|
368
|
-
input_queue: asyncio.Queue = asyncio.Queue()
|
369
|
-
|
370
|
-
async def create_input(arg: Any) -> api_pb2.FunctionPutInputsItem:
|
371
|
-
nonlocal num_inputs
|
372
|
-
idx = num_inputs
|
373
|
-
num_inputs += 1
|
374
|
-
item = await _create_input(arg, kwargs, client, idx=idx)
|
375
|
-
return item
|
376
|
-
|
377
|
-
async def drain_input_generator():
|
378
|
-
# Parallelize uploading blobs
|
379
|
-
proto_input_stream = stream.iterate(input_stream) | pipe.map(
|
380
|
-
create_input, # type: ignore[reportArgumentType]
|
381
|
-
ordered=True,
|
382
|
-
task_limit=BLOB_MAX_PARALLELISM,
|
383
|
-
)
|
384
|
-
async with proto_input_stream.stream() as streamer:
|
385
|
-
async for item in streamer:
|
386
|
-
await input_queue.put(item)
|
387
|
-
|
388
|
-
# close queue iterator
|
389
|
-
await input_queue.put(None)
|
390
|
-
yield
|
391
|
-
|
392
|
-
async def pump_inputs():
|
393
|
-
assert client.stub
|
394
|
-
nonlocal have_all_inputs
|
395
|
-
async for items in queue_batch_iterator(input_queue, MAP_INVOCATION_CHUNK_SIZE):
|
396
|
-
request = api_pb2.FunctionPutInputsRequest(
|
397
|
-
function_id=function_id, inputs=items, function_call_id=function_call_id
|
398
|
-
)
|
399
|
-
logger.debug(
|
400
|
-
f"Pushing {len(items)} inputs to server. Num queued inputs awaiting push is {input_queue.qsize()}."
|
401
|
-
)
|
402
|
-
resp = await retry_transient_errors(
|
403
|
-
client.stub.FunctionPutInputs,
|
404
|
-
request,
|
405
|
-
max_retries=None,
|
406
|
-
max_delay=10,
|
407
|
-
additional_status_codes=[Status.RESOURCE_EXHAUSTED],
|
408
|
-
)
|
409
|
-
for item in resp.inputs:
|
410
|
-
pending_outputs.setdefault(item.input_id, 0)
|
411
|
-
logger.debug(
|
412
|
-
f"Successfully pushed {len(items)} inputs to server. Num queued inputs awaiting push is {input_queue.qsize()}."
|
413
|
-
)
|
414
|
-
|
415
|
-
have_all_inputs = True
|
416
|
-
yield
|
417
|
-
|
418
|
-
async def get_all_outputs():
|
419
|
-
assert client.stub
|
420
|
-
nonlocal num_inputs, num_outputs, have_all_inputs
|
421
|
-
last_entry_id = "0-0"
|
422
|
-
while not have_all_inputs or len(pending_outputs) > len(completed_outputs):
|
423
|
-
request = api_pb2.FunctionGetOutputsRequest(
|
424
|
-
function_call_id=function_call_id,
|
425
|
-
timeout=OUTPUTS_TIMEOUT,
|
426
|
-
last_entry_id=last_entry_id,
|
427
|
-
clear_on_success=False,
|
428
|
-
)
|
429
|
-
response = await retry_transient_errors(
|
430
|
-
client.stub.FunctionGetOutputs,
|
431
|
-
request,
|
432
|
-
max_retries=20,
|
433
|
-
attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
|
434
|
-
)
|
435
|
-
|
436
|
-
if len(response.outputs) == 0:
|
437
|
-
continue
|
438
|
-
|
439
|
-
last_entry_id = response.last_entry_id
|
440
|
-
for item in response.outputs:
|
441
|
-
pending_outputs.setdefault(item.input_id, 0)
|
442
|
-
if item.input_id in completed_outputs:
|
443
|
-
# If this input is already completed, it means the output has already been
|
444
|
-
# processed and was received again due to a duplicate.
|
445
|
-
continue
|
446
|
-
completed_outputs.add(item.input_id)
|
447
|
-
num_outputs += 1
|
448
|
-
yield item
|
449
|
-
|
450
|
-
async def get_all_outputs_and_clean_up():
|
451
|
-
assert client.stub
|
452
|
-
try:
|
453
|
-
async for item in get_all_outputs():
|
454
|
-
yield item
|
455
|
-
finally:
|
456
|
-
# "ack" that we have all outputs we are interested in and let backend clear results
|
457
|
-
request = api_pb2.FunctionGetOutputsRequest(
|
458
|
-
function_call_id=function_call_id,
|
459
|
-
timeout=0,
|
460
|
-
last_entry_id="0-0",
|
461
|
-
clear_on_success=True,
|
462
|
-
)
|
463
|
-
await retry_transient_errors(client.stub.FunctionGetOutputs, request)
|
464
|
-
|
465
|
-
async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> Tuple[int, Any]:
|
466
|
-
try:
|
467
|
-
output = await _process_result(item.result, item.data_format, client.stub, client)
|
468
|
-
except Exception as e:
|
469
|
-
if return_exceptions:
|
470
|
-
output = e
|
471
|
-
else:
|
472
|
-
raise e
|
473
|
-
return (item.idx, output)
|
474
|
-
|
475
|
-
async def poll_outputs():
|
476
|
-
outputs = stream.iterate(get_all_outputs_and_clean_up())
|
477
|
-
outputs_fetched = outputs | pipe.map(fetch_output, ordered=True, task_limit=BLOB_MAX_PARALLELISM) # type: ignore
|
478
|
-
|
479
|
-
# map to store out-of-order outputs received
|
480
|
-
received_outputs = {}
|
481
|
-
output_idx = 0
|
482
|
-
|
483
|
-
async with outputs_fetched.stream() as streamer:
|
484
|
-
async for idx, output in streamer:
|
485
|
-
if count_update_callback is not None:
|
486
|
-
count_update_callback(num_outputs, num_inputs)
|
487
|
-
if not order_outputs:
|
488
|
-
yield _OutputValue(output)
|
489
|
-
else:
|
490
|
-
# hold on to outputs for function maps, so we can reorder them correctly.
|
491
|
-
received_outputs[idx] = output
|
492
|
-
while output_idx in received_outputs:
|
493
|
-
output = received_outputs.pop(output_idx)
|
494
|
-
yield _OutputValue(output)
|
495
|
-
output_idx += 1
|
496
|
-
|
497
|
-
assert len(received_outputs) == 0
|
498
|
-
|
499
|
-
response_gen = stream.merge(drain_input_generator(), pump_inputs(), poll_outputs())
|
500
|
-
|
501
|
-
async with response_gen.stream() as streamer:
|
502
|
-
async for response in streamer:
|
503
|
-
if response is not None:
|
504
|
-
yield response.value
|
505
|
-
|
506
|
-
|
507
305
|
# Wrapper type for api_pb2.FunctionStats
|
508
306
|
@dataclass(frozen=True)
|
509
307
|
class FunctionStats:
|
510
308
|
"""Simple data structure storing stats for a running function."""
|
511
309
|
|
512
310
|
backlog: int
|
513
|
-
num_active_runners: int
|
514
311
|
num_total_runners: int
|
515
312
|
|
313
|
+
def __getattr__(self, name):
|
314
|
+
if name == "num_active_runners":
|
315
|
+
msg = (
|
316
|
+
"'FunctionStats.num_active_runners' is deprecated."
|
317
|
+
" It currently always has a value of 0,"
|
318
|
+
" but it will be removed in a future release."
|
319
|
+
)
|
320
|
+
deprecation_warning((2024, 6, 14), msg)
|
321
|
+
return 0
|
322
|
+
raise AttributeError(f"'FunctionStats' object has no attribute '{name}'")
|
323
|
+
|
516
324
|
|
517
325
|
def _parse_retries(
|
518
326
|
retries: Optional[Union[int, Retries]],
|
519
|
-
|
327
|
+
source: str = "",
|
520
328
|
) -> Optional[api_pb2.FunctionRetryPolicy]:
|
521
329
|
if isinstance(retries, int):
|
522
330
|
return Retries(
|
@@ -529,118 +337,168 @@ def _parse_retries(
|
|
529
337
|
elif retries is None:
|
530
338
|
return None
|
531
339
|
else:
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
)
|
340
|
+
extra = f" on {source}" if source else ""
|
341
|
+
msg = f"Retries parameter must be an integer or instance of modal.Retries. Found: {type(retries)}{extra}."
|
342
|
+
raise InvalidError(msg)
|
536
343
|
|
537
344
|
|
538
345
|
@dataclass
|
539
|
-
class
|
346
|
+
class _FunctionSpec:
|
540
347
|
"""
|
541
|
-
Stores information about
|
542
|
-
|
348
|
+
Stores information about a Function specification.
|
349
|
+
This is used for `modal shell` to support running shells with
|
350
|
+
the same configuration as a user-defined Function.
|
543
351
|
"""
|
544
352
|
|
545
353
|
image: Optional[_Image]
|
546
354
|
mounts: Sequence[_Mount]
|
547
355
|
secrets: Sequence[_Secret]
|
548
|
-
network_file_systems:
|
549
|
-
volumes:
|
550
|
-
|
356
|
+
network_file_systems: dict[Union[str, PurePosixPath], _NetworkFileSystem]
|
357
|
+
volumes: dict[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]]
|
358
|
+
gpus: Union[GPU_T, list[GPU_T]] # TODO(irfansharif): Somehow assert that it's the first kind, in sandboxes
|
551
359
|
cloud: Optional[str]
|
552
|
-
cpu: Optional[float]
|
553
|
-
memory: Optional[int]
|
360
|
+
cpu: Optional[Union[float, tuple[float, float]]]
|
361
|
+
memory: Optional[Union[int, tuple[int, int]]]
|
362
|
+
ephemeral_disk: Optional[int]
|
363
|
+
scheduler_placement: Optional[SchedulerPlacement]
|
364
|
+
proxy: Optional[_Proxy]
|
365
|
+
|
366
|
+
|
367
|
+
P = typing_extensions.ParamSpec("P")
|
368
|
+
ReturnType = typing.TypeVar("ReturnType", covariant=True)
|
369
|
+
OriginalReturnType = typing.TypeVar(
|
370
|
+
"OriginalReturnType", covariant=True
|
371
|
+
) # differs from return type if ReturnType is coroutine
|
554
372
|
|
555
373
|
|
556
|
-
class _Function(_Object, type_prefix="fu"):
|
374
|
+
class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type_prefix="fu"):
|
557
375
|
"""Functions are the basic units of serverless execution on Modal.
|
558
376
|
|
559
377
|
Generally, you will not construct a `Function` directly. Instead, use the
|
560
|
-
|
378
|
+
`App.function()` decorator to register your Python functions with your App.
|
561
379
|
"""
|
562
380
|
|
563
381
|
# TODO: more type annotations
|
564
382
|
_info: Optional[FunctionInfo]
|
565
|
-
|
566
|
-
|
567
|
-
_obj:
|
383
|
+
_serve_mounts: frozenset[_Mount] # set at load time, only by loader
|
384
|
+
_app: Optional["modal.app._App"] = None
|
385
|
+
_obj: Optional["modal.cls._Obj"] = None # only set for InstanceServiceFunctions and bound instance methods
|
568
386
|
_web_url: Optional[str]
|
569
|
-
_is_remote_cls_method: bool = False # TODO(erikbern): deprecated
|
570
387
|
_function_name: Optional[str]
|
571
388
|
_is_method: bool
|
572
|
-
|
389
|
+
_spec: Optional[_FunctionSpec] = None
|
573
390
|
_tag: str
|
574
391
|
_raw_f: Callable[..., Any]
|
575
392
|
_build_args: dict
|
576
|
-
|
393
|
+
|
394
|
+
_is_generator: Optional[bool] = None
|
395
|
+
_cluster_size: Optional[int] = None
|
396
|
+
|
397
|
+
# when this is the method of a class/object function, invocation of this function
|
398
|
+
# should supply the method name in the FunctionInput:
|
399
|
+
_use_method_name: str = ""
|
400
|
+
|
401
|
+
_class_parameter_info: Optional["api_pb2.ClassParameterInfo"] = None
|
402
|
+
_method_handle_metadata: Optional[dict[str, "api_pb2.FunctionHandleMetadata"]] = None
|
403
|
+
|
404
|
+
def _bind_method(
|
405
|
+
self,
|
406
|
+
user_cls,
|
407
|
+
method_name: str,
|
408
|
+
partial_function: "modal.partial_function._PartialFunction",
|
409
|
+
):
|
410
|
+
"""mdmd:hidden
|
411
|
+
|
412
|
+
Creates a _Function that is bound to a specific class method name. This _Function is not uniquely tied
|
413
|
+
to any backend function -- its object_id is the function ID of the class service function.
|
414
|
+
|
415
|
+
"""
|
416
|
+
class_service_function = self
|
417
|
+
assert class_service_function._info # has to be a local function to be able to "bind" it
|
418
|
+
assert not class_service_function._is_method # should not be used on an already bound method placeholder
|
419
|
+
assert not class_service_function._obj # should only be used on base function / class service function
|
420
|
+
full_name = f"{user_cls.__name__}.{method_name}"
|
421
|
+
|
422
|
+
rep = f"Method({full_name})"
|
423
|
+
fun = _Object.__new__(_Function)
|
424
|
+
fun._init(rep)
|
425
|
+
fun._tag = full_name
|
426
|
+
fun._raw_f = partial_function.raw_f
|
427
|
+
fun._info = FunctionInfo(
|
428
|
+
partial_function.raw_f, user_cls=user_cls, serialized=class_service_function.info.is_serialized()
|
429
|
+
) # needed for .local()
|
430
|
+
fun._use_method_name = method_name
|
431
|
+
fun._app = class_service_function._app
|
432
|
+
fun._is_generator = partial_function.is_generator
|
433
|
+
fun._cluster_size = partial_function.cluster_size
|
434
|
+
fun._spec = class_service_function._spec
|
435
|
+
fun._is_method = True
|
436
|
+
return fun
|
577
437
|
|
578
438
|
@staticmethod
|
579
439
|
def from_args(
|
580
440
|
info: FunctionInfo,
|
581
|
-
|
441
|
+
app,
|
582
442
|
image: _Image,
|
583
|
-
secret: Optional[_Secret] = None,
|
584
443
|
secrets: Sequence[_Secret] = (),
|
585
444
|
schedule: Optional[Schedule] = None,
|
586
|
-
is_generator=False,
|
587
|
-
gpu: GPU_T = None,
|
445
|
+
is_generator: bool = False,
|
446
|
+
gpu: Union[GPU_T, list[GPU_T]] = None,
|
588
447
|
# TODO: maybe break this out into a separate decorator for notebooks.
|
589
448
|
mounts: Collection[_Mount] = (),
|
590
|
-
network_file_systems:
|
449
|
+
network_file_systems: dict[Union[str, PurePosixPath], _NetworkFileSystem] = {},
|
591
450
|
allow_cross_region_volumes: bool = False,
|
592
|
-
volumes:
|
451
|
+
volumes: dict[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]] = {},
|
593
452
|
webhook_config: Optional[api_pb2.WebhookConfig] = None,
|
594
|
-
memory: Optional[int] = None,
|
453
|
+
memory: Optional[Union[int, tuple[int, int]]] = None,
|
595
454
|
proxy: Optional[_Proxy] = None,
|
596
455
|
retries: Optional[Union[int, Retries]] = None,
|
597
456
|
timeout: Optional[int] = None,
|
598
457
|
concurrency_limit: Optional[int] = None,
|
599
458
|
allow_concurrent_inputs: Optional[int] = None,
|
459
|
+
batch_max_size: Optional[int] = None,
|
460
|
+
batch_wait_ms: Optional[int] = None,
|
600
461
|
container_idle_timeout: Optional[int] = None,
|
601
|
-
cpu: Optional[float] = None,
|
462
|
+
cpu: Optional[Union[float, tuple[float, float]]] = None,
|
602
463
|
keep_warm: Optional[int] = None, # keep_warm=True is equivalent to keep_warm=1
|
603
464
|
cloud: Optional[str] = None,
|
604
|
-
|
605
|
-
_experimental_scheduler: bool = False,
|
606
|
-
_experimental_scheduler_placement: Optional[SchedulerPlacement] = None,
|
465
|
+
scheduler_placement: Optional[SchedulerPlacement] = None,
|
607
466
|
is_builder_function: bool = False,
|
608
467
|
is_auto_snapshot: bool = False,
|
609
468
|
enable_memory_snapshot: bool = False,
|
610
|
-
checkpointing_enabled: Optional[bool] = None,
|
611
|
-
allow_background_volume_commits: bool = False,
|
612
469
|
block_network: bool = False,
|
470
|
+
i6pn_enabled: bool = False,
|
471
|
+
cluster_size: Optional[int] = None, # Experimental: Clustered functions
|
613
472
|
max_inputs: Optional[int] = None,
|
473
|
+
ephemeral_disk: Optional[int] = None,
|
474
|
+
_experimental_buffer_containers: Optional[int] = None,
|
475
|
+
_experimental_proxy_ip: Optional[str] = None,
|
476
|
+
_experimental_custom_scaling_factor: Optional[float] = None,
|
614
477
|
) -> None:
|
615
478
|
"""mdmd:hidden"""
|
479
|
+
# Needed to avoid circular imports
|
480
|
+
from .partial_function import _find_partial_methods_for_user_cls, _PartialFunctionFlags
|
481
|
+
|
616
482
|
tag = info.get_tag()
|
617
483
|
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
if not info.is_nullary():
|
484
|
+
if info.raw_f:
|
485
|
+
raw_f = info.raw_f
|
486
|
+
assert callable(raw_f)
|
487
|
+
if schedule is not None and not info.is_nullary():
|
622
488
|
raise InvalidError(
|
623
489
|
f"Function {raw_f} has a schedule, so it needs to support being called with no arguments"
|
624
490
|
)
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
)
|
631
|
-
secrets = [secret, *secrets]
|
632
|
-
|
633
|
-
if checkpointing_enabled is not None:
|
634
|
-
deprecation_warning(
|
635
|
-
(2024, 3, 4),
|
636
|
-
"The argument `checkpointing_enabled` is now deprecated. Use `enable_memory_snapshot` instead.",
|
637
|
-
)
|
638
|
-
enable_memory_snapshot = checkpointing_enabled
|
491
|
+
else:
|
492
|
+
# must be a "class service function"
|
493
|
+
assert info.user_cls
|
494
|
+
assert not webhook_config
|
495
|
+
assert not schedule
|
639
496
|
|
640
497
|
explicit_mounts = mounts
|
641
498
|
|
642
499
|
if is_local():
|
643
500
|
entrypoint_mounts = info.get_entrypoint_mount()
|
501
|
+
|
644
502
|
all_mounts = [
|
645
503
|
_get_client_mount(),
|
646
504
|
*explicit_mounts,
|
@@ -648,45 +506,57 @@ class _Function(_Object, type_prefix="fu"):
|
|
648
506
|
]
|
649
507
|
|
650
508
|
if config.get("automount"):
|
651
|
-
|
652
|
-
all_mounts += automounts
|
509
|
+
all_mounts += get_auto_mounts()
|
653
510
|
else:
|
654
511
|
# skip any mount introspection/logic inside containers, since the function
|
655
512
|
# should already be hydrated
|
656
513
|
# TODO: maybe the entire constructor should be exited early if not local?
|
657
514
|
all_mounts = []
|
658
515
|
|
659
|
-
retry_policy = _parse_retries(
|
516
|
+
retry_policy = _parse_retries(
|
517
|
+
retries, f"Function '{info.get_tag()}'" if info.raw_f else f"Class '{info.get_tag()}'"
|
518
|
+
)
|
660
519
|
|
661
|
-
|
520
|
+
if webhook_config is not None and retry_policy is not None:
|
521
|
+
raise InvalidError(
|
522
|
+
"Web endpoints do not support retries.",
|
523
|
+
)
|
524
|
+
|
525
|
+
if is_generator and retry_policy is not None:
|
526
|
+
deprecation_warning(
|
527
|
+
(2024, 6, 25),
|
528
|
+
"Retries for generator functions are deprecated and will soon be removed.",
|
529
|
+
)
|
662
530
|
|
663
531
|
if proxy:
|
664
532
|
# HACK: remove this once we stop using ssh tunnels for this.
|
665
533
|
if image:
|
534
|
+
# TODO(elias): this will cause an error if users use prior `.add_local_*` commands without copy=True
|
666
535
|
image = image.apt_install("autossh")
|
667
536
|
|
668
|
-
|
537
|
+
function_spec = _FunctionSpec(
|
669
538
|
mounts=all_mounts,
|
670
539
|
secrets=secrets,
|
671
|
-
|
540
|
+
gpus=gpu,
|
672
541
|
network_file_systems=network_file_systems,
|
673
542
|
volumes=volumes,
|
674
543
|
image=image,
|
675
544
|
cloud=cloud,
|
676
545
|
cpu=cpu,
|
677
546
|
memory=memory,
|
547
|
+
ephemeral_disk=ephemeral_disk,
|
548
|
+
scheduler_placement=scheduler_placement,
|
549
|
+
proxy=proxy,
|
678
550
|
)
|
679
551
|
|
680
|
-
if info.
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
for build_function in build_functions:
|
686
|
-
snapshot_info = FunctionInfo(build_function, cls=info.cls)
|
552
|
+
if info.user_cls and not is_auto_snapshot:
|
553
|
+
build_functions = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.BUILD).items()
|
554
|
+
for k, pf in build_functions:
|
555
|
+
build_function = pf.raw_f
|
556
|
+
snapshot_info = FunctionInfo(build_function, user_cls=info.user_cls)
|
687
557
|
snapshot_function = _Function.from_args(
|
688
558
|
snapshot_info,
|
689
|
-
|
559
|
+
app=None,
|
690
560
|
image=image,
|
691
561
|
secrets=secrets,
|
692
562
|
gpu=gpu,
|
@@ -694,16 +564,17 @@ class _Function(_Object, type_prefix="fu"):
|
|
694
564
|
network_file_systems=network_file_systems,
|
695
565
|
volumes=volumes,
|
696
566
|
memory=memory,
|
697
|
-
timeout=
|
567
|
+
timeout=pf.build_timeout,
|
698
568
|
cpu=cpu,
|
569
|
+
ephemeral_disk=ephemeral_disk,
|
699
570
|
is_builder_function=True,
|
700
571
|
is_auto_snapshot=True,
|
701
|
-
|
572
|
+
scheduler_placement=scheduler_placement,
|
702
573
|
)
|
703
574
|
image = _Image._from_args(
|
704
575
|
base_images={"base": image},
|
705
576
|
build_function=snapshot_function,
|
706
|
-
force_build=image.force_build,
|
577
|
+
force_build=image.force_build or pf.force_build,
|
707
578
|
)
|
708
579
|
|
709
580
|
if keep_warm is not None and not isinstance(keep_warm, int):
|
@@ -711,9 +582,15 @@ class _Function(_Object, type_prefix="fu"):
|
|
711
582
|
|
712
583
|
if (keep_warm is not None) and (concurrency_limit is not None) and concurrency_limit < keep_warm:
|
713
584
|
raise InvalidError(
|
714
|
-
f"Function `{info.function_name}` has `{concurrency_limit=}`,
|
585
|
+
f"Function `{info.function_name}` has `{concurrency_limit=}`, "
|
586
|
+
f"strictly less than its `{keep_warm=}` parameter."
|
715
587
|
)
|
716
588
|
|
589
|
+
if _experimental_custom_scaling_factor is not None and (
|
590
|
+
_experimental_custom_scaling_factor < 0 or _experimental_custom_scaling_factor > 1
|
591
|
+
):
|
592
|
+
raise InvalidError("`_experimental_custom_scaling_factor` must be between 0.0 and 1.0 inclusive.")
|
593
|
+
|
717
594
|
if not cloud and not is_builder_function:
|
718
595
|
cloud = config.get("default_cloud")
|
719
596
|
if cloud:
|
@@ -730,22 +607,56 @@ class _Function(_Object, type_prefix="fu"):
|
|
730
607
|
else:
|
731
608
|
raise InvalidError("Webhooks cannot be generators")
|
732
609
|
|
610
|
+
if info.raw_f and batch_max_size:
|
611
|
+
func_name = info.raw_f.__name__
|
612
|
+
if is_generator:
|
613
|
+
raise InvalidError(f"Modal batched function {func_name} cannot return generators")
|
614
|
+
for arg in inspect.signature(info.raw_f).parameters.values():
|
615
|
+
if arg.default is not inspect.Parameter.empty:
|
616
|
+
raise InvalidError(f"Modal batched function {func_name} does not accept default arguments.")
|
617
|
+
|
618
|
+
if container_idle_timeout is not None and container_idle_timeout <= 0:
|
619
|
+
raise InvalidError("`container_idle_timeout` must be > 0")
|
620
|
+
|
621
|
+
if max_inputs is not None:
|
622
|
+
if not isinstance(max_inputs, int):
|
623
|
+
raise InvalidError(f"`max_inputs` must be an int, not {type(max_inputs).__name__}")
|
624
|
+
if max_inputs <= 0:
|
625
|
+
raise InvalidError("`max_inputs` must be positive")
|
626
|
+
if max_inputs > 1:
|
627
|
+
raise InvalidError("Only `max_inputs=1` is currently supported")
|
628
|
+
|
733
629
|
# Validate volumes
|
734
630
|
validated_volumes = validate_volumes(volumes)
|
735
631
|
cloud_bucket_mounts = [(k, v) for k, v in validated_volumes if isinstance(v, _CloudBucketMount)]
|
736
632
|
validated_volumes = [(k, v) for k, v in validated_volumes if isinstance(v, _Volume)]
|
737
633
|
|
738
634
|
# Validate NFS
|
739
|
-
|
740
|
-
raise InvalidError("network_file_systems must be a dict[str, NetworkFileSystem] where the keys are paths")
|
741
|
-
validated_network_file_systems = validate_mount_points("Network file system", network_file_systems)
|
635
|
+
validated_network_file_systems = validate_network_file_systems(network_file_systems)
|
742
636
|
|
743
637
|
# Validate image
|
744
638
|
if image is not None and not isinstance(image, _Image):
|
745
639
|
raise InvalidError(f"Expected modal.Image object. Got {type(image)}.")
|
746
640
|
|
747
|
-
|
748
|
-
|
641
|
+
method_definitions: Optional[dict[str, api_pb2.MethodDefinition]] = None
|
642
|
+
|
643
|
+
if info.user_cls:
|
644
|
+
method_definitions = {}
|
645
|
+
partial_functions = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.FUNCTION)
|
646
|
+
for method_name, partial_function in partial_functions.items():
|
647
|
+
function_type = get_function_type(partial_function.is_generator)
|
648
|
+
function_name = f"{info.user_cls.__name__}.{method_name}"
|
649
|
+
method_definition = api_pb2.MethodDefinition(
|
650
|
+
webhook_config=partial_function.webhook_config,
|
651
|
+
function_type=function_type,
|
652
|
+
function_name=function_name,
|
653
|
+
)
|
654
|
+
method_definitions[method_name] = method_definition
|
655
|
+
|
656
|
+
function_type = get_function_type(is_generator)
|
657
|
+
|
658
|
+
def _deps(only_explicit_mounts=False) -> list[_Object]:
|
659
|
+
deps: list[_Object] = list(secrets)
|
749
660
|
if only_explicit_mounts:
|
750
661
|
# TODO: this is a bit hacky, but all_mounts may differ in the container vs locally
|
751
662
|
# We don't want the function dependencies to change, so we have this way to force it to
|
@@ -769,271 +680,358 @@ class _Function(_Object, type_prefix="fu"):
|
|
769
680
|
if cloud_bucket_mount.secret:
|
770
681
|
deps.append(cloud_bucket_mount.secret)
|
771
682
|
|
772
|
-
# Add implicit dependencies from the function's code
|
773
|
-
objs: list[Object] = get_referred_objects(info.raw_f)
|
774
|
-
_objs: list[_Object] = synchronizer._translate_in(objs) # type: ignore
|
775
|
-
deps += _objs
|
776
683
|
return deps
|
777
684
|
|
778
685
|
async def _preload(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
|
779
686
|
assert resolver.client and resolver.client.stub
|
780
|
-
if is_generator:
|
781
|
-
function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR
|
782
|
-
else:
|
783
|
-
function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION
|
784
687
|
|
688
|
+
assert resolver.app_id
|
785
689
|
req = api_pb2.FunctionPrecreateRequest(
|
786
690
|
app_id=resolver.app_id,
|
787
691
|
function_name=info.function_name,
|
788
692
|
function_type=function_type,
|
789
|
-
webhook_config=webhook_config,
|
790
693
|
existing_function_id=existing_object_id or "",
|
791
694
|
)
|
695
|
+
if method_definitions:
|
696
|
+
for method_name, method_definition in method_definitions.items():
|
697
|
+
req.method_definitions[method_name].CopyFrom(method_definition)
|
698
|
+
elif webhook_config:
|
699
|
+
req.webhook_config.CopyFrom(webhook_config)
|
792
700
|
response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req)
|
793
701
|
self._hydrate(response.function_id, resolver.client, response.handle_metadata)
|
794
702
|
|
795
703
|
async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
|
796
704
|
assert resolver.client and resolver.client.stub
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
if is_generator:
|
801
|
-
function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR
|
802
|
-
else:
|
803
|
-
function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION
|
804
|
-
|
805
|
-
if cpu is not None and cpu < 0.25:
|
806
|
-
raise InvalidError(f"Invalid fractional CPU value {cpu}. Cannot have less than 0.25 CPU resources.")
|
807
|
-
milli_cpu = int(1000 * cpu) if cpu is not None else 0
|
808
|
-
|
809
|
-
timeout_secs = timeout
|
705
|
+
with FunctionCreationStatus(resolver, tag) as function_creation_status:
|
706
|
+
timeout_secs = timeout
|
810
707
|
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
708
|
+
if app and app.is_interactive and not is_builder_function:
|
709
|
+
pty_info = get_pty_info(shell=False)
|
710
|
+
else:
|
711
|
+
pty_info = None
|
712
|
+
|
713
|
+
if info.is_serialized():
|
714
|
+
# Use cloudpickle. Used when working w/ Jupyter notebooks.
|
715
|
+
# serialize at _load time, not function decoration time
|
716
|
+
# otherwise we can't capture a surrounding class for lifetime methods etc.
|
717
|
+
function_serialized = info.serialized_function()
|
718
|
+
class_serialized = serialize(info.user_cls) if info.user_cls is not None else None
|
719
|
+
# Ensure that large data in global variables does not blow up the gRPC payload,
|
720
|
+
# which has maximum size 100 MiB. We set the limit lower for performance reasons.
|
721
|
+
if len(function_serialized) > 16 << 20: # 16 MiB
|
722
|
+
raise InvalidError(
|
723
|
+
f"Function {info.raw_f} has size {len(function_serialized)} bytes when packaged. "
|
724
|
+
"This is larger than the maximum limit of 16 MiB. "
|
725
|
+
"Try reducing the size of the closure by using parameters or mounts, "
|
726
|
+
"not large global variables."
|
727
|
+
)
|
728
|
+
elif len(function_serialized) > 256 << 10: # 256 KiB
|
729
|
+
warnings.warn(
|
730
|
+
f"Function {info.raw_f} has size {len(function_serialized)} bytes when packaged. "
|
731
|
+
"This is larger than the recommended limit of 256 KiB. "
|
732
|
+
"Try reducing the size of the closure by using parameters or mounts, "
|
733
|
+
"not large global variables."
|
734
|
+
)
|
735
|
+
else:
|
736
|
+
function_serialized = None
|
737
|
+
class_serialized = None
|
738
|
+
|
739
|
+
app_name = ""
|
740
|
+
if app and app.name:
|
741
|
+
app_name = app.name
|
742
|
+
|
743
|
+
# Relies on dicts being ordered (true as of Python 3.6).
|
744
|
+
volume_mounts = [
|
745
|
+
api_pb2.VolumeMount(
|
746
|
+
mount_path=path,
|
747
|
+
volume_id=volume.object_id,
|
748
|
+
allow_background_commits=True,
|
836
749
|
)
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
750
|
+
for path, volume in validated_volumes
|
751
|
+
]
|
752
|
+
loaded_mount_ids = {m.object_id for m in all_mounts} | {m.object_id for m in image._mount_layers}
|
753
|
+
|
754
|
+
# Get object dependencies
|
755
|
+
object_dependencies = []
|
756
|
+
for dep in _deps(only_explicit_mounts=True):
|
757
|
+
if not dep.object_id:
|
758
|
+
raise Exception(f"Dependency {dep} isn't hydrated")
|
759
|
+
object_dependencies.append(api_pb2.ObjectDependency(object_id=dep.object_id))
|
760
|
+
|
761
|
+
function_data: Optional[api_pb2.FunctionData] = None
|
762
|
+
function_definition: Optional[api_pb2.Function] = None
|
763
|
+
|
764
|
+
# Create function remotely
|
765
|
+
function_definition = api_pb2.Function(
|
766
|
+
module_name=info.module_name or "",
|
767
|
+
function_name=info.function_name,
|
768
|
+
mount_ids=loaded_mount_ids,
|
769
|
+
secret_ids=[secret.object_id for secret in secrets],
|
770
|
+
image_id=(image.object_id if image else ""),
|
771
|
+
definition_type=info.get_definition_type(),
|
772
|
+
function_serialized=function_serialized or b"",
|
773
|
+
class_serialized=class_serialized or b"",
|
774
|
+
function_type=function_type,
|
775
|
+
webhook_config=webhook_config,
|
776
|
+
method_definitions=method_definitions,
|
777
|
+
method_definitions_set=True,
|
778
|
+
shared_volume_mounts=network_file_system_mount_protos(
|
779
|
+
validated_network_file_systems, allow_cross_region_volumes
|
780
|
+
),
|
781
|
+
volume_mounts=volume_mounts,
|
782
|
+
proxy_id=(proxy.object_id if proxy else None),
|
783
|
+
retry_policy=retry_policy,
|
784
|
+
timeout_secs=timeout_secs or 0,
|
785
|
+
task_idle_timeout_secs=container_idle_timeout or 0,
|
786
|
+
concurrency_limit=concurrency_limit or 0,
|
787
|
+
pty_info=pty_info,
|
788
|
+
cloud_provider=cloud_provider,
|
789
|
+
warm_pool_size=keep_warm or 0,
|
790
|
+
runtime=config.get("function_runtime"),
|
791
|
+
runtime_debug=config.get("function_runtime_debug"),
|
792
|
+
runtime_perf_record=config.get("runtime_perf_record"),
|
793
|
+
app_name=app_name,
|
794
|
+
is_builder_function=is_builder_function,
|
795
|
+
target_concurrent_inputs=allow_concurrent_inputs or 0,
|
796
|
+
batch_max_size=batch_max_size or 0,
|
797
|
+
batch_linger_ms=batch_wait_ms or 0,
|
798
|
+
worker_id=config.get("worker_id"),
|
799
|
+
is_auto_snapshot=is_auto_snapshot,
|
800
|
+
is_method=bool(info.user_cls) and not info.is_service_class(),
|
801
|
+
checkpointing_enabled=enable_memory_snapshot,
|
802
|
+
object_dependencies=object_dependencies,
|
803
|
+
block_network=block_network,
|
804
|
+
max_inputs=max_inputs or 0,
|
805
|
+
cloud_bucket_mounts=cloud_bucket_mounts_to_proto(cloud_bucket_mounts),
|
806
|
+
scheduler_placement=scheduler_placement.proto if scheduler_placement else None,
|
807
|
+
is_class=info.is_service_class(),
|
808
|
+
class_parameter_info=info.class_parameter_info(),
|
809
|
+
i6pn_enabled=i6pn_enabled,
|
810
|
+
schedule=schedule.proto_message if schedule is not None else None,
|
811
|
+
snapshot_debug=config.get("snapshot_debug"),
|
812
|
+
_experimental_group_size=cluster_size or 0, # Experimental: Clustered functions
|
813
|
+
_experimental_concurrent_cancellations=True,
|
814
|
+
_experimental_buffer_containers=_experimental_buffer_containers or 0,
|
815
|
+
_experimental_proxy_ip=_experimental_proxy_ip,
|
816
|
+
_experimental_custom_scaling=_experimental_custom_scaling_factor is not None,
|
851
817
|
)
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
retry_policy=retry_policy,
|
882
|
-
timeout_secs=timeout_secs or 0,
|
883
|
-
task_idle_timeout_secs=container_idle_timeout or 0,
|
884
|
-
concurrency_limit=concurrency_limit or 0,
|
885
|
-
pty_info=pty_info,
|
886
|
-
cloud_provider=cloud_provider,
|
887
|
-
warm_pool_size=keep_warm or 0,
|
888
|
-
runtime=config.get("function_runtime"),
|
889
|
-
runtime_debug=config.get("function_runtime_debug"),
|
890
|
-
stub_name=stub_name,
|
891
|
-
is_builder_function=is_builder_function,
|
892
|
-
allow_concurrent_inputs=allow_concurrent_inputs or 0,
|
893
|
-
worker_id=config.get("worker_id"),
|
894
|
-
is_auto_snapshot=is_auto_snapshot,
|
895
|
-
is_method=bool(info.cls),
|
896
|
-
checkpointing_enabled=enable_memory_snapshot,
|
897
|
-
is_checkpointing_function=False,
|
898
|
-
object_dependencies=object_dependencies,
|
899
|
-
block_network=block_network,
|
900
|
-
max_inputs=max_inputs or 0,
|
901
|
-
cloud_bucket_mounts=cloud_bucket_mounts_to_proto(cloud_bucket_mounts),
|
902
|
-
_experimental_boost=_experimental_boost,
|
903
|
-
_experimental_scheduler=_experimental_scheduler,
|
904
|
-
_experimental_scheduler_placement=_experimental_scheduler_placement.proto
|
905
|
-
if _experimental_scheduler_placement
|
906
|
-
else None,
|
907
|
-
)
|
908
|
-
request = api_pb2.FunctionCreateRequest(
|
909
|
-
app_id=resolver.app_id,
|
910
|
-
function=function_definition,
|
911
|
-
schedule=schedule.proto_message if schedule is not None else None,
|
912
|
-
existing_function_id=existing_object_id or "",
|
913
|
-
)
|
914
|
-
try:
|
915
|
-
response: api_pb2.FunctionCreateResponse = await retry_transient_errors(
|
916
|
-
resolver.client.stub.FunctionCreate, request
|
917
|
-
)
|
918
|
-
except GRPCError as exc:
|
919
|
-
if exc.status == Status.INVALID_ARGUMENT:
|
920
|
-
raise InvalidError(exc.message)
|
921
|
-
if exc.status == Status.FAILED_PRECONDITION:
|
922
|
-
raise InvalidError(exc.message)
|
923
|
-
if exc.message and "Received :status = '413'" in exc.message:
|
924
|
-
raise InvalidError(f"Function {raw_f} is too large to deploy.")
|
925
|
-
raise
|
926
|
-
|
927
|
-
if response.function.web_url:
|
928
|
-
# Ensure terms used here match terms used in modal.com/docs/guide/webhook-urls doc.
|
929
|
-
if response.function.web_url_info.truncated:
|
930
|
-
suffix = " [grey70](label truncated)[/grey70]"
|
931
|
-
elif response.function.web_url_info.has_unique_hash:
|
932
|
-
suffix = " [grey70](label includes conflict-avoidance hash)[/grey70]"
|
933
|
-
elif response.function.web_url_info.label_stolen:
|
934
|
-
suffix = " [grey70](label stolen)[/grey70]"
|
935
|
-
else:
|
936
|
-
suffix = ""
|
937
|
-
# TODO: this is only printed when we're showing progress. Maybe move this somewhere else.
|
938
|
-
status_row.finish(f"Created {tag} => [magenta underline]{response.web_url}[/magenta underline]{suffix}")
|
939
|
-
|
940
|
-
# Print custom domain in terminal
|
941
|
-
for custom_domain in response.function.custom_domain_info:
|
942
|
-
custom_domain_status_row = resolver.add_status_row()
|
943
|
-
custom_domain_status_row.finish(
|
944
|
-
f"Custom domain for {tag} => [magenta underline]{custom_domain.url}[/magenta underline]{suffix}"
|
818
|
+
|
819
|
+
if isinstance(gpu, list):
|
820
|
+
function_data = api_pb2.FunctionData(
|
821
|
+
module_name=function_definition.module_name,
|
822
|
+
function_name=function_definition.function_name,
|
823
|
+
function_type=function_definition.function_type,
|
824
|
+
warm_pool_size=function_definition.warm_pool_size,
|
825
|
+
concurrency_limit=function_definition.concurrency_limit,
|
826
|
+
task_idle_timeout_secs=function_definition.task_idle_timeout_secs,
|
827
|
+
worker_id=function_definition.worker_id,
|
828
|
+
timeout_secs=function_definition.timeout_secs,
|
829
|
+
web_url=function_definition.web_url,
|
830
|
+
web_url_info=function_definition.web_url_info,
|
831
|
+
webhook_config=function_definition.webhook_config,
|
832
|
+
custom_domain_info=function_definition.custom_domain_info,
|
833
|
+
schedule=schedule.proto_message if schedule is not None else None,
|
834
|
+
is_class=function_definition.is_class,
|
835
|
+
class_parameter_info=function_definition.class_parameter_info,
|
836
|
+
is_method=function_definition.is_method,
|
837
|
+
use_function_id=function_definition.use_function_id,
|
838
|
+
use_method_name=function_definition.use_method_name,
|
839
|
+
method_definitions=function_definition.method_definitions,
|
840
|
+
method_definitions_set=function_definition.method_definitions_set,
|
841
|
+
_experimental_group_size=function_definition._experimental_group_size,
|
842
|
+
_experimental_buffer_containers=function_definition._experimental_buffer_containers,
|
843
|
+
_experimental_custom_scaling=function_definition._experimental_custom_scaling,
|
844
|
+
_experimental_proxy_ip=function_definition._experimental_proxy_ip,
|
845
|
+
snapshot_debug=function_definition.snapshot_debug,
|
846
|
+
runtime_perf_record=function_definition.runtime_perf_record,
|
945
847
|
)
|
946
848
|
|
947
|
-
|
948
|
-
|
849
|
+
ranked_functions = []
|
850
|
+
for rank, _gpu in enumerate(gpu):
|
851
|
+
function_definition_copy = api_pb2.Function()
|
852
|
+
function_definition_copy.CopyFrom(function_definition)
|
853
|
+
|
854
|
+
function_definition_copy.resources.CopyFrom(
|
855
|
+
convert_fn_config_to_resources_config(
|
856
|
+
cpu=cpu, memory=memory, gpu=_gpu, ephemeral_disk=ephemeral_disk
|
857
|
+
),
|
858
|
+
)
|
859
|
+
ranked_function = api_pb2.FunctionData.RankedFunction(
|
860
|
+
rank=rank,
|
861
|
+
function=function_definition_copy,
|
862
|
+
)
|
863
|
+
ranked_functions.append(ranked_function)
|
864
|
+
function_data.ranked_functions.extend(ranked_functions)
|
865
|
+
function_definition = None # function_definition is not used in this case
|
866
|
+
else:
|
867
|
+
# TODO(irfansharif): Assert on this specific type once we get rid of python 3.9.
|
868
|
+
# assert isinstance(gpu, GPU_T) # includes the case where gpu==None case
|
869
|
+
function_definition.resources.CopyFrom(
|
870
|
+
convert_fn_config_to_resources_config(
|
871
|
+
cpu=cpu, memory=memory, gpu=gpu, ephemeral_disk=ephemeral_disk
|
872
|
+
), # type: ignore
|
873
|
+
)
|
949
874
|
|
875
|
+
assert resolver.app_id
|
876
|
+
assert (function_definition is None) != (function_data is None) # xor
|
877
|
+
request = api_pb2.FunctionCreateRequest(
|
878
|
+
app_id=resolver.app_id,
|
879
|
+
function=function_definition,
|
880
|
+
function_data=function_data,
|
881
|
+
existing_function_id=existing_object_id or "",
|
882
|
+
defer_updates=True,
|
883
|
+
)
|
884
|
+
try:
|
885
|
+
response: api_pb2.FunctionCreateResponse = await retry_transient_errors(
|
886
|
+
resolver.client.stub.FunctionCreate, request
|
887
|
+
)
|
888
|
+
except GRPCError as exc:
|
889
|
+
if exc.status == Status.INVALID_ARGUMENT:
|
890
|
+
raise InvalidError(exc.message)
|
891
|
+
if exc.status == Status.FAILED_PRECONDITION:
|
892
|
+
raise InvalidError(exc.message)
|
893
|
+
if exc.message and "Received :status = '413'" in exc.message:
|
894
|
+
raise InvalidError(f"Function {info.function_name} is too large to deploy.")
|
895
|
+
raise
|
896
|
+
function_creation_status.set_response(response)
|
897
|
+
serve_mounts = {m for m in all_mounts if m.is_local()} # needed for modal.serve file watching
|
898
|
+
serve_mounts |= image._serve_mounts
|
899
|
+
obj._serve_mounts = frozenset(serve_mounts)
|
950
900
|
self._hydrate(response.function_id, resolver.client, response.handle_metadata)
|
951
901
|
|
952
902
|
rep = f"Function({tag})"
|
953
903
|
obj = _Function._from_loader(_load, rep, preload=_preload, deps=_deps)
|
954
904
|
|
955
|
-
obj._raw_f = raw_f
|
905
|
+
obj._raw_f = info.raw_f
|
956
906
|
obj._info = info
|
957
907
|
obj._tag = tag
|
958
|
-
obj.
|
959
|
-
obj._stub = stub # needed for CLI right now
|
908
|
+
obj._app = app # needed for CLI right now
|
960
909
|
obj._obj = None
|
961
910
|
obj._is_generator = is_generator
|
962
|
-
obj.
|
963
|
-
obj.
|
911
|
+
obj._cluster_size = cluster_size
|
912
|
+
obj._is_method = False
|
913
|
+
obj._spec = function_spec # needed for modal shell
|
964
914
|
|
965
|
-
# Used to check whether we should rebuild
|
966
|
-
|
967
|
-
# hash. We can't use the cloudpickle hash because it's not very stable.
|
915
|
+
# Used to check whether we should rebuild a modal.Image which uses `run_function`.
|
916
|
+
gpus: list[GPU_T] = gpu if isinstance(gpu, list) else [gpu]
|
968
917
|
obj._build_args = dict( # See get_build_def
|
969
918
|
secrets=repr(secrets),
|
970
|
-
gpu_config=repr(
|
919
|
+
gpu_config=repr([parse_gpu_config(_gpu) for _gpu in gpus]),
|
971
920
|
mounts=repr(mounts),
|
972
921
|
network_file_systems=repr(network_file_systems),
|
973
922
|
)
|
923
|
+
# these key are excluded if empty to avoid rebuilds on client upgrade
|
924
|
+
if volumes:
|
925
|
+
obj._build_args["volumes"] = repr(volumes)
|
926
|
+
if cloud or scheduler_placement:
|
927
|
+
obj._build_args["cloud"] = repr(cloud)
|
928
|
+
obj._build_args["scheduler_placement"] = repr(scheduler_placement)
|
974
929
|
|
975
930
|
return obj
|
976
931
|
|
977
|
-
def
|
932
|
+
def _bind_parameters(
|
978
933
|
self,
|
979
|
-
obj,
|
980
|
-
from_other_workspace: bool,
|
934
|
+
obj: "modal.cls._Obj",
|
981
935
|
options: Optional[api_pb2.FunctionOptions],
|
982
936
|
args: Sized,
|
983
|
-
kwargs:
|
937
|
+
kwargs: dict[str, Any],
|
984
938
|
) -> "_Function":
|
985
|
-
"""mdmd:hidden
|
939
|
+
"""mdmd:hidden
|
986
940
|
|
987
|
-
|
988
|
-
|
941
|
+
Binds a class-function to a specific instance of (init params, options) or a new workspace
|
942
|
+
"""
|
943
|
+
|
944
|
+
# In some cases, reuse the base function, i.e. not create new clones of each method or the "service function"
|
945
|
+
can_use_parent = len(args) + len(kwargs) == 0 and options is None
|
946
|
+
parent = self
|
947
|
+
|
948
|
+
async def _load(param_bound_func: _Function, resolver: Resolver, existing_object_id: Optional[str]):
|
949
|
+
if parent is None:
|
950
|
+
raise ExecutionError("Can't find the parent class' service function")
|
951
|
+
try:
|
952
|
+
identity = f"{parent.info.function_name} class service function"
|
953
|
+
except Exception:
|
954
|
+
# Can't always look up the function name that way, so fall back to generic message
|
955
|
+
identity = "class service function for a parameterized class"
|
956
|
+
if not parent.is_hydrated:
|
957
|
+
if parent.app._running_app is None:
|
958
|
+
reason = ", because the App it is defined on is not running"
|
959
|
+
else:
|
960
|
+
reason = ""
|
989
961
|
raise ExecutionError(
|
990
|
-
"
|
991
|
-
" defined on a different stub, or if it's on the same stub but it didn't get"
|
992
|
-
" created because it wasn't defined in global scope."
|
962
|
+
f"The {identity} has not been hydrated with the metadata it needs to run on Modal{reason}."
|
993
963
|
)
|
994
|
-
|
995
|
-
|
964
|
+
|
965
|
+
assert parent._client.stub
|
966
|
+
|
967
|
+
if can_use_parent:
|
968
|
+
# We can end up here if parent wasn't hydrated when class was instantiated, but has been since.
|
969
|
+
param_bound_func._hydrate_from_other(parent)
|
970
|
+
return
|
971
|
+
|
972
|
+
if (
|
973
|
+
parent._class_parameter_info
|
974
|
+
and parent._class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO
|
975
|
+
):
|
976
|
+
if args:
|
977
|
+
# TODO(elias) - We could potentially support positional args as well, if we want to?
|
978
|
+
raise InvalidError(
|
979
|
+
"Can't use positional arguments with modal.parameter-based synthetic constructors.\n"
|
980
|
+
"Use (<parameter_name>=value) keyword arguments when constructing classes instead."
|
981
|
+
)
|
982
|
+
serialized_params = serialize_proto_params(kwargs, parent._class_parameter_info.schema)
|
983
|
+
else:
|
984
|
+
serialized_params = serialize((args, kwargs))
|
996
985
|
environment_name = _get_environment_name(None, resolver)
|
986
|
+
assert parent is not None
|
997
987
|
req = api_pb2.FunctionBindParamsRequest(
|
998
|
-
function_id=
|
988
|
+
function_id=parent._object_id,
|
999
989
|
serialized_params=serialized_params,
|
1000
990
|
function_options=options,
|
1001
991
|
environment_name=environment_name
|
1002
992
|
or "", # TODO: investigate shouldn't environment name always be specified here?
|
1003
993
|
)
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
994
|
+
|
995
|
+
response = await retry_transient_errors(parent._client.stub.FunctionBindParams, req)
|
996
|
+
param_bound_func._hydrate(response.bound_function_id, parent._client, response.handle_metadata)
|
997
|
+
|
998
|
+
fun: _Function = _Function._from_loader(_load, "Function(parametrized)", hydrate_lazily=True)
|
999
|
+
|
1000
|
+
if can_use_parent and parent.is_hydrated:
|
1001
|
+
# skip the resolver altogether:
|
1002
|
+
fun._hydrate_from_other(parent)
|
1003
|
+
|
1012
1004
|
fun._info = self._info
|
1013
1005
|
fun._obj = obj
|
1014
|
-
fun._is_generator = self._is_generator
|
1015
|
-
fun._is_method = True
|
1016
|
-
fun._parent = self
|
1017
|
-
|
1018
1006
|
return fun
|
1019
1007
|
|
1020
1008
|
@live_method
|
1021
1009
|
async def keep_warm(self, warm_pool_size: int) -> None:
|
1022
|
-
"""Set the warm pool size for the function
|
1010
|
+
"""Set the warm pool size for the function.
|
1023
1011
|
|
1024
|
-
Please exercise care when using this advanced feature!
|
1012
|
+
Please exercise care when using this advanced feature!
|
1013
|
+
Setting and forgetting a warm pool on functions can lead to increased costs.
|
1025
1014
|
|
1026
|
-
```python
|
1015
|
+
```python notest
|
1027
1016
|
# Usage on a regular function.
|
1028
1017
|
f = modal.Function.lookup("my-app", "function")
|
1029
1018
|
f.keep_warm(2)
|
1030
1019
|
|
1031
1020
|
# Usage on a parametrized function.
|
1032
1021
|
Model = modal.Cls.lookup("my-app", "Model")
|
1033
|
-
Model("fine-tuned-model").
|
1022
|
+
Model("fine-tuned-model").keep_warm(2)
|
1034
1023
|
```
|
1035
1024
|
"""
|
1025
|
+
if self._is_method:
|
1026
|
+
raise InvalidError(
|
1027
|
+
textwrap.dedent(
|
1028
|
+
"""
|
1029
|
+
The `.keep_warm()` method can not be used on Modal class *methods* deployed using Modal >v0.63.
|
1036
1030
|
|
1031
|
+
Call `.keep_warm()` on the class *instance* instead.
|
1032
|
+
"""
|
1033
|
+
)
|
1034
|
+
)
|
1037
1035
|
assert self._client and self._client.stub
|
1038
1036
|
request = api_pb2.FunctionUpdateSchedulingParamsRequest(
|
1039
1037
|
function_id=self._object_id, warm_pool_size_override=warm_pool_size
|
@@ -1041,17 +1039,22 @@ class _Function(_Object, type_prefix="fu"):
|
|
1041
1039
|
await retry_transient_errors(self._client.stub.FunctionUpdateSchedulingParams, request)
|
1042
1040
|
|
1043
1041
|
@classmethod
|
1042
|
+
@renamed_parameter((2024, 12, 18), "tag", "name")
|
1044
1043
|
def from_name(
|
1045
|
-
cls:
|
1044
|
+
cls: type["_Function"],
|
1046
1045
|
app_name: str,
|
1047
|
-
|
1046
|
+
name: str,
|
1048
1047
|
namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE,
|
1049
1048
|
environment_name: Optional[str] = None,
|
1050
1049
|
) -> "_Function":
|
1051
|
-
"""
|
1050
|
+
"""Reference a Function from a deployed App by its name.
|
1051
|
+
|
1052
|
+
In contast to `modal.Function.lookup`, this is a lazy method
|
1053
|
+
that defers hydrating the local object with metadata from
|
1054
|
+
Modal servers until the first time it is actually used.
|
1052
1055
|
|
1053
1056
|
```python
|
1054
|
-
|
1057
|
+
f = modal.Function.from_name("other-app", "function")
|
1055
1058
|
```
|
1056
1059
|
"""
|
1057
1060
|
|
@@ -1059,7 +1062,7 @@ class _Function(_Object, type_prefix="fu"):
|
|
1059
1062
|
assert resolver.client and resolver.client.stub
|
1060
1063
|
request = api_pb2.FunctionGetRequest(
|
1061
1064
|
app_name=app_name,
|
1062
|
-
object_tag=
|
1065
|
+
object_tag=name,
|
1063
1066
|
namespace=namespace,
|
1064
1067
|
environment_name=_get_environment_name(environment_name, resolver) or "",
|
1065
1068
|
)
|
@@ -1071,26 +1074,32 @@ class _Function(_Object, type_prefix="fu"):
|
|
1071
1074
|
else:
|
1072
1075
|
raise
|
1073
1076
|
|
1077
|
+
print_server_warnings(response.server_warnings)
|
1078
|
+
|
1074
1079
|
self._hydrate(response.function_id, resolver.client, response.handle_metadata)
|
1075
1080
|
|
1076
1081
|
rep = f"Ref({app_name})"
|
1077
|
-
return cls._from_loader(_load_remote, rep, is_another_app=True)
|
1082
|
+
return cls._from_loader(_load_remote, rep, is_another_app=True, hydrate_lazily=True)
|
1078
1083
|
|
1079
1084
|
@staticmethod
|
1085
|
+
@renamed_parameter((2024, 12, 18), "tag", "name")
|
1080
1086
|
async def lookup(
|
1081
1087
|
app_name: str,
|
1082
|
-
|
1088
|
+
name: str,
|
1083
1089
|
namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE,
|
1084
1090
|
client: Optional[_Client] = None,
|
1085
1091
|
environment_name: Optional[str] = None,
|
1086
1092
|
) -> "_Function":
|
1087
|
-
"""Lookup a
|
1093
|
+
"""Lookup a Function from a deployed App by its name.
|
1088
1094
|
|
1089
|
-
|
1090
|
-
|
1095
|
+
In contrast to `modal.Function.from_name`, this is an eager method
|
1096
|
+
that will hydrate the local object with metadata from Modal servers.
|
1097
|
+
|
1098
|
+
```python notest
|
1099
|
+
f = modal.Function.lookup("other-app", "function")
|
1091
1100
|
```
|
1092
1101
|
"""
|
1093
|
-
obj = _Function.from_name(app_name,
|
1102
|
+
obj = _Function.from_name(app_name, name, namespace=namespace, environment_name=environment_name)
|
1094
1103
|
if client is None:
|
1095
1104
|
client = await _Client.from_env()
|
1096
1105
|
resolver = Resolver(client=client)
|
@@ -1104,9 +1113,18 @@ class _Function(_Object, type_prefix="fu"):
|
|
1104
1113
|
return self._tag
|
1105
1114
|
|
1106
1115
|
@property
|
1107
|
-
def
|
1116
|
+
def app(self) -> "modal.app._App":
|
1108
1117
|
"""mdmd:hidden"""
|
1109
|
-
|
1118
|
+
if self._app is None:
|
1119
|
+
raise ExecutionError("The app has not been assigned on the function at this point")
|
1120
|
+
|
1121
|
+
return self._app
|
1122
|
+
|
1123
|
+
@property
|
1124
|
+
def stub(self) -> "modal.app._App":
|
1125
|
+
"""mdmd:hidden"""
|
1126
|
+
# Deprecated soon, only for backwards compatibility
|
1127
|
+
return self.app
|
1110
1128
|
|
1111
1129
|
@property
|
1112
1130
|
def info(self) -> FunctionInfo:
|
@@ -1115,12 +1133,15 @@ class _Function(_Object, type_prefix="fu"):
|
|
1115
1133
|
return self._info
|
1116
1134
|
|
1117
1135
|
@property
|
1118
|
-
def
|
1136
|
+
def spec(self) -> _FunctionSpec:
|
1119
1137
|
"""mdmd:hidden"""
|
1120
|
-
|
1138
|
+
assert self._spec
|
1139
|
+
return self._spec
|
1121
1140
|
|
1122
1141
|
def get_build_def(self) -> str:
|
1123
1142
|
"""mdmd:hidden"""
|
1143
|
+
# Plaintext source and arg definition for the function, so it's part of the image
|
1144
|
+
# hash. We can't use the cloudpickle hash because it's not very stable.
|
1124
1145
|
assert hasattr(self, "_raw_f") and hasattr(self, "_build_args")
|
1125
1146
|
return f"{inspect.getsource(self._raw_f)}\n{repr(self._build_args)}"
|
1126
1147
|
|
@@ -1130,208 +1151,170 @@ class _Function(_Object, type_prefix="fu"):
|
|
1130
1151
|
# Overridden concrete implementation of base class method
|
1131
1152
|
self._progress = None
|
1132
1153
|
self._is_generator = None
|
1154
|
+
self._cluster_size = None
|
1133
1155
|
self._web_url = None
|
1134
|
-
self._output_mgr: Optional[OutputManager] = None
|
1135
|
-
self._mute_cancellation = (
|
1136
|
-
False # set when a user terminates the app intentionally, to prevent useless traceback spam
|
1137
|
-
)
|
1138
1156
|
self._function_name = None
|
1139
1157
|
self._info = None
|
1158
|
+
self._serve_mounts = frozenset()
|
1140
1159
|
|
1141
1160
|
def _hydrate_metadata(self, metadata: Optional[Message]):
|
1142
1161
|
# Overridden concrete implementation of base class method
|
1143
|
-
assert metadata and isinstance(metadata,
|
1162
|
+
assert metadata and isinstance(metadata, api_pb2.FunctionHandleMetadata)
|
1144
1163
|
self._is_generator = metadata.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR
|
1145
1164
|
self._web_url = metadata.web_url
|
1146
1165
|
self._function_name = metadata.function_name
|
1147
1166
|
self._is_method = metadata.is_method
|
1167
|
+
self._use_method_name = metadata.use_method_name
|
1168
|
+
self._class_parameter_info = metadata.class_parameter_info
|
1169
|
+
self._method_handle_metadata = dict(metadata.method_handle_metadata)
|
1170
|
+
self._definition_id = metadata.definition_id
|
1148
1171
|
|
1149
1172
|
def _get_metadata(self):
|
1150
1173
|
# Overridden concrete implementation of base class method
|
1151
|
-
assert self._function_name
|
1174
|
+
assert self._function_name, f"Function name must be set before metadata can be retrieved for {self}"
|
1152
1175
|
return api_pb2.FunctionHandleMetadata(
|
1153
1176
|
function_name=self._function_name,
|
1154
|
-
function_type=(
|
1155
|
-
api_pb2.Function.FUNCTION_TYPE_GENERATOR
|
1156
|
-
if self._is_generator
|
1157
|
-
else api_pb2.Function.FUNCTION_TYPE_FUNCTION
|
1158
|
-
),
|
1177
|
+
function_type=get_function_type(self._is_generator),
|
1159
1178
|
web_url=self._web_url or "",
|
1179
|
+
use_method_name=self._use_method_name,
|
1180
|
+
is_method=self._is_method,
|
1181
|
+
class_parameter_info=self._class_parameter_info,
|
1182
|
+
definition_id=self._definition_id,
|
1183
|
+
method_handle_metadata=self._method_handle_metadata,
|
1160
1184
|
)
|
1161
1185
|
|
1162
|
-
def
|
1163
|
-
self.
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1186
|
+
def _check_no_web_url(self, fn_name: str):
|
1187
|
+
if self._web_url:
|
1188
|
+
raise InvalidError(
|
1189
|
+
f"A webhook function cannot be invoked for remote execution with `.{fn_name}`. "
|
1190
|
+
f"Invoke this function via its web url '{self._web_url}' "
|
1191
|
+
+ f"or call it locally: {self._function_name}.local()"
|
1192
|
+
)
|
1167
1193
|
|
1194
|
+
# TODO (live_method on properties is not great, since it could be blocking the event loop from async contexts)
|
1168
1195
|
@property
|
1169
|
-
|
1196
|
+
@live_method
|
1197
|
+
async def web_url(self) -> str:
|
1170
1198
|
"""URL of a Function running as a web endpoint."""
|
1171
1199
|
if not self._web_url:
|
1172
1200
|
raise ValueError(
|
1173
|
-
f"No web_url can be found for function {self._function_name}. web_url
|
1201
|
+
f"No web_url can be found for function {self._function_name}. web_url "
|
1202
|
+
"can only be referenced from a running app context"
|
1174
1203
|
)
|
1175
1204
|
return self._web_url
|
1176
1205
|
|
1177
1206
|
@property
|
1178
|
-
def is_generator(self) -> bool:
|
1207
|
+
async def is_generator(self) -> bool:
|
1179
1208
|
"""mdmd:hidden"""
|
1180
|
-
|
1209
|
+
# hacky: kind of like @live_method, but not hydrating if we have the value already from local source
|
1210
|
+
if self._is_generator is not None:
|
1211
|
+
# this is set if the function or class is local
|
1212
|
+
return self._is_generator
|
1213
|
+
|
1214
|
+
# not set - this is a from_name lookup - hydrate
|
1215
|
+
await self.resolve()
|
1216
|
+
assert self._is_generator is not None # should be set now
|
1181
1217
|
return self._is_generator
|
1182
1218
|
|
1183
|
-
|
1184
|
-
|
1185
|
-
|
1186
|
-
|
1187
|
-
|
1188
|
-
|
1219
|
+
@property
|
1220
|
+
def cluster_size(self) -> int:
|
1221
|
+
"""mdmd:hidden"""
|
1222
|
+
return self._cluster_size or 1
|
1223
|
+
|
1224
|
+
@live_method_gen
|
1225
|
+
async def _map(
|
1226
|
+
self, input_queue: _SynchronizedQueue, order_outputs: bool, return_exceptions: bool
|
1227
|
+
) -> AsyncGenerator[Any, None]:
|
1228
|
+
"""mdmd:hidden
|
1229
|
+
|
1230
|
+
Synchronicity-wrapped map implementation. To be safe against invocations of user code in
|
1231
|
+
the synchronicity thread it doesn't accept an [async]iterator, and instead takes a
|
1232
|
+
_SynchronizedQueue instance that is fed by higher level functions like .map()
|
1233
|
+
|
1234
|
+
_SynchronizedQueue is used instead of asyncio.Queue so that the main thread can put
|
1235
|
+
items in the queue safely.
|
1236
|
+
"""
|
1237
|
+
self._check_no_web_url("map")
|
1189
1238
|
if self._is_generator:
|
1190
1239
|
raise InvalidError("A generator function cannot be called with `.map(...)`.")
|
1191
1240
|
|
1192
1241
|
assert self._function_name
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1242
|
+
if output_mgr := _get_output_manager():
|
1243
|
+
count_update_callback = output_mgr.function_progress_callback(self._function_name, total=None)
|
1244
|
+
else:
|
1245
|
+
count_update_callback = None
|
1246
|
+
|
1247
|
+
async with aclosing(
|
1248
|
+
_map_invocation(
|
1249
|
+
self, # type: ignore
|
1250
|
+
input_queue,
|
1251
|
+
self._client,
|
1252
|
+
order_outputs,
|
1253
|
+
return_exceptions,
|
1254
|
+
count_update_callback,
|
1255
|
+
)
|
1256
|
+
) as stream:
|
1257
|
+
async for item in stream:
|
1258
|
+
yield item
|
1196
1259
|
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1260
|
+
async def _call_function(self, args, kwargs) -> ReturnType:
|
1261
|
+
if config.get("client_retries"):
|
1262
|
+
function_call_invocation_type = api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
|
1263
|
+
else:
|
1264
|
+
function_call_invocation_type = api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY
|
1265
|
+
invocation = await _Invocation.create(
|
1266
|
+
self,
|
1267
|
+
args,
|
1200
1268
|
kwargs,
|
1201
|
-
self._client,
|
1202
|
-
|
1203
|
-
|
1204
|
-
count_update_callback,
|
1205
|
-
):
|
1206
|
-
yield item
|
1269
|
+
client=self._client,
|
1270
|
+
function_call_invocation_type=function_call_invocation_type,
|
1271
|
+
)
|
1207
1272
|
|
1208
|
-
|
1209
|
-
invocation = await _Invocation.create(self.object_id, args, kwargs, self._client)
|
1210
|
-
try:
|
1211
|
-
return await invocation.run_function()
|
1212
|
-
except asyncio.CancelledError:
|
1213
|
-
# this can happen if the user terminates a program, triggering a cancellation cascade
|
1214
|
-
if not self._mute_cancellation:
|
1215
|
-
raise
|
1273
|
+
return await invocation.run_function()
|
1216
1274
|
|
1217
|
-
async def _call_function_nowait(
|
1218
|
-
|
1275
|
+
async def _call_function_nowait(
|
1276
|
+
self, args, kwargs, function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType"
|
1277
|
+
) -> _Invocation:
|
1278
|
+
return await _Invocation.create(
|
1279
|
+
self, args, kwargs, client=self._client, function_call_invocation_type=function_call_invocation_type
|
1280
|
+
)
|
1219
1281
|
|
1220
|
-
@warn_if_generator_is_not_consumed
|
1282
|
+
@warn_if_generator_is_not_consumed()
|
1221
1283
|
@live_method_gen
|
1222
1284
|
@synchronizer.no_input_translation
|
1223
1285
|
async def _call_generator(self, args, kwargs):
|
1224
|
-
invocation = await _Invocation.create(
|
1286
|
+
invocation = await _Invocation.create(
|
1287
|
+
self,
|
1288
|
+
args,
|
1289
|
+
kwargs,
|
1290
|
+
client=self._client,
|
1291
|
+
function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY,
|
1292
|
+
)
|
1225
1293
|
async for res in invocation.run_generator():
|
1226
1294
|
yield res
|
1227
1295
|
|
1228
1296
|
@synchronizer.no_io_translation
|
1229
1297
|
async def _call_generator_nowait(self, args, kwargs):
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1241
|
-
) -> AsyncGenerator[Any, None]:
|
1242
|
-
"""Parallel map over a set of inputs.
|
1243
|
-
|
1244
|
-
Takes one iterator argument per argument in the function being mapped over.
|
1245
|
-
|
1246
|
-
Example:
|
1247
|
-
```python
|
1248
|
-
@stub.function()
|
1249
|
-
def my_func(a):
|
1250
|
-
return a ** 2
|
1251
|
-
|
1252
|
-
|
1253
|
-
@stub.local_entrypoint()
|
1254
|
-
def main():
|
1255
|
-
assert list(my_func.map([1, 2, 3, 4])) == [1, 4, 9, 16]
|
1256
|
-
```
|
1257
|
-
|
1258
|
-
If applied to a `stub.function`, `map()` returns one result per input and the output order
|
1259
|
-
is guaranteed to be the same as the input order. Set `order_outputs=False` to return results
|
1260
|
-
in the order that they are completed instead.
|
1261
|
-
|
1262
|
-
`return_exceptions` can be used to treat exceptions as successful results:
|
1263
|
-
|
1264
|
-
```python
|
1265
|
-
@stub.function()
|
1266
|
-
def my_func(a):
|
1267
|
-
if a == 2:
|
1268
|
-
raise Exception("ohno")
|
1269
|
-
return a ** 2
|
1270
|
-
|
1271
|
-
|
1272
|
-
@stub.local_entrypoint()
|
1273
|
-
def main():
|
1274
|
-
# [0, 1, UserCodeException(Exception('ohno'))]
|
1275
|
-
print(list(my_func.map(range(3), return_exceptions=True)))
|
1276
|
-
```
|
1277
|
-
"""
|
1278
|
-
|
1279
|
-
input_stream = stream.zip(*(stream.iterate(it) for it in input_iterators))
|
1280
|
-
async for item in self._map(input_stream, order_outputs, return_exceptions, kwargs):
|
1281
|
-
yield item
|
1282
|
-
|
1283
|
-
@synchronizer.no_input_translation
|
1284
|
-
async def for_each(self, *input_iterators, kwargs={}, ignore_exceptions: bool = False):
|
1285
|
-
"""Execute function for all inputs, ignoring outputs.
|
1286
|
-
|
1287
|
-
Convenient alias for `.map()` in cases where the function just needs to be called.
|
1288
|
-
as the caller doesn't have to consume the generator to process the inputs.
|
1289
|
-
"""
|
1290
|
-
# TODO(erikbern): it would be better if this is more like a map_spawn that immediately exits
|
1291
|
-
# rather than iterating over the result
|
1292
|
-
async for _ in self.map(
|
1293
|
-
*input_iterators, kwargs=kwargs, order_outputs=False, return_exceptions=ignore_exceptions
|
1294
|
-
):
|
1295
|
-
pass
|
1296
|
-
|
1297
|
-
@warn_if_generator_is_not_consumed
|
1298
|
-
@live_method_gen
|
1299
|
-
@synchronizer.no_input_translation
|
1300
|
-
async def starmap(
|
1301
|
-
self, input_iterator, kwargs={}, order_outputs: bool = True, return_exceptions: bool = False
|
1302
|
-
) -> AsyncGenerator[Any, None]:
|
1303
|
-
"""Like `map`, but spreads arguments over multiple function arguments.
|
1304
|
-
|
1305
|
-
Assumes every input is a sequence (e.g. a tuple).
|
1306
|
-
|
1307
|
-
Example:
|
1308
|
-
```python
|
1309
|
-
@stub.function()
|
1310
|
-
def my_func(a, b):
|
1311
|
-
return a + b
|
1312
|
-
|
1313
|
-
|
1314
|
-
@stub.local_entrypoint()
|
1315
|
-
def main():
|
1316
|
-
assert list(my_func.starmap([(1, 2), (3, 4)])) == [3, 7]
|
1317
|
-
```
|
1318
|
-
"""
|
1319
|
-
input_stream = stream.iterate(input_iterator)
|
1320
|
-
async for item in self._map(input_stream, order_outputs, return_exceptions, kwargs):
|
1321
|
-
yield item
|
1298
|
+
deprecation_warning(
|
1299
|
+
(2024, 12, 11),
|
1300
|
+
"Calling spawn on a generator function is deprecated and will soon raise an exception.",
|
1301
|
+
)
|
1302
|
+
return await _Invocation.create(
|
1303
|
+
self,
|
1304
|
+
args,
|
1305
|
+
kwargs,
|
1306
|
+
client=self._client,
|
1307
|
+
function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC_LEGACY,
|
1308
|
+
)
|
1322
1309
|
|
1323
1310
|
@synchronizer.no_io_translation
|
1324
1311
|
@live_method
|
1325
|
-
async def remote(self, *args, **kwargs) ->
|
1312
|
+
async def remote(self, *args: P.args, **kwargs: P.kwargs) -> ReturnType:
|
1326
1313
|
"""
|
1327
1314
|
Calls the function remotely, executing it with the given arguments and returning the execution's result.
|
1328
1315
|
"""
|
1329
1316
|
# TODO: Generics/TypeVars
|
1330
|
-
|
1331
|
-
raise InvalidError(
|
1332
|
-
"A web endpoint function cannot be invoked for remote execution with `.remote`. "
|
1333
|
-
f"Invoke this function via its web url '{self._web_url}' or call it locally: {self._function_name}()."
|
1334
|
-
)
|
1317
|
+
self._check_no_web_url("remote")
|
1335
1318
|
if self._is_generator:
|
1336
1319
|
raise InvalidError(
|
1337
1320
|
"A generator function cannot be called with `.remote(...)`. Use `.remote_gen(...)` instead."
|
@@ -1346,11 +1329,7 @@ class _Function(_Object, type_prefix="fu"):
|
|
1346
1329
|
Calls the generator remotely, executing it with the given arguments and returning the execution's result.
|
1347
1330
|
"""
|
1348
1331
|
# TODO: Generics/TypeVars
|
1349
|
-
|
1350
|
-
raise InvalidError(
|
1351
|
-
"A web endpoint function cannot be invoked for remote execution with `.remote`. "
|
1352
|
-
f"Invoke this function via its web url '{self._web_url}' or call it locally: {self._function_name}()."
|
1353
|
-
)
|
1332
|
+
self._check_no_web_url("remote_gen")
|
1354
1333
|
|
1355
1334
|
if not self._is_generator:
|
1356
1335
|
raise InvalidError(
|
@@ -1359,22 +1338,15 @@ class _Function(_Object, type_prefix="fu"):
|
|
1359
1338
|
async for item in self._call_generator(args, kwargs): # type: ignore
|
1360
1339
|
yield item
|
1361
1340
|
|
1362
|
-
|
1363
|
-
|
1364
|
-
async def shell(self, *args, **kwargs) -> None:
|
1365
|
-
if self._is_generator:
|
1366
|
-
async for item in self._call_generator(args, kwargs):
|
1367
|
-
pass
|
1368
|
-
else:
|
1369
|
-
await self._call_function(args, kwargs)
|
1341
|
+
def _is_local(self):
|
1342
|
+
return self._info is not None
|
1370
1343
|
|
1371
|
-
def
|
1372
|
-
|
1373
|
-
|
1374
|
-
def _get_info(self):
|
1344
|
+
def _get_info(self) -> FunctionInfo:
|
1345
|
+
if not self._info:
|
1346
|
+
raise ExecutionError("Can't get info for a function that isn't locally defined")
|
1375
1347
|
return self._info
|
1376
1348
|
|
1377
|
-
def _get_obj(self):
|
1349
|
+
def _get_obj(self) -> Optional["modal.cls._Obj"]:
|
1378
1350
|
if not self._is_method:
|
1379
1351
|
return None
|
1380
1352
|
elif not self._obj:
|
@@ -1383,83 +1355,129 @@ class _Function(_Object, type_prefix="fu"):
|
|
1383
1355
|
return self._obj
|
1384
1356
|
|
1385
1357
|
@synchronizer.nowrap
|
1386
|
-
def local(self, *args, **kwargs) ->
|
1358
|
+
def local(self, *args: P.args, **kwargs: P.kwargs) -> OriginalReturnType:
|
1387
1359
|
"""
|
1388
1360
|
Calls the function locally, executing it with the given arguments and returning the execution's result.
|
1389
|
-
|
1361
|
+
|
1362
|
+
The function will execute in the same environment as the caller, just like calling the underlying function
|
1363
|
+
directly in Python. In particular, only secrets available in the caller environment will be available
|
1364
|
+
through environment variables.
|
1390
1365
|
"""
|
1391
1366
|
# TODO(erikbern): it would be nice to remove the nowrap thing, but right now that would cause
|
1392
1367
|
# "user code" to run on the synchronicity thread, which seems bad
|
1393
|
-
|
1394
|
-
if not info:
|
1368
|
+
if not self._is_local():
|
1395
1369
|
msg = (
|
1396
|
-
"The definition for this function is missing so it is not possible to invoke it locally. "
|
1370
|
+
"The definition for this function is missing here so it is not possible to invoke it locally. "
|
1397
1371
|
"If this function was retrieved via `Function.lookup` you need to use `.remote()`."
|
1398
1372
|
)
|
1399
1373
|
raise ExecutionError(msg)
|
1400
1374
|
|
1401
|
-
|
1375
|
+
info = self._get_info()
|
1376
|
+
if not info.raw_f:
|
1377
|
+
# Here if calling .local on a service function itself which should never happen
|
1378
|
+
# TODO: check if we end up here in a container for a serialized function?
|
1379
|
+
raise ExecutionError("Can't call .local on service function")
|
1380
|
+
|
1381
|
+
if is_local() and self.spec.volumes or self.spec.network_file_systems:
|
1382
|
+
warnings.warn(
|
1383
|
+
f"The {info.function_name} function is executing locally "
|
1384
|
+
+ "and will not have access to the mounted Volume or NetworkFileSystem data"
|
1385
|
+
)
|
1386
|
+
|
1387
|
+
obj: Optional["modal.cls._Obj"] = self._get_obj()
|
1402
1388
|
|
1403
1389
|
if not obj:
|
1404
1390
|
fun = info.raw_f
|
1405
1391
|
return fun(*args, **kwargs)
|
1406
1392
|
else:
|
1407
1393
|
# This is a method on a class, so bind the self to the function
|
1408
|
-
|
1409
|
-
fun = info.raw_f.__get__(
|
1394
|
+
user_cls_instance = obj._cached_user_cls_instance()
|
1395
|
+
fun = info.raw_f.__get__(user_cls_instance)
|
1410
1396
|
|
1397
|
+
# TODO: replace implicit local enter/exit with a context manager
|
1411
1398
|
if is_async(info.raw_f):
|
1412
1399
|
# We want to run __aenter__ and fun in the same coroutine
|
1413
1400
|
async def coro():
|
1414
|
-
await obj.
|
1401
|
+
await obj._aenter()
|
1415
1402
|
return await fun(*args, **kwargs)
|
1416
1403
|
|
1417
|
-
return coro()
|
1404
|
+
return coro() # type: ignore
|
1418
1405
|
else:
|
1419
|
-
obj.
|
1406
|
+
obj._enter()
|
1420
1407
|
return fun(*args, **kwargs)
|
1421
1408
|
|
1422
1409
|
@synchronizer.no_input_translation
|
1423
1410
|
@live_method
|
1424
|
-
async def
|
1425
|
-
"""Calls the function with the given arguments, without waiting for the results.
|
1411
|
+
async def _experimental_spawn(self, *args: P.args, **kwargs: P.kwargs) -> "_FunctionCall[ReturnType]":
|
1412
|
+
"""[Experimental] Calls the function with the given arguments, without waiting for the results.
|
1426
1413
|
|
1427
|
-
|
1414
|
+
This experimental version of the spawn method allows up to 1 million inputs to be spawned.
|
1415
|
+
|
1416
|
+
Returns a `modal.functions.FunctionCall` object, that can later be polled or
|
1417
|
+
waited for using `.get(timeout=...)`.
|
1428
1418
|
Conceptually similar to `multiprocessing.pool.apply_async`, or a Future/Promise in other contexts.
|
1419
|
+
"""
|
1420
|
+
self._check_no_web_url("_experimental_spawn")
|
1421
|
+
if self._is_generator:
|
1422
|
+
invocation = await self._call_generator_nowait(args, kwargs)
|
1423
|
+
else:
|
1424
|
+
invocation = await self._call_function_nowait(
|
1425
|
+
args, kwargs, function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC
|
1426
|
+
)
|
1429
1427
|
|
1430
|
-
|
1431
|
-
|
1428
|
+
fc = _FunctionCall._new_hydrated(invocation.function_call_id, invocation.client, None)
|
1429
|
+
fc._is_generator = self._is_generator if self._is_generator else False
|
1430
|
+
return fc
|
1431
|
+
|
1432
|
+
@synchronizer.no_input_translation
|
1433
|
+
@live_method
|
1434
|
+
async def spawn(self, *args: P.args, **kwargs: P.kwargs) -> "_FunctionCall[ReturnType]":
|
1435
|
+
"""Calls the function with the given arguments, without waiting for the results.
|
1436
|
+
|
1437
|
+
Returns a `modal.functions.FunctionCall` object, that can later be polled or
|
1438
|
+
waited for using `.get(timeout=...)`.
|
1439
|
+
Conceptually similar to `multiprocessing.pool.apply_async`, or a Future/Promise in other contexts.
|
1432
1440
|
"""
|
1441
|
+
self._check_no_web_url("spawn")
|
1433
1442
|
if self._is_generator:
|
1434
|
-
await self._call_generator_nowait(args, kwargs)
|
1435
|
-
|
1443
|
+
invocation = await self._call_generator_nowait(args, kwargs)
|
1444
|
+
else:
|
1445
|
+
invocation = await self._call_function_nowait(
|
1446
|
+
args, kwargs, api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC_LEGACY
|
1447
|
+
)
|
1436
1448
|
|
1437
|
-
|
1438
|
-
|
1449
|
+
fc = _FunctionCall._new_hydrated(invocation.function_call_id, invocation.client, None)
|
1450
|
+
fc._is_generator = self._is_generator if self._is_generator else False
|
1451
|
+
return fc
|
1439
1452
|
|
1440
1453
|
def get_raw_f(self) -> Callable[..., Any]:
|
1441
1454
|
"""Return the inner Python object wrapped by this Modal Function."""
|
1442
|
-
|
1443
|
-
raise AttributeError("_info has not been set on this FunctionHandle and not available in this context")
|
1444
|
-
|
1445
|
-
return self._info.raw_f
|
1455
|
+
return self._raw_f
|
1446
1456
|
|
1447
1457
|
@live_method
|
1448
1458
|
async def get_current_stats(self) -> FunctionStats:
|
1449
1459
|
"""Return a `FunctionStats` object describing the current function's queue and runner counts."""
|
1450
1460
|
assert self._client.stub
|
1451
|
-
resp = await
|
1452
|
-
|
1453
|
-
|
1454
|
-
|
1455
|
-
backlog=resp.backlog, num_active_runners=resp.num_active_tasks, num_total_runners=resp.num_total_tasks
|
1461
|
+
resp = await retry_transient_errors(
|
1462
|
+
self._client.stub.FunctionGetCurrentStats,
|
1463
|
+
api_pb2.FunctionGetCurrentStatsRequest(function_id=self.object_id),
|
1464
|
+
total_timeout=10.0,
|
1456
1465
|
)
|
1466
|
+
return FunctionStats(backlog=resp.backlog, num_total_runners=resp.num_total_tasks)
|
1467
|
+
|
1468
|
+
# A bit hacky - but the map-style functions need to not be synchronicity-wrapped
|
1469
|
+
# in order to not execute their input iterators on the synchronicity event loop.
|
1470
|
+
# We still need to wrap them using MethodWithAio to maintain a synchronicity-like
|
1471
|
+
# api with `.aio` and get working type-stubs and reference docs generation:
|
1472
|
+
map = MethodWithAio(_map_sync, _map_async, synchronizer)
|
1473
|
+
starmap = MethodWithAio(_starmap_sync, _starmap_async, synchronizer)
|
1474
|
+
for_each = MethodWithAio(_for_each_sync, _for_each_async, synchronizer)
|
1457
1475
|
|
1458
1476
|
|
1459
1477
|
Function = synchronize_api(_Function)
|
1460
1478
|
|
1461
1479
|
|
1462
|
-
class _FunctionCall(_Object, type_prefix="fc"):
|
1480
|
+
class _FunctionCall(typing.Generic[ReturnType], _Object, type_prefix="fc"):
|
1463
1481
|
"""A reference to an executed function call.
|
1464
1482
|
|
1465
1483
|
Constructed using `.spawn(...)` on a Modal function with the same
|
@@ -1470,11 +1488,13 @@ class _FunctionCall(_Object, type_prefix="fc"):
|
|
1470
1488
|
Conceptually similar to a Future/Promise/AsyncResult in other contexts and languages.
|
1471
1489
|
"""
|
1472
1490
|
|
1491
|
+
_is_generator: bool = False
|
1492
|
+
|
1473
1493
|
def _invocation(self):
|
1474
1494
|
assert self._client.stub
|
1475
1495
|
return _Invocation(self._client.stub, self.object_id, self._client)
|
1476
1496
|
|
1477
|
-
async def get(self, timeout: Optional[float] = None):
|
1497
|
+
async def get(self, timeout: Optional[float] = None) -> ReturnType:
|
1478
1498
|
"""Get the result of the function call.
|
1479
1499
|
|
1480
1500
|
This function waits indefinitely by default. It takes an optional
|
@@ -1483,9 +1503,23 @@ class _FunctionCall(_Object, type_prefix="fc"):
|
|
1483
1503
|
|
1484
1504
|
The returned coroutine is not cancellation-safe.
|
1485
1505
|
"""
|
1506
|
+
|
1507
|
+
if self._is_generator:
|
1508
|
+
raise Exception("Cannot get the result of a generator function call. Use `get_gen` instead.")
|
1509
|
+
|
1486
1510
|
return await self._invocation().poll_function(timeout=timeout)
|
1487
1511
|
|
1488
|
-
async def
|
1512
|
+
async def get_gen(self) -> AsyncGenerator[Any, None]:
|
1513
|
+
"""
|
1514
|
+
Calls the generator remotely, executing it with the given arguments and returning the execution's result.
|
1515
|
+
"""
|
1516
|
+
if not self._is_generator:
|
1517
|
+
raise Exception("Cannot iterate over a non-generator function call. Use `get` instead.")
|
1518
|
+
|
1519
|
+
async for res in self._invocation().run_generator():
|
1520
|
+
yield res
|
1521
|
+
|
1522
|
+
async def get_call_graph(self) -> list[InputInfo]:
|
1489
1523
|
"""Returns a structure representing the call graph from a given root
|
1490
1524
|
call ID, along with the status of execution for each node.
|
1491
1525
|
|
@@ -1497,24 +1531,38 @@ class _FunctionCall(_Object, type_prefix="fc"):
|
|
1497
1531
|
response = await retry_transient_errors(self._client.stub.FunctionGetCallGraph, request)
|
1498
1532
|
return _reconstruct_call_graph(response)
|
1499
1533
|
|
1500
|
-
async def cancel(
|
1501
|
-
|
1502
|
-
|
1534
|
+
async def cancel(
|
1535
|
+
self,
|
1536
|
+
terminate_containers: bool = False, # if true, containers running the inputs are forcibly terminated
|
1537
|
+
):
|
1538
|
+
"""Cancels the function call, which will stop its execution and mark its inputs as
|
1539
|
+
[`TERMINATED`](/docs/reference/modal.call_graph#modalcall_graphinputstatus).
|
1540
|
+
|
1541
|
+
If `terminate_containers=True` - the containers running the cancelled inputs are all terminated
|
1542
|
+
causing any non-cancelled inputs on those containers to be rescheduled in new containers.
|
1543
|
+
"""
|
1544
|
+
request = api_pb2.FunctionCallCancelRequest(
|
1545
|
+
function_call_id=self.object_id, terminate_containers=terminate_containers
|
1546
|
+
)
|
1503
1547
|
assert self._client and self._client.stub
|
1504
1548
|
await retry_transient_errors(self._client.stub.FunctionCallCancel, request)
|
1505
1549
|
|
1506
1550
|
@staticmethod
|
1507
|
-
async def from_id(
|
1551
|
+
async def from_id(
|
1552
|
+
function_call_id: str, client: Optional[_Client] = None, is_generator: bool = False
|
1553
|
+
) -> "_FunctionCall":
|
1508
1554
|
if client is None:
|
1509
1555
|
client = await _Client.from_env()
|
1510
1556
|
|
1511
|
-
|
1557
|
+
fc = _FunctionCall._new_hydrated(function_call_id, client, None)
|
1558
|
+
fc._is_generator = is_generator
|
1559
|
+
return fc
|
1512
1560
|
|
1513
1561
|
|
1514
1562
|
FunctionCall = synchronize_api(_FunctionCall)
|
1515
1563
|
|
1516
1564
|
|
1517
|
-
async def _gather(*function_calls: _FunctionCall):
|
1565
|
+
async def _gather(*function_calls: _FunctionCall[ReturnType]) -> typing.Sequence[ReturnType]:
|
1518
1566
|
"""Wait until all Modal function calls have results before returning
|
1519
1567
|
|
1520
1568
|
Accepts a variable number of FunctionCall objects as returned by `Function.spawn()`.
|
@@ -1532,63 +1580,10 @@ async def _gather(*function_calls: _FunctionCall):
|
|
1532
1580
|
```
|
1533
1581
|
"""
|
1534
1582
|
try:
|
1535
|
-
return await
|
1583
|
+
return await TaskContext.gather(*[fc.get() for fc in function_calls])
|
1536
1584
|
except Exception as exc:
|
1537
1585
|
# TODO: kill all running function calls
|
1538
1586
|
raise exc
|
1539
1587
|
|
1540
1588
|
|
1541
1589
|
gather = synchronize_api(_gather)
|
1542
|
-
|
1543
|
-
|
1544
|
-
_current_input_id: ContextVar = ContextVar("_current_input_id")
|
1545
|
-
_current_function_call_id: ContextVar = ContextVar("_current_function_call_id")
|
1546
|
-
|
1547
|
-
|
1548
|
-
def current_input_id() -> Optional[str]:
|
1549
|
-
"""Returns the input ID for the current input.
|
1550
|
-
|
1551
|
-
Can only be called from Modal function (i.e. in a container context).
|
1552
|
-
|
1553
|
-
```python
|
1554
|
-
from modal import current_input_id
|
1555
|
-
|
1556
|
-
@stub.function()
|
1557
|
-
def process_stuff():
|
1558
|
-
print(f"Starting to process {current_input_id()}")
|
1559
|
-
```
|
1560
|
-
"""
|
1561
|
-
try:
|
1562
|
-
return _current_input_id.get()
|
1563
|
-
except LookupError:
|
1564
|
-
return None
|
1565
|
-
|
1566
|
-
|
1567
|
-
def current_function_call_id() -> Optional[str]:
|
1568
|
-
"""Returns the function call ID for the current input.
|
1569
|
-
|
1570
|
-
Can only be called from Modal function (i.e. in a container context).
|
1571
|
-
|
1572
|
-
```python
|
1573
|
-
from modal import current_function_call_id
|
1574
|
-
|
1575
|
-
@stub.function()
|
1576
|
-
def process_stuff():
|
1577
|
-
print(f"Starting to process input from {current_function_call_id()}")
|
1578
|
-
```
|
1579
|
-
"""
|
1580
|
-
try:
|
1581
|
-
return _current_function_call_id.get()
|
1582
|
-
except LookupError:
|
1583
|
-
return None
|
1584
|
-
|
1585
|
-
|
1586
|
-
def _set_current_context_ids(input_id: str, function_call_id: str) -> Callable[[], None]:
|
1587
|
-
input_token = _current_input_id.set(input_id)
|
1588
|
-
function_call_token = _current_function_call_id.set(function_call_id)
|
1589
|
-
|
1590
|
-
def _reset_current_context_ids():
|
1591
|
-
_current_input_id.reset(input_token)
|
1592
|
-
_current_function_call_id.reset(function_call_token)
|
1593
|
-
|
1594
|
-
return _reset_current_context_ids
|