modal 0.62.115__py3-none-any.whl → 0.72.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/__init__.py +13 -9
- modal/__main__.py +41 -3
- modal/_clustered_functions.py +80 -0
- modal/_clustered_functions.pyi +22 -0
- modal/_container_entrypoint.py +402 -398
- modal/_ipython.py +3 -13
- modal/_location.py +17 -10
- modal/_output.py +243 -99
- modal/_pty.py +2 -2
- modal/_resolver.py +55 -60
- modal/_resources.py +26 -7
- modal/_runtime/__init__.py +1 -0
- modal/_runtime/asgi.py +519 -0
- modal/_runtime/container_io_manager.py +1025 -0
- modal/{execution_context.py → _runtime/execution_context.py} +11 -2
- modal/_runtime/telemetry.py +169 -0
- modal/_runtime/user_code_imports.py +356 -0
- modal/_serialization.py +123 -6
- modal/_traceback.py +47 -187
- modal/_tunnel.py +50 -14
- modal/_tunnel.pyi +19 -36
- modal/_utils/app_utils.py +3 -17
- modal/_utils/async_utils.py +386 -104
- modal/_utils/blob_utils.py +157 -186
- modal/_utils/bytes_io_segment_payload.py +97 -0
- modal/_utils/deprecation.py +89 -0
- modal/_utils/docker_utils.py +98 -0
- modal/_utils/function_utils.py +299 -98
- modal/_utils/grpc_testing.py +47 -34
- modal/_utils/grpc_utils.py +54 -21
- modal/_utils/hash_utils.py +51 -10
- modal/_utils/http_utils.py +39 -9
- modal/_utils/logger.py +2 -1
- modal/_utils/mount_utils.py +34 -16
- modal/_utils/name_utils.py +58 -0
- modal/_utils/package_utils.py +14 -1
- modal/_utils/pattern_utils.py +205 -0
- modal/_utils/rand_pb_testing.py +3 -3
- modal/_utils/shell_utils.py +15 -49
- modal/_vendor/a2wsgi_wsgi.py +62 -72
- modal/_vendor/cloudpickle.py +1 -1
- modal/_watcher.py +12 -10
- modal/app.py +561 -323
- modal/app.pyi +474 -262
- modal/call_graph.py +7 -6
- modal/cli/_download.py +22 -6
- modal/cli/_traceback.py +200 -0
- modal/cli/app.py +203 -42
- modal/cli/config.py +12 -5
- modal/cli/container.py +61 -13
- modal/cli/dict.py +128 -0
- modal/cli/entry_point.py +26 -13
- modal/cli/environment.py +40 -9
- modal/cli/import_refs.py +21 -48
- modal/cli/launch.py +28 -14
- modal/cli/network_file_system.py +57 -21
- modal/cli/profile.py +1 -1
- modal/cli/programs/run_jupyter.py +34 -9
- modal/cli/programs/vscode.py +58 -8
- modal/cli/queues.py +131 -0
- modal/cli/run.py +199 -96
- modal/cli/secret.py +5 -4
- modal/cli/token.py +7 -2
- modal/cli/utils.py +74 -8
- modal/cli/volume.py +97 -56
- modal/client.py +248 -144
- modal/client.pyi +156 -124
- modal/cloud_bucket_mount.py +43 -30
- modal/cloud_bucket_mount.pyi +32 -25
- modal/cls.py +528 -141
- modal/cls.pyi +189 -145
- modal/config.py +32 -15
- modal/container_process.py +177 -0
- modal/container_process.pyi +82 -0
- modal/dict.py +50 -54
- modal/dict.pyi +120 -164
- modal/environments.py +106 -5
- modal/environments.pyi +77 -25
- modal/exception.py +30 -43
- modal/experimental.py +62 -2
- modal/file_io.py +537 -0
- modal/file_io.pyi +235 -0
- modal/file_pattern_matcher.py +196 -0
- modal/functions.py +846 -428
- modal/functions.pyi +446 -387
- modal/gpu.py +57 -44
- modal/image.py +943 -417
- modal/image.pyi +584 -245
- modal/io_streams.py +434 -0
- modal/io_streams.pyi +122 -0
- modal/mount.py +223 -90
- modal/mount.pyi +241 -243
- modal/network_file_system.py +85 -86
- modal/network_file_system.pyi +151 -110
- modal/object.py +66 -36
- modal/object.pyi +166 -143
- modal/output.py +63 -0
- modal/parallel_map.py +73 -47
- modal/parallel_map.pyi +51 -63
- modal/partial_function.py +272 -107
- modal/partial_function.pyi +219 -120
- modal/proxy.py +15 -12
- modal/proxy.pyi +3 -8
- modal/queue.py +96 -72
- modal/queue.pyi +210 -135
- modal/requirements/2024.04.txt +2 -1
- modal/requirements/2024.10.txt +16 -0
- modal/requirements/README.md +21 -0
- modal/requirements/base-images.json +22 -0
- modal/retries.py +45 -4
- modal/runner.py +325 -203
- modal/runner.pyi +124 -110
- modal/running_app.py +27 -4
- modal/sandbox.py +509 -231
- modal/sandbox.pyi +396 -169
- modal/schedule.py +2 -2
- modal/scheduler_placement.py +20 -3
- modal/secret.py +41 -25
- modal/secret.pyi +62 -42
- modal/serving.py +39 -49
- modal/serving.pyi +37 -43
- modal/stream_type.py +15 -0
- modal/token_flow.py +5 -3
- modal/token_flow.pyi +37 -32
- modal/volume.py +123 -137
- modal/volume.pyi +228 -221
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/METADATA +5 -5
- modal-0.72.13.dist-info/RECORD +174 -0
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/top_level.txt +0 -1
- modal_docs/gen_reference_docs.py +3 -1
- modal_docs/mdmd/mdmd.py +0 -1
- modal_docs/mdmd/signatures.py +1 -2
- modal_global_objects/images/base_images.py +28 -0
- modal_global_objects/mounts/python_standalone.py +2 -2
- modal_proto/__init__.py +1 -1
- modal_proto/api.proto +1231 -531
- modal_proto/api_grpc.py +750 -430
- modal_proto/api_pb2.py +2102 -1176
- modal_proto/api_pb2.pyi +8859 -0
- modal_proto/api_pb2_grpc.py +1329 -675
- modal_proto/api_pb2_grpc.pyi +1416 -0
- modal_proto/modal_api_grpc.py +149 -0
- modal_proto/modal_options_grpc.py +3 -0
- modal_proto/options_pb2.pyi +20 -0
- modal_proto/options_pb2_grpc.pyi +7 -0
- modal_proto/py.typed +0 -0
- modal_version/__init__.py +1 -1
- modal_version/_version_generated.py +2 -2
- modal/_asgi.py +0 -370
- modal/_container_exec.py +0 -128
- modal/_container_io_manager.py +0 -646
- modal/_container_io_manager.pyi +0 -412
- modal/_sandbox_shell.py +0 -49
- modal/app_utils.py +0 -20
- modal/app_utils.pyi +0 -17
- modal/execution_context.pyi +0 -37
- modal/shared_volume.py +0 -23
- modal/shared_volume.pyi +0 -24
- modal-0.62.115.dist-info/RECORD +0 -207
- modal_global_objects/images/conda.py +0 -15
- modal_global_objects/images/debian_slim.py +0 -15
- modal_global_objects/images/micromamba.py +0 -15
- test/__init__.py +0 -1
- test/aio_test.py +0 -12
- test/async_utils_test.py +0 -279
- test/blob_test.py +0 -67
- test/cli_imports_test.py +0 -149
- test/cli_test.py +0 -674
- test/client_test.py +0 -203
- test/cloud_bucket_mount_test.py +0 -22
- test/cls_test.py +0 -636
- test/config_test.py +0 -149
- test/conftest.py +0 -1485
- test/container_app_test.py +0 -50
- test/container_test.py +0 -1405
- test/cpu_test.py +0 -23
- test/decorator_test.py +0 -85
- test/deprecation_test.py +0 -34
- test/dict_test.py +0 -51
- test/e2e_test.py +0 -68
- test/error_test.py +0 -7
- test/function_serialization_test.py +0 -32
- test/function_test.py +0 -791
- test/function_utils_test.py +0 -101
- test/gpu_test.py +0 -159
- test/grpc_utils_test.py +0 -82
- test/helpers.py +0 -47
- test/image_test.py +0 -814
- test/live_reload_test.py +0 -80
- test/lookup_test.py +0 -70
- test/mdmd_test.py +0 -329
- test/mount_test.py +0 -162
- test/mounted_files_test.py +0 -327
- test/network_file_system_test.py +0 -188
- test/notebook_test.py +0 -66
- test/object_test.py +0 -41
- test/package_utils_test.py +0 -25
- test/queue_test.py +0 -115
- test/resolver_test.py +0 -59
- test/retries_test.py +0 -67
- test/runner_test.py +0 -85
- test/sandbox_test.py +0 -191
- test/schedule_test.py +0 -15
- test/scheduler_placement_test.py +0 -57
- test/secret_test.py +0 -89
- test/serialization_test.py +0 -50
- test/stub_composition_test.py +0 -10
- test/stub_test.py +0 -361
- test/test_asgi_wrapper.py +0 -234
- test/token_flow_test.py +0 -18
- test/traceback_test.py +0 -135
- test/tunnel_test.py +0 -29
- test/utils_test.py +0 -88
- test/version_test.py +0 -14
- test/volume_test.py +0 -397
- test/watcher_test.py +0 -58
- test/webhook_test.py +0 -145
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/LICENSE +0 -0
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/WHEEL +0 -0
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/entry_points.txt +0 -0
modal/_container_io_manager.py
DELETED
@@ -1,646 +0,0 @@
|
|
1
|
-
# Copyright Modal Labs 2024
|
2
|
-
import asyncio
|
3
|
-
import json
|
4
|
-
import math
|
5
|
-
import os
|
6
|
-
import signal
|
7
|
-
import time
|
8
|
-
import traceback
|
9
|
-
from pathlib import Path
|
10
|
-
from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, List, Optional, Set, Tuple
|
11
|
-
|
12
|
-
from google.protobuf.empty_pb2 import Empty
|
13
|
-
from google.protobuf.message import Message
|
14
|
-
from grpclib import Status
|
15
|
-
from synchronicity.async_wrap import asynccontextmanager
|
16
|
-
|
17
|
-
from modal_proto import api_pb2
|
18
|
-
|
19
|
-
from ._serialization import deserialize, deserialize_data_format, serialize, serialize_data_format
|
20
|
-
from ._traceback import extract_traceback
|
21
|
-
from ._utils.async_utils import TaskContext, asyncify, synchronize_api, synchronizer
|
22
|
-
from ._utils.blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload
|
23
|
-
from ._utils.function_utils import _stream_function_call_data
|
24
|
-
from ._utils.grpc_utils import get_proto_oneof, retry_transient_errors
|
25
|
-
from .client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client
|
26
|
-
from .config import config, logger
|
27
|
-
from .exception import InputCancellation, InvalidError
|
28
|
-
from .running_app import RunningApp
|
29
|
-
|
30
|
-
MAX_OUTPUT_BATCH_SIZE: int = 49
|
31
|
-
|
32
|
-
RTT_S: float = 0.5 # conservative estimate of RTT in seconds.
|
33
|
-
|
34
|
-
|
35
|
-
class UserException(Exception):
|
36
|
-
"""Used to shut down the task gracefully."""
|
37
|
-
|
38
|
-
|
39
|
-
class Sentinel:
|
40
|
-
"""Used to get type-stubs to work with this object."""
|
41
|
-
|
42
|
-
|
43
|
-
class _ContainerIOManager:
|
44
|
-
"""Synchronizes all RPC calls and network operations for a running container.
|
45
|
-
|
46
|
-
TODO: maybe we shouldn't synchronize the whole class.
|
47
|
-
Then we could potentially move a bunch of the global functions onto it.
|
48
|
-
"""
|
49
|
-
|
50
|
-
cancelled_input_ids: Set[str]
|
51
|
-
task_id: str
|
52
|
-
function_id: str
|
53
|
-
app_id: str
|
54
|
-
function_def: api_pb2.Function
|
55
|
-
checkpoint_id: Optional[str]
|
56
|
-
|
57
|
-
calls_completed: int
|
58
|
-
total_user_time: float
|
59
|
-
current_input_id: Optional[str]
|
60
|
-
current_input_started_at: Optional[float]
|
61
|
-
|
62
|
-
_input_concurrency: Optional[int]
|
63
|
-
_semaphore: Optional[asyncio.Semaphore]
|
64
|
-
_environment_name: str
|
65
|
-
_waiting_for_checkpoint: bool
|
66
|
-
_heartbeat_loop: Optional[asyncio.Task]
|
67
|
-
|
68
|
-
_is_interactivity_enabled: bool
|
69
|
-
_fetching_inputs: bool
|
70
|
-
|
71
|
-
_client: _Client
|
72
|
-
|
73
|
-
_GENERATOR_STOP_SENTINEL: ClassVar[Sentinel] = Sentinel()
|
74
|
-
_singleton: ClassVar[Optional["_ContainerIOManager"]] = None
|
75
|
-
|
76
|
-
def _init(self, container_args: api_pb2.ContainerArguments, client: _Client):
|
77
|
-
self.cancelled_input_ids = set()
|
78
|
-
self.task_id = container_args.task_id
|
79
|
-
self.function_id = container_args.function_id
|
80
|
-
self.app_id = container_args.app_id
|
81
|
-
self.function_def = container_args.function_def
|
82
|
-
self.checkpoint_id = container_args.checkpoint_id or None
|
83
|
-
|
84
|
-
self.calls_completed = 0
|
85
|
-
self.total_user_time = 0.0
|
86
|
-
self.current_input_id = None
|
87
|
-
self.current_input_started_at = None
|
88
|
-
|
89
|
-
self._input_concurrency = None
|
90
|
-
|
91
|
-
self._semaphore = None
|
92
|
-
self._environment_name = container_args.environment_name
|
93
|
-
self._waiting_for_checkpoint = False
|
94
|
-
self._heartbeat_loop = None
|
95
|
-
|
96
|
-
self._is_interactivity_enabled = False
|
97
|
-
self._fetching_inputs = True
|
98
|
-
|
99
|
-
self._client = client
|
100
|
-
assert isinstance(self._client, _Client)
|
101
|
-
|
102
|
-
def __new__(cls, container_args: api_pb2.ContainerArguments, client: _Client) -> "_ContainerIOManager":
|
103
|
-
cls._singleton = super().__new__(cls)
|
104
|
-
cls._singleton._init(container_args, client)
|
105
|
-
return cls._singleton
|
106
|
-
|
107
|
-
@classmethod
|
108
|
-
def _reset_singleton(cls):
|
109
|
-
"""Only used for tests."""
|
110
|
-
cls._singleton = None
|
111
|
-
|
112
|
-
async def _run_heartbeat_loop(self):
|
113
|
-
while 1:
|
114
|
-
t0 = time.monotonic()
|
115
|
-
try:
|
116
|
-
if await self._heartbeat_handle_cancellations():
|
117
|
-
# got a cancellation event, fine to start another heartbeat immediately
|
118
|
-
# since the cancellation queue should be empty on the worker server
|
119
|
-
# however, we wait at least 1s to prevent short-circuiting the heartbeat loop
|
120
|
-
# in case there is ever a bug. This means it will take at least 1s between
|
121
|
-
# two subsequent cancellations on the same task at the moment
|
122
|
-
await asyncio.sleep(1.0)
|
123
|
-
continue
|
124
|
-
except Exception as exc:
|
125
|
-
# don't stop heartbeat loop if there are transient exceptions!
|
126
|
-
time_elapsed = time.monotonic() - t0
|
127
|
-
error = exc
|
128
|
-
logger.warning(f"Heartbeat attempt failed ({time_elapsed=}, {error=})")
|
129
|
-
|
130
|
-
heartbeat_duration = time.monotonic() - t0
|
131
|
-
time_until_next_hearbeat = max(0.0, HEARTBEAT_INTERVAL - heartbeat_duration)
|
132
|
-
await asyncio.sleep(time_until_next_hearbeat)
|
133
|
-
|
134
|
-
async def _heartbeat_handle_cancellations(self) -> bool:
|
135
|
-
# Return True if a cancellation event was received, in that case we shouldn't wait too long for another heartbeat
|
136
|
-
|
137
|
-
# Don't send heartbeats for tasks waiting to be checkpointed.
|
138
|
-
# Calling gRPC methods open new connections which block the
|
139
|
-
# checkpointing process.
|
140
|
-
if self._waiting_for_checkpoint:
|
141
|
-
return False
|
142
|
-
|
143
|
-
request = api_pb2.ContainerHeartbeatRequest(supports_graceful_input_cancellation=True)
|
144
|
-
if self.current_input_id is not None:
|
145
|
-
request.current_input_id = self.current_input_id
|
146
|
-
if self.current_input_started_at is not None:
|
147
|
-
request.current_input_started_at = self.current_input_started_at
|
148
|
-
|
149
|
-
# TODO(erikbern): capture exceptions?
|
150
|
-
response = await retry_transient_errors(
|
151
|
-
self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
|
152
|
-
)
|
153
|
-
|
154
|
-
if response.HasField("cancel_input_event"):
|
155
|
-
# Pause processing of the current input by signaling self a SIGUSR1.
|
156
|
-
input_ids_to_cancel = response.cancel_input_event.input_ids
|
157
|
-
if input_ids_to_cancel:
|
158
|
-
if self._input_concurrency > 1:
|
159
|
-
logger.info(
|
160
|
-
"Shutting down task to stop some subset of inputs (concurrent functions don't support fine-grained cancellation)"
|
161
|
-
)
|
162
|
-
# This is equivalent to a task cancellation or preemption from worker code,
|
163
|
-
# except we do not send a SIGKILL to forcefully exit after 30 seconds.
|
164
|
-
#
|
165
|
-
# SIGINT always interrupts the main thread, but not any auxiliary threads. On a
|
166
|
-
# sync function without concurrent inputs, this raises a KeyboardInterrupt. When
|
167
|
-
# there are concurrent inputs, we cannot interrupt the thread pool, but the
|
168
|
-
# interpreter stops waiting for daemon threads and exits. On async functions,
|
169
|
-
# this signal lands outside the event loop, stopping `run_until_complete()`.
|
170
|
-
os.kill(os.getpid(), signal.SIGINT)
|
171
|
-
|
172
|
-
elif self.current_input_id in input_ids_to_cancel:
|
173
|
-
# This goes to a registered signal handler for sync Modal functions, or to the
|
174
|
-
# `SignalHandlingEventLoop` for async functions.
|
175
|
-
#
|
176
|
-
# We only send this signal on functions that do not have concurrent inputs enabled.
|
177
|
-
# This allows us to do fine-grained input cancellation. On sync functions, the
|
178
|
-
# SIGUSR1 signal should interrupt the main thread where user code is running,
|
179
|
-
# raising an InputCancellation() exception. On async functions, the signal should
|
180
|
-
# reach a handler in SignalHandlingEventLoop, which cancels the task.
|
181
|
-
os.kill(os.getpid(), signal.SIGUSR1)
|
182
|
-
return True
|
183
|
-
return False
|
184
|
-
|
185
|
-
@asynccontextmanager
|
186
|
-
async def heartbeats(self) -> AsyncGenerator[None, None]:
|
187
|
-
async with TaskContext() as tc:
|
188
|
-
self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop())
|
189
|
-
t.set_name("heartbeat loop")
|
190
|
-
try:
|
191
|
-
yield
|
192
|
-
finally:
|
193
|
-
t.cancel()
|
194
|
-
|
195
|
-
def stop_heartbeat(self):
|
196
|
-
if self._heartbeat_loop:
|
197
|
-
self._heartbeat_loop.cancel()
|
198
|
-
|
199
|
-
async def get_app_objects(self) -> RunningApp:
|
200
|
-
req = api_pb2.AppGetObjectsRequest(app_id=self.app_id, include_unindexed=True)
|
201
|
-
resp = await retry_transient_errors(self._client.stub.AppGetObjects, req)
|
202
|
-
logger.debug(f"AppGetObjects received {len(resp.items)} objects for app {self.app_id}")
|
203
|
-
|
204
|
-
tag_to_object_id = {}
|
205
|
-
object_handle_metadata = {}
|
206
|
-
for item in resp.items:
|
207
|
-
handle_metadata: Optional[Message] = get_proto_oneof(item.object, "handle_metadata_oneof")
|
208
|
-
object_handle_metadata[item.object.object_id] = handle_metadata
|
209
|
-
if item.tag:
|
210
|
-
tag_to_object_id[item.tag] = item.object.object_id
|
211
|
-
|
212
|
-
return RunningApp(
|
213
|
-
self.app_id,
|
214
|
-
environment_name=self._environment_name,
|
215
|
-
tag_to_object_id=tag_to_object_id,
|
216
|
-
object_handle_metadata=object_handle_metadata,
|
217
|
-
)
|
218
|
-
|
219
|
-
async def get_serialized_function(self) -> Tuple[Optional[Any], Callable]:
|
220
|
-
# Fetch the serialized function definition
|
221
|
-
request = api_pb2.FunctionGetSerializedRequest(function_id=self.function_id)
|
222
|
-
response = await self._client.stub.FunctionGetSerialized(request)
|
223
|
-
fun = self.deserialize(response.function_serialized)
|
224
|
-
|
225
|
-
if response.class_serialized:
|
226
|
-
cls = self.deserialize(response.class_serialized)
|
227
|
-
else:
|
228
|
-
cls = None
|
229
|
-
|
230
|
-
return cls, fun
|
231
|
-
|
232
|
-
def serialize(self, obj: Any) -> bytes:
|
233
|
-
return serialize(obj)
|
234
|
-
|
235
|
-
def deserialize(self, data: bytes) -> Any:
|
236
|
-
return deserialize(data, self._client)
|
237
|
-
|
238
|
-
@synchronizer.no_io_translation
|
239
|
-
def serialize_data_format(self, obj: Any, data_format: int) -> bytes:
|
240
|
-
return serialize_data_format(obj, data_format)
|
241
|
-
|
242
|
-
def deserialize_data_format(self, data: bytes, data_format: int) -> Any:
|
243
|
-
return deserialize_data_format(data, data_format, self._client)
|
244
|
-
|
245
|
-
async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]:
|
246
|
-
"""Read from the `data_in` stream of a function call."""
|
247
|
-
async for data in _stream_function_call_data(self._client, function_call_id, "data_in"):
|
248
|
-
yield data
|
249
|
-
|
250
|
-
async def put_data_out(
|
251
|
-
self,
|
252
|
-
function_call_id: str,
|
253
|
-
start_index: int,
|
254
|
-
data_format: int,
|
255
|
-
messages_bytes: List[Any],
|
256
|
-
) -> None:
|
257
|
-
"""Put data onto the `data_out` stream of a function call.
|
258
|
-
|
259
|
-
This is used for generator outputs, which includes web endpoint responses. Note that this
|
260
|
-
was introduced as a performance optimization in client version 0.57, so older clients will
|
261
|
-
still use the previous Postgres-backed system based on `FunctionPutOutputs()`.
|
262
|
-
"""
|
263
|
-
data_chunks: List[api_pb2.DataChunk] = []
|
264
|
-
for i, message_bytes in enumerate(messages_bytes):
|
265
|
-
chunk = api_pb2.DataChunk(data_format=data_format, index=start_index + i) # type: ignore
|
266
|
-
if len(message_bytes) > MAX_OBJECT_SIZE_BYTES:
|
267
|
-
chunk.data_blob_id = await blob_upload(message_bytes, self._client.stub)
|
268
|
-
else:
|
269
|
-
chunk.data = message_bytes
|
270
|
-
data_chunks.append(chunk)
|
271
|
-
|
272
|
-
req = api_pb2.FunctionCallPutDataRequest(function_call_id=function_call_id, data_chunks=data_chunks)
|
273
|
-
await retry_transient_errors(self._client.stub.FunctionCallPutDataOut, req)
|
274
|
-
|
275
|
-
async def generator_output_task(self, function_call_id: str, data_format: int, message_rx: asyncio.Queue) -> None:
|
276
|
-
"""Task that feeds generator outputs into a function call's `data_out` stream."""
|
277
|
-
index = 1
|
278
|
-
received_sentinel = False
|
279
|
-
while not received_sentinel:
|
280
|
-
message = await message_rx.get()
|
281
|
-
if message is self._GENERATOR_STOP_SENTINEL:
|
282
|
-
break
|
283
|
-
# ASGI 'http.response.start' and 'http.response.body' msgs are observed to be separated by 1ms.
|
284
|
-
# If we don't sleep here for 1ms we end up with an extra call to .put_data_out().
|
285
|
-
if index == 1:
|
286
|
-
await asyncio.sleep(0.001)
|
287
|
-
messages_bytes = [serialize_data_format(message, data_format)]
|
288
|
-
total_size = len(messages_bytes[0]) + 512
|
289
|
-
while total_size < 16 * 1024 * 1024: # 16 MiB, maximum size in a single message
|
290
|
-
try:
|
291
|
-
message = message_rx.get_nowait()
|
292
|
-
except asyncio.QueueEmpty:
|
293
|
-
break
|
294
|
-
if message is self._GENERATOR_STOP_SENTINEL:
|
295
|
-
received_sentinel = True
|
296
|
-
break
|
297
|
-
else:
|
298
|
-
messages_bytes.append(serialize_data_format(message, data_format))
|
299
|
-
total_size += len(messages_bytes[-1]) + 512 # 512 bytes for estimated framing overhead
|
300
|
-
await self.put_data_out(function_call_id, index, data_format, messages_bytes)
|
301
|
-
index += len(messages_bytes)
|
302
|
-
|
303
|
-
async def _queue_create(self, size: int) -> asyncio.Queue:
|
304
|
-
"""Create a queue, on the synchronicity event loop (needed on Python 3.8 and 3.9)."""
|
305
|
-
return asyncio.Queue(size)
|
306
|
-
|
307
|
-
async def _queue_put(self, queue: asyncio.Queue, value: Any) -> None:
|
308
|
-
"""Put a value onto a queue, using the synchronicity event loop."""
|
309
|
-
await queue.put(value)
|
310
|
-
|
311
|
-
async def populate_input_blobs(self, item: api_pb2.FunctionInput):
|
312
|
-
args = await blob_download(item.args_blob_id, self._client.stub)
|
313
|
-
|
314
|
-
# Mutating
|
315
|
-
item.ClearField("args_blob_id")
|
316
|
-
item.args = args
|
317
|
-
return item
|
318
|
-
|
319
|
-
def get_average_call_time(self) -> float:
|
320
|
-
if self.calls_completed == 0:
|
321
|
-
return 0
|
322
|
-
|
323
|
-
return self.total_user_time / self.calls_completed
|
324
|
-
|
325
|
-
def get_max_inputs_to_fetch(self):
|
326
|
-
if self.calls_completed == 0:
|
327
|
-
return 1
|
328
|
-
|
329
|
-
return math.ceil(RTT_S / max(self.get_average_call_time(), 1e-6))
|
330
|
-
|
331
|
-
@synchronizer.no_io_translation
|
332
|
-
async def _generate_inputs(self) -> AsyncIterator[Tuple[str, str, api_pb2.FunctionInput]]:
|
333
|
-
request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id)
|
334
|
-
eof_received = False
|
335
|
-
iteration = 0
|
336
|
-
while not eof_received and self._fetching_inputs:
|
337
|
-
request.average_call_time = self.get_average_call_time()
|
338
|
-
request.max_values = self.get_max_inputs_to_fetch() # Deprecated; remove.
|
339
|
-
request.input_concurrency = self._input_concurrency
|
340
|
-
|
341
|
-
await self._semaphore.acquire()
|
342
|
-
yielded = False
|
343
|
-
try:
|
344
|
-
# If number of active inputs is at max queue size, this will block.
|
345
|
-
iteration += 1
|
346
|
-
response: api_pb2.FunctionGetInputsResponse = await retry_transient_errors(
|
347
|
-
self._client.stub.FunctionGetInputs, request
|
348
|
-
)
|
349
|
-
|
350
|
-
if response.rate_limit_sleep_duration:
|
351
|
-
logger.info(
|
352
|
-
"Task exceeded rate limit, sleeping for %.2fs before trying again."
|
353
|
-
% response.rate_limit_sleep_duration
|
354
|
-
)
|
355
|
-
await asyncio.sleep(response.rate_limit_sleep_duration)
|
356
|
-
elif response.inputs:
|
357
|
-
# for input cancellations and concurrency logic we currently assume
|
358
|
-
# that there is no input buffering in the container
|
359
|
-
assert len(response.inputs) == 1
|
360
|
-
|
361
|
-
for item in response.inputs:
|
362
|
-
if item.kill_switch:
|
363
|
-
logger.debug(f"Task {self.task_id} input kill signal input.")
|
364
|
-
eof_received = True
|
365
|
-
break
|
366
|
-
if item.input_id in self.cancelled_input_ids:
|
367
|
-
continue
|
368
|
-
|
369
|
-
# If we got a pointer to a blob, download it from S3.
|
370
|
-
if item.input.WhichOneof("args_oneof") == "args_blob_id":
|
371
|
-
input_pb = await self.populate_input_blobs(item.input)
|
372
|
-
else:
|
373
|
-
input_pb = item.input
|
374
|
-
|
375
|
-
# If yielded, allow semaphore to be released via complete_call
|
376
|
-
yield (item.input_id, item.function_call_id, input_pb)
|
377
|
-
yielded = True
|
378
|
-
|
379
|
-
# We only support max_inputs = 1 at the moment
|
380
|
-
if item.input.final_input or self.function_def.max_inputs == 1:
|
381
|
-
eof_received = True
|
382
|
-
break
|
383
|
-
finally:
|
384
|
-
if not yielded:
|
385
|
-
self._semaphore.release()
|
386
|
-
|
387
|
-
@synchronizer.no_io_translation
|
388
|
-
async def run_inputs_outputs(self, input_concurrency: int = 1) -> AsyncIterator[Tuple[str, str, Any, Any]]:
|
389
|
-
# Ensure we do not fetch new inputs when container is too busy.
|
390
|
-
# Before trying to fetch an input, acquire the semaphore:
|
391
|
-
# - if no input is fetched, release the semaphore.
|
392
|
-
# - or, when the output for the fetched input is sent, release the semaphore.
|
393
|
-
self._input_concurrency = input_concurrency
|
394
|
-
self._semaphore = asyncio.Semaphore(input_concurrency)
|
395
|
-
|
396
|
-
try:
|
397
|
-
async for input_id, function_call_id, input_pb in self._generate_inputs():
|
398
|
-
args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {})
|
399
|
-
self.current_input_id, self.current_input_started_at = (input_id, time.time())
|
400
|
-
yield input_id, function_call_id, args, kwargs
|
401
|
-
self.current_input_id, self.current_input_started_at = (None, None)
|
402
|
-
finally:
|
403
|
-
# collect all active input slots, meaning all inputs have wrapped up.
|
404
|
-
for _ in range(input_concurrency):
|
405
|
-
await self._semaphore.acquire()
|
406
|
-
|
407
|
-
async def _push_output(self, input_id, started_at: float, data_format=api_pb2.DATA_FORMAT_UNSPECIFIED, **kwargs):
|
408
|
-
# upload data to S3 if too big.
|
409
|
-
if "data" in kwargs and kwargs["data"] and len(kwargs["data"]) > MAX_OBJECT_SIZE_BYTES:
|
410
|
-
data_blob_id = await blob_upload(kwargs["data"], self._client.stub)
|
411
|
-
# mutating kwargs.
|
412
|
-
del kwargs["data"]
|
413
|
-
kwargs["data_blob_id"] = data_blob_id
|
414
|
-
|
415
|
-
output = api_pb2.FunctionPutOutputsItem(
|
416
|
-
input_id=input_id,
|
417
|
-
input_started_at=started_at,
|
418
|
-
output_created_at=time.time(),
|
419
|
-
result=api_pb2.GenericResult(**kwargs),
|
420
|
-
data_format=data_format,
|
421
|
-
)
|
422
|
-
|
423
|
-
await retry_transient_errors(
|
424
|
-
self._client.stub.FunctionPutOutputs,
|
425
|
-
api_pb2.FunctionPutOutputsRequest(outputs=[output]),
|
426
|
-
additional_status_codes=[Status.RESOURCE_EXHAUSTED],
|
427
|
-
max_retries=None, # Retry indefinitely, trying every 1s.
|
428
|
-
)
|
429
|
-
|
430
|
-
def serialize_exception(self, exc: BaseException) -> Optional[bytes]:
|
431
|
-
try:
|
432
|
-
return self.serialize(exc)
|
433
|
-
except Exception as serialization_exc:
|
434
|
-
logger.info(f"Failed to serialize exception {exc}: {serialization_exc}")
|
435
|
-
# We can't always serialize exceptions.
|
436
|
-
return None
|
437
|
-
|
438
|
-
def serialize_traceback(self, exc: BaseException) -> Tuple[Optional[bytes], Optional[bytes]]:
|
439
|
-
serialized_tb, tb_line_cache = None, None
|
440
|
-
|
441
|
-
try:
|
442
|
-
tb_dict, line_cache = extract_traceback(exc, self.task_id)
|
443
|
-
serialized_tb = self.serialize(tb_dict)
|
444
|
-
tb_line_cache = self.serialize(line_cache)
|
445
|
-
except Exception:
|
446
|
-
logger.info("Failed to serialize exception traceback.")
|
447
|
-
|
448
|
-
return serialized_tb, tb_line_cache
|
449
|
-
|
450
|
-
@asynccontextmanager
|
451
|
-
async def handle_user_exception(self) -> AsyncGenerator[None, None]:
|
452
|
-
"""Sets the task as failed in a way where it's not retried.
|
453
|
-
|
454
|
-
Used for handling exceptions from container lifecycle methods at the moment, which should
|
455
|
-
trigger a task failure state.
|
456
|
-
"""
|
457
|
-
try:
|
458
|
-
yield
|
459
|
-
except KeyboardInterrupt:
|
460
|
-
# Send no task result in case we get sigint:ed by the runner
|
461
|
-
# The status of the input should have been handled externally already in that case
|
462
|
-
raise
|
463
|
-
except BaseException as exc:
|
464
|
-
# Since this is on a different thread, sys.exc_info() can't find the exception in the stack.
|
465
|
-
traceback.print_exception(type(exc), exc, exc.__traceback__)
|
466
|
-
|
467
|
-
serialized_tb, tb_line_cache = self.serialize_traceback(exc)
|
468
|
-
|
469
|
-
result = api_pb2.GenericResult(
|
470
|
-
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
|
471
|
-
data=self.serialize_exception(exc),
|
472
|
-
exception=repr(exc),
|
473
|
-
traceback="".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
|
474
|
-
serialized_tb=serialized_tb,
|
475
|
-
tb_line_cache=tb_line_cache,
|
476
|
-
)
|
477
|
-
|
478
|
-
req = api_pb2.TaskResultRequest(result=result)
|
479
|
-
await retry_transient_errors(self._client.stub.TaskResult, req)
|
480
|
-
|
481
|
-
# Shut down the task gracefully
|
482
|
-
raise UserException()
|
483
|
-
|
484
|
-
@asynccontextmanager
|
485
|
-
async def handle_input_exception(self, input_id, started_at: float) -> AsyncGenerator[None, None]:
|
486
|
-
"""Handle an exception while processing a function input."""
|
487
|
-
try:
|
488
|
-
yield
|
489
|
-
except KeyboardInterrupt:
|
490
|
-
raise
|
491
|
-
except (InputCancellation, asyncio.CancelledError):
|
492
|
-
# just skip creating any output for this input and keep going with the next instead
|
493
|
-
# it should have been marked as cancelled already in the backend at this point so it
|
494
|
-
# won't be retried
|
495
|
-
logger.warning(f"The current input ({input_id=}) was cancelled by a user request")
|
496
|
-
await self.complete_call(started_at)
|
497
|
-
return
|
498
|
-
except BaseException as exc:
|
499
|
-
# print exception so it's logged
|
500
|
-
traceback.print_exc()
|
501
|
-
serialized_tb, tb_line_cache = self.serialize_traceback(exc)
|
502
|
-
|
503
|
-
# Note: we're not serializing the traceback since it contains
|
504
|
-
# local references that means we can't unpickle it. We *are*
|
505
|
-
# serializing the exception, which may have some issues (there
|
506
|
-
# was an earlier note about it that it might not be possible
|
507
|
-
# to unpickle it in some cases). Let's watch out for issues.
|
508
|
-
await self._push_output(
|
509
|
-
input_id,
|
510
|
-
started_at=started_at,
|
511
|
-
data_format=api_pb2.DATA_FORMAT_PICKLE,
|
512
|
-
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
|
513
|
-
data=self.serialize_exception(exc),
|
514
|
-
exception=repr(exc),
|
515
|
-
traceback=traceback.format_exc(),
|
516
|
-
serialized_tb=serialized_tb,
|
517
|
-
tb_line_cache=tb_line_cache,
|
518
|
-
)
|
519
|
-
await self.complete_call(started_at)
|
520
|
-
|
521
|
-
async def complete_call(self, started_at):
|
522
|
-
self.total_user_time += time.time() - started_at
|
523
|
-
self.calls_completed += 1
|
524
|
-
self._semaphore.release()
|
525
|
-
|
526
|
-
@synchronizer.no_io_translation
|
527
|
-
async def push_output(self, input_id, started_at: float, data: Any, data_format: int) -> None:
|
528
|
-
await self._push_output(
|
529
|
-
input_id,
|
530
|
-
started_at=started_at,
|
531
|
-
data_format=data_format,
|
532
|
-
status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
|
533
|
-
data=self.serialize_data_format(data, data_format),
|
534
|
-
)
|
535
|
-
await self.complete_call(started_at)
|
536
|
-
|
537
|
-
async def restore(self) -> None:
|
538
|
-
# Busy-wait for restore. `/__modal/restore-state.json` is created
|
539
|
-
# by the worker process with updates to the container config.
|
540
|
-
restored_path = Path(config.get("restore_state_path"))
|
541
|
-
start = time.perf_counter()
|
542
|
-
while not restored_path.exists():
|
543
|
-
logger.debug(f"Waiting for restore (elapsed={time.perf_counter() - start:.3f}s)")
|
544
|
-
await asyncio.sleep(0.01)
|
545
|
-
continue
|
546
|
-
|
547
|
-
logger.debug("Container: restored")
|
548
|
-
|
549
|
-
# Look for state file and create new client with updated credentials.
|
550
|
-
# State data is serialized with key-value pairs, example: {"task_id": "tk-000"}
|
551
|
-
with restored_path.open("r") as file:
|
552
|
-
restored_state = json.load(file)
|
553
|
-
|
554
|
-
# Local ContainerIOManager state.
|
555
|
-
for key in ["task_id", "function_id"]:
|
556
|
-
if value := restored_state.get(key):
|
557
|
-
logger.debug(f"Updating ContainerIOManager.{key} = {value}")
|
558
|
-
setattr(self, key, restored_state[key])
|
559
|
-
|
560
|
-
# Env vars and global state.
|
561
|
-
for key, value in restored_state.items():
|
562
|
-
# Empty string indicates that value does not need to be updated.
|
563
|
-
if value != "":
|
564
|
-
config.override_locally(key, value)
|
565
|
-
|
566
|
-
# Restore input to default state.
|
567
|
-
self.current_input_id = None
|
568
|
-
self.current_input_started_at = None
|
569
|
-
|
570
|
-
self._client = await _Client.from_env()
|
571
|
-
self._waiting_for_checkpoint = False
|
572
|
-
|
573
|
-
async def checkpoint(self) -> None:
|
574
|
-
"""Message server indicating that function is ready to be checkpointed."""
|
575
|
-
if self.checkpoint_id:
|
576
|
-
logger.debug(f"Checkpoint ID: {self.checkpoint_id}")
|
577
|
-
|
578
|
-
await self._client.stub.ContainerCheckpoint(
|
579
|
-
api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id)
|
580
|
-
)
|
581
|
-
|
582
|
-
self._waiting_for_checkpoint = True
|
583
|
-
await self._client._close()
|
584
|
-
|
585
|
-
logger.debug("Checkpointing request sent. Connection closed.")
|
586
|
-
await self.restore()
|
587
|
-
|
588
|
-
async def volume_commit(self, volume_ids: List[str]) -> None:
|
589
|
-
"""
|
590
|
-
Perform volume commit for given `volume_ids`.
|
591
|
-
Only used on container exit to persist uncommitted changes on behalf of user.
|
592
|
-
"""
|
593
|
-
if not volume_ids:
|
594
|
-
return
|
595
|
-
await asyncify(os.sync)()
|
596
|
-
results = await asyncio.gather(
|
597
|
-
*[
|
598
|
-
retry_transient_errors(
|
599
|
-
self._client.stub.VolumeCommit,
|
600
|
-
api_pb2.VolumeCommitRequest(volume_id=v_id),
|
601
|
-
max_retries=9,
|
602
|
-
base_delay=0.25,
|
603
|
-
max_delay=256,
|
604
|
-
delay_factor=2,
|
605
|
-
)
|
606
|
-
for v_id in volume_ids
|
607
|
-
],
|
608
|
-
return_exceptions=True,
|
609
|
-
)
|
610
|
-
for volume_id, res in zip(volume_ids, results):
|
611
|
-
if isinstance(res, Exception):
|
612
|
-
logger.error(f"modal.Volume background commit failed for {volume_id}. Exception: {res}")
|
613
|
-
else:
|
614
|
-
logger.debug(f"modal.Volume background commit success for {volume_id}.")
|
615
|
-
|
616
|
-
async def interact(self):
|
617
|
-
if self._is_interactivity_enabled:
|
618
|
-
# Currently, interactivity is enabled forever
|
619
|
-
return
|
620
|
-
self._is_interactivity_enabled = True
|
621
|
-
|
622
|
-
if not self.function_def.pty_info:
|
623
|
-
raise InvalidError(
|
624
|
-
"Interactivity is not enabled in this function. Use MODAL_INTERACTIVE_FUNCTIONS=1 to enable interactivity."
|
625
|
-
)
|
626
|
-
|
627
|
-
if self.function_def.concurrency_limit > 1:
|
628
|
-
print(
|
629
|
-
"Warning: Interactivity is not supported on functions with concurrency > 1. You may experience unexpected behavior."
|
630
|
-
)
|
631
|
-
|
632
|
-
# todo(nathan): add warning if concurrency limit > 1. but idk how to check this here
|
633
|
-
# todo(nathan): check if function interactivity is enabled
|
634
|
-
try:
|
635
|
-
await self._client.stub.FunctionStartPtyShell(Empty())
|
636
|
-
except Exception as e:
|
637
|
-
print("Error: Failed to start PTY shell.")
|
638
|
-
raise e
|
639
|
-
|
640
|
-
@classmethod
|
641
|
-
def stop_fetching_inputs(cls):
|
642
|
-
assert cls._singleton
|
643
|
-
cls._singleton._fetching_inputs = False
|
644
|
-
|
645
|
-
|
646
|
-
ContainerIOManager = synchronize_api(_ContainerIOManager)
|