modal 0.62.115__py3-none-any.whl → 0.72.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +13 -9
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +402 -398
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -60
  11. modal/_resources.py +26 -7
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1025 -0
  15. modal/{execution_context.py → _runtime/execution_context.py} +11 -2
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +123 -6
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +50 -14
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +386 -104
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +299 -98
  29. modal/_utils/grpc_testing.py +47 -34
  30. modal/_utils/grpc_utils.py +54 -21
  31. modal/_utils/hash_utils.py +51 -10
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +3 -3
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +12 -10
  43. modal/app.py +561 -323
  44. modal/app.pyi +474 -262
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +22 -6
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +203 -42
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +61 -13
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +21 -48
  55. modal/cli/launch.py +28 -14
  56. modal/cli/network_file_system.py +57 -21
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +34 -9
  59. modal/cli/programs/vscode.py +58 -8
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +199 -96
  62. modal/cli/secret.py +5 -4
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +74 -8
  65. modal/cli/volume.py +97 -56
  66. modal/client.py +248 -144
  67. modal/client.pyi +156 -124
  68. modal/cloud_bucket_mount.py +43 -30
  69. modal/cloud_bucket_mount.pyi +32 -25
  70. modal/cls.py +528 -141
  71. modal/cls.pyi +189 -145
  72. modal/config.py +32 -15
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +50 -54
  76. modal/dict.pyi +120 -164
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +30 -43
  80. modal/experimental.py +62 -2
  81. modal/file_io.py +537 -0
  82. modal/file_io.pyi +235 -0
  83. modal/file_pattern_matcher.py +196 -0
  84. modal/functions.py +846 -428
  85. modal/functions.pyi +446 -387
  86. modal/gpu.py +57 -44
  87. modal/image.py +943 -417
  88. modal/image.pyi +584 -245
  89. modal/io_streams.py +434 -0
  90. modal/io_streams.pyi +122 -0
  91. modal/mount.py +223 -90
  92. modal/mount.pyi +241 -243
  93. modal/network_file_system.py +85 -86
  94. modal/network_file_system.pyi +151 -110
  95. modal/object.py +66 -36
  96. modal/object.pyi +166 -143
  97. modal/output.py +63 -0
  98. modal/parallel_map.py +73 -47
  99. modal/parallel_map.pyi +51 -63
  100. modal/partial_function.py +272 -107
  101. modal/partial_function.pyi +219 -120
  102. modal/proxy.py +15 -12
  103. modal/proxy.pyi +3 -8
  104. modal/queue.py +96 -72
  105. modal/queue.pyi +210 -135
  106. modal/requirements/2024.04.txt +2 -1
  107. modal/requirements/2024.10.txt +16 -0
  108. modal/requirements/README.md +21 -0
  109. modal/requirements/base-images.json +22 -0
  110. modal/retries.py +45 -4
  111. modal/runner.py +325 -203
  112. modal/runner.pyi +124 -110
  113. modal/running_app.py +27 -4
  114. modal/sandbox.py +509 -231
  115. modal/sandbox.pyi +396 -169
  116. modal/schedule.py +2 -2
  117. modal/scheduler_placement.py +20 -3
  118. modal/secret.py +41 -25
  119. modal/secret.pyi +62 -42
  120. modal/serving.py +39 -49
  121. modal/serving.pyi +37 -43
  122. modal/stream_type.py +15 -0
  123. modal/token_flow.py +5 -3
  124. modal/token_flow.pyi +37 -32
  125. modal/volume.py +123 -137
  126. modal/volume.pyi +228 -221
  127. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/METADATA +5 -5
  128. modal-0.72.13.dist-info/RECORD +174 -0
  129. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/top_level.txt +0 -1
  130. modal_docs/gen_reference_docs.py +3 -1
  131. modal_docs/mdmd/mdmd.py +0 -1
  132. modal_docs/mdmd/signatures.py +1 -2
  133. modal_global_objects/images/base_images.py +28 -0
  134. modal_global_objects/mounts/python_standalone.py +2 -2
  135. modal_proto/__init__.py +1 -1
  136. modal_proto/api.proto +1231 -531
  137. modal_proto/api_grpc.py +750 -430
  138. modal_proto/api_pb2.py +2102 -1176
  139. modal_proto/api_pb2.pyi +8859 -0
  140. modal_proto/api_pb2_grpc.py +1329 -675
  141. modal_proto/api_pb2_grpc.pyi +1416 -0
  142. modal_proto/modal_api_grpc.py +149 -0
  143. modal_proto/modal_options_grpc.py +3 -0
  144. modal_proto/options_pb2.pyi +20 -0
  145. modal_proto/options_pb2_grpc.pyi +7 -0
  146. modal_proto/py.typed +0 -0
  147. modal_version/__init__.py +1 -1
  148. modal_version/_version_generated.py +2 -2
  149. modal/_asgi.py +0 -370
  150. modal/_container_exec.py +0 -128
  151. modal/_container_io_manager.py +0 -646
  152. modal/_container_io_manager.pyi +0 -412
  153. modal/_sandbox_shell.py +0 -49
  154. modal/app_utils.py +0 -20
  155. modal/app_utils.pyi +0 -17
  156. modal/execution_context.pyi +0 -37
  157. modal/shared_volume.py +0 -23
  158. modal/shared_volume.pyi +0 -24
  159. modal-0.62.115.dist-info/RECORD +0 -207
  160. modal_global_objects/images/conda.py +0 -15
  161. modal_global_objects/images/debian_slim.py +0 -15
  162. modal_global_objects/images/micromamba.py +0 -15
  163. test/__init__.py +0 -1
  164. test/aio_test.py +0 -12
  165. test/async_utils_test.py +0 -279
  166. test/blob_test.py +0 -67
  167. test/cli_imports_test.py +0 -149
  168. test/cli_test.py +0 -674
  169. test/client_test.py +0 -203
  170. test/cloud_bucket_mount_test.py +0 -22
  171. test/cls_test.py +0 -636
  172. test/config_test.py +0 -149
  173. test/conftest.py +0 -1485
  174. test/container_app_test.py +0 -50
  175. test/container_test.py +0 -1405
  176. test/cpu_test.py +0 -23
  177. test/decorator_test.py +0 -85
  178. test/deprecation_test.py +0 -34
  179. test/dict_test.py +0 -51
  180. test/e2e_test.py +0 -68
  181. test/error_test.py +0 -7
  182. test/function_serialization_test.py +0 -32
  183. test/function_test.py +0 -791
  184. test/function_utils_test.py +0 -101
  185. test/gpu_test.py +0 -159
  186. test/grpc_utils_test.py +0 -82
  187. test/helpers.py +0 -47
  188. test/image_test.py +0 -814
  189. test/live_reload_test.py +0 -80
  190. test/lookup_test.py +0 -70
  191. test/mdmd_test.py +0 -329
  192. test/mount_test.py +0 -162
  193. test/mounted_files_test.py +0 -327
  194. test/network_file_system_test.py +0 -188
  195. test/notebook_test.py +0 -66
  196. test/object_test.py +0 -41
  197. test/package_utils_test.py +0 -25
  198. test/queue_test.py +0 -115
  199. test/resolver_test.py +0 -59
  200. test/retries_test.py +0 -67
  201. test/runner_test.py +0 -85
  202. test/sandbox_test.py +0 -191
  203. test/schedule_test.py +0 -15
  204. test/scheduler_placement_test.py +0 -57
  205. test/secret_test.py +0 -89
  206. test/serialization_test.py +0 -50
  207. test/stub_composition_test.py +0 -10
  208. test/stub_test.py +0 -361
  209. test/test_asgi_wrapper.py +0 -234
  210. test/token_flow_test.py +0 -18
  211. test/traceback_test.py +0 -135
  212. test/tunnel_test.py +0 -29
  213. test/utils_test.py +0 -88
  214. test/version_test.py +0 -14
  215. test/volume_test.py +0 -397
  216. test/watcher_test.py +0 -58
  217. test/webhook_test.py +0 -145
  218. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/LICENSE +0 -0
  219. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/WHEEL +0 -0
  220. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/entry_points.txt +0 -0
