modal 0.62.16__py3-none-any.whl → 0.72.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +17 -13
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +420 -937
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -59
  11. modal/_resources.py +51 -0
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1036 -0
  15. modal/_runtime/execution_context.py +89 -0
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +134 -9
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +52 -16
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +479 -100
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +460 -171
  29. modal/_utils/grpc_testing.py +47 -31
  30. modal/_utils/grpc_utils.py +62 -109
  31. modal/_utils/hash_utils.py +61 -19
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +5 -7
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +14 -12
  43. modal/app.py +1003 -314
  44. modal/app.pyi +540 -264
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +63 -53
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +205 -45
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +62 -14
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +64 -58
  55. modal/cli/launch.py +32 -18
  56. modal/cli/network_file_system.py +64 -83
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +35 -10
  59. modal/cli/programs/vscode.py +60 -10
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +234 -131
  62. modal/cli/secret.py +8 -7
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +79 -10
  65. modal/cli/volume.py +110 -109
  66. modal/client.py +250 -144
  67. modal/client.pyi +157 -118
  68. modal/cloud_bucket_mount.py +108 -34
  69. modal/cloud_bucket_mount.pyi +32 -38
  70. modal/cls.py +535 -148
  71. modal/cls.pyi +190 -146
  72. modal/config.py +41 -19
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +111 -65
  76. modal/dict.pyi +136 -131
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +34 -43
  80. modal/experimental.py +61 -2
  81. modal/extensions/ipython.py +5 -5
  82. modal/file_io.py +537 -0
  83. modal/file_io.pyi +235 -0
  84. modal/file_pattern_matcher.py +197 -0
  85. modal/functions.py +906 -911
  86. modal/functions.pyi +466 -430
  87. modal/gpu.py +57 -44
  88. modal/image.py +1089 -479
  89. modal/image.pyi +584 -228
  90. modal/io_streams.py +434 -0
  91. modal/io_streams.pyi +122 -0
  92. modal/mount.py +314 -101
  93. modal/mount.pyi +241 -235
  94. modal/network_file_system.py +92 -92
  95. modal/network_file_system.pyi +152 -110
  96. modal/object.py +67 -36
  97. modal/object.pyi +166 -143
  98. modal/output.py +63 -0
  99. modal/parallel_map.py +434 -0
  100. modal/parallel_map.pyi +75 -0
  101. modal/partial_function.py +282 -117
  102. modal/partial_function.pyi +222 -129
  103. modal/proxy.py +15 -12
  104. modal/proxy.pyi +3 -8
  105. modal/queue.py +182 -65
  106. modal/queue.pyi +218 -118
  107. modal/requirements/2024.04.txt +29 -0
  108. modal/requirements/2024.10.txt +16 -0
  109. modal/requirements/README.md +21 -0
  110. modal/requirements/base-images.json +22 -0
  111. modal/retries.py +48 -7
  112. modal/runner.py +459 -156
  113. modal/runner.pyi +135 -71
  114. modal/running_app.py +38 -0
  115. modal/sandbox.py +514 -236
  116. modal/sandbox.pyi +397 -169
  117. modal/schedule.py +4 -4
  118. modal/scheduler_placement.py +20 -3
  119. modal/secret.py +56 -31
  120. modal/secret.pyi +62 -42
  121. modal/serving.py +51 -56
  122. modal/serving.pyi +44 -36
  123. modal/stream_type.py +15 -0
  124. modal/token_flow.py +5 -3
  125. modal/token_flow.pyi +37 -32
  126. modal/volume.py +285 -157
  127. modal/volume.pyi +249 -184
  128. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/METADATA +7 -7
  129. modal-0.72.11.dist-info/RECORD +174 -0
  130. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
  131. modal_docs/gen_reference_docs.py +3 -1
  132. modal_docs/mdmd/mdmd.py +0 -1
  133. modal_docs/mdmd/signatures.py +5 -2
  134. modal_global_objects/images/base_images.py +28 -0
  135. modal_global_objects/mounts/python_standalone.py +2 -2
  136. modal_proto/__init__.py +1 -1
  137. modal_proto/api.proto +1288 -533
  138. modal_proto/api_grpc.py +856 -456
  139. modal_proto/api_pb2.py +2165 -1157
  140. modal_proto/api_pb2.pyi +8859 -0
  141. modal_proto/api_pb2_grpc.py +1674 -855
  142. modal_proto/api_pb2_grpc.pyi +1416 -0
  143. modal_proto/modal_api_grpc.py +149 -0
  144. modal_proto/modal_options_grpc.py +3 -0
  145. modal_proto/options_pb2.pyi +20 -0
  146. modal_proto/options_pb2_grpc.pyi +7 -0
  147. modal_proto/py.typed +0 -0
  148. modal_version/__init__.py +1 -1
  149. modal_version/_version_generated.py +2 -2
  150. modal/_asgi.py +0 -370
  151. modal/_container_entrypoint.pyi +0 -378
  152. modal/_container_exec.py +0 -128
  153. modal/_sandbox_shell.py +0 -49
  154. modal/shared_volume.py +0 -23
  155. modal/shared_volume.pyi +0 -24
  156. modal/stub.py +0 -783
  157. modal/stub.pyi +0 -332
  158. modal-0.62.16.dist-info/RECORD +0 -198
  159. modal_global_objects/images/conda.py +0 -15
  160. modal_global_objects/images/debian_slim.py +0 -15
  161. modal_global_objects/images/micromamba.py +0 -15
  162. test/__init__.py +0 -1
  163. test/aio_test.py +0 -12
  164. test/async_utils_test.py +0 -262
  165. test/blob_test.py +0 -67
  166. test/cli_imports_test.py +0 -149
  167. test/cli_test.py +0 -659
  168. test/client_test.py +0 -194
  169. test/cls_test.py +0 -630
  170. test/config_test.py +0 -137
  171. test/conftest.py +0 -1420
  172. test/container_app_test.py +0 -32
  173. test/container_test.py +0 -1389
  174. test/cpu_test.py +0 -23
  175. test/decorator_test.py +0 -85
  176. test/deprecation_test.py +0 -34
  177. test/dict_test.py +0 -33
  178. test/e2e_test.py +0 -68
  179. test/error_test.py +0 -7
  180. test/function_serialization_test.py +0 -32
  181. test/function_test.py +0 -653
  182. test/function_utils_test.py +0 -101
  183. test/gpu_test.py +0 -159
  184. test/grpc_utils_test.py +0 -141
  185. test/helpers.py +0 -42
  186. test/image_test.py +0 -669
  187. test/live_reload_test.py +0 -80
  188. test/lookup_test.py +0 -70
  189. test/mdmd_test.py +0 -329
  190. test/mount_test.py +0 -162
  191. test/mounted_files_test.py +0 -329
  192. test/network_file_system_test.py +0 -181
  193. test/notebook_test.py +0 -66
  194. test/object_test.py +0 -41
  195. test/package_utils_test.py +0 -25
  196. test/queue_test.py +0 -97
  197. test/resolver_test.py +0 -58
  198. test/retries_test.py +0 -67
  199. test/runner_test.py +0 -85
  200. test/sandbox_test.py +0 -191
  201. test/schedule_test.py +0 -15
  202. test/scheduler_placement_test.py +0 -29
  203. test/secret_test.py +0 -78
  204. test/serialization_test.py +0 -42
  205. test/stub_composition_test.py +0 -10
  206. test/stub_test.py +0 -360
  207. test/test_asgi_wrapper.py +0 -234
  208. test/token_flow_test.py +0 -18
  209. test/traceback_test.py +0 -135
  210. test/tunnel_test.py +0 -29
  211. test/utils_test.py +0 -88
  212. test/version_test.py +0 -14
  213. test/volume_test.py +0 -341
  214. test/watcher_test.py +0 -30
  215. test/webhook_test.py +0 -146
  216. /modal/{requirements.312.txt → requirements/2023.12.312.txt} +0 -0
  217. /modal/{requirements.txt → requirements/2023.12.txt} +0 -0
  218. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
  219. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
  220. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
