modal 1.0.6.dev58__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modal might be problematic. Click here for more details.
- modal/__main__.py +3 -4
- modal/_billing.py +80 -0
- modal/_clustered_functions.py +7 -3
- modal/_clustered_functions.pyi +4 -2
- modal/_container_entrypoint.py +41 -49
- modal/_functions.py +424 -195
- modal/_grpc_client.py +171 -0
- modal/_load_context.py +105 -0
- modal/_object.py +68 -20
- modal/_output.py +58 -45
- modal/_partial_function.py +36 -11
- modal/_pty.py +7 -3
- modal/_resolver.py +21 -35
- modal/_runtime/asgi.py +4 -3
- modal/_runtime/container_io_manager.py +301 -186
- modal/_runtime/container_io_manager.pyi +70 -61
- modal/_runtime/execution_context.py +18 -2
- modal/_runtime/execution_context.pyi +4 -1
- modal/_runtime/gpu_memory_snapshot.py +170 -63
- modal/_runtime/user_code_imports.py +28 -58
- modal/_serialization.py +57 -1
- modal/_utils/async_utils.py +33 -12
- modal/_utils/auth_token_manager.py +2 -5
- modal/_utils/blob_utils.py +110 -53
- modal/_utils/function_utils.py +49 -42
- modal/_utils/grpc_utils.py +80 -50
- modal/_utils/mount_utils.py +26 -1
- modal/_utils/name_utils.py +17 -3
- modal/_utils/task_command_router_client.py +536 -0
- modal/_utils/time_utils.py +34 -6
- modal/app.py +219 -83
- modal/app.pyi +229 -56
- modal/billing.py +5 -0
- modal/{requirements → builder}/2025.06.txt +1 -0
- modal/{requirements → builder}/PREVIEW.txt +1 -0
- modal/cli/_download.py +19 -3
- modal/cli/_traceback.py +3 -2
- modal/cli/app.py +4 -4
- modal/cli/cluster.py +15 -7
- modal/cli/config.py +5 -3
- modal/cli/container.py +7 -6
- modal/cli/dict.py +22 -16
- modal/cli/entry_point.py +12 -5
- modal/cli/environment.py +5 -4
- modal/cli/import_refs.py +3 -3
- modal/cli/launch.py +102 -5
- modal/cli/network_file_system.py +9 -13
- modal/cli/profile.py +3 -2
- modal/cli/programs/launch_instance_ssh.py +94 -0
- modal/cli/programs/run_jupyter.py +1 -1
- modal/cli/programs/run_marimo.py +95 -0
- modal/cli/programs/vscode.py +1 -1
- modal/cli/queues.py +57 -26
- modal/cli/run.py +58 -16
- modal/cli/secret.py +48 -22
- modal/cli/utils.py +3 -4
- modal/cli/volume.py +28 -25
- modal/client.py +13 -116
- modal/client.pyi +9 -91
- modal/cloud_bucket_mount.py +5 -3
- modal/cloud_bucket_mount.pyi +5 -1
- modal/cls.py +130 -102
- modal/cls.pyi +45 -85
- modal/config.py +29 -10
- modal/container_process.py +291 -13
- modal/container_process.pyi +95 -32
- modal/dict.py +282 -63
- modal/dict.pyi +423 -73
- modal/environments.py +15 -27
- modal/environments.pyi +5 -15
- modal/exception.py +8 -0
- modal/experimental/__init__.py +143 -38
- modal/experimental/flash.py +247 -78
- modal/experimental/flash.pyi +137 -9
- modal/file_io.py +14 -28
- modal/file_io.pyi +2 -2
- modal/file_pattern_matcher.py +25 -16
- modal/functions.pyi +134 -61
- modal/image.py +255 -86
- modal/image.pyi +300 -62
- modal/io_streams.py +436 -126
- modal/io_streams.pyi +236 -171
- modal/mount.py +62 -157
- modal/mount.pyi +45 -172
- modal/network_file_system.py +30 -53
- modal/network_file_system.pyi +16 -76
- modal/object.pyi +42 -8
- modal/parallel_map.py +821 -113
- modal/parallel_map.pyi +134 -0
- modal/partial_function.pyi +4 -1
- modal/proxy.py +16 -7
- modal/proxy.pyi +10 -2
- modal/queue.py +263 -61
- modal/queue.pyi +409 -66
- modal/runner.py +112 -92
- modal/runner.pyi +45 -27
- modal/sandbox.py +451 -124
- modal/sandbox.pyi +513 -67
- modal/secret.py +291 -67
- modal/secret.pyi +425 -19
- modal/serving.py +7 -11
- modal/serving.pyi +7 -8
- modal/snapshot.py +11 -8
- modal/token_flow.py +4 -4
- modal/volume.py +344 -98
- modal/volume.pyi +464 -68
- {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +9 -8
- modal-1.2.3.dev7.dist-info/RECORD +195 -0
- modal_docs/mdmd/mdmd.py +11 -1
- modal_proto/api.proto +399 -67
- modal_proto/api_grpc.py +241 -1
- modal_proto/api_pb2.py +1395 -1000
- modal_proto/api_pb2.pyi +1239 -79
- modal_proto/api_pb2_grpc.py +499 -4
- modal_proto/api_pb2_grpc.pyi +162 -14
- modal_proto/modal_api_grpc.py +175 -160
- modal_proto/sandbox_router.proto +145 -0
- modal_proto/sandbox_router_grpc.py +105 -0
- modal_proto/sandbox_router_pb2.py +149 -0
- modal_proto/sandbox_router_pb2.pyi +333 -0
- modal_proto/sandbox_router_pb2_grpc.py +203 -0
- modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
- modal_proto/task_command_router.proto +144 -0
- modal_proto/task_command_router_grpc.py +105 -0
- modal_proto/task_command_router_pb2.py +149 -0
- modal_proto/task_command_router_pb2.pyi +333 -0
- modal_proto/task_command_router_pb2_grpc.py +203 -0
- modal_proto/task_command_router_pb2_grpc.pyi +75 -0
- modal_version/__init__.py +1 -1
- modal-1.0.6.dev58.dist-info/RECORD +0 -183
- modal_proto/modal_options_grpc.py +0 -3
- modal_proto/options.proto +0 -19
- modal_proto/options_grpc.py +0 -3
- modal_proto/options_pb2.py +0 -35
- modal_proto/options_pb2.pyi +0 -20
- modal_proto/options_pb2_grpc.py +0 -4
- modal_proto/options_pb2_grpc.pyi +0 -7
- /modal/{requirements → builder}/2023.12.312.txt +0 -0
- /modal/{requirements → builder}/2023.12.txt +0 -0
- /modal/{requirements → builder}/2024.04.txt +0 -0
- /modal/{requirements → builder}/2024.10.txt +0 -0
- /modal/{requirements → builder}/README.md +0 -0
- /modal/{requirements → builder}/base-images.json +0 -0
- {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
- {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
- {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
- {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
|
@@ -27,6 +27,7 @@ class IOContext:
|
|
|
27
27
|
input_ids: list[str]
|
|
28
28
|
retry_counts: list[int]
|
|
29
29
|
function_call_ids: list[str]
|
|
30
|
+
attempt_tokens: list[str]
|
|
30
31
|
function_inputs: list[modal_proto.api_pb2.FunctionInput]
|
|
31
32
|
finalized_function: modal._runtime.user_code_imports.FinalizedFunction
|
|
32
33
|
_cancel_issued: bool
|
|
@@ -37,6 +38,7 @@ class IOContext:
|
|
|
37
38
|
input_ids: list[str],
|
|
38
39
|
retry_counts: list[int],
|
|
39
40
|
function_call_ids: list[str],
|
|
41
|
+
attempt_tokens: list[str],
|
|
40
42
|
finalized_function: modal._runtime.user_code_imports.FinalizedFunction,
|
|
41
43
|
function_inputs: list[modal_proto.api_pb2.FunctionInput],
|
|
42
44
|
is_batched: bool,
|
|
@@ -50,14 +52,29 @@ class IOContext:
|
|
|
50
52
|
cls,
|
|
51
53
|
client: modal.client._Client,
|
|
52
54
|
finalized_functions: dict[str, modal._runtime.user_code_imports.FinalizedFunction],
|
|
53
|
-
inputs: list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]],
|
|
55
|
+
inputs: list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]],
|
|
54
56
|
is_batched: bool,
|
|
55
57
|
) -> IOContext: ...
|
|
56
58
|
def set_cancel_callback(self, cb: collections.abc.Callable[[], None]): ...
|
|
57
59
|
def cancel(self): ...
|
|
58
60
|
def _args_and_kwargs(self) -> tuple[tuple[typing.Any, ...], dict[str, list[typing.Any]]]: ...
|
|
59
|
-
def
|
|
60
|
-
def
|
|
61
|
+
def _generator_output_format(self) -> int: ...
|
|
62
|
+
def _prepare_batch_output(self, data: typing.Any) -> list[typing.Any]: ...
|
|
63
|
+
def call_function_sync(self) -> list[typing.Any]: ...
|
|
64
|
+
async def call_function_async(self) -> list[typing.Any]: ...
|
|
65
|
+
def call_generator_sync(self) -> typing.Generator[typing.Any, None, None]: ...
|
|
66
|
+
def call_generator_async(self) -> collections.abc.AsyncGenerator[typing.Any, None]: ...
|
|
67
|
+
async def output_items_cancellation(self, started_at: float): ...
|
|
68
|
+
def _determine_output_format(self, input_format: int) -> int: ...
|
|
69
|
+
async def output_items_exception(
|
|
70
|
+
self, started_at: float, task_id: str, exc: BaseException
|
|
71
|
+
) -> list[modal_proto.api_pb2.FunctionPutOutputsItem]: ...
|
|
72
|
+
def output_items_generator_done(
|
|
73
|
+
self, started_at: float, items_total: int
|
|
74
|
+
) -> list[modal_proto.api_pb2.FunctionPutOutputsItem]: ...
|
|
75
|
+
async def output_items(
|
|
76
|
+
self, started_at: float, data: list[typing.Any]
|
|
77
|
+
) -> list[modal_proto.api_pb2.FunctionPutOutputsItem]: ...
|
|
61
78
|
|
|
62
79
|
class InputSlots:
|
|
63
80
|
"""A semaphore that allows dynamically adjusting the concurrency."""
|
|
@@ -131,14 +148,19 @@ class _ContainerIOManager:
|
|
|
131
148
|
def stop_heartbeat(self): ...
|
|
132
149
|
def dynamic_concurrency_manager(self) -> typing.AsyncContextManager[None]: ...
|
|
133
150
|
async def _dynamic_concurrency_loop(self): ...
|
|
134
|
-
def
|
|
135
|
-
|
|
136
|
-
|
|
151
|
+
def get_data_in(
|
|
152
|
+
self, function_call_id: str, attempt_token: typing.Optional[str]
|
|
153
|
+
) -> collections.abc.AsyncIterator[typing.Any]:
|
|
137
154
|
"""Read from the `data_in` stream of a function call."""
|
|
138
155
|
...
|
|
139
156
|
|
|
140
157
|
async def put_data_out(
|
|
141
|
-
self,
|
|
158
|
+
self,
|
|
159
|
+
function_call_id: str,
|
|
160
|
+
attempt_token: str,
|
|
161
|
+
start_index: int,
|
|
162
|
+
data_format: int,
|
|
163
|
+
serialized_messages: list[typing.Any],
|
|
142
164
|
) -> None:
|
|
143
165
|
"""Put data onto the `data_out` stream of a function call.
|
|
144
166
|
|
|
@@ -149,7 +171,7 @@ class _ContainerIOManager:
|
|
|
149
171
|
...
|
|
150
172
|
|
|
151
173
|
def generator_output_sender(
|
|
152
|
-
self, function_call_id: str, data_format: int, message_rx: asyncio.queues.Queue
|
|
174
|
+
self, function_call_id: str, attempt_token: str, data_format: int, message_rx: asyncio.queues.Queue
|
|
153
175
|
) -> typing.AsyncContextManager[None]:
|
|
154
176
|
"""Runs background task that feeds generator outputs into a function call's `data_out` stream."""
|
|
155
177
|
...
|
|
@@ -166,22 +188,17 @@ class _ContainerIOManager:
|
|
|
166
188
|
def get_max_inputs_to_fetch(self): ...
|
|
167
189
|
def _generate_inputs(
|
|
168
190
|
self, batch_max_size: int, batch_wait_ms: int
|
|
169
|
-
) -> collections.abc.AsyncIterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
|
191
|
+
) -> collections.abc.AsyncIterator[list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
|
170
192
|
def run_inputs_outputs(
|
|
171
193
|
self,
|
|
172
194
|
finalized_functions: dict[str, modal._runtime.user_code_imports.FinalizedFunction],
|
|
173
195
|
batch_max_size: int = 0,
|
|
174
196
|
batch_wait_ms: int = 0,
|
|
175
197
|
) -> collections.abc.AsyncIterator[IOContext]: ...
|
|
176
|
-
async def
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
data_format: int,
|
|
181
|
-
results: list[modal_proto.api_pb2.GenericResult],
|
|
182
|
-
) -> None: ...
|
|
183
|
-
def serialize_exception(self, exc: BaseException) -> bytes: ...
|
|
184
|
-
def serialize_traceback(self, exc: BaseException) -> tuple[typing.Optional[bytes], typing.Optional[bytes]]: ...
|
|
198
|
+
async def _send_outputs(self, started_at: float, outputs: list[modal_proto.api_pb2.FunctionPutOutputsItem]) -> None:
|
|
199
|
+
"""Send pre-built output items with retry and chunking."""
|
|
200
|
+
...
|
|
201
|
+
|
|
185
202
|
def handle_user_exception(self) -> typing.AsyncContextManager[None]:
|
|
186
203
|
"""Sets the task as failed in a way where it's not retried.
|
|
187
204
|
|
|
@@ -195,9 +212,7 @@ class _ContainerIOManager:
|
|
|
195
212
|
...
|
|
196
213
|
|
|
197
214
|
def exit_context(self, started_at, input_ids: list[str]): ...
|
|
198
|
-
async def push_outputs(
|
|
199
|
-
self, io_context: IOContext, started_at: float, data: typing.Any, data_format: int
|
|
200
|
-
) -> None: ...
|
|
215
|
+
async def push_outputs(self, io_context: IOContext, started_at: float, output_data: list[typing.Any]) -> None: ...
|
|
201
216
|
async def memory_restore(self) -> None: ...
|
|
202
217
|
async def memory_snapshot(self) -> None:
|
|
203
218
|
"""Message server indicating that function is ready to be checkpointed."""
|
|
@@ -323,20 +338,16 @@ class ContainerIOManager:
|
|
|
323
338
|
|
|
324
339
|
_dynamic_concurrency_loop: ___dynamic_concurrency_loop_spec[typing_extensions.Self]
|
|
325
340
|
|
|
326
|
-
def serialize_data_format(self, obj: typing.Any, data_format: int) -> bytes: ...
|
|
327
|
-
|
|
328
|
-
class __format_blob_data_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
329
|
-
def __call__(self, /, data: bytes) -> dict[str, typing.Any]: ...
|
|
330
|
-
async def aio(self, /, data: bytes) -> dict[str, typing.Any]: ...
|
|
331
|
-
|
|
332
|
-
format_blob_data: __format_blob_data_spec[typing_extensions.Self]
|
|
333
|
-
|
|
334
341
|
class __get_data_in_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
335
|
-
def __call__(
|
|
342
|
+
def __call__(
|
|
343
|
+
self, /, function_call_id: str, attempt_token: typing.Optional[str]
|
|
344
|
+
) -> typing.Iterator[typing.Any]:
|
|
336
345
|
"""Read from the `data_in` stream of a function call."""
|
|
337
346
|
...
|
|
338
347
|
|
|
339
|
-
def aio(
|
|
348
|
+
def aio(
|
|
349
|
+
self, /, function_call_id: str, attempt_token: typing.Optional[str]
|
|
350
|
+
) -> collections.abc.AsyncIterator[typing.Any]:
|
|
340
351
|
"""Read from the `data_in` stream of a function call."""
|
|
341
352
|
...
|
|
342
353
|
|
|
@@ -344,7 +355,13 @@ class ContainerIOManager:
|
|
|
344
355
|
|
|
345
356
|
class __put_data_out_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
346
357
|
def __call__(
|
|
347
|
-
self,
|
|
358
|
+
self,
|
|
359
|
+
/,
|
|
360
|
+
function_call_id: str,
|
|
361
|
+
attempt_token: str,
|
|
362
|
+
start_index: int,
|
|
363
|
+
data_format: int,
|
|
364
|
+
serialized_messages: list[typing.Any],
|
|
348
365
|
) -> None:
|
|
349
366
|
"""Put data onto the `data_out` stream of a function call.
|
|
350
367
|
|
|
@@ -355,7 +372,13 @@ class ContainerIOManager:
|
|
|
355
372
|
...
|
|
356
373
|
|
|
357
374
|
async def aio(
|
|
358
|
-
self,
|
|
375
|
+
self,
|
|
376
|
+
/,
|
|
377
|
+
function_call_id: str,
|
|
378
|
+
attempt_token: str,
|
|
379
|
+
start_index: int,
|
|
380
|
+
data_format: int,
|
|
381
|
+
serialized_messages: list[typing.Any],
|
|
359
382
|
) -> None:
|
|
360
383
|
"""Put data onto the `data_out` stream of a function call.
|
|
361
384
|
|
|
@@ -369,13 +392,13 @@ class ContainerIOManager:
|
|
|
369
392
|
|
|
370
393
|
class __generator_output_sender_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
371
394
|
def __call__(
|
|
372
|
-
self, /, function_call_id: str, data_format: int, message_rx: asyncio.queues.Queue
|
|
395
|
+
self, /, function_call_id: str, attempt_token: str, data_format: int, message_rx: asyncio.queues.Queue
|
|
373
396
|
) -> synchronicity.combined_types.AsyncAndBlockingContextManager[None]:
|
|
374
397
|
"""Runs background task that feeds generator outputs into a function call's `data_out` stream."""
|
|
375
398
|
...
|
|
376
399
|
|
|
377
400
|
def aio(
|
|
378
|
-
self, /, function_call_id: str, data_format: int, message_rx: asyncio.queues.Queue
|
|
401
|
+
self, /, function_call_id: str, attempt_token: str, data_format: int, message_rx: asyncio.queues.Queue
|
|
379
402
|
) -> typing.AsyncContextManager[None]:
|
|
380
403
|
"""Runs background task that feeds generator outputs into a function call's `data_out` stream."""
|
|
381
404
|
...
|
|
@@ -410,10 +433,10 @@ class ContainerIOManager:
|
|
|
410
433
|
class ___generate_inputs_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
411
434
|
def __call__(
|
|
412
435
|
self, /, batch_max_size: int, batch_wait_ms: int
|
|
413
|
-
) -> typing.Iterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
|
436
|
+
) -> typing.Iterator[list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
|
414
437
|
def aio(
|
|
415
438
|
self, /, batch_max_size: int, batch_wait_ms: int
|
|
416
|
-
) -> collections.abc.AsyncIterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
|
439
|
+
) -> collections.abc.AsyncIterator[list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
|
417
440
|
|
|
418
441
|
_generate_inputs: ___generate_inputs_spec[typing_extensions.Self]
|
|
419
442
|
|
|
@@ -435,28 +458,16 @@ class ContainerIOManager:
|
|
|
435
458
|
|
|
436
459
|
run_inputs_outputs: __run_inputs_outputs_spec[typing_extensions.Self]
|
|
437
460
|
|
|
438
|
-
class
|
|
439
|
-
def __call__(
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
io_context: IOContext,
|
|
443
|
-
started_at: float,
|
|
444
|
-
data_format: int,
|
|
445
|
-
results: list[modal_proto.api_pb2.GenericResult],
|
|
446
|
-
) -> None: ...
|
|
447
|
-
async def aio(
|
|
448
|
-
self,
|
|
449
|
-
/,
|
|
450
|
-
io_context: IOContext,
|
|
451
|
-
started_at: float,
|
|
452
|
-
data_format: int,
|
|
453
|
-
results: list[modal_proto.api_pb2.GenericResult],
|
|
454
|
-
) -> None: ...
|
|
461
|
+
class ___send_outputs_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
462
|
+
def __call__(self, /, started_at: float, outputs: list[modal_proto.api_pb2.FunctionPutOutputsItem]) -> None:
|
|
463
|
+
"""Send pre-built output items with retry and chunking."""
|
|
464
|
+
...
|
|
455
465
|
|
|
456
|
-
|
|
466
|
+
async def aio(self, /, started_at: float, outputs: list[modal_proto.api_pb2.FunctionPutOutputsItem]) -> None:
|
|
467
|
+
"""Send pre-built output items with retry and chunking."""
|
|
468
|
+
...
|
|
457
469
|
|
|
458
|
-
|
|
459
|
-
def serialize_traceback(self, exc: BaseException) -> tuple[typing.Optional[bytes], typing.Optional[bytes]]: ...
|
|
470
|
+
_send_outputs: ___send_outputs_spec[typing_extensions.Self]
|
|
460
471
|
|
|
461
472
|
class __handle_user_exception_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
462
473
|
def __call__(self, /) -> synchronicity.combined_types.AsyncAndBlockingContextManager[None]:
|
|
@@ -493,10 +504,8 @@ class ContainerIOManager:
|
|
|
493
504
|
def exit_context(self, started_at, input_ids: list[str]): ...
|
|
494
505
|
|
|
495
506
|
class __push_outputs_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
496
|
-
def __call__(self, /, io_context: IOContext, started_at: float,
|
|
497
|
-
async def aio(
|
|
498
|
-
self, /, io_context: IOContext, started_at: float, data: typing.Any, data_format: int
|
|
499
|
-
) -> None: ...
|
|
507
|
+
def __call__(self, /, io_context: IOContext, started_at: float, output_data: list[typing.Any]) -> None: ...
|
|
508
|
+
async def aio(self, /, io_context: IOContext, started_at: float, output_data: list[typing.Any]) -> None: ...
|
|
500
509
|
|
|
501
510
|
push_outputs: __push_outputs_spec[typing_extensions.Self]
|
|
502
511
|
|
|
@@ -72,22 +72,38 @@ def current_function_call_id() -> Optional[str]:
|
|
|
72
72
|
return None
|
|
73
73
|
|
|
74
74
|
|
|
75
|
-
def
|
|
76
|
-
|
|
75
|
+
def current_attempt_token() -> Optional[str]:
|
|
76
|
+
# This ContextVar isn't useful to expose to users.
|
|
77
|
+
try:
|
|
78
|
+
return _current_attempt_token.get()
|
|
79
|
+
except LookupError:
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _set_current_context_ids(
|
|
84
|
+
input_ids: list[str], function_call_ids: list[str], attempt_tokens: list[str]
|
|
85
|
+
) -> Callable[[], None]:
|
|
86
|
+
assert len(input_ids) == len(function_call_ids) == len(attempt_tokens) and input_ids
|
|
87
|
+
|
|
77
88
|
input_id = input_ids[0]
|
|
78
89
|
function_call_id = function_call_ids[0]
|
|
90
|
+
attempt_token = attempt_tokens[0]
|
|
91
|
+
|
|
79
92
|
input_token = _current_input_id.set(input_id)
|
|
80
93
|
function_call_token = _current_function_call_id.set(function_call_id)
|
|
94
|
+
attempt_token_token = _current_attempt_token.set(attempt_token)
|
|
81
95
|
|
|
82
96
|
def _reset_current_context_ids():
|
|
83
97
|
_current_input_id.reset(input_token)
|
|
84
98
|
_current_function_call_id.reset(function_call_token)
|
|
99
|
+
_current_attempt_token.reset(attempt_token_token)
|
|
85
100
|
|
|
86
101
|
return _reset_current_context_ids
|
|
87
102
|
|
|
88
103
|
|
|
89
104
|
_current_input_id: ContextVar = ContextVar("_current_input_id")
|
|
90
105
|
_current_function_call_id: ContextVar = ContextVar("_current_function_call_id")
|
|
106
|
+
_current_attempt_token: ContextVar = ContextVar("_current_attempt_token")
|
|
91
107
|
|
|
92
108
|
_is_currently_importing = False # we set this to True while a container is importing user code
|
|
93
109
|
|
|
@@ -68,11 +68,14 @@ def current_function_call_id() -> typing.Optional[str]:
|
|
|
68
68
|
"""
|
|
69
69
|
...
|
|
70
70
|
|
|
71
|
+
def current_attempt_token() -> typing.Optional[str]: ...
|
|
71
72
|
def _set_current_context_ids(
|
|
72
|
-
input_ids: list[str], function_call_ids: list[str]
|
|
73
|
+
input_ids: list[str], function_call_ids: list[str], attempt_tokens: list[str]
|
|
73
74
|
) -> collections.abc.Callable[[], None]: ...
|
|
74
75
|
def _import_context(): ...
|
|
75
76
|
|
|
76
77
|
_current_input_id: contextvars.ContextVar
|
|
77
78
|
|
|
78
79
|
_current_function_call_id: contextvars.ContextVar
|
|
80
|
+
|
|
81
|
+
_current_attempt_token: contextvars.ContextVar
|
|
@@ -1,25 +1,34 @@
|
|
|
1
1
|
# Copyright Modal Labs 2022
|
|
2
2
|
#
|
|
3
3
|
# This module provides a simple interface for creating GPU memory snapshots,
|
|
4
|
-
#
|
|
4
|
+
# providing a convenient interface to `cuda-checkpoint` [1]. This is intended
|
|
5
5
|
# to be used in conjunction with memory snapshots.
|
|
6
6
|
#
|
|
7
7
|
# [1] https://github.com/NVIDIA/cuda-checkpoint
|
|
8
8
|
|
|
9
9
|
import subprocess
|
|
10
10
|
import time
|
|
11
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
11
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
12
12
|
from dataclasses import dataclass
|
|
13
13
|
from enum import Enum
|
|
14
14
|
from pathlib import Path
|
|
15
|
+
from typing import List, Optional
|
|
15
16
|
|
|
16
17
|
from modal.config import config, logger
|
|
17
18
|
|
|
18
19
|
CUDA_CHECKPOINT_PATH: str = config.get("cuda_checkpoint_path")
|
|
19
20
|
|
|
21
|
+
# Maximum total duration for an entire toggle operation.
|
|
22
|
+
CUDA_CHECKPOINT_TOGGLE_TIMEOUT: float = 5 * 60.0
|
|
23
|
+
|
|
24
|
+
# Maximum total duration for each individual `cuda-checkpoint` invocation.
|
|
25
|
+
CUDA_CHECKPOINT_TIMEOUT: float = 90
|
|
26
|
+
|
|
20
27
|
|
|
21
28
|
class CudaCheckpointState(Enum):
|
|
22
|
-
"""State representation from the CUDA API
|
|
29
|
+
"""State representation from the CUDA API [1].
|
|
30
|
+
|
|
31
|
+
[1] https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html"""
|
|
23
32
|
|
|
24
33
|
RUNNING = "running"
|
|
25
34
|
LOCKED = "locked"
|
|
@@ -28,6 +37,8 @@ class CudaCheckpointState(Enum):
|
|
|
28
37
|
|
|
29
38
|
|
|
30
39
|
class CudaCheckpointException(Exception):
|
|
40
|
+
"""Exception raised for CUDA checkpoint operations."""
|
|
41
|
+
|
|
31
42
|
pass
|
|
32
43
|
|
|
33
44
|
|
|
@@ -39,24 +50,44 @@ class CudaCheckpointProcess:
|
|
|
39
50
|
pid: int
|
|
40
51
|
state: CudaCheckpointState
|
|
41
52
|
|
|
42
|
-
def toggle(self, target_state: CudaCheckpointState,
|
|
53
|
+
def toggle(self, target_state: CudaCheckpointState, skip_first_refresh: bool = False) -> None:
|
|
43
54
|
"""Toggle CUDA checkpoint state for current process, moving GPU memory to the
|
|
44
|
-
CPU and back depending on the current process state when called.
|
|
55
|
+
CPU and back depending on the current process state when called.
|
|
56
|
+
"""
|
|
45
57
|
logger.debug(f"PID: {self.pid} Toggling CUDA checkpoint state to {target_state.value}")
|
|
46
58
|
|
|
47
59
|
start_time = time.monotonic()
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
60
|
+
retry_count = 0
|
|
61
|
+
max_retries = 3
|
|
62
|
+
|
|
63
|
+
attempts = 0
|
|
64
|
+
while self._should_continue_toggle(
|
|
65
|
+
target_state, start_time, refresh=not (skip_first_refresh and attempts == 0)
|
|
66
|
+
):
|
|
67
|
+
attempts += 1
|
|
68
|
+
try:
|
|
69
|
+
self._execute_toggle_command()
|
|
70
|
+
# Use exponential backoff for retries
|
|
71
|
+
sleep_time = min(0.1 * (2**retry_count), 1.0)
|
|
72
|
+
time.sleep(sleep_time)
|
|
73
|
+
retry_count = 0
|
|
74
|
+
except CudaCheckpointException as e:
|
|
75
|
+
retry_count += 1
|
|
76
|
+
if retry_count >= max_retries:
|
|
77
|
+
raise CudaCheckpointException(
|
|
78
|
+
f"PID: {self.pid} Failed to toggle state after {max_retries} retries: {e}"
|
|
79
|
+
)
|
|
80
|
+
logger.debug(f"PID: {self.pid} Retry {retry_count}/{max_retries} after error: {e}")
|
|
81
|
+
time.sleep(0.5 * retry_count)
|
|
52
82
|
|
|
53
83
|
logger.debug(f"PID: {self.pid} Target state {target_state.value} reached")
|
|
54
84
|
|
|
55
85
|
def _should_continue_toggle(
|
|
56
|
-
self, target_state: CudaCheckpointState, start_time: float,
|
|
86
|
+
self, target_state: CudaCheckpointState, start_time: float, refresh: bool = True
|
|
57
87
|
) -> bool:
|
|
58
88
|
"""Check if toggle operation should continue based on current state and timeout."""
|
|
59
|
-
|
|
89
|
+
if refresh:
|
|
90
|
+
self.refresh_state()
|
|
60
91
|
|
|
61
92
|
if self.state == target_state:
|
|
62
93
|
return False
|
|
@@ -65,7 +96,7 @@ class CudaCheckpointProcess:
|
|
|
65
96
|
raise CudaCheckpointException(f"PID: {self.pid} CUDA process state is {self.state}")
|
|
66
97
|
|
|
67
98
|
elapsed = time.monotonic() - start_time
|
|
68
|
-
if elapsed >=
|
|
99
|
+
if elapsed >= CUDA_CHECKPOINT_TOGGLE_TIMEOUT:
|
|
69
100
|
raise CudaCheckpointException(
|
|
70
101
|
f"PID: {self.pid} Timeout after {elapsed:.2f}s waiting for state {target_state.value}. "
|
|
71
102
|
f"Current state: {self.state}"
|
|
@@ -73,19 +104,25 @@ class CudaCheckpointProcess:
|
|
|
73
104
|
|
|
74
105
|
return True
|
|
75
106
|
|
|
76
|
-
def _execute_toggle_command(self):
|
|
107
|
+
def _execute_toggle_command(self) -> None:
|
|
77
108
|
"""Execute the cuda-checkpoint toggle command."""
|
|
78
109
|
try:
|
|
79
|
-
subprocess.run(
|
|
110
|
+
_ = subprocess.run(
|
|
80
111
|
[CUDA_CHECKPOINT_PATH, "--toggle", "--pid", str(self.pid)],
|
|
81
112
|
check=True,
|
|
82
113
|
capture_output=True,
|
|
83
114
|
text=True,
|
|
115
|
+
timeout=CUDA_CHECKPOINT_TIMEOUT,
|
|
84
116
|
)
|
|
85
117
|
logger.debug(f"PID: {self.pid} Successfully toggled CUDA checkpoint state")
|
|
86
118
|
except subprocess.CalledProcessError as e:
|
|
87
|
-
|
|
88
|
-
|
|
119
|
+
error_msg = f"PID: {self.pid} Failed to toggle CUDA checkpoint state: {e.stderr}"
|
|
120
|
+
logger.debug(error_msg)
|
|
121
|
+
raise CudaCheckpointException(error_msg)
|
|
122
|
+
except subprocess.TimeoutExpired:
|
|
123
|
+
error_msg = f"PID: {self.pid} Toggle command timed out"
|
|
124
|
+
logger.debug(error_msg)
|
|
125
|
+
raise CudaCheckpointException(error_msg)
|
|
89
126
|
|
|
90
127
|
def refresh_state(self) -> None:
|
|
91
128
|
"""Refreshes the current CUDA checkpoint state for this process."""
|
|
@@ -95,15 +132,20 @@ class CudaCheckpointProcess:
|
|
|
95
132
|
check=True,
|
|
96
133
|
capture_output=True,
|
|
97
134
|
text=True,
|
|
98
|
-
timeout=
|
|
135
|
+
timeout=CUDA_CHECKPOINT_TIMEOUT,
|
|
99
136
|
)
|
|
100
137
|
|
|
101
138
|
state_str = result.stdout.strip().lower()
|
|
102
139
|
self.state = CudaCheckpointState(state_str)
|
|
103
140
|
|
|
104
141
|
except subprocess.CalledProcessError as e:
|
|
105
|
-
|
|
106
|
-
|
|
142
|
+
error_msg = f"PID: {self.pid} Failed to get CUDA checkpoint state: {e.stderr}"
|
|
143
|
+
logger.debug(error_msg)
|
|
144
|
+
raise CudaCheckpointException(error_msg)
|
|
145
|
+
except subprocess.TimeoutExpired:
|
|
146
|
+
error_msg = f"PID: {self.pid} Get state command timed out"
|
|
147
|
+
logger.debug(error_msg)
|
|
148
|
+
raise CudaCheckpointException(error_msg)
|
|
107
149
|
|
|
108
150
|
|
|
109
151
|
class CudaCheckpointSession:
|
|
@@ -111,12 +153,17 @@ class CudaCheckpointSession:
|
|
|
111
153
|
|
|
112
154
|
def __init__(self):
|
|
113
155
|
self.cuda_processes = self._get_cuda_pids()
|
|
114
|
-
|
|
156
|
+
if self.cuda_processes:
|
|
157
|
+
logger.debug(
|
|
158
|
+
f"Found {len(self.cuda_processes)} PID(s) with CUDA sessions: {[c.pid for c in self.cuda_processes]}"
|
|
159
|
+
)
|
|
160
|
+
else:
|
|
161
|
+
logger.debug("No CUDA sessions found.")
|
|
115
162
|
|
|
116
|
-
def _get_cuda_pids(self) ->
|
|
163
|
+
def _get_cuda_pids(self) -> List[CudaCheckpointProcess]:
|
|
117
164
|
"""Iterates over all PIDs and identifies the ones that have running
|
|
118
165
|
CUDA sessions."""
|
|
119
|
-
cuda_pids:
|
|
166
|
+
cuda_pids: List[CudaCheckpointProcess] = []
|
|
120
167
|
|
|
121
168
|
# Get all active process IDs from /proc directory
|
|
122
169
|
proc_dir = Path("/proc")
|
|
@@ -125,75 +172,135 @@ class CudaCheckpointSession:
|
|
|
125
172
|
"OS does not have /proc path rendering it incompatible with GPU memory snapshots."
|
|
126
173
|
)
|
|
127
174
|
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
continue
|
|
175
|
+
# Get all numeric directories (PIDs) from /proc
|
|
176
|
+
pid_dirs = [entry for entry in proc_dir.iterdir() if entry.name.isdigit()]
|
|
131
177
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
capture_output=True,
|
|
138
|
-
text=True,
|
|
139
|
-
timeout=10,
|
|
140
|
-
)
|
|
141
|
-
|
|
142
|
-
# If the command succeeds (return code 0), this PID has a CUDA session
|
|
143
|
-
if result.returncode == 0:
|
|
144
|
-
state_str = result.stdout.strip().lower()
|
|
145
|
-
state = CudaCheckpointState(state_str)
|
|
146
|
-
|
|
147
|
-
cuda_checkpoint_process = CudaCheckpointProcess(pid=pid, state=state)
|
|
148
|
-
cuda_pids.append(cuda_checkpoint_process)
|
|
149
|
-
|
|
150
|
-
# Command failed, which is expected for PIDs without CUDA sessions
|
|
151
|
-
except subprocess.CalledProcessError:
|
|
152
|
-
continue
|
|
178
|
+
# Use ThreadPoolExecutor to check PIDs in parallel for better performance
|
|
179
|
+
with ThreadPoolExecutor(max_workers=min(50, len(pid_dirs))) as executor:
|
|
180
|
+
future_to_pid = {
|
|
181
|
+
executor.submit(self._check_cuda_session, int(entry.name)): int(entry.name) for entry in pid_dirs
|
|
182
|
+
}
|
|
153
183
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
184
|
+
for future in as_completed(future_to_pid):
|
|
185
|
+
pid = future_to_pid[future]
|
|
186
|
+
try:
|
|
187
|
+
cuda_process = future.result()
|
|
188
|
+
if cuda_process:
|
|
189
|
+
cuda_pids.append(cuda_process)
|
|
190
|
+
except Exception as e:
|
|
191
|
+
logger.debug(f"Error checking PID {pid}: {e}")
|
|
159
192
|
|
|
160
193
|
# Sort PIDs for ordered checkpointing
|
|
161
194
|
cuda_pids.sort(key=lambda x: x.pid)
|
|
162
195
|
return cuda_pids
|
|
163
196
|
|
|
197
|
+
def _check_cuda_session(self, pid: int) -> Optional[CudaCheckpointProcess]:
|
|
198
|
+
"""Check if a specific PID has a CUDA session."""
|
|
199
|
+
try:
|
|
200
|
+
result = subprocess.run(
|
|
201
|
+
[CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(pid)],
|
|
202
|
+
capture_output=True,
|
|
203
|
+
text=True,
|
|
204
|
+
# This should be quick since no checkpoint has taken place yet
|
|
205
|
+
timeout=5,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# If the command succeeds (return code 0), this PID has a CUDA session
|
|
209
|
+
if result.returncode == 0:
|
|
210
|
+
state_str = result.stdout.strip().lower()
|
|
211
|
+
state = CudaCheckpointState(state_str)
|
|
212
|
+
return CudaCheckpointProcess(pid=pid, state=state)
|
|
213
|
+
|
|
214
|
+
except subprocess.CalledProcessError:
|
|
215
|
+
# Command failed, which is expected for PIDs without CUDA sessions
|
|
216
|
+
pass
|
|
217
|
+
except subprocess.TimeoutExpired:
|
|
218
|
+
logger.debug(f"Timeout checking CUDA state for PID {pid}")
|
|
219
|
+
except Exception as e:
|
|
220
|
+
logger.debug(f"Error checking PID {pid}: {e}")
|
|
221
|
+
|
|
222
|
+
return None
|
|
223
|
+
|
|
164
224
|
def checkpoint(self) -> None:
|
|
225
|
+
"""Checkpoint all CUDA processes, moving GPU memory to CPU."""
|
|
226
|
+
if not self.cuda_processes:
|
|
227
|
+
logger.debug("No CUDA processes to checkpoint.")
|
|
228
|
+
return
|
|
229
|
+
|
|
165
230
|
# Validate all states first
|
|
166
231
|
for proc in self.cuda_processes:
|
|
232
|
+
proc.refresh_state() # Refresh state before validation
|
|
167
233
|
if proc.state != CudaCheckpointState.RUNNING:
|
|
168
|
-
raise CudaCheckpointException(
|
|
234
|
+
raise CudaCheckpointException(
|
|
235
|
+
f"PID {proc.pid}: CUDA session not in {CudaCheckpointState.RUNNING.value} state. "
|
|
236
|
+
f"Current state: {proc.state.value}"
|
|
237
|
+
)
|
|
169
238
|
|
|
170
239
|
# Moving state from GPU to CPU can take several seconds per CUDA session.
|
|
171
240
|
# Make a parallel call per CUDA session.
|
|
172
241
|
start = time.perf_counter()
|
|
173
242
|
|
|
174
|
-
def checkpoint_impl(proc: CudaCheckpointProcess):
|
|
243
|
+
def checkpoint_impl(proc: CudaCheckpointProcess) -> None:
|
|
175
244
|
proc.toggle(CudaCheckpointState.CHECKPOINTED)
|
|
176
245
|
|
|
177
246
|
with ThreadPoolExecutor() as executor:
|
|
178
|
-
|
|
247
|
+
futures = [executor.submit(checkpoint_impl, proc) for proc in self.cuda_processes]
|
|
248
|
+
|
|
249
|
+
# Wait for all futures and collect any exceptions
|
|
250
|
+
exceptions = []
|
|
251
|
+
for future in as_completed(futures):
|
|
252
|
+
try:
|
|
253
|
+
future.result()
|
|
254
|
+
except Exception as e:
|
|
255
|
+
exceptions.append(e)
|
|
256
|
+
|
|
257
|
+
if exceptions:
|
|
258
|
+
raise CudaCheckpointException(
|
|
259
|
+
f"Failed to checkpoint {len(exceptions)} processes: {'; '.join(str(e) for e in exceptions)}"
|
|
260
|
+
)
|
|
179
261
|
|
|
180
262
|
elapsed = time.perf_counter() - start
|
|
181
|
-
logger.debug(f"Checkpointing CUDA sessions took => {elapsed:.3f}s")
|
|
263
|
+
logger.debug(f"Checkpointing {len(self.cuda_processes)} CUDA sessions took => {elapsed:.3f}s")
|
|
182
264
|
|
|
183
265
|
def restore(self) -> None:
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
266
|
+
"""Restore all CUDA processes, moving memory back from CPU to GPU."""
|
|
267
|
+
if not self.cuda_processes:
|
|
268
|
+
logger.debug("No CUDA sessions to restore.")
|
|
269
|
+
return
|
|
188
270
|
|
|
189
271
|
# See checkpoint() for rationale about parallelism.
|
|
190
272
|
start = time.perf_counter()
|
|
191
273
|
|
|
192
|
-
def restore_process(proc: CudaCheckpointProcess):
|
|
193
|
-
proc.toggle(CudaCheckpointState.RUNNING)
|
|
274
|
+
def restore_process(proc: CudaCheckpointProcess) -> None:
|
|
275
|
+
proc.toggle(CudaCheckpointState.RUNNING, skip_first_refresh=True)
|
|
194
276
|
|
|
195
277
|
with ThreadPoolExecutor() as executor:
|
|
196
|
-
|
|
278
|
+
futures = [executor.submit(restore_process, proc) for proc in self.cuda_processes]
|
|
279
|
+
|
|
280
|
+
# Wait for all futures and collect any exceptions
|
|
281
|
+
exceptions = []
|
|
282
|
+
for future in as_completed(futures):
|
|
283
|
+
try:
|
|
284
|
+
future.result()
|
|
285
|
+
except Exception as e:
|
|
286
|
+
exceptions.append(e)
|
|
287
|
+
|
|
288
|
+
if exceptions:
|
|
289
|
+
raise CudaCheckpointException(
|
|
290
|
+
f"Failed to restore {len(exceptions)} processes: {'; '.join(str(e) for e in exceptions)}"
|
|
291
|
+
)
|
|
197
292
|
|
|
198
293
|
elapsed = time.perf_counter() - start
|
|
199
|
-
logger.debug(f"Restoring CUDA
|
|
294
|
+
logger.debug(f"Restoring {len(self.cuda_processes)} CUDA session(s) took => {elapsed:.3f}s")
|
|
295
|
+
|
|
296
|
+
def get_process_count(self) -> int:
|
|
297
|
+
"""Get the number of CUDA processes managed by this session."""
|
|
298
|
+
return len(self.cuda_processes)
|
|
299
|
+
|
|
300
|
+
def get_process_states(self) -> List[tuple[int, CudaCheckpointState]]:
|
|
301
|
+
"""Get current states of all managed processes."""
|
|
302
|
+
states = []
|
|
303
|
+
for proc in self.cuda_processes:
|
|
304
|
+
proc.refresh_state()
|
|
305
|
+
states.append((proc.pid, proc.state))
|
|
306
|
+
return states
|