modal 0.62.115__py3-none-any.whl → 0.72.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +13 -9
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +402 -398
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -60
  11. modal/_resources.py +26 -7
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1025 -0
  15. modal/{execution_context.py → _runtime/execution_context.py} +11 -2
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +123 -6
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +50 -14
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +386 -104
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +299 -98
  29. modal/_utils/grpc_testing.py +47 -34
  30. modal/_utils/grpc_utils.py +54 -21
  31. modal/_utils/hash_utils.py +51 -10
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +3 -3
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +12 -10
  43. modal/app.py +561 -323
  44. modal/app.pyi +474 -262
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +22 -6
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +203 -42
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +61 -13
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +21 -48
  55. modal/cli/launch.py +28 -14
  56. modal/cli/network_file_system.py +57 -21
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +34 -9
  59. modal/cli/programs/vscode.py +58 -8
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +199 -96
  62. modal/cli/secret.py +5 -4
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +74 -8
  65. modal/cli/volume.py +97 -56
  66. modal/client.py +248 -144
  67. modal/client.pyi +156 -124
  68. modal/cloud_bucket_mount.py +43 -30
  69. modal/cloud_bucket_mount.pyi +32 -25
  70. modal/cls.py +528 -141
  71. modal/cls.pyi +189 -145
  72. modal/config.py +32 -15
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +50 -54
  76. modal/dict.pyi +120 -164
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +30 -43
  80. modal/experimental.py +62 -2
  81. modal/file_io.py +537 -0
  82. modal/file_io.pyi +235 -0
  83. modal/file_pattern_matcher.py +196 -0
  84. modal/functions.py +846 -428
  85. modal/functions.pyi +446 -387
  86. modal/gpu.py +57 -44
  87. modal/image.py +943 -417
  88. modal/image.pyi +584 -245
  89. modal/io_streams.py +434 -0
  90. modal/io_streams.pyi +122 -0
  91. modal/mount.py +223 -90
  92. modal/mount.pyi +241 -243
  93. modal/network_file_system.py +85 -86
  94. modal/network_file_system.pyi +151 -110
  95. modal/object.py +66 -36
  96. modal/object.pyi +166 -143
  97. modal/output.py +63 -0
  98. modal/parallel_map.py +73 -47
  99. modal/parallel_map.pyi +51 -63
  100. modal/partial_function.py +272 -107
  101. modal/partial_function.pyi +219 -120
  102. modal/proxy.py +15 -12
  103. modal/proxy.pyi +3 -8
  104. modal/queue.py +96 -72
  105. modal/queue.pyi +210 -135
  106. modal/requirements/2024.04.txt +2 -1
  107. modal/requirements/2024.10.txt +16 -0
  108. modal/requirements/README.md +21 -0
  109. modal/requirements/base-images.json +22 -0
  110. modal/retries.py +45 -4
  111. modal/runner.py +325 -203
  112. modal/runner.pyi +124 -110
  113. modal/running_app.py +27 -4
  114. modal/sandbox.py +509 -231
  115. modal/sandbox.pyi +396 -169
  116. modal/schedule.py +2 -2
  117. modal/scheduler_placement.py +20 -3
  118. modal/secret.py +41 -25
  119. modal/secret.pyi +62 -42
  120. modal/serving.py +39 -49
  121. modal/serving.pyi +37 -43
  122. modal/stream_type.py +15 -0
  123. modal/token_flow.py +5 -3
  124. modal/token_flow.pyi +37 -32
  125. modal/volume.py +123 -137
  126. modal/volume.pyi +228 -221
  127. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/METADATA +5 -5
  128. modal-0.72.13.dist-info/RECORD +174 -0
  129. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/top_level.txt +0 -1
  130. modal_docs/gen_reference_docs.py +3 -1
  131. modal_docs/mdmd/mdmd.py +0 -1
  132. modal_docs/mdmd/signatures.py +1 -2
  133. modal_global_objects/images/base_images.py +28 -0
  134. modal_global_objects/mounts/python_standalone.py +2 -2
  135. modal_proto/__init__.py +1 -1
  136. modal_proto/api.proto +1231 -531
  137. modal_proto/api_grpc.py +750 -430
  138. modal_proto/api_pb2.py +2102 -1176
  139. modal_proto/api_pb2.pyi +8859 -0
  140. modal_proto/api_pb2_grpc.py +1329 -675
  141. modal_proto/api_pb2_grpc.pyi +1416 -0
  142. modal_proto/modal_api_grpc.py +149 -0
  143. modal_proto/modal_options_grpc.py +3 -0
  144. modal_proto/options_pb2.pyi +20 -0
  145. modal_proto/options_pb2_grpc.pyi +7 -0
  146. modal_proto/py.typed +0 -0
  147. modal_version/__init__.py +1 -1
  148. modal_version/_version_generated.py +2 -2
  149. modal/_asgi.py +0 -370
  150. modal/_container_exec.py +0 -128
  151. modal/_container_io_manager.py +0 -646
  152. modal/_container_io_manager.pyi +0 -412
  153. modal/_sandbox_shell.py +0 -49
  154. modal/app_utils.py +0 -20
  155. modal/app_utils.pyi +0 -17
  156. modal/execution_context.pyi +0 -37
  157. modal/shared_volume.py +0 -23
  158. modal/shared_volume.pyi +0 -24
  159. modal-0.62.115.dist-info/RECORD +0 -207
  160. modal_global_objects/images/conda.py +0 -15
  161. modal_global_objects/images/debian_slim.py +0 -15
  162. modal_global_objects/images/micromamba.py +0 -15
  163. test/__init__.py +0 -1
  164. test/aio_test.py +0 -12
  165. test/async_utils_test.py +0 -279
  166. test/blob_test.py +0 -67
  167. test/cli_imports_test.py +0 -149
  168. test/cli_test.py +0 -674
  169. test/client_test.py +0 -203
  170. test/cloud_bucket_mount_test.py +0 -22
  171. test/cls_test.py +0 -636
  172. test/config_test.py +0 -149
  173. test/conftest.py +0 -1485
  174. test/container_app_test.py +0 -50
  175. test/container_test.py +0 -1405
  176. test/cpu_test.py +0 -23
  177. test/decorator_test.py +0 -85
  178. test/deprecation_test.py +0 -34
  179. test/dict_test.py +0 -51
  180. test/e2e_test.py +0 -68
  181. test/error_test.py +0 -7
  182. test/function_serialization_test.py +0 -32
  183. test/function_test.py +0 -791
  184. test/function_utils_test.py +0 -101
  185. test/gpu_test.py +0 -159
  186. test/grpc_utils_test.py +0 -82
  187. test/helpers.py +0 -47
  188. test/image_test.py +0 -814
  189. test/live_reload_test.py +0 -80
  190. test/lookup_test.py +0 -70
  191. test/mdmd_test.py +0 -329
  192. test/mount_test.py +0 -162
  193. test/mounted_files_test.py +0 -327
  194. test/network_file_system_test.py +0 -188
  195. test/notebook_test.py +0 -66
  196. test/object_test.py +0 -41
  197. test/package_utils_test.py +0 -25
  198. test/queue_test.py +0 -115
  199. test/resolver_test.py +0 -59
  200. test/retries_test.py +0 -67
  201. test/runner_test.py +0 -85
  202. test/sandbox_test.py +0 -191
  203. test/schedule_test.py +0 -15
  204. test/scheduler_placement_test.py +0 -57
  205. test/secret_test.py +0 -89
  206. test/serialization_test.py +0 -50
  207. test/stub_composition_test.py +0 -10
  208. test/stub_test.py +0 -361
  209. test/test_asgi_wrapper.py +0 -234
  210. test/token_flow_test.py +0 -18
  211. test/traceback_test.py +0 -135
  212. test/tunnel_test.py +0 -29
  213. test/utils_test.py +0 -88
  214. test/version_test.py +0 -14
  215. test/volume_test.py +0 -397
  216. test/watcher_test.py +0 -58
  217. test/webhook_test.py +0 -145
  218. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/LICENSE +0 -0
  219. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/WHEEL +0 -0
  220. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/entry_points.txt +0 -0
