modal 0.62.115__py3-none-any.whl → 0.72.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +13 -9
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +402 -398
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -60
  11. modal/_resources.py +26 -7
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1025 -0
  15. modal/{execution_context.py → _runtime/execution_context.py} +11 -2
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +123 -6
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +50 -14
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +386 -104
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +299 -98
  29. modal/_utils/grpc_testing.py +47 -34
  30. modal/_utils/grpc_utils.py +54 -21
  31. modal/_utils/hash_utils.py +51 -10
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +3 -3
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +12 -10
  43. modal/app.py +561 -323
  44. modal/app.pyi +474 -262
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +22 -6
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +203 -42
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +61 -13
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +21 -48
  55. modal/cli/launch.py +28 -14
  56. modal/cli/network_file_system.py +57 -21
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +34 -9
  59. modal/cli/programs/vscode.py +58 -8
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +199 -96
  62. modal/cli/secret.py +5 -4
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +74 -8
  65. modal/cli/volume.py +97 -56
  66. modal/client.py +248 -144
  67. modal/client.pyi +156 -124
  68. modal/cloud_bucket_mount.py +43 -30
  69. modal/cloud_bucket_mount.pyi +32 -25
  70. modal/cls.py +528 -141
  71. modal/cls.pyi +189 -145
  72. modal/config.py +32 -15
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +50 -54
  76. modal/dict.pyi +120 -164
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +30 -43
  80. modal/experimental.py +62 -2
  81. modal/file_io.py +537 -0
  82. modal/file_io.pyi +235 -0
  83. modal/file_pattern_matcher.py +196 -0
  84. modal/functions.py +846 -428
  85. modal/functions.pyi +446 -387
  86. modal/gpu.py +57 -44
  87. modal/image.py +943 -417
  88. modal/image.pyi +584 -245
  89. modal/io_streams.py +434 -0
  90. modal/io_streams.pyi +122 -0
  91. modal/mount.py +223 -90
  92. modal/mount.pyi +241 -243
  93. modal/network_file_system.py +85 -86
  94. modal/network_file_system.pyi +151 -110
  95. modal/object.py +66 -36
  96. modal/object.pyi +166 -143
  97. modal/output.py +63 -0
  98. modal/parallel_map.py +73 -47
  99. modal/parallel_map.pyi +51 -63
  100. modal/partial_function.py +272 -107
  101. modal/partial_function.pyi +219 -120
  102. modal/proxy.py +15 -12
  103. modal/proxy.pyi +3 -8
  104. modal/queue.py +96 -72
  105. modal/queue.pyi +210 -135
  106. modal/requirements/2024.04.txt +2 -1
  107. modal/requirements/2024.10.txt +16 -0
  108. modal/requirements/README.md +21 -0
  109. modal/requirements/base-images.json +22 -0
  110. modal/retries.py +45 -4
  111. modal/runner.py +325 -203
  112. modal/runner.pyi +124 -110
  113. modal/running_app.py +27 -4
  114. modal/sandbox.py +509 -231
  115. modal/sandbox.pyi +396 -169
  116. modal/schedule.py +2 -2
  117. modal/scheduler_placement.py +20 -3
  118. modal/secret.py +41 -25
  119. modal/secret.pyi +62 -42
  120. modal/serving.py +39 -49
  121. modal/serving.pyi +37 -43
  122. modal/stream_type.py +15 -0
  123. modal/token_flow.py +5 -3
  124. modal/token_flow.pyi +37 -32
  125. modal/volume.py +123 -137
  126. modal/volume.pyi +228 -221
  127. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/METADATA +5 -5
  128. modal-0.72.13.dist-info/RECORD +174 -0
  129. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/top_level.txt +0 -1
  130. modal_docs/gen_reference_docs.py +3 -1
  131. modal_docs/mdmd/mdmd.py +0 -1
  132. modal_docs/mdmd/signatures.py +1 -2
  133. modal_global_objects/images/base_images.py +28 -0
  134. modal_global_objects/mounts/python_standalone.py +2 -2
  135. modal_proto/__init__.py +1 -1
  136. modal_proto/api.proto +1231 -531
  137. modal_proto/api_grpc.py +750 -430
  138. modal_proto/api_pb2.py +2102 -1176
  139. modal_proto/api_pb2.pyi +8859 -0
  140. modal_proto/api_pb2_grpc.py +1329 -675
  141. modal_proto/api_pb2_grpc.pyi +1416 -0
  142. modal_proto/modal_api_grpc.py +149 -0
  143. modal_proto/modal_options_grpc.py +3 -0
  144. modal_proto/options_pb2.pyi +20 -0
  145. modal_proto/options_pb2_grpc.pyi +7 -0
  146. modal_proto/py.typed +0 -0
  147. modal_version/__init__.py +1 -1
  148. modal_version/_version_generated.py +2 -2
  149. modal/_asgi.py +0 -370
  150. modal/_container_exec.py +0 -128
  151. modal/_container_io_manager.py +0 -646
  152. modal/_container_io_manager.pyi +0 -412
  153. modal/_sandbox_shell.py +0 -49
  154. modal/app_utils.py +0 -20
  155. modal/app_utils.pyi +0 -17
  156. modal/execution_context.pyi +0 -37
  157. modal/shared_volume.py +0 -23
  158. modal/shared_volume.pyi +0 -24
  159. modal-0.62.115.dist-info/RECORD +0 -207
  160. modal_global_objects/images/conda.py +0 -15
  161. modal_global_objects/images/debian_slim.py +0 -15
  162. modal_global_objects/images/micromamba.py +0 -15
  163. test/__init__.py +0 -1
  164. test/aio_test.py +0 -12
  165. test/async_utils_test.py +0 -279
  166. test/blob_test.py +0 -67
  167. test/cli_imports_test.py +0 -149
  168. test/cli_test.py +0 -674
  169. test/client_test.py +0 -203
  170. test/cloud_bucket_mount_test.py +0 -22
  171. test/cls_test.py +0 -636
  172. test/config_test.py +0 -149
  173. test/conftest.py +0 -1485
  174. test/container_app_test.py +0 -50
  175. test/container_test.py +0 -1405
  176. test/cpu_test.py +0 -23
  177. test/decorator_test.py +0 -85
  178. test/deprecation_test.py +0 -34
  179. test/dict_test.py +0 -51
  180. test/e2e_test.py +0 -68
  181. test/error_test.py +0 -7
  182. test/function_serialization_test.py +0 -32
  183. test/function_test.py +0 -791
  184. test/function_utils_test.py +0 -101
  185. test/gpu_test.py +0 -159
  186. test/grpc_utils_test.py +0 -82
  187. test/helpers.py +0 -47
  188. test/image_test.py +0 -814
  189. test/live_reload_test.py +0 -80
  190. test/lookup_test.py +0 -70
  191. test/mdmd_test.py +0 -329
  192. test/mount_test.py +0 -162
  193. test/mounted_files_test.py +0 -327
  194. test/network_file_system_test.py +0 -188
  195. test/notebook_test.py +0 -66
  196. test/object_test.py +0 -41
  197. test/package_utils_test.py +0 -25
  198. test/queue_test.py +0 -115
  199. test/resolver_test.py +0 -59
  200. test/retries_test.py +0 -67
  201. test/runner_test.py +0 -85
  202. test/sandbox_test.py +0 -191
  203. test/schedule_test.py +0 -15
  204. test/scheduler_placement_test.py +0 -57
  205. test/secret_test.py +0 -89
  206. test/serialization_test.py +0 -50
  207. test/stub_composition_test.py +0 -10
  208. test/stub_test.py +0 -361
  209. test/test_asgi_wrapper.py +0 -234
  210. test/token_flow_test.py +0 -18
  211. test/traceback_test.py +0 -135
  212. test/tunnel_test.py +0 -29
  213. test/utils_test.py +0 -88
  214. test/version_test.py +0 -14
  215. test/volume_test.py +0 -397
  216. test/watcher_test.py +0 -58
  217. test/webhook_test.py +0 -145
  218. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/LICENSE +0 -0
  219. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/WHEEL +0 -0
  220. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,1025 @@
