modal 0.62.16__py3-none-any.whl → 0.72.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +17 -13
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +420 -937
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -59
  11. modal/_resources.py +51 -0
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1036 -0
  15. modal/_runtime/execution_context.py +89 -0
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +134 -9
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +52 -16
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +479 -100
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +460 -171
  29. modal/_utils/grpc_testing.py +47 -31
  30. modal/_utils/grpc_utils.py +62 -109
  31. modal/_utils/hash_utils.py +61 -19
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +5 -7
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +14 -12
  43. modal/app.py +1003 -314
  44. modal/app.pyi +540 -264
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +63 -53
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +205 -45
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +62 -14
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +64 -58
  55. modal/cli/launch.py +32 -18
  56. modal/cli/network_file_system.py +64 -83
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +35 -10
  59. modal/cli/programs/vscode.py +60 -10
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +234 -131
  62. modal/cli/secret.py +8 -7
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +79 -10
  65. modal/cli/volume.py +110 -109
  66. modal/client.py +250 -144
  67. modal/client.pyi +157 -118
  68. modal/cloud_bucket_mount.py +108 -34
  69. modal/cloud_bucket_mount.pyi +32 -38
  70. modal/cls.py +535 -148
  71. modal/cls.pyi +190 -146
  72. modal/config.py +41 -19
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +111 -65
  76. modal/dict.pyi +136 -131
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +34 -43
  80. modal/experimental.py +61 -2
  81. modal/extensions/ipython.py +5 -5
  82. modal/file_io.py +537 -0
  83. modal/file_io.pyi +235 -0
  84. modal/file_pattern_matcher.py +197 -0
  85. modal/functions.py +906 -911
  86. modal/functions.pyi +466 -430
  87. modal/gpu.py +57 -44
  88. modal/image.py +1089 -479
  89. modal/image.pyi +584 -228
  90. modal/io_streams.py +434 -0
  91. modal/io_streams.pyi +122 -0
  92. modal/mount.py +314 -101
  93. modal/mount.pyi +241 -235
  94. modal/network_file_system.py +92 -92
  95. modal/network_file_system.pyi +152 -110
  96. modal/object.py +67 -36
  97. modal/object.pyi +166 -143
  98. modal/output.py +63 -0
  99. modal/parallel_map.py +434 -0
  100. modal/parallel_map.pyi +75 -0
  101. modal/partial_function.py +282 -117
  102. modal/partial_function.pyi +222 -129
  103. modal/proxy.py +15 -12
  104. modal/proxy.pyi +3 -8
  105. modal/queue.py +182 -65
  106. modal/queue.pyi +218 -118
  107. modal/requirements/2024.04.txt +29 -0
  108. modal/requirements/2024.10.txt +16 -0
  109. modal/requirements/README.md +21 -0
  110. modal/requirements/base-images.json +22 -0
  111. modal/retries.py +48 -7
  112. modal/runner.py +459 -156
  113. modal/runner.pyi +135 -71
  114. modal/running_app.py +38 -0
  115. modal/sandbox.py +514 -236
  116. modal/sandbox.pyi +397 -169
  117. modal/schedule.py +4 -4
  118. modal/scheduler_placement.py +20 -3
  119. modal/secret.py +56 -31
  120. modal/secret.pyi +62 -42
  121. modal/serving.py +51 -56
  122. modal/serving.pyi +44 -36
  123. modal/stream_type.py +15 -0
  124. modal/token_flow.py +5 -3
  125. modal/token_flow.pyi +37 -32
  126. modal/volume.py +285 -157
  127. modal/volume.pyi +249 -184
  128. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/METADATA +7 -7
  129. modal-0.72.11.dist-info/RECORD +174 -0
  130. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
  131. modal_docs/gen_reference_docs.py +3 -1
  132. modal_docs/mdmd/mdmd.py +0 -1
  133. modal_docs/mdmd/signatures.py +5 -2
  134. modal_global_objects/images/base_images.py +28 -0
  135. modal_global_objects/mounts/python_standalone.py +2 -2
  136. modal_proto/__init__.py +1 -1
  137. modal_proto/api.proto +1288 -533
  138. modal_proto/api_grpc.py +856 -456
  139. modal_proto/api_pb2.py +2165 -1157
  140. modal_proto/api_pb2.pyi +8859 -0
  141. modal_proto/api_pb2_grpc.py +1674 -855
  142. modal_proto/api_pb2_grpc.pyi +1416 -0
  143. modal_proto/modal_api_grpc.py +149 -0
  144. modal_proto/modal_options_grpc.py +3 -0
  145. modal_proto/options_pb2.pyi +20 -0
  146. modal_proto/options_pb2_grpc.pyi +7 -0
  147. modal_proto/py.typed +0 -0
  148. modal_version/__init__.py +1 -1
  149. modal_version/_version_generated.py +2 -2
  150. modal/_asgi.py +0 -370
  151. modal/_container_entrypoint.pyi +0 -378
  152. modal/_container_exec.py +0 -128
  153. modal/_sandbox_shell.py +0 -49
  154. modal/shared_volume.py +0 -23
  155. modal/shared_volume.pyi +0 -24
  156. modal/stub.py +0 -783
  157. modal/stub.pyi +0 -332
  158. modal-0.62.16.dist-info/RECORD +0 -198
  159. modal_global_objects/images/conda.py +0 -15
  160. modal_global_objects/images/debian_slim.py +0 -15
  161. modal_global_objects/images/micromamba.py +0 -15
  162. test/__init__.py +0 -1
  163. test/aio_test.py +0 -12
  164. test/async_utils_test.py +0 -262
  165. test/blob_test.py +0 -67
  166. test/cli_imports_test.py +0 -149
  167. test/cli_test.py +0 -659
  168. test/client_test.py +0 -194
  169. test/cls_test.py +0 -630
  170. test/config_test.py +0 -137
  171. test/conftest.py +0 -1420
  172. test/container_app_test.py +0 -32
  173. test/container_test.py +0 -1389
  174. test/cpu_test.py +0 -23
  175. test/decorator_test.py +0 -85
  176. test/deprecation_test.py +0 -34
  177. test/dict_test.py +0 -33
  178. test/e2e_test.py +0 -68
  179. test/error_test.py +0 -7
  180. test/function_serialization_test.py +0 -32
  181. test/function_test.py +0 -653
  182. test/function_utils_test.py +0 -101
  183. test/gpu_test.py +0 -159
  184. test/grpc_utils_test.py +0 -141
  185. test/helpers.py +0 -42
  186. test/image_test.py +0 -669
  187. test/live_reload_test.py +0 -80
  188. test/lookup_test.py +0 -70
  189. test/mdmd_test.py +0 -329
  190. test/mount_test.py +0 -162
  191. test/mounted_files_test.py +0 -329
  192. test/network_file_system_test.py +0 -181
  193. test/notebook_test.py +0 -66
  194. test/object_test.py +0 -41
  195. test/package_utils_test.py +0 -25
  196. test/queue_test.py +0 -97
  197. test/resolver_test.py +0 -58
  198. test/retries_test.py +0 -67
  199. test/runner_test.py +0 -85
  200. test/sandbox_test.py +0 -191
  201. test/schedule_test.py +0 -15
  202. test/scheduler_placement_test.py +0 -29
  203. test/secret_test.py +0 -78
  204. test/serialization_test.py +0 -42
  205. test/stub_composition_test.py +0 -10
  206. test/stub_test.py +0 -360
  207. test/test_asgi_wrapper.py +0 -234
  208. test/token_flow_test.py +0 -18
  209. test/traceback_test.py +0 -135
  210. test/tunnel_test.py +0 -29
  211. test/utils_test.py +0 -88
  212. test/version_test.py +0 -14
  213. test/volume_test.py +0 -341
  214. test/watcher_test.py +0 -30
  215. test/webhook_test.py +0 -146
  216. /modal/{requirements.312.txt → requirements/2023.12.312.txt} +0 -0
  217. /modal/{requirements.txt → requirements/2023.12.txt} +0 -0
  218. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
  219. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
  220. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