test/container_test.py DELETED
@@ -1,1405 +0,0 @@
1
- # Copyright Modal Labs 2022
2
-
3
- import asyncio
4
- import base64
5
- import dataclasses
6
- import json
7
- import os
8
- import pathlib
9
- import pickle
10
- import pytest
11
- import signal
12
- import subprocess
13
- import sys
14
- import tempfile
15
- import time
16
- import uuid
17
- from typing import Any, Dict, List, Optional, Tuple
18
- from unittest import mock
19
- from unittest.mock import MagicMock
20
-
21
- from grpclib import Status
22
- from grpclib.exceptions import GRPCError
23
-
24
- from modal import Client, is_local
25
- from modal._container_entrypoint import UserException, main
26
- from modal._serialization import (
27
- deserialize,
28
- deserialize_data_format,
29
- serialize,
30
- serialize_data_format,
31
- )
32
- from modal._utils import async_utils
33
- from modal.app import _App
34
- from modal.exception import InvalidError
35
- from modal.partial_function import enter
36
- from modal_proto import api_pb2
37
-
38
- from .helpers import deploy_app_externally
39
- from .supports.skip import skip_github_non_linux
40
-
41
- EXTRA_TOLERANCE_DELAY = 2.0 if sys.platform == "linux" else 5.0
42
- FUNCTION_CALL_ID = "fc-123"
43
- SLEEP_DELAY = 0.1
44
-
45
-
46
- def _get_inputs(
47
- args: Tuple[Tuple, Dict] = ((42,), {}), n: int = 1, kill_switch=True
48
- ) -> List[api_pb2.FunctionGetInputsResponse]:
49
- input_pb = api_pb2.FunctionInput(args=serialize(args), data_format=api_pb2.DATA_FORMAT_PICKLE)
50
- inputs = [
51
- *(
52
- api_pb2.FunctionGetInputsItem(input_id=f"in-xyz{i}", function_call_id="fc-123", input=input_pb)
53
- for i in range(n)
54
- ),
55
- *([api_pb2.FunctionGetInputsItem(kill_switch=True)] if kill_switch else []),
56
- ]
57
- return [api_pb2.FunctionGetInputsResponse(inputs=[x]) for x in inputs]
58
-
59
-
60
- @dataclasses.dataclass
61
- class ContainerResult:
62
- client: Client
63
- items: List[api_pb2.FunctionPutOutputsItem]
64
- data_chunks: List[api_pb2.DataChunk]
65
- task_result: api_pb2.GenericResult
66
-
67
-
68
- def _get_multi_inputs(args: List[Tuple[Tuple, Dict]] = []) -> List[api_pb2.FunctionGetInputsResponse]:
69
- responses = []
70
- for input_n, input_args in enumerate(args):
71
- resp = api_pb2.FunctionGetInputsResponse(
72
- inputs=[
73
- api_pb2.FunctionGetInputsItem(
74
- input_id=f"in-{input_n:03}", input=api_pb2.FunctionInput(args=serialize(input_args))
75
- )
76
- ]
77
- )
78
- responses.append(resp)
79
-
80
- return responses + [api_pb2.FunctionGetInputsResponse(inputs=[api_pb2.FunctionGetInputsItem(kill_switch=True)])]
81
-
82
-
83
- def _container_args(
84
- module_name,
85
- function_name,
86
- function_type=api_pb2.Function.FUNCTION_TYPE_FUNCTION,
87
- webhook_type=api_pb2.WEBHOOK_TYPE_UNSPECIFIED,
88
- definition_type=api_pb2.Function.DEFINITION_TYPE_FILE,
89
- app_name: str = "",
90
- is_builder_function: bool = False,
91
- allow_concurrent_inputs: Optional[int] = None,
92
- serialized_params: Optional[bytes] = None,
93
- is_checkpointing_function: bool = False,
94
- deps: List[str] = ["im-1"],
95
- volume_mounts: Optional[List[api_pb2.VolumeMount]] = None,
96
- is_auto_snapshot: bool = False,
97
- max_inputs: Optional[int] = None,
98
- ):
99
- if webhook_type:
100
- webhook_config = api_pb2.WebhookConfig(
101
- type=webhook_type,
102
- method="GET",
103
- async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
104
- )
105
- else:
106
- webhook_config = None
107
-
108
- function_def = api_pb2.Function(
109
- module_name=module_name,
110
- function_name=function_name,
111
- function_type=function_type,
112
- volume_mounts=volume_mounts,
113
- webhook_config=webhook_config,
114
- definition_type=definition_type,
115
- app_name=app_name or "",
116
- is_builder_function=is_builder_function,
117
- is_auto_snapshot=is_auto_snapshot,
118
- allow_concurrent_inputs=allow_concurrent_inputs,
119
- is_checkpointing_function=is_checkpointing_function,
120
- object_dependencies=[api_pb2.ObjectDependency(object_id=object_id) for object_id in deps],
121
- max_inputs=max_inputs,
122
- )
123
-
124
- return api_pb2.ContainerArguments(
125
- task_id="ta-123",
126
- function_id="fu-123",
127
- app_id="ap-1",
128
- function_def=function_def,
129
- serialized_params=serialized_params,
130
- checkpoint_id=f"ch-{uuid.uuid4()}",
131
- )
132
-
133
-
134
- def _flatten_outputs(outputs) -> List[api_pb2.FunctionPutOutputsItem]:
135
- items: List[api_pb2.FunctionPutOutputsItem] = []
136
- for req in outputs:
137
- items += list(req.outputs)
138
- return items
139
-
140
-
141
- def _run_container(
142
- servicer,
143
- module_name,
144
- function_name,
145
- fail_get_inputs=False,
146
- inputs=None,
147
- function_type=api_pb2.Function.FUNCTION_TYPE_FUNCTION,
148
- webhook_type=api_pb2.WEBHOOK_TYPE_UNSPECIFIED,
149
- definition_type=api_pb2.Function.DEFINITION_TYPE_FILE,
150
- app_name: str = "",
151
- is_builder_function: bool = False,
152
- allow_concurrent_inputs: Optional[int] = None,
153
- serialized_params: Optional[bytes] = None,
154
- is_checkpointing_function: bool = False,
155
- deps: List[str] = ["im-1"],
156
- volume_mounts: Optional[List[api_pb2.VolumeMount]] = None,
157
- is_auto_snapshot: bool = False,
158
- max_inputs: Optional[int] = None,
159
- ) -> ContainerResult:
160
- container_args = _container_args(
161
- module_name,
162
- function_name,
163
- function_type,
164
- webhook_type,
165
- definition_type,
166
- app_name,
167
- is_builder_function,
168
- allow_concurrent_inputs,
169
- serialized_params,
170
- is_checkpointing_function,
171
- deps,
172
- volume_mounts,
173
- is_auto_snapshot,
174
- max_inputs,
175
- )
176
- with Client(servicer.remote_addr, api_pb2.CLIENT_TYPE_CONTAINER, ("ta-123", "task-secret")) as client:
177
- if inputs is None:
178
- servicer.container_inputs = _get_inputs()
179
- else:
180
- servicer.container_inputs = inputs
181
- function_call_id = servicer.container_inputs[0].inputs[0].function_call_id
182
- servicer.fail_get_inputs = fail_get_inputs
183
-
184
- if module_name in sys.modules:
185
- # Drop the module from sys.modules since some function code relies on the
186
- # assumption that that the app is created before the user code is imported.
187
- # This is really only an issue for tests.
188
- sys.modules.pop(module_name)
189
-
190
- env = os.environ.copy()
191
- temp_restore_file_path = tempfile.NamedTemporaryFile()
192
- if is_checkpointing_function:
193
- # State file is written to allow for a restore to happen.
194
- tmp_file_name = temp_restore_file_path.name
195
- with pathlib.Path(tmp_file_name).open("w") as target:
196
- json.dump({}, target)
197
- env["MODAL_RESTORE_STATE_PATH"] = tmp_file_name
198
-
199
- # Override server URL to reproduce restore behavior.
200
- env["MODAL_SERVER_URL"] = servicer.remote_addr
201
-
202
- # reset _App tracking state between runs
203
- _App._all_apps.clear()
204
-
205
- try:
206
- with mock.patch.dict(os.environ, env):
207
- main(container_args, client)
208
- except UserException:
209
- # Handle it gracefully
210
- pass
211
- finally:
212
- temp_restore_file_path.close()
213
-
214
- # Flatten outputs
215
- items = _flatten_outputs(servicer.container_outputs)
216
-
217
- # Get data chunks
218
- data_chunks: List[api_pb2.DataChunk] = []
219
- if function_call_id in servicer.fc_data_out:
220
- try:
221
- while True:
222
- chunk = servicer.fc_data_out[function_call_id].get_nowait()
223
- data_chunks.append(chunk)
224
- except asyncio.QueueEmpty:
225
- pass
226
-
227
- return ContainerResult(client, items, data_chunks, servicer.task_result)
228
-
229
-
230
- def _unwrap_scalar(ret: ContainerResult):
231
- assert len(ret.items) == 1
232
- assert ret.items[0].result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
233
- return deserialize(ret.items[0].result.data, ret.client)
234
-
235
-
236
- def _unwrap_exception(ret: ContainerResult):
237
- assert len(ret.items) == 1
238
- assert ret.items[0].result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE
239
- assert "Traceback" in ret.items[0].result.traceback
240
- return ret.items[0].result.exception
241
-
242
-
243
- def _unwrap_generator(ret: ContainerResult) -> Tuple[List[Any], Optional[Exception]]:
244
- assert len(ret.items) == 1
245
- item = ret.items[0]
246
- assert item.result.gen_status == api_pb2.GenericResult.GENERATOR_STATUS_UNSPECIFIED
247
-
248
- values: List[Any] = [deserialize_data_format(chunk.data, chunk.data_format, None) for chunk in ret.data_chunks]
249
-
250
- if item.result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE:
251
- exc = deserialize(item.result.data, ret.client)
252
- return values, exc
253
- elif item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS:
254
- assert item.data_format == api_pb2.DATA_FORMAT_GENERATOR_DONE
255
- done: api_pb2.GeneratorDone = deserialize_data_format(item.result.data, item.data_format, None)
256
- assert done.items_total == len(values)
257
- return values, None
258
- else:
259
- raise RuntimeError("unknown result type")
260
-
261
-
262
- def _unwrap_asgi(ret: ContainerResult):
263
- values, exc = _unwrap_generator(ret)
264
- assert exc is None, "web endpoint raised exception"
265
- return values
266
-
267
-
268
- @skip_github_non_linux
269
- def test_success(unix_servicer, event_loop):
270
- t0 = time.time()
271
- ret = _run_container(unix_servicer, "test.supports.functions", "square")
272
- assert 0 <= time.time() - t0 < EXTRA_TOLERANCE_DELAY
273
- assert _unwrap_scalar(ret) == 42**2
274
-
275
-
276
- @skip_github_non_linux
277
- def test_generator_success(unix_servicer, event_loop):
278
- ret = _run_container(
279
- unix_servicer,
280
- "test.supports.functions",
281
- "gen_n",
282
- function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR,
283
- )
284
-
285
- items, exc = _unwrap_generator(ret)
286
- assert items == [i**2 for i in range(42)]
287
- assert exc is None
288
-
289
-
290
- @skip_github_non_linux
291
- def test_generator_failure(unix_servicer, capsys):
292
- inputs = _get_inputs(((10, 5), {}))
293
- ret = _run_container(
294
- unix_servicer,
295
- "test.supports.functions",
296
- "gen_n_fail_on_m",
297
- function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR,
298
- inputs=inputs,
299
- )
300
- items, exc = _unwrap_generator(ret)
301
- assert items == [i**2 for i in range(5)]
302
- assert isinstance(exc, Exception)
303
- assert exc.args == ("bad",)
304
- assert 'raise Exception("bad")' in capsys.readouterr().err
305
-
306
-
307
- @skip_github_non_linux
308
- def test_async(unix_servicer):
309
- t0 = time.time()
310
- ret = _run_container(unix_servicer, "test.supports.functions", "square_async")
311
- assert SLEEP_DELAY <= time.time() - t0 < SLEEP_DELAY + EXTRA_TOLERANCE_DELAY
312
- assert _unwrap_scalar(ret) == 42**2
313
-
314
-
315
- @skip_github_non_linux
316
- def test_failure(unix_servicer, capsys):
317
- ret = _run_container(unix_servicer, "test.supports.functions", "raises")
318
- assert _unwrap_exception(ret) == "Exception('Failure!')"
319
- assert 'raise Exception("Failure!")' in capsys.readouterr().err # traceback
320
-
321
-
322
- @skip_github_non_linux
323
- def test_raises_base_exception(unix_servicer, capsys):
324
- ret = _run_container(unix_servicer, "test.supports.functions", "raises_sysexit")
325
- assert _unwrap_exception(ret) == "SystemExit(1)"
326
- assert "raise SystemExit(1)" in capsys.readouterr().err # traceback
327
-
328
-
329
- @skip_github_non_linux
330
- def test_keyboardinterrupt(unix_servicer):
331
- with pytest.raises(KeyboardInterrupt):
332
- _run_container(unix_servicer, "test.supports.functions", "raises_keyboardinterrupt")
333
-
334
-
335
- @skip_github_non_linux
336
- def test_rate_limited(unix_servicer, event_loop):
337
- t0 = time.time()
338
- unix_servicer.rate_limit_sleep_duration = 0.25
339
- ret = _run_container(unix_servicer, "test.supports.functions", "square")
340
- assert 0.25 <= time.time() - t0 < 0.25 + EXTRA_TOLERANCE_DELAY
341
- assert _unwrap_scalar(ret) == 42**2
342
-
343
-
344
- @skip_github_non_linux
345
- def test_grpc_failure(unix_servicer, event_loop):
346
- # An error in "Modal code" should cause the entire container to fail
347
- with pytest.raises(GRPCError):
348
- _run_container(
349
- unix_servicer,
350
- "test.supports.functions",
351
- "square",
352
- fail_get_inputs=True,
353
- )
354
-
355
- # assert unix_servicer.task_result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE
356
- # assert "GRPCError" in unix_servicer.task_result.exception
357
-
358
-
359
- @skip_github_non_linux
360
- def test_missing_main_conditional(unix_servicer, capsys):
361
- _run_container(unix_servicer, "test.supports.missing_main_conditional", "square")
362
- output = capsys.readouterr()
363
- assert "Can not run an app from within a container" in output.err
364
-
365
- assert unix_servicer.task_result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE
366
- assert "modal run" in unix_servicer.task_result.traceback
367
-
368
- exc = deserialize(unix_servicer.task_result.data, None)
369
- assert isinstance(exc, InvalidError)
370
-
371
-
372
- @skip_github_non_linux
373
- def test_startup_failure(unix_servicer, capsys):
374
- _run_container(unix_servicer, "test.supports.startup_failure", "f")
375
-
376
- assert unix_servicer.task_result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE
377
-
378
- exc = deserialize(unix_servicer.task_result.data, None)
379
- assert isinstance(exc, ImportError)
380
- assert "ModuleNotFoundError: No module named 'nonexistent_package'" in capsys.readouterr().err
381
-
382
-
383
- @skip_github_non_linux
384
- def test_from_local_python_packages_inside_container(unix_servicer):
385
- """`from_local_python_packages` shouldn't actually collect modules inside the container, because it's possible
386
- that there are modules that were present locally for the user that didn't get mounted into
387
- all the containers."""
388
- ret = _run_container(unix_servicer, "test.supports.package_mount", "num_mounts")
389
- assert _unwrap_scalar(ret) == 0
390
-
391
-
392
- def _get_web_inputs(path="/"):
393
- scope = {
394
- "method": "GET",
395
- "type": "http",
396
- "path": path,
397
- "headers": {},
398
- "query_string": b"arg=space",
399
- "http_version": "2",
400
- }
401
- return _get_inputs(((scope,), {}))
402
-
403
-
404
- @async_utils.synchronize_api # needs to be synchronized so the asyncio.Queue gets used from the same event loop as the servicer
405
- async def _put_web_body(servicer, body: bytes):
406
- asgi = {"type": "http.request", "body": body, "more_body": False}
407
- data = serialize_data_format(asgi, api_pb2.DATA_FORMAT_ASGI)
408
-
409
- q = servicer.fc_data_in.setdefault("fc-123", asyncio.Queue())
410
- q.put_nowait(api_pb2.DataChunk(data_format=api_pb2.DATA_FORMAT_ASGI, data=data, index=1))
411
-
412
-
413
- @skip_github_non_linux
414
- def test_webhook(unix_servicer):
415
- inputs = _get_web_inputs()
416
- _put_web_body(unix_servicer, b"")
417
- ret = _run_container(
418
- unix_servicer,
419
- "test.supports.functions",
420
- "webhook",
421
- inputs=inputs,
422
- webhook_type=api_pb2.WEBHOOK_TYPE_FUNCTION,
423
- )
424
- items = _unwrap_asgi(ret)
425
-
426
- # There should be one message for the header, one for the body, one for the EOF
427
- first_message, second_message = items # _unwrap_asgi ignores the eof
428
-
429
- # Check the headers
430
- assert first_message["status"] == 200
431
- headers = dict(first_message["headers"])
432
- assert headers[b"content-type"] == b"application/json"
433
-
434
- # Check body
435
- assert json.loads(second_message["body"]) == {"hello": "space"}
436
-
437
-
438
- @skip_github_non_linux
439
- def test_serialized_function(unix_servicer):
440
- def triple(x):
441
- return 3 * x
442
-
443
- unix_servicer.function_serialized = serialize(triple)
444
- ret = _run_container(
445
- unix_servicer,
446
- "foo.bar.baz",
447
- "f",
448
- definition_type=api_pb2.Function.DEFINITION_TYPE_SERIALIZED,
449
- )
450
- assert _unwrap_scalar(ret) == 3 * 42
451
-
452
-
453
- @skip_github_non_linux
454
- def test_webhook_serialized(unix_servicer):
455
- inputs = _get_web_inputs()
456
- _put_web_body(unix_servicer, b"")
457
-
458
- # Store a serialized webhook function on the servicer
459
- def webhook(arg="world"):
460
- return f"Hello, {arg}"
461
-
462
- unix_servicer.function_serialized = serialize(webhook)
463
-
464
- ret = _run_container(
465
- unix_servicer,
466
- "foo.bar.baz",
467
- "f",
468
- inputs=inputs,
469
- webhook_type=api_pb2.WEBHOOK_TYPE_FUNCTION,
470
- definition_type=api_pb2.Function.DEFINITION_TYPE_SERIALIZED,
471
- )
472
-
473
- _, second_message = _unwrap_asgi(ret)
474
- assert second_message["body"] == b'"Hello, space"' # Note: JSON-encoded
475
-
476
-
477
- @skip_github_non_linux
478
- def test_function_returning_generator(unix_servicer):
479
- ret = _run_container(
480
- unix_servicer,
481
- "test.supports.functions",
482
- "fun_returning_gen",
483
- function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR,
484
- )
485
- items, exc = _unwrap_generator(ret)
486
- assert len(items) == 42
487
-
488
-
489
- @skip_github_non_linux
490
- def test_asgi(unix_servicer):
491
- inputs = _get_web_inputs(path="/foo")
492
- _put_web_body(unix_servicer, b"")
493
- ret = _run_container(
494
- unix_servicer,
495
- "test.supports.functions",
496
- "fastapi_app",
497
- inputs=inputs,
498
- webhook_type=api_pb2.WEBHOOK_TYPE_ASGI_APP,
499
- )
500
-
501
- # There should be one message for the header, and one for the body
502
- first_message, second_message = _unwrap_asgi(ret)
503
-
504
- # Check the headers
505
- assert first_message["status"] == 200
506
- headers = dict(first_message["headers"])
507
- assert headers[b"content-type"] == b"application/json"
508
-
509
- # Check body
510
- assert json.loads(second_message["body"]) == {"hello": "space"}
511
-
512
-
513
- @skip_github_non_linux
514
- def test_wsgi(unix_servicer):
515
- inputs = _get_web_inputs(path="/")
516
- _put_web_body(unix_servicer, b"my wsgi body")
517
- ret = _run_container(
518
- unix_servicer,
519
- "test.supports.functions",
520
- "basic_wsgi_app",
521
- inputs=inputs,
522
- webhook_type=api_pb2.WEBHOOK_TYPE_WSGI_APP,
523
- )
524
-
525
- # There should be one message for headers, one for the body, and one for the end-of-body.
526
- first_message, second_message, third_message = _unwrap_asgi(ret)
527
-
528
- # Check the headers
529
- assert first_message["status"] == 200
530
- headers = dict(first_message["headers"])
531
- assert headers[b"content-type"] == b"text/plain; charset=utf-8"
532
-
533
- # Check body
534
- assert second_message["body"] == b"got body: my wsgi body"
535
- assert second_message.get("more_body", False) is True
536
- assert third_message["body"] == b""
537
- assert third_message.get("more_body", False) is False
538
-
539
-
540
- @skip_github_non_linux
541
- def test_webhook_streaming_sync(unix_servicer):
542
- inputs = _get_web_inputs()
543
- _put_web_body(unix_servicer, b"")
544
- ret = _run_container(
545
- unix_servicer,
546
- "test.supports.functions",
547
- "webhook_streaming",
548
- inputs=inputs,
549
- webhook_type=api_pb2.WEBHOOK_TYPE_FUNCTION,
550
- function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR,
551
- )
552
- data = _unwrap_asgi(ret)
553
- bodies = [d["body"].decode() for d in data if d.get("body")]
554
- assert bodies == [f"{i}..." for i in range(10)]
555
-
556
-
557
- @skip_github_non_linux
558
- def test_webhook_streaming_async(unix_servicer):
559
- inputs = _get_web_inputs()
560
- _put_web_body(unix_servicer, b"")
561
- ret = _run_container(
562
- unix_servicer,
563
- "test.supports.functions",
564
- "webhook_streaming_async",
565
- inputs=inputs,
566
- webhook_type=api_pb2.WEBHOOK_TYPE_FUNCTION,
567
- function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR,
568
- )
569
-
570
- data = _unwrap_asgi(ret)
571
- bodies = [d["body"].decode() for d in data if d.get("body")]
572
- assert bodies == [f"{i}..." for i in range(10)]
573
-
574
-
575
- @skip_github_non_linux
576
- def test_cls_function(unix_servicer):
577
- ret = _run_container(unix_servicer, "test.supports.functions", "Cls.f")
578
- assert _unwrap_scalar(ret) == 42 * 111
579
-
580
-
581
- @skip_github_non_linux
582
- def test_lifecycle_enter_sync(unix_servicer):
583
- ret = _run_container(unix_servicer, "test.supports.functions", "LifecycleCls.f_sync", inputs=_get_inputs(((), {})))
584
- assert _unwrap_scalar(ret) == ["enter_sync", "enter_async", "f_sync"]
585
-
586
-
587
- @skip_github_non_linux
588
- def test_lifecycle_enter_async(unix_servicer):
589
- ret = _run_container(unix_servicer, "test.supports.functions", "LifecycleCls.f_async", inputs=_get_inputs(((), {})))
590
- assert _unwrap_scalar(ret) == ["enter_sync", "enter_async", "f_async"]
591
-
592
-
593
- @skip_github_non_linux
594
- def test_param_cls_function(unix_servicer):
595
- serialized_params = pickle.dumps(([111], {"y": "foo"}))
596
- ret = _run_container(
597
- unix_servicer,
598
- "test.supports.functions",
599
- "ParamCls.f",
600
- serialized_params=serialized_params,
601
- )
602
- assert _unwrap_scalar(ret) == "111 foo 42"
603
-
604
-
605
- @skip_github_non_linux
606
- def test_cls_web_endpoint(unix_servicer):
607
- inputs = _get_web_inputs()
608
- ret = _run_container(
609
- unix_servicer,
610
- "test.supports.functions",
611
- "Cls.web",
612
- inputs=inputs,
613
- webhook_type=api_pb2.WEBHOOK_TYPE_FUNCTION,
614
- )
615
-
616
- _, second_message = _unwrap_asgi(ret)
617
- assert json.loads(second_message["body"]) == {"ret": "space" * 111}
618
-
619
-
620
- @skip_github_non_linux
621
- def test_serialized_cls(unix_servicer):
622
- class Cls:
623
- @enter()
624
- def enter(self):
625
- self.power = 5
626
-
627
- def method(self, x):
628
- return x**self.power
629
-
630
- unix_servicer.class_serialized = serialize(Cls)
631
- unix_servicer.function_serialized = serialize(Cls.method)
632
- ret = _run_container(
633
- unix_servicer,
634
- "module.doesnt.matter",
635
- "function.doesnt.matter",
636
- definition_type=api_pb2.Function.DEFINITION_TYPE_SERIALIZED,
637
- )
638
- assert _unwrap_scalar(ret) == 42**5
639
-
640
-
641
- @skip_github_non_linux
642
- def test_cls_generator(unix_servicer):
643
- ret = _run_container(
644
- unix_servicer,
645
- "test.supports.functions",
646
- "Cls.generator",
647
- function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR,
648
- )
649
- items, exc = _unwrap_generator(ret)
650
- assert items == [42**3]
651
- assert exc is None
652
-
653
-
654
- @skip_github_non_linux
655
- def test_checkpointing_cls_function(unix_servicer):
656
- ret = _run_container(
657
- unix_servicer,
658
- "test.supports.functions",
659
- "CheckpointingCls.f",
660
- inputs=_get_inputs((("D",), {})),
661
- is_checkpointing_function=True,
662
- )
663
- assert any(isinstance(request, api_pb2.ContainerCheckpointRequest) for request in unix_servicer.requests)
664
- for request in unix_servicer.requests:
665
- if isinstance(request, api_pb2.ContainerCheckpointRequest):
666
- assert request.checkpoint_id
667
- assert _unwrap_scalar(ret) == "ABCD"
668
-
669
-
670
- @skip_github_non_linux
671
- def test_cls_enter_uses_event_loop(unix_servicer):
672
- ret = _run_container(
673
- unix_servicer,
674
- "test.supports.functions",
675
- "EventLoopCls.f",
676
- inputs=_get_inputs(((), {})),
677
- )
678
- assert _unwrap_scalar(ret) == True
679
-
680
-
681
- @skip_github_non_linux
682
- def test_container_heartbeats(unix_servicer):
683
- _run_container(unix_servicer, "test.supports.functions", "square")
684
- assert any(isinstance(request, api_pb2.ContainerHeartbeatRequest) for request in unix_servicer.requests)
685
-
686
-
687
- @skip_github_non_linux
688
- def test_cli(unix_servicer):
689
- # This tests the container being invoked as a subprocess (the if __name__ == "__main__" block)
690
-
691
- # Build up payload we pass through sys args
692
- function_def = api_pb2.Function(
693
- module_name="test.supports.functions",
694
- function_name="square",
695
- function_type=api_pb2.Function.FUNCTION_TYPE_FUNCTION,
696
- definition_type=api_pb2.Function.DEFINITION_TYPE_FILE,
697
- object_dependencies=[api_pb2.ObjectDependency(object_id="im-123")],
698
- )
699
- container_args = api_pb2.ContainerArguments(
700
- task_id="ta-123",
701
- function_id="fu-123",
702
- app_id="ap-123",
703
- function_def=function_def,
704
- )
705
- data_base64: str = base64.b64encode(container_args.SerializeToString()).decode("ascii")
706
-
707
- # Needed for function hydration
708
- unix_servicer.app_objects["ap-123"] = {"": "im-123"}
709
-
710
- # Inputs that will be consumed by the container
711
- unix_servicer.container_inputs = _get_inputs()
712
-
713
- # Launch subprocess
714
- env = {"MODAL_SERVER_URL": unix_servicer.remote_addr}
715
- lib_dir = pathlib.Path(__file__).parent.parent
716
- args: List[str] = [sys.executable, "-m", "modal._container_entrypoint", data_base64]
717
- ret = subprocess.run(args, cwd=lib_dir, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
718
- stdout = ret.stdout.decode()
719
- stderr = ret.stderr.decode()
720
- if ret.returncode != 0:
721
- raise Exception(f"Failed with {ret.returncode} stdout: {stdout} stderr: {stderr}")
722
-
723
- assert stdout == ""
724
- assert stderr == ""
725
-
726
-
727
- @skip_github_non_linux
728
- def test_function_sibling_hydration(unix_servicer):
729
- deploy_app_externally(unix_servicer, "test.supports.functions", "app")
730
- ret = _run_container(unix_servicer, "test.supports.functions", "check_sibling_hydration")
731
- assert _unwrap_scalar(ret) is None
732
-
733
-
734
- @skip_github_non_linux
735
- def test_multiapp(unix_servicer, caplog):
736
- deploy_app_externally(unix_servicer, "test.supports.multiapp", "a")
737
- ret = _run_container(unix_servicer, "test.supports.multiapp", "a_func")
738
- assert _unwrap_scalar(ret) is None
739
- assert len(caplog.messages) == 0
740
- # Note that the app can be inferred from the function, even though there are multiple
741
- # apps present in the file
742
-
743
-
744
- @skip_github_non_linux
745
- def test_multiapp_privately_decorated(unix_servicer, caplog):
746
- # function handle does not override the original function, so we can't find the app
747
- # and the two apps are not named
748
- ret = _run_container(unix_servicer, "test.supports.multiapp_privately_decorated", "foo")
749
- assert _unwrap_scalar(ret) == 1
750
- assert "You have more than one unnamed app." in caplog.text
751
-
752
-
753
- @skip_github_non_linux
754
- def test_multiapp_privately_decorated_named_app(unix_servicer, caplog):
755
- # function handle does not override the original function, so we can't find the app
756
- # but we can use the names of the apps to determine the active app
757
- ret = _run_container(
758
- unix_servicer,
759
- "test.supports.multiapp_privately_decorated_named_app",
760
- "foo",
761
- app_name="dummy",
762
- )
763
- assert _unwrap_scalar(ret) == 1
764
- assert len(caplog.messages) == 0 # no warnings, since target app is named
765
-
766
-
767
- @skip_github_non_linux
768
- def test_multiapp_same_name_warning(unix_servicer, caplog, capsys):
769
- # function handle does not override the original function, so we can't find the app
770
- # two apps with the same name - warn since we won't know which one to hydrate
771
- ret = _run_container(
772
- unix_servicer,
773
- "test.supports.multiapp_same_name",
774
- "foo",
775
- app_name="dummy",
776
- )
777
- assert _unwrap_scalar(ret) == 1
778
- assert "You have more than one app with the same name ('dummy')" in caplog.text
779
- capsys.readouterr()
780
-
781
-
782
- @skip_github_non_linux
783
- def test_multiapp_serialized_func(unix_servicer, caplog):
784
- # serialized functions shouldn't warn about multiple/not finding apps, since they shouldn't load the module to begin with
785
- def dummy(x):
786
- return x
787
-
788
- unix_servicer.function_serialized = serialize(dummy)
789
- ret = _run_container(
790
- unix_servicer,
791
- "test.supports.multiapp_serialized_func",
792
- "foo",
793
- definition_type=api_pb2.Function.DEFINITION_TYPE_SERIALIZED,
794
- )
795
- assert _unwrap_scalar(ret) == 42
796
- assert len(caplog.messages) == 0
797
-
798
-
799
- @skip_github_non_linux
800
- def test_image_run_function_no_warn(unix_servicer, caplog):
801
- # builder functions currently aren't tied to any modal app,
802
- # so they shouldn't need to warn if they can't determine which app to use
803
- ret = _run_container(
804
- unix_servicer,
805
- "test.supports.image_run_function",
806
- "builder_function",
807
- inputs=_get_inputs(((), {})),
808
- is_builder_function=True,
809
- )
810
- assert _unwrap_scalar(ret) is None
811
- assert len(caplog.messages) == 0
812
-
813
-
814
- SLEEP_TIME = 0.7
815
-
816
-
817
- def _unwrap_concurrent_input_outputs(n_inputs: int, n_parallel: int, ret: ContainerResult):
818
- # Ensure that outputs align with expectation of running concurrent inputs
819
-
820
- # Each group of n_parallel inputs should start together of each other
821
- # and different groups should start SLEEP_TIME apart.
822
- assert len(ret.items) == n_inputs
823
- for i in range(1, len(ret.items)):
824
- diff = ret.items[i].input_started_at - ret.items[i - 1].input_started_at
825
- expected_diff = SLEEP_TIME if i % n_parallel == 0 else 0
826
- assert diff == pytest.approx(expected_diff, abs=0.3)
827
-
828
- outputs = []
829
- for item in ret.items:
830
- assert item.output_created_at - item.input_started_at == pytest.approx(SLEEP_TIME, abs=0.3)
831
- assert item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
832
- outputs.append(deserialize(item.result.data, ret.client))
833
- return outputs
834
-
835
-
836
- @skip_github_non_linux
837
- def test_concurrent_inputs_sync_function(unix_servicer):
838
- n_inputs = 18
839
- n_parallel = 6
840
-
841
- t0 = time.time()
842
- ret = _run_container(
843
- unix_servicer,
844
- "test.supports.functions",
845
- "sleep_700_sync",
846
- inputs=_get_inputs(n=n_inputs),
847
- allow_concurrent_inputs=n_parallel,
848
- )
849
-
850
- expected_execution = n_inputs / n_parallel * SLEEP_TIME
851
- assert expected_execution <= time.time() - t0 < expected_execution + EXTRA_TOLERANCE_DELAY
852
- outputs = _unwrap_concurrent_input_outputs(n_inputs, n_parallel, ret)
853
- for i, (squared, input_id, function_call_id) in enumerate(outputs):
854
- assert squared == 42**2
855
- assert input_id and input_id != outputs[i - 1][1]
856
- assert function_call_id and function_call_id == outputs[i - 1][2]
857
-
858
-
859
- @skip_github_non_linux
860
- def test_concurrent_inputs_async_function(unix_servicer):
861
- n_inputs = 18
862
- n_parallel = 6
863
-
864
- t0 = time.time()
865
- ret = _run_container(
866
- unix_servicer,
867
- "test.supports.functions",
868
- "sleep_700_async",
869
- inputs=_get_inputs(n=n_inputs),
870
- allow_concurrent_inputs=n_parallel,
871
- )
872
-
873
- expected_execution = n_inputs / n_parallel * SLEEP_TIME
874
- assert expected_execution <= time.time() - t0 < expected_execution + EXTRA_TOLERANCE_DELAY
875
- outputs = _unwrap_concurrent_input_outputs(n_inputs, n_parallel, ret)
876
- for i, (squared, input_id, function_call_id) in enumerate(outputs):
877
- assert squared == 42**2
878
- assert input_id and input_id != outputs[i - 1][1]
879
- assert function_call_id and function_call_id == outputs[i - 1][2]
880
-
881
-
882
- @skip_github_non_linux
883
- def test_unassociated_function(unix_servicer):
884
- ret = _run_container(unix_servicer, "test.supports.functions", "unassociated_function")
885
- assert _unwrap_scalar(ret) == 58
886
-
887
-
888
- @skip_github_non_linux
889
- def test_param_cls_function_calling_local(unix_servicer):
890
- serialized_params = pickle.dumps(([111], {"y": "foo"}))
891
- ret = _run_container(
892
- unix_servicer,
893
- "test.supports.functions",
894
- "ParamCls.g",
895
- serialized_params=serialized_params,
896
- )
897
- assert _unwrap_scalar(ret) == "111 foo 42"
898
-
899
-
900
- @skip_github_non_linux
901
- def test_derived_cls(unix_servicer):
902
- ret = _run_container(
903
- unix_servicer,
904
- "test.supports.functions",
905
- "DerivedCls.run",
906
- inputs=_get_inputs(((3,), {})),
907
- )
908
- assert _unwrap_scalar(ret) == 6
909
-
910
-
911
- @skip_github_non_linux
912
- def test_call_function_that_calls_function(unix_servicer):
913
- deploy_app_externally(unix_servicer, "test.supports.functions", "app")
914
- ret = _run_container(
915
- unix_servicer,
916
- "test.supports.functions",
917
- "cube",
918
- inputs=_get_inputs(((42,), {})),
919
- )
920
- assert _unwrap_scalar(ret) == 42**3
921
-
922
-
923
- @skip_github_non_linux
924
- def test_call_function_that_calls_method(unix_servicer, set_env_client):
925
- # TODO (elias): Remove set_env_client fixture dependency - shouldn't need an env client here?
926
- deploy_app_externally(unix_servicer, "test.supports.functions", "app")
927
- ret = _run_container(
928
- unix_servicer,
929
- "test.supports.functions",
930
- "function_calling_method",
931
- inputs=_get_inputs(((42, "abc", 123), {})),
932
- )
933
- assert _unwrap_scalar(ret) == 123**2 # servicer's implementation of function calling
934
-
935
-
936
- @skip_github_non_linux
937
- def test_checkpoint_and_restore_success(unix_servicer):
938
- """Functions send a checkpointing request and continue to execute normally,
939
- simulating a restore operation."""
940
- ret = _run_container(
941
- unix_servicer,
942
- "test.supports.functions",
943
- "square",
944
- is_checkpointing_function=True,
945
- )
946
- assert any(isinstance(request, api_pb2.ContainerCheckpointRequest) for request in unix_servicer.requests)
947
- for request in unix_servicer.requests:
948
- if isinstance(request, api_pb2.ContainerCheckpointRequest):
949
- assert request.checkpoint_id
950
-
951
- assert _unwrap_scalar(ret) == 42**2
952
-
953
-
954
- @skip_github_non_linux
955
- def test_volume_commit_on_exit(unix_servicer):
956
- volume_mounts = [
957
- api_pb2.VolumeMount(mount_path="/var/foo", volume_id="vo-123", allow_background_commits=True),
958
- api_pb2.VolumeMount(mount_path="/var/foo", volume_id="vo-456", allow_background_commits=True),
959
- ]
960
- ret = _run_container(
961
- unix_servicer,
962
- "test.supports.functions",
963
- "square",
964
- volume_mounts=volume_mounts,
965
- )
966
- volume_commit_rpcs = [r for r in unix_servicer.requests if isinstance(r, api_pb2.VolumeCommitRequest)]
967
- assert volume_commit_rpcs
968
- assert {"vo-123", "vo-456"} == set(r.volume_id for r in volume_commit_rpcs)
969
- assert _unwrap_scalar(ret) == 42**2
970
-
971
-
972
- @skip_github_non_linux
973
- def test_volume_commit_on_error(unix_servicer, capsys):
974
- volume_mounts = [
975
- api_pb2.VolumeMount(mount_path="/var/foo", volume_id="vo-foo", allow_background_commits=True),
976
- api_pb2.VolumeMount(mount_path="/var/foo", volume_id="vo-bar", allow_background_commits=True),
977
- ]
978
- _run_container(
979
- unix_servicer,
980
- "test.supports.functions",
981
- "raises",
982
- volume_mounts=volume_mounts,
983
- )
984
- volume_commit_rpcs = [r for r in unix_servicer.requests if isinstance(r, api_pb2.VolumeCommitRequest)]
985
- assert {"vo-foo", "vo-bar"} == set(r.volume_id for r in volume_commit_rpcs)
986
- assert 'raise Exception("Failure!")' in capsys.readouterr().err
987
-
988
-
989
- @skip_github_non_linux
990
- def test_no_volume_commit_on_exit(unix_servicer):
991
- volume_mounts = [api_pb2.VolumeMount(mount_path="/var/foo", volume_id="vo-999", allow_background_commits=False)]
992
- ret = _run_container(
993
- unix_servicer,
994
- "test.supports.functions",
995
- "square",
996
- volume_mounts=volume_mounts,
997
- )
998
- volume_commit_rpcs = [r for r in unix_servicer.requests if isinstance(r, api_pb2.VolumeCommitRequest)]
999
- assert not volume_commit_rpcs # No volume commit on exit for legacy volumes
1000
- assert _unwrap_scalar(ret) == 42**2
1001
-
1002
-
1003
- @skip_github_non_linux
1004
- def test_volume_commit_on_exit_doesnt_fail_container(unix_servicer):
1005
- volume_mounts = [
1006
- api_pb2.VolumeMount(mount_path="/var/foo", volume_id="vo-999", allow_background_commits=True),
1007
- api_pb2.VolumeMount(
1008
- mount_path="/var/foo",
1009
- volume_id="BAD-ID-FOR-VOL",
1010
- allow_background_commits=True,
1011
- ),
1012
- api_pb2.VolumeMount(mount_path="/var/foo", volume_id="vol-111", allow_background_commits=True),
1013
- ]
1014
- ret = _run_container(
1015
- unix_servicer,
1016
- "test.supports.functions",
1017
- "square",
1018
- volume_mounts=volume_mounts,
1019
- )
1020
- volume_commit_rpcs = [r for r in unix_servicer.requests if isinstance(r, api_pb2.VolumeCommitRequest)]
1021
- assert len(volume_commit_rpcs) == 3
1022
- assert _unwrap_scalar(ret) == 42**2
1023
-
1024
-
1025
- @skip_github_non_linux
1026
- def test_function_dep_hydration(unix_servicer):
1027
- deploy_app_externally(unix_servicer, "test.supports.functions", "app")
1028
- ret = _run_container(
1029
- unix_servicer,
1030
- "test.supports.functions",
1031
- "check_dep_hydration",
1032
- deps=["im-1", "vo-0", "im-1", "im-2", "vo-0", "vo-1"],
1033
- )
1034
- assert _unwrap_scalar(ret) is None
1035
-
1036
-
1037
- @skip_github_non_linux
1038
- def test_build_decorator_cls(unix_servicer):
1039
- ret = _run_container(
1040
- unix_servicer,
1041
- "test.supports.functions",
1042
- "BuildCls.build1",
1043
- inputs=_get_inputs(((), {})),
1044
- is_builder_function=True,
1045
- is_auto_snapshot=True,
1046
- )
1047
- assert _unwrap_scalar(ret) == 101
1048
- # TODO: this is GENERIC_STATUS_FAILURE when `@exit` fails,
1049
- # but why is it not set when `@exit` is successful?
1050
- # assert ret.task_result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
1051
- assert ret.task_result is None
1052
-
1053
-
1054
- @skip_github_non_linux
1055
- def test_multiple_build_decorator_cls(unix_servicer):
1056
- ret = _run_container(
1057
- unix_servicer,
1058
- "test.supports.functions",
1059
- "BuildCls.build2",
1060
- inputs=_get_inputs(((), {})),
1061
- is_builder_function=True,
1062
- is_auto_snapshot=True,
1063
- )
1064
- assert _unwrap_scalar(ret) == 1001
1065
- assert ret.task_result is None
1066
-
1067
-
1068
- @skip_github_non_linux
1069
- @pytest.mark.timeout(10.0)
1070
- def test_function_io_doesnt_inspect_args_or_return_values(monkeypatch, unix_servicer):
1071
- synchronizer = async_utils.synchronizer
1072
-
1073
- # set up spys to track synchronicity calls to _translate_scalar_in/out
1074
- translate_in_spy = MagicMock(wraps=synchronizer._translate_scalar_in)
1075
- monkeypatch.setattr(synchronizer, "_translate_scalar_in", translate_in_spy)
1076
- translate_out_spy = MagicMock(wraps=synchronizer._translate_scalar_out)
1077
- monkeypatch.setattr(synchronizer, "_translate_scalar_out", translate_out_spy)
1078
-
1079
- # don't do blobbing for this test
1080
- monkeypatch.setattr("modal._container_io_manager.MAX_OBJECT_SIZE_BYTES", 1e100)
1081
-
1082
- large_data_list = list(range(int(1e6))) # large data set
1083
-
1084
- t0 = time.perf_counter()
1085
- # pr = cProfile.Profile()
1086
- # pr.enable()
1087
- _run_container(
1088
- unix_servicer,
1089
- "test.supports.functions",
1090
- "ident",
1091
- inputs=_get_inputs(((large_data_list,), {})),
1092
- )
1093
- # pr.disable()
1094
- # pr.print_stats()
1095
- duration = time.perf_counter() - t0
1096
- assert duration < 5.0 # TODO (elias): might be able to get this down significantly more by improving serialization
1097
-
1098
- # function_io_manager.serialize(large_data_list)
1099
- in_translations = []
1100
- out_translations = []
1101
- for call in translate_in_spy.call_args_list:
1102
- in_translations += list(call.args)
1103
- for call in translate_out_spy.call_args_list:
1104
- out_translations += list(call.args)
1105
-
1106
- assert len(in_translations) < 1000 # typically 136 or something
1107
- assert len(out_translations) < 2000
1108
-
1109
-
1110
- def _run_container_process(
1111
- servicer,
1112
- module_name,
1113
- function_name,
1114
- *,
1115
- inputs: List[Tuple[Tuple, Dict[str, Any]]],
1116
- allow_concurrent_inputs: Optional[int] = None,
1117
- cls_params: Tuple[Tuple, Dict[str, Any]] = ((), {}),
1118
- print=False, # for debugging - print directly to stdout/stderr instead of pipeing
1119
- env={},
1120
- ) -> subprocess.Popen:
1121
- container_args = _container_args(
1122
- module_name,
1123
- function_name,
1124
- allow_concurrent_inputs=allow_concurrent_inputs,
1125
- serialized_params=serialize(cls_params),
1126
- )
1127
- encoded_container_args = base64.b64encode(container_args.SerializeToString())
1128
- servicer.container_inputs = _get_multi_inputs(inputs)
1129
- return subprocess.Popen(
1130
- [sys.executable, "-m", "modal._container_entrypoint", encoded_container_args],
1131
- env={**os.environ, **env},
1132
- stdout=subprocess.PIPE if not print else None,
1133
- stderr=subprocess.PIPE if not print else None,
1134
- )
1135
-
1136
-
1137
- @skip_github_non_linux
1138
- @pytest.mark.usefixtures("server_url_env")
1139
- @pytest.mark.parametrize(
1140
- ["function_name", "input_args", "cancelled_input_ids", "expected_container_output", "live_cancellations"],
1141
- [
1142
- # the 10 second inputs here are to be cancelled:
1143
- ("delay", [0.01, 20, 0.02], ["in-001"], [0.01, 0.02], 1), # cancel second input
1144
- ("delay_async", [0.01, 20, 0.02], ["in-001"], [0.01, 0.02], 1), # async variant
1145
- # cancel first input, but it has already been processed, so all three should come through:
1146
- ("delay", [0.01, 0.5, 0.03], ["in-000"], [0.01, 0.5, 0.03], 0),
1147
- ("delay_async", [0.01, 0.5, 0.03], ["in-000"], [0.01, 0.5, 0.03], 0),
1148
- ],
1149
- )
1150
- def test_cancellation_aborts_current_input_on_match(
1151
- servicer, function_name, input_args, cancelled_input_ids, expected_container_output, live_cancellations
1152
- ):
1153
- # NOTE: for a cancellation to actually happen in this test, it needs to be
1154
- # triggered while the relevant input is being processed. A future input
1155
- # would not be cancelled, since those are expected to be handled by
1156
- # the backend
1157
- with servicer.input_lockstep() as input_lock:
1158
- container_process = _run_container_process(
1159
- servicer,
1160
- "test.supports.functions",
1161
- function_name,
1162
- inputs=[((arg,), {}) for arg in input_args],
1163
- )
1164
- time.sleep(1)
1165
- input_lock.wait()
1166
- input_lock.wait()
1167
- # second input has been sent to container here
1168
- time.sleep(0.05) # give it a little time to start processing
1169
-
1170
- # now let container receive container heartbeat indicating there is a cancellation
1171
- t0 = time.monotonic()
1172
- num_prior_outputs = len(_flatten_outputs(servicer.container_outputs))
1173
- assert num_prior_outputs == 1 # the second input shouldn't have completed yet
1174
-
1175
- servicer.container_heartbeat_return_now(
1176
- api_pb2.ContainerHeartbeatResponse(cancel_input_event=api_pb2.CancelInputEvent(input_ids=cancelled_input_ids))
1177
- )
1178
- stdout, stderr = container_process.communicate()
1179
- assert stderr.decode().count("was cancelled by a user request") == live_cancellations
1180
- assert "Traceback" not in stderr.decode()
1181
- assert container_process.returncode == 0 # wait for container to exit
1182
- duration = time.monotonic() - t0 # time from heartbeat to container exit
1183
-
1184
- items = _flatten_outputs(servicer.container_outputs)
1185
- assert len(items) == len(expected_container_output)
1186
- data = [deserialize(i.result.data, client=None) for i in items]
1187
- assert data == expected_container_output
1188
- # should never run for ~20s, which is what the input would take if the sleep isn't interrupted
1189
- assert duration < 10 # should typically be < 1s, but for some reason in gh actions, it takes a really long time!
1190
-
1191
-
1192
- @skip_github_non_linux
1193
- @pytest.mark.usefixtures("server_url_env")
1194
- @pytest.mark.parametrize(
1195
- ["function_name"],
1196
- [("delay",), ("delay_async",)],
1197
- )
1198
- def test_cancellation_stops_task_with_concurrent_inputs(servicer, function_name):
1199
- # send three inputs in container: in-100, in-101, in-102
1200
- with servicer.input_lockstep() as input_lock:
1201
- container_process = _run_container_process(
1202
- servicer, "test.supports.functions", function_name, inputs=[((20,), {})], allow_concurrent_inputs=2
1203
- )
1204
- input_lock.wait()
1205
-
1206
- time.sleep(0.05) # let the container get and start processing the input
1207
- servicer.container_heartbeat_return_now(
1208
- api_pb2.ContainerHeartbeatResponse(cancel_input_event=api_pb2.CancelInputEvent(input_ids=["in-000"]))
1209
- )
1210
- # container should exit soon!
1211
- exit_code = container_process.wait(5)
1212
- assert exit_code == 0 # container should exit gracefully
1213
-
1214
-
1215
- @skip_github_non_linux
1216
- @pytest.mark.usefixtures("server_url_env")
1217
- def test_lifecycle_full(servicer):
1218
- # Sync and async container lifecycle methods on a sync function.
1219
- container_process = _run_container_process(
1220
- servicer, "test.supports.functions", "LifecycleCls.f_sync", inputs=[((), {})], cls_params=((True,), {})
1221
- )
1222
- stdout, _ = container_process.communicate(timeout=5)
1223
- assert container_process.returncode == 0
1224
- assert "[events:enter_sync,enter_async,f_sync,exit_sync,exit_async]" in stdout.decode()
1225
-
1226
- # Sync and async container lifecycle methods on an async function.
1227
- container_process = _run_container_process(
1228
- servicer, "test.supports.functions", "LifecycleCls.f_async", inputs=[((), {})], cls_params=((True,), {})
1229
- )
1230
- stdout, _ = container_process.communicate(timeout=5)
1231
- assert container_process.returncode == 0
1232
- assert "[events:enter_sync,enter_async,f_async,exit_sync,exit_async]" in stdout.decode()
1233
-
1234
-
1235
- ## modal.experimental functionality ##
1236
-
1237
-
1238
- @skip_github_non_linux
1239
- def test_stop_fetching_inputs(unix_servicer):
1240
- ret = _run_container(
1241
- unix_servicer,
1242
- "test.supports.experimental",
1243
- "StopFetching.after_two",
1244
- inputs=_get_inputs(((42,), {}), n=4, kill_switch=False),
1245
- )
1246
-
1247
- assert len(ret.items) == 2
1248
- assert ret.items[0].result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
1249
-
1250
-
1251
- @skip_github_non_linux
1252
- def test_container_heartbeat_survives_grpc_deadlines(servicer, caplog, monkeypatch):
1253
- monkeypatch.setattr("modal._container_io_manager.HEARTBEAT_INTERVAL", 0.01)
1254
- num_heartbeats = 0
1255
-
1256
- async def heartbeat_responder(servicer, stream):
1257
- nonlocal num_heartbeats
1258
- num_heartbeats += 1
1259
- await stream.recv_message()
1260
- raise GRPCError(Status.DEADLINE_EXCEEDED)
1261
-
1262
- with servicer.intercept() as ctx:
1263
- ctx.set_responder("ContainerHeartbeat", heartbeat_responder)
1264
- ret = _run_container(
1265
- servicer,
1266
- "test.supports.functions",
1267
- "delay",
1268
- inputs=_get_inputs(((2,), {})),
1269
- )
1270
- assert ret.task_result is None # should not cause a failure result
1271
- loop_iteration_failures = caplog.text.count("Heartbeat attempt failed")
1272
- assert "Traceback" not in caplog.text # should not print a full traceback - don't scare users!
1273
- assert (
1274
- loop_iteration_failures > 1
1275
- ) # one occurence per failing `retry_transient_errors()`, so fewer than the number of failing requests!
1276
- assert loop_iteration_failures < num_heartbeats
1277
- assert num_heartbeats > 4 # more than the default number of retries per heartbeat attempt + 1
1278
-
1279
-
1280
- @skip_github_non_linux
1281
- def test_container_heartbeat_survives_local_exceptions(servicer, caplog, monkeypatch):
1282
- numcalls = 0
1283
-
1284
- async def custom_heartbeater(self):
1285
- nonlocal numcalls
1286
- numcalls += 1
1287
- raise Exception("oops")
1288
-
1289
- monkeypatch.setattr("modal._container_io_manager.HEARTBEAT_INTERVAL", 0.01)
1290
- monkeypatch.setattr(
1291
- "modal._container_io_manager._ContainerIOManager._heartbeat_handle_cancellations", custom_heartbeater
1292
- )
1293
-
1294
- ret = _run_container(
1295
- servicer,
1296
- "test.supports.functions",
1297
- "delay",
1298
- inputs=_get_inputs(((0.5,), {})),
1299
- )
1300
- assert ret.task_result is None # should not cause a failure result
1301
- loop_iteration_failures = caplog.text.count("Heartbeat attempt failed")
1302
- assert loop_iteration_failures > 5
1303
- assert "error=Exception('oops')" in caplog.text
1304
- assert "Traceback" not in caplog.text # should not print a full traceback - don't scare users!
1305
-
1306
-
1307
- @skip_github_non_linux
1308
- @pytest.mark.usefixtures("server_url_env")
1309
- @pytest.mark.parametrize("method", ["delay", "delay_async"])
1310
- def test_sigint_termination_input(servicer, method):
1311
- # Sync and async container lifecycle methods on a sync function.
1312
- with servicer.input_lockstep() as input_barrier:
1313
- container_process = _run_container_process(
1314
- servicer,
1315
- "test.supports.functions",
1316
- f"LifecycleCls.{method}",
1317
- inputs=[((5,), {})],
1318
- cls_params=((), {"print_at_exit": True}),
1319
- )
1320
- input_barrier.wait() # get input
1321
- time.sleep(0.5)
1322
- signal_time = time.monotonic()
1323
- os.kill(container_process.pid, signal.SIGINT)
1324
-
1325
- stdout, stderr = container_process.communicate(timeout=5)
1326
- stop_duration = time.monotonic() - signal_time
1327
- assert len(servicer.container_outputs) == 0
1328
- assert (
1329
- container_process.returncode == 0
1330
- ) # container should catch and indicate successful termination by exiting cleanly when possible
1331
- assert f"[events:enter_sync,enter_async,{method},exit_sync,exit_async]" in stdout.decode()
1332
- assert "Traceback" not in stderr.decode()
1333
- assert stop_duration < 2.0 # if this would be ~4.5s, then the input isn't getting terminated
1334
- assert servicer.task_result is None
1335
-
1336
-
1337
- @skip_github_non_linux
1338
- @pytest.mark.usefixtures("server_url_env")
1339
- @pytest.mark.parametrize("enter_type", ["sync_enter", "async_enter"])
1340
- @pytest.mark.parametrize("method", ["delay", "delay_async"])
1341
- def test_sigint_termination_enter_handler(servicer, method, enter_type):
1342
- # Sync and async container lifecycle methods on a sync function.
1343
- container_process = _run_container_process(
1344
- servicer,
1345
- "test.supports.functions",
1346
- f"LifecycleCls.{method}",
1347
- inputs=[((5,), {})],
1348
- cls_params=((), {"print_at_exit": True, f"{enter_type}_duration": 10}),
1349
- )
1350
- time.sleep(1) # should be enough to start the enter method
1351
- signal_time = time.monotonic()
1352
- os.kill(container_process.pid, signal.SIGINT)
1353
- stdout, stderr = container_process.communicate(timeout=5)
1354
- stop_duration = time.monotonic() - signal_time
1355
- assert len(servicer.container_outputs) == 0
1356
- assert container_process.returncode == 0
1357
- if enter_type == "sync_enter":
1358
- assert "[events:enter_sync]" in stdout.decode()
1359
- else:
1360
- # enter_sync should run in 0s, and then we interrupt during the async enter
1361
- assert "[events:enter_sync,enter_async]" in stdout.decode()
1362
-
1363
- assert "Traceback" not in stderr.decode()
1364
- assert stop_duration < 2.0 # if this would be ~4.5s, then the task isn't being terminated timely
1365
- assert servicer.task_result is None
1366
-
1367
-
1368
- @skip_github_non_linux
1369
- @pytest.mark.usefixtures("server_url_env")
1370
- @pytest.mark.parametrize("exit_type", ["sync_exit", "async_exit"])
1371
- def test_sigint_termination_exit_handler(servicer, exit_type):
1372
- # Sync and async container lifecycle methods on a sync function.
1373
- with servicer.output_lockstep() as outputs:
1374
- container_process = _run_container_process(
1375
- servicer,
1376
- "test.supports.functions",
1377
- "LifecycleCls.delay",
1378
- inputs=[((0,), {})],
1379
- cls_params=((), {"print_at_exit": True, f"{exit_type}_duration": 2}),
1380
- )
1381
- outputs.wait() # wait for first output to be emitted
1382
- time.sleep(1) # give some time for container to end up in the exit handler
1383
- os.kill(container_process.pid, signal.SIGINT)
1384
-
1385
- stdout, stderr = container_process.communicate(timeout=5)
1386
-
1387
- assert len(servicer.container_outputs) == 1
1388
- assert container_process.returncode == 0
1389
- assert "[events:enter_sync,enter_async,delay,exit_sync,exit_async]" in stdout.decode()
1390
- assert "Traceback" not in stderr.decode()
1391
- assert servicer.task_result is None
1392
-
1393
-
1394
- @skip_github_non_linux
1395
- def test_sandbox(unix_servicer, event_loop):
1396
- ret = _run_container(unix_servicer, "test.supports.functions", "sandbox_f")
1397
- assert _unwrap_scalar(ret) == "sb-123"
1398
-
1399
-
1400
- @skip_github_non_linux
1401
- def test_is_local(unix_servicer, event_loop):
1402
- assert is_local() == True
1403
-
1404
- ret = _run_container(unix_servicer, "test.supports.functions", "is_local_f")
1405
- assert _unwrap_scalar(ret) == False