1
+ # Copyright Modal Labs 2024
2
+ import asyncio
3
+ import importlib.metadata
4
+ import inspect
5
+ import json
6
+ import math
7
+ import os
8
+ import signal
9
+ import sys
10
+ import time
11
+ import traceback
12
+ from collections.abc import AsyncGenerator, AsyncIterator
13
+ from contextlib import AsyncExitStack
14
+ from pathlib import Path
15
+ from typing import (
16
+ TYPE_CHECKING,
17
+ Any,
18
+ Callable,
19
+ ClassVar,
20
+ Optional,
21
+ )
22
+
23
+ from google.protobuf.empty_pb2 import Empty
24
+ from grpclib import Status
25
+ from synchronicity.async_wrap import asynccontextmanager
26
+
27
+ import modal_proto.api_pb2
28
+ from modal._serialization import deserialize, serialize, serialize_data_format
29
+ from modal._traceback import extract_traceback, print_exception
30
+ from modal._utils.async_utils import TaskContext, asyncify, synchronize_api, synchronizer
31
+ from modal._utils.blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload
32
+ from modal._utils.function_utils import _stream_function_call_data
33
+ from modal._utils.grpc_utils import retry_transient_errors
34
+ from modal._utils.package_utils import parse_major_minor_version
35
+ from modal.client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client
36
+ from modal.config import config, logger
37
+ from modal.exception import ClientClosed, InputCancellation, InvalidError, SerializationError
38
+ from modal_proto import api_pb2
39
+
40
+ if TYPE_CHECKING:
41
+ import modal._runtime.asgi
42
+ import modal._runtime.user_code_imports
43
+
44
+
45
+ DYNAMIC_CONCURRENCY_INTERVAL_SECS = 3
46
+ DYNAMIC_CONCURRENCY_TIMEOUT_SECS = 10
47
+ MAX_OUTPUT_BATCH_SIZE: int = 49
48
+
49
+ RTT_S: float = 0.5 # conservative estimate of RTT in seconds.
50
+
51
+
52
+ class UserException(Exception):
53
+ """Used to shut down the task gracefully."""
54
+
55
+
56
+ class Sentinel:
57
+ """Used to get type-stubs to work with this object."""
58
+
59
+
60
+ class IOContext:
61
+ """Context object for managing input, function calls, and function executions
62
+ in a batched or single input context.
63
+ """
64
+
65
+ input_ids: list[str]
66
+ function_call_ids: list[str]
67
+ finalized_function: "modal._runtime.user_code_imports.FinalizedFunction"
68
+
69
+ _cancel_issued: bool = False
70
+ _cancel_callback: Optional[Callable[[], None]] = None
71
+
72
+ def __init__(
73
+ self,
74
+ input_ids: list[str],
75
+ function_call_ids: list[str],
76
+ finalized_function: "modal._runtime.user_code_imports.FinalizedFunction",
77
+ function_inputs: list[api_pb2.FunctionInput],
78
+ is_batched: bool,
79
+ client: _Client,
80
+ ):
81
+ self.input_ids = input_ids
82
+ self.function_call_ids = function_call_ids
83
+ self.finalized_function = finalized_function
84
+ self._function_inputs = function_inputs
85
+ self._is_batched = is_batched
86
+ self._client = client
87
+
88
+ @classmethod
89
+ async def create(
90
+ cls,
91
+ client: _Client,
92
+ finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
93
+ inputs: list[tuple[str, str, api_pb2.FunctionInput]],
94
+ is_batched: bool,
95
+ ) -> "IOContext":
96
+ assert len(inputs) >= 1 if is_batched else len(inputs) == 1
97
+ input_ids, function_call_ids, function_inputs = zip(*inputs)
98
+
99
+ async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput) -> api_pb2.FunctionInput:
100
+ # If we got a pointer to a blob, download it from S3.
101
+ if input.WhichOneof("args_oneof") == "args_blob_id":
102
+ args = await blob_download(input.args_blob_id, client.stub)
103
+ # Mutating
104
+ input.ClearField("args_blob_id")
105
+ input.args = args
106
+
107
+ return input
108
+
109
+ function_inputs = await asyncio.gather(*[_populate_input_blobs(client, input) for input in function_inputs])
110
+ # check every input in batch executes the same function
111
+ method_name = function_inputs[0].method_name
112
+ assert all(method_name == input.method_name for input in function_inputs)
113
+ finalized_function = finalized_functions[method_name]
114
+ return cls(input_ids, function_call_ids, finalized_function, function_inputs, is_batched, client)
115
+
116
+ def set_cancel_callback(self, cb: Callable[[], None]):
117
+ self._cancel_callback = cb
118
+
119
+ def cancel(self):
120
+ # Ensure we only issue the cancellation once.
121
+ if self._cancel_issued:
122
+ return
123
+
124
+ if self._cancel_callback:
125
+ logger.warning(f"Received a cancellation signal while processing input {self.input_ids}")
126
+ self._cancel_issued = True
127
+ self._cancel_callback()
128
+ else:
129
+ # TODO (elias): This should not normally happen but there is a small chance of a race
130
+ # between creating a new task for an input and attaching the cancellation callback
131
+ logger.warning("Unexpected: Could not cancel input")
132
+
133
+ def _args_and_kwargs(self) -> tuple[tuple[Any, ...], dict[str, list[Any]]]:
134
+ # deserializing here instead of the constructor
135
+ # to make sure we handle user exceptions properly
136
+ # and don't retry
137
+ deserialized_args = [
138
+ deserialize(input.args, self._client) if input.args else ((), {}) for input in self._function_inputs
139
+ ]
140
+ if not self._is_batched:
141
+ return deserialized_args[0]
142
+
143
+ func_name = self.finalized_function.callable.__name__
144
+
145
+ param_names = []
146
+ for param in inspect.signature(self.finalized_function.callable).parameters.values():
147
+ param_names.append(param.name)
148
+
149
+ # aggregate args and kwargs of all inputs into a kwarg dict
150
+ kwargs_by_inputs: list[dict[str, Any]] = [{} for _ in range(len(self.input_ids))]
151
+
152
+ for i, (args, kwargs) in enumerate(deserialized_args):
153
+ # check that all batched inputs should have the same number of args and kwargs
154
+ if (num_params := len(args) + len(kwargs)) != len(param_names):
155
+ raise InvalidError(
156
+ f"Modal batched function {func_name} takes {len(param_names)} positional arguments, but one invocation in the batch has {num_params}." # noqa
157
+ )
158
+
159
+ for j, arg in enumerate(args):
160
+ kwargs_by_inputs[i][param_names[j]] = arg
161
+ for k, v in kwargs.items():
162
+ if k not in param_names:
163
+ raise InvalidError(
164
+ f"Modal batched function {func_name} got unexpected keyword argument {k} in one invocation in the batch." # noqa
165
+ )
166
+ if k in kwargs_by_inputs[i]:
167
+ raise InvalidError(
168
+ f"Modal batched function {func_name} got multiple values for argument {k} in one invocation in the batch." # noqa
169
+ )
170
+ kwargs_by_inputs[i][k] = v
171
+
172
+ formatted_kwargs = {
173
+ param_name: [kwargs[param_name] for kwargs in kwargs_by_inputs] for param_name in param_names
174
+ }
175
+ return (), formatted_kwargs
176
+
177
+ def call_finalized_function(self) -> Any:
178
+ logger.debug(f"Starting input {self.input_ids}")
179
+ args, kwargs = self._args_and_kwargs()
180
+ res = self.finalized_function.callable(*args, **kwargs)
181
+ logger.debug(f"Finished input {self.input_ids}")
182
+ return res
183
+
184
+ def validate_output_data(self, data: Any) -> list[Any]:
185
+ if not self._is_batched:
186
+ return [data]
187
+
188
+ function_name = self.finalized_function.callable.__name__
189
+ if not isinstance(data, list):
190
+ raise InvalidError(f"Output of batched function {function_name} must be a list.")
191
+ if len(data) != len(self.input_ids):
192
+ raise InvalidError(
193
+ f"Output of batched function {function_name} must be a list of equal length as its inputs."
194
+ )
195
+ return data
196
+
197
+
198
+ class InputSlots:
199
+ """A semaphore that allows dynamically adjusting the concurrency."""
200
+
201
+ active: int
202
+ value: int
203
+ waiter: Optional[asyncio.Future]
204
+ closed: bool
205
+
206
+ def __init__(self, value: int) -> None:
207
+ self.active = 0
208
+ self.value = value
209
+ self.waiter = None
210
+ self.closed = False
211
+
212
+ async def acquire(self) -> None:
213
+ if self.active < self.value:
214
+ self.active += 1
215
+ elif self.waiter is None:
216
+ self.waiter = asyncio.get_running_loop().create_future()
217
+ await self.waiter
218
+ else:
219
+ raise RuntimeError("Concurrent waiters are not supported.")
220
+
221
+ def _wake_waiter(self) -> None:
222
+ if self.active < self.value and self.waiter is not None:
223
+ if not self.waiter.cancelled(): # could have been cancelled during interpreter shutdown
224
+ self.waiter.set_result(None)
225
+ self.waiter = None
226
+ self.active += 1
227
+
228
+ def release(self) -> None:
229
+ self.active -= 1
230
+ self._wake_waiter()
231
+
232
+ def set_value(self, value: int) -> None:
233
+ if self.closed:
234
+ return
235
+ self.value = value
236
+ self._wake_waiter()
237
+
238
+ async def close(self) -> None:
239
+ self.closed = True
240
+ for _ in range(self.value):
241
+ await self.acquire()
242
+
243
+
244
+ class _ContainerIOManager:
245
+ """Synchronizes all RPC calls and network operations for a running container.
246
+
247
+ TODO: maybe we shouldn't synchronize the whole class.
248
+ Then we could potentially move a bunch of the global functions onto it.
249
+ """
250
+
251
+ task_id: str
252
+ function_id: str
253
+ app_id: str
254
+ function_def: api_pb2.Function
255
+ checkpoint_id: Optional[str]
256
+
257
+ calls_completed: int
258
+ total_user_time: float
259
+ current_input_id: Optional[str]
260
+ current_inputs: dict[str, IOContext] # input_id -> IOContext
261
+ current_input_started_at: Optional[float]
262
+
263
+ _target_concurrency: int
264
+ _max_concurrency: int
265
+ _concurrency_loop: Optional[asyncio.Task]
266
+ _input_slots: InputSlots
267
+
268
+ _environment_name: str
269
+ _heartbeat_loop: Optional[asyncio.Task]
270
+ _heartbeat_condition: Optional[asyncio.Condition]
271
+ _waiting_for_memory_snapshot: bool
272
+
273
+ _is_interactivity_enabled: bool
274
+ _fetching_inputs: bool
275
+
276
+ _client: _Client
277
+
278
+ _GENERATOR_STOP_SENTINEL: ClassVar[Sentinel] = Sentinel()
279
+ _singleton: ClassVar[Optional["_ContainerIOManager"]] = None
280
+
281
+ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client):
282
+ self.task_id = container_args.task_id
283
+ self.function_id = container_args.function_id
284
+ self.app_id = container_args.app_id
285
+ self.function_def = container_args.function_def
286
+ self.checkpoint_id = container_args.checkpoint_id or None
287
+
288
+ self.calls_completed = 0
289
+ self.total_user_time = 0.0
290
+ self.current_input_id = None
291
+ self.current_inputs = {}
292
+ self.current_input_started_at = None
293
+
294
+ if container_args.function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
295
+ target_concurrency = 1
296
+ max_concurrency = 1
297
+ else:
298
+ target_concurrency = container_args.function_def.target_concurrent_inputs or 1
299
+ max_concurrency = container_args.function_def.max_concurrent_inputs or target_concurrency
300
+
301
+ self._target_concurrency = target_concurrency
302
+ self._max_concurrency = max_concurrency
303
+ self._concurrency_loop = None
304
+ self._stop_concurrency_loop = False
305
+ self._input_slots = InputSlots(target_concurrency)
306
+
307
+ self._environment_name = container_args.environment_name
308
+ self._heartbeat_loop = None
309
+ self._heartbeat_condition = None
310
+ self._waiting_for_memory_snapshot = False
311
+
312
+ self._is_interactivity_enabled = False
313
+ self._fetching_inputs = True
314
+
315
+ self._client = client
316
+ assert isinstance(self._client, _Client)
317
+
318
+ @property
319
+ def heartbeat_condition(self) -> asyncio.Condition:
320
+ # ensures that heartbeat condition isn't assigned to an event loop until it's used for the first time
321
+ # (On Python 3.9 and below it would be assigned to the current thread's event loop on creation)
322
+ if self._heartbeat_condition is None:
323
+ self._heartbeat_condition = asyncio.Condition()
324
+ return self._heartbeat_condition
325
+
326
+ def __new__(cls, container_args: api_pb2.ContainerArguments, client: _Client) -> "_ContainerIOManager":
327
+ cls._singleton = super().__new__(cls)
328
+ cls._singleton._init(container_args, client)
329
+ return cls._singleton
330
+
331
+ @classmethod
332
+ def _reset_singleton(cls):
333
+ """Only used for tests."""
334
+ cls._singleton = None
335
+
336
+ async def hello(self):
337
+ await self._client.stub.ContainerHello(Empty())
338
+
339
+ async def _run_heartbeat_loop(self):
340
+ while 1:
341
+ t0 = time.monotonic()
342
+ try:
343
+ if await self._heartbeat_handle_cancellations():
344
+ # got a cancellation event, fine to start another heartbeat immediately
345
+ # since the cancellation queue should be empty on the worker server
346
+ # however, we wait at least 1s to prevent short-circuiting the heartbeat loop
347
+ # in case there is ever a bug. This means it will take at least 1s between
348
+ # two subsequent cancellations on the same task at the moment
349
+ await asyncio.sleep(1.0)
350
+ continue
351
+ except ClientClosed:
352
+ logger.info("Stopping heartbeat loop due to client shutdown")
353
+ break
354
+ except Exception as exc:
355
+ # don't stop heartbeat loop if there are transient exceptions!
356
+ time_elapsed = time.monotonic() - t0
357
+ error = exc
358
+ logger.warning(f"Heartbeat attempt failed ({time_elapsed=}, {error=})")
359
+
360
+ heartbeat_duration = time.monotonic() - t0
361
+ time_until_next_hearbeat = max(0.0, HEARTBEAT_INTERVAL - heartbeat_duration)
362
+ await asyncio.sleep(time_until_next_hearbeat)
363
+
364
+ async def _heartbeat_handle_cancellations(self) -> bool:
365
+ # Return True if a cancellation event was received, in that case
366
+ # we shouldn't wait too long for another heartbeat
367
+ async with self.heartbeat_condition:
368
+ # Continuously wait until `waiting_for_memory_snapshot` is false.
369
+ # TODO(matt): Verify that a `while` is necessary over an `if`. Spurious
370
+ # wakeups could allow execution to continue despite `_waiting_for_memory_snapshot`
371
+ # being true.
372
+ while self._waiting_for_memory_snapshot:
373
+ await self.heartbeat_condition.wait()
374
+
375
+ request = api_pb2.ContainerHeartbeatRequest(canceled_inputs_return_outputs_v2=True)
376
+ response = await retry_transient_errors(
377
+ self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
378
+ )
379
+
380
+ if response.HasField("cancel_input_event"):
381
+ # response.cancel_input_event.terminate_containers is never set, the server gets the worker to handle it.
382
+ input_ids_to_cancel = response.cancel_input_event.input_ids
383
+ if input_ids_to_cancel:
384
+ if self._max_concurrency > 1:
385
+ for input_id in input_ids_to_cancel:
386
+ if input_id in self.current_inputs:
387
+ self.current_inputs[input_id].cancel()
388
+
389
+ elif self.current_input_id and self.current_input_id in input_ids_to_cancel:
390
+ # This goes to a registered signal handler for sync Modal functions, or to the
391
+ # `SignalHandlingEventLoop` for async functions.
392
+ #
393
+ # We only send this signal on functions that do not have concurrent inputs enabled.
394
+ # This allows us to do fine-grained input cancellation. On sync functions, the
395
+ # SIGUSR1 signal should interrupt the main thread where user code is running,
396
+ # raising an InputCancellation() exception. On async functions, the signal should
397
+ # reach a handler in SignalHandlingEventLoop, which cancels the task.
398
+ logger.warning(f"Received a cancellation signal while processing input {self.current_input_id}")
399
+ os.kill(os.getpid(), signal.SIGUSR1)
400
+ return True
401
+ return False
402
+
403
+ @asynccontextmanager
404
+ async def heartbeats(self, wait_for_mem_snap: bool) -> AsyncGenerator[None, None]:
405
+ async with TaskContext() as tc:
406
+ self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop())
407
+ t.set_name("heartbeat loop")
408
+ self._waiting_for_memory_snapshot = wait_for_mem_snap
409
+ try:
410
+ yield
411
+ finally:
412
+ t.cancel()
413
+
414
+ def stop_heartbeat(self):
415
+ if self._heartbeat_loop:
416
+ self._heartbeat_loop.cancel()
417
+
418
+ @asynccontextmanager
419
+ async def dynamic_concurrency_manager(self) -> AsyncGenerator[None, None]:
420
+ async with TaskContext() as tc:
421
+ self._concurrency_loop = t = tc.create_task(self._dynamic_concurrency_loop())
422
+ t.set_name("dynamic concurrency loop")
423
+ try:
424
+ yield
425
+ finally:
426
+ t.cancel()
427
+
428
+ async def _dynamic_concurrency_loop(self):
429
+ logger.debug(f"Starting dynamic concurrency loop for task {self.task_id}")
430
+ while not self._stop_concurrency_loop:
431
+ try:
432
+ request = api_pb2.FunctionGetDynamicConcurrencyRequest(
433
+ function_id=self.function_id,
434
+ target_concurrency=self._target_concurrency,
435
+ max_concurrency=self._max_concurrency,
436
+ )
437
+ resp = await retry_transient_errors(
438
+ self._client.stub.FunctionGetDynamicConcurrency,
439
+ request,
440
+ attempt_timeout=DYNAMIC_CONCURRENCY_TIMEOUT_SECS,
441
+ )
442
+ if resp.concurrency != self._input_slots.value and not self._stop_concurrency_loop:
443
+ logger.debug(f"Dynamic concurrency set from {self._input_slots.value} to {resp.concurrency}")
444
+ self._input_slots.set_value(resp.concurrency)
445
+
446
+ except Exception as exc:
447
+ logger.debug(f"Failed to get dynamic concurrency for task {self.task_id}, {exc}")
448
+
449
+ await asyncio.sleep(DYNAMIC_CONCURRENCY_INTERVAL_SECS)
450
+
451
+ async def get_serialized_function(self) -> tuple[Optional[Any], Optional[Callable[..., Any]]]:
452
+ # Fetch the serialized function definition
453
+ request = api_pb2.FunctionGetSerializedRequest(function_id=self.function_id)
454
+ response = await self._client.stub.FunctionGetSerialized(request)
455
+ if response.function_serialized:
456
+ fun = self.deserialize(response.function_serialized)
457
+ else:
458
+ fun = None
459
+
460
+ if response.class_serialized:
461
+ cls = self.deserialize(response.class_serialized)
462
+ else:
463
+ cls = None
464
+
465
+ return cls, fun
466
+
467
+ def serialize(self, obj: Any) -> bytes:
468
+ return serialize(obj)
469
+
470
+ def deserialize(self, data: bytes) -> Any:
471
+ return deserialize(data, self._client)
472
+
473
+ @synchronizer.no_io_translation
474
+ def serialize_data_format(self, obj: Any, data_format: int) -> bytes:
475
+ return serialize_data_format(obj, data_format)
476
+
477
+ async def format_blob_data(self, data: bytes) -> dict[str, Any]:
478
+ return (
479
+ {"data_blob_id": await blob_upload(data, self._client.stub)}
480
+ if len(data) > MAX_OBJECT_SIZE_BYTES
481
+ else {"data": data}
482
+ )
483
+
484
+ async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]:
485
+ """Read from the `data_in` stream of a function call."""
486
+ async for data in _stream_function_call_data(self._client, function_call_id, "data_in"):
487
+ yield data
488
+
489
+ async def put_data_out(
490
+ self,
491
+ function_call_id: str,
492
+ start_index: int,
493
+ data_format: int,
494
+ messages_bytes: list[Any],
495
+ ) -> None:
496
+ """Put data onto the `data_out` stream of a function call.
497
+
498
+ This is used for generator outputs, which includes web endpoint responses. Note that this
499
+ was introduced as a performance optimization in client version 0.57, so older clients will
500
+ still use the previous Postgres-backed system based on `FunctionPutOutputs()`.
501
+ """
502
+ data_chunks: list[api_pb2.DataChunk] = []
503
+ for i, message_bytes in enumerate(messages_bytes):
504
+ chunk = api_pb2.DataChunk(data_format=data_format, index=start_index + i) # type: ignore
505
+ if len(message_bytes) > MAX_OBJECT_SIZE_BYTES:
506
+ chunk.data_blob_id = await blob_upload(message_bytes, self._client.stub)
507
+ else:
508
+ chunk.data = message_bytes
509
+ data_chunks.append(chunk)
510
+
511
+ req = api_pb2.FunctionCallPutDataRequest(function_call_id=function_call_id, data_chunks=data_chunks)
512
+ await retry_transient_errors(self._client.stub.FunctionCallPutDataOut, req)
513
+
514
+ async def generator_output_task(self, function_call_id: str, data_format: int, message_rx: asyncio.Queue) -> None:
515
+ """Task that feeds generator outputs into a function call's `data_out` stream."""
516
+ index = 1
517
+ received_sentinel = False
518
+ while not received_sentinel:
519
+ message = await message_rx.get()
520
+ if message is self._GENERATOR_STOP_SENTINEL:
521
+ break
522
+ # ASGI 'http.response.start' and 'http.response.body' msgs are observed to be separated by 1ms.
523
+ # If we don't sleep here for 1ms we end up with an extra call to .put_data_out().
524
+ if index == 1:
525
+ await asyncio.sleep(0.001)
526
+ messages_bytes = [serialize_data_format(message, data_format)]
527
+ total_size = len(messages_bytes[0]) + 512
528
+ while total_size < 16 * 1024 * 1024: # 16 MiB, maximum size in a single message
529
+ try:
530
+ message = message_rx.get_nowait()
531
+ except asyncio.QueueEmpty:
532
+ break
533
+ if message is self._GENERATOR_STOP_SENTINEL:
534
+ received_sentinel = True
535
+ break
536
+ else:
537
+ messages_bytes.append(serialize_data_format(message, data_format))
538
+ total_size += len(messages_bytes[-1]) + 512 # 512 bytes for estimated framing overhead
539
+ await self.put_data_out(function_call_id, index, data_format, messages_bytes)
540
+ index += len(messages_bytes)
541
+
542
+ async def _queue_create(self, size: int) -> asyncio.Queue:
543
+ """Create a queue, on the synchronicity event loop (needed on Python 3.8 and 3.9)."""
544
+ return asyncio.Queue(size)
545
+
546
+ async def _queue_put(self, queue: asyncio.Queue, value: Any) -> None:
547
+ """Put a value onto a queue, using the synchronicity event loop."""
548
+ await queue.put(value)
549
+
550
+ def get_average_call_time(self) -> float:
551
+ if self.calls_completed == 0:
552
+ return 0
553
+
554
+ return self.total_user_time / self.calls_completed
555
+
556
+ def get_max_inputs_to_fetch(self):
557
+ if self.calls_completed == 0:
558
+ return 1
559
+
560
+ return math.ceil(RTT_S / max(self.get_average_call_time(), 1e-6))
561
+
562
+ @synchronizer.no_io_translation
563
+ async def _generate_inputs(
564
+ self,
565
+ batch_max_size: int,
566
+ batch_wait_ms: int,
567
+ ) -> AsyncIterator[list[tuple[str, str, api_pb2.FunctionInput]]]:
568
+ request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id)
569
+ iteration = 0
570
+ while self._fetching_inputs:
571
+ await self._input_slots.acquire()
572
+
573
+ request.average_call_time = self.get_average_call_time()
574
+ request.max_values = self.get_max_inputs_to_fetch() # Deprecated; remove.
575
+ request.input_concurrency = self.get_input_concurrency()
576
+ request.batch_max_size, request.batch_linger_ms = batch_max_size, batch_wait_ms
577
+
578
+ yielded = False
579
+ try:
580
+ # If number of active inputs is at max queue size, this will block.
581
+ iteration += 1
582
+ response: api_pb2.FunctionGetInputsResponse = await retry_transient_errors(
583
+ self._client.stub.FunctionGetInputs, request
584
+ )
585
+
586
+ if response.rate_limit_sleep_duration:
587
+ logger.info(
588
+ "Task exceeded rate limit, sleeping for %.2fs before trying again."
589
+ % response.rate_limit_sleep_duration
590
+ )
591
+ await asyncio.sleep(response.rate_limit_sleep_duration)
592
+ elif response.inputs:
593
+ # for input cancellations and concurrency logic we currently assume
594
+ # that there is no input buffering in the container
595
+ assert 0 < len(response.inputs) <= max(1, request.batch_max_size)
596
+ inputs = []
597
+ final_input_received = False
598
+ for item in response.inputs:
599
+ if item.kill_switch:
600
+ logger.debug(f"Task {self.task_id} input kill signal input.")
601
+ return
602
+
603
+ inputs.append((item.input_id, item.function_call_id, item.input))
604
+ if item.input.final_input:
605
+ if request.batch_max_size > 0:
606
+ logger.debug(f"Task {self.task_id} Final input not expected in batch input stream")
607
+ final_input_received = True
608
+ break
609
+
610
+ # If yielded, allow input slots to be released via exit_context
611
+ yield inputs
612
+ yielded = True
613
+
614
+ # We only support max_inputs = 1 at the moment
615
+ if final_input_received or self.function_def.max_inputs == 1:
616
+ return
617
+ finally:
618
+ if not yielded:
619
+ self._input_slots.release()
620
+
621
+ @synchronizer.no_io_translation
622
+ async def run_inputs_outputs(
623
+ self,
624
+ finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
625
+ batch_max_size: int = 0,
626
+ batch_wait_ms: int = 0,
627
+ ) -> AsyncIterator[IOContext]:
628
+ # Ensure we do not fetch new inputs when container is too busy.
629
+ # Before trying to fetch an input, acquire an input slot:
630
+ # - if no input is fetched, release the input slot.
631
+ # - or, when the output for the fetched input is sent, release the input slot.
632
+ dynamic_concurrency_manager = (
633
+ self.dynamic_concurrency_manager() if self._max_concurrency > self._target_concurrency else AsyncExitStack()
634
+ )
635
+ async with dynamic_concurrency_manager:
636
+ async for inputs in self._generate_inputs(batch_max_size, batch_wait_ms):
637
+ io_context = await IOContext.create(self._client, finalized_functions, inputs, batch_max_size > 0)
638
+ for input_id in io_context.input_ids:
639
+ self.current_inputs[input_id] = io_context
640
+
641
+ self.current_input_id, self.current_input_started_at = io_context.input_ids[0], time.time()
642
+ yield io_context
643
+ self.current_input_id, self.current_input_started_at = (None, None)
644
+
645
+ # collect all active input slots, meaning all inputs have wrapped up.
646
+ await self._input_slots.close()
647
+
648
+ @synchronizer.no_io_translation
649
+ async def _push_outputs(
650
+ self,
651
+ io_context: IOContext,
652
+ started_at: float,
653
+ data_format: "modal_proto.api_pb2.DataFormat.ValueType",
654
+ results: list[api_pb2.GenericResult],
655
+ ) -> None:
656
+ output_created_at = time.time()
657
+ outputs = [
658
+ api_pb2.FunctionPutOutputsItem(
659
+ input_id=input_id,
660
+ input_started_at=started_at,
661
+ output_created_at=output_created_at,
662
+ result=result,
663
+ data_format=data_format,
664
+ )
665
+ for input_id, result in zip(io_context.input_ids, results)
666
+ ]
667
+ await retry_transient_errors(
668
+ self._client.stub.FunctionPutOutputs,
669
+ api_pb2.FunctionPutOutputsRequest(outputs=outputs),
670
+ additional_status_codes=[Status.RESOURCE_EXHAUSTED],
671
+ max_retries=None, # Retry indefinitely, trying every 1s.
672
+ )
673
+
674
+ def serialize_exception(self, exc: BaseException) -> bytes:
675
+ try:
676
+ return self.serialize(exc)
677
+ except Exception as serialization_exc:
678
+ # We can't always serialize exceptions.
679
+ err = f"Failed to serialize exception {exc} of type {type(exc)}: {serialization_exc}"
680
+ logger.info(err)
681
+ return self.serialize(SerializationError(err))
682
+
683
+ def serialize_traceback(self, exc: BaseException) -> tuple[Optional[bytes], Optional[bytes]]:
684
+ serialized_tb, tb_line_cache = None, None
685
+
686
+ try:
687
+ tb_dict, line_cache = extract_traceback(exc, self.task_id)
688
+ serialized_tb = self.serialize(tb_dict)
689
+ tb_line_cache = self.serialize(line_cache)
690
+ except Exception:
691
+ logger.info("Failed to serialize exception traceback.")
692
+
693
+ return serialized_tb, tb_line_cache
694
+
695
+ @asynccontextmanager
696
+ async def handle_user_exception(self) -> AsyncGenerator[None, None]:
697
+ """Sets the task as failed in a way where it's not retried.
698
+
699
+ Used for handling exceptions from container lifecycle methods at the moment, which should
700
+ trigger a task failure state.
701
+ """
702
+ try:
703
+ yield
704
+ except KeyboardInterrupt:
705
+ # Send no task result in case we get sigint:ed by the runner
706
+ # The status of the input should have been handled externally already in that case
707
+ raise
708
+ except BaseException as exc:
709
+ if isinstance(exc, ImportError):
710
+ # Catches errors raised by global scope imports
711
+ check_fastapi_pydantic_compatibility(exc)
712
+
713
+ # Since this is on a different thread, sys.exc_info() can't find the exception in the stack.
714
+ print_exception(type(exc), exc, exc.__traceback__)
715
+
716
+ serialized_tb, tb_line_cache = self.serialize_traceback(exc)
717
+
718
+ result = api_pb2.GenericResult(
719
+ status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
720
+ data=self.serialize_exception(exc),
721
+ exception=repr(exc),
722
+ traceback="".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
723
+ serialized_tb=serialized_tb or b"",
724
+ tb_line_cache=tb_line_cache or b"",
725
+ )
726
+
727
+ req = api_pb2.TaskResultRequest(result=result)
728
+ await retry_transient_errors(self._client.stub.TaskResult, req)
729
+
730
+ # Shut down the task gracefully
731
+ raise UserException()
732
+
733
+ @asynccontextmanager
734
+ async def handle_input_exception(
735
+ self,
736
+ io_context: IOContext,
737
+ started_at: float,
738
+ ) -> AsyncGenerator[None, None]:
739
+ """Handle an exception while processing a function input."""
740
+ try:
741
+ yield
742
+ except (KeyboardInterrupt, GeneratorExit):
743
+ # We need to explicitly reraise these BaseExceptions to not handle them in the catch-all:
744
+ # 1. KeyboardInterrupt can end up here even though this runs on non-main thread, since the
745
+ # code block yielded to could be sending back a main thread exception
746
+ # 2. GeneratorExit - raised if this (async) generator is garbage collected while waiting
747
+ # for the yield. Typically on event loop shutdown
748
+ raise
749
+ except (InputCancellation, asyncio.CancelledError):
750
+ # Create terminated outputs for these inputs to signal that the cancellations have been completed.
751
+ results = [
752
+ api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_TERMINATED)
753
+ for _ in io_context.input_ids
754
+ ]
755
+ await self._push_outputs(
756
+ io_context=io_context,
757
+ started_at=started_at,
758
+ data_format=api_pb2.DATA_FORMAT_PICKLE,
759
+ results=results,
760
+ )
761
+ self.exit_context(started_at, io_context.input_ids)
762
+ logger.warning(f"Successfully canceled input {io_context.input_ids}")
763
+ return
764
+ except BaseException as exc:
765
+ if isinstance(exc, ImportError):
766
+ # Catches errors raised by imports from within function body
767
+ check_fastapi_pydantic_compatibility(exc)
768
+
769
+ # print exception so it's logged
770
+ print_exception(*sys.exc_info())
771
+
772
+ serialized_tb, tb_line_cache = self.serialize_traceback(exc)
773
+
774
+ # Note: we're not serializing the traceback since it contains
775
+ # local references that means we can't unpickle it. We *are*
776
+ # serializing the exception, which may have some issues (there
777
+ # was an earlier note about it that it might not be possible
778
+ # to unpickle it in some cases). Let's watch out for issues.
779
+
780
+ repr_exc = repr(exc)
781
+ if len(repr_exc) >= MAX_OBJECT_SIZE_BYTES:
782
+ # We prevent large exception messages to avoid
783
+ # unhandled exceptions causing inf loops
784
+ # and just send backa trimmed version
785
+ trimmed_bytes = len(repr_exc) - MAX_OBJECT_SIZE_BYTES - 1000
786
+ repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000]
787
+ repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception"
788
+
789
+ data: bytes = self.serialize_exception(exc) or b""
790
+ data_result_part = await self.format_blob_data(data)
791
+ results = [
792
+ api_pb2.GenericResult(
793
+ status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
794
+ exception=repr_exc,
795
+ traceback=traceback.format_exc(),
796
+ serialized_tb=serialized_tb or b"",
797
+ tb_line_cache=tb_line_cache or b"",
798
+ **data_result_part,
799
+ )
800
+ for _ in io_context.input_ids
801
+ ]
802
+ await self._push_outputs(
803
+ io_context=io_context,
804
+ started_at=started_at,
805
+ data_format=api_pb2.DATA_FORMAT_PICKLE,
806
+ results=results,
807
+ )
808
+ self.exit_context(started_at, io_context.input_ids)
809
+
810
+ def exit_context(self, started_at, input_ids: list[str]):
811
+ self.total_user_time += time.time() - started_at
812
+ self.calls_completed += 1
813
+
814
+ for input_id in input_ids:
815
+ self.current_inputs.pop(input_id)
816
+
817
+ self._input_slots.release()
818
+
819
+ @synchronizer.no_io_translation
820
+ async def push_outputs(
821
+ self,
822
+ io_context: IOContext,
823
+ started_at: float,
824
+ data: Any,
825
+ data_format: "modal_proto.api_pb2.DataFormat.ValueType",
826
+ ) -> None:
827
+ data = io_context.validate_output_data(data)
828
+ formatted_data = await asyncio.gather(
829
+ *[self.format_blob_data(self.serialize_data_format(d, data_format)) for d in data]
830
+ )
831
+ results = [
832
+ api_pb2.GenericResult(
833
+ status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
834
+ **d,
835
+ )
836
+ for d in formatted_data
837
+ ]
838
+ await self._push_outputs(
839
+ io_context=io_context,
840
+ started_at=started_at,
841
+ data_format=data_format,
842
+ results=results,
843
+ )
844
+ self.exit_context(started_at, io_context.input_ids)
845
+
846
+ async def memory_restore(self) -> None:
847
+ # Busy-wait for restore. `/__modal/restore-state.json` is created
848
+ # by the worker process with updates to the container config.
849
+ restored_path = Path(config.get("restore_state_path"))
850
+ start = time.perf_counter()
851
+ while not restored_path.exists():
852
+ logger.debug(f"Waiting for restore (elapsed={time.perf_counter() - start:.3f}s)")
853
+ await asyncio.sleep(0.01)
854
+ continue
855
+
856
+ logger.debug("Container: restored")
857
+
858
+ # Look for state file and create new client with updated credentials.
859
+ # State data is serialized with key-value pairs, example: {"task_id": "tk-000"}
860
+ with restored_path.open("r") as file:
861
+ restored_state = json.load(file)
862
+
863
+ # Start a debugger if the worker tells us to
864
+ if int(restored_state.get("snapshot_debug", 0)):
865
+ logger.debug("Entering snapshot debugger")
866
+ breakpoint()
867
+
868
+ # Local ContainerIOManager state.
869
+ for key in ["task_id", "function_id"]:
870
+ if value := restored_state.get(key):
871
+ logger.debug(f"Updating ContainerIOManager.{key} = {value}")
872
+ setattr(self, key, restored_state[key])
873
+
874
+ # Env vars and global state.
875
+ for key, value in restored_state.items():
876
+ # Empty string indicates that value does not need to be updated.
877
+ if value != "":
878
+ config.override_locally(key, value)
879
+
880
+ # Restore input to default state.
881
+ self.current_input_id = None
882
+ self.current_inputs = {}
883
+ self.current_input_started_at = None
884
+ self._client = await _Client.from_env()
885
+
886
+ async def memory_snapshot(self) -> None:
887
+ """Message server indicating that function is ready to be checkpointed."""
888
+ if self.checkpoint_id:
889
+ logger.debug(f"Checkpoint ID: {self.checkpoint_id} (Memory Snapshot ID)")
890
+ else:
891
+ raise ValueError("No checkpoint ID provided for memory snapshot")
892
+
893
+ # Pause heartbeats since they keep the client connection open which causes the snapshotter to crash
894
+ async with self.heartbeat_condition:
895
+ # Notify the heartbeat loop that the snapshot phase has begun in order to
896
+ # prevent it from sending heartbeat RPCs
897
+ self._waiting_for_memory_snapshot = True
898
+ self.heartbeat_condition.notify_all()
899
+
900
+ await self._client.stub.ContainerCheckpoint(
901
+ api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id)
902
+ )
903
+
904
+ await self._client._close(prep_for_restore=True)
905
+
906
+ logger.debug("Memory snapshot request sent. Connection closed.")
907
+ await self.memory_restore()
908
+ # Turn heartbeats back on. This is safe since the snapshot RPC
909
+ # and the restore phase has finished.
910
+ self._waiting_for_memory_snapshot = False
911
+ self.heartbeat_condition.notify_all()
912
+
913
+ async def volume_commit(self, volume_ids: list[str]) -> None:
914
+ """
915
+ Perform volume commit for given `volume_ids`.
916
+ Only used on container exit to persist uncommitted changes on behalf of user.
917
+ """
918
+ if not volume_ids:
919
+ return
920
+ await asyncify(os.sync)()
921
+ results = await asyncio.gather(
922
+ *[
923
+ retry_transient_errors(
924
+ self._client.stub.VolumeCommit,
925
+ api_pb2.VolumeCommitRequest(volume_id=v_id),
926
+ max_retries=9,
927
+ base_delay=0.25,
928
+ max_delay=256,
929
+ delay_factor=2,
930
+ )
931
+ for v_id in volume_ids
932
+ ],
933
+ return_exceptions=True,
934
+ )
935
+ for volume_id, res in zip(volume_ids, results):
936
+ if isinstance(res, Exception):
937
+ logger.error(f"modal.Volume background commit failed for {volume_id}. Exception: {res}")
938
+ else:
939
+ logger.debug(f"modal.Volume background commit success for {volume_id}.")
940
+
941
+ async def interact(self, from_breakpoint: bool = False):
942
+ if self._is_interactivity_enabled:
943
+ # Currently, interactivity is enabled forever
944
+ return
945
+ self._is_interactivity_enabled = True
946
+
947
+ if not self.function_def.pty_info.pty_type:
948
+ trigger = "breakpoint()" if from_breakpoint else "modal.interact()"
949
+ raise InvalidError(f"Cannot use {trigger} without running Modal in interactive mode.")
950
+
951
+ try:
952
+ await self._client.stub.FunctionStartPtyShell(Empty())
953
+ except Exception as e:
954
+ logger.error("Failed to start PTY shell.")
955
+ raise e
956
+
957
+ @property
958
+ def target_concurrency(self) -> int:
959
+ return self._target_concurrency
960
+
961
+ @property
962
+ def max_concurrency(self) -> int:
963
+ return self._max_concurrency
964
+
965
+ @classmethod
966
+ def get_input_concurrency(cls) -> int:
967
+ """
968
+ Returns the number of usable input slots.
969
+
970
+ If concurrency is reduced, active slots can exceed allotted slots. Returns the larger value
971
+ in this case.
972
+ """
973
+
974
+ io_manager = cls._singleton
975
+ assert io_manager
976
+ return max(io_manager._input_slots.active, io_manager._input_slots.value)
977
+
978
+ @classmethod
979
+ def set_input_concurrency(cls, concurrency: int):
980
+ """
981
+ Edit the number of input slots.
982
+
983
+ This disables the background loop which automatically adjusts concurrency
984
+ within [target_concurrency, max_concurrency].
985
+ """
986
+ io_manager = cls._singleton
987
+ assert io_manager
988
+ io_manager._stop_concurrency_loop = True
989
+ concurrency = min(concurrency, io_manager._max_concurrency)
990
+ io_manager._input_slots.set_value(concurrency)
991
+
992
+ @classmethod
993
+ def stop_fetching_inputs(cls):
994
+ assert cls._singleton
995
+ cls._singleton._fetching_inputs = False
996
+
997
+
998
+ ContainerIOManager = synchronize_api(_ContainerIOManager)
999
+
1000
+
1001
+ def check_fastapi_pydantic_compatibility(exc: ImportError) -> None:
1002
+ """Add a helpful note to an exception that is likely caused by a pydantic<>fastapi version incompatibility.
1003
+
1004
+ We need this becasue the legacy set of container requirements (image_builder_version=2023.12) contains a
1005
+ version of fastapi that is not forwards-compatible with pydantic 2.0+, and users commonly run into issues
1006
+ building an image that specifies a more recent version only for pydantic.
1007
+ """
1008
+ note = (
1009
+ "Please ensure that your Image contains compatible versions of fastapi and pydantic."
1010
+ " If using pydantic>=2.0, you must also install fastapi>=0.100."
1011
+ )
1012
+ name = exc.name or ""
1013
+ if name.startswith("pydantic"):
1014
+ try:
1015
+ fastapi_version = parse_major_minor_version(importlib.metadata.version("fastapi"))
1016
+ pydantic_version = parse_major_minor_version(importlib.metadata.version("pydantic"))
1017
+ if pydantic_version >= (2, 0) and fastapi_version < (0, 100):
1018
+ if sys.version_info < (3, 11):
1019
+ # https://peps.python.org/pep-0678/
1020
+ exc.__notes__ = [note]
1021
+ else:
1022
+ exc.add_note(note)
1023
+ except Exception:
1024
+ # Since we're just trying to add a helpful message, don't fail here
1025
+ pass