modal 0.62.16__py3-none-any.whl → 0.72.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +17 -13
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +420 -937
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -59
  11. modal/_resources.py +51 -0
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1036 -0
  15. modal/_runtime/execution_context.py +89 -0
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +134 -9
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +52 -16
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +479 -100
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +460 -171
  29. modal/_utils/grpc_testing.py +47 -31
  30. modal/_utils/grpc_utils.py +62 -109
  31. modal/_utils/hash_utils.py +61 -19
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +5 -7
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +14 -12
  43. modal/app.py +1003 -314
  44. modal/app.pyi +540 -264
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +63 -53
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +205 -45
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +62 -14
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +64 -58
  55. modal/cli/launch.py +32 -18
  56. modal/cli/network_file_system.py +64 -83
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +35 -10
  59. modal/cli/programs/vscode.py +60 -10
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +234 -131
  62. modal/cli/secret.py +8 -7
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +79 -10
  65. modal/cli/volume.py +110 -109
  66. modal/client.py +250 -144
  67. modal/client.pyi +157 -118
  68. modal/cloud_bucket_mount.py +108 -34
  69. modal/cloud_bucket_mount.pyi +32 -38
  70. modal/cls.py +535 -148
  71. modal/cls.pyi +190 -146
  72. modal/config.py +41 -19
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +111 -65
  76. modal/dict.pyi +136 -131
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +34 -43
  80. modal/experimental.py +61 -2
  81. modal/extensions/ipython.py +5 -5
  82. modal/file_io.py +537 -0
  83. modal/file_io.pyi +235 -0
  84. modal/file_pattern_matcher.py +197 -0
  85. modal/functions.py +906 -911
  86. modal/functions.pyi +466 -430
  87. modal/gpu.py +57 -44
  88. modal/image.py +1089 -479
  89. modal/image.pyi +584 -228
  90. modal/io_streams.py +434 -0
  91. modal/io_streams.pyi +122 -0
  92. modal/mount.py +314 -101
  93. modal/mount.pyi +241 -235
  94. modal/network_file_system.py +92 -92
  95. modal/network_file_system.pyi +152 -110
  96. modal/object.py +67 -36
  97. modal/object.pyi +166 -143
  98. modal/output.py +63 -0
  99. modal/parallel_map.py +434 -0
  100. modal/parallel_map.pyi +75 -0
  101. modal/partial_function.py +282 -117
  102. modal/partial_function.pyi +222 -129
  103. modal/proxy.py +15 -12
  104. modal/proxy.pyi +3 -8
  105. modal/queue.py +182 -65
  106. modal/queue.pyi +218 -118
  107. modal/requirements/2024.04.txt +29 -0
  108. modal/requirements/2024.10.txt +16 -0
  109. modal/requirements/README.md +21 -0
  110. modal/requirements/base-images.json +22 -0
  111. modal/retries.py +48 -7
  112. modal/runner.py +459 -156
  113. modal/runner.pyi +135 -71
  114. modal/running_app.py +38 -0
  115. modal/sandbox.py +514 -236
  116. modal/sandbox.pyi +397 -169
  117. modal/schedule.py +4 -4
  118. modal/scheduler_placement.py +20 -3
  119. modal/secret.py +56 -31
  120. modal/secret.pyi +62 -42
  121. modal/serving.py +51 -56
  122. modal/serving.pyi +44 -36
  123. modal/stream_type.py +15 -0
  124. modal/token_flow.py +5 -3
  125. modal/token_flow.pyi +37 -32
  126. modal/volume.py +285 -157
  127. modal/volume.pyi +249 -184
  128. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/METADATA +7 -7
  129. modal-0.72.11.dist-info/RECORD +174 -0
  130. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
  131. modal_docs/gen_reference_docs.py +3 -1
  132. modal_docs/mdmd/mdmd.py +0 -1
  133. modal_docs/mdmd/signatures.py +5 -2
  134. modal_global_objects/images/base_images.py +28 -0
  135. modal_global_objects/mounts/python_standalone.py +2 -2
  136. modal_proto/__init__.py +1 -1
  137. modal_proto/api.proto +1288 -533
  138. modal_proto/api_grpc.py +856 -456
  139. modal_proto/api_pb2.py +2165 -1157
  140. modal_proto/api_pb2.pyi +8859 -0
  141. modal_proto/api_pb2_grpc.py +1674 -855
  142. modal_proto/api_pb2_grpc.pyi +1416 -0
  143. modal_proto/modal_api_grpc.py +149 -0
  144. modal_proto/modal_options_grpc.py +3 -0
  145. modal_proto/options_pb2.pyi +20 -0
  146. modal_proto/options_pb2_grpc.pyi +7 -0
  147. modal_proto/py.typed +0 -0
  148. modal_version/__init__.py +1 -1
  149. modal_version/_version_generated.py +2 -2
  150. modal/_asgi.py +0 -370
  151. modal/_container_entrypoint.pyi +0 -378
  152. modal/_container_exec.py +0 -128
  153. modal/_sandbox_shell.py +0 -49
  154. modal/shared_volume.py +0 -23
  155. modal/shared_volume.pyi +0 -24
  156. modal/stub.py +0 -783
  157. modal/stub.pyi +0 -332
  158. modal-0.62.16.dist-info/RECORD +0 -198
  159. modal_global_objects/images/conda.py +0 -15
  160. modal_global_objects/images/debian_slim.py +0 -15
  161. modal_global_objects/images/micromamba.py +0 -15
  162. test/__init__.py +0 -1
  163. test/aio_test.py +0 -12
  164. test/async_utils_test.py +0 -262
  165. test/blob_test.py +0 -67
  166. test/cli_imports_test.py +0 -149
  167. test/cli_test.py +0 -659
  168. test/client_test.py +0 -194
  169. test/cls_test.py +0 -630
  170. test/config_test.py +0 -137
  171. test/conftest.py +0 -1420
  172. test/container_app_test.py +0 -32
  173. test/container_test.py +0 -1389
  174. test/cpu_test.py +0 -23
  175. test/decorator_test.py +0 -85
  176. test/deprecation_test.py +0 -34
  177. test/dict_test.py +0 -33
  178. test/e2e_test.py +0 -68
  179. test/error_test.py +0 -7
  180. test/function_serialization_test.py +0 -32
  181. test/function_test.py +0 -653
  182. test/function_utils_test.py +0 -101
  183. test/gpu_test.py +0 -159
  184. test/grpc_utils_test.py +0 -141
  185. test/helpers.py +0 -42
  186. test/image_test.py +0 -669
  187. test/live_reload_test.py +0 -80
  188. test/lookup_test.py +0 -70
  189. test/mdmd_test.py +0 -329
  190. test/mount_test.py +0 -162
  191. test/mounted_files_test.py +0 -329
  192. test/network_file_system_test.py +0 -181
  193. test/notebook_test.py +0 -66
  194. test/object_test.py +0 -41
  195. test/package_utils_test.py +0 -25
  196. test/queue_test.py +0 -97
  197. test/resolver_test.py +0 -58
  198. test/retries_test.py +0 -67
  199. test/runner_test.py +0 -85
  200. test/sandbox_test.py +0 -191
  201. test/schedule_test.py +0 -15
  202. test/scheduler_placement_test.py +0 -29
  203. test/secret_test.py +0 -78
  204. test/serialization_test.py +0 -42
  205. test/stub_composition_test.py +0 -10
  206. test/stub_test.py +0 -360
  207. test/test_asgi_wrapper.py +0 -234
  208. test/token_flow_test.py +0 -18
  209. test/traceback_test.py +0 -135
  210. test/tunnel_test.py +0 -29
  211. test/utils_test.py +0 -88
  212. test/version_test.py +0 -14
  213. test/volume_test.py +0 -341
  214. test/watcher_test.py +0 -30
  215. test/webhook_test.py +0 -146
  216. /modal/{requirements.312.txt → requirements/2023.12.312.txt} +0 -0
  217. /modal/{requirements.txt → requirements/2023.12.txt} +0 -0
  218. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
  219. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
  220. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
@@ -1,63 +1,103 @@
1
1
  # Copyright Modal Labs 2022
2
- from __future__ import annotations
2
+ # ruff: noqa: E402
3
+ import os
4
+
5
+ from modal._runtime.user_code_imports import Service, import_class_service, import_single_function_service
6
+
7
+ telemetry_socket = os.environ.get("MODAL_TELEMETRY_SOCKET")
8
+ if telemetry_socket:
9
+ from ._runtime.telemetry import instrument_imports
10
+
11
+ instrument_imports(telemetry_socket)
3
12
 
4
13
  import asyncio
5
- import base64
6
- import contextlib
7
- import importlib
14
+ import concurrent.futures
8
15
  import inspect
9
- import json
10
- import math
11
- import os
16
+ import queue
12
17
  import signal
13
18
  import sys
14
19
  import threading
15
20
  import time
16
- import traceback
17
- from collections.abc import Iterable
18
- from dataclasses import dataclass
19
- from pathlib import Path
20
- from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Callable, List, Optional, Set, Tuple, Type
21
+ from collections.abc import Sequence
22
+ from typing import TYPE_CHECKING, Any, Callable, Optional
21
23
 
22
- from grpclib import Status
24
+ from google.protobuf.message import Message
23
25
 
