modal 0.62.115__py3-none-any.whl → 0.72.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +13 -9
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +407 -398
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -60
  11. modal/_resources.py +26 -7
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1036 -0
  15. modal/{execution_context.py → _runtime/execution_context.py} +11 -2
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +123 -6
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +50 -14
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +386 -104
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +299 -98
  29. modal/_utils/grpc_testing.py +47 -34
  30. modal/_utils/grpc_utils.py +54 -21
  31. modal/_utils/hash_utils.py +51 -10
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +3 -3
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +12 -10
  43. modal/app.py +561 -323
  44. modal/app.pyi +474 -262
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +22 -6
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +203 -42
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +61 -13
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +21 -48
  55. modal/cli/launch.py +28 -14
  56. modal/cli/network_file_system.py +57 -21
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +34 -9
  59. modal/cli/programs/vscode.py +58 -8
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +199 -96
  62. modal/cli/secret.py +5 -4
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +74 -8
  65. modal/cli/volume.py +97 -56
  66. modal/client.py +248 -144
  67. modal/client.pyi +156 -124
  68. modal/cloud_bucket_mount.py +43 -30
  69. modal/cloud_bucket_mount.pyi +32 -25
  70. modal/cls.py +528 -141
  71. modal/cls.pyi +189 -145
  72. modal/config.py +32 -15
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +50 -54
  76. modal/dict.pyi +120 -164
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +30 -43
  80. modal/experimental.py +62 -2
  81. modal/file_io.py +537 -0
  82. modal/file_io.pyi +235 -0
  83. modal/file_pattern_matcher.py +197 -0
  84. modal/functions.py +846 -428
  85. modal/functions.pyi +446 -387
  86. modal/gpu.py +57 -44
  87. modal/image.py +946 -417
  88. modal/image.pyi +584 -245
  89. modal/io_streams.py +434 -0
  90. modal/io_streams.pyi +122 -0
  91. modal/mount.py +223 -90
  92. modal/mount.pyi +241 -243
  93. modal/network_file_system.py +85 -86
  94. modal/network_file_system.pyi +151 -110
  95. modal/object.py +66 -36
  96. modal/object.pyi +166 -143
  97. modal/output.py +63 -0
  98. modal/parallel_map.py +73 -47
  99. modal/parallel_map.pyi +51 -63
  100. modal/partial_function.py +272 -107
  101. modal/partial_function.pyi +219 -120
  102. modal/proxy.py +15 -12
  103. modal/proxy.pyi +3 -8
  104. modal/queue.py +96 -72
  105. modal/queue.pyi +210 -135
  106. modal/requirements/2024.04.txt +2 -1
  107. modal/requirements/2024.10.txt +16 -0
  108. modal/requirements/README.md +21 -0
  109. modal/requirements/base-images.json +22 -0
  110. modal/retries.py +45 -4
  111. modal/runner.py +325 -203
  112. modal/runner.pyi +124 -110
  113. modal/running_app.py +27 -4
  114. modal/sandbox.py +509 -231
  115. modal/sandbox.pyi +396 -169
  116. modal/schedule.py +2 -2
  117. modal/scheduler_placement.py +20 -3
  118. modal/secret.py +41 -25
  119. modal/secret.pyi +62 -42
  120. modal/serving.py +39 -49
  121. modal/serving.pyi +37 -43
  122. modal/stream_type.py +15 -0
  123. modal/token_flow.py +5 -3
  124. modal/token_flow.pyi +37 -32
  125. modal/volume.py +123 -137
  126. modal/volume.pyi +228 -221
  127. {modal-0.62.115.dist-info → modal-0.72.11.dist-info}/METADATA +5 -5
  128. modal-0.72.11.dist-info/RECORD +174 -0
  129. {modal-0.62.115.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
  130. modal_docs/gen_reference_docs.py +3 -1
  131. modal_docs/mdmd/mdmd.py +0 -1
  132. modal_docs/mdmd/signatures.py +1 -2
  133. modal_global_objects/images/base_images.py +28 -0
  134. modal_global_objects/mounts/python_standalone.py +2 -2
  135. modal_proto/__init__.py +1 -1
  136. modal_proto/api.proto +1231 -531
  137. modal_proto/api_grpc.py +750 -430
  138. modal_proto/api_pb2.py +2102 -1176
  139. modal_proto/api_pb2.pyi +8859 -0
  140. modal_proto/api_pb2_grpc.py +1329 -675
  141. modal_proto/api_pb2_grpc.pyi +1416 -0
  142. modal_proto/modal_api_grpc.py +149 -0
  143. modal_proto/modal_options_grpc.py +3 -0
  144. modal_proto/options_pb2.pyi +20 -0
  145. modal_proto/options_pb2_grpc.pyi +7 -0
  146. modal_proto/py.typed +0 -0
  147. modal_version/__init__.py +1 -1
  148. modal_version/_version_generated.py +2 -2
  149. modal/_asgi.py +0 -370
  150. modal/_container_exec.py +0 -128
  151. modal/_container_io_manager.py +0 -646
  152. modal/_container_io_manager.pyi +0 -412
  153. modal/_sandbox_shell.py +0 -49
  154. modal/app_utils.py +0 -20
  155. modal/app_utils.pyi +0 -17
  156. modal/execution_context.pyi +0 -37
  157. modal/shared_volume.py +0 -23
  158. modal/shared_volume.pyi +0 -24
  159. modal-0.62.115.dist-info/RECORD +0 -207
  160. modal_global_objects/images/conda.py +0 -15
  161. modal_global_objects/images/debian_slim.py +0 -15
  162. modal_global_objects/images/micromamba.py +0 -15
  163. test/__init__.py +0 -1
  164. test/aio_test.py +0 -12
  165. test/async_utils_test.py +0 -279
  166. test/blob_test.py +0 -67
  167. test/cli_imports_test.py +0 -149
  168. test/cli_test.py +0 -674
  169. test/client_test.py +0 -203
  170. test/cloud_bucket_mount_test.py +0 -22
  171. test/cls_test.py +0 -636
  172. test/config_test.py +0 -149
  173. test/conftest.py +0 -1485
  174. test/container_app_test.py +0 -50
  175. test/container_test.py +0 -1405
  176. test/cpu_test.py +0 -23
  177. test/decorator_test.py +0 -85
  178. test/deprecation_test.py +0 -34
  179. test/dict_test.py +0 -51
  180. test/e2e_test.py +0 -68
  181. test/error_test.py +0 -7
  182. test/function_serialization_test.py +0 -32
  183. test/function_test.py +0 -791
  184. test/function_utils_test.py +0 -101
  185. test/gpu_test.py +0 -159
  186. test/grpc_utils_test.py +0 -82
  187. test/helpers.py +0 -47
  188. test/image_test.py +0 -814
  189. test/live_reload_test.py +0 -80
  190. test/lookup_test.py +0 -70
  191. test/mdmd_test.py +0 -329
  192. test/mount_test.py +0 -162
  193. test/mounted_files_test.py +0 -327
  194. test/network_file_system_test.py +0 -188
  195. test/notebook_test.py +0 -66
  196. test/object_test.py +0 -41
  197. test/package_utils_test.py +0 -25
  198. test/queue_test.py +0 -115
  199. test/resolver_test.py +0 -59
  200. test/retries_test.py +0 -67
  201. test/runner_test.py +0 -85
  202. test/sandbox_test.py +0 -191
  203. test/schedule_test.py +0 -15
  204. test/scheduler_placement_test.py +0 -57
  205. test/secret_test.py +0 -89
  206. test/serialization_test.py +0 -50
  207. test/stub_composition_test.py +0 -10
  208. test/stub_test.py +0 -361
  209. test/test_asgi_wrapper.py +0 -234
  210. test/token_flow_test.py +0 -18
  211. test/traceback_test.py +0 -135
  212. test/tunnel_test.py +0 -29
  213. test/utils_test.py +0 -88
  214. test/version_test.py +0 -14
  215. test/volume_test.py +0 -397
  216. test/watcher_test.py +0 -58
  217. test/webhook_test.py +0 -145
  218. {modal-0.62.115.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
  219. {modal-0.62.115.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
  220. {modal-0.62.115.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
@@ -1,63 +1,103 @@
1
1
  # Copyright Modal Labs 2022
2
+ # ruff: noqa: E402
3
+ import os
4
+
5
+ from modal._runtime.user_code_imports import Service, import_class_service, import_single_function_service
6
+
7
+ telemetry_socket = os.environ.get("MODAL_TELEMETRY_SOCKET")
8
+ if telemetry_socket:
9
+ from ._runtime.telemetry import instrument_imports
10
+
11
+ instrument_imports(telemetry_socket)
12
+
2
13
  import asyncio
3
- import base64
4
- import importlib
14
+ import concurrent.futures
5
15
  import inspect
16
+ import queue
6
17
  import signal
7
18
  import sys
8
19
  import threading
9
20
  import time
10
- from dataclasses import dataclass
11
- from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Type
21
+ from collections.abc import Sequence
22
+ from typing import TYPE_CHECKING, Any, Callable, Optional
12
23
 
13
24
  from google.protobuf.message import Message
14
- from synchronicity import Interface
15
25
 
26
+ from modal._clustered_functions import initialize_clustered_function
27
+ from modal._proxy_tunnel import proxy_tunnel
28
+ from modal._serialization import deserialize, deserialize_proto_params
29
+ from modal._utils.async_utils import TaskContext, synchronizer
30
+ from modal._utils.function_utils import (
31
+ callable_has_non_self_params,
32
+ )
33
+ from modal.app import App, _App
34
+ from modal.client import Client, _Client
35
+ from modal.config import logger
36
+ from modal.exception import ExecutionError, InputCancellation, InvalidError
37
+ from modal.partial_function import (
38
+ _find_callables_for_obj,
39
+ _PartialFunctionFlags,
40
+ )
41
+ from modal.running_app import RunningApp
16
42
  from modal_proto import api_pb2
17
43
 
18
- from ._asgi import (
19
- asgi_app_wrapper,
20
- get_ip_address,
21
- wait_for_web_server,
22
- web_server_proxy,
23
- webhook_asgi_app,
24
- wsgi_app_wrapper,
44
+ from ._runtime.container_io_manager import (
45
+ ContainerIOManager,
46
+ IOContext,
47
+ UserException,
48
+ _ContainerIOManager,
25
49
  )
26
- from ._container_io_manager import ContainerIOManager, UserException, _ContainerIOManager
27
- from ._proxy_tunnel import proxy_tunnel
28
- from ._serialization import deserialize
29
- from ._utils.async_utils import TaskContext, synchronizer
30
- from ._utils.function_utils import (
31
- LocalFunctionError,
32
- is_async as get_is_async,
33
- is_global_function,
34
- method_has_params,
35
- )
36
- from .app import App, _App
37
- from .client import Client, _Client
38
- from .cls import Cls
39
- from .config import logger
40
- from .exception import ExecutionError, InputCancellation, InvalidError
41
- from .execution_context import _set_current_context_ids, interact
42
- from .functions import Function, _Function
43
- from .partial_function import _find_callables_for_obj, _PartialFunctionFlags
44
- from .running_app import RunningApp
50
+ from ._runtime.execution_context import _set_current_context_ids
45
51
 
46
52
  if TYPE_CHECKING:
47
- from types import ModuleType
53
+ import modal._runtime.container_io_manager
54
+ import modal.object
48
55
 
49
56
 
50
- @dataclass
51
- class ImportedFunction:
52
- obj: Any
53
- fun: Callable
54
- app: Optional[_App]
55
- is_async: bool
56
- is_generator: bool
57
- data_format: int # api_pb2.DataFormat
58
- input_concurrency: int
59
- is_auto_snapshot: bool
60
- function: _Function
57
+ class DaemonizedThreadPool:
58
+ # Used instead of ThreadPoolExecutor, since the latter won't allow
59
+ # the interpreter to shut down before the currently running tasks
60
+ # have finished
61
+ def __init__(self, max_threads: int):
62
+ self.max_threads = max_threads
63
+
64
+ def __enter__(self):
65
+ self.spawned_workers = 0
66
+ self.inputs: queue.Queue[Any] = queue.Queue()
67
+ self.finished = threading.Event()
68
+ return self
69
+
70
+ def __exit__(self, exc_type, exc_value, traceback):
71
+ self.finished.set()
72
+
73
+ if exc_type is None:
74
+ self.inputs.join()
75
+ else:
76
+ # special case - allows us to exit the
77
+ if self.inputs.unfinished_tasks:
78
+ logger.info(
79
+ f"Exiting DaemonizedThreadPool with {self.inputs.unfinished_tasks} active "
80
+ f"inputs due to exception: {repr(exc_type)}"
81
+ )
82
+
83
+ def submit(self, func, *args):
84
+ def worker_thread():
85
+ while not self.finished.is_set():
86
+ try:
87
+ _func, _args = self.inputs.get(timeout=1)
88
+ except queue.Empty:
89
+ continue
90
+ try:
91
+ _func(*_args)
92
+ except BaseException:
93
+ logger.exception(f"Exception raised by {_func} in DaemonizedThreadPool worker!")
94
+ self.inputs.task_done()
95
+
96
+ if self.spawned_workers < self.max_threads:
97
+ threading.Thread(target=worker_thread, daemon=True).start()
98
+ self.spawned_workers += 1
99
+
100
+ self.inputs.put((func, args))
61
101
 
62
102
 
63
103
  class UserCodeEventLoop:
@@ -76,14 +116,25 @@ class UserCodeEventLoop:
76
116
 
77
117
  def __enter__(self):
78
118
  self.loop = asyncio.new_event_loop()
119
+ self.tasks = set()
79
120
  return self
80
121
 
81
122
  def __exit__(self, exc_type, exc_value, traceback):
82
123
  self.loop.run_until_complete(self.loop.shutdown_asyncgens())
83
124
  if sys.version_info[:2] >= (3, 9):
84
125
  self.loop.run_until_complete(self.loop.shutdown_default_executor()) # Introduced in Python 3.9
126
+
127
+ for task in self.tasks:
128
+ task.cancel()
129
+
85
130
  self.loop.close()
86
131
 
132
+ def create_task(self, coro):
133
+ task = self.loop.create_task(coro)
134
+ self.tasks.add(task)
135
+ task.add_done_callback(self.tasks.discard)
136
+ return task
137
+
87
138
  def run(self, coro):
88
139
  task = asyncio.ensure_future(coro, loop=self.loop)
89
140
  self._sigints = 0
@@ -99,7 +150,9 @@ class UserCodeEventLoop:
99
150
  # first sigint is graceful
100
151
  task.cancel()
101
152
  return
102
- raise KeyboardInterrupt() # this should normally not happen, but the second sigint would "hard kill" the event loop!
153
+
154
+ # this should normally not happen, but the second sigint would "hard kill" the event loop!
155
+ raise KeyboardInterrupt()
103
156
 
104
157
  ignore_sigint = signal.getsignal(signal.SIGINT) == signal.SIG_IGN
105
158
  if not ignore_sigint:
@@ -122,111 +175,21 @@ class UserCodeEventLoop:
122
175
  self.loop.remove_signal_handler(signal.SIGINT)
123
176
 
124
177
 
125
- def call_function_sync(
126
- container_io_manager, #: ContainerIOManager, TODO: this type is generated at runtime
127
- imp_fun: ImportedFunction,
178
+ def call_function(
179
+ user_code_event_loop: UserCodeEventLoop,
180
+ container_io_manager: "modal._runtime.container_io_manager.ContainerIOManager",
181
+ finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
182
+ batch_max_size: int,
183
+ batch_wait_ms: int,
128
184
  ):
129
- def run_input(input_id: str, function_call_id: str, args: Any, kwargs: Any) -> None:
185
+ async def run_input_async(io_context: IOContext) -> None:
130
186
  started_at = time.time()
131
- reset_context = _set_current_context_ids(input_id, function_call_id)
132
- with container_io_manager.handle_input_exception(input_id, started_at):
133
- logger.debug(f"Starting input {input_id} (sync)")
134
- res = imp_fun.fun(*args, **kwargs)
135
- logger.debug(f"Finished input {input_id} (sync)")
136
-
187
+ input_ids, function_call_ids = io_context.input_ids, io_context.function_call_ids
188
+ reset_context = _set_current_context_ids(input_ids, function_call_ids)
189
+ async with container_io_manager.handle_input_exception.aio(io_context, started_at):
190
+ res = io_context.call_finalized_function()
137
191
  # TODO(erikbern): any exception below shouldn't be considered a user exception
138
- if imp_fun.is_generator:
139
- if not inspect.isgenerator(res):
140
- raise InvalidError(f"Generator function returned value of type {type(res)}")
141
-
142
- # Send up to this many outputs at a time.
143
- generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024)
144
- generator_output_task = container_io_manager.generator_output_task(
145
- function_call_id,
146
- imp_fun.data_format,
147
- generator_queue,
148
- _future=True, # Synchronicity magic to return a future.
149
- )
150
-
151
- item_count = 0
152
- for value in res:
153
- container_io_manager._queue_put(generator_queue, value)
154
- item_count += 1
155
-
156
- container_io_manager._queue_put(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL)
157
- generator_output_task.result() # Wait to finish sending generator outputs.
158
- message = api_pb2.GeneratorDone(items_total=item_count)
159
- container_io_manager.push_output(input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE)
160
- else:
161
- if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
162
- raise InvalidError(
163
- f"Sync (non-generator) function return value of type {type(res)}."
164
- " You might need to use @app.function(..., is_generator=True)."
165
- )
166
- container_io_manager.push_output(input_id, started_at, res, imp_fun.data_format)
167
- reset_context()
168
-
169
- if imp_fun.input_concurrency > 1:
170
- # We can't use `concurrent.futures.ThreadPoolExecutor` here because in Python 3.11+, this
171
- # class has no workaround that allows us to exit the Python interpreter process without
172
- # waiting for the worker threads to finish. We need this behavior on SIGINT.
173
-
174
- import queue
175
- import threading
176
-
177
- spawned_workers = 0
178
- inputs: queue.Queue[Any] = queue.Queue()
179
- finished = threading.Event()
180
-
181
- def worker_thread():
182
- while not finished.is_set():
183
- try:
184
- args = inputs.get(timeout=1)
185
- except queue.Empty:
186
- continue
187
- try:
188
- run_input(*args)
189
- except BaseException:
190
- # This should basically never happen, since only KeyboardInterrupt is the only error that can
191
- # bubble out of from handle_input_exception and those wouldn't be raised outside the main thread
192
- pass
193
- inputs.task_done()
194
-
195
- for input_id, function_call_id, args, kwargs in container_io_manager.run_inputs_outputs(
196
- imp_fun.input_concurrency
197
- ):
198
- if spawned_workers < imp_fun.input_concurrency:
199
- threading.Thread(target=worker_thread, daemon=True).start()
200
- spawned_workers += 1
201
- inputs.put((input_id, function_call_id, args, kwargs))
202
-
203
- finished.set()
204
- inputs.join()
205
-
206
- else:
207
- for input_id, function_call_id, args, kwargs in container_io_manager.run_inputs_outputs(
208
- imp_fun.input_concurrency
209
- ):
210
- try:
211
- run_input(input_id, function_call_id, args, kwargs)
212
- except:
213
- raise
214
-
215
-
216
- async def call_function_async(
217
- container_io_manager, #: ContainerIOManager, TODO: this type is generated at runtime
218
- imp_fun: ImportedFunction,
219
- ):
220
- async def run_input(input_id: str, function_call_id: str, args: Any, kwargs: Any) -> None:
221
- started_at = time.time()
222
- reset_context = _set_current_context_ids(input_id, function_call_id)
223
- async with container_io_manager.handle_input_exception.aio(input_id, started_at):
224
- logger.debug(f"Starting input {input_id} (async)")
225
- res = imp_fun.fun(*args, **kwargs)
226
- logger.debug(f"Finished input {input_id} (async)")
227
-
228
- # TODO(erikbern): any exception below shouldn't be considered a user exception
229
- if imp_fun.is_generator:
192
+ if io_context.finalized_function.is_generator:
230
193
  if not inspect.isasyncgen(res):
231
194
  raise InvalidError(f"Async generator function returned value of type {type(res)}")
232
195
 
@@ -234,8 +197,8 @@ async def call_function_async(
234
197
  generator_queue: asyncio.Queue[Any] = await container_io_manager._queue_create.aio(1024)
235
198
  generator_output_task = asyncio.create_task(
236
199
  container_io_manager.generator_output_task.aio(
237
- function_call_id,
238
- imp_fun.data_format,
200
+ function_call_ids[0],
201
+ io_context.finalized_function.data_format,
239
202
  generator_queue,
240
203
  )
241
204
  )
@@ -248,8 +211,11 @@ async def call_function_async(
248
211
  await container_io_manager._queue_put.aio(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL)
249
212
  await generator_output_task # Wait to finish sending generator outputs.
250
213
  message = api_pb2.GeneratorDone(items_total=item_count)
251
- await container_io_manager.push_output.aio(
252
- input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE
214
+ await container_io_manager.push_outputs.aio(
215
+ io_context,
216
+ started_at,
217
+ message,
218
+ api_pb2.DATA_FORMAT_GENERATOR_DONE,
253
219
  )
254
220
  else:
255
221
  if not inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
@@ -258,310 +224,332 @@ async def call_function_async(
258
224
  " You might need to use @app.function(..., is_generator=True)."
259
225
  )
260
226
  value = await res
261
- await container_io_manager.push_output.aio(input_id, started_at, value, imp_fun.data_format)
227
+ await container_io_manager.push_outputs.aio(
228
+ io_context,
229
+ started_at,
230
+ value,
231
+ io_context.finalized_function.data_format,
232
+ )
262
233
  reset_context()
263
234
 
264
- if imp_fun.input_concurrency > 1:
265
- # all run_input coroutines will have completed by the time we leave the execution context
266
- # but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s
267
- # for them to resolve gracefully:
268
- async with TaskContext(0.01) as task_context:
269
- async for input_id, function_call_id, args, kwargs in container_io_manager.run_inputs_outputs.aio(
270
- imp_fun.input_concurrency
271
- ):
272
- # Note that run_inputs_outputs will not return until the concurrency semaphore has
273
- # released all its slots so that they can be acquired by the run_inputs_outputs finalizer
274
- # This prevents leaving the task_context before outputs have been created
275
- # TODO: refactor to make this a bit more easy to follow?
276
- task_context.create_task(run_input(input_id, function_call_id, args, kwargs))
277
- else:
278
- async for input_id, function_call_id, args, kwargs in container_io_manager.run_inputs_outputs.aio(
279
- imp_fun.input_concurrency
280
- ):
281
- await run_input(input_id, function_call_id, args, kwargs)
282
-
283
-
284
- def import_function(
285
- function_def: api_pb2.Function,
286
- ser_cls,
287
- ser_fun,
288
- ser_params: Optional[bytes],
289
- container_io_manager,
290
- client: Client,
291
- ) -> ImportedFunction:
292
- """Imports a function dynamically, and locates the app.
293
-
294
- This is somewhat complex because we're dealing with 3 quite different type of functions:
295
- 1. Functions defined in global scope and decorated in global scope (Function objects)
296
- 2. Functions defined in global scope but decorated elsewhere (these will be raw callables)
297
- 3. Serialized functions
298
-
299
- In addition, we also need to handle
300
- * Normal functions
301
- * Methods on classes (in which case we need to instantiate the object)
302
-
303
- This helper also handles web endpoints, ASGI/WSGI servers, and HTTP servers.
304
-
305
- In order to locate the app, we try two things:
306
- * If the function is a Function, we can get the app directly from it
307
- * Otherwise, use the app name and look it up from a global list of apps: this
308
- typically only happens in case 2 above, or in sometimes for case 3
309
-
310
- Note that `import_function` is *not* synchronized, becase we need it to run on the main
311
- thread. This is so that any user code running in global scope (which executes as a part of
312
- the import) runs on the right thread.
313
- """
314
- module: Optional[ModuleType] = None
315
- cls: Optional[Type] = None
316
- fun: Callable
317
- function: Optional[_Function] = None
318
- active_app: Optional[_App] = None
319
- pty_info: api_pb2.PTYInfo = function_def.pty_info
320
-
321
- if ser_fun is not None:
322
- # This is a serialized function we already fetched from the server
323
- cls, fun = ser_cls, ser_fun
324
- else:
325
- # Load the module dynamically
326
- module = importlib.import_module(function_def.module_name)
327
- qual_name: str = function_def.function_name
328
-
329
- if not is_global_function(qual_name):
330
- raise LocalFunctionError("Attempted to load a function defined in a function scope")
331
-
332
- parts = qual_name.split(".")
333
- if len(parts) == 1:
334
- # This is a function
335
- cls = None
336
- f = getattr(module, qual_name)
337
- if isinstance(f, Function):
338
- function = synchronizer._translate_in(f)
339
- fun = function.get_raw_f()
340
- active_app = function._app
341
- else:
342
- fun = f
343
- elif len(parts) == 2:
344
- # This is a method on a class
345
- cls_name, fun_name = parts
346
- cls = getattr(module, cls_name)
347
- if isinstance(cls, Cls):
348
- # The cls decorator is in global scope
349
- _cls = synchronizer._translate_in(cls)
350
- fun = _cls._callables[fun_name]
351
- function = _cls._functions.get(fun_name)
352
- active_app = _cls._app
353
- else:
354
- # This is a raw class
355
- fun = getattr(cls, fun_name)
356
- else:
357
- raise InvalidError(f"Invalid function qualname {qual_name}")
358
-
359
- # If the cls/function decorator was applied in local scope, but the app is global, we can look it up
360
- if active_app is None:
361
- # This branch is reached in the special case that the imported function is 1) not serialized, and 2) isn't a FunctionHandle - i.e, not decorated at definition time
362
- # Look at all instantiated apps - if there is only one with the indicated name, use that one
363
- app_name: Optional[str] = function_def.app_name or None # coalesce protobuf field to None
364
- matching_apps = _App._all_apps.get(app_name, [])
365
- if len(matching_apps) > 1:
366
- if app_name is not None:
367
- warning_sub_message = f"app with the same name ('{app_name}')"
368
- else:
369
- warning_sub_message = "unnamed app"
370
- logger.warning(
371
- f"You have more than one {warning_sub_message}. It's recommended to name all your Apps uniquely when using multiple apps"
372
- )
373
- elif len(matching_apps) == 1:
374
- (active_app,) = matching_apps
375
- # there could also technically be zero found apps, but that should probably never be an issue since that would mean user won't use is_inside or other function handles anyway
235
+ def run_input_sync(io_context: IOContext) -> None:
236
+ started_at = time.time()
237
+ input_ids, function_call_ids = io_context.input_ids, io_context.function_call_ids
238
+ reset_context = _set_current_context_ids(input_ids, function_call_ids)
239
+ with container_io_manager.handle_input_exception(io_context, started_at):
240
+ res = io_context.call_finalized_function()
376
241
 
377
- # Check this property before we turn it into a method (overriden by webhooks)
378
- is_async = get_is_async(fun)
242
+ # TODO(erikbern): any exception below shouldn't be considered a user exception
243
+ if io_context.finalized_function.is_generator:
244
+ if not inspect.isgenerator(res):
245
+ raise InvalidError(f"Generator function returned value of type {type(res)}")
379
246
 
380
- # Use the function definition for whether this is a generator (overriden by webhooks)
381
- is_generator = function_def.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR
247
+ # Send up to this many outputs at a time.
248
+ generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024)
249
+ generator_output_task: concurrent.futures.Future = container_io_manager.generator_output_task( # type: ignore
250
+ function_call_ids[0],
251
+ io_context.finalized_function.data_format,
252
+ generator_queue,
253
+ _future=True, # type: ignore # Synchronicity magic to return a future.
254
+ )
382
255
 
383
- # What data format is used for function inputs and outputs
384
- data_format = api_pb2.DATA_FORMAT_PICKLE
256
+ item_count = 0
257
+ for value in res:
258
+ container_io_manager._queue_put(generator_queue, value)
259
+ item_count += 1
385
260
 
386
- # Container can fetch multiple inputs simultaneously
387
- if pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
388
- # Concurrency doesn't apply for `modal shell`.
389
- input_concurrency = 1
390
- else:
391
- input_concurrency = function_def.allow_concurrent_inputs or 1
261
+ container_io_manager._queue_put(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL)
262
+ generator_output_task.result() # Wait to finish sending generator outputs.
263
+ message = api_pb2.GeneratorDone(items_total=item_count)
264
+ container_io_manager.push_outputs(io_context, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE)
265
+ else:
266
+ if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
267
+ raise InvalidError(
268
+ f"Sync (non-generator) function return value of type {type(res)}."
269
+ " You might need to use @app.function(..., is_generator=True)."
270
+ )
271
+ container_io_manager.push_outputs(
272
+ io_context, started_at, res, io_context.finalized_function.data_format
273
+ )
274
+ reset_context()
392
275
 
393
- # Instantiate the class if it's defined
394
- if cls:
395
- if ser_params:
396
- _client: _Client = synchronizer._translate_in(client)
397
- args, kwargs = deserialize(ser_params, _client)
398
- else:
399
- args, kwargs = (), {}
400
- obj = cls(*args, **kwargs)
401
- if isinstance(cls, Cls):
402
- obj = obj.get_obj()
403
- # Bind the function to the instance (using the descriptor protocol!)
404
- fun = fun.__get__(obj)
276
+ if container_io_manager.target_concurrency > 1:
277
+ with DaemonizedThreadPool(max_threads=container_io_manager.max_concurrency) as thread_pool:
278
+
279
+ def make_async_cancel_callback(task):
280
+ def f():
281
+ user_code_event_loop.loop.call_soon_threadsafe(task.cancel)
282
+
283
+ return f
284
+
285
+ did_sigint = False
286
+
287
+ def cancel_callback_sync():
288
+ nonlocal did_sigint
289
+ # We only want one sigint even if multiple inputs are cancelled
290
+ # A second sigint would forcibly shut down the event loop and spew
291
+ # out a bunch of tracebacks, which we only want to happen in case
292
+ # the worker kills this process after a failed self-termination
293
+ if not did_sigint:
294
+ did_sigint = True
295
+ logger.warning(
296
+ "User cancelling input of non-async functions with allow_concurrent_inputs > 1.\n"
297
+ "This shuts down the container, causing concurrently running inputs to be "
298
+ "rescheduled in other containers."
299
+ )
300
+ os.kill(os.getpid(), signal.SIGINT)
301
+
302
+ async def run_concurrent_inputs():
303
+ # all run_input coroutines will have completed by the time we leave the execution context
304
+ # but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s
305
+ # for them to resolve gracefully:
306
+ async with TaskContext(0.01) as task_context:
307
+ async for io_context in container_io_manager.run_inputs_outputs.aio(
308
+ finalized_functions, batch_max_size, batch_wait_ms
309
+ ):
310
+ # Note that run_inputs_outputs will not return until all the input slots are released
311
+ # so that they can be acquired by the run_inputs_outputs finalizer
312
+ # This prevents leaving the task_context before outputs have been created
313
+ # TODO: refactor to make this a bit more easy to follow?
314
+ if io_context.finalized_function.is_async:
315
+ input_task = task_context.create_task(run_input_async(io_context))
316
+ io_context.set_cancel_callback(make_async_cancel_callback(input_task))
317
+ else:
318
+ # run sync input in thread
319
+ thread_pool.submit(run_input_sync, io_context)
320
+ io_context.set_cancel_callback(cancel_callback_sync)
321
+
322
+ user_code_event_loop.run(run_concurrent_inputs())
405
323
  else:
406
- obj = None
407
-
408
- if function_def.webhook_config.type:
409
- is_async = True
410
- is_generator = True
411
- data_format = api_pb2.DATA_FORMAT_ASGI
412
-
413
- if function_def.webhook_config.type == api_pb2.WEBHOOK_TYPE_ASGI_APP:
414
- # Function returns an asgi_app, which we can use as a callable.
415
- fun = asgi_app_wrapper(fun(), container_io_manager)
416
-
417
- elif function_def.webhook_config.type == api_pb2.WEBHOOK_TYPE_WSGI_APP:
418
- # Function returns an wsgi_app, which we can use as a callable.
419
- fun = wsgi_app_wrapper(fun(), container_io_manager)
420
-
421
- elif function_def.webhook_config.type == api_pb2.WEBHOOK_TYPE_FUNCTION:
422
- # Function is a webhook without an ASGI app. Create one for it.
423
- fun = asgi_app_wrapper(
424
- webhook_asgi_app(fun, function_def.webhook_config.method),
425
- container_io_manager,
426
- )
427
-
428
- elif function_def.webhook_config.type == api_pb2.WEBHOOK_TYPE_WEB_SERVER:
429
- # Function spawns an HTTP web server listening at a port.
430
- fun()
431
-
432
- # We intentionally try to connect to the external interface instead of the loopback
433
- # interface here so users are forced to expose the server. This allows us to potentially
434
- # change the implementation to use an external bridge in the future.
435
- host = get_ip_address(b"eth0")
436
- port = function_def.webhook_config.web_server_port
437
- startup_timeout = function_def.webhook_config.web_server_startup_timeout
438
- wait_for_web_server(host, port, timeout=startup_timeout)
439
- fun = asgi_app_wrapper(web_server_proxy(host, port), container_io_manager)
324
+ for io_context in container_io_manager.run_inputs_outputs(finalized_functions, batch_max_size, batch_wait_ms):
325
+ if io_context.finalized_function.is_async:
326
+ user_code_event_loop.run(run_input_async(io_context))
327
+ else:
328
+ # Set up a custom signal handler for `SIGUSR1`, which gets translated to an InputCancellation
329
+ # during function execution. This is sent to cancel inputs from the user
330
+ def _cancel_input_signal_handler(signum, stackframe):
331
+ raise InputCancellation("Input was cancelled by user")
440
332
 
333
+ usr1_handler = signal.signal(signal.SIGUSR1, _cancel_input_signal_handler)
334
+ # run this sync code in the main thread, blocking the "userland" event loop
335
+ # this lets us cancel it using a signal handler that raises an exception
336
+ try:
337
+ run_input_sync(io_context)
338
+ finally:
339
+ signal.signal(signal.SIGUSR1, usr1_handler) # reset signal handler
340
+
341
+
342
+ def get_active_app_fallback(function_def: api_pb2.Function) -> _App:
343
+ # This branch is reached in the special case that the imported function/class is:
344
+ # 1) not serialized, and
345
+ # 2) isn't a FunctionHandle - i.e, not decorated at definition time
346
+ # Look at all instantiated apps - if there is only one with the indicated name, use that one
347
+ app_name: Optional[str] = function_def.app_name or None # coalesce protobuf field to None
348
+ matching_apps = _App._all_apps.get(app_name, [])
349
+ if len(matching_apps) == 1:
350
+ active_app: _App = matching_apps[0]
351
+ return active_app
352
+
353
+ if len(matching_apps) > 1:
354
+ if app_name is not None:
355
+ warning_sub_message = f"app with the same name ('{app_name}')"
441
356
  else:
442
- raise InvalidError(f"Unrecognized web endpoint type {function_def.webhook_config.type}")
443
-
444
- return ImportedFunction(
445
- obj,
446
- fun,
447
- active_app,
448
- is_async,
449
- is_generator,
450
- data_format,
451
- input_concurrency,
452
- function_def.is_auto_snapshot,
453
- function,
454
- )
357
+ warning_sub_message = "unnamed app"
358
+ logger.warning(
359
+ f"You have more than one {warning_sub_message}. "
360
+ "It's recommended to name all your Apps uniquely when using multiple apps"
361
+ )
362
+
363
+ # If we don't have an active app, create one on the fly
364
+ # The app object is used to carry the app layout etc
365
+ return _App()
455
366
 
456
367
 
457
368
  def call_lifecycle_functions(
458
369
  event_loop: UserCodeEventLoop,
459
370
  container_io_manager, #: ContainerIOManager, TODO: this type is generated at runtime
460
- funcs: Sequence[Callable],
371
+ funcs: Sequence[Callable[..., Any]],
461
372
  ) -> None:
462
373
  """Call function(s), can be sync or async, but any return values are ignored."""
463
374
  with container_io_manager.handle_user_exception():
464
375
  for func in funcs:
465
376
  # We are deprecating parameterized exit methods but want to gracefully handle old code.
466
377
  # We can remove this once the deprecation in the actual @exit decorator is enforced.
467
- args = (None, None, None) if method_has_params(func) else ()
468
- res = func(
469
- *args
470
- ) # in case func is non-async, it's executed here and sigint will by default interrupt it using a KeyboardInterrupt exception
378
+ args = (None, None, None) if callable_has_non_self_params(func) else ()
379
+ # in case func is non-async, it's executed here and sigint will by default
380
+ # interrupt it using a KeyboardInterrupt exception
381
+ res = func(*args)
471
382
  if inspect.iscoroutine(res):
472
383
  # if however func is async, we have to jump through some hoops
473
384
  event_loop.run(res)
474
385
 
475
386
 
387
+ def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function, _client: "modal.client._Client"):
388
+ if function_def.class_parameter_info.format in (
389
+ api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_UNSPECIFIED,
390
+ api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PICKLE,
391
+ ):
392
+ # legacy serialization format - pickle of `(args, kwargs)` w/ support for modal object arguments
393
+ param_args, param_kwargs = deserialize(serialized_params, _client)
394
+ elif function_def.class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO:
395
+ param_args = ()
396
+ param_kwargs = deserialize_proto_params(serialized_params, list(function_def.class_parameter_info.schema))
397
+ else:
398
+ raise ExecutionError(
399
+ f"Unknown class parameter serialization format: {function_def.class_parameter_info.format}"
400
+ )
401
+
402
+ return param_args, param_kwargs
403
+
404
+
476
405
  def main(container_args: api_pb2.ContainerArguments, client: Client):
477
406
  # This is a bit weird but we need both the blocking and async versions of ContainerIOManager.
478
407
  # At some point, we should fix that by having built-in support for running "user code"
479
408
  container_io_manager = ContainerIOManager(container_args, client)
409
+ active_app: _App
410
+ service: Service
411
+ function_def = container_args.function_def
412
+ is_auto_snapshot: bool = function_def.is_auto_snapshot
413
+ # The worker sets this flag to "1" for snapshot and restore tasks. Otherwise, this flag is unset,
414
+ # in which case snapshots should be disabled.
415
+ is_snapshotting_function = (
416
+ function_def.is_checkpointing_function and os.environ.get("MODAL_ENABLE_SNAP_RESTORE", "0") == "1"
417
+ )
418
+
419
+ _client: _Client = synchronizer._translate_in(client) # TODO(erikbern): ugly
480
420
 
481
- with container_io_manager.heartbeats(), UserCodeEventLoop() as event_loop:
421
+ # Call ContainerHello - currently a noop but might be used later for things
422
+ container_io_manager.hello()
423
+
424
+ with container_io_manager.heartbeats(is_snapshotting_function), UserCodeEventLoop() as event_loop:
482
425
  # If this is a serialized function, fetch the definition from the server
483
- if container_args.function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED:
426
+ if function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED:
484
427
  ser_cls, ser_fun = container_io_manager.get_serialized_function()
485
428
  else:
486
429
  ser_cls, ser_fun = None, None
487
430
 
488
431
  # Initialize the function, importing user code.
489
432
  with container_io_manager.handle_user_exception():
490
- imp_fun = import_function(
491
- container_args.function_def,
492
- ser_cls,
493
- ser_fun,
494
- container_args.serialized_params,
495
- container_io_manager,
496
- client,
497
- )
433
+ if container_args.serialized_params:
434
+ param_args, param_kwargs = deserialize_params(container_args.serialized_params, function_def, _client)
435
+ else:
436
+ param_args = ()
437
+ param_kwargs = {}
438
+
439
+ if function_def.is_class:
440
+ service = import_class_service(
441
+ function_def,
442
+ ser_cls,
443
+ param_args,
444
+ param_kwargs,
445
+ )
446
+ else:
447
+ service = import_single_function_service(
448
+ function_def,
449
+ ser_cls,
450
+ ser_fun,
451
+ param_args,
452
+ param_kwargs,
453
+ )
454
+
455
+ # If the cls/function decorator was applied in local scope, but the app is global, we can look it up
456
+ if service.app is not None:
457
+ active_app = service.app
458
+ else:
459
+ # if the app can't be inferred by the imported function, use name-based fallback
460
+ active_app = get_active_app_fallback(function_def)
461
+
462
+ if function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
463
+ # Concurrency and batching doesn't apply for `modal shell`.
464
+ batch_max_size = 0
465
+ batch_wait_ms = 0
466
+ else:
467
+ batch_max_size = function_def.batch_max_size or 0
468
+ batch_wait_ms = function_def.batch_linger_ms or 0
498
469
 
499
470
  # Get ids and metadata for objects (primarily functions and classes) on the app
500
- container_app: RunningApp = container_io_manager.get_app_objects()
471
+ container_app: RunningApp = container_io_manager.get_app_objects(container_args.app_layout)
501
472
 
502
473
  # Initialize objects on the app.
503
474
  # This is basically only functions and classes - anything else is deprecated and will be unsupported soon
504
- if imp_fun.app is not None:
505
- app: App = synchronizer._translate_out(imp_fun.app, Interface.BLOCKING)
506
- app._init_container(client, container_app)
475
+ app: App = synchronizer._translate_out(active_app)
476
+ app._init_container(client, container_app)
507
477
 
508
478
  # Hydrate all function dependencies.
509
479
  # TODO(erikbern): we an remove this once we
510
480
  # 1. Enable lazy hydration for all objects
511
481
  # 2. Fully deprecate .new() objects
512
- if imp_fun.function:
513
- _client: _Client = synchronizer._translate_in(client) # TODO(erikbern): ugly
514
- dep_object_ids: List[str] = [dep.object_id for dep in container_args.function_def.object_dependencies]
515
- function_deps = imp_fun.function.deps(only_explicit_mounts=True)
516
- if len(function_deps) != len(dep_object_ids):
482
+ if service.code_deps is not None: # this is not set for serialized or non-global scope functions
483
+ dep_object_ids: list[str] = [dep.object_id for dep in function_def.object_dependencies]
484
+ if len(service.code_deps) != len(dep_object_ids):
517
485
  raise ExecutionError(
518
- f"Function has {len(function_deps)} dependencies"
519
- f" but container got {len(dep_object_ids)} object ids."
486
+ f"Function has {len(service.code_deps)} dependencies"
487
+ f" but container got {len(dep_object_ids)} object ids.\n"
488
+ f"Code deps: {service.code_deps}\n"
489
+ f"Object ids: {dep_object_ids}"
520
490
  )
521
- for object_id, obj in zip(dep_object_ids, function_deps):
491
+ for object_id, obj in zip(dep_object_ids, service.code_deps):
522
492
  metadata: Message = container_app.object_handle_metadata[object_id]
523
493
  obj._hydrate(object_id, _client, metadata)
524
494
 
525
- # Identify all "enter" methods that need to run before we checkpoint.
526
- if imp_fun.obj is not None and not imp_fun.is_auto_snapshot:
527
- pre_checkpoint_methods = _find_callables_for_obj(imp_fun.obj, _PartialFunctionFlags.ENTER_PRE_CHECKPOINT)
528
- call_lifecycle_functions(event_loop, container_io_manager, list(pre_checkpoint_methods.values()))
495
+ # Initialize clustered functions.
496
+ if function_def._experimental_group_size > 0:
497
+ initialize_clustered_function(
498
+ client,
499
+ container_args.task_id,
500
+ function_def._experimental_group_size,
501
+ )
502
+
503
+ # Identify all "enter" methods that need to run before we snapshot.
504
+ if service.user_cls_instance is not None and not is_auto_snapshot:
505
+ pre_snapshot_methods = _find_callables_for_obj(
506
+ service.user_cls_instance, _PartialFunctionFlags.ENTER_PRE_SNAPSHOT
507
+ )
508
+ call_lifecycle_functions(event_loop, container_io_manager, list(pre_snapshot_methods.values()))
529
509
 
530
510
  # If this container is being used to create a checkpoint, checkpoint the container after
531
- # global imports and innitialization. Checkpointed containers run from this point onwards.
532
- if container_args.function_def.is_checkpointing_function:
533
- container_io_manager.checkpoint()
511
+ # global imports and initialization. Checkpointed containers run from this point onwards.
512
+ if is_snapshotting_function:
513
+ container_io_manager.memory_snapshot()
534
514
 
535
515
  # Install hooks for interactive functions.
536
- if container_args.function_def.pty_info.pty_type != api_pb2.PTYInfo.PTY_TYPE_UNSPECIFIED:
516
+ def breakpoint_wrapper():
517
+ # note: it would be nice to not have breakpoint_wrapper() included in the backtrace
518
+ container_io_manager.interact(from_breakpoint=True)
519
+ import pdb
537
520
 
538
- def breakpoint_wrapper():
539
- # note: it would be nice to not have breakpoint_wrapper() included in the backtrace
540
- interact()
541
- import pdb
521
+ frame = inspect.currentframe().f_back
542
522
 
543
- pdb.set_trace()
523
+ pdb.Pdb().set_trace(frame)
544
524
 
545
- sys.breakpointhook = breakpoint_wrapper
525
+ sys.breakpointhook = breakpoint_wrapper
546
526
 
547
- # Identify the "enter" methods to run after resuming from a checkpoint.
548
- if imp_fun.obj is not None and not imp_fun.is_auto_snapshot:
549
- post_checkpoint_methods = _find_callables_for_obj(imp_fun.obj, _PartialFunctionFlags.ENTER_POST_CHECKPOINT)
550
- call_lifecycle_functions(event_loop, container_io_manager, list(post_checkpoint_methods.values()))
527
+ # Identify the "enter" methods to run after resuming from a snapshot.
528
+ if service.user_cls_instance is not None and not is_auto_snapshot:
529
+ post_snapshot_methods = _find_callables_for_obj(
530
+ service.user_cls_instance, _PartialFunctionFlags.ENTER_POST_SNAPSHOT
531
+ )
532
+ call_lifecycle_functions(event_loop, container_io_manager, list(post_snapshot_methods.values()))
551
533
 
534
+ with container_io_manager.handle_user_exception():
535
+ finalized_functions = service.get_finalized_functions(function_def, container_io_manager)
552
536
  # Execute the function.
537
+ lifespan_background_tasks = []
553
538
  try:
554
- if imp_fun.is_async:
555
- event_loop.run(call_function_async(container_io_manager, imp_fun))
556
- else:
557
- # Set up a signal handler for `SIGUSR1`, which gets translated to an InputCancellation
558
- # during function execution. This is sent to cancel inputs from the user.
559
- def _cancel_input_signal_handler(signum, stackframe):
560
- raise InputCancellation("Input was cancelled by user")
561
-
562
- signal.signal(signal.SIGUSR1, _cancel_input_signal_handler)
563
-
564
- call_function_sync(container_io_manager, imp_fun)
539
+ for finalized_function in finalized_functions.values():
540
+ if finalized_function.lifespan_manager:
541
+ lifespan_background_tasks.append(
542
+ event_loop.create_task(finalized_function.lifespan_manager.background_task())
543
+ )
544
+ with container_io_manager.handle_user_exception():
545
+ event_loop.run(finalized_function.lifespan_manager.lifespan_startup())
546
+ call_function(
547
+ event_loop,
548
+ container_io_manager,
549
+ finalized_functions,
550
+ batch_max_size,
551
+ batch_wait_ms,
552
+ )
565
553
  finally:
566
554
  # Run exit handlers. From this point onward, ignore all SIGINT signals that come from
567
555
  # graceful shutdowns originating on the worker, as well as stray SIGUSR1 signals that
@@ -570,15 +558,27 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
570
558
  usr1_handler = signal.signal(signal.SIGUSR1, signal.SIG_IGN)
571
559
 
572
560
  try:
573
- # Identify "exit" methods and run them.
574
- if imp_fun.obj is not None and not imp_fun.is_auto_snapshot:
575
- exit_methods = _find_callables_for_obj(imp_fun.obj, _PartialFunctionFlags.EXIT)
576
- call_lifecycle_functions(event_loop, container_io_manager, list(exit_methods.values()))
561
+ try:
562
+ # run lifespan shutdown for asgi apps
563
+ for finalized_function in finalized_functions.values():
564
+ if finalized_function.lifespan_manager:
565
+ with container_io_manager.handle_user_exception():
566
+ event_loop.run(finalized_function.lifespan_manager.lifespan_shutdown())
567
+ finally:
568
+ # no need to keep the lifespan asgi call around - we send it no more messages
569
+ for lifespan_background_task in lifespan_background_tasks:
570
+ lifespan_background_task.cancel() # prevent dangling tasks
571
+
572
+ # Identify "exit" methods and run them.
573
+ # want to make sure this is called even if the lifespan manager fails
574
+ if service.user_cls_instance is not None and not is_auto_snapshot:
575
+ exit_methods = _find_callables_for_obj(service.user_cls_instance, _PartialFunctionFlags.EXIT)
576
+ call_lifecycle_functions(event_loop, container_io_manager, list(exit_methods.values()))
577
577
 
578
578
  # Finally, commit on exit to catch uncommitted volume changes and surface background
579
579
  # commit errors.
580
580
  container_io_manager.volume_commit(
581
- [v.volume_id for v in container_args.function_def.volume_mounts if v.allow_background_commits]
581
+ [v.volume_id for v in function_def.volume_mounts if v.allow_background_commits]
582
582
  )
583
583
  finally:
584
584
  # Restore the original signal handler, needed for container_test hygiene since the
@@ -591,7 +591,15 @@ if __name__ == "__main__":
591
591
  logger.debug("Container: starting")
592
592
 
593
593
  container_args = api_pb2.ContainerArguments()
594
- container_args.ParseFromString(base64.b64decode(sys.argv[1]))
594
+
595
+ container_arguments_path: Optional[str] = os.environ.get("MODAL_CONTAINER_ARGUMENTS_PATH")
596
+ if container_arguments_path is None:
597
+ # TODO(erikbern): this fallback is for old workers and we can remove it very soon (days)
598
+ import base64
599
+
600
+ container_args.ParseFromString(base64.b64decode(sys.argv[1]))
601
+ else:
602
+ container_args.ParseFromString(open(container_arguments_path, "rb").read())
595
603
 
596
604
  # Note that we're creating the client in a synchronous context, but it will be running in a separate thread.
597
605
  # This is good because if the function is long running then we the client can still send heartbeats
@@ -611,7 +619,7 @@ if __name__ == "__main__":
611
619
  # from shutting down. The sleep(0) here is needed for finished ThreadPoolExecutor resources to
612
620
  # shut down without triggering this warning (e.g., `@wsgi_app()`).
613
621
  time.sleep(0)
614
- lingering_threads: List[threading.Thread] = []
622
+ lingering_threads: list[threading.Thread] = []
615
623
  for thread in threading.enumerate():
616
624
  current_thread = threading.get_ident()
617
625
  if thread.ident is not None and thread.ident != current_thread and not thread.daemon and thread.is_alive():
@@ -619,7 +627,8 @@ if __name__ == "__main__":
619
627
  if lingering_threads:
620
628
  thread_names = ", ".join(t.name for t in lingering_threads)
621
629
  logger.warning(
622
- f"Detected {len(lingering_threads)} background thread(s) [{thread_names}] still running after container exit. This will prevent runner shutdown for up to 30 seconds."
630
+ f"Detected {len(lingering_threads)} background thread(s) [{thread_names}] still running "
631
+ "after container exit. This will prevent runner shutdown for up to 30 seconds."
623
632
  )
624
633
 
625
634
  logger.debug("Container: done")