modal 0.62.115__py3-none-any.whl → 0.72.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/__init__.py +13 -9
- modal/__main__.py +41 -3
- modal/_clustered_functions.py +80 -0
- modal/_clustered_functions.pyi +22 -0
- modal/_container_entrypoint.py +407 -398
- modal/_ipython.py +3 -13
- modal/_location.py +17 -10
- modal/_output.py +243 -99
- modal/_pty.py +2 -2
- modal/_resolver.py +55 -60
- modal/_resources.py +26 -7
- modal/_runtime/__init__.py +1 -0
- modal/_runtime/asgi.py +519 -0
- modal/_runtime/container_io_manager.py +1036 -0
- modal/{execution_context.py → _runtime/execution_context.py} +11 -2
- modal/_runtime/telemetry.py +169 -0
- modal/_runtime/user_code_imports.py +356 -0
- modal/_serialization.py +123 -6
- modal/_traceback.py +47 -187
- modal/_tunnel.py +50 -14
- modal/_tunnel.pyi +19 -36
- modal/_utils/app_utils.py +3 -17
- modal/_utils/async_utils.py +386 -104
- modal/_utils/blob_utils.py +157 -186
- modal/_utils/bytes_io_segment_payload.py +97 -0
- modal/_utils/deprecation.py +89 -0
- modal/_utils/docker_utils.py +98 -0
- modal/_utils/function_utils.py +299 -98
- modal/_utils/grpc_testing.py +47 -34
- modal/_utils/grpc_utils.py +54 -21
- modal/_utils/hash_utils.py +51 -10
- modal/_utils/http_utils.py +39 -9
- modal/_utils/logger.py +2 -1
- modal/_utils/mount_utils.py +34 -16
- modal/_utils/name_utils.py +58 -0
- modal/_utils/package_utils.py +14 -1
- modal/_utils/pattern_utils.py +205 -0
- modal/_utils/rand_pb_testing.py +3 -3
- modal/_utils/shell_utils.py +15 -49
- modal/_vendor/a2wsgi_wsgi.py +62 -72
- modal/_vendor/cloudpickle.py +1 -1
- modal/_watcher.py +12 -10
- modal/app.py +561 -323
- modal/app.pyi +474 -262
- modal/call_graph.py +7 -6
- modal/cli/_download.py +22 -6
- modal/cli/_traceback.py +200 -0
- modal/cli/app.py +203 -42
- modal/cli/config.py +12 -5
- modal/cli/container.py +61 -13
- modal/cli/dict.py +128 -0
- modal/cli/entry_point.py +26 -13
- modal/cli/environment.py +40 -9
- modal/cli/import_refs.py +21 -48
- modal/cli/launch.py +28 -14
- modal/cli/network_file_system.py +57 -21
- modal/cli/profile.py +1 -1
- modal/cli/programs/run_jupyter.py +34 -9
- modal/cli/programs/vscode.py +58 -8
- modal/cli/queues.py +131 -0
- modal/cli/run.py +199 -96
- modal/cli/secret.py +5 -4
- modal/cli/token.py +7 -2
- modal/cli/utils.py +74 -8
- modal/cli/volume.py +97 -56
- modal/client.py +248 -144
- modal/client.pyi +156 -124
- modal/cloud_bucket_mount.py +43 -30
- modal/cloud_bucket_mount.pyi +32 -25
- modal/cls.py +528 -141
- modal/cls.pyi +189 -145
- modal/config.py +32 -15
- modal/container_process.py +177 -0
- modal/container_process.pyi +82 -0
- modal/dict.py +50 -54
- modal/dict.pyi +120 -164
- modal/environments.py +106 -5
- modal/environments.pyi +77 -25
- modal/exception.py +30 -43
- modal/experimental.py +62 -2
- modal/file_io.py +537 -0
- modal/file_io.pyi +235 -0
- modal/file_pattern_matcher.py +197 -0
- modal/functions.py +846 -428
- modal/functions.pyi +446 -387
- modal/gpu.py +57 -44
- modal/image.py +946 -417
- modal/image.pyi +584 -245
- modal/io_streams.py +434 -0
- modal/io_streams.pyi +122 -0
- modal/mount.py +223 -90
- modal/mount.pyi +241 -243
- modal/network_file_system.py +85 -86
- modal/network_file_system.pyi +151 -110
- modal/object.py +66 -36
- modal/object.pyi +166 -143
- modal/output.py +63 -0
- modal/parallel_map.py +73 -47
- modal/parallel_map.pyi +51 -63
- modal/partial_function.py +272 -107
- modal/partial_function.pyi +219 -120
- modal/proxy.py +15 -12
- modal/proxy.pyi +3 -8
- modal/queue.py +96 -72
- modal/queue.pyi +210 -135
- modal/requirements/2024.04.txt +2 -1
- modal/requirements/2024.10.txt +16 -0
- modal/requirements/README.md +21 -0
- modal/requirements/base-images.json +22 -0
- modal/retries.py +45 -4
- modal/runner.py +325 -203
- modal/runner.pyi +124 -110
- modal/running_app.py +27 -4
- modal/sandbox.py +509 -231
- modal/sandbox.pyi +396 -169
- modal/schedule.py +2 -2
- modal/scheduler_placement.py +20 -3
- modal/secret.py +41 -25
- modal/secret.pyi +62 -42
- modal/serving.py +39 -49
- modal/serving.pyi +37 -43
- modal/stream_type.py +15 -0
- modal/token_flow.py +5 -3
- modal/token_flow.pyi +37 -32
- modal/volume.py +123 -137
- modal/volume.pyi +228 -221
- {modal-0.62.115.dist-info → modal-0.72.11.dist-info}/METADATA +5 -5
- modal-0.72.11.dist-info/RECORD +174 -0
- {modal-0.62.115.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
- modal_docs/gen_reference_docs.py +3 -1
- modal_docs/mdmd/mdmd.py +0 -1
- modal_docs/mdmd/signatures.py +1 -2
- modal_global_objects/images/base_images.py +28 -0
- modal_global_objects/mounts/python_standalone.py +2 -2
- modal_proto/__init__.py +1 -1
- modal_proto/api.proto +1231 -531
- modal_proto/api_grpc.py +750 -430
- modal_proto/api_pb2.py +2102 -1176
- modal_proto/api_pb2.pyi +8859 -0
- modal_proto/api_pb2_grpc.py +1329 -675
- modal_proto/api_pb2_grpc.pyi +1416 -0
- modal_proto/modal_api_grpc.py +149 -0
- modal_proto/modal_options_grpc.py +3 -0
- modal_proto/options_pb2.pyi +20 -0
- modal_proto/options_pb2_grpc.pyi +7 -0
- modal_proto/py.typed +0 -0
- modal_version/__init__.py +1 -1
- modal_version/_version_generated.py +2 -2
- modal/_asgi.py +0 -370
- modal/_container_exec.py +0 -128
- modal/_container_io_manager.py +0 -646
- modal/_container_io_manager.pyi +0 -412
- modal/_sandbox_shell.py +0 -49
- modal/app_utils.py +0 -20
- modal/app_utils.pyi +0 -17
- modal/execution_context.pyi +0 -37
- modal/shared_volume.py +0 -23
- modal/shared_volume.pyi +0 -24
- modal-0.62.115.dist-info/RECORD +0 -207
- modal_global_objects/images/conda.py +0 -15
- modal_global_objects/images/debian_slim.py +0 -15
- modal_global_objects/images/micromamba.py +0 -15
- test/__init__.py +0 -1
- test/aio_test.py +0 -12
- test/async_utils_test.py +0 -279
- test/blob_test.py +0 -67
- test/cli_imports_test.py +0 -149
- test/cli_test.py +0 -674
- test/client_test.py +0 -203
- test/cloud_bucket_mount_test.py +0 -22
- test/cls_test.py +0 -636
- test/config_test.py +0 -149
- test/conftest.py +0 -1485
- test/container_app_test.py +0 -50
- test/container_test.py +0 -1405
- test/cpu_test.py +0 -23
- test/decorator_test.py +0 -85
- test/deprecation_test.py +0 -34
- test/dict_test.py +0 -51
- test/e2e_test.py +0 -68
- test/error_test.py +0 -7
- test/function_serialization_test.py +0 -32
- test/function_test.py +0 -791
- test/function_utils_test.py +0 -101
- test/gpu_test.py +0 -159
- test/grpc_utils_test.py +0 -82
- test/helpers.py +0 -47
- test/image_test.py +0 -814
- test/live_reload_test.py +0 -80
- test/lookup_test.py +0 -70
- test/mdmd_test.py +0 -329
- test/mount_test.py +0 -162
- test/mounted_files_test.py +0 -327
- test/network_file_system_test.py +0 -188
- test/notebook_test.py +0 -66
- test/object_test.py +0 -41
- test/package_utils_test.py +0 -25
- test/queue_test.py +0 -115
- test/resolver_test.py +0 -59
- test/retries_test.py +0 -67
- test/runner_test.py +0 -85
- test/sandbox_test.py +0 -191
- test/schedule_test.py +0 -15
- test/scheduler_placement_test.py +0 -57
- test/secret_test.py +0 -89
- test/serialization_test.py +0 -50
- test/stub_composition_test.py +0 -10
- test/stub_test.py +0 -361
- test/test_asgi_wrapper.py +0 -234
- test/token_flow_test.py +0 -18
- test/traceback_test.py +0 -135
- test/tunnel_test.py +0 -29
- test/utils_test.py +0 -88
- test/version_test.py +0 -14
- test/volume_test.py +0 -397
- test/watcher_test.py +0 -58
- test/webhook_test.py +0 -145
- {modal-0.62.115.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
- {modal-0.62.115.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
- {modal-0.62.115.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
test/conftest.py
DELETED
@@ -1,1485 +0,0 @@
|
|
1
|
-
# Copyright Modal Labs 2024
|
2
|
-
from __future__ import annotations
|
3
|
-
|
4
|
-
import asyncio
|
5
|
-
import contextlib
|
6
|
-
import dataclasses
|
7
|
-
import hashlib
|
8
|
-
import inspect
|
9
|
-
import os
|
10
|
-
import pytest
|
11
|
-
import shutil
|
12
|
-
import sys
|
13
|
-
import tempfile
|
14
|
-
import textwrap
|
15
|
-
import threading
|
16
|
-
import traceback
|
17
|
-
from collections import defaultdict
|
18
|
-
from pathlib import Path
|
19
|
-
from typing import Dict, Iterator, Optional, get_args
|
20
|
-
|
21
|
-
import aiohttp.web
|
22
|
-
import aiohttp.web_runner
|
23
|
-
import grpclib.server
|
24
|
-
import pkg_resources
|
25
|
-
import pytest_asyncio
|
26
|
-
from google.protobuf.empty_pb2 import Empty
|
27
|
-
from grpclib import GRPCError, Status
|
28
|
-
|
29
|
-
import modal._serialization
|
30
|
-
from modal import __version__, config
|
31
|
-
from modal._container_io_manager import _ContainerIOManager
|
32
|
-
from modal._serialization import serialize_data_format
|
33
|
-
from modal._utils.async_utils import asyncify, synchronize_api
|
34
|
-
from modal._utils.grpc_testing import patch_mock_servicer
|
35
|
-
from modal._utils.grpc_utils import find_free_port
|
36
|
-
from modal._utils.http_utils import run_temporary_http_server
|
37
|
-
from modal._vendor import cloudpickle
|
38
|
-
from modal.client import Client
|
39
|
-
from modal.image import ImageBuilderVersion
|
40
|
-
from modal.mount import client_mount_name
|
41
|
-
from modal_proto import api_grpc, api_pb2
|
42
|
-
|
43
|
-
|
44
|
-
@dataclasses.dataclass
|
45
|
-
class VolumeFile:
|
46
|
-
data: bytes
|
47
|
-
data_blob_id: str
|
48
|
-
mode: int
|
49
|
-
|
50
|
-
|
51
|
-
# TODO: Isolate all test config from the host
|
52
|
-
@pytest.fixture(scope="session", autouse=True)
|
53
|
-
def set_env():
|
54
|
-
os.environ["MODAL_ENVIRONMENT"] = "main"
|
55
|
-
|
56
|
-
|
57
|
-
@patch_mock_servicer
|
58
|
-
class MockClientServicer(api_grpc.ModalClientBase):
|
59
|
-
# TODO(erikbern): add more annotations
|
60
|
-
container_inputs: list[api_pb2.FunctionGetInputsResponse]
|
61
|
-
container_outputs: list[api_pb2.FunctionPutOutputsRequest]
|
62
|
-
fc_data_in: defaultdict[str, asyncio.Queue[api_pb2.DataChunk]]
|
63
|
-
fc_data_out: defaultdict[str, asyncio.Queue[api_pb2.DataChunk]]
|
64
|
-
|
65
|
-
def __init__(self, blob_host, blobs):
|
66
|
-
self.use_blob_outputs = False
|
67
|
-
self.put_outputs_barrier = threading.Barrier(
|
68
|
-
1, timeout=10
|
69
|
-
) # set to non-1 to get lock-step of output pushing within a test
|
70
|
-
self.get_inputs_barrier = threading.Barrier(
|
71
|
-
1, timeout=10
|
72
|
-
) # set to non-1 to get lock-step of input releases within a test
|
73
|
-
|
74
|
-
self.app_state_history = defaultdict(list)
|
75
|
-
self.app_heartbeats: Dict[str, int] = defaultdict(int)
|
76
|
-
self.container_checkpoint_requests = 0
|
77
|
-
self.n_blobs = 0
|
78
|
-
self.blob_host = blob_host
|
79
|
-
self.blobs = blobs # shared dict
|
80
|
-
self.requests = []
|
81
|
-
self.done = False
|
82
|
-
self.rate_limit_sleep_duration = None
|
83
|
-
self.fail_get_inputs = False
|
84
|
-
self.slow_put_inputs = False
|
85
|
-
self.container_inputs = []
|
86
|
-
self.container_outputs = []
|
87
|
-
self.fc_data_in = defaultdict(lambda: asyncio.Queue()) # unbounded
|
88
|
-
self.fc_data_out = defaultdict(lambda: asyncio.Queue()) # unbounded
|
89
|
-
self.queue = []
|
90
|
-
self.deployed_apps = {
|
91
|
-
client_mount_name(): "ap-x",
|
92
|
-
}
|
93
|
-
self.app_objects = {}
|
94
|
-
self.app_single_objects = {}
|
95
|
-
self.app_unindexed_objects = {
|
96
|
-
"ap-1": ["im-1", "vo-1"],
|
97
|
-
}
|
98
|
-
self.n_inputs = 0
|
99
|
-
self.n_queues = 0
|
100
|
-
self.n_dict_heartbeats = 0
|
101
|
-
self.n_queue_heartbeats = 0
|
102
|
-
self.n_nfs_heartbeats = 0
|
103
|
-
self.n_vol_heartbeats = 0
|
104
|
-
self.n_mounts = 0
|
105
|
-
self.n_mount_files = 0
|
106
|
-
self.mount_contents = {}
|
107
|
-
self.files_name2sha = {}
|
108
|
-
self.files_sha2data = {}
|
109
|
-
self.function_id_for_function_call = {}
|
110
|
-
self.client_calls = {}
|
111
|
-
self.function_is_running = False
|
112
|
-
self.n_functions = 0
|
113
|
-
self.n_schedules = 0
|
114
|
-
self.function2schedule = {}
|
115
|
-
self.function_create_error = False
|
116
|
-
self.heartbeat_status_code = None
|
117
|
-
self.n_apps = 0
|
118
|
-
self.classes = {}
|
119
|
-
|
120
|
-
self.task_result = None
|
121
|
-
|
122
|
-
self.nfs_files: Dict[str, Dict[str, api_pb2.SharedVolumePutFileRequest]] = defaultdict(dict)
|
123
|
-
self.volume_files: Dict[str, Dict[str, VolumeFile]] = defaultdict(dict)
|
124
|
-
self.images = {}
|
125
|
-
self.image_build_function_ids = {}
|
126
|
-
self.image_builder_versions = {}
|
127
|
-
self.force_built_images = []
|
128
|
-
self.fail_blob_create = []
|
129
|
-
self.blob_create_metadata = None
|
130
|
-
self.blob_multipart_threshold = 10_000_000
|
131
|
-
|
132
|
-
self.precreated_functions = set()
|
133
|
-
self.app_functions = {}
|
134
|
-
self.fcidx = 0
|
135
|
-
|
136
|
-
self.function_serialized = None
|
137
|
-
self.class_serialized = None
|
138
|
-
|
139
|
-
self.client_hello_metadata = None
|
140
|
-
|
141
|
-
self.dicts = {}
|
142
|
-
self.secrets = {}
|
143
|
-
|
144
|
-
self.deployed_dicts = {}
|
145
|
-
self.deployed_mounts = {
|
146
|
-
(client_mount_name(), api_pb2.DEPLOYMENT_NAMESPACE_GLOBAL): "mo-123",
|
147
|
-
}
|
148
|
-
self.deployed_nfss = {}
|
149
|
-
self.deployed_queues = {}
|
150
|
-
self.deployed_secrets = {}
|
151
|
-
self.deployed_volumes = {}
|
152
|
-
|
153
|
-
self.cleared_function_calls = set()
|
154
|
-
|
155
|
-
self.cancelled_calls = []
|
156
|
-
|
157
|
-
self.app_client_disconnect_count = 0
|
158
|
-
self.app_get_logs_initial_count = 0
|
159
|
-
self.app_set_objects_count = 0
|
160
|
-
|
161
|
-
self.volume_counter = 0
|
162
|
-
# Volume-id -> commit/reload count
|
163
|
-
self.volume_commits: Dict[str, int] = defaultdict(lambda: 0)
|
164
|
-
self.volume_reloads: Dict[str, int] = defaultdict(lambda: 0)
|
165
|
-
|
166
|
-
self.sandbox_defs = []
|
167
|
-
self.sandbox: asyncio.subprocess.Process = None
|
168
|
-
|
169
|
-
# Whether the sandbox is executing a shell program in interactive mode.
|
170
|
-
self.sandbox_is_interactive = False
|
171
|
-
self.sandbox_shell_prompt = "TEST_PROMPT# "
|
172
|
-
self.sandbox_result: Optional[api_pb2.GenericResult] = None
|
173
|
-
|
174
|
-
self.token_flow_localhost_port = None
|
175
|
-
self.queue_max_len = 1_00
|
176
|
-
|
177
|
-
self.container_heartbeat_response = None
|
178
|
-
self.container_heartbeat_abort = threading.Event()
|
179
|
-
|
180
|
-
@self.function_body
|
181
|
-
def default_function_body(*args, **kwargs):
|
182
|
-
return sum(arg**2 for arg in args) + sum(value**2 for key, value in kwargs.items())
|
183
|
-
|
184
|
-
def function_body(self, func):
|
185
|
-
"""Decorator for setting the function that will be called for any FunctionGetOutputs calls"""
|
186
|
-
self._function_body = func
|
187
|
-
return func
|
188
|
-
|
189
|
-
def container_heartbeat_return_now(self, response: api_pb2.ContainerHeartbeatResponse):
|
190
|
-
self.container_heartbeat_response = response
|
191
|
-
self.container_heartbeat_abort.set()
|
192
|
-
|
193
|
-
def get_function_metadata(self, object_id: str) -> api_pb2.FunctionHandleMetadata:
|
194
|
-
definition: api_pb2.Function = self.app_functions[object_id]
|
195
|
-
return api_pb2.FunctionHandleMetadata(
|
196
|
-
function_name=definition.function_name,
|
197
|
-
function_type=definition.function_type,
|
198
|
-
web_url=definition.web_url,
|
199
|
-
is_method=definition.is_method,
|
200
|
-
)
|
201
|
-
|
202
|
-
def get_class_metadata(self, object_id: str) -> api_pb2.ClassHandleMetadata:
|
203
|
-
class_handle_metadata = api_pb2.ClassHandleMetadata()
|
204
|
-
for f_name, f_id in self.classes[object_id].items():
|
205
|
-
function_handle_metadata = self.get_function_metadata(f_id)
|
206
|
-
class_handle_metadata.methods.append(
|
207
|
-
api_pb2.ClassMethod(
|
208
|
-
function_name=f_name, function_id=f_id, function_handle_metadata=function_handle_metadata
|
209
|
-
)
|
210
|
-
)
|
211
|
-
return class_handle_metadata
|
212
|
-
|
213
|
-
def get_object_metadata(self, object_id) -> api_pb2.Object:
|
214
|
-
if object_id.startswith("fu-"):
|
215
|
-
res = api_pb2.Object(function_handle_metadata=self.get_function_metadata(object_id))
|
216
|
-
|
217
|
-
elif object_id.startswith("cs-"):
|
218
|
-
res = api_pb2.Object(class_handle_metadata=self.get_class_metadata(object_id))
|
219
|
-
|
220
|
-
elif object_id.startswith("mo-"):
|
221
|
-
mount_handle_metadata = api_pb2.MountHandleMetadata(content_checksum_sha256_hex="abc123")
|
222
|
-
res = api_pb2.Object(mount_handle_metadata=mount_handle_metadata)
|
223
|
-
|
224
|
-
elif object_id.startswith("sb-"):
|
225
|
-
sandbox_handle_metadata = api_pb2.SandboxHandleMetadata(result=self.sandbox_result)
|
226
|
-
res = api_pb2.Object(sandbox_handle_metadata=sandbox_handle_metadata)
|
227
|
-
|
228
|
-
else:
|
229
|
-
res = api_pb2.Object()
|
230
|
-
|
231
|
-
res.object_id = object_id
|
232
|
-
return res
|
233
|
-
|
234
|
-
### App
|
235
|
-
|
236
|
-
async def AppCreate(self, stream):
|
237
|
-
request: api_pb2.AppCreateRequest = await stream.recv_message()
|
238
|
-
self.requests.append(request)
|
239
|
-
self.n_apps += 1
|
240
|
-
app_id = f"ap-{self.n_apps}"
|
241
|
-
self.app_state_history[app_id].append(api_pb2.APP_STATE_INITIALIZING)
|
242
|
-
await stream.send_message(
|
243
|
-
api_pb2.AppCreateResponse(app_id=app_id, app_logs_url="https://modaltest.com/apps/ap-123")
|
244
|
-
)
|
245
|
-
|
246
|
-
async def AppClientDisconnect(self, stream):
|
247
|
-
request: api_pb2.AppClientDisconnectRequest = await stream.recv_message()
|
248
|
-
self.requests.append(request)
|
249
|
-
self.done = True
|
250
|
-
self.app_client_disconnect_count += 1
|
251
|
-
state_history = self.app_state_history[request.app_id]
|
252
|
-
if state_history[-1] not in [api_pb2.APP_STATE_DETACHED, api_pb2.APP_STATE_DEPLOYED]:
|
253
|
-
state_history.append(api_pb2.APP_STATE_STOPPED)
|
254
|
-
await stream.send_message(Empty())
|
255
|
-
|
256
|
-
async def AppGetLogs(self, stream):
|
257
|
-
request: api_pb2.AppGetLogsRequest = await stream.recv_message()
|
258
|
-
if not request.last_entry_id:
|
259
|
-
# Just count initial requests
|
260
|
-
self.app_get_logs_initial_count += 1
|
261
|
-
last_entry_id = "1"
|
262
|
-
else:
|
263
|
-
last_entry_id = str(int(request.last_entry_id) + 1)
|
264
|
-
await asyncio.sleep(0.5)
|
265
|
-
log = api_pb2.TaskLogs(data=f"hello, world ({last_entry_id})\n", file_descriptor=api_pb2.FILE_DESCRIPTOR_STDOUT)
|
266
|
-
await stream.send_message(api_pb2.TaskLogsBatch(entry_id=last_entry_id, items=[log]))
|
267
|
-
if self.done:
|
268
|
-
await stream.send_message(api_pb2.TaskLogsBatch(app_done=True))
|
269
|
-
|
270
|
-
async def AppGetObjects(self, stream):
|
271
|
-
request: api_pb2.AppGetObjectsRequest = await stream.recv_message()
|
272
|
-
object_ids = self.app_objects.get(request.app_id, {})
|
273
|
-
objects = list(object_ids.items())
|
274
|
-
if request.include_unindexed:
|
275
|
-
unindexed_object_ids = self.app_unindexed_objects.get(request.app_id, [])
|
276
|
-
objects += [(None, object_id) for object_id in unindexed_object_ids]
|
277
|
-
items = [
|
278
|
-
api_pb2.AppGetObjectsItem(tag=tag, object=self.get_object_metadata(object_id)) for tag, object_id in objects
|
279
|
-
]
|
280
|
-
await stream.send_message(api_pb2.AppGetObjectsResponse(items=items))
|
281
|
-
|
282
|
-
async def AppSetObjects(self, stream):
|
283
|
-
request: api_pb2.AppSetObjectsRequest = await stream.recv_message()
|
284
|
-
self.app_objects[request.app_id] = dict(request.indexed_object_ids)
|
285
|
-
self.app_unindexed_objects[request.app_id] = list(request.unindexed_object_ids)
|
286
|
-
if request.single_object_id:
|
287
|
-
self.app_single_objects[request.app_id] = request.single_object_id
|
288
|
-
self.app_set_objects_count += 1
|
289
|
-
if request.new_app_state:
|
290
|
-
self.app_state_history[request.app_id].append(request.new_app_state)
|
291
|
-
await stream.send_message(Empty())
|
292
|
-
|
293
|
-
async def AppDeploy(self, stream):
|
294
|
-
request: api_pb2.AppDeployRequest = await stream.recv_message()
|
295
|
-
self.deployed_apps[request.name] = request.app_id
|
296
|
-
self.app_state_history[request.app_id].append(api_pb2.APP_STATE_DEPLOYED)
|
297
|
-
await stream.send_message(api_pb2.AppDeployResponse(url="http://test.modal.com/foo/bar"))
|
298
|
-
|
299
|
-
async def AppGetByDeploymentName(self, stream):
|
300
|
-
request: api_pb2.AppGetByDeploymentNameRequest = await stream.recv_message()
|
301
|
-
await stream.send_message(api_pb2.AppGetByDeploymentNameResponse(app_id=self.deployed_apps.get(request.name)))
|
302
|
-
|
303
|
-
async def AppHeartbeat(self, stream):
|
304
|
-
request: api_pb2.AppHeartbeatRequest = await stream.recv_message()
|
305
|
-
self.requests.append(request)
|
306
|
-
self.app_heartbeats[request.app_id] += 1
|
307
|
-
await stream.send_message(Empty())
|
308
|
-
|
309
|
-
async def AppList(self, stream):
|
310
|
-
await stream.recv_message()
|
311
|
-
apps = []
|
312
|
-
for app_name, app_id in self.deployed_apps.items():
|
313
|
-
apps.append(api_pb2.AppStats(name=app_name, description=app_name, app_id=app_id))
|
314
|
-
await stream.send_message(api_pb2.AppListResponse(apps=apps))
|
315
|
-
|
316
|
-
### Checkpoint
|
317
|
-
|
318
|
-
async def ContainerCheckpoint(self, stream):
|
319
|
-
request: api_pb2.ContainerCheckpointRequest = await stream.recv_message()
|
320
|
-
self.requests.append(request)
|
321
|
-
self.container_checkpoint_requests += 1
|
322
|
-
await stream.send_message(Empty())
|
323
|
-
|
324
|
-
### Blob
|
325
|
-
|
326
|
-
async def BlobCreate(self, stream):
|
327
|
-
req = await stream.recv_message()
|
328
|
-
# This is used to test retry_transient_errors, see grpc_utils_test.py
|
329
|
-
self.blob_create_metadata = stream.metadata
|
330
|
-
if len(self.fail_blob_create) > 0:
|
331
|
-
status_code = self.fail_blob_create.pop()
|
332
|
-
raise GRPCError(status_code, "foobar")
|
333
|
-
elif req.content_length > self.blob_multipart_threshold:
|
334
|
-
blob_id = await self.next_blob_id()
|
335
|
-
num_parts = (req.content_length + self.blob_multipart_threshold - 1) // self.blob_multipart_threshold
|
336
|
-
upload_urls = []
|
337
|
-
for part_number in range(num_parts):
|
338
|
-
upload_url = f"{self.blob_host}/upload?blob_id={blob_id}&part_number={part_number}"
|
339
|
-
upload_urls.append(upload_url)
|
340
|
-
|
341
|
-
await stream.send_message(
|
342
|
-
api_pb2.BlobCreateResponse(
|
343
|
-
blob_id=blob_id,
|
344
|
-
multipart=api_pb2.MultiPartUpload(
|
345
|
-
part_length=self.blob_multipart_threshold,
|
346
|
-
upload_urls=upload_urls,
|
347
|
-
completion_url=f"{self.blob_host}/complete_multipart?blob_id={blob_id}",
|
348
|
-
),
|
349
|
-
)
|
350
|
-
)
|
351
|
-
else:
|
352
|
-
blob_id = await self.next_blob_id()
|
353
|
-
upload_url = f"{self.blob_host}/upload?blob_id={blob_id}"
|
354
|
-
await stream.send_message(api_pb2.BlobCreateResponse(blob_id=blob_id, upload_url=upload_url))
|
355
|
-
|
356
|
-
async def next_blob_id(self):
|
357
|
-
self.n_blobs += 1
|
358
|
-
blob_id = f"bl-{self.n_blobs}"
|
359
|
-
return blob_id
|
360
|
-
|
361
|
-
async def BlobGet(self, stream):
|
362
|
-
request: api_pb2.BlobGetRequest = await stream.recv_message()
|
363
|
-
download_url = f"{self.blob_host}/download?blob_id={request.blob_id}"
|
364
|
-
await stream.send_message(api_pb2.BlobGetResponse(download_url=download_url))
|
365
|
-
|
366
|
-
### Class
|
367
|
-
|
368
|
-
async def ClassCreate(self, stream):
|
369
|
-
request: api_pb2.ClassCreateRequest = await stream.recv_message()
|
370
|
-
assert request.app_id
|
371
|
-
methods: dict[str, str] = {method.function_name: method.function_id for method in request.methods}
|
372
|
-
class_id = "cs-" + str(len(self.classes))
|
373
|
-
self.classes[class_id] = methods
|
374
|
-
await stream.send_message(
|
375
|
-
api_pb2.ClassCreateResponse(class_id=class_id, handle_metadata=self.get_class_metadata(class_id))
|
376
|
-
)
|
377
|
-
|
378
|
-
async def ClassGet(self, stream):
|
379
|
-
request: api_pb2.ClassGetRequest = await stream.recv_message()
|
380
|
-
app_id = self.deployed_apps.get(request.app_name)
|
381
|
-
app_objects = self.app_objects[app_id]
|
382
|
-
object_id = app_objects.get(request.object_tag)
|
383
|
-
if object_id is None:
|
384
|
-
raise GRPCError(Status.NOT_FOUND, f"can't find object {request.object_tag}")
|
385
|
-
await stream.send_message(
|
386
|
-
api_pb2.ClassGetResponse(class_id=object_id, handle_metadata=self.get_class_metadata(object_id))
|
387
|
-
)
|
388
|
-
|
389
|
-
### Client
|
390
|
-
|
391
|
-
async def ClientHello(self, stream):
|
392
|
-
request: Empty = await stream.recv_message()
|
393
|
-
self.requests.append(request)
|
394
|
-
self.client_create_metadata = stream.metadata
|
395
|
-
client_version = stream.metadata["x-modal-client-version"]
|
396
|
-
image_builder_version = max(get_args(ImageBuilderVersion))
|
397
|
-
warning = ""
|
398
|
-
assert stream.user_agent.startswith(f"modal-client/{__version__} ")
|
399
|
-
if stream.metadata.get("x-modal-token-id") == "bad":
|
400
|
-
raise GRPCError(Status.UNAUTHENTICATED, "bad bad bad")
|
401
|
-
elif client_version == "unauthenticated":
|
402
|
-
raise GRPCError(Status.UNAUTHENTICATED, "failed authentication")
|
403
|
-
elif client_version == "deprecated":
|
404
|
-
warning = "SUPER OLD"
|
405
|
-
elif client_version == "timeout":
|
406
|
-
await asyncio.sleep(60)
|
407
|
-
elif pkg_resources.parse_version(client_version) < pkg_resources.parse_version(__version__):
|
408
|
-
raise GRPCError(Status.FAILED_PRECONDITION, "Old client")
|
409
|
-
resp = api_pb2.ClientHelloResponse(warning=warning, image_builder_version=image_builder_version)
|
410
|
-
await stream.send_message(resp)
|
411
|
-
|
412
|
-
# Container
|
413
|
-
|
414
|
-
async def ContainerHeartbeat(self, stream):
|
415
|
-
request: api_pb2.ContainerHeartbeatRequest = await stream.recv_message()
|
416
|
-
self.requests.append(request)
|
417
|
-
# Return earlier than the usual 15-second heartbeat to avoid suspending tests.
|
418
|
-
await asyncify(self.container_heartbeat_abort.wait)(5)
|
419
|
-
if self.container_heartbeat_response:
|
420
|
-
await stream.send_message(self.container_heartbeat_response)
|
421
|
-
self.container_heartbeat_response = None
|
422
|
-
else:
|
423
|
-
await stream.send_message(api_pb2.ContainerHeartbeatResponse())
|
424
|
-
|
425
|
-
async def ContainerExec(self, stream):
|
426
|
-
_request: api_pb2.ContainerExecRequest = await stream.recv_message()
|
427
|
-
await stream.send_message(api_pb2.ContainerExecResponse(exec_id="container_exec_id"))
|
428
|
-
|
429
|
-
async def ContainerExecGetOutput(self, stream):
|
430
|
-
_request: api_pb2.ContainerExecGetOutputRequest = await stream.recv_message()
|
431
|
-
await stream.send_message(
|
432
|
-
api_pb2.RuntimeOutputBatch(
|
433
|
-
items=[
|
434
|
-
api_pb2.RuntimeOutputMessage(
|
435
|
-
file_descriptor=api_pb2.FileDescriptor.FILE_DESCRIPTOR_STDOUT, message="Hello World"
|
436
|
-
)
|
437
|
-
]
|
438
|
-
)
|
439
|
-
)
|
440
|
-
await stream.send_message(api_pb2.RuntimeOutputBatch(exit_code=0))
|
441
|
-
|
442
|
-
### Dict
|
443
|
-
|
444
|
-
async def DictCreate(self, stream):
|
445
|
-
request: api_pb2.DictCreateRequest = await stream.recv_message()
|
446
|
-
if request.existing_dict_id:
|
447
|
-
dict_id = request.existing_dict_id
|
448
|
-
else:
|
449
|
-
dict_id = f"di-{len(self.dicts)}"
|
450
|
-
self.dicts[dict_id] = {}
|
451
|
-
await stream.send_message(api_pb2.DictCreateResponse(dict_id=dict_id))
|
452
|
-
|
453
|
-
async def DictGetOrCreate(self, stream):
|
454
|
-
request: api_pb2.DictGetOrCreateRequest = await stream.recv_message()
|
455
|
-
k = (request.deployment_name, request.namespace, request.environment_name)
|
456
|
-
if k in self.deployed_dicts:
|
457
|
-
dict_id = self.deployed_dicts[k]
|
458
|
-
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING:
|
459
|
-
dict_id = f"di-{len(self.dicts)}"
|
460
|
-
self.dicts[dict_id] = {entry.key: entry.value for entry in request.data}
|
461
|
-
self.deployed_dicts[k] = dict_id
|
462
|
-
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_EPHEMERAL:
|
463
|
-
dict_id = f"di-{len(self.dicts)}"
|
464
|
-
self.dicts[dict_id] = {entry.key: entry.value for entry in request.data}
|
465
|
-
else:
|
466
|
-
raise GRPCError(Status.NOT_FOUND, "Queue not found")
|
467
|
-
await stream.send_message(api_pb2.DictGetOrCreateResponse(dict_id=dict_id))
|
468
|
-
|
469
|
-
async def DictHeartbeat(self, stream):
|
470
|
-
await stream.recv_message()
|
471
|
-
self.n_dict_heartbeats += 1
|
472
|
-
await stream.send_message(Empty())
|
473
|
-
|
474
|
-
async def DictDelete(self, stream):
|
475
|
-
request: api_pb2.DictDeleteRequest = await stream.recv_message()
|
476
|
-
self.deployed_dicts = {k: v for k, v in self.deployed_dicts.items() if v != request.dict_id}
|
477
|
-
await stream.send_message(Empty())
|
478
|
-
|
479
|
-
async def DictClear(self, stream):
|
480
|
-
request: api_pb2.DictGetRequest = await stream.recv_message()
|
481
|
-
self.dicts[request.dict_id] = {}
|
482
|
-
await stream.send_message(Empty())
|
483
|
-
|
484
|
-
async def DictGet(self, stream):
|
485
|
-
request: api_pb2.DictGetRequest = await stream.recv_message()
|
486
|
-
d = self.dicts[request.dict_id]
|
487
|
-
await stream.send_message(api_pb2.DictGetResponse(value=d.get(request.key), found=bool(request.key in d)))
|
488
|
-
|
489
|
-
async def DictLen(self, stream):
|
490
|
-
request: api_pb2.DictLenRequest = await stream.recv_message()
|
491
|
-
await stream.send_message(api_pb2.DictLenResponse(len=len(self.dicts[request.dict_id])))
|
492
|
-
|
493
|
-
async def DictUpdate(self, stream):
|
494
|
-
request: api_pb2.DictUpdateRequest = await stream.recv_message()
|
495
|
-
for update in request.updates:
|
496
|
-
self.dicts[request.dict_id][update.key] = update.value
|
497
|
-
await stream.send_message(api_pb2.DictUpdateResponse())
|
498
|
-
|
499
|
-
async def DictContents(self, stream):
|
500
|
-
request: api_pb2.DictGetRequest = await stream.recv_message()
|
501
|
-
for k, v in self.dicts[request.dict_id].items():
|
502
|
-
await stream.send_message(api_pb2.DictEntry(key=k, value=v))
|
503
|
-
|
504
|
-
### Function
|
505
|
-
|
506
|
-
async def FunctionBindParams(self, stream):
|
507
|
-
request: api_pb2.FunctionBindParamsRequest = await stream.recv_message()
|
508
|
-
assert request.function_id
|
509
|
-
assert request.serialized_params
|
510
|
-
self.n_functions += 1
|
511
|
-
function_id = f"fu-{self.n_functions}"
|
512
|
-
|
513
|
-
await stream.send_message(api_pb2.FunctionBindParamsResponse(bound_function_id=function_id))
|
514
|
-
|
515
|
-
@contextlib.contextmanager
|
516
|
-
def input_lockstep(self) -> Iterator[threading.Barrier]:
|
517
|
-
self.get_inputs_barrier = threading.Barrier(2, timeout=10)
|
518
|
-
yield self.get_inputs_barrier
|
519
|
-
self.get_inputs_barrier = threading.Barrier(1)
|
520
|
-
|
521
|
-
@contextlib.contextmanager
|
522
|
-
def output_lockstep(self) -> Iterator[threading.Barrier]:
|
523
|
-
self.put_outputs_barrier = threading.Barrier(2, timeout=10)
|
524
|
-
yield self.put_outputs_barrier
|
525
|
-
self.put_outputs_barrier = threading.Barrier(1)
|
526
|
-
|
527
|
-
async def FunctionGetInputs(self, stream):
|
528
|
-
self.get_inputs_barrier.wait()
|
529
|
-
request: api_pb2.FunctionGetInputsRequest = await stream.recv_message()
|
530
|
-
assert request.function_id
|
531
|
-
if self.fail_get_inputs:
|
532
|
-
raise GRPCError(Status.INTERNAL)
|
533
|
-
elif self.rate_limit_sleep_duration is not None:
|
534
|
-
s = self.rate_limit_sleep_duration
|
535
|
-
self.rate_limit_sleep_duration = None
|
536
|
-
await stream.send_message(api_pb2.FunctionGetInputsResponse(rate_limit_sleep_duration=s))
|
537
|
-
elif not self.container_inputs:
|
538
|
-
await asyncio.sleep(1.0)
|
539
|
-
await stream.send_message(api_pb2.FunctionGetInputsResponse(inputs=[]))
|
540
|
-
else:
|
541
|
-
await stream.send_message(self.container_inputs.pop(0))
|
542
|
-
|
543
|
-
async def FunctionPutOutputs(self, stream):
|
544
|
-
self.put_outputs_barrier.wait()
|
545
|
-
request: api_pb2.FunctionPutOutputsRequest = await stream.recv_message()
|
546
|
-
self.container_outputs.append(request)
|
547
|
-
await stream.send_message(Empty())
|
548
|
-
|
549
|
-
async def FunctionPrecreate(self, stream):
|
550
|
-
req: api_pb2.FunctionPrecreateRequest = await stream.recv_message()
|
551
|
-
if not req.existing_function_id:
|
552
|
-
self.n_functions += 1
|
553
|
-
function_id = f"fu-{self.n_functions}"
|
554
|
-
else:
|
555
|
-
function_id = req.existing_function_id
|
556
|
-
|
557
|
-
self.precreated_functions.add(function_id)
|
558
|
-
|
559
|
-
web_url = "http://xyz.internal" if req.HasField("webhook_config") and req.webhook_config.type else None
|
560
|
-
await stream.send_message(
|
561
|
-
api_pb2.FunctionPrecreateResponse(
|
562
|
-
function_id=function_id,
|
563
|
-
handle_metadata=api_pb2.FunctionHandleMetadata(
|
564
|
-
function_name=req.function_name,
|
565
|
-
function_type=req.function_type,
|
566
|
-
web_url=web_url,
|
567
|
-
),
|
568
|
-
)
|
569
|
-
)
|
570
|
-
|
571
|
-
async def FunctionCreate(self, stream):
|
572
|
-
request: api_pb2.FunctionCreateRequest = await stream.recv_message()
|
573
|
-
if self.function_create_error:
|
574
|
-
raise GRPCError(Status.INTERNAL, "Function create failed")
|
575
|
-
if request.existing_function_id:
|
576
|
-
function_id = request.existing_function_id
|
577
|
-
else:
|
578
|
-
self.n_functions += 1
|
579
|
-
function_id = f"fu-{self.n_functions}"
|
580
|
-
if request.schedule:
|
581
|
-
self.function2schedule[function_id] = request.schedule
|
582
|
-
function = api_pb2.Function()
|
583
|
-
function.CopyFrom(request.function)
|
584
|
-
if function.webhook_config.type:
|
585
|
-
function.web_url = "http://xyz.internal"
|
586
|
-
|
587
|
-
self.app_functions[function_id] = function
|
588
|
-
await stream.send_message(
|
589
|
-
api_pb2.FunctionCreateResponse(
|
590
|
-
function_id=function_id,
|
591
|
-
function=function,
|
592
|
-
handle_metadata=api_pb2.FunctionHandleMetadata(
|
593
|
-
function_name=function.function_name,
|
594
|
-
function_type=function.function_type,
|
595
|
-
web_url=function.web_url,
|
596
|
-
),
|
597
|
-
)
|
598
|
-
)
|
599
|
-
|
600
|
-
async def FunctionGet(self, stream):
|
601
|
-
request: api_pb2.FunctionGetRequest = await stream.recv_message()
|
602
|
-
app_id = self.deployed_apps.get(request.app_name)
|
603
|
-
app_objects = self.app_objects[app_id]
|
604
|
-
object_id = app_objects.get(request.object_tag)
|
605
|
-
if object_id is None:
|
606
|
-
raise GRPCError(Status.NOT_FOUND, f"can't find object {request.object_tag}")
|
607
|
-
await stream.send_message(
|
608
|
-
api_pb2.FunctionGetResponse(function_id=object_id, handle_metadata=self.get_function_metadata(object_id))
|
609
|
-
)
|
610
|
-
|
611
|
-
async def FunctionMap(self, stream):
|
612
|
-
self.fcidx += 1
|
613
|
-
request: api_pb2.FunctionMapRequest = await stream.recv_message()
|
614
|
-
function_call_id = f"fc-{self.fcidx}"
|
615
|
-
self.function_id_for_function_call[function_call_id] = request.function_id
|
616
|
-
await stream.send_message(api_pb2.FunctionMapResponse(function_call_id=function_call_id))
|
617
|
-
|
618
|
-
async def FunctionPutInputs(self, stream):
|
619
|
-
request: api_pb2.FunctionPutInputsRequest = await stream.recv_message()
|
620
|
-
response_items = []
|
621
|
-
function_call_inputs = self.client_calls.setdefault(request.function_call_id, [])
|
622
|
-
for item in request.inputs:
|
623
|
-
if item.input.WhichOneof("args_oneof") == "args":
|
624
|
-
args, kwargs = modal._serialization.deserialize(item.input.args, None)
|
625
|
-
else:
|
626
|
-
args, kwargs = modal._serialization.deserialize(self.blobs[item.input.args_blob_id], None)
|
627
|
-
|
628
|
-
input_id = f"in-{self.n_inputs}"
|
629
|
-
self.n_inputs += 1
|
630
|
-
response_items.append(api_pb2.FunctionPutInputsResponseItem(input_id=input_id, idx=item.idx))
|
631
|
-
function_call_inputs.append(((item.idx, input_id), (args, kwargs)))
|
632
|
-
if self.slow_put_inputs:
|
633
|
-
await asyncio.sleep(0.001)
|
634
|
-
await stream.send_message(api_pb2.FunctionPutInputsResponse(inputs=response_items))
|
635
|
-
|
636
|
-
async def FunctionGetOutputs(self, stream):
|
637
|
-
request: api_pb2.FunctionGetOutputsRequest = await stream.recv_message()
|
638
|
-
if request.clear_on_success:
|
639
|
-
self.cleared_function_calls.add(request.function_call_id)
|
640
|
-
|
641
|
-
client_calls = self.client_calls.get(request.function_call_id, [])
|
642
|
-
if client_calls and not self.function_is_running:
|
643
|
-
popidx = len(client_calls) // 2 # simulate that results don't always come in order
|
644
|
-
(idx, input_id), (args, kwargs) = client_calls.pop(popidx)
|
645
|
-
output_exc = None
|
646
|
-
try:
|
647
|
-
res = self._function_body(*args, **kwargs)
|
648
|
-
|
649
|
-
if inspect.iscoroutine(res):
|
650
|
-
result = await res
|
651
|
-
result_data_format = api_pb2.DATA_FORMAT_PICKLE
|
652
|
-
elif inspect.isgenerator(res):
|
653
|
-
count = 0
|
654
|
-
for item in res:
|
655
|
-
count += 1
|
656
|
-
await self.fc_data_out[request.function_call_id].put(
|
657
|
-
api_pb2.DataChunk(
|
658
|
-
data_format=api_pb2.DATA_FORMAT_PICKLE,
|
659
|
-
data=serialize_data_format(item, api_pb2.DATA_FORMAT_PICKLE),
|
660
|
-
index=count,
|
661
|
-
)
|
662
|
-
)
|
663
|
-
result = api_pb2.GeneratorDone(items_total=count)
|
664
|
-
result_data_format = api_pb2.DATA_FORMAT_GENERATOR_DONE
|
665
|
-
else:
|
666
|
-
result = res
|
667
|
-
result_data_format = api_pb2.DATA_FORMAT_PICKLE
|
668
|
-
except Exception as exc:
|
669
|
-
serialized_exc = cloudpickle.dumps(exc)
|
670
|
-
result = api_pb2.GenericResult(
|
671
|
-
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
|
672
|
-
data=serialized_exc,
|
673
|
-
exception=repr(exc),
|
674
|
-
traceback="".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
|
675
|
-
)
|
676
|
-
output_exc = api_pb2.FunctionGetOutputsItem(
|
677
|
-
input_id=input_id, idx=idx, result=result, gen_index=0, data_format=api_pb2.DATA_FORMAT_PICKLE
|
678
|
-
)
|
679
|
-
|
680
|
-
if output_exc:
|
681
|
-
output = output_exc
|
682
|
-
else:
|
683
|
-
serialized_data = serialize_data_format(result, result_data_format)
|
684
|
-
if self.use_blob_outputs:
|
685
|
-
blob_id = await self.next_blob_id()
|
686
|
-
self.blobs[blob_id] = serialized_data
|
687
|
-
data_kwargs = {
|
688
|
-
"data_blob_id": blob_id,
|
689
|
-
}
|
690
|
-
else:
|
691
|
-
data_kwargs = {"data": serialized_data}
|
692
|
-
output = api_pb2.FunctionGetOutputsItem(
|
693
|
-
input_id=input_id,
|
694
|
-
idx=idx,
|
695
|
-
result=api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS, **data_kwargs),
|
696
|
-
data_format=result_data_format,
|
697
|
-
)
|
698
|
-
|
699
|
-
await stream.send_message(api_pb2.FunctionGetOutputsResponse(outputs=[output]))
|
700
|
-
else:
|
701
|
-
await stream.send_message(api_pb2.FunctionGetOutputsResponse(outputs=[]))
|
702
|
-
|
703
|
-
async def FunctionGetSerialized(self, stream):
|
704
|
-
await stream.send_message(
|
705
|
-
api_pb2.FunctionGetSerializedResponse(
|
706
|
-
function_serialized=self.function_serialized,
|
707
|
-
class_serialized=self.class_serialized,
|
708
|
-
)
|
709
|
-
)
|
710
|
-
|
711
|
-
async def FunctionCallCancel(self, stream):
|
712
|
-
req = await stream.recv_message()
|
713
|
-
self.cancelled_calls.append(req.function_call_id)
|
714
|
-
await stream.send_message(Empty())
|
715
|
-
|
716
|
-
async def FunctionCallGetDataIn(self, stream):
|
717
|
-
req: api_pb2.FunctionCallGetDataRequest = await stream.recv_message()
|
718
|
-
while True:
|
719
|
-
chunk = await self.fc_data_in[req.function_call_id].get()
|
720
|
-
await stream.send_message(chunk)
|
721
|
-
|
722
|
-
async def FunctionCallGetDataOut(self, stream):
|
723
|
-
req: api_pb2.FunctionCallGetDataRequest = await stream.recv_message()
|
724
|
-
while True:
|
725
|
-
chunk = await self.fc_data_out[req.function_call_id].get()
|
726
|
-
await stream.send_message(chunk)
|
727
|
-
|
728
|
-
async def FunctionCallPutDataOut(self, stream):
|
729
|
-
req: api_pb2.FunctionCallPutDataRequest = await stream.recv_message()
|
730
|
-
for chunk in req.data_chunks:
|
731
|
-
await self.fc_data_out[req.function_call_id].put(chunk)
|
732
|
-
await stream.send_message(Empty())
|
733
|
-
|
734
|
-
### Image
|
735
|
-
|
736
|
-
async def ImageGetOrCreate(self, stream):
|
737
|
-
request: api_pb2.ImageGetOrCreateRequest = await stream.recv_message()
|
738
|
-
idx = len(self.images) + 1
|
739
|
-
image_id = f"im-{idx}"
|
740
|
-
|
741
|
-
self.images[image_id] = request.image
|
742
|
-
self.image_build_function_ids[image_id] = request.build_function_id
|
743
|
-
self.image_builder_versions[image_id] = request.builder_version
|
744
|
-
if request.force_build:
|
745
|
-
self.force_built_images.append(image_id)
|
746
|
-
await stream.send_message(api_pb2.ImageGetOrCreateResponse(image_id=image_id))
|
747
|
-
|
748
|
-
async def ImageJoinStreaming(self, stream):
|
749
|
-
await stream.recv_message()
|
750
|
-
task_log_1 = api_pb2.TaskLogs(data="hello, world\n", file_descriptor=api_pb2.FILE_DESCRIPTOR_INFO)
|
751
|
-
task_log_2 = api_pb2.TaskLogs(
|
752
|
-
task_progress=api_pb2.TaskProgress(
|
753
|
-
len=1, pos=0, progress_type=api_pb2.IMAGE_SNAPSHOT_UPLOAD, description="xyz"
|
754
|
-
)
|
755
|
-
)
|
756
|
-
await stream.send_message(api_pb2.ImageJoinStreamingResponse(task_logs=[task_log_1, task_log_2]))
|
757
|
-
await stream.send_message(
|
758
|
-
api_pb2.ImageJoinStreamingResponse(
|
759
|
-
result=api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS)
|
760
|
-
)
|
761
|
-
)
|
762
|
-
|
763
|
-
### Mount
|
764
|
-
|
765
|
-
async def MountPutFile(self, stream):
|
766
|
-
request: api_pb2.MountPutFileRequest = await stream.recv_message()
|
767
|
-
if request.WhichOneof("data_oneof") is not None:
|
768
|
-
self.files_sha2data[request.sha256_hex] = {"data": request.data, "data_blob_id": request.data_blob_id}
|
769
|
-
self.n_mount_files += 1
|
770
|
-
await stream.send_message(api_pb2.MountPutFileResponse(exists=True))
|
771
|
-
else:
|
772
|
-
await stream.send_message(api_pb2.MountPutFileResponse(exists=False))
|
773
|
-
|
774
|
-
async def MountGetOrCreate(self, stream):
|
775
|
-
request: api_pb2.MountGetOrCreateRequest = await stream.recv_message()
|
776
|
-
k = (request.deployment_name, request.namespace)
|
777
|
-
if request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_UNSPECIFIED:
|
778
|
-
if k not in self.deployed_mounts:
|
779
|
-
raise GRPCError(Status.NOT_FOUND, "Mount not found")
|
780
|
-
mount_id = self.deployed_mounts[k]
|
781
|
-
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS:
|
782
|
-
self.n_mounts += 1
|
783
|
-
mount_id = f"mo-{self.n_mounts}"
|
784
|
-
self.deployed_mounts[k] = mount_id
|
785
|
-
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_ANONYMOUS_OWNED_BY_APP:
|
786
|
-
self.n_mounts += 1
|
787
|
-
mount_id = f"mo-{self.n_mounts}"
|
788
|
-
|
789
|
-
else:
|
790
|
-
raise Exception("unsupported creation type")
|
791
|
-
|
792
|
-
mount_content = self.mount_contents[mount_id] = {}
|
793
|
-
for file in request.files:
|
794
|
-
mount_content[file.filename] = self.files_name2sha[file.filename] = file.sha256_hex
|
795
|
-
|
796
|
-
await stream.send_message(
|
797
|
-
api_pb2.MountGetOrCreateResponse(
|
798
|
-
mount_id=mount_id, handle_metadata=api_pb2.MountHandleMetadata(content_checksum_sha256_hex="deadbeef")
|
799
|
-
)
|
800
|
-
)
|
801
|
-
|
802
|
-
### Proxy
|
803
|
-
|
804
|
-
async def ProxyGetOrCreate(self, stream):
|
805
|
-
await stream.recv_message()
|
806
|
-
await stream.send_message(api_pb2.ProxyGetOrCreateResponse(proxy_id="pr-123"))
|
807
|
-
|
808
|
-
### Queue
|
809
|
-
|
810
|
-
async def QueueCreate(self, stream):
|
811
|
-
request: api_pb2.QueueCreateRequest = await stream.recv_message()
|
812
|
-
if request.existing_queue_id:
|
813
|
-
queue_id = request.existing_queue_id
|
814
|
-
else:
|
815
|
-
self.n_queues += 1
|
816
|
-
queue_id = f"qu-{self.n_queues}"
|
817
|
-
await stream.send_message(api_pb2.QueueCreateResponse(queue_id=queue_id))
|
818
|
-
|
819
|
-
async def QueueGetOrCreate(self, stream):
|
820
|
-
request: api_pb2.QueueGetOrCreateRequest = await stream.recv_message()
|
821
|
-
k = (request.deployment_name, request.namespace, request.environment_name)
|
822
|
-
if k in self.deployed_queues:
|
823
|
-
queue_id = self.deployed_queues[k]
|
824
|
-
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING:
|
825
|
-
self.n_queues += 1
|
826
|
-
queue_id = f"qu-{self.n_queues}"
|
827
|
-
self.deployed_queues[k] = queue_id
|
828
|
-
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_EPHEMERAL:
|
829
|
-
self.n_queues += 1
|
830
|
-
queue_id = f"qu-{self.n_queues}"
|
831
|
-
else:
|
832
|
-
raise GRPCError(Status.NOT_FOUND, "Queue not found")
|
833
|
-
await stream.send_message(api_pb2.QueueGetOrCreateResponse(queue_id=queue_id))
|
834
|
-
|
835
|
-
async def QueueDelete(self, stream):
|
836
|
-
request: api_pb2.QueueDeleteRequest = await stream.recv_message()
|
837
|
-
self.deployed_queues = {k: v for k, v in self.deployed_queues.items() if v != request.queue_id}
|
838
|
-
await stream.send_message(Empty())
|
839
|
-
|
840
|
-
async def QueueHeartbeat(self, stream):
|
841
|
-
await stream.recv_message()
|
842
|
-
self.n_queue_heartbeats += 1
|
843
|
-
await stream.send_message(Empty())
|
844
|
-
|
845
|
-
async def QueuePut(self, stream):
|
846
|
-
request: api_pb2.QueuePutRequest = await stream.recv_message()
|
847
|
-
if len(self.queue) >= self.queue_max_len:
|
848
|
-
raise GRPCError(Status.RESOURCE_EXHAUSTED, f"Hit servicer's max len for Queues: {self.queue_max_len}")
|
849
|
-
self.queue += request.values
|
850
|
-
await stream.send_message(Empty())
|
851
|
-
|
852
|
-
async def QueueGet(self, stream):
|
853
|
-
await stream.recv_message()
|
854
|
-
if len(self.queue) > 0:
|
855
|
-
values = [self.queue.pop(0)]
|
856
|
-
else:
|
857
|
-
values = []
|
858
|
-
await stream.send_message(api_pb2.QueueGetResponse(values=values))
|
859
|
-
|
860
|
-
async def QueueLen(self, stream):
|
861
|
-
await stream.recv_message()
|
862
|
-
await stream.send_message(api_pb2.QueueLenResponse(len=len(self.queue)))
|
863
|
-
|
864
|
-
async def QueueNextItems(self, stream):
|
865
|
-
request: api_pb2.QueueNextItemsRequest = await stream.recv_message()
|
866
|
-
next_item_idx = int(request.last_entry_id) + 1 if request.last_entry_id else 0
|
867
|
-
if next_item_idx < len(self.queue):
|
868
|
-
item = api_pb2.QueueItem(value=self.queue[next_item_idx], entry_id=f"{next_item_idx}")
|
869
|
-
await stream.send_message(api_pb2.QueueNextItemsResponse(items=[item]))
|
870
|
-
else:
|
871
|
-
if request.item_poll_timeout > 0:
|
872
|
-
await asyncio.sleep(0.1)
|
873
|
-
await stream.send_message(api_pb2.QueueNextItemsResponse(items=[]))
|
874
|
-
|
875
|
-
### Sandbox
|
876
|
-
|
877
|
-
async def SandboxCreate(self, stream):
|
878
|
-
request: api_pb2.SandboxCreateRequest = await stream.recv_message()
|
879
|
-
if request.definition.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
|
880
|
-
self.sandbox_is_interactive = True
|
881
|
-
|
882
|
-
self.sandbox = await asyncio.subprocess.create_subprocess_exec(
|
883
|
-
*request.definition.entrypoint_args,
|
884
|
-
stdout=asyncio.subprocess.PIPE,
|
885
|
-
stderr=asyncio.subprocess.PIPE,
|
886
|
-
stdin=asyncio.subprocess.PIPE,
|
887
|
-
)
|
888
|
-
|
889
|
-
self.sandbox_defs.append(request.definition)
|
890
|
-
|
891
|
-
await stream.send_message(api_pb2.SandboxCreateResponse(sandbox_id="sb-123"))
|
892
|
-
|
893
|
-
async def SandboxGetLogs(self, stream):
|
894
|
-
request: api_pb2.SandboxGetLogsRequest = await stream.recv_message()
|
895
|
-
f: asyncio.StreamReader
|
896
|
-
if self.sandbox_is_interactive:
|
897
|
-
# sends an empty message to simulate PTY
|
898
|
-
await stream.send_message(
|
899
|
-
api_pb2.TaskLogsBatch(
|
900
|
-
items=[api_pb2.TaskLogs(data=self.sandbox_shell_prompt, file_descriptor=request.file_descriptor)]
|
901
|
-
)
|
902
|
-
)
|
903
|
-
|
904
|
-
if request.file_descriptor == api_pb2.FILE_DESCRIPTOR_STDOUT:
|
905
|
-
# Blocking read until EOF is returned.
|
906
|
-
f = self.sandbox.stdout
|
907
|
-
else:
|
908
|
-
f = self.sandbox.stderr
|
909
|
-
|
910
|
-
async for message in f:
|
911
|
-
await stream.send_message(
|
912
|
-
api_pb2.TaskLogsBatch(
|
913
|
-
items=[api_pb2.TaskLogs(data=message.decode("utf-8"), file_descriptor=request.file_descriptor)]
|
914
|
-
)
|
915
|
-
)
|
916
|
-
|
917
|
-
await stream.send_message(api_pb2.TaskLogsBatch(eof=True))
|
918
|
-
|
919
|
-
async def SandboxWait(self, stream):
|
920
|
-
request: api_pb2.SandboxWaitRequest = await stream.recv_message()
|
921
|
-
try:
|
922
|
-
await asyncio.wait_for(self.sandbox.wait(), request.timeout)
|
923
|
-
except asyncio.TimeoutError:
|
924
|
-
pass
|
925
|
-
|
926
|
-
if self.sandbox.returncode is None:
|
927
|
-
# This happens when request.timeout is 0 and the sandbox hasn't completed.
|
928
|
-
await stream.send_message(api_pb2.SandboxWaitResponse())
|
929
|
-
return
|
930
|
-
elif self.sandbox.returncode != 0:
|
931
|
-
result = api_pb2.GenericResult(
|
932
|
-
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE, exitcode=self.sandbox.returncode
|
933
|
-
)
|
934
|
-
else:
|
935
|
-
result = api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS)
|
936
|
-
self.sandbox_result = result
|
937
|
-
await stream.send_message(api_pb2.SandboxWaitResponse(result=result))
|
938
|
-
|
939
|
-
async def SandboxTerminate(self, stream):
|
940
|
-
self.sandbox.terminate()
|
941
|
-
await stream.send_message(api_pb2.SandboxTerminateResponse())
|
942
|
-
|
943
|
-
async def SandboxGetTaskId(self, stream):
|
944
|
-
# only used for `modal shell` / `modal container exec`
|
945
|
-
_request: api_pb2.SandboxGetTaskIdRequest = await stream.recv_message()
|
946
|
-
await stream.send_message(api_pb2.SandboxGetTaskIdResponse(task_id="modal_container_exec"))
|
947
|
-
|
948
|
-
async def SandboxStdinWrite(self, stream):
|
949
|
-
request: api_pb2.SandboxStdinWriteRequest = await stream.recv_message()
|
950
|
-
|
951
|
-
self.sandbox.stdin.write(request.input)
|
952
|
-
await self.sandbox.stdin.drain()
|
953
|
-
|
954
|
-
if request.eof:
|
955
|
-
self.sandbox.stdin.close()
|
956
|
-
await stream.send_message(api_pb2.SandboxStdinWriteResponse())
|
957
|
-
|
958
|
-
### Secret
|
959
|
-
|
960
|
-
async def SecretGetOrCreate(self, stream):
|
961
|
-
request: api_pb2.SecretGetOrCreateRequest = await stream.recv_message()
|
962
|
-
k = (request.deployment_name, request.namespace, request.environment_name)
|
963
|
-
if request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_ANONYMOUS_OWNED_BY_APP:
|
964
|
-
secret_id = "st-" + str(len(self.secrets))
|
965
|
-
self.secrets[secret_id] = request.env_dict
|
966
|
-
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS:
|
967
|
-
if k in self.deployed_secrets:
|
968
|
-
raise GRPCError(Status.ALREADY_EXISTS, "Already exists")
|
969
|
-
secret_id = None
|
970
|
-
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_OVERWRITE_IF_EXISTS:
|
971
|
-
secret_id = None
|
972
|
-
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_UNSPECIFIED:
|
973
|
-
if k not in self.deployed_secrets:
|
974
|
-
raise GRPCError(Status.NOT_FOUND, "No such secret")
|
975
|
-
secret_id = self.deployed_secrets[k]
|
976
|
-
else:
|
977
|
-
raise Exception("unsupported creation type")
|
978
|
-
|
979
|
-
if secret_id is None: # Create one
|
980
|
-
secret_id = "st-" + str(len(self.secrets))
|
981
|
-
self.secrets[secret_id] = request.env_dict
|
982
|
-
self.deployed_secrets[k] = secret_id
|
983
|
-
|
984
|
-
await stream.send_message(api_pb2.SecretGetOrCreateResponse(secret_id=secret_id))
|
985
|
-
|
986
|
-
async def SecretList(self, stream):
|
987
|
-
await stream.recv_message()
|
988
|
-
items = [api_pb2.SecretListItem(label=f"dummy-secret-{i}") for i, _ in enumerate(self.secrets)]
|
989
|
-
await stream.send_message(api_pb2.SecretListResponse(items=items))
|
990
|
-
|
991
|
-
### Network File System (née Shared volume)
|
992
|
-
|
993
|
-
async def SharedVolumeCreate(self, stream):
|
994
|
-
nfs_id = f"sv-{len(self.nfs_files)}"
|
995
|
-
self.nfs_files[nfs_id] = {}
|
996
|
-
await stream.send_message(api_pb2.SharedVolumeCreateResponse(shared_volume_id=nfs_id))
|
997
|
-
|
998
|
-
async def SharedVolumeGetOrCreate(self, stream):
|
999
|
-
request: api_pb2.SharedVolumeGetOrCreateRequest = await stream.recv_message()
|
1000
|
-
k = (request.deployment_name, request.namespace, request.environment_name)
|
1001
|
-
if request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_UNSPECIFIED:
|
1002
|
-
if k not in self.deployed_nfss:
|
1003
|
-
raise GRPCError(Status.NOT_FOUND, "NFS not found")
|
1004
|
-
nfs_id = self.deployed_nfss[k]
|
1005
|
-
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_EPHEMERAL:
|
1006
|
-
nfs_id = f"sv-{len(self.nfs_files)}"
|
1007
|
-
self.nfs_files[nfs_id] = {}
|
1008
|
-
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING:
|
1009
|
-
if k not in self.deployed_nfss:
|
1010
|
-
nfs_id = f"sv-{len(self.nfs_files)}"
|
1011
|
-
self.nfs_files[nfs_id] = {}
|
1012
|
-
self.deployed_nfss[k] = nfs_id
|
1013
|
-
nfs_id = self.deployed_nfss[k]
|
1014
|
-
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS:
|
1015
|
-
if k in self.deployed_nfss:
|
1016
|
-
raise GRPCError(Status.ALREADY_EXISTS, "NFS already exists")
|
1017
|
-
nfs_id = f"sv-{len(self.nfs_files)}"
|
1018
|
-
self.nfs_files[nfs_id] = {}
|
1019
|
-
self.deployed_nfss[k] = nfs_id
|
1020
|
-
else:
|
1021
|
-
raise GRPCError(Status.INVALID_ARGUMENT, "unsupported object creation type")
|
1022
|
-
|
1023
|
-
await stream.send_message(api_pb2.SharedVolumeGetOrCreateResponse(shared_volume_id=nfs_id))
|
1024
|
-
|
1025
|
-
async def SharedVolumeHeartbeat(self, stream):
|
1026
|
-
await stream.recv_message()
|
1027
|
-
self.n_nfs_heartbeats += 1
|
1028
|
-
await stream.send_message(Empty())
|
1029
|
-
|
1030
|
-
async def SharedVolumePutFile(self, stream):
|
1031
|
-
req = await stream.recv_message()
|
1032
|
-
self.nfs_files[req.shared_volume_id][req.path] = req
|
1033
|
-
await stream.send_message(api_pb2.SharedVolumePutFileResponse(exists=True))
|
1034
|
-
|
1035
|
-
async def SharedVolumeGetFile(self, stream):
|
1036
|
-
req = await stream.recv_message()
|
1037
|
-
put_req = self.nfs_files.get(req.shared_volume_id, {}).get(req.path)
|
1038
|
-
if not put_req:
|
1039
|
-
raise GRPCError(Status.NOT_FOUND, f"No such file: {req.path}")
|
1040
|
-
if put_req.data_blob_id:
|
1041
|
-
await stream.send_message(api_pb2.SharedVolumeGetFileResponse(data_blob_id=put_req.data_blob_id))
|
1042
|
-
else:
|
1043
|
-
await stream.send_message(api_pb2.SharedVolumeGetFileResponse(data=put_req.data))
|
1044
|
-
|
1045
|
-
async def SharedVolumeListFilesStream(self, stream):
|
1046
|
-
req: api_pb2.SharedVolumeListFilesRequest = await stream.recv_message()
|
1047
|
-
for path in self.nfs_files[req.shared_volume_id].keys():
|
1048
|
-
entry = api_pb2.FileEntry(path=path, type=api_pb2.FileEntry.FileType.FILE)
|
1049
|
-
response = api_pb2.SharedVolumeListFilesResponse(entries=[entry])
|
1050
|
-
if req.path == "**" or req.path == "/" or req.path == path: # hack
|
1051
|
-
await stream.send_message(response)
|
1052
|
-
|
1053
|
-
### Task
|
1054
|
-
|
1055
|
-
async def TaskCurrentInputs(
|
1056
|
-
self, stream: "grpclib.server.Stream[Empty, api_pb2.TaskCurrentInputsResponse]"
|
1057
|
-
) -> None:
|
1058
|
-
await stream.send_message(api_pb2.TaskCurrentInputsResponse(input_ids=[])) # dummy implementation
|
1059
|
-
|
1060
|
-
async def TaskResult(self, stream):
|
1061
|
-
request: api_pb2.TaskResultRequest = await stream.recv_message()
|
1062
|
-
self.task_result = request.result
|
1063
|
-
await stream.send_message(Empty())
|
1064
|
-
|
1065
|
-
### Token flow
|
1066
|
-
|
1067
|
-
async def TokenFlowCreate(self, stream):
|
1068
|
-
request: api_pb2.TokenFlowCreateRequest = await stream.recv_message()
|
1069
|
-
self.token_flow_localhost_port = request.localhost_port
|
1070
|
-
await stream.send_message(
|
1071
|
-
api_pb2.TokenFlowCreateResponse(token_flow_id="tc-123", web_url="https://localhost/xyz/abc")
|
1072
|
-
)
|
1073
|
-
|
1074
|
-
async def TokenFlowWait(self, stream):
|
1075
|
-
await stream.send_message(
|
1076
|
-
api_pb2.TokenFlowWaitResponse(
|
1077
|
-
token_id="abc",
|
1078
|
-
token_secret="xyz",
|
1079
|
-
)
|
1080
|
-
)
|
1081
|
-
|
1082
|
-
async def WorkspaceNameLookup(self, stream):
|
1083
|
-
await stream.send_message(api_pb2.WorkspaceNameLookupResponse(username="test-username"))
|
1084
|
-
|
1085
|
-
### Tunnel
|
1086
|
-
|
1087
|
-
async def TunnelStart(self, stream):
|
1088
|
-
request: api_pb2.TunnelStartRequest = await stream.recv_message()
|
1089
|
-
port = request.port
|
1090
|
-
await stream.send_message(api_pb2.TunnelStartResponse(host=f"{port}.modal.test", port=443))
|
1091
|
-
|
1092
|
-
async def TunnelStop(self, stream):
|
1093
|
-
await stream.recv_message()
|
1094
|
-
await stream.send_message(api_pb2.TunnelStopResponse(exists=True))
|
1095
|
-
|
1096
|
-
### Volume
|
1097
|
-
|
1098
|
-
async def VolumeCreate(self, stream):
|
1099
|
-
req = await stream.recv_message()
|
1100
|
-
self.requests.append(req)
|
1101
|
-
self.volume_counter += 1
|
1102
|
-
volume_id = f"vo-{self.volume_counter}"
|
1103
|
-
self.volume_files[volume_id] = {}
|
1104
|
-
await stream.send_message(api_pb2.VolumeCreateResponse(volume_id=volume_id))
|
1105
|
-
|
1106
|
-
async def VolumeGetOrCreate(self, stream):
|
1107
|
-
request: api_pb2.VolumeGetOrCreateRequest = await stream.recv_message()
|
1108
|
-
k = (request.deployment_name, request.namespace, request.environment_name)
|
1109
|
-
if request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_UNSPECIFIED:
|
1110
|
-
if k not in self.deployed_volumes:
|
1111
|
-
raise GRPCError(Status.NOT_FOUND, "Volume not found")
|
1112
|
-
volume_id = self.deployed_volumes[k]
|
1113
|
-
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_EPHEMERAL:
|
1114
|
-
volume_id = f"vo-{len(self.volume_files)}"
|
1115
|
-
self.volume_files[volume_id] = {}
|
1116
|
-
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING:
|
1117
|
-
if k not in self.deployed_volumes:
|
1118
|
-
volume_id = f"vo-{len(self.volume_files)}"
|
1119
|
-
self.volume_files[volume_id] = {}
|
1120
|
-
self.deployed_volumes[k] = volume_id
|
1121
|
-
volume_id = self.deployed_volumes[k]
|
1122
|
-
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS:
|
1123
|
-
if k in self.deployed_volumes:
|
1124
|
-
raise GRPCError(Status.ALREADY_EXISTS, "Volume already exists")
|
1125
|
-
volume_id = f"vo-{len(self.volume_files)}"
|
1126
|
-
self.volume_files[volume_id] = {}
|
1127
|
-
self.deployed_volumes[k] = volume_id
|
1128
|
-
else:
|
1129
|
-
raise GRPCError(Status.INVALID_ARGUMENT, "unsupported object creation type")
|
1130
|
-
|
1131
|
-
await stream.send_message(api_pb2.VolumeGetOrCreateResponse(volume_id=volume_id))
|
1132
|
-
|
1133
|
-
async def VolumeHeartbeat(self, stream):
|
1134
|
-
await stream.recv_message()
|
1135
|
-
self.n_vol_heartbeats += 1
|
1136
|
-
await stream.send_message(Empty())
|
1137
|
-
|
1138
|
-
async def VolumeCommit(self, stream):
|
1139
|
-
req = await stream.recv_message()
|
1140
|
-
self.requests.append(req)
|
1141
|
-
if not req.volume_id.startswith("vo-"):
|
1142
|
-
raise GRPCError(Status.NOT_FOUND, f"invalid volume ID {req.volume_id}")
|
1143
|
-
self.volume_commits[req.volume_id] += 1
|
1144
|
-
await stream.send_message(api_pb2.VolumeCommitResponse(skip_reload=False))
|
1145
|
-
|
1146
|
-
async def VolumeDelete(self, stream):
|
1147
|
-
req: api_pb2.VolumeDeleteRequest = await stream.recv_message()
|
1148
|
-
self.volume_files.pop(req.volume_id)
|
1149
|
-
self.deployed_volumes = {k: vol_id for k, vol_id in self.deployed_volumes.items() if vol_id != req.volume_id}
|
1150
|
-
await stream.send_message(Empty())
|
1151
|
-
|
1152
|
-
async def VolumeReload(self, stream):
|
1153
|
-
req = await stream.recv_message()
|
1154
|
-
self.requests.append(req)
|
1155
|
-
self.volume_reloads[req.volume_id] += 1
|
1156
|
-
await stream.send_message(Empty())
|
1157
|
-
|
1158
|
-
async def VolumeGetFile(self, stream):
|
1159
|
-
req = await stream.recv_message()
|
1160
|
-
if req.path.decode("utf-8") not in self.volume_files[req.volume_id]:
|
1161
|
-
raise GRPCError(Status.NOT_FOUND, "File not found")
|
1162
|
-
vol_file = self.volume_files[req.volume_id][req.path.decode("utf-8")]
|
1163
|
-
if vol_file.data_blob_id:
|
1164
|
-
await stream.send_message(api_pb2.VolumeGetFileResponse(data_blob_id=vol_file.data_blob_id))
|
1165
|
-
else:
|
1166
|
-
size = len(vol_file.data)
|
1167
|
-
if req.start or req.len:
|
1168
|
-
start = req.start
|
1169
|
-
len_ = req.len or len(vol_file.data)
|
1170
|
-
await stream.send_message(
|
1171
|
-
api_pb2.VolumeGetFileResponse(data=vol_file.data[start : start + len_], size=size)
|
1172
|
-
)
|
1173
|
-
else:
|
1174
|
-
await stream.send_message(api_pb2.VolumeGetFileResponse(data=vol_file.data, size=size))
|
1175
|
-
|
1176
|
-
async def VolumeRemoveFile(self, stream):
|
1177
|
-
req = await stream.recv_message()
|
1178
|
-
if req.path.decode("utf-8") not in self.volume_files[req.volume_id]:
|
1179
|
-
raise GRPCError(Status.INVALID_ARGUMENT, "File not found")
|
1180
|
-
del self.volume_files[req.volume_id][req.path.decode("utf-8")]
|
1181
|
-
await stream.send_message(Empty())
|
1182
|
-
|
1183
|
-
async def VolumeListFiles(self, stream):
|
1184
|
-
req = await stream.recv_message()
|
1185
|
-
path = req.path if req.path else "/"
|
1186
|
-
if path.startswith("/"):
|
1187
|
-
path = path[1:]
|
1188
|
-
if path.endswith("/"):
|
1189
|
-
path = path[:-1]
|
1190
|
-
|
1191
|
-
found_file = False # empty directory detection is not handled here!
|
1192
|
-
for k, vol_file in self.volume_files[req.volume_id].items():
|
1193
|
-
if not path or k == path or (k.startswith(path + "/") and (req.recursive or "/" not in k[len(path) + 1 :])):
|
1194
|
-
entry = api_pb2.FileEntry(path=k, type=api_pb2.FileEntry.FileType.FILE, size=len(vol_file.data))
|
1195
|
-
await stream.send_message(api_pb2.VolumeListFilesResponse(entries=[entry]))
|
1196
|
-
found_file = True
|
1197
|
-
|
1198
|
-
if path and not found_file:
|
1199
|
-
raise GRPCError(Status.NOT_FOUND, "No such file")
|
1200
|
-
|
1201
|
-
async def VolumePutFiles(self, stream):
|
1202
|
-
req = await stream.recv_message()
|
1203
|
-
for file in req.files:
|
1204
|
-
blob_data = self.files_sha2data[file.sha256_hex]
|
1205
|
-
|
1206
|
-
if file.filename in self.volume_files[req.volume_id] and req.disallow_overwrite_existing_files:
|
1207
|
-
raise GRPCError(
|
1208
|
-
Status.ALREADY_EXISTS,
|
1209
|
-
f"{file.filename}: already exists (disallow_overwrite_existing_files={req.disallow_overwrite_existing_files}",
|
1210
|
-
)
|
1211
|
-
|
1212
|
-
self.volume_files[req.volume_id][file.filename] = VolumeFile(
|
1213
|
-
data=blob_data["data"],
|
1214
|
-
data_blob_id=blob_data["data_blob_id"],
|
1215
|
-
mode=file.mode,
|
1216
|
-
)
|
1217
|
-
await stream.send_message(Empty())
|
1218
|
-
|
1219
|
-
async def VolumeCopyFiles(self, stream):
|
1220
|
-
req = await stream.recv_message()
|
1221
|
-
for src_path in req.src_paths:
|
1222
|
-
if src_path.decode("utf-8") not in self.volume_files[req.volume_id]:
|
1223
|
-
raise GRPCError(Status.NOT_FOUND, f"Source file not found: {src_path}")
|
1224
|
-
src_file = self.volume_files[req.volume_id][src_path.decode("utf-8")]
|
1225
|
-
if len(req.src_paths) > 1:
|
1226
|
-
# check to make sure dst is a directory
|
1227
|
-
if (
|
1228
|
-
req.dst_path.decode("utf-8").endswith(("/", "\\"))
|
1229
|
-
or not os.path.splitext(os.path.basename(req.dst_path))[1]
|
1230
|
-
):
|
1231
|
-
dst_path = os.path.join(req.dst_path, os.path.basename(src_path))
|
1232
|
-
else:
|
1233
|
-
raise GRPCError(Status.INVALID_ARGUMENT, f"{dst_path} is not a directory.")
|
1234
|
-
else:
|
1235
|
-
dst_path = req.dst_path
|
1236
|
-
self.volume_files[req.volume_id][dst_path.decode("utf-8")] = src_file
|
1237
|
-
await stream.send_message(Empty())
|
1238
|
-
|
1239
|
-
|
1240
|
-
@pytest_asyncio.fixture
|
1241
|
-
async def blob_server():
|
1242
|
-
blobs = {}
|
1243
|
-
blob_parts: Dict[str, Dict[int, bytes]] = defaultdict(dict)
|
1244
|
-
|
1245
|
-
async def upload(request):
|
1246
|
-
blob_id = request.query["blob_id"]
|
1247
|
-
content = await request.content.read()
|
1248
|
-
if content == b"FAILURE":
|
1249
|
-
return aiohttp.web.Response(status=500)
|
1250
|
-
content_md5 = hashlib.md5(content).hexdigest()
|
1251
|
-
etag = f'"{content_md5}"'
|
1252
|
-
if "part_number" in request.query:
|
1253
|
-
part_number = int(request.query["part_number"])
|
1254
|
-
blob_parts[blob_id][part_number] = content
|
1255
|
-
else:
|
1256
|
-
blobs[blob_id] = content
|
1257
|
-
return aiohttp.web.Response(text="Hello, world", headers={"ETag": etag})
|
1258
|
-
|
1259
|
-
async def complete_multipart(request):
|
1260
|
-
blob_id = request.query["blob_id"]
|
1261
|
-
blob_nums = range(min(blob_parts[blob_id].keys()), max(blob_parts[blob_id].keys()) + 1)
|
1262
|
-
content = b""
|
1263
|
-
part_hashes = b""
|
1264
|
-
for num in blob_nums:
|
1265
|
-
part_content = blob_parts[blob_id][num]
|
1266
|
-
content += part_content
|
1267
|
-
part_hashes += hashlib.md5(part_content).digest()
|
1268
|
-
|
1269
|
-
content_md5 = hashlib.md5(part_hashes).hexdigest()
|
1270
|
-
etag = f'"{content_md5}-{len(blob_parts[blob_id])}"'
|
1271
|
-
blobs[blob_id] = content
|
1272
|
-
return aiohttp.web.Response(text=f"<etag>{etag}</etag>")
|
1273
|
-
|
1274
|
-
async def download(request):
|
1275
|
-
blob_id = request.query["blob_id"]
|
1276
|
-
if blob_id == "bl-failure":
|
1277
|
-
return aiohttp.web.Response(status=500)
|
1278
|
-
return aiohttp.web.Response(body=blobs[blob_id])
|
1279
|
-
|
1280
|
-
app = aiohttp.web.Application()
|
1281
|
-
app.add_routes([aiohttp.web.put("/upload", upload)])
|
1282
|
-
app.add_routes([aiohttp.web.get("/download", download)])
|
1283
|
-
app.add_routes([aiohttp.web.post("/complete_multipart", complete_multipart)])
|
1284
|
-
|
1285
|
-
async with run_temporary_http_server(app) as host:
|
1286
|
-
yield host, blobs
|
1287
|
-
|
1288
|
-
|
1289
|
-
@pytest_asyncio.fixture(scope="function")
|
1290
|
-
async def servicer_factory(blob_server):
|
1291
|
-
@contextlib.asynccontextmanager
|
1292
|
-
async def create_server(host=None, port=None, path=None):
|
1293
|
-
blob_host, blobs = blob_server
|
1294
|
-
servicer = MockClientServicer(blob_host, blobs) # type: ignore
|
1295
|
-
server = None
|
1296
|
-
|
1297
|
-
async def _start_servicer():
|
1298
|
-
nonlocal server
|
1299
|
-
server = grpclib.server.Server([servicer])
|
1300
|
-
await server.start(host=host, port=port, path=path)
|
1301
|
-
|
1302
|
-
async def _stop_servicer():
|
1303
|
-
servicer.container_heartbeat_abort.set()
|
1304
|
-
server.close()
|
1305
|
-
# This is the proper way to close down the asyncio server,
|
1306
|
-
# but it causes our tests to hang on 3.12+ because client connections
|
1307
|
-
# for clients created through _Client.from_env don't get closed until
|
1308
|
-
# asyncio event loop shutdown. Commenting out but perhaps revisit if we
|
1309
|
-
# refactor the way that _Client cleanup happens.
|
1310
|
-
# await server.wait_closed()
|
1311
|
-
|
1312
|
-
start_servicer = synchronize_api(_start_servicer)
|
1313
|
-
stop_servicer = synchronize_api(_stop_servicer)
|
1314
|
-
|
1315
|
-
await start_servicer.aio()
|
1316
|
-
try:
|
1317
|
-
yield servicer
|
1318
|
-
finally:
|
1319
|
-
await stop_servicer.aio()
|
1320
|
-
|
1321
|
-
yield create_server
|
1322
|
-
|
1323
|
-
|
1324
|
-
@pytest_asyncio.fixture(scope="function")
|
1325
|
-
async def servicer(servicer_factory):
|
1326
|
-
port = find_free_port()
|
1327
|
-
async with servicer_factory(host="0.0.0.0", port=port) as servicer:
|
1328
|
-
servicer.remote_addr = f"http://127.0.0.1:{port}"
|
1329
|
-
yield servicer
|
1330
|
-
|
1331
|
-
|
1332
|
-
@pytest_asyncio.fixture(scope="function")
|
1333
|
-
async def unix_servicer(servicer_factory):
|
1334
|
-
with tempfile.TemporaryDirectory() as tmpdirname:
|
1335
|
-
path = os.path.join(tmpdirname, "servicer.sock")
|
1336
|
-
async with servicer_factory(path=path) as servicer:
|
1337
|
-
servicer.remote_addr = f"unix://{path}"
|
1338
|
-
yield servicer
|
1339
|
-
|
1340
|
-
|
1341
|
-
@pytest_asyncio.fixture(scope="function")
|
1342
|
-
async def client(servicer):
|
1343
|
-
with Client(servicer.remote_addr, api_pb2.CLIENT_TYPE_CLIENT, ("foo-id", "foo-secret")) as client:
|
1344
|
-
yield client
|
1345
|
-
|
1346
|
-
|
1347
|
-
@pytest_asyncio.fixture(scope="function")
|
1348
|
-
async def container_client(unix_servicer):
|
1349
|
-
async with Client(unix_servicer.remote_addr, api_pb2.CLIENT_TYPE_CONTAINER, ("ta-123", "task-secret")) as client:
|
1350
|
-
yield client
|
1351
|
-
|
1352
|
-
|
1353
|
-
@pytest_asyncio.fixture(scope="function")
|
1354
|
-
async def server_url_env(servicer, monkeypatch):
|
1355
|
-
monkeypatch.setenv("MODAL_SERVER_URL", servicer.remote_addr)
|
1356
|
-
yield
|
1357
|
-
|
1358
|
-
|
1359
|
-
@pytest_asyncio.fixture(scope="function", autouse=True)
|
1360
|
-
async def reset_default_client():
|
1361
|
-
Client.set_env_client(None)
|
1362
|
-
|
1363
|
-
|
1364
|
-
@pytest.fixture(name="mock_dir", scope="session")
|
1365
|
-
def mock_dir_factory():
|
1366
|
-
"""Sets up a temp dir with content as specified in a nested dict
|
1367
|
-
|
1368
|
-
Example usage:
|
1369
|
-
spec = {
|
1370
|
-
"foo": {
|
1371
|
-
"bar.txt": "some content"
|
1372
|
-
},
|
1373
|
-
}
|
1374
|
-
|
1375
|
-
with mock_dir(spec) as root_dir:
|
1376
|
-
assert os.path.exists(os.path.join(root_dir, "foo", "bar.txt"))
|
1377
|
-
"""
|
1378
|
-
|
1379
|
-
@contextlib.contextmanager
|
1380
|
-
def mock_dir(root_spec):
|
1381
|
-
def rec_make(dir, dir_spec):
|
1382
|
-
for filename, spec in dir_spec.items():
|
1383
|
-
path = os.path.join(dir, filename)
|
1384
|
-
if isinstance(spec, str):
|
1385
|
-
with open(path, "w") as f:
|
1386
|
-
f.write(spec)
|
1387
|
-
else:
|
1388
|
-
os.mkdir(path)
|
1389
|
-
rec_make(path, spec)
|
1390
|
-
|
1391
|
-
# Windows has issues cleaning up TempDirectory: https://www.scivision.dev/python-tempfile-permission-error-windows
|
1392
|
-
# Seems to have been fixed for some python versions in https://github.com/python/cpython/pull/10320.
|
1393
|
-
root_dir = tempfile.mkdtemp()
|
1394
|
-
rec_make(root_dir, root_spec)
|
1395
|
-
cwd = os.getcwd()
|
1396
|
-
try:
|
1397
|
-
os.chdir(root_dir)
|
1398
|
-
yield
|
1399
|
-
finally:
|
1400
|
-
os.chdir(cwd)
|
1401
|
-
shutil.rmtree(root_dir, ignore_errors=True)
|
1402
|
-
|
1403
|
-
return mock_dir
|
1404
|
-
|
1405
|
-
|
1406
|
-
@pytest.fixture(autouse=True)
|
1407
|
-
def reset_sys_modules():
|
1408
|
-
# Needed since some tests will import dynamic modules
|
1409
|
-
backup = sys.modules.copy()
|
1410
|
-
try:
|
1411
|
-
yield
|
1412
|
-
finally:
|
1413
|
-
sys.modules = backup
|
1414
|
-
|
1415
|
-
|
1416
|
-
@pytest.fixture(autouse=True)
|
1417
|
-
def reset_container_app():
|
1418
|
-
try:
|
1419
|
-
yield
|
1420
|
-
finally:
|
1421
|
-
_ContainerIOManager._reset_singleton()
|
1422
|
-
|
1423
|
-
|
1424
|
-
@pytest.fixture
|
1425
|
-
def repo_root(request):
|
1426
|
-
return Path(request.config.rootdir)
|
1427
|
-
|
1428
|
-
|
1429
|
-
@pytest.fixture(scope="module")
|
1430
|
-
def test_dir(request):
|
1431
|
-
"""Absolute path to directory containing test file."""
|
1432
|
-
root_dir = Path(request.config.rootdir)
|
1433
|
-
test_dir = Path(os.getenv("PYTEST_CURRENT_TEST")).parent
|
1434
|
-
return root_dir / test_dir
|
1435
|
-
|
1436
|
-
|
1437
|
-
@pytest.fixture(scope="function")
|
1438
|
-
def modal_config():
|
1439
|
-
"""Return a context manager with a temporary modal.toml file"""
|
1440
|
-
|
1441
|
-
@contextlib.contextmanager
|
1442
|
-
def mock_modal_toml(contents: str = "", show_on_error: bool = False):
|
1443
|
-
# Some of the cli tests run within within the main process
|
1444
|
-
# so we need to modify the config singletons to pick up any changes
|
1445
|
-
orig_config_path_env = os.environ.get("MODAL_CONFIG_PATH")
|
1446
|
-
orig_config_path = config.user_config_path
|
1447
|
-
orig_profile = config._profile
|
1448
|
-
try:
|
1449
|
-
with tempfile.NamedTemporaryFile(delete=False, suffix=".toml", mode="w") as t:
|
1450
|
-
t.write(textwrap.dedent(contents.strip("\n")))
|
1451
|
-
os.environ["MODAL_CONFIG_PATH"] = t.name
|
1452
|
-
config.user_config_path = t.name
|
1453
|
-
config._user_config = config._read_user_config()
|
1454
|
-
config._profile = config._config_active_profile()
|
1455
|
-
yield t.name
|
1456
|
-
except Exception:
|
1457
|
-
if show_on_error:
|
1458
|
-
with open(t.name) as f:
|
1459
|
-
print(f"Test config file contents:\n\n{f.read()}", file=sys.stderr)
|
1460
|
-
raise
|
1461
|
-
finally:
|
1462
|
-
if orig_config_path_env:
|
1463
|
-
os.environ["MODAL_CONFIG_PATH"] = orig_config_path_env
|
1464
|
-
else:
|
1465
|
-
del os.environ["MODAL_CONFIG_PATH"]
|
1466
|
-
config.user_config_path = orig_config_path
|
1467
|
-
config._user_config = config._read_user_config()
|
1468
|
-
config._profile = orig_profile
|
1469
|
-
os.remove(t.name)
|
1470
|
-
|
1471
|
-
return mock_modal_toml
|
1472
|
-
|
1473
|
-
|
1474
|
-
@pytest.fixture
|
1475
|
-
def supports_dir(test_dir):
|
1476
|
-
return test_dir / Path("supports")
|
1477
|
-
|
1478
|
-
|
1479
|
-
@pytest_asyncio.fixture
|
1480
|
-
async def set_env_client(client):
|
1481
|
-
try:
|
1482
|
-
Client.set_env_client(client)
|
1483
|
-
yield
|
1484
|
-
finally:
|
1485
|
-
Client.set_env_client(None)
|