26
+ from modal._clustered_functions import initialize_clustered_function
27
+ from modal._proxy_tunnel import proxy_tunnel
28
+ from modal._serialization import deserialize, deserialize_proto_params
29
+ from modal._utils.async_utils import TaskContext, synchronizer
30
+ from modal._utils.function_utils import (
31
+ callable_has_non_self_params,
32
+ )
33
+ from modal.app import App, _App
34
+ from modal.client import Client, _Client
35
+ from modal.config import logger
36
+ from modal.exception import ExecutionError, InputCancellation, InvalidError
37
+ from modal.partial_function import (
38
+ _find_callables_for_obj,
39
+ _PartialFunctionFlags,
40
+ )
41
+ from modal.running_app import RunningApp
24
42
  from modal_proto import api_pb2
25
43
 
26
- from ._asgi import (
27
- asgi_app_wrapper,
28
- get_ip_address,
29
- wait_for_web_server,
30
- web_server_proxy,
31
- webhook_asgi_app,
32
- wsgi_app_wrapper,
44
+ from ._runtime.container_io_manager import (
45
+ ContainerIOManager,
46
+ IOContext,
47
+ UserException,
48
+ _ContainerIOManager,
33
49
  )
34
- from ._proxy_tunnel import proxy_tunnel
35
- from ._serialization import deserialize, deserialize_data_format, serialize, serialize_data_format
36
- from ._traceback import extract_traceback
37
- from ._utils.async_utils import TaskContext, asyncify, synchronize_api, synchronizer
38
- from ._utils.blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload
39
- from ._utils.function_utils import LocalFunctionError, is_async as get_is_async, is_global_function, method_has_params
40
- from ._utils.grpc_utils import retry_transient_errors
41
- from .app import ContainerApp, _container_app, _ContainerApp, interact
42
- from .client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, Client, _Client
43
- from .cls import Cls
44
- from .config import config, logger
45
- from .exception import InputCancellation, InvalidError
46
- from .functions import Function, _Function, _set_current_context_ids, _stream_function_call_data
47
- from .partial_function import _find_callables_for_obj, _PartialFunctionFlags
48
- from .stub import _Stub
50
+ from ._runtime.execution_context import _set_current_context_ids
49
51
 
50
52
  if TYPE_CHECKING:
51
- from types import ModuleType
53
+ import modal._runtime.container_io_manager
54
+ import modal.object
55
+
56
+
57
+ class DaemonizedThreadPool:
58
+ # Used instead of ThreadPoolExecutor, since the latter won't allow
59
+ # the interpreter to shut down before the currently running tasks
60
+ # have finished
61
+ def __init__(self, max_threads: int):
62
+ self.max_threads = max_threads
63
+
64
+ def __enter__(self):
65
+ self.spawned_workers = 0
66
+ self.inputs: queue.Queue[Any] = queue.Queue()
67
+ self.finished = threading.Event()
68
+ return self
69
+
70
+ def __exit__(self, exc_type, exc_value, traceback):
71
+ self.finished.set()
52
72
 
53
- MAX_OUTPUT_BATCH_SIZE: int = 49
73
+ if exc_type is None:
74
+ self.inputs.join()
75
+ else:
76
+ # special case - allows us to exit the
77
+ if self.inputs.unfinished_tasks:
78
+ logger.info(
79
+ f"Exiting DaemonizedThreadPool with {self.inputs.unfinished_tasks} active "
80
+ f"inputs due to exception: {repr(exc_type)}"
81
+ )
54
82
 
55
- RTT_S: float = 0.5 # conservative estimate of RTT in seconds.
83
+ def submit(self, func, *args):
84
+ def worker_thread():
85
+ while not self.finished.is_set():
86
+ try:
87
+ _func, _args = self.inputs.get(timeout=1)
88
+ except queue.Empty:
89
+ continue
90
+ try:
91
+ _func(*_args)
92
+ except BaseException:
93
+ logger.exception(f"Exception raised by {_func} in DaemonizedThreadPool worker!")
94
+ self.inputs.task_done()
56
95
 
96
+ if self.spawned_workers < self.max_threads:
97
+ threading.Thread(target=worker_thread, daemon=True).start()
98
+ self.spawned_workers += 1
57
99
 
58
- class UserException(Exception):
59
- # Used to shut down the task gracefully
60
- pass
100
+ self.inputs.put((func, args))
61
101
 
62
102
 
63
103
  class UserCodeEventLoop:
@@ -76,14 +116,25 @@ class UserCodeEventLoop:
76
116
 
77
117
  def __enter__(self):
78
118
  self.loop = asyncio.new_event_loop()
119
+ self.tasks = set()
79
120
  return self
80
121
 
81
122
  def __exit__(self, exc_type, exc_value, traceback):
82
123
  self.loop.run_until_complete(self.loop.shutdown_asyncgens())
83
124
  if sys.version_info[:2] >= (3, 9):
84
125
  self.loop.run_until_complete(self.loop.shutdown_default_executor()) # Introduced in Python 3.9
126
+
127
+ for task in self.tasks:
128
+ task.cancel()
129
+
85
130
  self.loop.close()
86
131
 
132
+ def create_task(self, coro):
133
+ task = self.loop.create_task(coro)
134
+ self.tasks.add(task)
135
+ task.add_done_callback(self.tasks.discard)
136
+ return task
137
+
87
138
  def run(self, coro):
88
139
  task = asyncio.ensure_future(coro, loop=self.loop)
89
140
  self._sigints = 0
@@ -99,7 +150,9 @@ class UserCodeEventLoop:
99
150
  # first sigint is graceful
100
151
  task.cancel()
101
152
  return
102
- raise KeyboardInterrupt() # this should normally not happen, but the second sigint would "hard kill" the event loop!
153
+
154
+ # this should normally not happen, but the second sigint would "hard kill" the event loop!
155
+ raise KeyboardInterrupt()
103
156
 
104
157
  ignore_sigint = signal.getsignal(signal.SIGINT) == signal.SIG_IGN
105
158
  if not ignore_sigint:
@@ -122,972 +175,381 @@ class UserCodeEventLoop:
122
175
  self.loop.remove_signal_handler(signal.SIGINT)
123
176
 
124
177
 
125
- class _FunctionIOManager:
126
- """Synchronizes all RPC calls and network operations for a running container.
127
-
128
- TODO: maybe we shouldn't synchronize the whole class.
129
- Then we could potentially move a bunch of the global functions onto it.
130
- """
131
-
132
- _GENERATOR_STOP_SENTINEL = object()
133
-
134
- def __init__(self, container_args: api_pb2.ContainerArguments, client: _Client):
135
- self.cancelled_input_ids: Set[str] = set()
136
- self.task_id = container_args.task_id
137
- self.function_id = container_args.function_id
138
- self.app_id = container_args.app_id
139
- self.function_def = container_args.function_def
140
- self.checkpoint_id = container_args.checkpoint_id
141
-
142
- self.calls_completed = 0
143
- self.total_user_time: float = 0.0
144
- self.current_input_id: Optional[str] = None
145
- self.current_input_started_at: Optional[float] = None
146
-
147
- self._input_concurrency: Optional[int] = None
148
-
149
- self._semaphore: Optional[asyncio.Semaphore] = None
150
- self._environment_name = container_args.environment_name
151
- self._waiting_for_checkpoint = False
152
- self._heartbeat_loop = None
153
-
154
- self._client = client
155
- assert isinstance(self._client, _Client)
156
-
157
- async def initialize_app(self) -> _ContainerApp:
158
- await _container_app.init(self._client, self.app_id, self._environment_name, self.function_def)
159
- return _container_app
160
-
161
- async def _run_heartbeat_loop(self):
162
- while 1:
163
- t0 = time.monotonic()
164
- try:
165
- if await self._heartbeat_handle_cancellations():
166
- # got a cancellation event, fine to start another heartbeat immediately
167
- # since the cancellation queue should be empty on the worker server
168
- # however, we wait at least 1s to prevent short-circuiting the heartbeat loop
169
- # in case there is ever a bug. This means it will take at least 1s between
170
- # two subsequent cancellations on the same task at the moment
171
- await asyncio.sleep(1.0)
172
- continue
173
- except Exception as exc:
174
- # don't stop heartbeat loop if there are transient exceptions!
175
- time_elapsed = time.monotonic() - t0
176
- error = exc
177
- logger.warning(f"Heartbeat attempt failed ({time_elapsed=}, {error=})")
178
-
179
- heartbeat_duration = time.monotonic() - t0
180
- time_until_next_hearbeat = max(0.0, HEARTBEAT_INTERVAL - heartbeat_duration)
181
- await asyncio.sleep(time_until_next_hearbeat)
182
-
183
- async def _heartbeat_handle_cancellations(self) -> bool:
184
- # Return True if a cancellation event was received, in that case we shouldn't wait too long for another heartbeat
185
-
186
- # Don't send heartbeats for tasks waiting to be checkpointed.
187
- # Calling gRPC methods open new connections which block the
188
- # checkpointing process.
189
- if self._waiting_for_checkpoint:
190
- return False
191
-
192
- request = api_pb2.ContainerHeartbeatRequest(supports_graceful_input_cancellation=True)
193
- if self.current_input_id is not None:
194
- request.current_input_id = self.current_input_id
195
- if self.current_input_started_at is not None:
196
- request.current_input_started_at = self.current_input_started_at
197
-
198
- # TODO(erikbern): capture exceptions?
199
- response = await retry_transient_errors(
200
- self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
201
- )
178
+ def call_function(
179
+ user_code_event_loop: UserCodeEventLoop,
180
+ container_io_manager: "modal._runtime.container_io_manager.ContainerIOManager",
181
+ finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
182
+ batch_max_size: int,
183
+ batch_wait_ms: int,
184
+ ):
185
+ async def run_input_async(io_context: IOContext) -> None:
186
+ started_at = time.time()
187
+ input_ids, function_call_ids = io_context.input_ids, io_context.function_call_ids
188
+ reset_context = _set_current_context_ids(input_ids, function_call_ids)
189
+ async with container_io_manager.handle_input_exception.aio(io_context, started_at):
190
+ res = io_context.call_finalized_function()
191
+ # TODO(erikbern): any exception below shouldn't be considered a user exception
192
+ if io_context.finalized_function.is_generator:
193
+ if not inspect.isasyncgen(res):
194
+ raise InvalidError(f"Async generator function returned value of type {type(res)}")
202
195
 
203
- if response.HasField("cancel_input_event"):
204
- # Pause processing of the current input by signaling self a SIGUSR1.
205
- input_ids_to_cancel = response.cancel_input_event.input_ids
206
- if input_ids_to_cancel:
207
- if self._input_concurrency > 1:
208
- logger.info(
209
- "Shutting down task to stop some subset of inputs (concurrent functions don't support fine-grained cancellation)"
196
+ # Send up to this many outputs at a time.
197
+ generator_queue: asyncio.Queue[Any] = await container_io_manager._queue_create.aio(1024)
198
+ generator_output_task = asyncio.create_task(
199
+ container_io_manager.generator_output_task.aio(
200
+ function_call_ids[0],
201
+ io_context.finalized_function.data_format,
202
+ generator_queue,
210
203
  )
211
- # This is equivalent to a task cancellation or preemption from worker code,
212
- # except we do not send a SIGKILL to forcefully exit after 30 seconds.
213
- #
214
- # SIGINT always interrupts the main thread, but not any auxiliary threads. On a
215
- # sync function without concurrent inputs, this raises a KeyboardInterrupt. When
216
- # there are concurrent inputs, we cannot interrupt the thread pool, but the
217
- # interpreter stops waiting for daemon threads and exits. On async functions,
218
- # this signal lands outside the event loop, stopping `run_until_complete()`.
219
- os.kill(os.getpid(), signal.SIGINT)
220
-
221
- elif self.current_input_id in input_ids_to_cancel:
222
- # This goes to a registered signal handler for sync Modal functions, or to the
223
- # `SignalHandlingEventLoop` for async functions.
224
- #
225
- # We only send this signal on functions that do not have concurrent inputs enabled.
226
- # This allows us to do fine-grained input cancellation. On sync functions, the
227
- # SIGUSR1 signal should interrupt the main thread where user code is running,
228
- # raising an InputCancellation() exception. On async functions, the signal should
229
- # reach a handler in SignalHandlingEventLoop, which cancels the task.
230
- os.kill(os.getpid(), signal.SIGUSR1)
231
- return True
232
- return False
233
-
234
- @contextlib.asynccontextmanager
235
- async def heartbeats(self):
236
- async with TaskContext() as tc:
237
- self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop())
238
- t.set_name("heartbeat loop")
239
- try:
240
- yield
241
- finally:
242
- t.cancel()
243
-
244
- def stop_heartbeat(self):
245
- if self._heartbeat_loop:
246
- self._heartbeat_loop.cancel()
247
-
248
- async def get_serialized_function(self) -> Tuple[Optional[Any], Callable]:
249
- # Fetch the serialized function definition
250
- request = api_pb2.FunctionGetSerializedRequest(function_id=self.function_id)
251
- response = await self._client.stub.FunctionGetSerialized(request)
252
- fun = self.deserialize(response.function_serialized)
253
-
254
- if response.class_serialized:
255
- cls = self.deserialize(response.class_serialized)
256
- else:
257
- cls = None
258
-
259
- return cls, fun
260
-
261
- def serialize(self, obj: Any) -> bytes:
262
- return serialize(obj)
263
-
264
- def deserialize(self, data: bytes) -> Any:
265
- return deserialize(data, self._client)
266
-
267
- @synchronizer.no_io_translation
268
- def serialize_data_format(self, obj: Any, data_format: int) -> bytes:
269
- return serialize_data_format(obj, data_format)
270
-
271
- def deserialize_data_format(self, data: bytes, data_format: int) -> Any:
272
- return deserialize_data_format(data, data_format, self._client)
273
-
274
- async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]:
275
- """Read from the `data_in` stream of a function call."""
276
- async for data in _stream_function_call_data(self._client, function_call_id, "data_in"):
277
- yield data
278
-
279
- async def put_data_out(
280
- self,
281
- function_call_id: str,
282
- start_index: int,
283
- data_format: int,
284
- messages_bytes: List[Any],
285
- ) -> None:
286
- """Put data onto the `data_out` stream of a function call.
287
-
288
- This is used for generator outputs, which includes web endpoint responses. Note that this
289
- was introduced as a performance optimization in client version 0.57, so older clients will
290
- still use the previous Postgres-backed system based on `FunctionPutOutputs()`.
291
- """
292
- data_chunks: List[api_pb2.DataChunk] = []
293
- for i, message_bytes in enumerate(messages_bytes):
294
- chunk = api_pb2.DataChunk(data_format=data_format, index=start_index + i) # type: ignore
295
- if len(message_bytes) > MAX_OBJECT_SIZE_BYTES:
296
- chunk.data_blob_id = await blob_upload(message_bytes, self._client.stub)
297
- else:
298
- chunk.data = message_bytes
299
- data_chunks.append(chunk)
300
-
301
- req = api_pb2.FunctionCallPutDataRequest(function_call_id=function_call_id, data_chunks=data_chunks)
302
- await retry_transient_errors(self._client.stub.FunctionCallPutDataOut, req)
303
-
304
- async def generator_output_task(self, function_call_id: str, data_format: int, message_rx: asyncio.Queue) -> None:
305
- """Task that feeds generator outputs into a function call's `data_out` stream."""
306
- index = 1
307
- received_sentinel = False
308
- while not received_sentinel:
309
- message = await message_rx.get()
310
- if message is self._GENERATOR_STOP_SENTINEL:
311
- break
312
- # ASGI 'http.response.start' and 'http.response.body' msgs are observed to be separated by 1ms.
313
- # If we don't sleep here for 1ms we end up with an extra call to .put_data_out().
314
- if index == 1:
315
- await asyncio.sleep(0.001)
316
- messages_bytes = [serialize_data_format(message, data_format)]
317
- total_size = len(messages_bytes[0]) + 512
318
- while total_size < 16 * 1024 * 1024: # 16 MiB, maximum size in a single message
319
- try:
320
- message = message_rx.get_nowait()
321
- except asyncio.QueueEmpty:
322
- break
323
- if message is self._GENERATOR_STOP_SENTINEL:
324
- received_sentinel = True
325
- break
326
- else:
327
- messages_bytes.append(serialize_data_format(message, data_format))
328
- total_size += len(messages_bytes[-1]) + 512 # 512 bytes for estimated framing overhead
329
- await self.put_data_out(function_call_id, index, data_format, messages_bytes)
330
- index += len(messages_bytes)
331
-
332
- async def _queue_create(self, size: int) -> asyncio.Queue:
333
- """Create a queue, on the synchronicity event loop (needed on Python 3.8 and 3.9)."""
334
- return asyncio.Queue(size)
335
-
336
- async def _queue_put(self, queue: asyncio.Queue, value: Any) -> None:
337
- """Put a value onto a queue, using the synchronicity event loop."""
338
- await queue.put(value)
339
-
340
- async def populate_input_blobs(self, item: api_pb2.FunctionInput):
341
- args = await blob_download(item.args_blob_id, self._client.stub)
342
-
343
- # Mutating
344
- item.ClearField("args_blob_id")
345
- item.args = args
346
- return item
347
-
348
- def get_average_call_time(self) -> float:
349
- if self.calls_completed == 0:
350
- return 0
351
-
352
- return self.total_user_time / self.calls_completed
353
-
354
- def get_max_inputs_to_fetch(self):
355
- if self.calls_completed == 0:
356
- return 1
357
-
358
- return math.ceil(RTT_S / max(self.get_average_call_time(), 1e-6))
359
-
360
- @synchronizer.no_io_translation
361
- async def _generate_inputs(self) -> AsyncIterator[Tuple[str, str, api_pb2.FunctionInput]]:
362
- request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id)
363
- eof_received = False
364
- iteration = 0
365
- while not eof_received and _container_app.fetching_inputs:
366
- request.average_call_time = self.get_average_call_time()
367
- request.max_values = self.get_max_inputs_to_fetch() # Deprecated; remove.
368
- request.input_concurrency = self._input_concurrency
369
-
370
- await self._semaphore.acquire()
371
- yielded = False
372
- try:
373
- # If number of active inputs is at max queue size, this will block.
374
- iteration += 1
375
- response: api_pb2.FunctionGetInputsResponse = await retry_transient_errors(
376
- self._client.stub.FunctionGetInputs, request
377
204
  )