@@ -1,646 +0,0 @@
1
- # Copyright Modal Labs 2024
2
- import asyncio
3
- import json
4
- import math
5
- import os
6
- import signal
7
- import time
8
- import traceback
9
- from pathlib import Path
10
- from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, List, Optional, Set, Tuple
11
-
12
- from google.protobuf.empty_pb2 import Empty
13
- from google.protobuf.message import Message
14
- from grpclib import Status
15
- from synchronicity.async_wrap import asynccontextmanager
16
-
17
- from modal_proto import api_pb2
18
-
19
- from ._serialization import deserialize, deserialize_data_format, serialize, serialize_data_format
20
- from ._traceback import extract_traceback
21
- from ._utils.async_utils import TaskContext, asyncify, synchronize_api, synchronizer
22
- from ._utils.blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload
23
- from ._utils.function_utils import _stream_function_call_data
24
- from ._utils.grpc_utils import get_proto_oneof, retry_transient_errors
25
- from .client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client
26
- from .config import config, logger
27
- from .exception import InputCancellation, InvalidError
28
- from .running_app import RunningApp
29
-
30
- MAX_OUTPUT_BATCH_SIZE: int = 49
31
-
32
- RTT_S: float = 0.5 # conservative estimate of RTT in seconds.
33
-
34
-
35
- class UserException(Exception):
36
- """Used to shut down the task gracefully."""
37
-
38
-
39
- class Sentinel:
40
- """Used to get type-stubs to work with this object."""
41
-
42
-
43
- class _ContainerIOManager:
44
- """Synchronizes all RPC calls and network operations for a running container.
45
-
46
- TODO: maybe we shouldn't synchronize the whole class.
47
- Then we could potentially move a bunch of the global functions onto it.
48
- """
49
-
50
- cancelled_input_ids: Set[str]
51
- task_id: str
52
- function_id: str
53
- app_id: str
54
- function_def: api_pb2.Function
55
- checkpoint_id: Optional[str]
56
-
57
- calls_completed: int
58
- total_user_time: float
59
- current_input_id: Optional[str]
60
- current_input_started_at: Optional[float]
61
-
62
- _input_concurrency: Optional[int]
63
- _semaphore: Optional[asyncio.Semaphore]
64
- _environment_name: str
65
- _waiting_for_checkpoint: bool
66
- _heartbeat_loop: Optional[asyncio.Task]
67
-
68
- _is_interactivity_enabled: bool
69
- _fetching_inputs: bool
70
-
71
- _client: _Client
72
-
73
- _GENERATOR_STOP_SENTINEL: ClassVar[Sentinel] = Sentinel()
74
- _singleton: ClassVar[Optional["_ContainerIOManager"]] = None
75
-
76
- def _init(self, container_args: api_pb2.ContainerArguments, client: _Client):
77
- self.cancelled_input_ids = set()
78
- self.task_id = container_args.task_id
79
- self.function_id = container_args.function_id
80
- self.app_id = container_args.app_id
81
- self.function_def = container_args.function_def
82
- self.checkpoint_id = container_args.checkpoint_id or None
83
-
84
- self.calls_completed = 0
85
- self.total_user_time = 0.0
86
- self.current_input_id = None
87
- self.current_input_started_at = None
88
-
89
- self._input_concurrency = None
90
-
91
- self._semaphore = None
92
- self._environment_name = container_args.environment_name
93
- self._waiting_for_checkpoint = False
94
- self._heartbeat_loop = None
95
-
96
- self._is_interactivity_enabled = False
97
- self._fetching_inputs = True
98
-
99
- self._client = client
100
- assert isinstance(self._client, _Client)
101
-
102
- def __new__(cls, container_args: api_pb2.ContainerArguments, client: _Client) -> "_ContainerIOManager":
103
- cls._singleton = super().__new__(cls)
104
- cls._singleton._init(container_args, client)
105
- return cls._singleton
106
-
107
- @classmethod
108
- def _reset_singleton(cls):
109
- """Only used for tests."""
110
- cls._singleton = None
111
-
112
- async def _run_heartbeat_loop(self):
113
- while 1:
114
- t0 = time.monotonic()
115
- try:
116
- if await self._heartbeat_handle_cancellations():
117
- # got a cancellation event, fine to start another heartbeat immediately
118
- # since the cancellation queue should be empty on the worker server
119
- # however, we wait at least 1s to prevent short-circuiting the heartbeat loop
120
- # in case there is ever a bug. This means it will take at least 1s between
121
- # two subsequent cancellations on the same task at the moment
122
- await asyncio.sleep(1.0)
123
- continue
124
- except Exception as exc:
125
- # don't stop heartbeat loop if there are transient exceptions!
126
- time_elapsed = time.monotonic() - t0
127
- error = exc
128
- logger.warning(f"Heartbeat attempt failed ({time_elapsed=}, {error=})")
129
-
130
- heartbeat_duration = time.monotonic() - t0
131
- time_until_next_hearbeat = max(0.0, HEARTBEAT_INTERVAL - heartbeat_duration)
132
- await asyncio.sleep(time_until_next_hearbeat)
133
-
134
- async def _heartbeat_handle_cancellations(self) -> bool:
135
- # Return True if a cancellation event was received, in that case we shouldn't wait too long for another heartbeat
136
-
137
- # Don't send heartbeats for tasks waiting to be checkpointed.
138
- # Calling gRPC methods open new connections which block the
139
- # checkpointing process.
140
- if self._waiting_for_checkpoint:
141
- return False
142
-
143
- request = api_pb2.ContainerHeartbeatRequest(supports_graceful_input_cancellation=True)
144
- if self.current_input_id is not None:
145
- request.current_input_id = self.current_input_id
146
- if self.current_input_started_at is not None:
147
- request.current_input_started_at = self.current_input_started_at
148
-
149
- # TODO(erikbern): capture exceptions?
150
- response = await retry_transient_errors(
151
- self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
152
- )
153
-
154
- if response.HasField("cancel_input_event"):
155
- # Pause processing of the current input by signaling self a SIGUSR1.
156
- input_ids_to_cancel = response.cancel_input_event.input_ids
157
- if input_ids_to_cancel:
158
- if self._input_concurrency > 1:
159
- logger.info(
160
- "Shutting down task to stop some subset of inputs (concurrent functions don't support fine-grained cancellation)"
161
- )
162
- # This is equivalent to a task cancellation or preemption from worker code,
163
- # except we do not send a SIGKILL to forcefully exit after 30 seconds.
164
- #
165
- # SIGINT always interrupts the main thread, but not any auxiliary threads. On a
166
- # sync function without concurrent inputs, this raises a KeyboardInterrupt. When
167
- # there are concurrent inputs, we cannot interrupt the thread pool, but the
168
- # interpreter stops waiting for daemon threads and exits. On async functions,
169
- # this signal lands outside the event loop, stopping `run_until_complete()`.
170
- os.kill(os.getpid(), signal.SIGINT)
171
-
172
- elif self.current_input_id in input_ids_to_cancel:
173
- # This goes to a registered signal handler for sync Modal functions, or to the
174
- # `SignalHandlingEventLoop` for async functions.
175
- #
176
- # We only send this signal on functions that do not have concurrent inputs enabled.
177
- # This allows us to do fine-grained input cancellation. On sync functions, the
178
- # SIGUSR1 signal should interrupt the main thread where user code is running,
179
- # raising an InputCancellation() exception. On async functions, the signal should
180
- # reach a handler in SignalHandlingEventLoop, which cancels the task.
181
- os.kill(os.getpid(), signal.SIGUSR1)
182
- return True
183
- return False
184
-
185
- @asynccontextmanager
186
- async def heartbeats(self) -> AsyncGenerator[None, None]:
187
- async with TaskContext() as tc:
188
- self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop())
189
- t.set_name("heartbeat loop")
190
- try:
191
- yield
192
- finally:
193
- t.cancel()
194
-
195
- def stop_heartbeat(self):
196
- if self._heartbeat_loop:
197
- self._heartbeat_loop.cancel()
198
-
199
- async def get_app_objects(self) -> RunningApp:
200
- req = api_pb2.AppGetObjectsRequest(app_id=self.app_id, include_unindexed=True)
201
- resp = await retry_transient_errors(self._client.stub.AppGetObjects, req)
202
- logger.debug(f"AppGetObjects received {len(resp.items)} objects for app {self.app_id}")
203
-
204
- tag_to_object_id = {}
205
- object_handle_metadata = {}
206
- for item in resp.items:
207
- handle_metadata: Optional[Message] = get_proto_oneof(item.object, "handle_metadata_oneof")
208
- object_handle_metadata[item.object.object_id] = handle_metadata
209
- if item.tag:
210
- tag_to_object_id[item.tag] = item.object.object_id
211
-
212
- return RunningApp(
213
- self.app_id,
214
- environment_name=self._environment_name,
215
- tag_to_object_id=tag_to_object_id,
216
- object_handle_metadata=object_handle_metadata,
217
- )
218
-
219
- async def get_serialized_function(self) -> Tuple[Optional[Any], Callable]:
220
- # Fetch the serialized function definition
221
- request = api_pb2.FunctionGetSerializedRequest(function_id=self.function_id)
222
- response = await self._client.stub.FunctionGetSerialized(request)
223
- fun = self.deserialize(response.function_serialized)
224
-
225
- if response.class_serialized:
226
- cls = self.deserialize(response.class_serialized)
227
- else:
228
- cls = None
229
-
230
- return cls, fun
231
-
232
- def serialize(self, obj: Any) -> bytes:
233
- return serialize(obj)
234
-
235
- def deserialize(self, data: bytes) -> Any:
236
- return deserialize(data, self._client)
237
-
238
- @synchronizer.no_io_translation
239
- def serialize_data_format(self, obj: Any, data_format: int) -> bytes:
240
- return serialize_data_format(obj, data_format)
241
-
242
- def deserialize_data_format(self, data: bytes, data_format: int) -> Any:
243
- return deserialize_data_format(data, data_format, self._client)
244
-
245
- async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]:
246
- """Read from the `data_in` stream of a function call."""
247
- async for data in _stream_function_call_data(self._client, function_call_id, "data_in"):
248
- yield data
249
-
250
- async def put_data_out(
251
- self,
252
- function_call_id: str,
253
- start_index: int,
254
- data_format: int,
255
- messages_bytes: List[Any],
256
- ) -> None:
257
- """Put data onto the `data_out` stream of a function call.
258
-
259
- This is used for generator outputs, which includes web endpoint responses. Note that this
260
- was introduced as a performance optimization in client version 0.57, so older clients will
261
- still use the previous Postgres-backed system based on `FunctionPutOutputs()`.
262
- """
263
- data_chunks: List[api_pb2.DataChunk] = []
264
- for i, message_bytes in enumerate(messages_bytes):
265
- chunk = api_pb2.DataChunk(data_format=data_format, index=start_index + i) # type: ignore
266
- if len(message_bytes) > MAX_OBJECT_SIZE_BYTES:
267
- chunk.data_blob_id = await blob_upload(message_bytes, self._client.stub)
268
- else:
269
- chunk.data = message_bytes
270
- data_chunks.append(chunk)
271
-
272
- req = api_pb2.FunctionCallPutDataRequest(function_call_id=function_call_id, data_chunks=data_chunks)
273
- await retry_transient_errors(self._client.stub.FunctionCallPutDataOut, req)
274
-
275
- async def generator_output_task(self, function_call_id: str, data_format: int, message_rx: asyncio.Queue) -> None:
276
- """Task that feeds generator outputs into a function call's `data_out` stream."""
277
- index = 1
278
- received_sentinel = False
279
- while not received_sentinel:
280
- message = await message_rx.get()
281
- if message is self._GENERATOR_STOP_SENTINEL:
282
- break
283
- # ASGI 'http.response.start' and 'http.response.body' msgs are observed to be separated by 1ms.
284
- # If we don't sleep here for 1ms we end up with an extra call to .put_data_out().
285
- if index == 1:
286
- await asyncio.sleep(0.001)
287
- messages_bytes = [serialize_data_format(message, data_format)]
288
- total_size = len(messages_bytes[0]) + 512
289
- while total_size < 16 * 1024 * 1024: # 16 MiB, maximum size in a single message
290
- try:
291
- message = message_rx.get_nowait()
292
- except asyncio.QueueEmpty:
293
- break
294
- if message is self._GENERATOR_STOP_SENTINEL:
295
- received_sentinel = True
296
- break
297
- else:
298
- messages_bytes.append(serialize_data_format(message, data_format))
299
- total_size += len(messages_bytes[-1]) + 512 # 512 bytes for estimated framing overhead
300
- await self.put_data_out(function_call_id, index, data_format, messages_bytes)
301
- index += len(messages_bytes)
302
-
303
- async def _queue_create(self, size: int) -> asyncio.Queue:
304
- """Create a queue, on the synchronicity event loop (needed on Python 3.8 and 3.9)."""
305
- return asyncio.Queue(size)
306
-
307
- async def _queue_put(self, queue: asyncio.Queue, value: Any) -> None:
308
- """Put a value onto a queue, using the synchronicity event loop."""
309
- await queue.put(value)
310
-
311
- async def populate_input_blobs(self, item: api_pb2.FunctionInput):
312
- args = await blob_download(item.args_blob_id, self._client.stub)
313
-
314
- # Mutating
315
- item.ClearField("args_blob_id")
316
- item.args = args
317
- return item
318
-
319
- def get_average_call_time(self) -> float:
320
- if self.calls_completed == 0:
321
- return 0
322
-
323
- return self.total_user_time / self.calls_completed
324
-
325
- def get_max_inputs_to_fetch(self):
326
- if self.calls_completed == 0:
327
- return 1
328
-
329
- return math.ceil(RTT_S / max(self.get_average_call_time(), 1e-6))
330
-
331
- @synchronizer.no_io_translation
332
- async def _generate_inputs(self) -> AsyncIterator[Tuple[str, str, api_pb2.FunctionInput]]:
333
- request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id)
334
- eof_received = False
335
- iteration = 0
336
- while not eof_received and self._fetching_inputs:
337
- request.average_call_time = self.get_average_call_time()
338
- request.max_values = self.get_max_inputs_to_fetch() # Deprecated; remove.
339
- request.input_concurrency = self._input_concurrency
340
-
341
- await self._semaphore.acquire()
342
- yielded = False
343
- try:
344
- # If number of active inputs is at max queue size, this will block.
345
- iteration += 1
346
- response: api_pb2.FunctionGetInputsResponse = await retry_transient_errors(
347
- self._client.stub.FunctionGetInputs, request
348
- )
349
-
350
- if response.rate_limit_sleep_duration:
351
- logger.info(
352
- "Task exceeded rate limit, sleeping for %.2fs before trying again."
353
- % response.rate_limit_sleep_duration
354
- )
355
- await asyncio.sleep(response.rate_limit_sleep_duration)
356
- elif response.inputs:
357
- # for input cancellations and concurrency logic we currently assume
358
- # that there is no input buffering in the container
359
- assert len(response.inputs) == 1
360
-
361
- for item in response.inputs:
362
- if item.kill_switch:
363
- logger.debug(f"Task {self.task_id} input kill signal input.")
364
- eof_received = True
365
- break
366
- if item.input_id in self.cancelled_input_ids:
367
- continue
368
-
369
- # If we got a pointer to a blob, download it from S3.
370
- if item.input.WhichOneof("args_oneof") == "args_blob_id":
371
- input_pb = await self.populate_input_blobs(item.input)
372
- else:
373
- input_pb = item.input
374
-
375
- # If yielded, allow semaphore to be released via complete_call
376
- yield (item.input_id, item.function_call_id, input_pb)
377
- yielded = True
378
-
379
- # We only support max_inputs = 1 at the moment
380
- if item.input.final_input or self.function_def.max_inputs == 1:
381
- eof_received = True
382
- break
383
- finally:
384
- if not yielded:
385
- self._semaphore.release()
386
-
387
- @synchronizer.no_io_translation
388
- async def run_inputs_outputs(self, input_concurrency: int = 1) -> AsyncIterator[Tuple[str, str, Any, Any]]:
389
- # Ensure we do not fetch new inputs when container is too busy.
390
- # Before trying to fetch an input, acquire the semaphore:
391
- # - if no input is fetched, release the semaphore.
392
- # - or, when the output for the fetched input is sent, release the semaphore.
393
- self._input_concurrency = input_concurrency
394
- self._semaphore = asyncio.Semaphore(input_concurrency)
395
-
396
- try:
397
- async for input_id, function_call_id, input_pb in self._generate_inputs():
398
- args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {})
399
- self.current_input_id, self.current_input_started_at = (input_id, time.time())
400
- yield input_id, function_call_id, args, kwargs
401
- self.current_input_id, self.current_input_started_at = (None, None)
402
- finally:
403
- # collect all active input slots, meaning all inputs have wrapped up.
404
- for _ in range(input_concurrency):
405
- await self._semaphore.acquire()
406
-
407
- async def _push_output(self, input_id, started_at: float, data_format=api_pb2.DATA_FORMAT_UNSPECIFIED, **kwargs):
408
- # upload data to S3 if too big.
409
- if "data" in kwargs and kwargs["data"] and len(kwargs["data"]) > MAX_OBJECT_SIZE_BYTES:
410
- data_blob_id = await blob_upload(kwargs["data"], self._client.stub)
411
- # mutating kwargs.
412
- del kwargs["data"]
413
- kwargs["data_blob_id"] = data_blob_id
414
-
415
- output = api_pb2.FunctionPutOutputsItem(
416
- input_id=input_id,
417
- input_started_at=started_at,
418
- output_created_at=time.time(),
419
- result=api_pb2.GenericResult(**kwargs),
420
- data_format=data_format,
421
- )
422
-
423
- await retry_transient_errors(
424
- self._client.stub.FunctionPutOutputs,
425
- api_pb2.FunctionPutOutputsRequest(outputs=[output]),
426
- additional_status_codes=[Status.RESOURCE_EXHAUSTED],
427
- max_retries=None, # Retry indefinitely, trying every 1s.
428
- )
429
-
430
- def serialize_exception(self, exc: BaseException) -> Optional[bytes]:
431
- try:
432
- return self.serialize(exc)
433
- except Exception as serialization_exc:
434
- logger.info(f"Failed to serialize exception {exc}: {serialization_exc}")
435
- # We can't always serialize exceptions.
436
- return None
437
-
438
- def serialize_traceback(self, exc: BaseException) -> Tuple[Optional[bytes], Optional[bytes]]:
439
- serialized_tb, tb_line_cache = None, None
440
-
441
- try:
442
- tb_dict, line_cache = extract_traceback(exc, self.task_id)
443
- serialized_tb = self.serialize(tb_dict)
444
- tb_line_cache = self.serialize(line_cache)
445
- except Exception:
446
- logger.info("Failed to serialize exception traceback.")
447
-
448
- return serialized_tb, tb_line_cache
449
-
450
- @asynccontextmanager
451
- async def handle_user_exception(self) -> AsyncGenerator[None, None]:
452
- """Sets the task as failed in a way where it's not retried.
453
-
454
- Used for handling exceptions from container lifecycle methods at the moment, which should
455
- trigger a task failure state.
456
- """
457
- try:
458
- yield
459
- except KeyboardInterrupt:
460
- # Send no task result in case we get sigint:ed by the runner
461
- # The status of the input should have been handled externally already in that case
462
- raise
463
- except BaseException as exc:
464
- # Since this is on a different thread, sys.exc_info() can't find the exception in the stack.
465
- traceback.print_exception(type(exc), exc, exc.__traceback__)
466
-
467
- serialized_tb, tb_line_cache = self.serialize_traceback(exc)
468
-
469
- result = api_pb2.GenericResult(
470
- status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
471
- data=self.serialize_exception(exc),
472
- exception=repr(exc),
473
- traceback="".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
474
- serialized_tb=serialized_tb,
475
- tb_line_cache=tb_line_cache,
476
- )
477
-
478
- req = api_pb2.TaskResultRequest(result=result)
479
- await retry_transient_errors(self._client.stub.TaskResult, req)
480
-
481
- # Shut down the task gracefully
482
- raise UserException()
483
-
484
- @asynccontextmanager
485
- async def handle_input_exception(self, input_id, started_at: float) -> AsyncGenerator[None, None]:
486
- """Handle an exception while processing a function input."""
487
- try:
488
- yield
489
- except KeyboardInterrupt:
490
- raise
491
- except (InputCancellation, asyncio.CancelledError):
492
- # just skip creating any output for this input and keep going with the next instead
493
- # it should have been marked as cancelled already in the backend at this point so it
494
- # won't be retried
495
- logger.warning(f"The current input ({input_id=}) was cancelled by a user request")
496
- await self.complete_call(started_at)
497
- return
498
- except BaseException as exc:
499
- # print exception so it's logged
500
- traceback.print_exc()
501
- serialized_tb, tb_line_cache = self.serialize_traceback(exc)
502
-
503
- # Note: we're not serializing the traceback since it contains
504
- # local references that means we can't unpickle it. We *are*
505
- # serializing the exception, which may have some issues (there
506
- # was an earlier note about it that it might not be possible
507
- # to unpickle it in some cases). Let's watch out for issues.
508
- await self._push_output(
509
- input_id,
510
- started_at=started_at,
511
- data_format=api_pb2.DATA_FORMAT_PICKLE,
512
- status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
513
- data=self.serialize_exception(exc),
514
- exception=repr(exc),
515
- traceback=traceback.format_exc(),
516
- serialized_tb=serialized_tb,
517
- tb_line_cache=tb_line_cache,
518
- )
519
- await self.complete_call(started_at)
520
-
521
- async def complete_call(self, started_at):
522
- self.total_user_time += time.time() - started_at
523
- self.calls_completed += 1
524
- self._semaphore.release()
525
-
526
- @synchronizer.no_io_translation
527
- async def push_output(self, input_id, started_at: float, data: Any, data_format: int) -> None:
528
- await self._push_output(
529
- input_id,
530
- started_at=started_at,
531
- data_format=data_format,
532
- status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
533
- data=self.serialize_data_format(data, data_format),
534
- )
535
- await self.complete_call(started_at)
536
-
537
- async def restore(self) -> None:
538
- # Busy-wait for restore. `/__modal/restore-state.json` is created
539
- # by the worker process with updates to the container config.
540
- restored_path = Path(config.get("restore_state_path"))
541
- start = time.perf_counter()
542
- while not restored_path.exists():
543
- logger.debug(f"Waiting for restore (elapsed={time.perf_counter() - start:.3f}s)")
544
- await asyncio.sleep(0.01)
545
- continue
546
-
547
- logger.debug("Container: restored")
548
-
549
- # Look for state file and create new client with updated credentials.
550
- # State data is serialized with key-value pairs, example: {"task_id": "tk-000"}
551
- with restored_path.open("r") as file:
552
- restored_state = json.load(file)
553
-
554
- # Local ContainerIOManager state.
555
- for key in ["task_id", "function_id"]:
556
- if value := restored_state.get(key):
557
- logger.debug(f"Updating ContainerIOManager.{key} = {value}")
558
- setattr(self, key, restored_state[key])
559
-
560
- # Env vars and global state.
561
- for key, value in restored_state.items():
562
- # Empty string indicates that value does not need to be updated.
563
- if value != "":
564
- config.override_locally(key, value)
565
-
566
- # Restore input to default state.
567
- self.current_input_id = None
568
- self.current_input_started_at = None
569
-
570
- self._client = await _Client.from_env()
571
- self._waiting_for_checkpoint = False
572
-
573
- async def checkpoint(self) -> None:
574
- """Message server indicating that function is ready to be checkpointed."""
575
- if self.checkpoint_id:
576
- logger.debug(f"Checkpoint ID: {self.checkpoint_id}")
577
-
578
- await self._client.stub.ContainerCheckpoint(
579
- api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id)
580
- )
581
-
582
- self._waiting_for_checkpoint = True
583
- await self._client._close()
584
-
585
- logger.debug("Checkpointing request sent. Connection closed.")
586
- await self.restore()
587
-
588
- async def volume_commit(self, volume_ids: List[str]) -> None:
589
- """
590
- Perform volume commit for given `volume_ids`.
591
- Only used on container exit to persist uncommitted changes on behalf of user.
592
- """
593
- if not volume_ids:
594
- return
595
- await asyncify(os.sync)()
596
- results = await asyncio.gather(
597
- *[
598
- retry_transient_errors(
599
- self._client.stub.VolumeCommit,
600
- api_pb2.VolumeCommitRequest(volume_id=v_id),
601
- max_retries=9,
602
- base_delay=0.25,
603
- max_delay=256,
604
- delay_factor=2,
605
- )
606
- for v_id in volume_ids
607
- ],
608
- return_exceptions=True,
609
- )
610
- for volume_id, res in zip(volume_ids, results):
611
- if isinstance(res, Exception):
612
- logger.error(f"modal.Volume background commit failed for {volume_id}. Exception: {res}")
613
- else:
614
- logger.debug(f"modal.Volume background commit success for {volume_id}.")
615
-
616
- async def interact(self):
617
- if self._is_interactivity_enabled:
618
- # Currently, interactivity is enabled forever
619
- return
620
- self._is_interactivity_enabled = True
621
-
622
- if not self.function_def.pty_info:
623
- raise InvalidError(
624
- "Interactivity is not enabled in this function. Use MODAL_INTERACTIVE_FUNCTIONS=1 to enable interactivity."
625
- )
626
-
627
- if self.function_def.concurrency_limit > 1:
628
- print(
629
- "Warning: Interactivity is not supported on functions with concurrency > 1. You may experience unexpected behavior."
630
- )
631
-
632
- # todo(nathan): add warning if concurrency limit > 1. but idk how to check this here
633
- # todo(nathan): check if function interactivity is enabled
634
- try:
635
- await self._client.stub.FunctionStartPtyShell(Empty())
636
- except Exception as e:
637
- print("Error: Failed to start PTY shell.")
638
- raise e
639
-
640
- @classmethod
641
- def stop_fetching_inputs(cls):
642
- assert cls._singleton
643
- cls._singleton._fetching_inputs = False
644
-
645
-
646
- ContainerIOManager = synchronize_api(_ContainerIOManager)