test/container_test.py DELETED
@@ -1,1389 +0,0 @@
1
- # Copyright Modal Labs 2022
2
-
3
- import asyncio
4
- import base64
5
- import dataclasses
6
- import json
7
- import os
8
- import pathlib
9
- import pickle
10
- import pytest
11
- import signal
12
- import subprocess
13
- import sys
14
- import tempfile
15
- import time
16
- import uuid
17
- from typing import Any, Dict, List, Optional, Tuple
18
- from unittest import mock
19
- from unittest.mock import MagicMock
20
-
21
- from grpclib import Status
22
- from grpclib.exceptions import GRPCError
23
-
24
- from modal import Client
25
- from modal._container_entrypoint import UserException, main
26
- from modal._serialization import (
27
- deserialize,
28
- deserialize_data_format,
29
- serialize,
30
- serialize_data_format,
31
- )
32
- from modal._utils import async_utils
33
- from modal.exception import InvalidError
34
- from modal.partial_function import enter
35
- from modal.stub import _Stub
36
- from modal_proto import api_pb2
37
-
38
- from .helpers import deploy_stub_externally
39
- from .supports.skip import skip_windows_signals, skip_windows_unix_socket
40
-
41
- EXTRA_TOLERANCE_DELAY = 2.0 if sys.platform == "linux" else 5.0
42
- FUNCTION_CALL_ID = "fc-123"
43
- SLEEP_DELAY = 0.1
44
-
45
-
46
- def _get_inputs(
47
- args: Tuple[Tuple, Dict] = ((42,), {}), n: int = 1, kill_switch=True
48
- ) -> List[api_pb2.FunctionGetInputsResponse]:
49
- input_pb = api_pb2.FunctionInput(args=serialize(args), data_format=api_pb2.DATA_FORMAT_PICKLE)
50
- inputs = [
51
- *(
52
- api_pb2.FunctionGetInputsItem(input_id=f"in-xyz{i}", function_call_id="fc-123", input=input_pb)
53
- for i in range(n)
54
- ),
55
- *([api_pb2.FunctionGetInputsItem(kill_switch=True)] if kill_switch else []),
56
- ]
57
- return [api_pb2.FunctionGetInputsResponse(inputs=[x]) for x in inputs]
58
-
59
-
60
- @dataclasses.dataclass
61
- class ContainerResult:
62
- client: Client
63
- items: List[api_pb2.FunctionPutOutputsItem]
64
- data_chunks: List[api_pb2.DataChunk]
65
- task_result: api_pb2.GenericResult
66
-
67
-
68
- def _get_multi_inputs(args: List[Tuple[Tuple, Dict]] = []) -> List[api_pb2.FunctionGetInputsResponse]:
69
- responses = []
70
- for input_n, input_args in enumerate(args):
71
- resp = api_pb2.FunctionGetInputsResponse(
72
- inputs=[
73
- api_pb2.FunctionGetInputsItem(
74
- input_id=f"in-{input_n:03}", input=api_pb2.FunctionInput(args=serialize(input_args))
75
- )
76
- ]
77
- )
78
- responses.append(resp)
79
-
80
- return responses + [api_pb2.FunctionGetInputsResponse(inputs=[api_pb2.FunctionGetInputsItem(kill_switch=True)])]
81
-
82
-
83
- def _container_args(
84
- module_name,
85
- function_name,
86
- function_type=api_pb2.Function.FUNCTION_TYPE_FUNCTION,
87
- webhook_type=api_pb2.WEBHOOK_TYPE_UNSPECIFIED,
88
- definition_type=api_pb2.Function.DEFINITION_TYPE_FILE,
89
- stub_name: str = "",
90
- is_builder_function: bool = False,
91
- allow_concurrent_inputs: Optional[int] = None,
92
- serialized_params: Optional[bytes] = None,
93
- is_checkpointing_function: bool = False,
94
- deps: List[str] = ["im-1"],
95
- volume_mounts: Optional[List[api_pb2.VolumeMount]] = None,
96
- is_auto_snapshot: bool = False,
97
- max_inputs: Optional[int] = None,
98
- ):
99
- if webhook_type:
100
- webhook_config = api_pb2.WebhookConfig(
101
- type=webhook_type,
102
- method="GET",
103
- async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
104
- )
105
- else:
106
- webhook_config = None
107
-
108
- function_def = api_pb2.Function(
109
- module_name=module_name,
110
- function_name=function_name,
111
- function_type=function_type,
112
- volume_mounts=volume_mounts,
113
- webhook_config=webhook_config,
114
- definition_type=definition_type,
115
- stub_name=stub_name or "",
116
- is_builder_function=is_builder_function,
117
- is_auto_snapshot=is_auto_snapshot,
118
- allow_concurrent_inputs=allow_concurrent_inputs,
119
- is_checkpointing_function=is_checkpointing_function,
120
- object_dependencies=[api_pb2.ObjectDependency(object_id=object_id) for object_id in deps],
121
- max_inputs=max_inputs,
122
- )
123
-
124
- return api_pb2.ContainerArguments(
125
- task_id="ta-123",
126
- function_id="fu-123",
127
- app_id="ap-1",
128
- function_def=function_def,
129
- serialized_params=serialized_params,
130
- checkpoint_id=f"ch-{uuid.uuid4()}",
131
- )
132
-
133
-
134
- def _flatten_outputs(outputs) -> List[api_pb2.FunctionPutOutputsItem]:
135
- items: List[api_pb2.FunctionPutOutputsItem] = []
136
- for req in outputs:
137
- items += list(req.outputs)
138
- return items
139
-
140
-
141
- def _run_container(
142
- servicer,
143
- module_name,
144
- function_name,
145
- fail_get_inputs=False,
146
- inputs=None,
147
- function_type=api_pb2.Function.FUNCTION_TYPE_FUNCTION,
148
- webhook_type=api_pb2.WEBHOOK_TYPE_UNSPECIFIED,
149
- definition_type=api_pb2.Function.DEFINITION_TYPE_FILE,
150
- stub_name: str = "",
151
- is_builder_function: bool = False,
152
- allow_concurrent_inputs: Optional[int] = None,
153
- serialized_params: Optional[bytes] = None,
154
- is_checkpointing_function: bool = False,
155
- deps: List[str] = ["im-1"],
156
- volume_mounts: Optional[List[api_pb2.VolumeMount]] = None,
157
- is_auto_snapshot: bool = False,
158
- max_inputs: Optional[int] = None,
159
- ) -> ContainerResult:
160
- container_args = _container_args(
161
- module_name,
162
- function_name,
163
- function_type,
164
- webhook_type,
165
- definition_type,
166
- stub_name,
167
- is_builder_function,
168
- allow_concurrent_inputs,
169
- serialized_params,
170
- is_checkpointing_function,
171
- deps,
172
- volume_mounts,
173
- is_auto_snapshot,
174
- max_inputs,
175
- )
176
- with Client(servicer.remote_addr, api_pb2.CLIENT_TYPE_CONTAINER, ("ta-123", "task-secret")) as client:
177
- if inputs is None:
178
- servicer.container_inputs = _get_inputs()
179
- else:
180
- servicer.container_inputs = inputs
181
- function_call_id = servicer.container_inputs[0].inputs[0].function_call_id
182
- servicer.fail_get_inputs = fail_get_inputs
183
-
184
- if module_name in sys.modules:
185
- # Drop the module from sys.modules since some function code relies on the
186
- # assumption that that the app is created before the user code is imported.
187
- # This is really only an issue for tests.
188
- sys.modules.pop(module_name)
189
-
190
- env = os.environ.copy()
191
- temp_restore_file_path = tempfile.NamedTemporaryFile()
192
- if is_checkpointing_function:
193
- # State file is written to allow for a restore to happen.
194
- tmp_file_name = temp_restore_file_path.name
195
- with pathlib.Path(tmp_file_name).open("w") as target:
196
- json.dump({}, target)
197
- env["MODAL_RESTORE_STATE_PATH"] = tmp_file_name
198
-
199
- # Override server URL to reproduce restore behavior.
200
- env["MODAL_SERVER_URL"] = servicer.remote_addr
201
-
202
- # reset _Stub tracking state between runs
203
- _Stub._all_stubs = {}
204
-
205
- try:
206
- with mock.patch.dict(os.environ, env):
207
- main(container_args, client)
208
- except UserException:
209
- # Handle it gracefully
210
- pass
211
- finally:
212
- temp_restore_file_path.close()
213
-
214
- # Flatten outputs
215
- items = _flatten_outputs(servicer.container_outputs)
216
-
217
- # Get data chunks
218
- data_chunks: List[api_pb2.DataChunk] = []
219
- if function_call_id in servicer.fc_data_out:
220
- try:
221
- while True:
222
- chunk = servicer.fc_data_out[function_call_id].get_nowait()
223
- data_chunks.append(chunk)
224
- except asyncio.QueueEmpty:
225
- pass
226
-
227
- return ContainerResult(client, items, data_chunks, servicer.task_result)
228
-
229
-
230
- def _unwrap_scalar(ret: ContainerResult):
231
- assert len(ret.items) == 1
232
- assert ret.items[0].result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
233
- return deserialize(ret.items[0].result.data, ret.client)
234
-
235
-
236
- def _unwrap_exception(ret: ContainerResult):
237
- assert len(ret.items) == 1
238
- assert ret.items[0].result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE
239
- assert "Traceback" in ret.items[0].result.traceback
240
- return ret.items[0].result.exception
241
-
242
-
243
- def _unwrap_generator(ret: ContainerResult) -> Tuple[List[Any], Optional[Exception]]:
244
- assert len(ret.items) == 1
245
- item = ret.items[0]
246
- assert item.result.gen_status == api_pb2.GenericResult.GENERATOR_STATUS_UNSPECIFIED
247
-
248
- values: List[Any] = [deserialize_data_format(chunk.data, chunk.data_format, None) for chunk in ret.data_chunks]
249
-
250
- if item.result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE:
251
- exc = deserialize(item.result.data, ret.client)
252
- return values, exc
253
- elif item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS:
254
- assert item.data_format == api_pb2.DATA_FORMAT_GENERATOR_DONE
255
- done: api_pb2.GeneratorDone = deserialize_data_format(item.result.data, item.data_format, None)
256
- assert done.items_total == len(values)
257
- return values, None
258
- else:
259
- raise RuntimeError("unknown result type")
260
-
261
-
262
- def _unwrap_asgi(ret: ContainerResult):
263
- values, exc = _unwrap_generator(ret)
264
- assert exc is None, "web endpoint raised exception"
265
- return values
266
-
267
-
268
- @skip_windows_unix_socket
269
- def test_success(unix_servicer, event_loop):
270
- t0 = time.time()
271
- ret = _run_container(unix_servicer, "test.supports.functions", "square")
272
- assert 0 <= time.time() - t0 < EXTRA_TOLERANCE_DELAY
273
- assert _unwrap_scalar(ret) == 42**2
274
-
275
-
276
- @skip_windows_unix_socket
277
- def test_generator_success(unix_servicer, event_loop):
278
- ret = _run_container(
279
- unix_servicer,
280
- "test.supports.functions",
281
- "gen_n",
282
- function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR,
283
- )
284
-
285
- items, exc = _unwrap_generator(ret)
286
- assert items == [i**2 for i in range(42)]
287
- assert exc is None
288
-
289
-
290
- @skip_windows_unix_socket
291
- def test_generator_failure(unix_servicer, capsys):
292
- inputs = _get_inputs(((10, 5), {}))
293
- ret = _run_container(
294
- unix_servicer,
295
- "test.supports.functions",
296
- "gen_n_fail_on_m",
297
- function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR,
298
- inputs=inputs,
299
- )
300
- items, exc = _unwrap_generator(ret)
301
- assert items == [i**2 for i in range(5)]
302
- assert isinstance(exc, Exception)
303
- assert exc.args == ("bad",)
304
- assert 'raise Exception("bad")' in capsys.readouterr().err
305
-
306
-
307
- @skip_windows_unix_socket
308
- def test_async(unix_servicer):
309
- t0 = time.time()
310
- ret = _run_container(unix_servicer, "test.supports.functions", "square_async")
311
- assert SLEEP_DELAY <= time.time() - t0 < SLEEP_DELAY + EXTRA_TOLERANCE_DELAY
312
- assert _unwrap_scalar(ret) == 42**2
313
-
314
-
315
- @skip_windows_unix_socket
316
- def test_failure(unix_servicer, capsys):
317
- ret = _run_container(unix_servicer, "test.supports.functions", "raises")
318
- assert _unwrap_exception(ret) == "Exception('Failure!')"
319
- assert 'raise Exception("Failure!")' in capsys.readouterr().err # traceback
320
-
321
-
322
- @skip_windows_unix_socket
323
- def test_raises_base_exception(unix_servicer, capsys):
324
- ret = _run_container(unix_servicer, "test.supports.functions", "raises_sysexit")
325
- assert _unwrap_exception(ret) == "SystemExit(1)"
326
- assert "raise SystemExit(1)" in capsys.readouterr().err # traceback
327
-
328
-
329
- @skip_windows_unix_socket
330
- def test_keyboardinterrupt(unix_servicer):
331
- with pytest.raises(KeyboardInterrupt):
332
- _run_container(unix_servicer, "test.supports.functions", "raises_keyboardinterrupt")
333
-
334
-
335
- @skip_windows_unix_socket
336
- def test_rate_limited(unix_servicer, event_loop):
337
- t0 = time.time()
338
- unix_servicer.rate_limit_sleep_duration = 0.25
339
- ret = _run_container(unix_servicer, "test.supports.functions", "square")
340
- assert 0.25 <= time.time() - t0 < 0.25 + EXTRA_TOLERANCE_DELAY
341
- assert _unwrap_scalar(ret) == 42**2
342
-
343
-
344
- @skip_windows_unix_socket
345
- def test_grpc_failure(unix_servicer, event_loop):
346
- # An error in "Modal code" should cause the entire container to fail
347
- with pytest.raises(GRPCError):
348
- _run_container(
349
- unix_servicer,
350
- "test.supports.functions",
351
- "square",
352
- fail_get_inputs=True,
353
- )
354
-
355
- # assert unix_servicer.task_result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE
356
- # assert "GRPCError" in unix_servicer.task_result.exception
357
-
358
-
359
- @skip_windows_unix_socket
360
- def test_missing_main_conditional(unix_servicer, capsys):
361
- _run_container(unix_servicer, "test.supports.missing_main_conditional", "square")
362
- output = capsys.readouterr()
363
- assert "Can not run an app from within a container" in output.err
364
-
365
- assert unix_servicer.task_result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE
366
- assert "modal run" in unix_servicer.task_result.traceback
367
-
368
- exc = deserialize(unix_servicer.task_result.data, None)
369
- assert isinstance(exc, InvalidError)
370
-
371
-
372
- @skip_windows_unix_socket
373
- def test_startup_failure(unix_servicer, capsys):
374
- _run_container(unix_servicer, "test.supports.startup_failure", "f")
375
-
376
- assert unix_servicer.task_result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE
377
-
378
- exc = deserialize(unix_servicer.task_result.data, None)
379
- assert isinstance(exc, ImportError)
380
- assert "ModuleNotFoundError: No module named 'nonexistent_package'" in capsys.readouterr().err
381
-
382
-
383
- @skip_windows_unix_socket
384
- def test_from_local_python_packages_inside_container(unix_servicer):
385
- """`from_local_python_packages` shouldn't actually collect modules inside the container, because it's possible
386
- that there are modules that were present locally for the user that didn't get mounted into
387
- all the containers."""
388
- ret = _run_container(unix_servicer, "test.supports.package_mount", "num_mounts")
389
- assert _unwrap_scalar(ret) == 0
390
-
391
-
392
- def _get_web_inputs(path="/"):
393
- scope = {
394
- "method": "GET",
395
- "type": "http",
396
- "path": path,
397
- "headers": {},
398
- "query_string": b"arg=space",
399
- "http_version": "2",
400
- }
401
- return _get_inputs(((scope,), {}))
402
-
403
-
404
- @async_utils.synchronize_api # needs to be synchronized so the asyncio.Queue gets used from the same event loop as the servicer
405
- async def _put_web_body(servicer, body: bytes):
406
- asgi = {"type": "http.request", "body": body, "more_body": False}
407
- data = serialize_data_format(asgi, api_pb2.DATA_FORMAT_ASGI)
408
-
409
- q = servicer.fc_data_in.setdefault("fc-123", asyncio.Queue())
410
- q.put_nowait(api_pb2.DataChunk(data_format=api_pb2.DATA_FORMAT_ASGI, data=data, index=1))
411
-
412
-
413
- @skip_windows_unix_socket
414
- def test_webhook(unix_servicer):
415
- inputs = _get_web_inputs()
416
- _put_web_body(unix_servicer, b"")
417
- ret = _run_container(
418
- unix_servicer,
419
- "test.supports.functions",
420
- "webhook",
421
- inputs=inputs,
422
- webhook_type=api_pb2.WEBHOOK_TYPE_FUNCTION,
423
- )
424
- items = _unwrap_asgi(ret)
425
-
426
- # There should be one message for the header, one for the body, one for the EOF
427
- first_message, second_message = items # _unwrap_asgi ignores the eof
428
-
429
- # Check the headers
430
- assert first_message["status"] == 200
431
- headers = dict(first_message["headers"])
432
- assert headers[b"content-type"] == b"application/json"
433
-
434
- # Check body
435
- assert json.loads(second_message["body"]) == {"hello": "space"}
436
-
437
-
438
- @skip_windows_unix_socket
439
- def test_serialized_function(unix_servicer):
440
- def triple(x):
441
- return 3 * x
442
-
443
- unix_servicer.function_serialized = serialize(triple)
444
- ret = _run_container(
445
- unix_servicer,
446
- "foo.bar.baz",
447
- "f",
448
- definition_type=api_pb2.Function.DEFINITION_TYPE_SERIALIZED,
449
- )
450
- assert _unwrap_scalar(ret) == 3 * 42
451
-
452
-
453
- @skip_windows_unix_socket
454
- def test_webhook_serialized(unix_servicer):
455
- inputs = _get_web_inputs()
456
- _put_web_body(unix_servicer, b"")
457
-
458
- # Store a serialized webhook function on the servicer
459
- def webhook(arg="world"):
460
- return f"Hello, {arg}"
461
-
462
- unix_servicer.function_serialized = serialize(webhook)
463
-
464
- ret = _run_container(
465
- unix_servicer,
466
- "foo.bar.baz",
467
- "f",
468
- inputs=inputs,
469
- webhook_type=api_pb2.WEBHOOK_TYPE_FUNCTION,
470
- definition_type=api_pb2.Function.DEFINITION_TYPE_SERIALIZED,
471
- )
472
-
473
- _, second_message = _unwrap_asgi(ret)
474
- assert second_message["body"] == b'"Hello, space"' # Note: JSON-encoded
475
-
476
-
477
- @skip_windows_unix_socket
478
- def test_function_returning_generator(unix_servicer):
479
- ret = _run_container(
480
- unix_servicer,
481
- "test.supports.functions",
482
- "fun_returning_gen",
483
- function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR,
484
- )
485
- items, exc = _unwrap_generator(ret)
486
- assert len(items) == 42
487
-
488
-
489
- @skip_windows_unix_socket
490
- def test_asgi(unix_servicer):
491
- inputs = _get_web_inputs(path="/foo")
492
- _put_web_body(unix_servicer, b"")
493
- ret = _run_container(
494
- unix_servicer,
495
- "test.supports.functions",
496
- "fastapi_app",
497
- inputs=inputs,
498
- webhook_type=api_pb2.WEBHOOK_TYPE_ASGI_APP,
499
- )
500
-
501
- # There should be one message for the header, and one for the body
502
- first_message, second_message = _unwrap_asgi(ret)
503
-
504
- # Check the headers
505
- assert first_message["status"] == 200
506
- headers = dict(first_message["headers"])
507
- assert headers[b"content-type"] == b"application/json"
508
-
509
- # Check body
510
- assert json.loads(second_message["body"]) == {"hello": "space"}
511
-
512
-
513
- @skip_windows_unix_socket
514
- def test_wsgi(unix_servicer):
515
- inputs = _get_web_inputs(path="/")
516
- _put_web_body(unix_servicer, b"my wsgi body")
517
- ret = _run_container(
518
- unix_servicer,
519
- "test.supports.functions",
520
- "basic_wsgi_app",
521
- inputs=inputs,
522
- webhook_type=api_pb2.WEBHOOK_TYPE_WSGI_APP,
523
- )
524
-
525
- # There should be one message for headers, one for the body, and one for the end-of-body.
526
- first_message, second_message, third_message = _unwrap_asgi(ret)
527
-
528
- # Check the headers
529
- assert first_message["status"] == 200
530
- headers = dict(first_message["headers"])
531
- assert headers[b"content-type"] == b"text/plain; charset=utf-8"
532
-
533
- # Check body
534
- assert second_message["body"] == b"got body: my wsgi body"
535
- assert second_message.get("more_body", False) is True
536
- assert third_message["body"] == b""
537
- assert third_message.get("more_body", False) is False
538
-
539
-
540
- @skip_windows_unix_socket
541
- def test_webhook_streaming_sync(unix_servicer):
542
- inputs = _get_web_inputs()
543
- _put_web_body(unix_servicer, b"")
544
- ret = _run_container(
545
- unix_servicer,
546
- "test.supports.functions",
547
- "webhook_streaming",
548
- inputs=inputs,
549
- webhook_type=api_pb2.WEBHOOK_TYPE_FUNCTION,
550
- function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR,
551
- )
552
- data = _unwrap_asgi(ret)
553
- bodies = [d["body"].decode() for d in data if d.get("body")]
554
- assert bodies == [f"{i}..." for i in range(10)]
555
-
556
-
557
- @skip_windows_unix_socket
558
- def test_webhook_streaming_async(unix_servicer):
559
- inputs = _get_web_inputs()
560
- _put_web_body(unix_servicer, b"")
561
- ret = _run_container(
562
- unix_servicer,
563
- "test.supports.functions",
564
- "webhook_streaming_async",
565
- inputs=inputs,
566
- webhook_type=api_pb2.WEBHOOK_TYPE_FUNCTION,
567
- function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR,
568
- )
569
-
570
- data = _unwrap_asgi(ret)
571
- bodies = [d["body"].decode() for d in data if d.get("body")]
572
- assert bodies == [f"{i}..." for i in range(10)]
573
-
574
-
575
- @skip_windows_unix_socket
576
- def test_cls_function(unix_servicer):
577
- ret = _run_container(unix_servicer, "test.supports.functions", "Cls.f")
578
- assert _unwrap_scalar(ret) == 42 * 111
579
-
580
-
581
- @skip_windows_unix_socket
582
- def test_lifecycle_enter_sync(unix_servicer):
583
- ret = _run_container(unix_servicer, "test.supports.functions", "LifecycleCls.f_sync", inputs=_get_inputs(((), {})))
584
- assert _unwrap_scalar(ret) == ["enter_sync", "enter_async", "f_sync"]
585
-
586
-
587
- @skip_windows_unix_socket
588
- def test_lifecycle_enter_async(unix_servicer):
589
- ret = _run_container(unix_servicer, "test.supports.functions", "LifecycleCls.f_async", inputs=_get_inputs(((), {})))
590
- assert _unwrap_scalar(ret) == ["enter_sync", "enter_async", "f_async"]
591
-
592
-
593
- @skip_windows_unix_socket
594
- def test_param_cls_function(unix_servicer):
595
- serialized_params = pickle.dumps(([111], {"y": "foo"}))
596
- ret = _run_container(
597
- unix_servicer,
598
- "test.supports.functions",
599
- "ParamCls.f",
600
- serialized_params=serialized_params,
601
- )
602
- assert _unwrap_scalar(ret) == "111 foo 42"
603
-
604
-
605
- @skip_windows_unix_socket
606
- def test_cls_web_endpoint(unix_servicer):
607
- inputs = _get_web_inputs()
608
- ret = _run_container(
609
- unix_servicer,
610
- "test.supports.functions",
611
- "Cls.web",
612
- inputs=inputs,
613
- webhook_type=api_pb2.WEBHOOK_TYPE_FUNCTION,
614
- )
615
-
616
- _, second_message = _unwrap_asgi(ret)
617
- assert json.loads(second_message["body"]) == {"ret": "space" * 111}
618
-
619
-
620
- @skip_windows_unix_socket
621
- def test_serialized_cls(unix_servicer):
622
- class Cls:
623
- @enter()
624
- def enter(self):
625
- self.power = 5
626
-
627
- def method(self, x):
628
- return x**self.power
629
-
630
- unix_servicer.class_serialized = serialize(Cls)
631
- unix_servicer.function_serialized = serialize(Cls.method)
632
- ret = _run_container(
633
- unix_servicer,
634
- "module.doesnt.matter",
635
- "function.doesnt.matter",
636
- definition_type=api_pb2.Function.DEFINITION_TYPE_SERIALIZED,
637
- )
638
- assert _unwrap_scalar(ret) == 42**5
639
-
640
-
641
- @skip_windows_unix_socket
642
- def test_cls_generator(unix_servicer):
643
- ret = _run_container(
644
- unix_servicer,
645
- "test.supports.functions",
646
- "Cls.generator",
647
- function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR,
648
- )
649
- items, exc = _unwrap_generator(ret)
650
- assert items == [42**3]
651
- assert exc is None
652
-
653
-
654
- @skip_windows_unix_socket
655
- def test_checkpointing_cls_function(unix_servicer):
656
- ret = _run_container(
657
- unix_servicer,
658
- "test.supports.functions",
659
- "CheckpointingCls.f",
660
- inputs=_get_inputs((("D",), {})),
661
- is_checkpointing_function=True,
662
- )
663
- assert any(isinstance(request, api_pb2.ContainerCheckpointRequest) for request in unix_servicer.requests)
664
- for request in unix_servicer.requests:
665
- if isinstance(request, api_pb2.ContainerCheckpointRequest):
666
- assert request.checkpoint_id
667
- assert _unwrap_scalar(ret) == "ABCD"
668
-
669
-
670
- @skip_windows_unix_socket
671
- def test_cls_enter_uses_event_loop(unix_servicer):
672
- ret = _run_container(
673
- unix_servicer,
674
- "test.supports.functions",
675
- "EventLoopCls.f",
676
- inputs=_get_inputs(((), {})),
677
- )
678
- assert _unwrap_scalar(ret) == True
679
-
680
-
681
- @skip_windows_unix_socket
682
- def test_container_heartbeats(unix_servicer):
683
- _run_container(unix_servicer, "test.supports.functions", "square")
684
- assert any(isinstance(request, api_pb2.ContainerHeartbeatRequest) for request in unix_servicer.requests)
685
-
686
-
687
- @skip_windows_unix_socket
688
- def test_cli(unix_servicer):
689
- # This tests the container being invoked as a subprocess (the if __name__ == "__main__" block)
690
-
691
- # Build up payload we pass through sys args
692
- function_def = api_pb2.Function(
693
- module_name="test.supports.functions",
694
- function_name="square",
695
- function_type=api_pb2.Function.FUNCTION_TYPE_FUNCTION,
696
- definition_type=api_pb2.Function.DEFINITION_TYPE_FILE,
697
- object_dependencies=[api_pb2.ObjectDependency(object_id="im-123")],
698
- )
699
- container_args = api_pb2.ContainerArguments(
700
- task_id="ta-123",
701
- function_id="fu-123",
702
- app_id="ap-123",
703
- function_def=function_def,
704
- )
705
- data_base64: str = base64.b64encode(container_args.SerializeToString()).decode("ascii")
706
-
707
- # Needed for function hydration
708
- unix_servicer.app_objects["ap-123"] = {"": "im-123"}
709
-
710
- # Inputs that will be consumed by the container
711
- unix_servicer.container_inputs = _get_inputs()
712
-
713
- # Launch subprocess
714
- env = {"MODAL_SERVER_URL": unix_servicer.remote_addr}
715
- lib_dir = pathlib.Path(__file__).parent.parent
716
- args: List[str] = [sys.executable, "-m", "modal._container_entrypoint", data_base64]
717
- ret = subprocess.run(args, cwd=lib_dir, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
718
- stdout = ret.stdout.decode()
719
- stderr = ret.stderr.decode()
720
- if ret.returncode != 0:
721
- raise Exception(f"Failed with {ret.returncode} stdout: {stdout} stderr: {stderr}")
722
-
723
- assert stdout == ""
724
- assert stderr == ""
725
-
726
-
727
- @skip_windows_unix_socket
728
- def test_function_sibling_hydration(unix_servicer):
729
- deploy_stub_externally(unix_servicer, "test.supports.functions", "stub")
730
- ret = _run_container(unix_servicer, "test.supports.functions", "check_sibling_hydration")
731
- assert _unwrap_scalar(ret) is None
732
-
733
-
734
- @skip_windows_unix_socket
735
- def test_multistub(unix_servicer, caplog):
736
- deploy_stub_externally(unix_servicer, "test.supports.multistub", "a")
737
- ret = _run_container(unix_servicer, "test.supports.multistub", "a_func")
738
- assert _unwrap_scalar(ret) is None
739
- assert len(caplog.messages) == 0
740
- # Note that the stub can be inferred from the function, even though there are multiple
741
- # stubs present in the file
742
-
743
-
744
- @skip_windows_unix_socket
745
- def test_multistub_privately_decorated(unix_servicer, caplog):
746
- # function handle does not override the original function, so we can't find the stub
747
- # and the two stubs are not named
748
- ret = _run_container(unix_servicer, "test.supports.multistub_privately_decorated", "foo")
749
- assert _unwrap_scalar(ret) == 1
750
- assert "You have more than one unnamed stub." in caplog.text
751
-
752
-
753
- @skip_windows_unix_socket
754
- def test_multistub_privately_decorated_named_stub(unix_servicer, caplog):
755
- # function handle does not override the original function, so we can't find the stub
756
- # but we can use the names of the stubs to determine the active stub
757
- ret = _run_container(
758
- unix_servicer,
759
- "test.supports.multistub_privately_decorated_named_stub",
760
- "foo",
761
- stub_name="dummy",
762
- )
763
- assert _unwrap_scalar(ret) == 1
764
- assert len(caplog.messages) == 0 # no warnings, since target stub is named
765
-
766
-
767
- @skip_windows_unix_socket
768
- def test_multistub_same_name_warning(unix_servicer, caplog, capsys):
769
- # function handle does not override the original function, so we can't find the stub
770
- # two stubs with the same name - warn since we won't know which one to hydrate
771
- ret = _run_container(
772
- unix_servicer,
773
- "test.supports.multistub_same_name",
774
- "foo",
775
- stub_name="dummy",
776
- )
777
- assert _unwrap_scalar(ret) == 1
778
- assert "You have more than one stub with the same name ('dummy')" in caplog.text
779
- capsys.readouterr()
780
-
781
-
782
- @skip_windows_unix_socket
783
- def test_multistub_serialized_func(unix_servicer, caplog):
784
- # serialized functions shouldn't warn about multiple/not finding stubs, since they shouldn't load the module to begin with
785
- def dummy(x):
786
- return x
787
-
788
- unix_servicer.function_serialized = serialize(dummy)
789
- ret = _run_container(
790
- unix_servicer,
791
- "test.supports.multistub_serialized_func",
792
- "foo",
793
- definition_type=api_pb2.Function.DEFINITION_TYPE_SERIALIZED,
794
- )
795
- assert _unwrap_scalar(ret) == 42
796
- assert len(caplog.messages) == 0
797
-
798
-
799
- @skip_windows_unix_socket
800
- def test_image_run_function_no_warn(unix_servicer, caplog):
801
- # builder functions currently aren't tied to any modal stub, so they shouldn't need to warn if they can't determine a stub to use
802
- ret = _run_container(
803
- unix_servicer,
804
- "test.supports.image_run_function",
805
- "builder_function",
806
- inputs=_get_inputs(((), {})),
807
- is_builder_function=True,
808
- )
809
- assert _unwrap_scalar(ret) is None
810
- assert len(caplog.messages) == 0
811
-
812
-
813
- SLEEP_TIME = 0.7
814
-
815
-
816
- def _unwrap_concurrent_input_outputs(n_inputs: int, n_parallel: int, ret: ContainerResult):
817
- # Ensure that outputs align with expectation of running concurrent inputs
818
-
819
- # Each group of n_parallel inputs should start together of each other
820
- # and different groups should start SLEEP_TIME apart.
821
- assert len(ret.items) == n_inputs
822
- for i in range(1, len(ret.items)):
823
- diff = ret.items[i].input_started_at - ret.items[i - 1].input_started_at
824
- expected_diff = SLEEP_TIME if i % n_parallel == 0 else 0
825
- assert diff == pytest.approx(expected_diff, abs=0.3)
826
-
827
- outputs = []
828
- for item in ret.items:
829
- assert item.output_created_at - item.input_started_at == pytest.approx(SLEEP_TIME, abs=0.3)
830
- assert item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
831
- outputs.append(deserialize(item.result.data, ret.client))
832
- return outputs
833
-
834
-
835
- @skip_windows_unix_socket
836
- def test_concurrent_inputs_sync_function(unix_servicer):
837
- n_inputs = 18
838
- n_parallel = 6
839
-
840
- t0 = time.time()
841
- ret = _run_container(
842
- unix_servicer,
843
- "test.supports.functions",
844
- "sleep_700_sync",
845
- inputs=_get_inputs(n=n_inputs),
846
- allow_concurrent_inputs=n_parallel,
847
- )
848
-
849
- expected_execution = n_inputs / n_parallel * SLEEP_TIME
850
- assert expected_execution <= time.time() - t0 < expected_execution + EXTRA_TOLERANCE_DELAY
851
- outputs = _unwrap_concurrent_input_outputs(n_inputs, n_parallel, ret)
852
- for i, (squared, input_id, function_call_id) in enumerate(outputs):
853
- assert squared == 42**2
854
- assert input_id and input_id != outputs[i - 1][1]
855
- assert function_call_id and function_call_id == outputs[i - 1][2]
856
-
857
-
858
- @skip_windows_unix_socket
859
- def test_concurrent_inputs_async_function(unix_servicer):
860
- n_inputs = 18
861
- n_parallel = 6
862
-
863
- t0 = time.time()
864
- ret = _run_container(
865
- unix_servicer,
866
- "test.supports.functions",
867
- "sleep_700_async",
868
- inputs=_get_inputs(n=n_inputs),
869
- allow_concurrent_inputs=n_parallel,
870
- )
871
-
872
- expected_execution = n_inputs / n_parallel * SLEEP_TIME
873
- assert expected_execution <= time.time() - t0 < expected_execution + EXTRA_TOLERANCE_DELAY
874
- outputs = _unwrap_concurrent_input_outputs(n_inputs, n_parallel, ret)
875
- for i, (squared, input_id, function_call_id) in enumerate(outputs):
876
- assert squared == 42**2
877
- assert input_id and input_id != outputs[i - 1][1]
878
- assert function_call_id and function_call_id == outputs[i - 1][2]
879
-
880
-
881
- @skip_windows_unix_socket
882
- def test_unassociated_function(unix_servicer):
883
- ret = _run_container(unix_servicer, "test.supports.functions", "unassociated_function")
884
- assert _unwrap_scalar(ret) == 58
885
-
886
-
887
- @skip_windows_unix_socket
888
- def test_param_cls_function_calling_local(unix_servicer):
889
- serialized_params = pickle.dumps(([111], {"y": "foo"}))
890
- ret = _run_container(
891
- unix_servicer,
892
- "test.supports.functions",
893
- "ParamCls.g",
894
- serialized_params=serialized_params,
895
- )
896
- assert _unwrap_scalar(ret) == "111 foo 42"
897
-
898
-
899
- @skip_windows_unix_socket
900
- def test_derived_cls(unix_servicer):
901
- ret = _run_container(
902
- unix_servicer,
903
- "test.supports.functions",
904
- "DerivedCls.run",
905
- inputs=_get_inputs(((3,), {})),
906
- )
907
- assert _unwrap_scalar(ret) == 6
908
-
909
-
910
- @skip_windows_unix_socket
911
- def test_call_function_that_calls_function(unix_servicer):
912
- deploy_stub_externally(unix_servicer, "test.supports.functions", "stub")
913
- ret = _run_container(
914
- unix_servicer,
915
- "test.supports.functions",
916
- "cube",
917
- inputs=_get_inputs(((42,), {})),
918
- )
919
- assert _unwrap_scalar(ret) == 42**3
920
-
921
-
922
- @skip_windows_unix_socket
923
- def test_call_function_that_calls_method(unix_servicer):
924
- deploy_stub_externally(unix_servicer, "test.supports.functions", "stub")
925
- ret = _run_container(
926
- unix_servicer,
927
- "test.supports.functions",
928
- "function_calling_method",
929
- inputs=_get_inputs(((42, "abc", 123), {})),
930
- )
931
- assert _unwrap_scalar(ret) == 123**2 # servicer's implementation of function calling
932
-
933
-
934
- @skip_windows_unix_socket
935
- def test_checkpoint_and_restore_success(unix_servicer):
936
- """Functions send a checkpointing request and continue to execute normally,
937
- simulating a restore operation."""
938
- ret = _run_container(
939
- unix_servicer,
940
- "test.supports.functions",
941
- "square",
942
- is_checkpointing_function=True,
943
- )
944
- assert any(isinstance(request, api_pb2.ContainerCheckpointRequest) for request in unix_servicer.requests)
945
- for request in unix_servicer.requests:
946
- if isinstance(request, api_pb2.ContainerCheckpointRequest):
947
- assert request.checkpoint_id
948
-
949
- assert _unwrap_scalar(ret) == 42**2
950
-
951
-
952
- @skip_windows_unix_socket
953
- def test_volume_commit_on_exit(unix_servicer):
954
- volume_mounts = [
955
- api_pb2.VolumeMount(mount_path="/var/foo", volume_id="vo-123", allow_background_commits=True),
956
- api_pb2.VolumeMount(mount_path="/var/foo", volume_id="vo-456", allow_background_commits=True),
957
- ]
958
- ret = _run_container(
959
- unix_servicer,
960
- "test.supports.functions",
961
- "square",
962
- volume_mounts=volume_mounts,
963
- )
964
- volume_commit_rpcs = [r for r in unix_servicer.requests if isinstance(r, api_pb2.VolumeCommitRequest)]
965
- assert volume_commit_rpcs
966
- assert {"vo-123", "vo-456"} == set(r.volume_id for r in volume_commit_rpcs)
967
- assert _unwrap_scalar(ret) == 42**2
968
-
969
-
970
- @skip_windows_unix_socket
971
- def test_volume_commit_on_error(unix_servicer, capsys):
972
- volume_mounts = [
973
- api_pb2.VolumeMount(mount_path="/var/foo", volume_id="vo-foo", allow_background_commits=True),
974
- api_pb2.VolumeMount(mount_path="/var/foo", volume_id="vo-bar", allow_background_commits=True),
975
- ]
976
- _run_container(
977
- unix_servicer,
978
- "test.supports.functions",
979
- "raises",
980
- volume_mounts=volume_mounts,
981
- )
982
- volume_commit_rpcs = [r for r in unix_servicer.requests if isinstance(r, api_pb2.VolumeCommitRequest)]
983
- assert {"vo-foo", "vo-bar"} == set(r.volume_id for r in volume_commit_rpcs)
984
- assert 'raise Exception("Failure!")' in capsys.readouterr().err
985
-
986
-
987
- @skip_windows_unix_socket
988
- def test_no_volume_commit_on_exit(unix_servicer):
989
- volume_mounts = [api_pb2.VolumeMount(mount_path="/var/foo", volume_id="vo-999", allow_background_commits=False)]
990
- ret = _run_container(
991
- unix_servicer,
992
- "test.supports.functions",
993
- "square",
994
- volume_mounts=volume_mounts,
995
- )
996
- volume_commit_rpcs = [r for r in unix_servicer.requests if isinstance(r, api_pb2.VolumeCommitRequest)]
997
- assert not volume_commit_rpcs # No volume commit on exit for legacy volumes
998
- assert _unwrap_scalar(ret) == 42**2
999
-
1000
-
1001
- @skip_windows_unix_socket
1002
- def test_volume_commit_on_exit_doesnt_fail_container(unix_servicer):
1003
- volume_mounts = [
1004
- api_pb2.VolumeMount(mount_path="/var/foo", volume_id="vo-999", allow_background_commits=True),
1005
- api_pb2.VolumeMount(
1006
- mount_path="/var/foo",
1007
- volume_id="BAD-ID-FOR-VOL",
1008
- allow_background_commits=True,
1009
- ),
1010
- api_pb2.VolumeMount(mount_path="/var/foo", volume_id="vol-111", allow_background_commits=True),
1011
- ]
1012
- ret = _run_container(
1013
- unix_servicer,
1014
- "test.supports.functions",
1015
- "square",
1016
- volume_mounts=volume_mounts,
1017
- )
1018
- volume_commit_rpcs = [r for r in unix_servicer.requests if isinstance(r, api_pb2.VolumeCommitRequest)]
1019
- assert len(volume_commit_rpcs) == 3
1020
- assert _unwrap_scalar(ret) == 42**2
1021
-
1022
-
1023
- @skip_windows_unix_socket
1024
- def test_function_dep_hydration(unix_servicer):
1025
- deploy_stub_externally(unix_servicer, "test.supports.functions", "stub")
1026
- ret = _run_container(
1027
- unix_servicer,
1028
- "test.supports.functions",
1029
- "check_dep_hydration",
1030
- deps=["im-1", "vo-0", "im-1", "im-2", "vo-0", "vo-1"],
1031
- )
1032
- assert _unwrap_scalar(ret) is None
1033
-
1034
-
1035
- @skip_windows_unix_socket
1036
- def test_build_decorator_cls(unix_servicer):
1037
- ret = _run_container(
1038
- unix_servicer,
1039
- "test.supports.functions",
1040
- "BuildCls.build1",
1041
- inputs=_get_inputs(((), {})),
1042
- is_builder_function=True,
1043
- is_auto_snapshot=True,
1044
- )
1045
- assert _unwrap_scalar(ret) == 101
1046
- # TODO: this is GENERIC_STATUS_FAILURE when `@exit` fails,
1047
- # but why is it not set when `@exit` is successful?
1048
- # assert ret.task_result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
1049
- assert ret.task_result is None
1050
-
1051
-
1052
- @skip_windows_unix_socket
1053
- def test_multiple_build_decorator_cls(unix_servicer):
1054
- ret = _run_container(
1055
- unix_servicer,
1056
- "test.supports.functions",
1057
- "BuildCls.build2",
1058
- inputs=_get_inputs(((), {})),
1059
- is_builder_function=True,
1060
- is_auto_snapshot=True,
1061
- )
1062
- assert _unwrap_scalar(ret) == 1001
1063
- assert ret.task_result is None
1064
-
1065
-
1066
- @skip_windows_unix_socket
1067
- @pytest.mark.timeout(10.0)
1068
- def test_function_io_doesnt_inspect_args_or_return_values(monkeypatch, unix_servicer):
1069
- synchronizer = async_utils.synchronizer
1070
-
1071
- # set up spys to track synchronicity calls to _translate_scalar_in/out
1072
- translate_in_spy = MagicMock(wraps=synchronizer._translate_scalar_in)
1073
- monkeypatch.setattr(synchronizer, "_translate_scalar_in", translate_in_spy)
1074
- translate_out_spy = MagicMock(wraps=synchronizer._translate_scalar_out)
1075
- monkeypatch.setattr(synchronizer, "_translate_scalar_out", translate_out_spy)
1076
-
1077
- # don't do blobbing for this test
1078
- monkeypatch.setattr("modal._container_entrypoint.MAX_OBJECT_SIZE_BYTES", 1e100)
1079
-
1080
- large_data_list = list(range(int(1e6))) # large data set
1081
-
1082
- t0 = time.perf_counter()
1083
- # pr = cProfile.Profile()
1084
- # pr.enable()
1085
- _run_container(
1086
- unix_servicer,
1087
- "test.supports.functions",
1088
- "ident",
1089
- inputs=_get_inputs(((large_data_list,), {})),
1090
- )
1091
- # pr.disable()
1092
- # pr.print_stats()
1093
- duration = time.perf_counter() - t0
1094
- assert duration < 5.0 # TODO (elias): might be able to get this down significantly more by improving serialization
1095
-
1096
- # function_io_manager.serialize(large_data_list)
1097
- in_translations = []
1098
- out_translations = []
1099
- for call in translate_in_spy.call_args_list:
1100
- in_translations += list(call.args)
1101
- for call in translate_out_spy.call_args_list:
1102
- out_translations += list(call.args)
1103
-
1104
- assert len(in_translations) < 1000 # typically 136 or something
1105
- assert len(out_translations) < 2000
1106
-
1107
-
1108
- def _run_container_process(
1109
- servicer,
1110
- module_name,
1111
- function_name,
1112
- *,
1113
- inputs: List[Tuple[Tuple, Dict[str, Any]]],
1114
- allow_concurrent_inputs: Optional[int] = None,
1115
- cls_params: Tuple[Tuple, Dict[str, Any]] = ((), {}),
1116
- print=False, # for debugging - print directly to stdout/stderr instead of pipeing
1117
- env={},
1118
- ) -> subprocess.Popen:
1119
- container_args = _container_args(
1120
- module_name,
1121
- function_name,
1122
- allow_concurrent_inputs=allow_concurrent_inputs,
1123
- serialized_params=serialize(cls_params),
1124
- )
1125
- encoded_container_args = base64.b64encode(container_args.SerializeToString())
1126
- servicer.container_inputs = _get_multi_inputs(inputs)
1127
- return subprocess.Popen(
1128
- [sys.executable, "-m", "modal._container_entrypoint", encoded_container_args],
1129
- env={**os.environ, **env},
1130
- stdout=subprocess.PIPE if not print else None,
1131
- stderr=subprocess.PIPE if not print else None,
1132
- )
1133
-
1134
-
1135
- @skip_windows_signals
1136
- @pytest.mark.usefixtures("server_url_env")
1137
- @pytest.mark.parametrize(
1138
- ["function_name", "input_args", "cancelled_input_ids", "expected_container_output", "live_cancellations"],
1139
- [
1140
- # the 10 second inputs here are to be cancelled:
1141
- ("delay", [0.01, 20, 0.02], ["in-001"], [0.01, 0.02], 1), # cancel second input
1142
- ("delay_async", [0.01, 20, 0.02], ["in-001"], [0.01, 0.02], 1), # async variant
1143
- # cancel first input, but it has already been processed, so all three should come through:
1144
- ("delay", [0.01, 0.5, 0.03], ["in-000"], [0.01, 0.5, 0.03], 0),
1145
- ("delay_async", [0.01, 0.5, 0.03], ["in-000"], [0.01, 0.5, 0.03], 0),
1146
- ],
1147
- )
1148
- def test_cancellation_aborts_current_input_on_match(
1149
- servicer, function_name, input_args, cancelled_input_ids, expected_container_output, live_cancellations
1150
- ):
1151
- # NOTE: for a cancellation to actually happen in this test, it needs to be
1152
- # triggered while the relevant input is being processed. A future input
1153
- # would not be cancelled, since those are expected to be handled by
1154
- # the backend
1155
- with servicer.input_lockstep() as input_lock:
1156
- container_process = _run_container_process(
1157
- servicer,
1158
- "test.supports.functions",
1159
- function_name,
1160
- inputs=[((arg,), {}) for arg in input_args],
1161
- )
1162
- time.sleep(1)
1163
- input_lock.wait()
1164
- input_lock.wait()
1165
- # second input has been sent to container here
1166
- time.sleep(0.05) # give it a little time to start processing
1167
-
1168
- # now let container receive container heartbeat indicating there is a cancellation
1169
- t0 = time.monotonic()
1170
- num_prior_outputs = len(_flatten_outputs(servicer.container_outputs))
1171
- assert num_prior_outputs == 1 # the second input shouldn't have completed yet
1172
-
1173
- servicer.container_heartbeat_return_now(
1174
- api_pb2.ContainerHeartbeatResponse(cancel_input_event=api_pb2.CancelInputEvent(input_ids=cancelled_input_ids))
1175
- )
1176
- stdout, stderr = container_process.communicate()
1177
- assert stderr.decode().count("was cancelled by a user request") == live_cancellations
1178
- assert "Traceback" not in stderr.decode()
1179
- assert container_process.returncode == 0 # wait for container to exit
1180
- duration = time.monotonic() - t0 # time from heartbeat to container exit
1181
-
1182
- items = _flatten_outputs(servicer.container_outputs)
1183
- assert len(items) == len(expected_container_output)
1184
- data = [deserialize(i.result.data, client=None) for i in items]
1185
- assert data == expected_container_output
1186
- # should never run for ~20s, which is what the input would take if the sleep isn't interrupted
1187
- assert duration < 10 # should typically be < 1s, but for some reason in gh actions, it takes a really long time!
1188
-
1189
-
1190
- @skip_windows_signals
1191
- @pytest.mark.usefixtures("server_url_env")
1192
- @pytest.mark.parametrize(
1193
- ["function_name"],
1194
- [("delay",), ("delay_async",)],
1195
- )
1196
- def test_cancellation_stops_task_with_concurrent_inputs(servicer, function_name):
1197
- # send three inputs in container: in-100, in-101, in-102
1198
- with servicer.input_lockstep() as input_lock:
1199
- container_process = _run_container_process(
1200
- servicer, "test.supports.functions", function_name, inputs=[((20,), {})], allow_concurrent_inputs=2
1201
- )
1202
- input_lock.wait()
1203
-
1204
- time.sleep(0.05) # let the container get and start processing the input
1205
- servicer.container_heartbeat_return_now(
1206
- api_pb2.ContainerHeartbeatResponse(cancel_input_event=api_pb2.CancelInputEvent(input_ids=["in-000"]))
1207
- )
1208
- # container should exit soon!
1209
- exit_code = container_process.wait(5)
1210
- assert exit_code == 0 # container should exit gracefully
1211
-
1212
-
1213
- @skip_windows_signals
1214
- @pytest.mark.usefixtures("server_url_env")
1215
- def test_lifecycle_full(servicer):
1216
- # Sync and async container lifecycle methods on a sync function.
1217
- container_process = _run_container_process(
1218
- servicer, "test.supports.functions", "LifecycleCls.f_sync", inputs=[((), {})], cls_params=((True,), {})
1219
- )
1220
- stdout, _ = container_process.communicate(timeout=5)
1221
- assert container_process.returncode == 0
1222
- assert "[events:enter_sync,enter_async,f_sync,exit_sync,exit_async]" in stdout.decode()
1223
-
1224
- # Sync and async container lifecycle methods on an async function.
1225
- container_process = _run_container_process(
1226
- servicer, "test.supports.functions", "LifecycleCls.f_async", inputs=[((), {})], cls_params=((True,), {})
1227
- )
1228
- stdout, _ = container_process.communicate(timeout=5)
1229
- assert container_process.returncode == 0
1230
- assert "[events:enter_sync,enter_async,f_async,exit_sync,exit_async]" in stdout.decode()
1231
-
1232
-
1233
- ## modal.experimental functionality ##
1234
-
1235
-
1236
- @skip_windows_unix_socket
1237
- def test_stop_fetching_inputs(unix_servicer):
1238
- ret = _run_container(
1239
- unix_servicer,
1240
- "test.supports.experimental",
1241
- "StopFetching.after_two",
1242
- inputs=_get_inputs(((42,), {}), n=4, kill_switch=False),
1243
- )
1244
-
1245
- assert len(ret.items) == 2
1246
- assert ret.items[0].result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
1247
-
1248
-
1249
- @skip_windows_unix_socket
1250
- def test_container_heartbeat_survives_grpc_deadlines(servicer, caplog, monkeypatch):
1251
- monkeypatch.setattr("modal._container_entrypoint.HEARTBEAT_INTERVAL", 0.01)
1252
- num_heartbeats = 0
1253
-
1254
- async def heartbeat_responder(servicer, stream):
1255
- nonlocal num_heartbeats
1256
- num_heartbeats += 1
1257
- await stream.recv_message()
1258
- raise GRPCError(Status.DEADLINE_EXCEEDED)
1259
-
1260
- with servicer.intercept() as ctx:
1261
- ctx.set_responder("ContainerHeartbeat", heartbeat_responder)
1262
- ret = _run_container(
1263
- servicer,
1264
- "test.supports.functions",
1265
- "delay",
1266
- inputs=_get_inputs(((2,), {})),
1267
- )
1268
- assert ret.task_result is None # should not cause a failure result
1269
- loop_iteration_failures = caplog.text.count("Heartbeat attempt failed")
1270
- assert "Traceback" not in caplog.text # should not print a full traceback - don't scare users!
1271
- assert (
1272
- loop_iteration_failures > 1
1273
- ) # one occurence per failing `retry_transient_errors()`, so fewer than the number of failing requests!
1274
- assert loop_iteration_failures < num_heartbeats
1275
- assert num_heartbeats > 4 # more than the default number of retries per heartbeat attempt + 1
1276
-
1277
-
1278
- @skip_windows_unix_socket
1279
- def test_container_heartbeat_survives_local_exceptions(servicer, caplog, monkeypatch):
1280
- numcalls = 0
1281
-
1282
- async def custom_heartbeater(self):
1283
- nonlocal numcalls
1284
- numcalls += 1
1285
- raise Exception("oops")
1286
-
1287
- monkeypatch.setattr("modal._container_entrypoint.HEARTBEAT_INTERVAL", 0.01)
1288
- monkeypatch.setattr(
1289
- "modal._container_entrypoint._FunctionIOManager._heartbeat_handle_cancellations", custom_heartbeater
1290
- )
1291
-
1292
- ret = _run_container(
1293
- servicer,
1294
- "test.supports.functions",
1295
- "delay",
1296
- inputs=_get_inputs(((0.5,), {})),
1297
- )
1298
- assert ret.task_result is None # should not cause a failure result
1299
- loop_iteration_failures = caplog.text.count("Heartbeat attempt failed")
1300
- assert loop_iteration_failures > 5
1301
- assert "error=Exception('oops')" in caplog.text
1302
- assert "Traceback" not in caplog.text # should not print a full traceback - don't scare users!
1303
-
1304
-
1305
- @skip_windows_signals
1306
- @pytest.mark.usefixtures("server_url_env")
1307
- @pytest.mark.parametrize("method", ["delay", "delay_async"])
1308
- def test_sigint_termination_input(servicer, method):
1309
- # Sync and async container lifecycle methods on a sync function.
1310
- with servicer.input_lockstep() as input_barrier:
1311
- container_process = _run_container_process(
1312
- servicer,
1313
- "test.supports.functions",
1314
- f"LifecycleCls.{method}",
1315
- inputs=[((5,), {})],
1316
- cls_params=((), {"print_at_exit": True}),
1317
- )
1318
- input_barrier.wait() # get input
1319
- time.sleep(0.5)
1320
- signal_time = time.monotonic()
1321
- os.kill(container_process.pid, signal.SIGINT)
1322
-
1323
- stdout, stderr = container_process.communicate(timeout=5)
1324
- stop_duration = time.monotonic() - signal_time
1325
- assert len(servicer.container_outputs) == 0
1326
- assert (
1327
- container_process.returncode == 0
1328
- ) # container should catch and indicate successful termination by exiting cleanly when possible
1329
- assert f"[events:enter_sync,enter_async,{method},exit_sync,exit_async]" in stdout.decode()
1330
- assert "Traceback" not in stderr.decode()
1331
- assert stop_duration < 2.0 # if this would be ~4.5s, then the input isn't getting terminated
1332
- assert servicer.task_result is None
1333
-
1334
-
1335
- @skip_windows_signals
1336
- @pytest.mark.usefixtures("server_url_env")
1337
- @pytest.mark.parametrize("enter_type", ["sync_enter", "async_enter"])
1338
- @pytest.mark.parametrize("method", ["delay", "delay_async"])
1339
- def test_sigint_termination_enter_handler(servicer, method, enter_type):
1340
- # Sync and async container lifecycle methods on a sync function.
1341
- container_process = _run_container_process(
1342
- servicer,
1343
- "test.supports.functions",
1344
- f"LifecycleCls.{method}",
1345
- inputs=[((5,), {})],
1346
- cls_params=((), {"print_at_exit": True, f"{enter_type}_duration": 10}),
1347
- )
1348
- time.sleep(1) # should be enough to start the enter method
1349
- signal_time = time.monotonic()
1350
- os.kill(container_process.pid, signal.SIGINT)
1351
- stdout, stderr = container_process.communicate(timeout=5)
1352
- stop_duration = time.monotonic() - signal_time
1353
- assert len(servicer.container_outputs) == 0
1354
- assert container_process.returncode == 0
1355
- if enter_type == "sync_enter":
1356
- assert "[events:enter_sync]" in stdout.decode()
1357
- else:
1358
- # enter_sync should run in 0s, and then we interrupt during the async enter
1359
- assert "[events:enter_sync,enter_async]" in stdout.decode()
1360
-
1361
- assert "Traceback" not in stderr.decode()
1362
- assert stop_duration < 2.0 # if this would be ~4.5s, then the task isn't being terminated timely
1363
- assert servicer.task_result is None
1364
-
1365
-
1366
- @skip_windows_signals
1367
- @pytest.mark.usefixtures("server_url_env")
1368
- @pytest.mark.parametrize("exit_type", ["sync_exit", "async_exit"])
1369
- def test_sigint_termination_exit_handler(servicer, exit_type):
1370
- # Sync and async container lifecycle methods on a sync function.
1371
- with servicer.output_lockstep() as outputs:
1372
- container_process = _run_container_process(
1373
- servicer,
1374
- "test.supports.functions",
1375
- "LifecycleCls.delay",
1376
- inputs=[((0,), {})],
1377
- cls_params=((), {"print_at_exit": True, f"{exit_type}_duration": 2}),
1378
- )
1379
- outputs.wait() # wait for first output to be emitted
1380
- time.sleep(1) # give some time for container to end up in the exit handler
1381
- os.kill(container_process.pid, signal.SIGINT)
1382
-
1383
- stdout, stderr = container_process.communicate(timeout=5)
1384
-
1385
- assert len(servicer.container_outputs) == 1
1386
- assert container_process.returncode == 0
1387
- assert "[events:enter_sync,enter_async,delay,exit_sync,exit_async]" in stdout.decode()
1388
- assert "Traceback" not in stderr.decode()
1389
- assert servicer.task_result is None