378
205
 
379
- if response.rate_limit_sleep_duration:
380
- logger.info(
381
- "Task exceeded rate limit, sleeping for %.2fs before trying again."
382
- % response.rate_limit_sleep_duration
383
- )
384
- await asyncio.sleep(response.rate_limit_sleep_duration)
385
- elif response.inputs:
386
- # for input cancellations and concurrency logic we currently assume
387
- # that there is no input buffering in the container
388
- assert len(response.inputs) == 1
389
-
390
- for item in response.inputs:
391
- if item.kill_switch:
392
- logger.debug(f"Task {self.task_id} input kill signal input.")
393
- eof_received = True
394
- break
395
- if item.input_id in self.cancelled_input_ids:
396
- continue
397
-
398
- # If we got a pointer to a blob, download it from S3.
399
- if item.input.WhichOneof("args_oneof") == "args_blob_id":
400
- input_pb = await self.populate_input_blobs(item.input)
401
- else:
402
- input_pb = item.input
403
-
404
- # If yielded, allow semaphore to be released via complete_call
405
- yield (item.input_id, item.function_call_id, input_pb)
406
- yielded = True
407
-
408
- # We only support max_inputs = 1 at the moment
409
- if item.input.final_input or self.function_def.max_inputs == 1:
410
- eof_received = True
411
- break
412
- finally:
413
- if not yielded:
414
- self._semaphore.release()
415
-
416
- @synchronizer.no_io_translation
417
- async def run_inputs_outputs(self, input_concurrency: int = 1) -> AsyncIterator[Tuple[str, str, Any, Any]]:
418
- # Ensure we do not fetch new inputs when container is too busy.
419
- # Before trying to fetch an input, acquire the semaphore:
420
- # - if no input is fetched, release the semaphore.
421
- # - or, when the output for the fetched input is sent, release the semaphore.
422
- self._input_concurrency = input_concurrency
423
- self._semaphore = asyncio.Semaphore(input_concurrency)
424
-
425
- try:
426
- async for input_id, function_call_id, input_pb in self._generate_inputs():
427
- args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {})
428
- self.current_input_id, self.current_input_started_at = (input_id, time.time())
429
- yield input_id, function_call_id, args, kwargs
430
- self.current_input_id, self.current_input_started_at = (None, None)
431
- finally:
432
- # collect all active input slots, meaning all inputs have wrapped up.
433
- for _ in range(input_concurrency):
434
- await self._semaphore.acquire()
435
-
436
- async def _push_output(self, input_id, started_at: float, data_format=api_pb2.DATA_FORMAT_UNSPECIFIED, **kwargs):
437
- # upload data to S3 if too big.
438
- if "data" in kwargs and kwargs["data"] and len(kwargs["data"]) > MAX_OBJECT_SIZE_BYTES:
439
- data_blob_id = await blob_upload(kwargs["data"], self._client.stub)
440
- # mutating kwargs.
441
- del kwargs["data"]
442
- kwargs["data_blob_id"] = data_blob_id
443
-
444
- output = api_pb2.FunctionPutOutputsItem(
445
- input_id=input_id,
446
- input_started_at=started_at,
447
- output_created_at=time.time(),
448
- result=api_pb2.GenericResult(**kwargs),
449
- data_format=data_format,
450
- )
451
-
452
- await retry_transient_errors(
453
- self._client.stub.FunctionPutOutputs,
454
- api_pb2.FunctionPutOutputsRequest(outputs=[output]),
455
- additional_status_codes=[Status.RESOURCE_EXHAUSTED],
456
- max_retries=None, # Retry indefinitely, trying every 1s.
457
- )
458
-
459
- def serialize_exception(self, exc: BaseException) -> Optional[bytes]:
460
- try:
461
- return self.serialize(exc)
462
- except Exception as serialization_exc:
463
- logger.info(f"Failed to serialize exception {exc}: {serialization_exc}")
464
- # We can't always serialize exceptions.
465
- return None
466
-
467
- def serialize_traceback(self, exc: BaseException) -> Tuple[Optional[bytes], Optional[bytes]]:
468
- serialized_tb, tb_line_cache = None, None
469
-
470
- try:
471
- tb_dict, line_cache = extract_traceback(exc, self.task_id)
472
- serialized_tb = self.serialize(tb_dict)
473
- tb_line_cache = self.serialize(line_cache)
474
- except Exception:
475
- logger.info("Failed to serialize exception traceback.")
476
-
477
- return serialized_tb, tb_line_cache
478
-
479
- @contextlib.asynccontextmanager
480
- async def handle_user_exception(self) -> AsyncGenerator[None, None]:
481
- """Sets the task as failed in a way where it's not retried.
482
-
483
- Used for handling exceptions from container lifecycle methods at the moment, which should
484
- trigger a task failure state.
485
- """
486
- try:
487
- yield
488
- except KeyboardInterrupt:
489
- # Send no task result in case we get sigint:ed by the runner
490
- # The status of the input should have been handled externally already in that case
491
- raise
492
- except BaseException as exc:
493
- # Since this is on a different thread, sys.exc_info() can't find the exception in the stack.
494
- traceback.print_exception(type(exc), exc, exc.__traceback__)
495
-
496
- serialized_tb, tb_line_cache = self.serialize_traceback(exc)
497
-
498
- result = api_pb2.GenericResult(
499
- status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
500
- data=self.serialize_exception(exc),
501
- exception=repr(exc),
502
- traceback="".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
503
- serialized_tb=serialized_tb,
504
- tb_line_cache=tb_line_cache,
505
- )
506
-
507
- req = api_pb2.TaskResultRequest(result=result)
508
- await retry_transient_errors(self._client.stub.TaskResult, req)
509
-
510
- # Shut down the task gracefully
511
- raise UserException()
512
-
513
- @contextlib.asynccontextmanager
514
- async def handle_input_exception(self, input_id, started_at: float) -> AsyncGenerator[None, None]:
515
- """Handle an exception while processing a function input."""
516
- try:
517
- yield
518
- except KeyboardInterrupt:
519
- raise
520
- except (InputCancellation, asyncio.CancelledError):
521
- # just skip creating any output for this input and keep going with the next instead
522
- # it should have been marked as cancelled already in the backend at this point so it
523
- # won't be retried
524
- logger.warning(f"The current input ({input_id=}) was cancelled by a user request")
525
- await self.complete_call(started_at)
526
- return
527
- except BaseException as exc:
528
- # print exception so it's logged
529
- traceback.print_exc()
530
- serialized_tb, tb_line_cache = self.serialize_traceback(exc)
531
-
532
- # Note: we're not serializing the traceback since it contains
533
- # local references that means we can't unpickle it. We *are*
534
- # serializing the exception, which may have some issues (there
535
- # was an earlier note about it that it might not be possible
536
- # to unpickle it in some cases). Let's watch out for issues.
537
- await self._push_output(
538
- input_id,
539
- started_at=started_at,
540
- data_format=api_pb2.DATA_FORMAT_PICKLE,
541
- status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
542
- data=self.serialize_exception(exc),
543
- exception=repr(exc),
544
- traceback=traceback.format_exc(),
545
- serialized_tb=serialized_tb,
546
- tb_line_cache=tb_line_cache,
547
- )
548
- await self.complete_call(started_at)
549
-
550
- async def complete_call(self, started_at):
551
- self.total_user_time += time.time() - started_at
552
- self.calls_completed += 1
553
- self._semaphore.release()
554
-
555
- @synchronizer.no_io_translation
556
- async def push_output(self, input_id, started_at: float, data: Any, data_format: int) -> None:
557
- await self._push_output(
558
- input_id,
559
- started_at=started_at,
560
- data_format=data_format,
561
- status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
562
- data=self.serialize_data_format(data, data_format),
563
- )
564
- await self.complete_call(started_at)
565
-
566
- async def restore(self) -> None:
567
- # Busy-wait for restore. `/__modal/restore-state.json` is created
568
- # by the worker process with updates to the container config.
569
- restored_path = Path(config.get("restore_state_path"))
570
- start = time.perf_counter()
571
- while not restored_path.exists():
572
- logger.debug(f"Waiting for restore (elapsed={time.perf_counter() - start:.3f}s)")
573
- await asyncio.sleep(0.01)
574
- continue
575
-
576
- logger.debug("Container: restored")
577
-
578
- # Look for state file and create new client with updated credentials.
579
- # State data is serialized with key-value pairs, example: {"task_id": "tk-000"}
580
- with restored_path.open("r") as file:
581
- restored_state = json.load(file)
582
-
583
- # Local FunctionIOManager state.
584
- for key in ["task_id", "function_id"]:
585
- if value := restored_state.get(key):
586
- logger.debug(f"Updating FunctionIOManager.{key} = {value}")
587
- setattr(self, key, restored_state[key])
588
-
589
- # Env vars and global state.
590
- for key, value in restored_state.items():
591
- # Empty string indicates that value does not need to be updated.
592
- if value != "":
593
- config.override_locally(key, value)
594
-
595
- # Restore input to default state.
596
- self.current_input_id = None
597
- self.current_input_started_at = None
598
-
599
- self._client = await _Client.from_env()
600
- self._waiting_for_checkpoint = False
601
-
602
- async def checkpoint(self) -> None:
603
- """Message server indicating that function is ready to be checkpointed."""
604
- if self.checkpoint_id:
605
- logger.debug(f"Checkpoint ID: {self.checkpoint_id}")
606
-
607
- await self._client.stub.ContainerCheckpoint(
608
- api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id)
609
- )
206
+ item_count = 0
207
+ async for value in res:
208
+ await container_io_manager._queue_put.aio(generator_queue, value)
209
+ item_count += 1
610
210
 