test/conftest.py DELETED
@@ -1,1420 +0,0 @@
1
- # Copyright Modal Labs 2024
2
- from __future__ import annotations
3
-
4
- import asyncio
5
- import contextlib
6
- import dataclasses
7
- import hashlib
8
- import inspect
9
- import os
10
- import pytest
11
- import shutil
12
- import sys
13
- import tempfile
14
- import textwrap
15
- import threading
16
- import traceback
17
- from collections import defaultdict
18
- from pathlib import Path
19
- from typing import Dict, Iterator, Optional
20
-
21
- import aiohttp.web
22
- import aiohttp.web_runner
23
- import grpclib.server
24
- import pkg_resources
25
- import pytest_asyncio
26
- from google.protobuf.empty_pb2 import Empty
27
- from grpclib import GRPCError, Status
28
-
29
- import modal._serialization
30
- from modal import __version__, config
31
- from modal._serialization import serialize_data_format
32
- from modal._utils.async_utils import asyncify, synchronize_api
33
- from modal._utils.grpc_testing import patch_mock_servicer
34
- from modal._utils.grpc_utils import find_free_port
35
- from modal._utils.http_utils import run_temporary_http_server
36
- from modal._vendor import cloudpickle
37
- from modal.app import _ContainerApp
38
- from modal.client import Client
39
- from modal.mount import client_mount_name
40
- from modal_proto import api_grpc, api_pb2
41
-
42
-
43
- @dataclasses.dataclass
44
- class VolumeFile:
45
- data: bytes
46
- data_blob_id: str
47
- mode: int
48
-
49
-
50
- # TODO: Isolate all test config from the host
51
- @pytest.fixture(scope="session", autouse=True)
52
- def set_env():
53
- os.environ["MODAL_ENVIRONMENT"] = "main"
54
-
55
-
56
- @patch_mock_servicer
57
- class MockClientServicer(api_grpc.ModalClientBase):
58
- # TODO(erikbern): add more annotations
59
- container_inputs: list[api_pb2.FunctionGetInputsResponse]
60
- container_outputs: list[api_pb2.FunctionPutOutputsRequest]
61
- fc_data_in: defaultdict[str, asyncio.Queue[api_pb2.DataChunk]]
62
- fc_data_out: defaultdict[str, asyncio.Queue[api_pb2.DataChunk]]
63
-
64
- def __init__(self, blob_host, blobs):
65
- self.put_outputs_barrier = threading.Barrier(
66
- 1, timeout=10
67
- ) # set to non-1 to get lock-step of output pushing within a test
68
- self.get_inputs_barrier = threading.Barrier(
69
- 1, timeout=10
70
- ) # set to non-1 to get lock-step of input releases within a test
71
-
72
- self.app_state_history = defaultdict(list)
73
- self.app_heartbeats: Dict[str, int] = defaultdict(int)
74
- self.container_checkpoint_requests = 0
75
- self.n_blobs = 0
76
- self.blob_host = blob_host
77
- self.blobs = blobs # shared dict
78
- self.requests = []
79
- self.done = False
80
- self.rate_limit_sleep_duration = None
81
- self.fail_get_inputs = False
82
- self.slow_put_inputs = False
83
- self.container_inputs = []
84
- self.container_outputs = []
85
- self.fc_data_in = defaultdict(lambda: asyncio.Queue()) # unbounded
86
- self.fc_data_out = defaultdict(lambda: asyncio.Queue()) # unbounded
87
- self.queue = []
88
- self.deployed_apps = {
89
- client_mount_name(): "ap-x",
90
- }
91
- self.app_objects = {}
92
- self.app_single_objects = {}
93
- self.app_unindexed_objects = {
94
- "ap-1": ["im-1", "vo-1"],
95
- }
96
- self.n_inputs = 0
97
- self.n_queues = 0
98
- self.n_dict_heartbeats = 0
99
- self.n_queue_heartbeats = 0
100
- self.n_nfs_heartbeats = 0
101
- self.n_vol_heartbeats = 0
102
- self.n_mounts = 0
103
- self.n_mount_files = 0
104
- self.mount_contents = {}
105
- self.files_name2sha = {}
106
- self.files_sha2data = {}
107
- self.function_id_for_function_call = {}
108
- self.client_calls = {}
109
- self.function_is_running = False
110
- self.n_functions = 0
111
- self.n_schedules = 0
112
- self.function2schedule = {}
113
- self.function_create_error = False
114
- self.heartbeat_status_code = None
115
- self.n_apps = 0
116
- self.classes = {}
117
-
118
- self.task_result = None
119
-
120
- self.nfs_files: Dict[str, Dict[str, api_pb2.SharedVolumePutFileRequest]] = defaultdict(dict)
121
- self.volume_files: Dict[str, Dict[str, VolumeFile]] = defaultdict(dict)
122
- self.images = {}
123
- self.image_build_function_ids = {}
124
- self.force_built_images = []
125
- self.fail_blob_create = []
126
- self.blob_create_metadata = None
127
- self.blob_multipart_threshold = 10_000_000
128
-
129
- self.precreated_functions = set()
130
- self.app_functions = {}
131
- self.fcidx = 0
132
-
133
- self.function_serialized = None
134
- self.class_serialized = None
135
-
136
- self.client_hello_metadata = None
137
-
138
- self.dicts = {}
139
- self.secrets = {}
140
-
141
- self.deployed_dicts = {}
142
- self.deployed_mounts = {
143
- (client_mount_name(), api_pb2.DEPLOYMENT_NAMESPACE_GLOBAL): "mo-123",
144
- }
145
- self.deployed_nfss = {}
146
- self.deployed_queues = {}
147
- self.deployed_secrets = {}
148
- self.deployed_volumes = {}
149
-
150
- self.cleared_function_calls = set()
151
-
152
- self.cancelled_calls = []
153
-
154
- self.app_client_disconnect_count = 0
155
- self.app_get_logs_initial_count = 0
156
- self.app_set_objects_count = 0
157
-
158
- self.volume_counter = 0
159
- # Volume-id -> commit/reload count
160
- self.volume_commits: Dict[str, int] = defaultdict(lambda: 0)
161
- self.volume_reloads: Dict[str, int] = defaultdict(lambda: 0)
162
-
163
- self.sandbox_defs = []
164
- self.sandbox: asyncio.subprocess.Process = None
165
-
166
- # Whether the sandbox is executing a shell program in interactive mode.
167
- self.sandbox_is_interactive = False
168
- self.sandbox_shell_prompt = "TEST_PROMPT# "
169
- self.sandbox_result: Optional[api_pb2.GenericResult] = None
170
-
171
- self.token_flow_localhost_port = None
172
- self.queue_max_len = 1_00
173
-
174
- self.container_heartbeat_response = None
175
- self.container_heartbeat_abort = threading.Event()
176
-
177
- @self.function_body
178
- def default_function_body(*args, **kwargs):
179
- return sum(arg**2 for arg in args) + sum(value**2 for key, value in kwargs.items())
180
-
181
- def function_body(self, func):
182
- """Decorator for setting the function that will be called for any FunctionGetOutputs calls"""
183
- self._function_body = func
184
- return func
185
-
186
- def container_heartbeat_return_now(self, response: api_pb2.ContainerHeartbeatResponse):
187
- self.container_heartbeat_response = response
188
- self.container_heartbeat_abort.set()
189
-
190
- def get_function_metadata(self, object_id: str) -> api_pb2.FunctionHandleMetadata:
191
- definition: api_pb2.Function = self.app_functions[object_id]
192
- return api_pb2.FunctionHandleMetadata(
193
- function_name=definition.function_name,
194
- function_type=definition.function_type,
195
- web_url=definition.web_url,
196
- is_method=definition.is_method,
197
- )
198
-
199
- def get_class_metadata(self, object_id: str) -> api_pb2.ClassHandleMetadata:
200
- class_handle_metadata = api_pb2.ClassHandleMetadata()
201
- for f_name, f_id in self.classes[object_id].items():
202
- function_handle_metadata = self.get_function_metadata(f_id)
203
- class_handle_metadata.methods.append(
204
- api_pb2.ClassMethod(
205
- function_name=f_name, function_id=f_id, function_handle_metadata=function_handle_metadata
206
- )
207
- )
208
- return class_handle_metadata
209
-
210
- def get_object_metadata(self, object_id) -> api_pb2.Object:
211
- if object_id.startswith("fu-"):
212
- res = api_pb2.Object(function_handle_metadata=self.get_function_metadata(object_id))
213
-
214
- elif object_id.startswith("cs-"):
215
- res = api_pb2.Object(class_handle_metadata=self.get_class_metadata(object_id))
216
-
217
- elif object_id.startswith("mo-"):
218
- mount_handle_metadata = api_pb2.MountHandleMetadata(content_checksum_sha256_hex="abc123")
219
- res = api_pb2.Object(mount_handle_metadata=mount_handle_metadata)
220
-
221
- elif object_id.startswith("sb-"):
222
- sandbox_handle_metadata = api_pb2.SandboxHandleMetadata(result=self.sandbox_result)
223
- res = api_pb2.Object(sandbox_handle_metadata=sandbox_handle_metadata)
224
-
225
- else:
226
- res = api_pb2.Object()
227
-
228
- res.object_id = object_id
229
- return res
230
-
231
- ### App
232
-
233
- async def AppCreate(self, stream):
234
- request: api_pb2.AppCreateRequest = await stream.recv_message()
235
- self.requests.append(request)
236
- self.n_apps += 1
237
- app_id = f"ap-{self.n_apps}"
238
- self.app_state_history[app_id].append(api_pb2.APP_STATE_INITIALIZING)
239
- await stream.send_message(
240
- api_pb2.AppCreateResponse(app_id=app_id, app_logs_url="https://modaltest.com/apps/ap-123")
241
- )
242
-
243
- async def AppClientDisconnect(self, stream):
244
- request: api_pb2.AppClientDisconnectRequest = await stream.recv_message()
245
- self.requests.append(request)
246
- self.done = True
247
- self.app_client_disconnect_count += 1
248
- state_history = self.app_state_history[request.app_id]
249
- if state_history[-1] not in [api_pb2.APP_STATE_DETACHED, api_pb2.APP_STATE_DEPLOYED]:
250
- state_history.append(api_pb2.APP_STATE_STOPPED)
251
- await stream.send_message(Empty())
252
-
253
- async def AppGetLogs(self, stream):
254
- request: api_pb2.AppGetLogsRequest = await stream.recv_message()
255
- if not request.last_entry_id:
256
- # Just count initial requests
257
- self.app_get_logs_initial_count += 1
258
- last_entry_id = "1"
259
- else:
260
- last_entry_id = str(int(request.last_entry_id) + 1)
261
- await asyncio.sleep(0.5)
262
- log = api_pb2.TaskLogs(data=f"hello, world ({last_entry_id})\n", file_descriptor=api_pb2.FILE_DESCRIPTOR_STDOUT)
263
- await stream.send_message(api_pb2.TaskLogsBatch(entry_id=last_entry_id, items=[log]))
264
- if self.done:
265
- await stream.send_message(api_pb2.TaskLogsBatch(app_done=True))
266
-
267
- async def AppGetObjects(self, stream):
268
- request: api_pb2.AppGetObjectsRequest = await stream.recv_message()
269
- object_ids = self.app_objects.get(request.app_id, {})
270
- objects = list(object_ids.items())
271
- if request.include_unindexed:
272
- unindexed_object_ids = self.app_unindexed_objects.get(request.app_id, [])
273
- objects += [(None, object_id) for object_id in unindexed_object_ids]
274
- items = [
275
- api_pb2.AppGetObjectsItem(tag=tag, object=self.get_object_metadata(object_id)) for tag, object_id in objects
276
- ]
277
- await stream.send_message(api_pb2.AppGetObjectsResponse(items=items))
278
-
279
- async def AppSetObjects(self, stream):
280
- request: api_pb2.AppSetObjectsRequest = await stream.recv_message()
281
- self.app_objects[request.app_id] = dict(request.indexed_object_ids)
282
- self.app_unindexed_objects[request.app_id] = list(request.unindexed_object_ids)
283
- if request.single_object_id:
284
- self.app_single_objects[request.app_id] = request.single_object_id
285
- self.app_set_objects_count += 1
286
- if request.new_app_state:
287
- self.app_state_history[request.app_id].append(request.new_app_state)
288
- await stream.send_message(Empty())
289
-
290
- async def AppDeploy(self, stream):
291
- request: api_pb2.AppDeployRequest = await stream.recv_message()
292
- self.deployed_apps[request.name] = request.app_id
293
- self.app_state_history[request.app_id].append(api_pb2.APP_STATE_DEPLOYED)
294
- await stream.send_message(api_pb2.AppDeployResponse(url="http://test.modal.com/foo/bar"))
295
-
296
- async def AppGetByDeploymentName(self, stream):
297
- request: api_pb2.AppGetByDeploymentNameRequest = await stream.recv_message()
298
- await stream.send_message(api_pb2.AppGetByDeploymentNameResponse(app_id=self.deployed_apps.get(request.name)))
299
-
300
- async def AppHeartbeat(self, stream):
301
- request: api_pb2.AppHeartbeatRequest = await stream.recv_message()
302
- self.requests.append(request)
303
- self.app_heartbeats[request.app_id] += 1
304
- await stream.send_message(Empty())
305
-
306
- ### Checkpoint
307
-
308
- async def ContainerCheckpoint(self, stream):
309
- request: api_pb2.ContainerCheckpointRequest = await stream.recv_message()
310
- self.requests.append(request)
311
- self.container_checkpoint_requests += 1
312
- await stream.send_message(Empty())
313
-
314
- ### Blob
315
-
316
- async def BlobCreate(self, stream):
317
- req = await stream.recv_message()
318
- # This is used to test retry_transient_errors, see grpc_utils_test.py
319
- self.blob_create_metadata = stream.metadata
320
- if len(self.fail_blob_create) > 0:
321
- status_code = self.fail_blob_create.pop()
322
- raise GRPCError(status_code, "foobar")
323
- elif req.content_length > self.blob_multipart_threshold:
324
- self.n_blobs += 1
325
- blob_id = f"bl-{self.n_blobs}"
326
- num_parts = (req.content_length + self.blob_multipart_threshold - 1) // self.blob_multipart_threshold
327
- upload_urls = []
328
- for part_number in range(num_parts):
329
- upload_url = f"{self.blob_host}/upload?blob_id={blob_id}&part_number={part_number}"
330
- upload_urls.append(upload_url)
331
-
332
- await stream.send_message(
333
- api_pb2.BlobCreateResponse(
334
- blob_id=blob_id,
335
- multipart=api_pb2.MultiPartUpload(
336
- part_length=self.blob_multipart_threshold,
337
- upload_urls=upload_urls,
338
- completion_url=f"{self.blob_host}/complete_multipart?blob_id={blob_id}",
339
- ),
340
- )
341
- )
342
- else:
343
- self.n_blobs += 1
344
- blob_id = f"bl-{self.n_blobs}"
345
- upload_url = f"{self.blob_host}/upload?blob_id={blob_id}"
346
- await stream.send_message(api_pb2.BlobCreateResponse(blob_id=blob_id, upload_url=upload_url))
347
-
348
- async def BlobGet(self, stream):
349
- request: api_pb2.BlobGetRequest = await stream.recv_message()
350
- download_url = f"{self.blob_host}/download?blob_id={request.blob_id}"
351
- await stream.send_message(api_pb2.BlobGetResponse(download_url=download_url))
352
-
353
- ### Class
354
-
355
- async def ClassCreate(self, stream):
356
- request: api_pb2.ClassCreateRequest = await stream.recv_message()
357
- assert request.app_id
358
- methods: dict[str, str] = {method.function_name: method.function_id for method in request.methods}
359
- class_id = "cs-" + str(len(self.classes))
360
- self.classes[class_id] = methods
361
- await stream.send_message(
362
- api_pb2.ClassCreateResponse(class_id=class_id, handle_metadata=self.get_class_metadata(class_id))
363
- )
364
-
365
- async def ClassGet(self, stream):
366
- request: api_pb2.ClassGetRequest = await stream.recv_message()
367
- app_id = self.deployed_apps.get(request.app_name)
368
- app_objects = self.app_objects[app_id]
369
- object_id = app_objects.get(request.object_tag)
370
- if object_id is None:
371
- raise GRPCError(Status.NOT_FOUND, f"can't find object {request.object_tag}")
372
- await stream.send_message(
373
- api_pb2.ClassGetResponse(class_id=object_id, handle_metadata=self.get_class_metadata(object_id))
374
- )
375
-
376
- ### Client
377
-
378
- async def ClientHello(self, stream):
379
- request: Empty = await stream.recv_message()
380
- self.requests.append(request)
381
- self.client_create_metadata = stream.metadata
382
- client_version = stream.metadata["x-modal-client-version"]
383
- assert stream.user_agent.startswith(f"modal-client/{__version__} ")
384
- if stream.metadata.get("x-modal-token-id") == "bad":
385
- raise GRPCError(Status.UNAUTHENTICATED, "bad bad bad")
386
- elif client_version == "timeout":
387
- await asyncio.sleep(60)
388
- await stream.send_message(api_pb2.ClientHelloResponse())
389
- elif client_version == "unauthenticated":
390
- raise GRPCError(Status.UNAUTHENTICATED, "failed authentication")
391
- elif client_version == "deprecated":
392
- await stream.send_message(api_pb2.ClientHelloResponse(warning="SUPER OLD"))
393
- elif pkg_resources.parse_version(client_version) < pkg_resources.parse_version(__version__):
394
- raise GRPCError(Status.FAILED_PRECONDITION, "Old client")
395
- else:
396
- await stream.send_message(api_pb2.ClientHelloResponse())
397
-
398
- # Container
399
-
400
- async def ContainerHeartbeat(self, stream):
401
- request: api_pb2.ContainerHeartbeatRequest = await stream.recv_message()
402
- self.requests.append(request)
403
- # Return earlier than the usual 15-second heartbeat to avoid suspending tests.
404
- await asyncify(self.container_heartbeat_abort.wait)(5)
405
- if self.container_heartbeat_response:
406
- await stream.send_message(self.container_heartbeat_response)
407
- self.container_heartbeat_response = None
408
- else:
409
- await stream.send_message(api_pb2.ContainerHeartbeatResponse())
410
-
411
- async def ContainerExec(self, stream):
412
- _request: api_pb2.ContainerExecRequest = await stream.recv_message()
413
- await stream.send_message(api_pb2.ContainerExecResponse(exec_id="container_exec_id"))
414
-
415
- async def ContainerExecGetOutput(self, stream):
416
- _request: api_pb2.ContainerExecGetOutputRequest = await stream.recv_message()
417
- await stream.send_message(
418
- api_pb2.RuntimeOutputBatch(
419
- items=[
420
- api_pb2.RuntimeOutputMessage(
421
- file_descriptor=api_pb2.FileDescriptor.FILE_DESCRIPTOR_STDOUT, message="Hello World"
422
- )
423
- ]
424
- )
425
- )
426
- await stream.send_message(api_pb2.RuntimeOutputBatch(exit_code=0))
427
-
428
- ### Dict
429
-
430
- async def DictCreate(self, stream):
431
- request: api_pb2.DictCreateRequest = await stream.recv_message()
432
- if request.existing_dict_id:
433
- dict_id = request.existing_dict_id
434
- else:
435
- dict_id = f"di-{len(self.dicts)}"
436
- self.dicts[dict_id] = {}
437
- await stream.send_message(api_pb2.DictCreateResponse(dict_id=dict_id))
438
-
439
- async def DictGetOrCreate(self, stream):
440
- request: api_pb2.DictGetOrCreateRequest = await stream.recv_message()
441
- k = (request.deployment_name, request.namespace, request.environment_name)
442
- if k in self.deployed_dicts:
443
- dict_id = self.deployed_dicts[k]
444
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING:
445
- dict_id = f"di-{len(self.dicts)}"
446
- self.dicts[dict_id] = {entry.key: entry.value for entry in request.data}
447
- self.deployed_dicts[k] = dict_id
448
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_EPHEMERAL:
449
- dict_id = f"di-{len(self.dicts)}"
450
- self.dicts[dict_id] = {entry.key: entry.value for entry in request.data}
451
- else:
452
- raise GRPCError(Status.NOT_FOUND, "Queue not found")
453
- await stream.send_message(api_pb2.DictGetOrCreateResponse(dict_id=dict_id))
454
-
455
- async def DictHeartbeat(self, stream):
456
- await stream.recv_message()
457
- self.n_dict_heartbeats += 1
458
- await stream.send_message(Empty())
459
-
460
- async def DictClear(self, stream):
461
- request: api_pb2.DictGetRequest = await stream.recv_message()
462
- self.dicts[request.dict_id] = {}
463
- await stream.send_message(Empty())
464
-
465
- async def DictGet(self, stream):
466
- request: api_pb2.DictGetRequest = await stream.recv_message()
467
- d = self.dicts[request.dict_id]
468
- await stream.send_message(api_pb2.DictGetResponse(value=d.get(request.key), found=bool(request.key in d)))
469
-
470
- async def DictLen(self, stream):
471
- request: api_pb2.DictLenRequest = await stream.recv_message()
472
- await stream.send_message(api_pb2.DictLenResponse(len=len(self.dicts[request.dict_id])))
473
-
474
- async def DictUpdate(self, stream):
475
- request: api_pb2.DictUpdateRequest = await stream.recv_message()
476
- for update in request.updates:
477
- self.dicts[request.dict_id][update.key] = update.value
478
- await stream.send_message(api_pb2.DictUpdateResponse())
479
-
480
- ### Function
481
-
482
- async def FunctionBindParams(self, stream):
483
- request: api_pb2.FunctionBindParamsRequest = await stream.recv_message()
484
- assert request.function_id
485
- assert request.serialized_params
486
- self.n_functions += 1
487
- function_id = f"fu-{self.n_functions}"
488
-
489
- await stream.send_message(api_pb2.FunctionBindParamsResponse(bound_function_id=function_id))
490
-
491
- @contextlib.contextmanager
492
- def input_lockstep(self) -> Iterator[threading.Barrier]:
493
- self.get_inputs_barrier = threading.Barrier(2, timeout=10)
494
- yield self.get_inputs_barrier
495
- self.get_inputs_barrier = threading.Barrier(1)
496
-
497
- @contextlib.contextmanager
498
- def output_lockstep(self) -> Iterator[threading.Barrier]:
499
- self.put_outputs_barrier = threading.Barrier(2, timeout=10)
500
- yield self.put_outputs_barrier
501
- self.put_outputs_barrier = threading.Barrier(1)
502
-
503
- async def FunctionGetInputs(self, stream):
504
- self.get_inputs_barrier.wait()
505
- request: api_pb2.FunctionGetInputsRequest = await stream.recv_message()
506
- assert request.function_id
507
- if self.fail_get_inputs:
508
- raise GRPCError(Status.INTERNAL)
509
- elif self.rate_limit_sleep_duration is not None:
510
- s = self.rate_limit_sleep_duration
511
- self.rate_limit_sleep_duration = None
512
- await stream.send_message(api_pb2.FunctionGetInputsResponse(rate_limit_sleep_duration=s))
513
- elif not self.container_inputs:
514
- await asyncio.sleep(1.0)
515
- await stream.send_message(api_pb2.FunctionGetInputsResponse(inputs=[]))
516
- else:
517
- await stream.send_message(self.container_inputs.pop(0))
518
-
519
- async def FunctionPutOutputs(self, stream):
520
- self.put_outputs_barrier.wait()
521
- request: api_pb2.FunctionPutOutputsRequest = await stream.recv_message()
522
- self.container_outputs.append(request)
523
- await stream.send_message(Empty())
524
-
525
- async def FunctionPrecreate(self, stream):
526
- req: api_pb2.FunctionPrecreateRequest = await stream.recv_message()
527
- if not req.existing_function_id:
528
- self.n_functions += 1
529
- function_id = f"fu-{self.n_functions}"
530
- else:
531
- function_id = req.existing_function_id
532
-
533
- self.precreated_functions.add(function_id)
534
-
535
- web_url = "http://xyz.internal" if req.HasField("webhook_config") and req.webhook_config.type else None
536
- await stream.send_message(
537
- api_pb2.FunctionPrecreateResponse(
538
- function_id=function_id,
539
- handle_metadata=api_pb2.FunctionHandleMetadata(
540
- function_name=req.function_name,
541
- function_type=req.function_type,
542
- web_url=web_url,
543
- ),
544
- )
545
- )
546
-
547
- async def FunctionCreate(self, stream):
548
- request: api_pb2.FunctionCreateRequest = await stream.recv_message()
549
- if self.function_create_error:
550
- raise GRPCError(Status.INTERNAL, "Function create failed")
551
- if request.existing_function_id:
552
- function_id = request.existing_function_id
553
- else:
554
- self.n_functions += 1
555
- function_id = f"fu-{self.n_functions}"
556
- if request.schedule:
557
- self.function2schedule[function_id] = request.schedule
558
- function = api_pb2.Function()
559
- function.CopyFrom(request.function)
560
- if function.webhook_config.type:
561
- function.web_url = "http://xyz.internal"
562
-
563
- self.app_functions[function_id] = function
564
- await stream.send_message(
565
- api_pb2.FunctionCreateResponse(
566
- function_id=function_id,
567
- function=function,
568
- handle_metadata=api_pb2.FunctionHandleMetadata(
569
- function_name=function.function_name,
570
- function_type=function.function_type,
571
- web_url=function.web_url,
572
- ),
573
- )
574
- )
575
-
576
- async def FunctionGet(self, stream):
577
- request: api_pb2.FunctionGetRequest = await stream.recv_message()
578
- app_id = self.deployed_apps.get(request.app_name)
579
- app_objects = self.app_objects[app_id]
580
- object_id = app_objects.get(request.object_tag)
581
- if object_id is None:
582
- raise GRPCError(Status.NOT_FOUND, f"can't find object {request.object_tag}")
583
- await stream.send_message(
584
- api_pb2.FunctionGetResponse(function_id=object_id, handle_metadata=self.get_function_metadata(object_id))
585
- )
586
-
587
- async def FunctionMap(self, stream):
588
- self.fcidx += 1
589
- request: api_pb2.FunctionMapRequest = await stream.recv_message()
590
- function_call_id = f"fc-{self.fcidx}"
591
- self.function_id_for_function_call[function_call_id] = request.function_id
592
- await stream.send_message(api_pb2.FunctionMapResponse(function_call_id=function_call_id))
593
-
594
- async def FunctionPutInputs(self, stream):
595
- request: api_pb2.FunctionPutInputsRequest = await stream.recv_message()
596
- response_items = []
597
- function_call_inputs = self.client_calls.setdefault(request.function_call_id, [])
598
- for item in request.inputs:
599
- args, kwargs = modal._serialization.deserialize(item.input.args, None) if item.input.args else ((), {})
600
- input_id = f"in-{self.n_inputs}"
601
- self.n_inputs += 1
602
- response_items.append(api_pb2.FunctionPutInputsResponseItem(input_id=input_id, idx=item.idx))
603
- function_call_inputs.append(((item.idx, input_id), (args, kwargs)))
604
- if self.slow_put_inputs:
605
- await asyncio.sleep(0.001)
606
- await stream.send_message(api_pb2.FunctionPutInputsResponse(inputs=response_items))
607
-
608
- async def FunctionGetOutputs(self, stream):
609
- request: api_pb2.FunctionGetOutputsRequest = await stream.recv_message()
610
- if request.clear_on_success:
611
- self.cleared_function_calls.add(request.function_call_id)
612
-
613
- client_calls = self.client_calls.get(request.function_call_id, [])
614
- if client_calls and not self.function_is_running:
615
- popidx = len(client_calls) // 2 # simulate that results don't always come in order
616
- (idx, input_id), (args, kwargs) = client_calls.pop(popidx)
617
- output_exc = None
618
- try:
619
- res = self._function_body(*args, **kwargs)
620
-
621
- if inspect.iscoroutine(res):
622
- result = await res
623
- result_data_format = api_pb2.DATA_FORMAT_PICKLE
624
- elif inspect.isgenerator(res):
625
- count = 0
626
- for item in res:
627
- count += 1
628
- await self.fc_data_out[request.function_call_id].put(
629
- api_pb2.DataChunk(
630
- data_format=api_pb2.DATA_FORMAT_PICKLE,
631
- data=serialize_data_format(item, api_pb2.DATA_FORMAT_PICKLE),
632
- index=count,
633
- )
634
- )
635
- result = api_pb2.GeneratorDone(items_total=count)
636
- result_data_format = api_pb2.DATA_FORMAT_GENERATOR_DONE
637
- else:
638
- result = res
639
- result_data_format = api_pb2.DATA_FORMAT_PICKLE
640
- except Exception as exc:
641
- serialized_exc = cloudpickle.dumps(exc)
642
- result = api_pb2.GenericResult(
643
- status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
644
- data=serialized_exc,
645
- exception=repr(exc),
646
- traceback="".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
647
- )
648
- output_exc = api_pb2.FunctionGetOutputsItem(
649
- input_id=input_id, idx=idx, result=result, gen_index=0, data_format=api_pb2.DATA_FORMAT_PICKLE
650
- )
651
-
652
- if output_exc:
653
- output = output_exc
654
- else:
655
- output = api_pb2.FunctionGetOutputsItem(
656
- input_id=input_id,
657
- idx=idx,
658
- result=api_pb2.GenericResult(
659
- status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
660
- data=serialize_data_format(result, result_data_format),
661
- ),
662
- data_format=result_data_format,
663
- )
664
-
665
- await stream.send_message(api_pb2.FunctionGetOutputsResponse(outputs=[output]))
666
- else:
667
- await stream.send_message(api_pb2.FunctionGetOutputsResponse(outputs=[]))
668
-
669
- async def FunctionGetSerialized(self, stream):
670
- await stream.send_message(
671
- api_pb2.FunctionGetSerializedResponse(
672
- function_serialized=self.function_serialized,
673
- class_serialized=self.class_serialized,
674
- )
675
- )
676
-
677
- async def FunctionCallCancel(self, stream):
678
- req = await stream.recv_message()
679
- self.cancelled_calls.append(req.function_call_id)
680
- await stream.send_message(Empty())
681
-
682
- async def FunctionCallGetDataIn(self, stream):
683
- req: api_pb2.FunctionCallGetDataRequest = await stream.recv_message()
684
- while True:
685
- chunk = await self.fc_data_in[req.function_call_id].get()
686
- await stream.send_message(chunk)
687
-
688
- async def FunctionCallGetDataOut(self, stream):
689
- req: api_pb2.FunctionCallGetDataRequest = await stream.recv_message()
690
- while True:
691
- chunk = await self.fc_data_out[req.function_call_id].get()
692
- await stream.send_message(chunk)
693
-
694
- async def FunctionCallPutDataOut(self, stream):
695
- req: api_pb2.FunctionCallPutDataRequest = await stream.recv_message()
696
- for chunk in req.data_chunks:
697
- await self.fc_data_out[req.function_call_id].put(chunk)
698
- await stream.send_message(Empty())
699
-
700
- ### Image
701
-
702
- async def ImageGetOrCreate(self, stream):
703
- request: api_pb2.ImageGetOrCreateRequest = await stream.recv_message()
704
- idx = len(self.images) + 1
705
- image_id = f"im-{idx}"
706
-
707
- self.images[image_id] = request.image
708
- self.image_build_function_ids[image_id] = request.build_function_id
709
- if request.force_build:
710
- self.force_built_images.append(image_id)
711
- await stream.send_message(api_pb2.ImageGetOrCreateResponse(image_id=image_id))
712
-
713
- async def ImageJoinStreaming(self, stream):
714
- await stream.recv_message()
715
- task_log_1 = api_pb2.TaskLogs(data="hello, world\n", file_descriptor=api_pb2.FILE_DESCRIPTOR_INFO)
716
- task_log_2 = api_pb2.TaskLogs(
717
- task_progress=api_pb2.TaskProgress(
718
- len=1, pos=0, progress_type=api_pb2.IMAGE_SNAPSHOT_UPLOAD, description="xyz"
719
- )
720
- )
721
- await stream.send_message(api_pb2.ImageJoinStreamingResponse(task_logs=[task_log_1, task_log_2]))
722
- await stream.send_message(
723
- api_pb2.ImageJoinStreamingResponse(
724
- result=api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS)
725
- )
726
- )
727
-
728
- ### Mount
729
-
730
- async def MountPutFile(self, stream):
731
- request: api_pb2.MountPutFileRequest = await stream.recv_message()
732
- if request.WhichOneof("data_oneof") is not None:
733
- self.files_sha2data[request.sha256_hex] = {"data": request.data, "data_blob_id": request.data_blob_id}
734
- self.n_mount_files += 1
735
- await stream.send_message(api_pb2.MountPutFileResponse(exists=True))
736
- else:
737
- await stream.send_message(api_pb2.MountPutFileResponse(exists=False))
738
-
739
- async def MountGetOrCreate(self, stream):
740
- request: api_pb2.MountGetOrCreateRequest = await stream.recv_message()
741
- k = (request.deployment_name, request.namespace)
742
- if request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_UNSPECIFIED:
743
- if k not in self.deployed_mounts:
744
- raise GRPCError(Status.NOT_FOUND, "Mount not found")
745
- mount_id = self.deployed_mounts[k]
746
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS:
747
- self.n_mounts += 1
748
- mount_id = f"mo-{self.n_mounts}"
749
- self.deployed_mounts[k] = mount_id
750
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_ANONYMOUS_OWNED_BY_APP:
751
- self.n_mounts += 1
752
- mount_id = f"mo-{self.n_mounts}"
753
-
754
- else:
755
- raise Exception("unsupported creation type")
756
-
757
- mount_content = self.mount_contents[mount_id] = {}
758
- for file in request.files:
759
- mount_content[file.filename] = self.files_name2sha[file.filename] = file.sha256_hex
760
-
761
- await stream.send_message(
762
- api_pb2.MountGetOrCreateResponse(
763
- mount_id=mount_id, handle_metadata=api_pb2.MountHandleMetadata(content_checksum_sha256_hex="deadbeef")
764
- )
765
- )
766
-
767
- ### Proxy
768
-
769
- async def ProxyGetOrCreate(self, stream):
770
- await stream.recv_message()
771
- await stream.send_message(api_pb2.ProxyGetOrCreateResponse(proxy_id="pr-123"))
772
-
773
- ### Queue
774
-
775
- async def QueueCreate(self, stream):
776
- request: api_pb2.QueueCreateRequest = await stream.recv_message()
777
- if request.existing_queue_id:
778
- queue_id = request.existing_queue_id
779
- else:
780
- self.n_queues += 1
781
- queue_id = f"qu-{self.n_queues}"
782
- await stream.send_message(api_pb2.QueueCreateResponse(queue_id=queue_id))
783
-
784
- async def QueueGetOrCreate(self, stream):
785
- request: api_pb2.QueueGetOrCreateRequest = await stream.recv_message()
786
- k = (request.deployment_name, request.namespace, request.environment_name)
787
- if k in self.deployed_queues:
788
- queue_id = self.deployed_queues[k]
789
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING:
790
- self.n_queues += 1
791
- queue_id = f"qu-{self.n_queues}"
792
- self.deployed_queues[k] = queue_id
793
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_EPHEMERAL:
794
- self.n_queues += 1
795
- queue_id = f"qu-{self.n_queues}"
796
- else:
797
- raise GRPCError(Status.NOT_FOUND, "Queue not found")
798
- await stream.send_message(api_pb2.QueueGetOrCreateResponse(queue_id=queue_id))
799
-
800
- async def QueueHeartbeat(self, stream):
801
- await stream.recv_message()
802
- self.n_queue_heartbeats += 1
803
- await stream.send_message(Empty())
804
-
805
- async def QueuePut(self, stream):
806
- request: api_pb2.QueuePutRequest = await stream.recv_message()
807
- if len(self.queue) >= self.queue_max_len:
808
- raise GRPCError(Status.RESOURCE_EXHAUSTED, f"Hit servicer's max len for Queues: {self.queue_max_len}")
809
- self.queue += request.values
810
- await stream.send_message(Empty())
811
-
812
- async def QueueGet(self, stream):
813
- await stream.recv_message()
814
- if len(self.queue) > 0:
815
- values = [self.queue.pop(0)]
816
- else:
817
- values = []
818
- await stream.send_message(api_pb2.QueueGetResponse(values=values))
819
-
820
- async def QueueLen(self, stream):
821
- await stream.recv_message()
822
- await stream.send_message(api_pb2.QueueLenResponse(len=len(self.queue)))
823
-
824
- ### Sandbox
825
-
826
- async def SandboxCreate(self, stream):
827
- request: api_pb2.SandboxCreateRequest = await stream.recv_message()
828
- if request.definition.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
829
- self.sandbox_is_interactive = True
830
-
831
- self.sandbox = await asyncio.subprocess.create_subprocess_exec(
832
- *request.definition.entrypoint_args,
833
- stdout=asyncio.subprocess.PIPE,
834
- stderr=asyncio.subprocess.PIPE,
835
- stdin=asyncio.subprocess.PIPE,
836
- )
837
-
838
- self.sandbox_defs.append(request.definition)
839
-
840
- await stream.send_message(api_pb2.SandboxCreateResponse(sandbox_id="sb-123"))
841
-
842
- async def SandboxGetLogs(self, stream):
843
- request: api_pb2.SandboxGetLogsRequest = await stream.recv_message()
844
- f: asyncio.StreamReader
845
- if self.sandbox_is_interactive:
846
- # sends an empty message to simulate PTY
847
- await stream.send_message(
848
- api_pb2.TaskLogsBatch(
849
- items=[api_pb2.TaskLogs(data=self.sandbox_shell_prompt, file_descriptor=request.file_descriptor)]
850
- )
851
- )
852
-
853
- if request.file_descriptor == api_pb2.FILE_DESCRIPTOR_STDOUT:
854
- # Blocking read until EOF is returned.
855
- f = self.sandbox.stdout
856
- else:
857
- f = self.sandbox.stderr
858
-
859
- async for message in f:
860
- await stream.send_message(
861
- api_pb2.TaskLogsBatch(
862
- items=[api_pb2.TaskLogs(data=message.decode("utf-8"), file_descriptor=request.file_descriptor)]
863
- )
864
- )
865
-
866
- await stream.send_message(api_pb2.TaskLogsBatch(eof=True))
867
-
868
- async def SandboxWait(self, stream):
869
- request: api_pb2.SandboxWaitRequest = await stream.recv_message()
870
- try:
871
- await asyncio.wait_for(self.sandbox.wait(), request.timeout)
872
- except asyncio.TimeoutError:
873
- pass
874
-
875
- if self.sandbox.returncode is None:
876
- # This happens when request.timeout is 0 and the sandbox hasn't completed.
877
- await stream.send_message(api_pb2.SandboxWaitResponse())
878
- return
879
- elif self.sandbox.returncode != 0:
880
- result = api_pb2.GenericResult(
881
- status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE, exitcode=self.sandbox.returncode
882
- )
883
- else:
884
- result = api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS)
885
- self.sandbox_result = result
886
- await stream.send_message(api_pb2.SandboxWaitResponse(result=result))
887
-
888
- async def SandboxTerminate(self, stream):
889
- self.sandbox.terminate()
890
- await stream.send_message(api_pb2.SandboxTerminateResponse())
891
-
892
- async def SandboxGetTaskId(self, stream):
893
- # only used for `modal shell` / `modal container exec`
894
- _request: api_pb2.SandboxGetTaskIdRequest = await stream.recv_message()
895
- await stream.send_message(api_pb2.SandboxGetTaskIdResponse(task_id="modal_container_exec"))
896
-
897
- async def SandboxStdinWrite(self, stream):
898
- request: api_pb2.SandboxStdinWriteRequest = await stream.recv_message()
899
-
900
- self.sandbox.stdin.write(request.input)
901
- await self.sandbox.stdin.drain()
902
-
903
- if request.eof:
904
- self.sandbox.stdin.close()
905
- await stream.send_message(api_pb2.SandboxStdinWriteResponse())
906
-
907
- ### Secret
908
-
909
- async def SecretGetOrCreate(self, stream):
910
- request: api_pb2.SecretGetOrCreateRequest = await stream.recv_message()
911
- k = (request.deployment_name, request.namespace, request.environment_name)
912
- if request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_ANONYMOUS_OWNED_BY_APP:
913
- secret_id = "st-" + str(len(self.secrets))
914
- self.secrets[secret_id] = request.env_dict
915
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS:
916
- if k in self.deployed_secrets:
917
- raise GRPCError(Status.ALREADY_EXISTS, "Already exists")
918
- secret_id = None
919
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_OVERWRITE_IF_EXISTS:
920
- secret_id = None
921
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_UNSPECIFIED:
922
- if k not in self.deployed_secrets:
923
- raise GRPCError(Status.NOT_FOUND, "No such secret")
924
- secret_id = self.deployed_secrets[k]
925
- else:
926
- raise Exception("unsupported creation type")
927
-
928
- if secret_id is None: # Create one
929
- secret_id = "st-" + str(len(self.secrets))
930
- self.secrets[secret_id] = request.env_dict
931
- self.deployed_secrets[k] = secret_id
932
-
933
- await stream.send_message(api_pb2.SecretGetOrCreateResponse(secret_id=secret_id))
934
-
935
- async def SecretList(self, stream):
936
- await stream.recv_message()
937
- items = [api_pb2.SecretListItem(label=f"dummy-secret-{i}") for i, _ in enumerate(self.secrets)]
938
- await stream.send_message(api_pb2.SecretListResponse(items=items))
939
-
940
- ### Network File System (née Shared volume)
941
-
942
- async def SharedVolumeCreate(self, stream):
943
- nfs_id = f"sv-{len(self.nfs_files)}"
944
- self.nfs_files[nfs_id] = {}
945
- await stream.send_message(api_pb2.SharedVolumeCreateResponse(shared_volume_id=nfs_id))
946
-
947
- async def SharedVolumeGetOrCreate(self, stream):
948
- request: api_pb2.SharedVolumeGetOrCreateRequest = await stream.recv_message()
949
- k = (request.deployment_name, request.namespace, request.environment_name)
950
- if request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_UNSPECIFIED:
951
- if k not in self.deployed_nfss:
952
- raise GRPCError(Status.NOT_FOUND, "NFS not found")
953
- nfs_id = self.deployed_nfss[k]
954
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_EPHEMERAL:
955
- nfs_id = f"sv-{len(self.nfs_files)}"
956
- self.nfs_files[nfs_id] = {}
957
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING:
958
- if k not in self.deployed_nfss:
959
- nfs_id = f"sv-{len(self.nfs_files)}"
960
- self.nfs_files[nfs_id] = {}
961
- self.deployed_nfss[k] = nfs_id
962
- nfs_id = self.deployed_nfss[k]
963
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS:
964
- if k in self.deployed_nfss:
965
- raise GRPCError(Status.ALREADY_EXISTS, "NFS already exists")
966
- nfs_id = f"sv-{len(self.nfs_files)}"
967
- self.nfs_files[nfs_id] = {}
968
- self.deployed_nfss[k] = nfs_id
969
- else:
970
- raise GRPCError(Status.INVALID_ARGUMENT, "unsupported object creation type")
971
-
972
- await stream.send_message(api_pb2.SharedVolumeGetOrCreateResponse(shared_volume_id=nfs_id))
973
-
974
- async def SharedVolumeHeartbeat(self, stream):
975
- await stream.recv_message()
976
- self.n_nfs_heartbeats += 1
977
- await stream.send_message(Empty())
978
-
979
- async def SharedVolumePutFile(self, stream):
980
- req = await stream.recv_message()
981
- self.nfs_files[req.shared_volume_id][req.path] = req
982
- await stream.send_message(api_pb2.SharedVolumePutFileResponse(exists=True))
983
-
984
- async def SharedVolumeGetFile(self, stream):
985
- req = await stream.recv_message()
986
- put_req = self.nfs_files.get(req.shared_volume_id, {}).get(req.path)
987
- if not put_req:
988
- raise GRPCError(Status.NOT_FOUND, f"No such file: {req.path}")
989
- if put_req.data_blob_id:
990
- await stream.send_message(api_pb2.SharedVolumeGetFileResponse(data_blob_id=put_req.data_blob_id))
991
- else:
992
- await stream.send_message(api_pb2.SharedVolumeGetFileResponse(data=put_req.data))
993
-
994
- async def SharedVolumeListFilesStream(self, stream):
995
- req: api_pb2.SharedVolumeListFilesRequest = await stream.recv_message()
996
- for path in self.nfs_files[req.shared_volume_id].keys():
997
- entry = api_pb2.SharedVolumeListFilesEntry(path=path)
998
- response = api_pb2.SharedVolumeListFilesResponse(entries=[entry])
999
- await stream.send_message(response)
1000
-
1001
- ### Task
1002
-
1003
- async def TaskCurrentInputs(
1004
- self, stream: "grpclib.server.Stream[Empty, api_pb2.TaskCurrentInputsResponse]"
1005
- ) -> None:
1006
- await stream.send_message(api_pb2.TaskCurrentInputsResponse(input_ids=[])) # dummy implementation
1007
-
1008
- async def TaskResult(self, stream):
1009
- request: api_pb2.TaskResultRequest = await stream.recv_message()
1010
- self.task_result = request.result
1011
- await stream.send_message(Empty())
1012
-
1013
- ### Token flow
1014
-
1015
- async def TokenFlowCreate(self, stream):
1016
- request: api_pb2.TokenFlowCreateRequest = await stream.recv_message()
1017
- self.token_flow_localhost_port = request.localhost_port
1018
- await stream.send_message(
1019
- api_pb2.TokenFlowCreateResponse(token_flow_id="tc-123", web_url="https://localhost/xyz/abc")
1020
- )
1021
-
1022
- async def TokenFlowWait(self, stream):
1023
- await stream.send_message(
1024
- api_pb2.TokenFlowWaitResponse(
1025
- token_id="abc",
1026
- token_secret="xyz",
1027
- )
1028
- )
1029
-
1030
- async def WorkspaceNameLookup(self, stream):
1031
- await stream.send_message(
1032
- api_pb2.WorkspaceNameLookupResponse(workspace_name="test-workspace", username="test-username")
1033
- )
1034
-
1035
- ### Tunnel
1036
-
1037
- async def TunnelStart(self, stream):
1038
- request: api_pb2.TunnelStartRequest = await stream.recv_message()
1039
- port = request.port
1040
- await stream.send_message(api_pb2.TunnelStartResponse(host=f"{port}.modal.test", port=443))
1041
-
1042
- async def TunnelStop(self, stream):
1043
- await stream.recv_message()
1044
- await stream.send_message(api_pb2.TunnelStopResponse(exists=True))
1045
-
1046
- ### Volume
1047
-
1048
- async def VolumeCreate(self, stream):
1049
- req = await stream.recv_message()
1050
- self.requests.append(req)
1051
- self.volume_counter += 1
1052
- volume_id = f"vo-{self.volume_counter}"
1053
- self.volume_files[volume_id] = {}
1054
- await stream.send_message(api_pb2.VolumeCreateResponse(volume_id=volume_id))
1055
-
1056
- async def VolumeGetOrCreate(self, stream):
1057
- request: api_pb2.VolumeGetOrCreateRequest = await stream.recv_message()
1058
- k = (request.deployment_name, request.namespace, request.environment_name)
1059
- if request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_UNSPECIFIED:
1060
- if k not in self.deployed_volumes:
1061
- raise GRPCError(Status.NOT_FOUND, "Volume not found")
1062
- volume_id = self.deployed_volumes[k]
1063
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_EPHEMERAL:
1064
- volume_id = f"vo-{len(self.volume_files)}"
1065
- self.volume_files[volume_id] = {}
1066
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING:
1067
- if k not in self.deployed_volumes:
1068
- volume_id = f"vo-{len(self.volume_files)}"
1069
- self.volume_files[volume_id] = {}
1070
- self.deployed_volumes[k] = volume_id
1071
- volume_id = self.deployed_volumes[k]
1072
- elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS:
1073
- if k in self.deployed_volumes:
1074
- raise GRPCError(Status.ALREADY_EXISTS, "Volume already exists")
1075
- volume_id = f"vo-{len(self.volume_files)}"
1076
- self.volume_files[volume_id] = {}
1077
- self.deployed_volumes[k] = volume_id
1078
- else:
1079
- raise GRPCError(Status.INVALID_ARGUMENT, "unsupported object creation type")
1080
-
1081
- await stream.send_message(api_pb2.VolumeGetOrCreateResponse(volume_id=volume_id))
1082
-
1083
- async def VolumeHeartbeat(self, stream):
1084
- await stream.recv_message()
1085
- self.n_vol_heartbeats += 1
1086
- await stream.send_message(Empty())
1087
-
1088
- async def VolumeCommit(self, stream):
1089
- req = await stream.recv_message()
1090
- self.requests.append(req)
1091
- if not req.volume_id.startswith("vo-"):
1092
- raise GRPCError(Status.NOT_FOUND, f"invalid volume ID {req.volume_id}")
1093
- self.volume_commits[req.volume_id] += 1
1094
- await stream.send_message(api_pb2.VolumeCommitResponse(skip_reload=False))
1095
-
1096
- async def VolumeDelete(self, stream):
1097
- req: api_pb2.VolumeDeleteRequest = await stream.recv_message()
1098
- self.volume_files.pop(req.volume_id)
1099
- self.deployed_volumes = {k: vol_id for k, vol_id in self.deployed_volumes.items() if vol_id != req.volume_id}
1100
- await stream.send_message(Empty())
1101
-
1102
- async def VolumeReload(self, stream):
1103
- req = await stream.recv_message()
1104
- self.requests.append(req)
1105
- self.volume_reloads[req.volume_id] += 1
1106
- await stream.send_message(Empty())
1107
-
1108
- async def VolumeGetFile(self, stream):
1109
- req = await stream.recv_message()
1110
- if req.path.decode("utf-8") not in self.volume_files[req.volume_id]:
1111
- raise GRPCError(Status.NOT_FOUND, "File not found")
1112
- vol_file = self.volume_files[req.volume_id][req.path.decode("utf-8")]
1113
- if vol_file.data_blob_id:
1114
- await stream.send_message(api_pb2.VolumeGetFileResponse(data_blob_id=vol_file.data_blob_id))
1115
- else:
1116
- size = len(vol_file.data)
1117
- if req.start or req.len:
1118
- start = req.start
1119
- len_ = req.len or len(vol_file.data)
1120
- await stream.send_message(
1121
- api_pb2.VolumeGetFileResponse(data=vol_file.data[start : start + len_], size=size)
1122
- )
1123
- else:
1124
- await stream.send_message(api_pb2.VolumeGetFileResponse(data=vol_file.data, size=size))
1125
-
1126
- async def VolumeRemoveFile(self, stream):
1127
- req = await stream.recv_message()
1128
- if req.path.decode("utf-8") not in self.volume_files[req.volume_id]:
1129
- raise GRPCError(Status.INVALID_ARGUMENT, "File not found")
1130
- del self.volume_files[req.volume_id][req.path.decode("utf-8")]
1131
- await stream.send_message(Empty())
1132
-
1133
- async def VolumeListFiles(self, stream):
1134
- req = await stream.recv_message()
1135
- if req.path != "**":
1136
- raise NotImplementedError("Only '**' listing is supported.")
1137
- for k, vol_file in self.volume_files[req.volume_id].items():
1138
- entries = [
1139
- api_pb2.VolumeListFilesEntry(
1140
- path=k, type=api_pb2.VolumeListFilesEntry.FileType.FILE, size=len(vol_file.data)
1141
- )
1142
- ]
1143
- await stream.send_message(api_pb2.VolumeListFilesResponse(entries=entries))
1144
-
1145
- async def VolumePutFiles(self, stream):
1146
- req = await stream.recv_message()
1147
- for file in req.files:
1148
- blob_data = self.files_sha2data[file.sha256_hex]
1149
-
1150
- if file.filename in self.volume_files[req.volume_id] and req.disallow_overwrite_existing_files:
1151
- raise GRPCError(
1152
- Status.ALREADY_EXISTS,
1153
- f"{file.filename}: already exists (disallow_overwrite_existing_files={req.disallow_overwrite_existing_files}",
1154
- )
1155
-
1156
- self.volume_files[req.volume_id][file.filename] = VolumeFile(
1157
- data=blob_data["data"],
1158
- data_blob_id=blob_data["data_blob_id"],
1159
- mode=file.mode,
1160
- )
1161
- await stream.send_message(Empty())
1162
-
1163
- async def VolumeCopyFiles(self, stream):
1164
- req = await stream.recv_message()
1165
- for src_path in req.src_paths:
1166
- if src_path.decode("utf-8") not in self.volume_files[req.volume_id]:
1167
- raise GRPCError(Status.NOT_FOUND, f"Source file not found: {src_path}")
1168
- src_file = self.volume_files[req.volume_id][src_path.decode("utf-8")]
1169
- if len(req.src_paths) > 1:
1170
- # check to make sure dst is a directory
1171
- if (
1172
- req.dst_path.decode("utf-8").endswith(("/", "\\"))
1173
- or not os.path.splitext(os.path.basename(req.dst_path))[1]
1174
- ):
1175
- dst_path = os.path.join(req.dst_path, os.path.basename(src_path))
1176
- else:
1177
- raise GRPCError(Status.INVALID_ARGUMENT, f"{dst_path} is not a directory.")
1178
- else:
1179
- dst_path = req.dst_path
1180
- self.volume_files[req.volume_id][dst_path.decode("utf-8")] = src_file
1181
- await stream.send_message(Empty())
1182
-
1183
-
1184
- @pytest_asyncio.fixture
1185
- async def blob_server():
1186
- blobs = {}
1187
- blob_parts: Dict[str, Dict[int, bytes]] = defaultdict(dict)
1188
-
1189
- async def upload(request):
1190
- blob_id = request.query["blob_id"]
1191
- content = await request.content.read()
1192
- if content == b"FAILURE":
1193
- return aiohttp.web.Response(status=500)
1194
- content_md5 = hashlib.md5(content).hexdigest()
1195
- etag = f'"{content_md5}"'
1196
- if "part_number" in request.query:
1197
- part_number = int(request.query["part_number"])
1198
- blob_parts[blob_id][part_number] = content
1199
- else:
1200
- blobs[blob_id] = content
1201
- return aiohttp.web.Response(text="Hello, world", headers={"ETag": etag})
1202
-
1203
- async def complete_multipart(request):
1204
- blob_id = request.query["blob_id"]
1205
- blob_nums = range(min(blob_parts[blob_id].keys()), max(blob_parts[blob_id].keys()) + 1)
1206
- content = b""
1207
- part_hashes = b""
1208
- for num in blob_nums:
1209
- part_content = blob_parts[blob_id][num]
1210
- content += part_content
1211
- part_hashes += hashlib.md5(part_content).digest()
1212
-
1213
- content_md5 = hashlib.md5(part_hashes).hexdigest()
1214
- etag = f'"{content_md5}-{len(blob_parts[blob_id])}"'
1215
- blobs[blob_id] = content
1216
- return aiohttp.web.Response(text=f"<etag>{etag}</etag>")
1217
-
1218
- async def download(request):
1219
- blob_id = request.query["blob_id"]
1220
- if blob_id == "bl-failure":
1221
- return aiohttp.web.Response(status=500)
1222
- return aiohttp.web.Response(body=blobs[blob_id])
1223
-
1224
- app = aiohttp.web.Application()
1225
- app.add_routes([aiohttp.web.put("/upload", upload)])
1226
- app.add_routes([aiohttp.web.get("/download", download)])
1227
- app.add_routes([aiohttp.web.post("/complete_multipart", complete_multipart)])
1228
-
1229
- async with run_temporary_http_server(app) as host:
1230
- yield host, blobs
1231
-
1232
-
1233
- @pytest_asyncio.fixture(scope="function")
1234
- async def servicer_factory(blob_server):
1235
- @contextlib.asynccontextmanager
1236
- async def create_server(host=None, port=None, path=None):
1237
- blob_host, blobs = blob_server
1238
- servicer = MockClientServicer(blob_host, blobs) # type: ignore
1239
- server = None
1240
-
1241
- async def _start_servicer():
1242
- nonlocal server
1243
- server = grpclib.server.Server([servicer])
1244
- await server.start(host=host, port=port, path=path)
1245
-
1246
- async def _stop_servicer():
1247
- servicer.container_heartbeat_abort.set()
1248
- server.close()
1249
- # This is the proper way to close down the asyncio server,
1250
- # but it causes our tests to hang on 3.12+ because client connections
1251
- # for clients created through _Client.from_env don't get closed until
1252
- # asyncio event loop shutdown. Commenting out but perhaps revisit if we
1253
- # refactor the way that _Client cleanup happens.
1254
- # await server.wait_closed()
1255
-
1256
- start_servicer = synchronize_api(_start_servicer)
1257
- stop_servicer = synchronize_api(_stop_servicer)
1258
-
1259
- await start_servicer.aio()
1260
- try:
1261
- yield servicer
1262
- finally:
1263
- await stop_servicer.aio()
1264
-
1265
- yield create_server
1266
-
1267
-
1268
- @pytest_asyncio.fixture(scope="function")
1269
- async def servicer(servicer_factory):
1270
- port = find_free_port()
1271
- async with servicer_factory(host="0.0.0.0", port=port) as servicer:
1272
- servicer.remote_addr = f"http://127.0.0.1:{port}"
1273
- yield servicer
1274
-
1275
-
1276
- @pytest_asyncio.fixture(scope="function")
1277
- async def unix_servicer(servicer_factory):
1278
- with tempfile.TemporaryDirectory() as tmpdirname:
1279
- path = os.path.join(tmpdirname, "servicer.sock")
1280
- async with servicer_factory(path=path) as servicer:
1281
- servicer.remote_addr = f"unix://{path}"
1282
- yield servicer
1283
-
1284
-
1285
- @pytest_asyncio.fixture(scope="function")
1286
- async def client(servicer):
1287
- with Client(servicer.remote_addr, api_pb2.CLIENT_TYPE_CLIENT, ("foo-id", "foo-secret")) as client:
1288
- yield client
1289
-
1290
-
1291
- @pytest_asyncio.fixture(scope="function")
1292
- async def container_client(unix_servicer):
1293
- async with Client(unix_servicer.remote_addr, api_pb2.CLIENT_TYPE_CONTAINER, ("ta-123", "task-secret")) as client:
1294
- yield client
1295
-
1296
-
1297
- @pytest_asyncio.fixture(scope="function")
1298
- async def server_url_env(servicer, monkeypatch):
1299
- monkeypatch.setenv("MODAL_SERVER_URL", servicer.remote_addr)
1300
- yield
1301
-
1302
-
1303
- @pytest_asyncio.fixture(scope="function", autouse=True)
1304
- async def reset_default_client():
1305
- Client.set_env_client(None)
1306
-
1307
-
1308
- @pytest.fixture(name="mock_dir", scope="session")
1309
- def mock_dir_factory():
1310
- """Sets up a temp dir with content as specified in a nested dict
1311
-
1312
- Example usage:
1313
- spec = {
1314
- "foo": {
1315
- "bar.txt": "some content"
1316
- },
1317
- }
1318
-
1319
- with mock_dir(spec) as root_dir:
1320
- assert os.path.exists(os.path.join(root_dir, "foo", "bar.txt"))
1321
- """
1322
-
1323
- @contextlib.contextmanager
1324
- def mock_dir(root_spec):
1325
- def rec_make(dir, dir_spec):
1326
- for filename, spec in dir_spec.items():
1327
- path = os.path.join(dir, filename)
1328
- if isinstance(spec, str):
1329
- with open(path, "w") as f:
1330
- f.write(spec)
1331
- else:
1332
- os.mkdir(path)
1333
- rec_make(path, spec)
1334
-
1335
- # Windows has issues cleaning up TempDirectory: https://www.scivision.dev/python-tempfile-permission-error-windows
1336
- # Seems to have been fixed for some python versions in https://github.com/python/cpython/pull/10320.
1337
- root_dir = tempfile.mkdtemp()
1338
- rec_make(root_dir, root_spec)
1339
- cwd = os.getcwd()
1340
- try:
1341
- os.chdir(root_dir)
1342
- yield
1343
- finally:
1344
- os.chdir(cwd)
1345
- shutil.rmtree(root_dir, ignore_errors=True)
1346
-
1347
- return mock_dir
1348
-
1349
-
1350
- @pytest.fixture(autouse=True)
1351
- def reset_sys_modules():
1352
- # Needed since some tests will import dynamic modules
1353
- backup = sys.modules.copy()
1354
- try:
1355
- yield
1356
- finally:
1357
- sys.modules = backup
1358
-
1359
-
1360
- @pytest.fixture(autouse=True)
1361
- def reset_container_app():
1362
- try:
1363
- yield
1364
- finally:
1365
- _ContainerApp._reset_container()
1366
-
1367
-
1368
- @pytest.fixture
1369
- def repo_root(request):
1370
- return Path(request.config.rootdir)
1371
-
1372
-
1373
- @pytest.fixture(scope="module")
1374
- def test_dir(request):
1375
- """Absolute path to directory containing test file."""
1376
- root_dir = Path(request.config.rootdir)
1377
- test_dir = Path(os.getenv("PYTEST_CURRENT_TEST")).parent
1378
- return root_dir / test_dir
1379
-
1380
-
1381
- @pytest.fixture(scope="function")
1382
- def modal_config():
1383
- """Return a context manager with a temporary modal.toml file"""
1384
-
1385
- @contextlib.contextmanager
1386
- def mock_modal_toml(contents: str = "", show_on_error: bool = False):
1387
- # Some of the cli tests run within within the main process
1388
- # so we need to modify the config singletons to pick up any changes
1389
- orig_config_path_env = os.environ.get("MODAL_CONFIG_PATH")
1390
- orig_config_path = config.user_config_path
1391
- orig_profile = config._profile
1392
- try:
1393
- with tempfile.NamedTemporaryFile(delete=False, suffix=".toml", mode="w") as t:
1394
- t.write(textwrap.dedent(contents.strip("\n")))
1395
- os.environ["MODAL_CONFIG_PATH"] = t.name
1396
- config.user_config_path = t.name
1397
- config._user_config = config._read_user_config()
1398
- config._profile = config._config_active_profile()
1399
- yield t.name
1400
- except Exception:
1401
- if show_on_error:
1402
- with open(t.name) as f:
1403
- print(f"Test config file contents:\n\n{f.read()}", file=sys.stderr)
1404
- raise
1405
- finally:
1406
- if orig_config_path_env:
1407
- os.environ["MODAL_CONFIG_PATH"] = orig_config_path_env
1408
- else:
1409
- del os.environ["MODAL_CONFIG_PATH"]
1410
- config.user_config_path = orig_config_path
1411
- config._user_config = config._read_user_config()
1412
- config._profile = orig_profile
1413
- os.remove(t.name)
1414
-
1415
- return mock_modal_toml
1416
-
1417
-
1418
- @pytest.fixture
1419
- def supports_dir(test_dir):
1420
- return test_dir / Path("supports")