modal 0.62.115__py3-none-any.whl → 0.72.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +13 -9
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +407 -398
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -60
  11. modal/_resources.py +26 -7
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1036 -0
  15. modal/{execution_context.py → _runtime/execution_context.py} +11 -2
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +123 -6
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +50 -14
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +386 -104
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +299 -98
  29. modal/_utils/grpc_testing.py +47 -34
  30. modal/_utils/grpc_utils.py +54 -21
  31. modal/_utils/hash_utils.py +51 -10
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +3 -3
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +12 -10
  43. modal/app.py +561 -323
  44. modal/app.pyi +474 -262
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +22 -6
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +203 -42
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +61 -13
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +21 -48
  55. modal/cli/launch.py +28 -14
  56. modal/cli/network_file_system.py +57 -21
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +34 -9
  59. modal/cli/programs/vscode.py +58 -8
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +199 -96
  62. modal/cli/secret.py +5 -4
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +74 -8
  65. modal/cli/volume.py +97 -56
  66. modal/client.py +248 -144
  67. modal/client.pyi +156 -124
  68. modal/cloud_bucket_mount.py +43 -30
  69. modal/cloud_bucket_mount.pyi +32 -25
  70. modal/cls.py +528 -141
  71. modal/cls.pyi +189 -145
  72. modal/config.py +32 -15
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +50 -54
  76. modal/dict.pyi +120 -164
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +30 -43
  80. modal/experimental.py +62 -2
  81. modal/file_io.py +537 -0
  82. modal/file_io.pyi +235 -0
  83. modal/file_pattern_matcher.py +197 -0
  84. modal/functions.py +846 -428
  85. modal/functions.pyi +446 -387
  86. modal/gpu.py +57 -44
  87. modal/image.py +946 -417
  88. modal/image.pyi +584 -245
  89. modal/io_streams.py +434 -0
  90. modal/io_streams.pyi +122 -0
  91. modal/mount.py +223 -90
  92. modal/mount.pyi +241 -243
  93. modal/network_file_system.py +85 -86
  94. modal/network_file_system.pyi +151 -110
  95. modal/object.py +66 -36
  96. modal/object.pyi +166 -143
  97. modal/output.py +63 -0
  98. modal/parallel_map.py +73 -47
  99. modal/parallel_map.pyi +51 -63
  100. modal/partial_function.py +272 -107
  101. modal/partial_function.pyi +219 -120
  102. modal/proxy.py +15 -12
  103. modal/proxy.pyi +3 -8
  104. modal/queue.py +96 -72
  105. modal/queue.pyi +210 -135
  106. modal/requirements/2024.04.txt +2 -1
  107. modal/requirements/2024.10.txt +16 -0
  108. modal/requirements/README.md +21 -0
  109. modal/requirements/base-images.json +22 -0
  110. modal/retries.py +45 -4
  111. modal/runner.py +325 -203
  112. modal/runner.pyi +124 -110
  113. modal/running_app.py +27 -4
  114. modal/sandbox.py +509 -231
  115. modal/sandbox.pyi +396 -169
  116. modal/schedule.py +2 -2
  117. modal/scheduler_placement.py +20 -3
  118. modal/secret.py +41 -25
  119. modal/secret.pyi +62 -42
  120. modal/serving.py +39 -49
  121. modal/serving.pyi +37 -43
  122. modal/stream_type.py +15 -0
  123. modal/token_flow.py +5 -3
  124. modal/token_flow.pyi +37 -32
  125. modal/volume.py +123 -137
  126. modal/volume.pyi +228 -221
  127. {modal-0.62.115.dist-info → modal-0.72.11.dist-info}/METADATA +5 -5
  128. modal-0.72.11.dist-info/RECORD +174 -0
  129. {modal-0.62.115.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
  130. modal_docs/gen_reference_docs.py +3 -1
  131. modal_docs/mdmd/mdmd.py +0 -1
  132. modal_docs/mdmd/signatures.py +1 -2
  133. modal_global_objects/images/base_images.py +28 -0
  134. modal_global_objects/mounts/python_standalone.py +2 -2
  135. modal_proto/__init__.py +1 -1
  136. modal_proto/api.proto +1231 -531
  137. modal_proto/api_grpc.py +750 -430
  138. modal_proto/api_pb2.py +2102 -1176
  139. modal_proto/api_pb2.pyi +8859 -0
  140. modal_proto/api_pb2_grpc.py +1329 -675
  141. modal_proto/api_pb2_grpc.pyi +1416 -0
  142. modal_proto/modal_api_grpc.py +149 -0
  143. modal_proto/modal_options_grpc.py +3 -0
  144. modal_proto/options_pb2.pyi +20 -0
  145. modal_proto/options_pb2_grpc.pyi +7 -0
  146. modal_proto/py.typed +0 -0
  147. modal_version/__init__.py +1 -1
  148. modal_version/_version_generated.py +2 -2
  149. modal/_asgi.py +0 -370
  150. modal/_container_exec.py +0 -128
  151. modal/_container_io_manager.py +0 -646
  152. modal/_container_io_manager.pyi +0 -412
  153. modal/_sandbox_shell.py +0 -49
  154. modal/app_utils.py +0 -20
  155. modal/app_utils.pyi +0 -17
  156. modal/execution_context.pyi +0 -37
  157. modal/shared_volume.py +0 -23
  158. modal/shared_volume.pyi +0 -24
  159. modal-0.62.115.dist-info/RECORD +0 -207
  160. modal_global_objects/images/conda.py +0 -15
  161. modal_global_objects/images/debian_slim.py +0 -15
  162. modal_global_objects/images/micromamba.py +0 -15
  163. test/__init__.py +0 -1
  164. test/aio_test.py +0 -12
  165. test/async_utils_test.py +0 -279
  166. test/blob_test.py +0 -67
  167. test/cli_imports_test.py +0 -149
  168. test/cli_test.py +0 -674
  169. test/client_test.py +0 -203
  170. test/cloud_bucket_mount_test.py +0 -22
  171. test/cls_test.py +0 -636
  172. test/config_test.py +0 -149
  173. test/conftest.py +0 -1485
  174. test/container_app_test.py +0 -50
  175. test/container_test.py +0 -1405
  176. test/cpu_test.py +0 -23
  177. test/decorator_test.py +0 -85
  178. test/deprecation_test.py +0 -34
  179. test/dict_test.py +0 -51
  180. test/e2e_test.py +0 -68
  181. test/error_test.py +0 -7
  182. test/function_serialization_test.py +0 -32
  183. test/function_test.py +0 -791
  184. test/function_utils_test.py +0 -101
  185. test/gpu_test.py +0 -159
  186. test/grpc_utils_test.py +0 -82
  187. test/helpers.py +0 -47
  188. test/image_test.py +0 -814
  189. test/live_reload_test.py +0 -80
  190. test/lookup_test.py +0 -70
  191. test/mdmd_test.py +0 -329
  192. test/mount_test.py +0 -162
  193. test/mounted_files_test.py +0 -327
  194. test/network_file_system_test.py +0 -188
  195. test/notebook_test.py +0 -66
  196. test/object_test.py +0 -41
  197. test/package_utils_test.py +0 -25
  198. test/queue_test.py +0 -115
  199. test/resolver_test.py +0 -59
  200. test/retries_test.py +0 -67
  201. test/runner_test.py +0 -85
  202. test/sandbox_test.py +0 -191
  203. test/schedule_test.py +0 -15
  204. test/scheduler_placement_test.py +0 -57
  205. test/secret_test.py +0 -89
  206. test/serialization_test.py +0 -50
  207. test/stub_composition_test.py +0 -10
  208. test/stub_test.py +0 -361
  209. test/test_asgi_wrapper.py +0 -234
  210. test/token_flow_test.py +0 -18
  211. test/traceback_test.py +0 -135
  212. test/tunnel_test.py +0 -29
  213. test/utils_test.py +0 -88
  214. test/version_test.py +0 -14
  215. test/volume_test.py +0 -397
  216. test/watcher_test.py +0 -58
  217. test/webhook_test.py +0 -145
  218. {modal-0.62.115.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
  219. {modal-0.62.115.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
  220. {modal-0.62.115.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
test/conftest.py DELETED
@@ -1,1485 +0,0 @@
1
- # Copyright Modal Labs 2024
2
- from __future__ import annotations
3
-
4
- import asyncio
5
- import contextlib
6
- import dataclasses
7
- import hashlib
8
- import inspect
9
- import os
10
- import pytest
11
- import shutil
12
- import sys
13
- import tempfile
14
- import textwrap
15
- import threading
16
- import traceback
17
- from collections import defaultdict
18
- from pathlib import Path
19
- from typing import Dict, Iterator, Optional, get_args
20
-
21
- import aiohttp.web
22
- import aiohttp.web_runner
23
- import grpclib.server
24
- import pkg_resources
25
- import pytest_asyncio
26
- from google.protobuf.empty_pb2 import Empty
27
- from grpclib import GRPCError, Status
28
-
29
- import modal._serialization
30
- from modal import __version__, config
31
- from modal._container_io_manager import _ContainerIOManager
32
- from modal._serialization import serialize_data_format
33
- from modal._utils.async_utils import asyncify, synchronize_api
34
- from modal._utils.grpc_testing import patch_mock_servicer
35
- from modal._utils.grpc_utils import find_free_port
36
- from modal._utils.http_utils import run_temporary_http_server
37
- from modal._vendor import cloudpickle
38
- from modal.client import Client
39
- from modal.image import ImageBuilderVersion
40
- from modal.mount import client_mount_name
41
- from modal_proto import api_grpc, api_pb2
42
-
43
-
44
- @dataclasses.dataclass
45
- class VolumeFile:
46
- data: bytes
47
- data_blob_id: str
48
- mode: int
49
-
50
-
51
- # TODO: Isolate all test config from the host
52
- @pytest.fixture(scope="session", autouse=True)
53
- def set_env():
54
- os.environ["MODAL_ENVIRONMENT"] = "main"
55
-
56
-
57
- @patch_mock_servicer
58
- class MockClientServicer(api_grpc.ModalClientBase):
59
- # TODO(erikbern): add more annotations
60
- container_inputs: list[api_pb2.FunctionGetInputsResponse]
61
- container_outputs: list[api_pb2.FunctionPutOutputsRequest]
62
- fc_data_in: defaultdict[str, asyncio.Queue[api_pb2.DataChunk]]
63
- fc_data_out: defaultdict[str, asyncio.Queue[api_pb2.DataChunk]]
64
-
65
- def __init__(self, blob_host, blobs):
66
- self.use_blob_outputs = False
67
- self.put_outputs_barrier = threading.Barrier(
68
- 1, timeout=10
69
- ) # set to non-1 to get lock-step of output pushing within a test
70
- self.get_inputs_barrier = threading.Barrier(
71
- 1, timeout=10
72
- ) # set to non-1 to get lock-step of input releases within a test
73
-
74
- self.app_state_history = defaultdict(list)
75
- self.app_heartbeats: Dict[str, int] = defaultdict(int)
76
- self.container_checkpoint_requests = 0
77
- self.n_blobs = 0
78
- self.blob_host = blob_host
79
- self.blobs = blobs # shared dict
80
- self.requests = []
81
- self.done = False
82
- self.rate_limit_sleep_duration = None
83
- self.fail_get_inputs = False
84
- self.slow_put_inputs = False
85
- self.container_inputs = []
86
- self.container_outputs = []
87
- self.fc_data_in = defaultdict(lambda: asyncio.Queue()) # unbounded
88
- self.fc_data_out = defaultdict(lambda: asyncio.Queue()) # unbounded
89
- self.queue = []
90
- self.deployed_apps = {
91
- client_mount_name(): "ap-x",
92
- }
93
- self.app_objects = {}
94
- self.app_single_objects = {}
95
- self.app_unindexed_objects = {
96
- "ap-1": ["im-1", "vo-1"],
97
- }
98
- self.n_inputs = 0
99
- self.n_queues = 0
100
- self.n_dict_heartbeats = 0
101
- self.n_queue_heartbeats = 0
102
- self.n_nfs_heartbeats = 0
103
- self.n_vol_heartbeats = 0
104
- self.n_mounts = 0
105
- self.n_mount_files = 0
106
- self.mount_contents = {}
107
- self.files_name2sha = {}
108
- self.files_sha2data = {}
109
- self.function_id_for_function_call = {}
110
- self.client_calls = {}
111
- self.function_is_running = False
112
- self.n_functions = 0
113
- self.n_schedules = 0
114
- self.function2schedule = {}
115
- self.function_create_error = False
116
- self.heartbeat_status_code = None
117
- self.n_apps = 0
118
- self.classes = {}
119
-
120
- self.task_result = None
121
-
122
- self.nfs_files: Dict[str, Dict[str, api_pb2.SharedVolumePutFileRequest]] = defaultdict(dict)
123
- self.volume_files: Dict[str, Dict[str, VolumeFile]] = defaultdict(dict)
124
- self.images = {}
125
- self.image_build_function_ids = {}
126
- self.image_builder_versions = {}
127
- self.force_built_images = []
128
- self.fail_blob_create = []
129
- self.blob_create_metadata = None
130
- self.blob_multipart_threshold = 10_000_000
131
-
132
- self.precreated_functions = set()
133
- self.app_functions = {}
134
- self.fcidx = 0
135
-
136
- self.function_serialized = None
137
- self.class_serialized = None
138
-
139
- self.client_hello_metadata = None
140
-
141
- self.dicts = {}
142
- self.secrets = {}
143
-
144
- self.deployed_dicts = {}
145
- self.deployed_mounts = {
146
- (client_mount_name(), api_pb2.DEPLOYMENT_NAMESPACE_GLOBAL): "mo-123",
147
- }
148
- self.deployed_nfss = {}
149
- self.deployed_queues = {}
150
- self.deployed_secrets = {}
151
- self.deployed_volumes = {}
152
-
153
- self.cleared_function_calls = set()
154
-
155
- self.cancelled_calls = []
156
-
157
- self.app_client_disconnect_count = 0
158
- self.app_get_logs_initial_count = 0
159
- self.app_set_objects_count = 0
160
-
161
- self.volume_counter = 0
162
- # Volume-id -> commit/reload count
163
- self.volume_commits: Dict[str, int] = defaultdict(lambda: 0)
164
- self.volume_reloads: Dict[str, int] = defaultdict(lambda: 0)
165
-
166
- self.sandbox_defs = []
167
- self.sandbox: asyncio.subprocess.Process = None
168
-
169
- # Whether the sandbox is executing a shell program in interactive mode.
170
- self.sandbox_is_interactive = False
171
- self.sandbox_shell_prompt = "TEST_PROMPT# "
172
- self.sandbox_result: Optional[api_pb2.GenericResult] = None
173
-
174
- self.token_flow_localhost_port = None
175
- self.queue_max_len = 1_00
176
-
177
- self.container_heartbeat_response = None
178
- self.container_heartbeat_abort = threading.Event()
179
-
180
- @self.function_body
181
- def default_function_body(*args, **kwargs):
182
- return sum(arg**2 for arg in args) + sum(value**2 for key, value in kwargs.items())
183
-
184
- def function_body(self, func):
185
- """Decorator for setting the function that will be called for any FunctionGetOutputs calls"""
186
- self._function_body = func
187
- return func
188
-
189
- def container_heartbeat_return_now(self, response: api_pb2.ContainerHeartbeatResponse):
190
- self.container_heartbeat_response = response
191
- self.container_heartbeat_abort.set()
192
-
193
- def get_function_metadata(self, object_id: str) -> api_pb2.FunctionHandleMetadata:
194
- definition: api_pb2.Function = self.app_functions[object_id]
195
- return api_pb2.FunctionHandleMetadata(
196
- function_name=definition.function_name,
197
- function_type=definition.function_type,
198
- web_url=definition.web_url,
199
- is_method=definition.is_method,
200
- )
201
-
202
- def get_class_metadata(self, object_id: str) -> api_pb2.ClassHandleMetadata:
203
- class_handle_metadata = api_pb2.ClassHandleMetadata()
204
- for f_name, f_id in self.classes[object_id].items():
205
- function_handle_metadata = self.get_function_metadata(f_id)
206
- class_handle_metadata.methods.append(
207
- api_pb2.ClassMethod(
208
- function_name=f_name, function_id=f_id, function_handle_metadata=function_handle_metadata
209
- )
210
- )
211
- return class_handle_metadata
212
-
213
- def get_object_metadata(self, object_id) -> api_pb2.Object:
214
- if object_id.startswith("fu-"):
215
- res = api_pb2.Object(function_handle_metadata=self.get_function_metadata(object_id))
216
-
217
- elif object_id.startswith("cs-"):
218
- res = api_pb2.Object(class_handle_metadata=self.get_class_metadata(object_id))
219
-
220
- elif object_id.startswith("mo-"):
221
- mount_handle_metadata = api_pb2.MountHandleMetadata(content_checksum_sha256_hex="abc123")
222
- res = api_pb2.Object(mount_handle_metadata=mount_handle_metadata)
223
-
224
- elif object_id.startswith("sb-"):
225
- sandbox_handle_metadata = api_pb2.SandboxHandleMetadata(result=self.sandbox_result)
226
- res = api_pb2.Object(sandbox_handle_metadata=sandbox_handle_metadata)
227
-
228
- else:
229
- res = api_pb2.Object()
230
-
231
- res.object_id = object_id
232
- return res
233
-
234
- ### App
235
-
236
- async def AppCreate(self, stream):
237
- request: api_pb2.AppCreateRequest = await stream.recv_message()
238
- self.requests.append(request)
239
- self.n_apps += 1
240
- app_id = f"ap-{self.n_apps}"
241
- self.app_state_history[app_id].append(api_pb2.APP_STATE_INITIALIZING)
242
- await stream.send_message(
243
- api_pb2.AppCreateResponse(app_id=app_id, app_logs_url="https://modaltest.com/apps/ap-123")
244
- )
245
-
246
- async def AppClientDisconnect(self, stream):
247
- request: api_pb2.AppClientDisconnectRequest = await stream.recv_message()
248
- self.requests.append(request)
249
- self.done = True
250
- self.app_client_disconnect_count += 1
251
- state_history = self.app_state_history[request.app_id]
252
- if state_history[-1] not in [api_pb2.APP_STATE_DETACHED, api_pb2.APP_STATE_DEPLOYED]:
253
- state_history.append(api_pb2.APP_STATE_STOPPED)
254
- await stream.send_message(Empty())
255
-
256
- async def AppGetLogs(self, stream):
257
- request: api_pb2.AppGetLogsRequest = await stream.recv_message()
258
- if not request.last_entry_id:
259
- # Just count initial requests
260
- self.app_get_logs_initial_count += 1
261
- last_entry_id = "1"
262
- else:
263
- last_entry_id = str(int(request.last_entry_id) + 1)
264
- await asyncio.sleep(0.5)
265
- log = api_pb2.TaskLogs(data=f"hello, world ({last_entry_id})\n", file_descriptor=api_pb2.FILE_DESCRIPTOR_STDOUT)
266
- await stream.send_message(api_pb2.TaskLogsBatch(entry_id=last_entry_id, items=[log]))
267
- if self.done:
268
- await stream.send_message(api_pb2.TaskLogsBatch(app_done=True))
269
-
270
- async def AppGetObjects(self, stream):
271
- request: api_pb2.AppGetObjectsRequest = await stream.recv_message()
272
- object_ids = self.app_objects.get(request.app_id, {})
273
- objects = list(object_ids.items())
274
- if request.include_unindexed:
275
- unindexed_object_ids = self.app_unindexed_objects.get(request.app_id, [])
276
- objects += [(None, object_id) for object_id in unindexed_object_ids]
277
- items = [
278
- api_pb2.AppGetObjectsItem(tag=tag, object=self.get_object_metadata(object_id)) for tag, object_id in objects
279
- ]
280
- await stream.send_message(api_pb2.AppGetObjectsResponse(items=items))
281
-
282
- async def AppSetObjects(self, stream):
283
- request: api_pb2.AppSetObjectsRequest = await stream.recv_message()
284
- self.app_objects[request.app_id] = dict(request.indexed_object_ids)
285
- self.app_unindexed_objects[request.app_id] = list(request.unindexed_object_ids)
286
- if request.single_object_id:
287
- self.app_single_objects[request.app_id] = request.single_object_id
288
- self.app_set_objects_count += 1
289
- if request.new_app_state:
290
- self.app_state_history[request.app_id].append(request.new_app_state)
291
- await stream.send_message(Empty())
292
-
293
- async def AppDeploy(self, stream):
294
- request: api_pb2.AppDeployRequest = await stream.recv_message()
295
- self.deployed_apps[request.name] = request.app_id
296
- self.app_state_history[request.app_id].append(api_pb2.APP_STATE_DEPLOYED)
297
- await stream.send_message(api_pb2.AppDeployResponse(url="http://test.modal.com/foo/bar"))
298
-
299
- async def AppGetByDeploymentName(self, stream):
300
- request: api_pb2.AppGetByDeploymentNameRequest = await stream.recv_message()
301
- await stream.send_message(api_pb2.AppGetByDeploymentNameResponse(app_id=self.deployed_apps.get(request.name)))
302
-
303
- async def AppHeartbeat(self, stream):
304
- request: api_pb2.AppHeartbeatRequest = await stream.recv_message()
305
- self.requests.append(request)
306
- self.app_heartbeats[request.app_id] += 1
307
- await stream.send_message(Empty())
308
-
309
- async def AppList(self, stream):
310
- await stream.recv_message()
311
- apps = []
312
- for app_name, app_id in self.deployed_apps.items():
313
- apps.append(api_pb2.AppStats(name=app_name, description=app_name, app_id=app_id))
314
- await stream.send_message(api_pb2.AppListResponse(apps=apps))
315
-
316
- ### Checkpoint
317
-
318
- async def ContainerCheckpoint(self, stream):
319
- request: api_pb2.ContainerCheckpointRequest = await stream.recv_message()
320
- self.requests.append(request)
321
- self.container_checkpoint_requests += 1
322
- await stream.send_message(Empty())
323
-
324
- ### Blob
325
-
326
- async def BlobCreate(self, stream):
327
- req = await stream.recv_message()
328
- # This is used to test retry_transient_errors, see grpc_utils_test.py
329
- self.blob_create_metadata = stream.metadata
330
- if len(self.fail_blob_create) > 0:
331
- status_code = self.fail_blob_create.pop()
332
- raise GRPCError(status_code, "foobar")
333
- elif req.content_length > self.blob_multipart_threshold:
334
- blob_id = await self.next_blob_id()
335
- num_parts = (req.content_length + self.blob_multipart_threshold - 1) // self.blob_multipart_threshold
336
- upload_urls = []
337
- for part_number in range(num_parts):
338
- upload_url = f"{self.blob_host}/upload?blob_id={blob_id}&part_number={part_number}"
339
- upload_urls.append(upload_url)
340
-
341
- await stream.send_message(
342
- api_pb2.BlobCreateResponse(
343
- blob_id=blob_id,
344
- multipart=api_pb2.MultiPartUpload(
345
- part_length=self.blob_multipart_threshold,
346
- upload_urls=upload_urls,
347
- completion_url=f"{self.blob_host}/complete_multipart?blob_id={blob_id}",
348
- ),
349
- )
350
- )
351
- else:
352
- blob_id = await self.next_blob_id()
353
- upload_url = f"{self.blob_host}/upload?blob_id={blob_id}"
354
- await stream.send_message(api_pb2.BlobCreateResponse(blob_id=blob_id, upload_url=upload_url))
355
-
356
- async def next_blob_id(self):
357
- self.n_blobs += 1
358
- blob_id = f"bl-{self.n_blobs}"
359
- return blob_id
360
-
361
- async def BlobGet(self, stream):
362
- request: api_pb2.BlobGetRequest = await stream.recv_message()
363
- download_url = f"{self.blob_host}/download?blob_id={request.blob_id}"
364
- await stream.send_message(api_pb2.BlobGetResponse(download_url=download_url))
365
-
366
- ### Class
367
-
368
- async def ClassCreate(self, stream):
369
- request: api_pb2.ClassCreateRequest = await stream.recv_message()
370
- assert request.app_id
371
- methods: dict[str, str] = {method.function_name: method.function_id for method in request.methods}
372
- class_id = "cs-" + str(len(self.classes))
373
- self.classes[class_id] = methods
374
- await stream.send_message(
375
- api_pb2.ClassCreateResponse(class_id=class_id, handle_metadata=self.get_class_metadata(class_id))
376
- )
377
-
378
- async def ClassGet(self, stream):
379
- request: api_pb2.ClassGetRequest = await stream.recv_message()
380
- app_id = self.deployed_apps.get(request.app_name)
381
- app_objects = self.app_objects[app_id]
382
- object_id = app_objects.get(request.object_tag)
383
- if object_id is None:
384
- raise GRPCError(Status.NOT_FOUND, f"can't find object {request.object_tag}")
385
- await stream.send_message(
386
- api_pb2.ClassGetResponse(class_id=object_id, handle_metadata=self.get_class_metadata(object_id))
387
- )
388
-
389
- ### Client
390
-
391
- async def ClientHello(self, stream):
392
- request: Empty = await stream.recv_message()
393
- self.requests.append(request)
394
- self.client_create_metadata = stream.metadata
395
- client_version = stream.metadata["x-modal-client-version"]
396
- image_builder_version = max(get_args(ImageBuilderVersion))
397
- warning = ""
398
- assert stream.user_agent.startswith(f"modal-client/{__version__} ")
399
- if stream.metadata.get("x-modal-token-id") == "bad":
400
- raise GRPCError(Status.UNAUTHENTICATED, "bad bad bad")
401
- elif client_version == "unauthenticated":
402
- raise GRPCError(Status.UNAUTHENTICATED, "failed authentication")
403
- elif client_version == "deprecated":
404
- warning = "SUPER OLD"
405
- elif client_version == "timeout":
406
- await asyncio.sleep(60)
407
- elif pkg_resources.parse_version(client_version) < pkg_resources.parse_version(__version__):
408
- raise GRPCError(Status.FAILED_PRECONDITION, "Old client")
409
- resp = api_pb2.ClientHelloResponse(warning=warning, image_builder_version=image_builder_version)
410
- await stream.send_message(resp)
411
-
412
- # Container
413
-
414
- async def ContainerHeartbeat(self, stream):
415
- request: api_pb2.ContainerHeartbeatRequest = await stream.recv_message()
416
- self.requests.append(request)
417
- # Return earlier than the usual 15-second heartbeat to avoid suspending tests.
418
- await asyncify(self.container_heartbeat_abort.wait)(5)
419
- if self.container_heartbeat_response:
420
- await stream.send_message(self.container_heartbeat_response)
421
- self.container_heartbeat_response = None
422
- else:
423
- await stream.send_message(api_pb2.ContainerHeartbeatResponse())
424
-
425
- async def ContainerExec(self, stream):
426
- _request: api_pb2.ContainerExecRequest = await stream.recv_message()
427
- await stream.send_message(api_pb2.ContainerExecResponse(exec_id="container_exec_id"))
428
-
429
- async def ContainerExecGetOutput(self, stream):
430
- _request: api_pb2.ContainerExecGetOutputRequest = await stream.recv_message()
431
- await stream.send_message(
432
- api_pb2.RuntimeOutputBatch(
433
- items=[
434
- api_pb2.RuntimeOutputMessage(
435
- file_descriptor=api_pb2.FileDescriptor.FILE_DESCRIPTOR_STDOUT, message="Hello World"
436
- )
437
- ]
438
- )
439
- )
440
- await stream.send_message(api_pb2.RuntimeOutputBatch(exit_code=0))
441
-
442
- ### Dict
443
-
444
- async def DictCreate(self, stream):
445
- request: api_pb2.DictCreateRequest = await stream.recv_message()
446
- if request.existing_dict_id:
447
- dict_id = request.existing_dict_id
448
- else:
449
- dict_id = f"di-{len(self.dicts)}"
450
- self.dicts[dict_id] = {}
451
- await stream.send_message(api_pb2.DictCreateResponse(dict_id=dict_id))
452
-
453
- async def DictGetOrCreate(self, stream):
454
- request: api_pb2.DictGetOrCreateRequest = await stream.recv_message()
455
- k = (request.deployment_name, request.namespace, request.environment_name)
456
- if k in self.deployed_dicts:
457
- dict_id = self.deployed_dicts[k]
458
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING:
459
- dict_id = f"di-{len(self.dicts)}"
460
- self.dicts[dict_id] = {entry.key: entry.value for entry in request.data}
461
- self.deployed_dicts[k] = dict_id
462
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_EPHEMERAL:
463
- dict_id = f"di-{len(self.dicts)}"
464
- self.dicts[dict_id] = {entry.key: entry.value for entry in request.data}
465
- else:
466
- raise GRPCError(Status.NOT_FOUND, "Queue not found")
467
- await stream.send_message(api_pb2.DictGetOrCreateResponse(dict_id=dict_id))
468
-
469
- async def DictHeartbeat(self, stream):
470
- await stream.recv_message()
471
- self.n_dict_heartbeats += 1
472
- await stream.send_message(Empty())
473
-
474
- async def DictDelete(self, stream):
475
- request: api_pb2.DictDeleteRequest = await stream.recv_message()
476
- self.deployed_dicts = {k: v for k, v in self.deployed_dicts.items() if v != request.dict_id}
477
- await stream.send_message(Empty())
478
-
479
- async def DictClear(self, stream):
480
- request: api_pb2.DictGetRequest = await stream.recv_message()
481
- self.dicts[request.dict_id] = {}
482
- await stream.send_message(Empty())
483
-
484
- async def DictGet(self, stream):
485
- request: api_pb2.DictGetRequest = await stream.recv_message()
486
- d = self.dicts[request.dict_id]
487
- await stream.send_message(api_pb2.DictGetResponse(value=d.get(request.key), found=bool(request.key in d)))
488
-
489
- async def DictLen(self, stream):
490
- request: api_pb2.DictLenRequest = await stream.recv_message()
491
- await stream.send_message(api_pb2.DictLenResponse(len=len(self.dicts[request.dict_id])))
492
-
493
- async def DictUpdate(self, stream):
494
- request: api_pb2.DictUpdateRequest = await stream.recv_message()
495
- for update in request.updates:
496
- self.dicts[request.dict_id][update.key] = update.value
497
- await stream.send_message(api_pb2.DictUpdateResponse())
498
-
499
- async def DictContents(self, stream):
500
- request: api_pb2.DictGetRequest = await stream.recv_message()
501
- for k, v in self.dicts[request.dict_id].items():
502
- await stream.send_message(api_pb2.DictEntry(key=k, value=v))
503
-
504
- ### Function
505
-
506
- async def FunctionBindParams(self, stream):
507
- request: api_pb2.FunctionBindParamsRequest = await stream.recv_message()
508
- assert request.function_id
509
- assert request.serialized_params
510
- self.n_functions += 1
511
- function_id = f"fu-{self.n_functions}"
512
-
513
- await stream.send_message(api_pb2.FunctionBindParamsResponse(bound_function_id=function_id))
514
-
515
- @contextlib.contextmanager
516
- def input_lockstep(self) -> Iterator[threading.Barrier]:
517
- self.get_inputs_barrier = threading.Barrier(2, timeout=10)
518
- yield self.get_inputs_barrier
519
- self.get_inputs_barrier = threading.Barrier(1)
520
-
521
- @contextlib.contextmanager
522
- def output_lockstep(self) -> Iterator[threading.Barrier]:
523
- self.put_outputs_barrier = threading.Barrier(2, timeout=10)
524
- yield self.put_outputs_barrier
525
- self.put_outputs_barrier = threading.Barrier(1)
526
-
527
- async def FunctionGetInputs(self, stream):
528
- self.get_inputs_barrier.wait()
529
- request: api_pb2.FunctionGetInputsRequest = await stream.recv_message()
530
- assert request.function_id
531
- if self.fail_get_inputs:
532
- raise GRPCError(Status.INTERNAL)
533
- elif self.rate_limit_sleep_duration is not None:
534
- s = self.rate_limit_sleep_duration
535
- self.rate_limit_sleep_duration = None
536
- await stream.send_message(api_pb2.FunctionGetInputsResponse(rate_limit_sleep_duration=s))
537
- elif not self.container_inputs:
538
- await asyncio.sleep(1.0)
539
- await stream.send_message(api_pb2.FunctionGetInputsResponse(inputs=[]))
540
- else:
541
- await stream.send_message(self.container_inputs.pop(0))
542
-
543
- async def FunctionPutOutputs(self, stream):
544
- self.put_outputs_barrier.wait()
545
- request: api_pb2.FunctionPutOutputsRequest = await stream.recv_message()
546
- self.container_outputs.append(request)
547
- await stream.send_message(Empty())
548
-
549
- async def FunctionPrecreate(self, stream):
550
- req: api_pb2.FunctionPrecreateRequest = await stream.recv_message()
551
- if not req.existing_function_id:
552
- self.n_functions += 1
553
- function_id = f"fu-{self.n_functions}"
554
- else:
555
- function_id = req.existing_function_id
556
-
557
- self.precreated_functions.add(function_id)
558
-
559
- web_url = "http://xyz.internal" if req.HasField("webhook_config") and req.webhook_config.type else None
560
- await stream.send_message(
561
- api_pb2.FunctionPrecreateResponse(
562
- function_id=function_id,
563
- handle_metadata=api_pb2.FunctionHandleMetadata(
564
- function_name=req.function_name,
565
- function_type=req.function_type,
566
- web_url=web_url,
567
- ),
568
- )
569
- )
570
-
571
- async def FunctionCreate(self, stream):
572
- request: api_pb2.FunctionCreateRequest = await stream.recv_message()
573
- if self.function_create_error:
574
- raise GRPCError(Status.INTERNAL, "Function create failed")
575
- if request.existing_function_id:
576
- function_id = request.existing_function_id
577
- else:
578
- self.n_functions += 1
579
- function_id = f"fu-{self.n_functions}"
580
- if request.schedule:
581
- self.function2schedule[function_id] = request.schedule
582
- function = api_pb2.Function()
583
- function.CopyFrom(request.function)
584
- if function.webhook_config.type:
585
- function.web_url = "http://xyz.internal"
586
-
587
- self.app_functions[function_id] = function
588
- await stream.send_message(
589
- api_pb2.FunctionCreateResponse(
590
- function_id=function_id,
591
- function=function,
592
- handle_metadata=api_pb2.FunctionHandleMetadata(
593
- function_name=function.function_name,
594
- function_type=function.function_type,
595
- web_url=function.web_url,
596
- ),
597
- )
598
- )
599
-
600
- async def FunctionGet(self, stream):
601
- request: api_pb2.FunctionGetRequest = await stream.recv_message()
602
- app_id = self.deployed_apps.get(request.app_name)
603
- app_objects = self.app_objects[app_id]
604
- object_id = app_objects.get(request.object_tag)
605
- if object_id is None:
606
- raise GRPCError(Status.NOT_FOUND, f"can't find object {request.object_tag}")
607
- await stream.send_message(
608
- api_pb2.FunctionGetResponse(function_id=object_id, handle_metadata=self.get_function_metadata(object_id))
609
- )
610
-
611
- async def FunctionMap(self, stream):
612
- self.fcidx += 1
613
- request: api_pb2.FunctionMapRequest = await stream.recv_message()
614
- function_call_id = f"fc-{self.fcidx}"
615
- self.function_id_for_function_call[function_call_id] = request.function_id
616
- await stream.send_message(api_pb2.FunctionMapResponse(function_call_id=function_call_id))
617
-
618
- async def FunctionPutInputs(self, stream):
619
- request: api_pb2.FunctionPutInputsRequest = await stream.recv_message()
620
- response_items = []
621
- function_call_inputs = self.client_calls.setdefault(request.function_call_id, [])
622
- for item in request.inputs:
623
- if item.input.WhichOneof("args_oneof") == "args":
624
- args, kwargs = modal._serialization.deserialize(item.input.args, None)
625
- else:
626
- args, kwargs = modal._serialization.deserialize(self.blobs[item.input.args_blob_id], None)
627
-
628
- input_id = f"in-{self.n_inputs}"
629
- self.n_inputs += 1
630
- response_items.append(api_pb2.FunctionPutInputsResponseItem(input_id=input_id, idx=item.idx))
631
- function_call_inputs.append(((item.idx, input_id), (args, kwargs)))
632
- if self.slow_put_inputs:
633
- await asyncio.sleep(0.001)
634
- await stream.send_message(api_pb2.FunctionPutInputsResponse(inputs=response_items))
635
-
636
- async def FunctionGetOutputs(self, stream):
637
- request: api_pb2.FunctionGetOutputsRequest = await stream.recv_message()
638
- if request.clear_on_success:
639
- self.cleared_function_calls.add(request.function_call_id)
640
-
641
- client_calls = self.client_calls.get(request.function_call_id, [])
642
- if client_calls and not self.function_is_running:
643
- popidx = len(client_calls) // 2 # simulate that results don't always come in order
644
- (idx, input_id), (args, kwargs) = client_calls.pop(popidx)
645
- output_exc = None
646
- try:
647
- res = self._function_body(*args, **kwargs)
648
-
649
- if inspect.iscoroutine(res):
650
- result = await res
651
- result_data_format = api_pb2.DATA_FORMAT_PICKLE
652
- elif inspect.isgenerator(res):
653
- count = 0
654
- for item in res:
655
- count += 1
656
- await self.fc_data_out[request.function_call_id].put(
657
- api_pb2.DataChunk(
658
- data_format=api_pb2.DATA_FORMAT_PICKLE,
659
- data=serialize_data_format(item, api_pb2.DATA_FORMAT_PICKLE),
660
- index=count,
661
- )
662
- )
663
- result = api_pb2.GeneratorDone(items_total=count)
664
- result_data_format = api_pb2.DATA_FORMAT_GENERATOR_DONE
665
- else:
666
- result = res
667
- result_data_format = api_pb2.DATA_FORMAT_PICKLE
668
- except Exception as exc:
669
- serialized_exc = cloudpickle.dumps(exc)
670
- result = api_pb2.GenericResult(
671
- status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
672
- data=serialized_exc,
673
- exception=repr(exc),
674
- traceback="".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
675
- )
676
- output_exc = api_pb2.FunctionGetOutputsItem(
677
- input_id=input_id, idx=idx, result=result, gen_index=0, data_format=api_pb2.DATA_FORMAT_PICKLE
678
- )
679
-
680
- if output_exc:
681
- output = output_exc
682
- else:
683
- serialized_data = serialize_data_format(result, result_data_format)
684
- if self.use_blob_outputs:
685
- blob_id = await self.next_blob_id()
686
- self.blobs[blob_id] = serialized_data
687
- data_kwargs = {
688
- "data_blob_id": blob_id,
689
- }
690
- else:
691
- data_kwargs = {"data": serialized_data}
692
- output = api_pb2.FunctionGetOutputsItem(
693
- input_id=input_id,
694
- idx=idx,
695
- result=api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS, **data_kwargs),
696
- data_format=result_data_format,
697
- )
698
-
699
- await stream.send_message(api_pb2.FunctionGetOutputsResponse(outputs=[output]))
700
- else:
701
- await stream.send_message(api_pb2.FunctionGetOutputsResponse(outputs=[]))
702
-
703
- async def FunctionGetSerialized(self, stream):
704
- await stream.send_message(
705
- api_pb2.FunctionGetSerializedResponse(
706
- function_serialized=self.function_serialized,
707
- class_serialized=self.class_serialized,
708
- )
709
- )
710
-
711
- async def FunctionCallCancel(self, stream):
712
- req = await stream.recv_message()
713
- self.cancelled_calls.append(req.function_call_id)
714
- await stream.send_message(Empty())
715
-
716
- async def FunctionCallGetDataIn(self, stream):
717
- req: api_pb2.FunctionCallGetDataRequest = await stream.recv_message()
718
- while True:
719
- chunk = await self.fc_data_in[req.function_call_id].get()
720
- await stream.send_message(chunk)
721
-
722
- async def FunctionCallGetDataOut(self, stream):
723
- req: api_pb2.FunctionCallGetDataRequest = await stream.recv_message()
724
- while True:
725
- chunk = await self.fc_data_out[req.function_call_id].get()
726
- await stream.send_message(chunk)
727
-
728
- async def FunctionCallPutDataOut(self, stream):
729
- req: api_pb2.FunctionCallPutDataRequest = await stream.recv_message()
730
- for chunk in req.data_chunks:
731
- await self.fc_data_out[req.function_call_id].put(chunk)
732
- await stream.send_message(Empty())
733
-
734
- ### Image
735
-
736
- async def ImageGetOrCreate(self, stream):
737
- request: api_pb2.ImageGetOrCreateRequest = await stream.recv_message()
738
- idx = len(self.images) + 1
739
- image_id = f"im-{idx}"
740
-
741
- self.images[image_id] = request.image
742
- self.image_build_function_ids[image_id] = request.build_function_id
743
- self.image_builder_versions[image_id] = request.builder_version
744
- if request.force_build:
745
- self.force_built_images.append(image_id)
746
- await stream.send_message(api_pb2.ImageGetOrCreateResponse(image_id=image_id))
747
-
748
- async def ImageJoinStreaming(self, stream):
749
- await stream.recv_message()
750
- task_log_1 = api_pb2.TaskLogs(data="hello, world\n", file_descriptor=api_pb2.FILE_DESCRIPTOR_INFO)
751
- task_log_2 = api_pb2.TaskLogs(
752
- task_progress=api_pb2.TaskProgress(
753
- len=1, pos=0, progress_type=api_pb2.IMAGE_SNAPSHOT_UPLOAD, description="xyz"
754
- )
755
- )
756
- await stream.send_message(api_pb2.ImageJoinStreamingResponse(task_logs=[task_log_1, task_log_2]))
757
- await stream.send_message(
758
- api_pb2.ImageJoinStreamingResponse(
759
- result=api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS)
760
- )
761
- )
762
-
763
- ### Mount
764
-
765
- async def MountPutFile(self, stream):
766
- request: api_pb2.MountPutFileRequest = await stream.recv_message()
767
- if request.WhichOneof("data_oneof") is not None:
768
- self.files_sha2data[request.sha256_hex] = {"data": request.data, "data_blob_id": request.data_blob_id}
769
- self.n_mount_files += 1
770
- await stream.send_message(api_pb2.MountPutFileResponse(exists=True))
771
- else:
772
- await stream.send_message(api_pb2.MountPutFileResponse(exists=False))
773
-
774
- async def MountGetOrCreate(self, stream):
775
- request: api_pb2.MountGetOrCreateRequest = await stream.recv_message()
776
- k = (request.deployment_name, request.namespace)
777
- if request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_UNSPECIFIED:
778
- if k not in self.deployed_mounts:
779
- raise GRPCError(Status.NOT_FOUND, "Mount not found")
780
- mount_id = self.deployed_mounts[k]
781
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS:
782
- self.n_mounts += 1
783
- mount_id = f"mo-{self.n_mounts}"
784
- self.deployed_mounts[k] = mount_id
785
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_ANONYMOUS_OWNED_BY_APP:
786
- self.n_mounts += 1
787
- mount_id = f"mo-{self.n_mounts}"
788
-
789
- else:
790
- raise Exception("unsupported creation type")
791
-
792
- mount_content = self.mount_contents[mount_id] = {}
793
- for file in request.files:
794
- mount_content[file.filename] = self.files_name2sha[file.filename] = file.sha256_hex
795
-
796
- await stream.send_message(
797
- api_pb2.MountGetOrCreateResponse(
798
- mount_id=mount_id, handle_metadata=api_pb2.MountHandleMetadata(content_checksum_sha256_hex="deadbeef")
799
- )
800
- )
801
-
802
- ### Proxy
803
-
804
- async def ProxyGetOrCreate(self, stream):
805
- await stream.recv_message()
806
- await stream.send_message(api_pb2.ProxyGetOrCreateResponse(proxy_id="pr-123"))
807
-
808
- ### Queue
809
-
810
- async def QueueCreate(self, stream):
811
- request: api_pb2.QueueCreateRequest = await stream.recv_message()
812
- if request.existing_queue_id:
813
- queue_id = request.existing_queue_id
814
- else:
815
- self.n_queues += 1
816
- queue_id = f"qu-{self.n_queues}"
817
- await stream.send_message(api_pb2.QueueCreateResponse(queue_id=queue_id))
818
-
819
- async def QueueGetOrCreate(self, stream):
820
- request: api_pb2.QueueGetOrCreateRequest = await stream.recv_message()
821
- k = (request.deployment_name, request.namespace, request.environment_name)
822
- if k in self.deployed_queues:
823
- queue_id = self.deployed_queues[k]
824
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING:
825
- self.n_queues += 1
826
- queue_id = f"qu-{self.n_queues}"
827
- self.deployed_queues[k] = queue_id
828
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_EPHEMERAL:
829
- self.n_queues += 1
830
- queue_id = f"qu-{self.n_queues}"
831
- else:
832
- raise GRPCError(Status.NOT_FOUND, "Queue not found")
833
- await stream.send_message(api_pb2.QueueGetOrCreateResponse(queue_id=queue_id))
834
-
835
- async def QueueDelete(self, stream):
836
- request: api_pb2.QueueDeleteRequest = await stream.recv_message()
837
- self.deployed_queues = {k: v for k, v in self.deployed_queues.items() if v != request.queue_id}
838
- await stream.send_message(Empty())
839
-
840
- async def QueueHeartbeat(self, stream):
841
- await stream.recv_message()
842
- self.n_queue_heartbeats += 1
843
- await stream.send_message(Empty())
844
-
845
- async def QueuePut(self, stream):
846
- request: api_pb2.QueuePutRequest = await stream.recv_message()
847
- if len(self.queue) >= self.queue_max_len:
848
- raise GRPCError(Status.RESOURCE_EXHAUSTED, f"Hit servicer's max len for Queues: {self.queue_max_len}")
849
- self.queue += request.values
850
- await stream.send_message(Empty())
851
-
852
- async def QueueGet(self, stream):
853
- await stream.recv_message()
854
- if len(self.queue) > 0:
855
- values = [self.queue.pop(0)]
856
- else:
857
- values = []
858
- await stream.send_message(api_pb2.QueueGetResponse(values=values))
859
-
860
- async def QueueLen(self, stream):
861
- await stream.recv_message()
862
- await stream.send_message(api_pb2.QueueLenResponse(len=len(self.queue)))
863
-
864
- async def QueueNextItems(self, stream):
865
- request: api_pb2.QueueNextItemsRequest = await stream.recv_message()
866
- next_item_idx = int(request.last_entry_id) + 1 if request.last_entry_id else 0
867
- if next_item_idx < len(self.queue):
868
- item = api_pb2.QueueItem(value=self.queue[next_item_idx], entry_id=f"{next_item_idx}")
869
- await stream.send_message(api_pb2.QueueNextItemsResponse(items=[item]))
870
- else:
871
- if request.item_poll_timeout > 0:
872
- await asyncio.sleep(0.1)
873
- await stream.send_message(api_pb2.QueueNextItemsResponse(items=[]))
874
-
875
- ### Sandbox
876
-
877
- async def SandboxCreate(self, stream):
878
- request: api_pb2.SandboxCreateRequest = await stream.recv_message()
879
- if request.definition.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
880
- self.sandbox_is_interactive = True
881
-
882
- self.sandbox = await asyncio.subprocess.create_subprocess_exec(
883
- *request.definition.entrypoint_args,
884
- stdout=asyncio.subprocess.PIPE,
885
- stderr=asyncio.subprocess.PIPE,
886
- stdin=asyncio.subprocess.PIPE,
887
- )
888
-
889
- self.sandbox_defs.append(request.definition)
890
-
891
- await stream.send_message(api_pb2.SandboxCreateResponse(sandbox_id="sb-123"))
892
-
893
- async def SandboxGetLogs(self, stream):
894
- request: api_pb2.SandboxGetLogsRequest = await stream.recv_message()
895
- f: asyncio.StreamReader
896
- if self.sandbox_is_interactive:
897
- # sends an empty message to simulate PTY
898
- await stream.send_message(
899
- api_pb2.TaskLogsBatch(
900
- items=[api_pb2.TaskLogs(data=self.sandbox_shell_prompt, file_descriptor=request.file_descriptor)]
901
- )
902
- )
903
-
904
- if request.file_descriptor == api_pb2.FILE_DESCRIPTOR_STDOUT:
905
- # Blocking read until EOF is returned.
906
- f = self.sandbox.stdout
907
- else:
908
- f = self.sandbox.stderr
909
-
910
- async for message in f:
911
- await stream.send_message(
912
- api_pb2.TaskLogsBatch(
913
- items=[api_pb2.TaskLogs(data=message.decode("utf-8"), file_descriptor=request.file_descriptor)]
914
- )
915
- )
916
-
917
- await stream.send_message(api_pb2.TaskLogsBatch(eof=True))
918
-
919
- async def SandboxWait(self, stream):
920
- request: api_pb2.SandboxWaitRequest = await stream.recv_message()
921
- try:
922
- await asyncio.wait_for(self.sandbox.wait(), request.timeout)
923
- except asyncio.TimeoutError:
924
- pass
925
-
926
- if self.sandbox.returncode is None:
927
- # This happens when request.timeout is 0 and the sandbox hasn't completed.
928
- await stream.send_message(api_pb2.SandboxWaitResponse())
929
- return
930
- elif self.sandbox.returncode != 0:
931
- result = api_pb2.GenericResult(
932
- status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE, exitcode=self.sandbox.returncode
933
- )
934
- else:
935
- result = api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS)
936
- self.sandbox_result = result
937
- await stream.send_message(api_pb2.SandboxWaitResponse(result=result))
938
-
939
- async def SandboxTerminate(self, stream):
940
- self.sandbox.terminate()
941
- await stream.send_message(api_pb2.SandboxTerminateResponse())
942
-
943
- async def SandboxGetTaskId(self, stream):
944
- # only used for `modal shell` / `modal container exec`
945
- _request: api_pb2.SandboxGetTaskIdRequest = await stream.recv_message()
946
- await stream.send_message(api_pb2.SandboxGetTaskIdResponse(task_id="modal_container_exec"))
947
-
948
- async def SandboxStdinWrite(self, stream):
949
- request: api_pb2.SandboxStdinWriteRequest = await stream.recv_message()
950
-
951
- self.sandbox.stdin.write(request.input)
952
- await self.sandbox.stdin.drain()
953
-
954
- if request.eof:
955
- self.sandbox.stdin.close()
956
- await stream.send_message(api_pb2.SandboxStdinWriteResponse())
957
-
958
- ### Secret
959
-
960
- async def SecretGetOrCreate(self, stream):
961
- request: api_pb2.SecretGetOrCreateRequest = await stream.recv_message()
962
- k = (request.deployment_name, request.namespace, request.environment_name)
963
- if request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_ANONYMOUS_OWNED_BY_APP:
964
- secret_id = "st-" + str(len(self.secrets))
965
- self.secrets[secret_id] = request.env_dict
966
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS:
967
- if k in self.deployed_secrets:
968
- raise GRPCError(Status.ALREADY_EXISTS, "Already exists")
969
- secret_id = None
970
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_OVERWRITE_IF_EXISTS:
971
- secret_id = None
972
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_UNSPECIFIED:
973
- if k not in self.deployed_secrets:
974
- raise GRPCError(Status.NOT_FOUND, "No such secret")
975
- secret_id = self.deployed_secrets[k]
976
- else:
977
- raise Exception("unsupported creation type")
978
-
979
- if secret_id is None: # Create one
980
- secret_id = "st-" + str(len(self.secrets))
981
- self.secrets[secret_id] = request.env_dict
982
- self.deployed_secrets[k] = secret_id
983
-
984
- await stream.send_message(api_pb2.SecretGetOrCreateResponse(secret_id=secret_id))
985
-
986
- async def SecretList(self, stream):
987
- await stream.recv_message()
988
- items = [api_pb2.SecretListItem(label=f"dummy-secret-{i}") for i, _ in enumerate(self.secrets)]
989
- await stream.send_message(api_pb2.SecretListResponse(items=items))
990
-
991
- ### Network File System (née Shared volume)
992
-
993
- async def SharedVolumeCreate(self, stream):
994
- nfs_id = f"sv-{len(self.nfs_files)}"
995
- self.nfs_files[nfs_id] = {}
996
- await stream.send_message(api_pb2.SharedVolumeCreateResponse(shared_volume_id=nfs_id))
997
-
998
- async def SharedVolumeGetOrCreate(self, stream):
999
- request: api_pb2.SharedVolumeGetOrCreateRequest = await stream.recv_message()
1000
- k = (request.deployment_name, request.namespace, request.environment_name)
1001
- if request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_UNSPECIFIED:
1002
- if k not in self.deployed_nfss:
1003
- raise GRPCError(Status.NOT_FOUND, "NFS not found")
1004
- nfs_id = self.deployed_nfss[k]
1005
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_EPHEMERAL:
1006
- nfs_id = f"sv-{len(self.nfs_files)}"
1007
- self.nfs_files[nfs_id] = {}
1008
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING:
1009
- if k not in self.deployed_nfss:
1010
- nfs_id = f"sv-{len(self.nfs_files)}"
1011
- self.nfs_files[nfs_id] = {}
1012
- self.deployed_nfss[k] = nfs_id
1013
- nfs_id = self.deployed_nfss[k]
1014
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS:
1015
- if k in self.deployed_nfss:
1016
- raise GRPCError(Status.ALREADY_EXISTS, "NFS already exists")
1017
- nfs_id = f"sv-{len(self.nfs_files)}"
1018
- self.nfs_files[nfs_id] = {}
1019
- self.deployed_nfss[k] = nfs_id
1020
- else:
1021
- raise GRPCError(Status.INVALID_ARGUMENT, "unsupported object creation type")
1022
-
1023
- await stream.send_message(api_pb2.SharedVolumeGetOrCreateResponse(shared_volume_id=nfs_id))
1024
-
1025
- async def SharedVolumeHeartbeat(self, stream):
1026
- await stream.recv_message()
1027
- self.n_nfs_heartbeats += 1
1028
- await stream.send_message(Empty())
1029
-
1030
- async def SharedVolumePutFile(self, stream):
1031
- req = await stream.recv_message()
1032
- self.nfs_files[req.shared_volume_id][req.path] = req
1033
- await stream.send_message(api_pb2.SharedVolumePutFileResponse(exists=True))
1034
-
1035
- async def SharedVolumeGetFile(self, stream):
1036
- req = await stream.recv_message()
1037
- put_req = self.nfs_files.get(req.shared_volume_id, {}).get(req.path)
1038
- if not put_req:
1039
- raise GRPCError(Status.NOT_FOUND, f"No such file: {req.path}")
1040
- if put_req.data_blob_id:
1041
- await stream.send_message(api_pb2.SharedVolumeGetFileResponse(data_blob_id=put_req.data_blob_id))
1042
- else:
1043
- await stream.send_message(api_pb2.SharedVolumeGetFileResponse(data=put_req.data))
1044
-
1045
- async def SharedVolumeListFilesStream(self, stream):
1046
- req: api_pb2.SharedVolumeListFilesRequest = await stream.recv_message()
1047
- for path in self.nfs_files[req.shared_volume_id].keys():
1048
- entry = api_pb2.FileEntry(path=path, type=api_pb2.FileEntry.FileType.FILE)
1049
- response = api_pb2.SharedVolumeListFilesResponse(entries=[entry])
1050
- if req.path == "**" or req.path == "/" or req.path == path: # hack
1051
- await stream.send_message(response)
1052
-
1053
- ### Task
1054
-
1055
- async def TaskCurrentInputs(
1056
- self, stream: "grpclib.server.Stream[Empty, api_pb2.TaskCurrentInputsResponse]"
1057
- ) -> None:
1058
- await stream.send_message(api_pb2.TaskCurrentInputsResponse(input_ids=[])) # dummy implementation
1059
-
1060
- async def TaskResult(self, stream):
1061
- request: api_pb2.TaskResultRequest = await stream.recv_message()
1062
- self.task_result = request.result
1063
- await stream.send_message(Empty())
1064
-
1065
- ### Token flow
1066
-
1067
- async def TokenFlowCreate(self, stream):
1068
- request: api_pb2.TokenFlowCreateRequest = await stream.recv_message()
1069
- self.token_flow_localhost_port = request.localhost_port
1070
- await stream.send_message(
1071
- api_pb2.TokenFlowCreateResponse(token_flow_id="tc-123", web_url="https://localhost/xyz/abc")
1072
- )
1073
-
1074
- async def TokenFlowWait(self, stream):
1075
- await stream.send_message(
1076
- api_pb2.TokenFlowWaitResponse(
1077
- token_id="abc",
1078
- token_secret="xyz",
1079
- )
1080
- )
1081
-
1082
- async def WorkspaceNameLookup(self, stream):
1083
- await stream.send_message(api_pb2.WorkspaceNameLookupResponse(username="test-username"))
1084
-
1085
- ### Tunnel
1086
-
1087
- async def TunnelStart(self, stream):
1088
- request: api_pb2.TunnelStartRequest = await stream.recv_message()
1089
- port = request.port
1090
- await stream.send_message(api_pb2.TunnelStartResponse(host=f"{port}.modal.test", port=443))
1091
-
1092
- async def TunnelStop(self, stream):
1093
- await stream.recv_message()
1094
- await stream.send_message(api_pb2.TunnelStopResponse(exists=True))
1095
-
1096
- ### Volume
1097
-
1098
- async def VolumeCreate(self, stream):
1099
- req = await stream.recv_message()
1100
- self.requests.append(req)
1101
- self.volume_counter += 1
1102
- volume_id = f"vo-{self.volume_counter}"
1103
- self.volume_files[volume_id] = {}
1104
- await stream.send_message(api_pb2.VolumeCreateResponse(volume_id=volume_id))
1105
-
1106
- async def VolumeGetOrCreate(self, stream):
1107
- request: api_pb2.VolumeGetOrCreateRequest = await stream.recv_message()
1108
- k = (request.deployment_name, request.namespace, request.environment_name)
1109
- if request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_UNSPECIFIED:
1110
- if k not in self.deployed_volumes:
1111
- raise GRPCError(Status.NOT_FOUND, "Volume not found")
1112
- volume_id = self.deployed_volumes[k]
1113
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_EPHEMERAL:
1114
- volume_id = f"vo-{len(self.volume_files)}"
1115
- self.volume_files[volume_id] = {}
1116
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING:
1117
- if k not in self.deployed_volumes:
1118
- volume_id = f"vo-{len(self.volume_files)}"
1119
- self.volume_files[volume_id] = {}
1120
- self.deployed_volumes[k] = volume_id
1121
- volume_id = self.deployed_volumes[k]
1122
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS:
1123
- if k in self.deployed_volumes:
1124
- raise GRPCError(Status.ALREADY_EXISTS, "Volume already exists")
1125
- volume_id = f"vo-{len(self.volume_files)}"
1126
- self.volume_files[volume_id] = {}
1127
- self.deployed_volumes[k] = volume_id
1128
- else:
1129
- raise GRPCError(Status.INVALID_ARGUMENT, "unsupported object creation type")
1130
-
1131
- await stream.send_message(api_pb2.VolumeGetOrCreateResponse(volume_id=volume_id))
1132
-
1133
- async def VolumeHeartbeat(self, stream):
1134
- await stream.recv_message()
1135
- self.n_vol_heartbeats += 1
1136
- await stream.send_message(Empty())
1137
-
1138
- async def VolumeCommit(self, stream):
1139
- req = await stream.recv_message()
1140
- self.requests.append(req)
1141
- if not req.volume_id.startswith("vo-"):
1142
- raise GRPCError(Status.NOT_FOUND, f"invalid volume ID {req.volume_id}")
1143
- self.volume_commits[req.volume_id] += 1
1144
- await stream.send_message(api_pb2.VolumeCommitResponse(skip_reload=False))
1145
-
1146
- async def VolumeDelete(self, stream):
1147
- req: api_pb2.VolumeDeleteRequest = await stream.recv_message()
1148
- self.volume_files.pop(req.volume_id)
1149
- self.deployed_volumes = {k: vol_id for k, vol_id in self.deployed_volumes.items() if vol_id != req.volume_id}
1150
- await stream.send_message(Empty())
1151
-
1152
- async def VolumeReload(self, stream):
1153
- req = await stream.recv_message()
1154
- self.requests.append(req)
1155
- self.volume_reloads[req.volume_id] += 1
1156
- await stream.send_message(Empty())
1157
-
1158
- async def VolumeGetFile(self, stream):
1159
- req = await stream.recv_message()
1160
- if req.path.decode("utf-8") not in self.volume_files[req.volume_id]:
1161
- raise GRPCError(Status.NOT_FOUND, "File not found")
1162
- vol_file = self.volume_files[req.volume_id][req.path.decode("utf-8")]
1163
- if vol_file.data_blob_id:
1164
- await stream.send_message(api_pb2.VolumeGetFileResponse(data_blob_id=vol_file.data_blob_id))
1165
- else:
1166
- size = len(vol_file.data)
1167
- if req.start or req.len:
1168
- start = req.start
1169
- len_ = req.len or len(vol_file.data)
1170
- await stream.send_message(
1171
- api_pb2.VolumeGetFileResponse(data=vol_file.data[start : start + len_], size=size)
1172
- )
1173
- else:
1174
- await stream.send_message(api_pb2.VolumeGetFileResponse(data=vol_file.data, size=size))
1175
-
1176
- async def VolumeRemoveFile(self, stream):
1177
- req = await stream.recv_message()
1178
- if req.path.decode("utf-8") not in self.volume_files[req.volume_id]:
1179
- raise GRPCError(Status.INVALID_ARGUMENT, "File not found")
1180
- del self.volume_files[req.volume_id][req.path.decode("utf-8")]
1181
- await stream.send_message(Empty())
1182
-
1183
- async def VolumeListFiles(self, stream):
1184
- req = await stream.recv_message()
1185
- path = req.path if req.path else "/"
1186
- if path.startswith("/"):
1187
- path = path[1:]
1188
- if path.endswith("/"):
1189
- path = path[:-1]
1190
-
1191
- found_file = False # empty directory detection is not handled here!
1192
- for k, vol_file in self.volume_files[req.volume_id].items():
1193
- if not path or k == path or (k.startswith(path + "/") and (req.recursive or "/" not in k[len(path) + 1 :])):
1194
- entry = api_pb2.FileEntry(path=k, type=api_pb2.FileEntry.FileType.FILE, size=len(vol_file.data))
1195
- await stream.send_message(api_pb2.VolumeListFilesResponse(entries=[entry]))
1196
- found_file = True
1197
-
1198
- if path and not found_file:
1199
- raise GRPCError(Status.NOT_FOUND, "No such file")
1200
-
1201
- async def VolumePutFiles(self, stream):
1202
- req = await stream.recv_message()
1203
- for file in req.files:
1204
- blob_data = self.files_sha2data[file.sha256_hex]
1205
-
1206
- if file.filename in self.volume_files[req.volume_id] and req.disallow_overwrite_existing_files:
1207
- raise GRPCError(
1208
- Status.ALREADY_EXISTS,
1209
- f"{file.filename}: already exists (disallow_overwrite_existing_files={req.disallow_overwrite_existing_files}",
1210
- )
1211
-
1212
- self.volume_files[req.volume_id][file.filename] = VolumeFile(
1213
- data=blob_data["data"],
1214
- data_blob_id=blob_data["data_blob_id"],
1215
- mode=file.mode,
1216
- )
1217
- await stream.send_message(Empty())
1218
-
1219
- async def VolumeCopyFiles(self, stream):
1220
- req = await stream.recv_message()
1221
- for src_path in req.src_paths:
1222
- if src_path.decode("utf-8") not in self.volume_files[req.volume_id]:
1223
- raise GRPCError(Status.NOT_FOUND, f"Source file not found: {src_path}")
1224
- src_file = self.volume_files[req.volume_id][src_path.decode("utf-8")]
1225
- if len(req.src_paths) > 1:
1226
- # check to make sure dst is a directory
1227
- if (
1228
- req.dst_path.decode("utf-8").endswith(("/", "\\"))
1229
- or not os.path.splitext(os.path.basename(req.dst_path))[1]
1230
- ):
1231
- dst_path = os.path.join(req.dst_path, os.path.basename(src_path))
1232
- else:
1233
- raise GRPCError(Status.INVALID_ARGUMENT, f"{dst_path} is not a directory.")
1234
- else:
1235
- dst_path = req.dst_path
1236
- self.volume_files[req.volume_id][dst_path.decode("utf-8")] = src_file
1237
- await stream.send_message(Empty())
1238
-
1239
-
1240
- @pytest_asyncio.fixture
1241
- async def blob_server():
1242
- blobs = {}
1243
- blob_parts: Dict[str, Dict[int, bytes]] = defaultdict(dict)
1244
-
1245
- async def upload(request):
1246
- blob_id = request.query["blob_id"]
1247
- content = await request.content.read()
1248
- if content == b"FAILURE":
1249
- return aiohttp.web.Response(status=500)
1250
- content_md5 = hashlib.md5(content).hexdigest()
1251
- etag = f'"{content_md5}"'
1252
- if "part_number" in request.query:
1253
- part_number = int(request.query["part_number"])
1254
- blob_parts[blob_id][part_number] = content
1255
- else:
1256
- blobs[blob_id] = content
1257
- return aiohttp.web.Response(text="Hello, world", headers={"ETag": etag})
1258
-
1259
- async def complete_multipart(request):
1260
- blob_id = request.query["blob_id"]
1261
- blob_nums = range(min(blob_parts[blob_id].keys()), max(blob_parts[blob_id].keys()) + 1)
1262
- content = b""
1263
- part_hashes = b""
1264
- for num in blob_nums:
1265
- part_content = blob_parts[blob_id][num]
1266
- content += part_content
1267
- part_hashes += hashlib.md5(part_content).digest()
1268
-
1269
- content_md5 = hashlib.md5(part_hashes).hexdigest()
1270
- etag = f'"{content_md5}-{len(blob_parts[blob_id])}"'
1271
- blobs[blob_id] = content
1272
- return aiohttp.web.Response(text=f"<etag>{etag}</etag>")
1273
-
1274
- async def download(request):
1275
- blob_id = request.query["blob_id"]
1276
- if blob_id == "bl-failure":
1277
- return aiohttp.web.Response(status=500)
1278
- return aiohttp.web.Response(body=blobs[blob_id])
1279
-
1280
- app = aiohttp.web.Application()
1281
- app.add_routes([aiohttp.web.put("/upload", upload)])
1282
- app.add_routes([aiohttp.web.get("/download", download)])
1283
- app.add_routes([aiohttp.web.post("/complete_multipart", complete_multipart)])
1284
-
1285
- async with run_temporary_http_server(app) as host:
1286
- yield host, blobs
1287
-
1288
-
1289
- @pytest_asyncio.fixture(scope="function")
1290
- async def servicer_factory(blob_server):
1291
- @contextlib.asynccontextmanager
1292
- async def create_server(host=None, port=None, path=None):
1293
- blob_host, blobs = blob_server
1294
- servicer = MockClientServicer(blob_host, blobs) # type: ignore
1295
- server = None
1296
-
1297
- async def _start_servicer():
1298
- nonlocal server
1299
- server = grpclib.server.Server([servicer])
1300
- await server.start(host=host, port=port, path=path)
1301
-
1302
- async def _stop_servicer():
1303
- servicer.container_heartbeat_abort.set()
1304
- server.close()
1305
- # This is the proper way to close down the asyncio server,
1306
- # but it causes our tests to hang on 3.12+ because client connections
1307
- # for clients created through _Client.from_env don't get closed until
1308
- # asyncio event loop shutdown. Commenting out but perhaps revisit if we
1309
- # refactor the way that _Client cleanup happens.
1310
- # await server.wait_closed()
1311
-
1312
- start_servicer = synchronize_api(_start_servicer)
1313
- stop_servicer = synchronize_api(_stop_servicer)
1314
-
1315
- await start_servicer.aio()
1316
- try:
1317
- yield servicer
1318
- finally:
1319
- await stop_servicer.aio()
1320
-
1321
- yield create_server
1322
-
1323
-
1324
- @pytest_asyncio.fixture(scope="function")
1325
- async def servicer(servicer_factory):
1326
- port = find_free_port()
1327
- async with servicer_factory(host="0.0.0.0", port=port) as servicer:
1328
- servicer.remote_addr = f"http://127.0.0.1:{port}"
1329
- yield servicer
1330
-
1331
-
1332
- @pytest_asyncio.fixture(scope="function")
1333
- async def unix_servicer(servicer_factory):
1334
- with tempfile.TemporaryDirectory() as tmpdirname:
1335
- path = os.path.join(tmpdirname, "servicer.sock")
1336
- async with servicer_factory(path=path) as servicer:
1337
- servicer.remote_addr = f"unix://{path}"
1338
- yield servicer
1339
-
1340
-
1341
- @pytest_asyncio.fixture(scope="function")
1342
- async def client(servicer):
1343
- with Client(servicer.remote_addr, api_pb2.CLIENT_TYPE_CLIENT, ("foo-id", "foo-secret")) as client:
1344
- yield client
1345
-
1346
-
1347
- @pytest_asyncio.fixture(scope="function")
1348
- async def container_client(unix_servicer):
1349
- async with Client(unix_servicer.remote_addr, api_pb2.CLIENT_TYPE_CONTAINER, ("ta-123", "task-secret")) as client:
1350
- yield client
1351
-
1352
-
1353
- @pytest_asyncio.fixture(scope="function")
1354
- async def server_url_env(servicer, monkeypatch):
1355
- monkeypatch.setenv("MODAL_SERVER_URL", servicer.remote_addr)
1356
- yield
1357
-
1358
-
1359
- @pytest_asyncio.fixture(scope="function", autouse=True)
1360
- async def reset_default_client():
1361
- Client.set_env_client(None)
1362
-
1363
-
1364
- @pytest.fixture(name="mock_dir", scope="session")
1365
- def mock_dir_factory():
1366
- """Sets up a temp dir with content as specified in a nested dict
1367
-
1368
- Example usage:
1369
- spec = {
1370
- "foo": {
1371
- "bar.txt": "some content"
1372
- },
1373
- }
1374
-
1375
- with mock_dir(spec) as root_dir:
1376
- assert os.path.exists(os.path.join(root_dir, "foo", "bar.txt"))
1377
- """
1378
-
1379
- @contextlib.contextmanager
1380
- def mock_dir(root_spec):
1381
- def rec_make(dir, dir_spec):
1382
- for filename, spec in dir_spec.items():
1383
- path = os.path.join(dir, filename)
1384
- if isinstance(spec, str):
1385
- with open(path, "w") as f:
1386
- f.write(spec)
1387
- else:
1388
- os.mkdir(path)
1389
- rec_make(path, spec)
1390
-
1391
- # Windows has issues cleaning up TempDirectory: https://www.scivision.dev/python-tempfile-permission-error-windows
1392
- # Seems to have been fixed for some python versions in https://github.com/python/cpython/pull/10320.
1393
- root_dir = tempfile.mkdtemp()
1394
- rec_make(root_dir, root_spec)
1395
- cwd = os.getcwd()
1396
- try:
1397
- os.chdir(root_dir)
1398
- yield
1399
- finally:
1400
- os.chdir(cwd)
1401
- shutil.rmtree(root_dir, ignore_errors=True)
1402
-
1403
- return mock_dir
1404
-
1405
-
1406
- @pytest.fixture(autouse=True)
1407
- def reset_sys_modules():
1408
- # Needed since some tests will import dynamic modules
1409
- backup = sys.modules.copy()
1410
- try:
1411
- yield
1412
- finally:
1413
- sys.modules = backup
1414
-
1415
-
1416
- @pytest.fixture(autouse=True)
1417
- def reset_container_app():
1418
- try:
1419
- yield
1420
- finally:
1421
- _ContainerIOManager._reset_singleton()
1422
-
1423
-
1424
- @pytest.fixture
1425
- def repo_root(request):
1426
- return Path(request.config.rootdir)
1427
-
1428
-
1429
- @pytest.fixture(scope="module")
1430
- def test_dir(request):
1431
- """Absolute path to directory containing test file."""
1432
- root_dir = Path(request.config.rootdir)
1433
- test_dir = Path(os.getenv("PYTEST_CURRENT_TEST")).parent
1434
- return root_dir / test_dir
1435
-
1436
-
1437
- @pytest.fixture(scope="function")
1438
- def modal_config():
1439
- """Return a context manager with a temporary modal.toml file"""
1440
-
1441
- @contextlib.contextmanager
1442
- def mock_modal_toml(contents: str = "", show_on_error: bool = False):
1443
- # Some of the cli tests run within within the main process
1444
- # so we need to modify the config singletons to pick up any changes
1445
- orig_config_path_env = os.environ.get("MODAL_CONFIG_PATH")
1446
- orig_config_path = config.user_config_path
1447
- orig_profile = config._profile
1448
- try:
1449
- with tempfile.NamedTemporaryFile(delete=False, suffix=".toml", mode="w") as t:
1450
- t.write(textwrap.dedent(contents.strip("\n")))
1451
- os.environ["MODAL_CONFIG_PATH"] = t.name
1452
- config.user_config_path = t.name
1453
- config._user_config = config._read_user_config()
1454
- config._profile = config._config_active_profile()
1455
- yield t.name
1456
- except Exception:
1457
- if show_on_error:
1458
- with open(t.name) as f:
1459
- print(f"Test config file contents:\n\n{f.read()}", file=sys.stderr)
1460
- raise
1461
- finally:
1462
- if orig_config_path_env:
1463
- os.environ["MODAL_CONFIG_PATH"] = orig_config_path_env
1464
- else:
1465
- del os.environ["MODAL_CONFIG_PATH"]
1466
- config.user_config_path = orig_config_path
1467
- config._user_config = config._read_user_config()
1468
- config._profile = orig_profile
1469
- os.remove(t.name)
1470
-
1471
- return mock_modal_toml
1472
-
1473
-
1474
- @pytest.fixture
1475
- def supports_dir(test_dir):
1476
- return test_dir / Path("supports")
1477
-
1478
-
1479
- @pytest_asyncio.fixture
1480
- async def set_env_client(client):
1481
- try:
1482
- Client.set_env_client(client)
1483
- yield
1484
- finally:
1485
- Client.set_env_client(None)