611
- self._waiting_for_checkpoint = True
612
- await self._client._close()
613
-
614
- logger.debug("Checkpointing request sent. Connection closed.")
615
- await self.restore()
616
-
617
- async def volume_commit(self, volume_ids: List[str]) -> None:
618
- """
619
- Perform volume commit for given `volume_ids`.
620
- Only used on container exit to persist uncommitted changes on behalf of user.
621
- """
622
- if not volume_ids:
623
- return
624
- await asyncify(os.sync)()
625
- results = await asyncio.gather(
626
- *[
627
- retry_transient_errors(
628
- self._client.stub.VolumeCommit,
629
- api_pb2.VolumeCommitRequest(volume_id=v_id),
630
- max_retries=9,
631
- base_delay=0.25,
632
- max_delay=256,
633
- delay_factor=2,
211
+ await container_io_manager._queue_put.aio(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL)
212
+ await generator_output_task # Wait to finish sending generator outputs.
213
+ message = api_pb2.GeneratorDone(items_total=item_count)
214
+ await container_io_manager.push_outputs.aio(
215
+ io_context,
216
+ started_at,
217
+ message,
218
+ api_pb2.DATA_FORMAT_GENERATOR_DONE,
634
219
  )
635
- for v_id in volume_ids
636
- ],
637
- return_exceptions=True,
638
- )
639
- for volume_id, res in zip(volume_ids, results):
640
- if isinstance(res, Exception):
641
- logger.error(f"modal.Volume background commit failed for {volume_id}. Exception: {res}")
642
220
  else:
643
- logger.debug(f"modal.Volume background commit success for {volume_id}.")
644
-
645
-
646
- FunctionIOManager = synchronize_api(_FunctionIOManager)
647
-
221
+ if not inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
222
+ raise InvalidError(
223
+ f"Async (non-generator) function returned value of type {type(res)}"
224
+ " You might need to use @app.function(..., is_generator=True)."
225
+ )
226
+ value = await res
227
+ await container_io_manager.push_outputs.aio(
228
+ io_context,
229
+ started_at,
230
+ value,
231
+ io_context.finalized_function.data_format,
232
+ )
233
+ reset_context()
648
234
 
649
- def call_function_sync(
650
- function_io_manager, #: FunctionIOManager, TODO: this type is generated at runtime
651
- imp_fun: ImportedFunction,
652
- ):
653
- def run_input(input_id: str, function_call_id: str, args: Any, kwargs: Any) -> None:
235
+ def run_input_sync(io_context: IOContext) -> None:
654
236
  started_at = time.time()
655
- reset_context = _set_current_context_ids(input_id, function_call_id)
656
- with function_io_manager.handle_input_exception(input_id, started_at):
657
- logger.debug(f"Starting input {input_id} (sync)")
658
- res = imp_fun.fun(*args, **kwargs)
659
- logger.debug(f"Finished input {input_id} (sync)")
237
+ input_ids, function_call_ids = io_context.input_ids, io_context.function_call_ids
238
+ reset_context = _set_current_context_ids(input_ids, function_call_ids)
239
+ with container_io_manager.handle_input_exception(io_context, started_at):
240
+ res = io_context.call_finalized_function()
660
241
 
661
242
  # TODO(erikbern): any exception below shouldn't be considered a user exception
662
- if imp_fun.is_generator:
243
+ if io_context.finalized_function.is_generator:
663
244
  if not inspect.isgenerator(res):
664
245
  raise InvalidError(f"Generator function returned value of type {type(res)}")
665
246
 
666
247
  # Send up to this many outputs at a time.
667
- generator_queue: asyncio.Queue[Any] = function_io_manager._queue_create(1024)
668
- generator_output_task = function_io_manager.generator_output_task(
669
- function_call_id,
670
- imp_fun.data_format,
248
+ generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024)
249
+ generator_output_task: concurrent.futures.Future = container_io_manager.generator_output_task( # type: ignore
250
+ function_call_ids[0],
251
+ io_context.finalized_function.data_format,
671
252
  generator_queue,
672
- _future=True, # Synchronicity magic to return a future.
253
+ _future=True, # type: ignore # Synchronicity magic to return a future.
673
254
  )
674
255
 
675
256
  item_count = 0
676
257
  for value in res:
677
- function_io_manager._queue_put(generator_queue, value)
258
+ container_io_manager._queue_put(generator_queue, value)
678
259
  item_count += 1
679
260
 
680
- function_io_manager._queue_put(generator_queue, _FunctionIOManager._GENERATOR_STOP_SENTINEL)
261
+ container_io_manager._queue_put(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL)
681
262
  generator_output_task.result() # Wait to finish sending generator outputs.
682
263
  message = api_pb2.GeneratorDone(items_total=item_count)
683
- function_io_manager.push_output(input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE)
264
+ container_io_manager.push_outputs(io_context, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE)
684
265
  else:
685
266
  if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
686
267
  raise InvalidError(
687
268
  f"Sync (non-generator) function return value of type {type(res)}."
688
- " You might need to use @stub.function(..., is_generator=True)."
269
+ " You might need to use @app.function(..., is_generator=True)."
689
270
  )
690
- function_io_manager.push_output(input_id, started_at, res, imp_fun.data_format)
271
+ container_io_manager.push_outputs(
272
+ io_context, started_at, res, io_context.finalized_function.data_format
273
+ )
691
274
  reset_context()
692
275
 
693
- if imp_fun.input_concurrency > 1:
694
- # We can't use `concurrent.futures.ThreadPoolExecutor` here because in Python 3.11+, this
695
- # class has no workaround that allows us to exit the Python interpreter process without
696
- # waiting for the worker threads to finish. We need this behavior on SIGINT.
697
-
698
- import queue
699
- import threading
700
-
701
- spawned_workers = 0
702
- inputs: queue.Queue[Any] = queue.Queue()
703
- finished = threading.Event()
704
-
705
- def worker_thread():
706
- while not finished.is_set():
707
- try:
708
- args = inputs.get(timeout=1)
709
- except queue.Empty:
710
- continue
711
- try:
712
- run_input(*args)
713
- except BaseException:
714
- # This should basically never happen, since only KeyboardInterrupt is the only error that can
715
- # bubble out of from handle_input_exception and those wouldn't be raised outside the main thread
716
- pass
717
- inputs.task_done()
718
-
719
- for input_id, function_call_id, args, kwargs in function_io_manager.run_inputs_outputs(
720
- imp_fun.input_concurrency
721
- ):
722
- if spawned_workers < imp_fun.input_concurrency:
723
- threading.Thread(target=worker_thread, daemon=True).start()
724
- spawned_workers += 1
725
- inputs.put((input_id, function_call_id, args, kwargs))
726
-
727
- finished.set()
728
- inputs.join()
729
-
730
- else:
731
- for input_id, function_call_id, args, kwargs in function_io_manager.run_inputs_outputs(
732
- imp_fun.input_concurrency
733
- ):
734
- try:
735
- run_input(input_id, function_call_id, args, kwargs)
736
- except:
737
- raise
738
-
739
-
740
- async def call_function_async(
741
- function_io_manager, #: FunctionIOManager, TODO: this type is generated at runtime
742
- imp_fun: ImportedFunction,
743
- ):
744
- async def run_input(input_id: str, function_call_id: str, args: Any, kwargs: Any) -> None:
745
- started_at = time.time()
746
- reset_context = _set_current_context_ids(input_id, function_call_id)
747
- async with function_io_manager.handle_input_exception.aio(input_id, started_at):
748
- logger.debug(f"Starting input {input_id} (async)")
749
- res = imp_fun.fun(*args, **kwargs)
750
- logger.debug(f"Finished input {input_id} (async)")
751
-
752
- # TODO(erikbern): any exception below shouldn't be considered a user exception
753
- if imp_fun.is_generator:
754
- if not inspect.isasyncgen(res):
755
- raise InvalidError(f"Async generator function returned value of type {type(res)}")
756
-
757
- # Send up to this many outputs at a time.
758
- generator_queue: asyncio.Queue[Any] = await function_io_manager._queue_create.aio(1024)
759
- generator_output_task = asyncio.create_task(
760
- function_io_manager.generator_output_task.aio(
761
- function_call_id,
762
- imp_fun.data_format,
763
- generator_queue,
276
+ if container_io_manager.target_concurrency > 1:
277
+ with DaemonizedThreadPool(max_threads=container_io_manager.max_concurrency) as thread_pool:
278
+
279
+ def make_async_cancel_callback(task):
280
+ def f():
281
+ user_code_event_loop.loop.call_soon_threadsafe(task.cancel)
282
+
283
+ return f
284
+
285
+ did_sigint = False
286
+
287
+ def cancel_callback_sync():
288
+ nonlocal did_sigint
289
+ # We only want one sigint even if multiple inputs are cancelled
290
+ # A second sigint would forcibly shut down the event loop and spew
291
+ # out a bunch of tracebacks, which we only want to happen in case
292
+ # the worker kills this process after a failed self-termination
293
+ if not did_sigint:
294
+ did_sigint = True
295
+ logger.warning(
296
+ "User cancelling input of non-async functions with allow_concurrent_inputs > 1.\n"
297
+ "This shuts down the container, causing concurrently running inputs to be "
298
+ "rescheduled in other containers."
764
299
  )
765
- )
766
-
767
- item_count = 0
768
- async for value in res:
769
- await function_io_manager._queue_put.aio(generator_queue, value)
770
- item_count += 1
300
+ os.kill(os.getpid(), signal.SIGINT)
771
301
 
772
- await function_io_manager._queue_put.aio(generator_queue, _FunctionIOManager._GENERATOR_STOP_SENTINEL)
773
- await generator_output_task # Wait to finish sending generator outputs.
774
- message = api_pb2.GeneratorDone(items_total=item_count)
775
- await function_io_manager.push_output.aio(
776
- input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE
777
- )
778
- else:
779
- if not inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
780
- raise InvalidError(
781
- f"Async (non-generator) function returned value of type {type(res)}"
782
- " You might need to use @stub.function(..., is_generator=True)."
783
- )
784
- value = await res
785
- await function_io_manager.push_output.aio(input_id, started_at, value, imp_fun.data_format)
786
- reset_context()
302
+ async def run_concurrent_inputs():
303
+ # all run_input coroutines will have completed by the time we leave the execution context
304
+ # but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s
305
+ # for them to resolve gracefully:
306
+ async with TaskContext(0.01) as task_context:
307
+ async for io_context in container_io_manager.run_inputs_outputs.aio(
308
+ finalized_functions, batch_max_size, batch_wait_ms
309
+ ):
310
+ # Note that run_inputs_outputs will not return until all the input slots are released
311
+ # so that they can be acquired by the run_inputs_outputs finalizer
312
+ # This prevents leaving the task_context before outputs have been created
313
+ # TODO: refactor to make this a bit more easy to follow?
314
+ if io_context.finalized_function.is_async:
315
+ input_task = task_context.create_task(run_input_async(io_context))
316
+ io_context.set_cancel_callback(make_async_cancel_callback(input_task))
317
+ else:
318
+ # run sync input in thread
319
+ thread_pool.submit(run_input_sync, io_context)
320
+ io_context.set_cancel_callback(cancel_callback_sync)
787
321
 
788
- if imp_fun.input_concurrency > 1:
789
- # all run_input coroutines will have completed by the time we leave the execution context
790
- # but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s
791
- # for them to resolve gracefully:
792
- async with TaskContext(0.01) as execution_context:
793
- async for input_id, function_call_id, args, kwargs in function_io_manager.run_inputs_outputs.aio(
794
- imp_fun.input_concurrency
795
- ):
796
- # Note that run_inputs_outputs will not return until the concurrency semaphore has
797
- # released all its slots so that they can be acquired by the run_inputs_outputs finalizer
798
- # This prevents leaving the execution_context before outputs have been created
799
- # TODO: refactor to make this a bit more easy to follow?
800
- execution_context.create_task(run_input(input_id, function_call_id, args, kwargs))
801
- else:
802
- async for input_id, function_call_id, args, kwargs in function_io_manager.run_inputs_outputs.aio(
803
- imp_fun.input_concurrency
804
- ):
805
- await run_input(input_id, function_call_id, args, kwargs)
806
-
807
-
808
- @dataclass
809
- class ImportedFunction:
810
- obj: Any
811
- fun: Callable
812
- stub: Optional[_Stub]
813
- is_async: bool
814
- is_generator: bool
815
- data_format: int # api_pb2.DataFormat
816
- input_concurrency: int
817
- is_auto_snapshot: bool
818
- function: _Function
819
-
820
-
821
- def import_function(
822
- function_def: api_pb2.Function,
823
- ser_cls,
824
- ser_fun,
825
- ser_params: Optional[bytes],
826
- function_io_manager,
827
- client: Client,
828
- ) -> ImportedFunction:
829
- """Imports a function dynamically, and locates the stub.
830
-
831
- This is somewhat complex because we're dealing with 3 quite different type of functions:
832
- 1. Functions defined in global scope and decorated in global scope (Function objects)
833
- 2. Functions defined in global scope but decorated elsewhere (these will be raw callables)
834
- 3. Serialized functions
835
-
836
- In addition, we also need to handle
837
- * Normal functions
838
- * Methods on classes (in which case we need to instantiate the object)
839
-
840
- This helper also handles web endpoints, ASGI/WSGI servers, and HTTP servers.
841
-
842
- In order to locate the stub, we try two things:
843
- * If the function is a Function, we can get the stub directly from it
844
- * Otherwise, use the stub name and look it up from a global list of stubs: this
845
- typically only happens in case 2 above, or in sometimes for case 3
846
-
847
- Note that `import_function` is *not* synchronized, becase we need it to run on the main
848
- thread. This is so that any user code running in global scope (which executes as a part of
849
- the import) runs on the right thread.
850
- """
851
- module: Optional[ModuleType] = None
852
- cls: Optional[Type] = None
853
- fun: Callable
854
- function: Optional[_Function] = None
855
- active_stub: Optional[_Stub] = None
856
- pty_info: api_pb2.PTYInfo = function_def.pty_info
857
-
858
- if ser_fun is not None:
859
- # This is a serialized function we already fetched from the server
860
- cls, fun = ser_cls, ser_fun
322
+ user_code_event_loop.run(run_concurrent_inputs())
861
323
  else:
862
- # Load the module dynamically
863
- module = importlib.import_module(function_def.module_name)
864
- qual_name: str = function_def.function_name
865
-
866
- if not is_global_function(qual_name):
867
- raise LocalFunctionError("Attempted to load a function defined in a function scope")
868
-
869
- parts = qual_name.split(".")
870
- if len(parts) == 1:
871
- # This is a function
872
- cls = None
873
- f = getattr(module, qual_name)
874
- if isinstance(f, Function):
875
- function = synchronizer._translate_in(f)
876
- fun = function.get_raw_f()
877
- active_stub = function._stub
878
- else:
879
- fun = f
880
- elif len(parts) == 2:
881
- # This is a method on a class
882
- cls_name, fun_name = parts
883
- cls = getattr(module, cls_name)
884
- if isinstance(cls, Cls):
885
- # The cls decorator is in global scope
886
- _cls = synchronizer._translate_in(cls)
887
- fun = _cls._callables[fun_name]
888
- function = _cls._functions.get(fun_name)
889
- active_stub = _cls._stub
890
- else:
891
- # This is a raw class
892
- fun = getattr(cls, fun_name)
893
- else:
894
- raise InvalidError(f"Invalid function qualname {qual_name}")
895
-
896
- # If the cls/function decorator was applied in local scope, but the stub is global, we can look it up
897
- if active_stub is None:
898
- # This branch is reached in the special case that the imported function is 1) not serialized, and 2) isn't a FunctionHandle - i.e, not decorated at definition time
899
- # Look at all instantiated stubs - if there is only one with the indicated name, use that one
900
- stub_name: Optional[str] = function_def.stub_name or None # coalesce protobuf field to None
901
- matching_stubs = _Stub._all_stubs.get(stub_name, [])
902
- if len(matching_stubs) > 1:
903
- if stub_name is not None:
904
- warning_sub_message = f"stub with the same name ('{stub_name}')"
324
+ for io_context in container_io_manager.run_inputs_outputs(finalized_functions, batch_max_size, batch_wait_ms):
325
+ if io_context.finalized_function.is_async:
326
+ user_code_event_loop.run(run_input_async(io_context))
905
327
  else:
906
- warning_sub_message = "unnamed stub"
907
- logger.warning(
908
- f"You have more than one {warning_sub_message}. It's recommended to name all your Stubs uniquely when using multiple stubs"
909
- )
910
- elif len(matching_stubs) == 1:
911
- (active_stub,) = matching_stubs
912
- # there could also technically be zero found stubs, but that should probably never be an issue since that would mean user won't use is_inside or other function handles anyway
913
-
914
- # Check this property before we turn it into a method (overriden by webhooks)
915
- is_async = get_is_async(fun)
916
-
917
- # Use the function definition for whether this is a generator (overriden by webhooks)
918
- is_generator = function_def.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR
919
-
920
- # What data format is used for function inputs and outputs
921
- data_format = api_pb2.DATA_FORMAT_PICKLE
922
-
923
- # Container can fetch multiple inputs simultaneously
924
- if pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
925
- # Concurrency doesn't apply for `modal shell`.
926
- input_concurrency = 1
927
- else:
928
- input_concurrency = function_def.allow_concurrent_inputs or 1
328
+ # Set up a custom signal handler for `SIGUSR1`, which gets translated to an InputCancellation
329
+ # during function execution. This is sent to cancel inputs from the user
330
+ def _cancel_input_signal_handler(signum, stackframe):
331
+ raise InputCancellation("Input was cancelled by user")
929
332
 
930
- # Instantiate the class if it's defined
931
- if cls:
932
- if ser_params:
933
- _client: _Client = synchronizer._translate_in(client)
934
- args, kwargs = deserialize(ser_params, _client)
333
+ usr1_handler = signal.signal(signal.SIGUSR1, _cancel_input_signal_handler)
334
+ # run this sync code in the main thread, blocking the "userland" event loop
335
+ # this lets us cancel it using a signal handler that raises an exception
336
+ try:
337
+ run_input_sync(io_context)
338
+ finally:
339
+ signal.signal(signal.SIGUSR1, usr1_handler) # reset signal handler
340
+
341
+
342
+ def get_active_app_fallback(function_def: api_pb2.Function) -> _App:
343
+ # This branch is reached in the special case that the imported function/class is:
344
+ # 1) not serialized, and
345
+ # 2) isn't a FunctionHandle - i.e, not decorated at definition time
346
+ # Look at all instantiated apps - if there is only one with the indicated name, use that one
347
+ app_name: Optional[str] = function_def.app_name or None # coalesce protobuf field to None
348
+ matching_apps = _App._all_apps.get(app_name, [])
349
+ if len(matching_apps) == 1:
350
+ active_app: _App = matching_apps[0]
351
+ return active_app
352
+
353
+ if len(matching_apps) > 1:
354
+ if app_name is not None:
355
+ warning_sub_message = f"app with the same name ('{app_name}')"
935
356
  else:
936
- args, kwargs = (), {}
937
- obj = cls(*args, **kwargs)
938
- if isinstance(cls, Cls):
939
- obj = obj.get_obj()
940
- # Bind the function to the instance (using the descriptor protocol!)
941
- fun = fun.__get__(obj)
942
- else:
943
- obj = None
944
-
945
- if function_def.webhook_config.type:
946
- is_async = True
947
- is_generator = True
948
- data_format = api_pb2.DATA_FORMAT_ASGI
949
-
950
- if function_def.webhook_config.type == api_pb2.WEBHOOK_TYPE_ASGI_APP:
951
- # Function returns an asgi_app, which we can use as a callable.
952
- fun = asgi_app_wrapper(fun(), function_io_manager)
953
-
954
- elif function_def.webhook_config.type == api_pb2.WEBHOOK_TYPE_WSGI_APP:
955
- # Function returns an wsgi_app, which we can use as a callable.
956
- fun = wsgi_app_wrapper(fun(), function_io_manager)
957
-
958
- elif function_def.webhook_config.type == api_pb2.WEBHOOK_TYPE_FUNCTION:
959
- # Function is a webhook without an ASGI app. Create one for it.
960
- fun = asgi_app_wrapper(
961
- webhook_asgi_app(fun, function_def.webhook_config.method),
962
- function_io_manager,
963
- )
357
+ warning_sub_message = "unnamed app"
358
+ logger.warning(
359
+ f"You have more than one {warning_sub_message}. "
360
+ "It's recommended to name all your Apps uniquely when using multiple apps"
361
+ )
964
362
 
965
- elif function_def.webhook_config.type == api_pb2.WEBHOOK_TYPE_WEB_SERVER:
966
- # Function spawns an HTTP web server listening at a port.
967
- fun()
968
-
969
- # We intentionally try to connect to the external interface instead of the loopback
970
- # interface here so users are forced to expose the server. This allows us to potentially
971
- # change the implementation to use an external bridge in the future.
972
- host = get_ip_address(b"eth0")
973
- port = function_def.webhook_config.web_server_port
974
- startup_timeout = function_def.webhook_config.web_server_startup_timeout
975
- wait_for_web_server(host, port, timeout=startup_timeout)
976
- fun = asgi_app_wrapper(web_server_proxy(host, port), function_io_manager)
977
-
978
- else:
979
- raise InvalidError(f"Unrecognized web endpoint type {function_def.webhook_config.type}")
980
-
981
- return ImportedFunction(
982
- obj,
983
- fun,
984
- active_stub,
985
- is_async,
986
- is_generator,
987
- data_format,
988
- input_concurrency,
989
- function_def.is_auto_snapshot,
990
- function,
991
- )
363
+ # If we don't have an active app, create one on the fly
364
+ # The app object is used to carry the app layout etc
365
+ return _App()
992
366
 
993
367
 
994
368
  def call_lifecycle_functions(
995
369
  event_loop: UserCodeEventLoop,
996
- function_io_manager, #: FunctionIOManager, TODO: this type is generated at runtime
997
- funcs: Iterable[Callable],
370
+ container_io_manager, #: ContainerIOManager, TODO: this type is generated at runtime
371
+ funcs: Sequence[Callable[..., Any]],
998
372
  ) -> None:
999
373
  """Call function(s), can be sync or async, but any return values are ignored."""
1000
- with function_io_manager.handle_user_exception():
374
+ with container_io_manager.handle_user_exception():
1001
375
  for func in funcs:
1002
376
  # We are deprecating parameterized exit methods but want to gracefully handle old code.
1003
377
  # We can remove this once the deprecation in the actual @exit decorator is enforced.
1004
- args = (None, None, None) if method_has_params(func) else ()
1005
- res = func(
1006
- *args
1007
- ) # in case func is non-async, it's executed here and sigint will by default interrupt it using a KeyboardInterrupt exception
378
+ args = (None, None, None) if callable_has_non_self_params(func) else ()
379
+ # in case func is non-async, it's executed here and sigint will by default
380
+ # interrupt it using a KeyboardInterrupt exception
381
+ res = func(*args)
1008
382
  if inspect.iscoroutine(res):
1009
383
  # if however func is async, we have to jump through some hoops
1010
384
  event_loop.run(res)
1011
385
 
1012
386
 
387
+ def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function, _client: "modal.client._Client"):
388
+ if function_def.class_parameter_info.format in (
389
+ api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_UNSPECIFIED,
390
+ api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PICKLE,
391
+ ):
392
+ # legacy serialization format - pickle of `(args, kwargs)` w/ support for modal object arguments
393
+ param_args, param_kwargs = deserialize(serialized_params, _client)
394
+ elif function_def.class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO:
395
+ param_args = ()
396
+ param_kwargs = deserialize_proto_params(serialized_params, list(function_def.class_parameter_info.schema))
397
+ else:
398
+ raise ExecutionError(
399
+ f"Unknown class parameter serialization format: {function_def.class_parameter_info.format}"
400
+ )
401
+
402
+ return param_args, param_kwargs
403
+
404
+
1013
405
  def main(container_args: api_pb2.ContainerArguments, client: Client):
1014
- # This is a bit weird but we need both the blocking and async versions of FunctionIOManager.
406
+ # This is a bit weird but we need both the blocking and async versions of ContainerIOManager.
1015
407
  # At some point, we should fix that by having built-in support for running "user code"
1016
- function_io_manager = FunctionIOManager(container_args, client)
408
+ container_io_manager = ContainerIOManager(container_args, client)
409
+ active_app: _App
410
+ service: Service
411
+ function_def = container_args.function_def
412
+ is_auto_snapshot: bool = function_def.is_auto_snapshot
413
+ # The worker sets this flag to "1" for snapshot and restore tasks. Otherwise, this flag is unset,
414
+ # in which case snapshots should be disabled.
415
+ is_snapshotting_function = (
416
+ function_def.is_checkpointing_function and os.environ.get("MODAL_ENABLE_SNAP_RESTORE", "0") == "1"
417
+ )
418
+
419
+ _client: _Client = synchronizer._translate_in(client) # TODO(erikbern): ugly
1017
420
 
1018
- # Define a global app (need to do this before imports).
1019
- container_app: ContainerApp = function_io_manager.initialize_app()
421
+ # Call ContainerHello - currently a noop but might be used later for things
422
+ container_io_manager.hello()
1020
423
 
1021
- with function_io_manager.heartbeats(), UserCodeEventLoop() as event_loop:
424
+ with container_io_manager.heartbeats(is_snapshotting_function), UserCodeEventLoop() as event_loop:
1022
425
  # If this is a serialized function, fetch the definition from the server
1023
- if container_args.function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED:
1024
- ser_cls, ser_fun = function_io_manager.get_serialized_function()
426
+ if function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED:
427
+ ser_cls, ser_fun = container_io_manager.get_serialized_function()
1025
428
  else:
1026
429
  ser_cls, ser_fun = None, None
1027
430
 
1028
431
  # Initialize the function, importing user code.
1029
- with function_io_manager.handle_user_exception():
1030
- imp_fun = import_function(
1031
- container_args.function_def,
1032
- ser_cls,
1033
- ser_fun,
1034
- container_args.serialized_params,
1035
- function_io_manager,
1036
- client,
1037
- )
432
+ with container_io_manager.handle_user_exception():
433
+ if container_args.serialized_params:
434
+ param_args, param_kwargs = deserialize_params(container_args.serialized_params, function_def, _client)
435
+ else:
436
+ param_args = ()
437
+ param_kwargs = {}
438
+
439
+ if function_def.is_class:
440
+ service = import_class_service(
441
+ function_def,
442
+ ser_cls,
443
+ param_args,
444
+ param_kwargs,
445
+ )
446
+ else:
447
+ service = import_single_function_service(
448
+ function_def,
449
+ ser_cls,
450
+ ser_fun,
451
+ param_args,
452
+ param_kwargs,
453
+ )
454
+
455
+ # If the cls/function decorator was applied in local scope, but the app is global, we can look it up
456
+ if service.app is not None:
457
+ active_app = service.app
458
+ else:
459
+ # if the app can't be inferred by the imported function, use name-based fallback
460
+ active_app = get_active_app_fallback(function_def)
461
+
462
+ if function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
463
+ # Concurrency and batching doesn't apply for `modal shell`.
464
+ batch_max_size = 0
465
+ batch_wait_ms = 0
466
+ else:
467
+ batch_max_size = function_def.batch_max_size or 0
468
+ batch_wait_ms = function_def.batch_linger_ms or 0
1038
469
 
1039
- # Initialize objects on the stub.
1040
- if imp_fun.stub is not None:
1041
- container_app.associate_stub_container(imp_fun.stub)
470
+ # Get ids and metadata for objects (primarily functions and classes) on the app
471
+ container_app: RunningApp = container_io_manager.get_app_objects(container_args.app_layout)
472
+
473
+ # Initialize objects on the app.
474
+ # This is basically only functions and classes - anything else is deprecated and will be unsupported soon
475
+ app: App = synchronizer._translate_out(active_app)
476
+ app._init_container(client, container_app)
1042
477
 
1043
478
  # Hydrate all function dependencies.
1044
479
  # TODO(erikbern): we an remove this once we
1045
480
  # 1. Enable lazy hydration for all objects
1046
481
  # 2. Fully deprecate .new() objects
1047
- if imp_fun.function:
1048
- dep_object_ids: List[str] = [dep.object_id for dep in container_args.function_def.object_dependencies]
1049
- container_app.hydrate_function_deps(imp_fun.function, dep_object_ids)
482
+ if service.code_deps is not None: # this is not set for serialized or non-global scope functions
483
+ dep_object_ids: list[str] = [dep.object_id for dep in function_def.object_dependencies]
484
+ if len(service.code_deps) != len(dep_object_ids):
485
+ raise ExecutionError(
486
+ f"Function has {len(service.code_deps)} dependencies"
487
+ f" but container got {len(dep_object_ids)} object ids.\n"
488
+ f"Code deps: {service.code_deps}\n"
489
+ f"Object ids: {dep_object_ids}"
490
+ )
491
+ for object_id, obj in zip(dep_object_ids, service.code_deps):
492
+ metadata: Message = container_app.object_handle_metadata[object_id]
493
+ obj._hydrate(object_id, _client, metadata)
494
+
495
+ # Initialize clustered functions.
496
+ if function_def._experimental_group_size > 0:
497
+ initialize_clustered_function(
498
+ client,
499
+ container_args.task_id,
500
+ function_def._experimental_group_size,
501
+ )
1050
502
 
1051
- # Identify all "enter" methods that need to run before we checkpoint.
1052
- if imp_fun.obj is not None and not imp_fun.is_auto_snapshot:
1053
- pre_checkpoint_methods = _find_callables_for_obj(imp_fun.obj, _PartialFunctionFlags.ENTER_PRE_CHECKPOINT)
1054
- call_lifecycle_functions(event_loop, function_io_manager, pre_checkpoint_methods.values())
503
+ # Identify all "enter" methods that need to run before we snapshot.
504
+ if service.user_cls_instance is not None and not is_auto_snapshot:
505
+ pre_snapshot_methods = _find_callables_for_obj(
506
+ service.user_cls_instance, _PartialFunctionFlags.ENTER_PRE_SNAPSHOT
507
+ )
508
+ call_lifecycle_functions(event_loop, container_io_manager, list(pre_snapshot_methods.values()))
1055
509
 
1056
510
  # If this container is being used to create a checkpoint, checkpoint the container after
1057
- # global imports and innitialization. Checkpointed containers run from this point onwards.
1058
- if container_args.function_def.is_checkpointing_function:
1059
- function_io_manager.checkpoint()
511
+ # global imports and initialization. Checkpointed containers run from this point onwards.
512
+ if is_snapshotting_function:
513
+ container_io_manager.memory_snapshot()
1060
514
 
1061
515
  # Install hooks for interactive functions.
1062
- if container_args.function_def.pty_info.pty_type != api_pb2.PTYInfo.PTY_TYPE_UNSPECIFIED:
516
+ def breakpoint_wrapper():
517
+ # note: it would be nice to not have breakpoint_wrapper() included in the backtrace
518
+ container_io_manager.interact(from_breakpoint=True)
519
+ import pdb
1063
520
 
1064
- def breakpoint_wrapper():
1065
- # note: it would be nice to not have breakpoint_wrapper() included in the backtrace
1066
- interact()
1067
- import pdb
521
+ frame = inspect.currentframe().f_back
1068
522
 
1069
- pdb.set_trace()
523
+ pdb.Pdb().set_trace(frame)
1070
524
 
1071
- sys.breakpointhook = breakpoint_wrapper
525
+ sys.breakpointhook = breakpoint_wrapper
1072
526
 
1073
- # Identify the "enter" methods to run after resuming from a checkpoint.
1074
- if imp_fun.obj is not None and not imp_fun.is_auto_snapshot:
1075
- post_checkpoint_methods = _find_callables_for_obj(imp_fun.obj, _PartialFunctionFlags.ENTER_POST_CHECKPOINT)
1076
- call_lifecycle_functions(event_loop, function_io_manager, post_checkpoint_methods.values())
527
+ # Identify the "enter" methods to run after resuming from a snapshot.
528
+ if service.user_cls_instance is not None and not is_auto_snapshot:
529
+ post_snapshot_methods = _find_callables_for_obj(
530
+ service.user_cls_instance, _PartialFunctionFlags.ENTER_POST_SNAPSHOT
531
+ )
532
+ call_lifecycle_functions(event_loop, container_io_manager, list(post_snapshot_methods.values()))
1077
533
 
534
+ with container_io_manager.handle_user_exception():
535
+ finalized_functions = service.get_finalized_functions(function_def, container_io_manager)
1078
536
  # Execute the function.
537
+ lifespan_background_tasks = []
1079
538
  try:
1080
- if imp_fun.is_async:
1081
- event_loop.run(call_function_async(function_io_manager, imp_fun))
1082
- else:
1083
- # Set up a signal handler for `SIGUSR1`, which gets translated to an InputCancellation
1084
- # during function execution. This is sent to cancel inputs from the user.
1085
- def _cancel_input_signal_handler(signum, stackframe):
1086
- raise InputCancellation("Input was cancelled by user")
1087
-
1088
- signal.signal(signal.SIGUSR1, _cancel_input_signal_handler)
1089
-
1090
- call_function_sync(function_io_manager, imp_fun)
539
+ for finalized_function in finalized_functions.values():
540
+ if finalized_function.lifespan_manager:
541
+ lifespan_background_tasks.append(
542
+ event_loop.create_task(finalized_function.lifespan_manager.background_task())
543
+ )
544
+ with container_io_manager.handle_user_exception():
545
+ event_loop.run(finalized_function.lifespan_manager.lifespan_startup())
546
+ call_function(
547
+ event_loop,
548
+ container_io_manager,
549
+ finalized_functions,
550
+ batch_max_size,
551
+ batch_wait_ms,
552
+ )
1091
553
  finally:
1092
554
  # Run exit handlers. From this point onward, ignore all SIGINT signals that come from
1093
555
  # graceful shutdowns originating on the worker, as well as stray SIGUSR1 signals that
@@ -1096,15 +558,27 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
1096
558
  usr1_handler = signal.signal(signal.SIGUSR1, signal.SIG_IGN)
1097
559
 
1098
560
  try:
1099
- # Identify "exit" methods and run them.
1100
- if imp_fun.obj is not None and not imp_fun.is_auto_snapshot:
1101
- exit_methods = _find_callables_for_obj(imp_fun.obj, _PartialFunctionFlags.EXIT)
1102
- call_lifecycle_functions(event_loop, function_io_manager, exit_methods.values())
561
+ try:
562
+ # run lifespan shutdown for asgi apps
563
+ for finalized_function in finalized_functions.values():
564
+ if finalized_function.lifespan_manager:
565
+ with container_io_manager.handle_user_exception():
566
+ event_loop.run(finalized_function.lifespan_manager.lifespan_shutdown())
567
+ finally:
568
+ # no need to keep the lifespan asgi call around - we send it no more messages
569
+ for lifespan_background_task in lifespan_background_tasks:
570
+ lifespan_background_task.cancel() # prevent dangling tasks
571
+
572
+ # Identify "exit" methods and run them.
573
+ # want to make sure this is called even if the lifespan manager fails
574
+ if service.user_cls_instance is not None and not is_auto_snapshot:
575
+ exit_methods = _find_callables_for_obj(service.user_cls_instance, _PartialFunctionFlags.EXIT)
576
+ call_lifecycle_functions(event_loop, container_io_manager, list(exit_methods.values()))
1103
577
 
1104
578
  # Finally, commit on exit to catch uncommitted volume changes and surface background
1105
579
  # commit errors.
1106
- function_io_manager.volume_commit(
1107
- [v.volume_id for v in container_args.function_def.volume_mounts if v.allow_background_commits]
580
+ container_io_manager.volume_commit(
581
+ [v.volume_id for v in function_def.volume_mounts if v.allow_background_commits]
1108
582
  )
1109
583
  finally:
1110
584
  # Restore the original signal handler, needed for container_test hygiene since the
@@ -1117,7 +591,15 @@ if __name__ == "__main__":
1117
591
  logger.debug("Container: starting")
1118
592
 
1119
593
  container_args = api_pb2.ContainerArguments()
1120
- container_args.ParseFromString(base64.b64decode(sys.argv[1]))
594
+
595
+ container_arguments_path: Optional[str] = os.environ.get("MODAL_CONTAINER_ARGUMENTS_PATH")
596
+ if container_arguments_path is None:
597
+ # TODO(erikbern): this fallback is for old workers and we can remove it very soon (days)
598
+ import base64
599
+
600
+ container_args.ParseFromString(base64.b64decode(sys.argv[1]))
601
+ else:
602
+ container_args.ParseFromString(open(container_arguments_path, "rb").read())
1121
603
 
1122
604
  # Note that we're creating the client in a synchronous context, but it will be running in a separate thread.
1123
605
  # This is good because if the function is long running then we the client can still send heartbeats
@@ -1137,7 +619,7 @@ if __name__ == "__main__":
1137
619
  # from shutting down. The sleep(0) here is needed for finished ThreadPoolExecutor resources to
1138
620
  # shut down without triggering this warning (e.g., `@wsgi_app()`).
1139
621
  time.sleep(0)
1140
- lingering_threads: List[threading.Thread] = []
622
+ lingering_threads: list[threading.Thread] = []
1141
623
  for thread in threading.enumerate():
1142
624
  current_thread = threading.get_ident()
1143
625
  if thread.ident is not None and thread.ident != current_thread and not thread.daemon and thread.is_alive():
@@ -1145,7 +627,8 @@ if __name__ == "__main__":
1145
627
  if lingering_threads:
1146
628
  thread_names = ", ".join(t.name for t in lingering_threads)
1147
629
  logger.warning(
1148
- f"Detected {len(lingering_threads)} background thread(s) [{thread_names}] still running after container exit. This will prevent runner shutdown for up to 30 seconds."
630
+ f"Detected {len(lingering_threads)} background thread(s) [{thread_names}] still running "
631
+ "after container exit. This will prevent runner shutdown for up to 30 seconds."
1149
632
  )
1150
633
 
1151
634
  logger.debug("Container: done")