modal 1.0.3.dev10__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modal might be problematic. Click here for more details.
- modal/__init__.py +0 -2
- modal/__main__.py +3 -4
- modal/_billing.py +80 -0
- modal/_clustered_functions.py +7 -3
- modal/_clustered_functions.pyi +15 -3
- modal/_container_entrypoint.py +51 -69
- modal/_functions.py +508 -240
- modal/_grpc_client.py +171 -0
- modal/_load_context.py +105 -0
- modal/_object.py +81 -21
- modal/_output.py +58 -45
- modal/_partial_function.py +48 -73
- modal/_pty.py +7 -3
- modal/_resolver.py +26 -46
- modal/_runtime/asgi.py +4 -3
- modal/_runtime/container_io_manager.py +358 -220
- modal/_runtime/container_io_manager.pyi +296 -101
- modal/_runtime/execution_context.py +18 -2
- modal/_runtime/execution_context.pyi +64 -7
- modal/_runtime/gpu_memory_snapshot.py +262 -57
- modal/_runtime/user_code_imports.py +28 -58
- modal/_serialization.py +90 -6
- modal/_traceback.py +42 -1
- modal/_tunnel.pyi +380 -12
- modal/_utils/async_utils.py +84 -29
- modal/_utils/auth_token_manager.py +111 -0
- modal/_utils/blob_utils.py +181 -58
- modal/_utils/deprecation.py +19 -0
- modal/_utils/function_utils.py +91 -47
- modal/_utils/grpc_utils.py +89 -66
- modal/_utils/mount_utils.py +26 -1
- modal/_utils/name_utils.py +17 -3
- modal/_utils/task_command_router_client.py +536 -0
- modal/_utils/time_utils.py +34 -6
- modal/app.py +256 -88
- modal/app.pyi +909 -92
- modal/billing.py +5 -0
- modal/builder/2025.06.txt +18 -0
- modal/builder/PREVIEW.txt +18 -0
- modal/builder/base-images.json +58 -0
- modal/cli/_download.py +19 -3
- modal/cli/_traceback.py +3 -2
- modal/cli/app.py +4 -4
- modal/cli/cluster.py +15 -7
- modal/cli/config.py +5 -3
- modal/cli/container.py +7 -6
- modal/cli/dict.py +22 -16
- modal/cli/entry_point.py +12 -5
- modal/cli/environment.py +5 -4
- modal/cli/import_refs.py +3 -3
- modal/cli/launch.py +102 -5
- modal/cli/network_file_system.py +11 -12
- modal/cli/profile.py +3 -2
- modal/cli/programs/launch_instance_ssh.py +94 -0
- modal/cli/programs/run_jupyter.py +1 -1
- modal/cli/programs/run_marimo.py +95 -0
- modal/cli/programs/vscode.py +1 -1
- modal/cli/queues.py +57 -26
- modal/cli/run.py +91 -23
- modal/cli/secret.py +48 -22
- modal/cli/token.py +7 -8
- modal/cli/utils.py +4 -7
- modal/cli/volume.py +31 -25
- modal/client.py +15 -85
- modal/client.pyi +183 -62
- modal/cloud_bucket_mount.py +5 -3
- modal/cloud_bucket_mount.pyi +197 -5
- modal/cls.py +200 -126
- modal/cls.pyi +446 -68
- modal/config.py +29 -11
- modal/container_process.py +319 -19
- modal/container_process.pyi +190 -20
- modal/dict.py +290 -71
- modal/dict.pyi +835 -83
- modal/environments.py +15 -27
- modal/environments.pyi +46 -24
- modal/exception.py +14 -2
- modal/experimental/__init__.py +194 -40
- modal/experimental/flash.py +618 -0
- modal/experimental/flash.pyi +380 -0
- modal/experimental/ipython.py +11 -7
- modal/file_io.py +29 -36
- modal/file_io.pyi +251 -53
- modal/file_pattern_matcher.py +56 -16
- modal/functions.pyi +673 -92
- modal/gpu.py +1 -1
- modal/image.py +528 -176
- modal/image.pyi +1572 -145
- modal/io_streams.py +458 -128
- modal/io_streams.pyi +433 -52
- modal/mount.py +216 -151
- modal/mount.pyi +225 -78
- modal/network_file_system.py +45 -62
- modal/network_file_system.pyi +277 -56
- modal/object.pyi +93 -17
- modal/parallel_map.py +942 -129
- modal/parallel_map.pyi +294 -15
- modal/partial_function.py +0 -2
- modal/partial_function.pyi +234 -19
- modal/proxy.py +17 -8
- modal/proxy.pyi +36 -3
- modal/queue.py +270 -65
- modal/queue.pyi +817 -57
- modal/runner.py +115 -101
- modal/runner.pyi +205 -49
- modal/sandbox.py +512 -136
- modal/sandbox.pyi +845 -111
- modal/schedule.py +1 -1
- modal/secret.py +300 -70
- modal/secret.pyi +589 -34
- modal/serving.py +7 -11
- modal/serving.pyi +7 -8
- modal/snapshot.py +11 -8
- modal/snapshot.pyi +25 -4
- modal/token_flow.py +4 -4
- modal/token_flow.pyi +28 -8
- modal/volume.py +416 -158
- modal/volume.pyi +1117 -121
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +10 -9
- modal-1.2.3.dev7.dist-info/RECORD +195 -0
- modal_docs/mdmd/mdmd.py +17 -4
- modal_proto/api.proto +534 -79
- modal_proto/api_grpc.py +337 -1
- modal_proto/api_pb2.py +1522 -968
- modal_proto/api_pb2.pyi +1619 -134
- modal_proto/api_pb2_grpc.py +699 -4
- modal_proto/api_pb2_grpc.pyi +226 -14
- modal_proto/modal_api_grpc.py +175 -154
- modal_proto/sandbox_router.proto +145 -0
- modal_proto/sandbox_router_grpc.py +105 -0
- modal_proto/sandbox_router_pb2.py +149 -0
- modal_proto/sandbox_router_pb2.pyi +333 -0
- modal_proto/sandbox_router_pb2_grpc.py +203 -0
- modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
- modal_proto/task_command_router.proto +144 -0
- modal_proto/task_command_router_grpc.py +105 -0
- modal_proto/task_command_router_pb2.py +149 -0
- modal_proto/task_command_router_pb2.pyi +333 -0
- modal_proto/task_command_router_pb2_grpc.py +203 -0
- modal_proto/task_command_router_pb2_grpc.pyi +75 -0
- modal_version/__init__.py +1 -1
- modal/requirements/PREVIEW.txt +0 -16
- modal/requirements/base-images.json +0 -26
- modal-1.0.3.dev10.dist-info/RECORD +0 -179
- modal_proto/modal_options_grpc.py +0 -3
- modal_proto/options.proto +0 -19
- modal_proto/options_grpc.py +0 -3
- modal_proto/options_pb2.py +0 -35
- modal_proto/options_pb2.pyi +0 -20
- modal_proto/options_pb2_grpc.py +0 -4
- modal_proto/options_pb2_grpc.pyi +0 -7
- /modal/{requirements → builder}/2023.12.312.txt +0 -0
- /modal/{requirements → builder}/2023.12.txt +0 -0
- /modal/{requirements → builder}/2024.04.txt +0 -0
- /modal/{requirements → builder}/2024.10.txt +0 -0
- /modal/{requirements → builder}/README.md +0 -0
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
|
@@ -3,22 +3,79 @@ import contextvars
|
|
|
3
3
|
import typing
|
|
4
4
|
import typing_extensions
|
|
5
5
|
|
|
6
|
-
def is_local() -> bool:
|
|
7
|
-
|
|
6
|
+
def is_local() -> bool:
|
|
7
|
+
"""Returns if we are currently on the machine launching/deploying a Modal app
|
|
8
|
+
|
|
9
|
+
Returns `True` when executed locally on the user's machine.
|
|
10
|
+
Returns `False` when executed from a Modal container in the cloud.
|
|
11
|
+
"""
|
|
12
|
+
...
|
|
13
|
+
|
|
14
|
+
async def _interact() -> None:
|
|
15
|
+
"""Enable interactivity with user input inside a Modal container.
|
|
16
|
+
|
|
17
|
+
See the [interactivity guide](https://modal.com/docs/guide/developing-debugging#interactivity)
|
|
18
|
+
for more information on how to use this function.
|
|
19
|
+
"""
|
|
20
|
+
...
|
|
8
21
|
|
|
9
22
|
class __interact_spec(typing_extensions.Protocol):
|
|
10
|
-
def __call__(self, /) -> None:
|
|
11
|
-
|
|
23
|
+
def __call__(self, /) -> None:
|
|
24
|
+
"""Enable interactivity with user input inside a Modal container.
|
|
25
|
+
|
|
26
|
+
See the [interactivity guide](https://modal.com/docs/guide/developing-debugging#interactivity)
|
|
27
|
+
for more information on how to use this function.
|
|
28
|
+
"""
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
async def aio(self, /) -> None:
|
|
32
|
+
"""Enable interactivity with user input inside a Modal container.
|
|
33
|
+
|
|
34
|
+
See the [interactivity guide](https://modal.com/docs/guide/developing-debugging#interactivity)
|
|
35
|
+
for more information on how to use this function.
|
|
36
|
+
"""
|
|
37
|
+
...
|
|
12
38
|
|
|
13
39
|
interact: __interact_spec
|
|
14
40
|
|
|
15
|
-
def current_input_id() -> typing.Optional[str]:
|
|
16
|
-
|
|
41
|
+
def current_input_id() -> typing.Optional[str]:
|
|
42
|
+
"""Returns the input ID for the current input.
|
|
43
|
+
|
|
44
|
+
Can only be called from Modal function (i.e. in a container context).
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
from modal import current_input_id
|
|
48
|
+
|
|
49
|
+
@app.function()
|
|
50
|
+
def process_stuff():
|
|
51
|
+
print(f"Starting to process {current_input_id()}")
|
|
52
|
+
```
|
|
53
|
+
"""
|
|
54
|
+
...
|
|
55
|
+
|
|
56
|
+
def current_function_call_id() -> typing.Optional[str]:
|
|
57
|
+
"""Returns the function call ID for the current input.
|
|
58
|
+
|
|
59
|
+
Can only be called from Modal function (i.e. in a container context).
|
|
60
|
+
|
|
61
|
+
```python
|
|
62
|
+
from modal import current_function_call_id
|
|
63
|
+
|
|
64
|
+
@app.function()
|
|
65
|
+
def process_stuff():
|
|
66
|
+
print(f"Starting to process input from {current_function_call_id()}")
|
|
67
|
+
```
|
|
68
|
+
"""
|
|
69
|
+
...
|
|
70
|
+
|
|
71
|
+
def current_attempt_token() -> typing.Optional[str]: ...
|
|
17
72
|
def _set_current_context_ids(
|
|
18
|
-
input_ids: list[str], function_call_ids: list[str]
|
|
73
|
+
input_ids: list[str], function_call_ids: list[str], attempt_tokens: list[str]
|
|
19
74
|
) -> collections.abc.Callable[[], None]: ...
|
|
20
75
|
def _import_context(): ...
|
|
21
76
|
|
|
22
77
|
_current_input_id: contextvars.ContextVar
|
|
23
78
|
|
|
24
79
|
_current_function_call_id: contextvars.ContextVar
|
|
80
|
+
|
|
81
|
+
_current_attempt_token: contextvars.ContextVar
|
|
@@ -1,23 +1,34 @@
|
|
|
1
1
|
# Copyright Modal Labs 2022
|
|
2
2
|
#
|
|
3
3
|
# This module provides a simple interface for creating GPU memory snapshots,
|
|
4
|
-
#
|
|
4
|
+
# providing a convenient interface to `cuda-checkpoint` [1]. This is intended
|
|
5
5
|
# to be used in conjunction with memory snapshots.
|
|
6
6
|
#
|
|
7
7
|
# [1] https://github.com/NVIDIA/cuda-checkpoint
|
|
8
8
|
|
|
9
|
-
import os
|
|
10
9
|
import subprocess
|
|
11
10
|
import time
|
|
11
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
12
|
+
from dataclasses import dataclass
|
|
12
13
|
from enum import Enum
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import List, Optional
|
|
13
16
|
|
|
14
17
|
from modal.config import config, logger
|
|
15
18
|
|
|
16
19
|
CUDA_CHECKPOINT_PATH: str = config.get("cuda_checkpoint_path")
|
|
17
20
|
|
|
21
|
+
# Maximum total duration for an entire toggle operation.
|
|
22
|
+
CUDA_CHECKPOINT_TOGGLE_TIMEOUT: float = 5 * 60.0
|
|
23
|
+
|
|
24
|
+
# Maximum total duration for each individual `cuda-checkpoint` invocation.
|
|
25
|
+
CUDA_CHECKPOINT_TIMEOUT: float = 90
|
|
26
|
+
|
|
18
27
|
|
|
19
28
|
class CudaCheckpointState(Enum):
|
|
20
|
-
"""State representation from the CUDA API
|
|
29
|
+
"""State representation from the CUDA API [1].
|
|
30
|
+
|
|
31
|
+
[1] https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html"""
|
|
21
32
|
|
|
22
33
|
RUNNING = "running"
|
|
23
34
|
LOCKED = "locked"
|
|
@@ -26,76 +37,270 @@ class CudaCheckpointState(Enum):
|
|
|
26
37
|
|
|
27
38
|
|
|
28
39
|
class CudaCheckpointException(Exception):
|
|
40
|
+
"""Exception raised for CUDA checkpoint operations."""
|
|
41
|
+
|
|
29
42
|
pass
|
|
30
43
|
|
|
31
44
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
logger.debug(f"Toggling CUDA checkpoint state for PID {pid}")
|
|
45
|
+
@dataclass
|
|
46
|
+
class CudaCheckpointProcess:
|
|
47
|
+
"""Contains a reference to a PID with active CUDA session. This also provides
|
|
48
|
+
methods for checkpointing and restoring GPU memory."""
|
|
37
49
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
[
|
|
41
|
-
CUDA_CHECKPOINT_PATH,
|
|
42
|
-
"--toggle",
|
|
43
|
-
"--pid",
|
|
44
|
-
str(pid),
|
|
45
|
-
],
|
|
46
|
-
check=True,
|
|
47
|
-
capture_output=True,
|
|
48
|
-
text=True,
|
|
49
|
-
)
|
|
50
|
-
logger.debug("Successfully toggled CUDA checkpoint state")
|
|
50
|
+
pid: int
|
|
51
|
+
state: CudaCheckpointState
|
|
51
52
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
53
|
+
def toggle(self, target_state: CudaCheckpointState, skip_first_refresh: bool = False) -> None:
|
|
54
|
+
"""Toggle CUDA checkpoint state for current process, moving GPU memory to the
|
|
55
|
+
CPU and back depending on the current process state when called.
|
|
56
|
+
"""
|
|
57
|
+
logger.debug(f"PID: {self.pid} Toggling CUDA checkpoint state to {target_state.value}")
|
|
55
58
|
|
|
59
|
+
start_time = time.monotonic()
|
|
60
|
+
retry_count = 0
|
|
61
|
+
max_retries = 3
|
|
56
62
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
63
|
+
attempts = 0
|
|
64
|
+
while self._should_continue_toggle(
|
|
65
|
+
target_state, start_time, refresh=not (skip_first_refresh and attempts == 0)
|
|
66
|
+
):
|
|
67
|
+
attempts += 1
|
|
68
|
+
try:
|
|
69
|
+
self._execute_toggle_command()
|
|
70
|
+
# Use exponential backoff for retries
|
|
71
|
+
sleep_time = min(0.1 * (2**retry_count), 1.0)
|
|
72
|
+
time.sleep(sleep_time)
|
|
73
|
+
retry_count = 0
|
|
74
|
+
except CudaCheckpointException as e:
|
|
75
|
+
retry_count += 1
|
|
76
|
+
if retry_count >= max_retries:
|
|
77
|
+
raise CudaCheckpointException(
|
|
78
|
+
f"PID: {self.pid} Failed to toggle state after {max_retries} retries: {e}"
|
|
79
|
+
)
|
|
80
|
+
logger.debug(f"PID: {self.pid} Retry {retry_count}/{max_retries} after error: {e}")
|
|
81
|
+
time.sleep(0.5 * retry_count)
|
|
60
82
|
|
|
61
|
-
|
|
62
|
-
result = subprocess.run(
|
|
63
|
-
[CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(pid)], check=True, capture_output=True, text=True
|
|
64
|
-
)
|
|
83
|
+
logger.debug(f"PID: {self.pid} Target state {target_state.value} reached")
|
|
65
84
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
85
|
+
def _should_continue_toggle(
|
|
86
|
+
self, target_state: CudaCheckpointState, start_time: float, refresh: bool = True
|
|
87
|
+
) -> bool:
|
|
88
|
+
"""Check if toggle operation should continue based on current state and timeout."""
|
|
89
|
+
if refresh:
|
|
90
|
+
self.refresh_state()
|
|
69
91
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
raise CudaCheckpointException(e.stderr)
|
|
92
|
+
if self.state == target_state:
|
|
93
|
+
return False
|
|
73
94
|
|
|
95
|
+
if self.state == CudaCheckpointState.FAILED:
|
|
96
|
+
raise CudaCheckpointException(f"PID: {self.pid} CUDA process state is {self.state}")
|
|
74
97
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
98
|
+
elapsed = time.monotonic() - start_time
|
|
99
|
+
if elapsed >= CUDA_CHECKPOINT_TOGGLE_TIMEOUT:
|
|
100
|
+
raise CudaCheckpointException(
|
|
101
|
+
f"PID: {self.pid} Timeout after {elapsed:.2f}s waiting for state {target_state.value}. "
|
|
102
|
+
f"Current state: {self.state}"
|
|
103
|
+
)
|
|
79
104
|
|
|
80
|
-
|
|
81
|
-
current_state = get_state()
|
|
105
|
+
return True
|
|
82
106
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
107
|
+
def _execute_toggle_command(self) -> None:
|
|
108
|
+
"""Execute the cuda-checkpoint toggle command."""
|
|
109
|
+
try:
|
|
110
|
+
_ = subprocess.run(
|
|
111
|
+
[CUDA_CHECKPOINT_PATH, "--toggle", "--pid", str(self.pid)],
|
|
112
|
+
check=True,
|
|
113
|
+
capture_output=True,
|
|
114
|
+
text=True,
|
|
115
|
+
timeout=CUDA_CHECKPOINT_TIMEOUT,
|
|
116
|
+
)
|
|
117
|
+
logger.debug(f"PID: {self.pid} Successfully toggled CUDA checkpoint state")
|
|
118
|
+
except subprocess.CalledProcessError as e:
|
|
119
|
+
error_msg = f"PID: {self.pid} Failed to toggle CUDA checkpoint state: {e.stderr}"
|
|
120
|
+
logger.debug(error_msg)
|
|
121
|
+
raise CudaCheckpointException(error_msg)
|
|
122
|
+
except subprocess.TimeoutExpired:
|
|
123
|
+
error_msg = f"PID: {self.pid} Toggle command timed out"
|
|
124
|
+
logger.debug(error_msg)
|
|
125
|
+
raise CudaCheckpointException(error_msg)
|
|
86
126
|
|
|
87
|
-
|
|
88
|
-
|
|
127
|
+
def refresh_state(self) -> None:
|
|
128
|
+
"""Refreshes the current CUDA checkpoint state for this process."""
|
|
129
|
+
try:
|
|
130
|
+
result = subprocess.run(
|
|
131
|
+
[CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(self.pid)],
|
|
132
|
+
check=True,
|
|
133
|
+
capture_output=True,
|
|
134
|
+
text=True,
|
|
135
|
+
timeout=CUDA_CHECKPOINT_TIMEOUT,
|
|
136
|
+
)
|
|
89
137
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
138
|
+
state_str = result.stdout.strip().lower()
|
|
139
|
+
self.state = CudaCheckpointState(state_str)
|
|
140
|
+
|
|
141
|
+
except subprocess.CalledProcessError as e:
|
|
142
|
+
error_msg = f"PID: {self.pid} Failed to get CUDA checkpoint state: {e.stderr}"
|
|
143
|
+
logger.debug(error_msg)
|
|
144
|
+
raise CudaCheckpointException(error_msg)
|
|
145
|
+
except subprocess.TimeoutExpired:
|
|
146
|
+
error_msg = f"PID: {self.pid} Get state command timed out"
|
|
147
|
+
logger.debug(error_msg)
|
|
148
|
+
raise CudaCheckpointException(error_msg)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class CudaCheckpointSession:
|
|
152
|
+
"""Manages the checkpointing state of processes with active CUDA sessions."""
|
|
153
|
+
|
|
154
|
+
def __init__(self):
|
|
155
|
+
self.cuda_processes = self._get_cuda_pids()
|
|
156
|
+
if self.cuda_processes:
|
|
157
|
+
logger.debug(
|
|
158
|
+
f"Found {len(self.cuda_processes)} PID(s) with CUDA sessions: {[c.pid for c in self.cuda_processes]}"
|
|
159
|
+
)
|
|
160
|
+
else:
|
|
161
|
+
logger.debug("No CUDA sessions found.")
|
|
162
|
+
|
|
163
|
+
def _get_cuda_pids(self) -> List[CudaCheckpointProcess]:
|
|
164
|
+
"""Iterates over all PIDs and identifies the ones that have running
|
|
165
|
+
CUDA sessions."""
|
|
166
|
+
cuda_pids: List[CudaCheckpointProcess] = []
|
|
167
|
+
|
|
168
|
+
# Get all active process IDs from /proc directory
|
|
169
|
+
proc_dir = Path("/proc")
|
|
170
|
+
if not proc_dir.exists():
|
|
171
|
+
raise CudaCheckpointException(
|
|
172
|
+
"OS does not have /proc path rendering it incompatible with GPU memory snapshots."
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Get all numeric directories (PIDs) from /proc
|
|
176
|
+
pid_dirs = [entry for entry in proc_dir.iterdir() if entry.name.isdigit()]
|
|
177
|
+
|
|
178
|
+
# Use ThreadPoolExecutor to check PIDs in parallel for better performance
|
|
179
|
+
with ThreadPoolExecutor(max_workers=min(50, len(pid_dirs))) as executor:
|
|
180
|
+
future_to_pid = {
|
|
181
|
+
executor.submit(self._check_cuda_session, int(entry.name)): int(entry.name) for entry in pid_dirs
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
for future in as_completed(future_to_pid):
|
|
185
|
+
pid = future_to_pid[future]
|
|
186
|
+
try:
|
|
187
|
+
cuda_process = future.result()
|
|
188
|
+
if cuda_process:
|
|
189
|
+
cuda_pids.append(cuda_process)
|
|
190
|
+
except Exception as e:
|
|
191
|
+
logger.debug(f"Error checking PID {pid}: {e}")
|
|
192
|
+
|
|
193
|
+
# Sort PIDs for ordered checkpointing
|
|
194
|
+
cuda_pids.sort(key=lambda x: x.pid)
|
|
195
|
+
return cuda_pids
|
|
196
|
+
|
|
197
|
+
def _check_cuda_session(self, pid: int) -> Optional[CudaCheckpointProcess]:
|
|
198
|
+
"""Check if a specific PID has a CUDA session."""
|
|
199
|
+
try:
|
|
200
|
+
result = subprocess.run(
|
|
201
|
+
[CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(pid)],
|
|
202
|
+
capture_output=True,
|
|
203
|
+
text=True,
|
|
204
|
+
# This should be quick since no checkpoint has taken place yet
|
|
205
|
+
timeout=5,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# If the command succeeds (return code 0), this PID has a CUDA session
|
|
209
|
+
if result.returncode == 0:
|
|
210
|
+
state_str = result.stdout.strip().lower()
|
|
211
|
+
state = CudaCheckpointState(state_str)
|
|
212
|
+
return CudaCheckpointProcess(pid=pid, state=state)
|
|
213
|
+
|
|
214
|
+
except subprocess.CalledProcessError:
|
|
215
|
+
# Command failed, which is expected for PIDs without CUDA sessions
|
|
216
|
+
pass
|
|
217
|
+
except subprocess.TimeoutExpired:
|
|
218
|
+
logger.debug(f"Timeout checking CUDA state for PID {pid}")
|
|
219
|
+
except Exception as e:
|
|
220
|
+
logger.debug(f"Error checking PID {pid}: {e}")
|
|
221
|
+
|
|
222
|
+
return None
|
|
223
|
+
|
|
224
|
+
def checkpoint(self) -> None:
|
|
225
|
+
"""Checkpoint all CUDA processes, moving GPU memory to CPU."""
|
|
226
|
+
if not self.cuda_processes:
|
|
227
|
+
logger.debug("No CUDA processes to checkpoint.")
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
# Validate all states first
|
|
231
|
+
for proc in self.cuda_processes:
|
|
232
|
+
proc.refresh_state() # Refresh state before validation
|
|
233
|
+
if proc.state != CudaCheckpointState.RUNNING:
|
|
234
|
+
raise CudaCheckpointException(
|
|
235
|
+
f"PID {proc.pid}: CUDA session not in {CudaCheckpointState.RUNNING.value} state. "
|
|
236
|
+
f"Current state: {proc.state.value}"
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# Moving state from GPU to CPU can take several seconds per CUDA session.
|
|
240
|
+
# Make a parallel call per CUDA session.
|
|
241
|
+
start = time.perf_counter()
|
|
242
|
+
|
|
243
|
+
def checkpoint_impl(proc: CudaCheckpointProcess) -> None:
|
|
244
|
+
proc.toggle(CudaCheckpointState.CHECKPOINTED)
|
|
245
|
+
|
|
246
|
+
with ThreadPoolExecutor() as executor:
|
|
247
|
+
futures = [executor.submit(checkpoint_impl, proc) for proc in self.cuda_processes]
|
|
248
|
+
|
|
249
|
+
# Wait for all futures and collect any exceptions
|
|
250
|
+
exceptions = []
|
|
251
|
+
for future in as_completed(futures):
|
|
252
|
+
try:
|
|
253
|
+
future.result()
|
|
254
|
+
except Exception as e:
|
|
255
|
+
exceptions.append(e)
|
|
256
|
+
|
|
257
|
+
if exceptions:
|
|
258
|
+
raise CudaCheckpointException(
|
|
259
|
+
f"Failed to checkpoint {len(exceptions)} processes: {'; '.join(str(e) for e in exceptions)}"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
elapsed = time.perf_counter() - start
|
|
263
|
+
logger.debug(f"Checkpointing {len(self.cuda_processes)} CUDA sessions took => {elapsed:.3f}s")
|
|
264
|
+
|
|
265
|
+
def restore(self) -> None:
|
|
266
|
+
"""Restore all CUDA processes, moving memory back from CPU to GPU."""
|
|
267
|
+
if not self.cuda_processes:
|
|
268
|
+
logger.debug("No CUDA sessions to restore.")
|
|
269
|
+
return
|
|
270
|
+
|
|
271
|
+
# See checkpoint() for rationale about parallelism.
|
|
272
|
+
start = time.perf_counter()
|
|
273
|
+
|
|
274
|
+
def restore_process(proc: CudaCheckpointProcess) -> None:
|
|
275
|
+
proc.toggle(CudaCheckpointState.RUNNING, skip_first_refresh=True)
|
|
276
|
+
|
|
277
|
+
with ThreadPoolExecutor() as executor:
|
|
278
|
+
futures = [executor.submit(restore_process, proc) for proc in self.cuda_processes]
|
|
279
|
+
|
|
280
|
+
# Wait for all futures and collect any exceptions
|
|
281
|
+
exceptions = []
|
|
282
|
+
for future in as_completed(futures):
|
|
283
|
+
try:
|
|
284
|
+
future.result()
|
|
285
|
+
except Exception as e:
|
|
286
|
+
exceptions.append(e)
|
|
287
|
+
|
|
288
|
+
if exceptions:
|
|
289
|
+
raise CudaCheckpointException(
|
|
290
|
+
f"Failed to restore {len(exceptions)} processes: {'; '.join(str(e) for e in exceptions)}"
|
|
291
|
+
)
|
|
93
292
|
|
|
94
|
-
time.
|
|
293
|
+
elapsed = time.perf_counter() - start
|
|
294
|
+
logger.debug(f"Restoring {len(self.cuda_processes)} CUDA session(s) took => {elapsed:.3f}s")
|
|
95
295
|
|
|
296
|
+
def get_process_count(self) -> int:
|
|
297
|
+
"""Get the number of CUDA processes managed by this session."""
|
|
298
|
+
return len(self.cuda_processes)
|
|
96
299
|
|
|
97
|
-
def
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
300
|
+
def get_process_states(self) -> List[tuple[int, CudaCheckpointState]]:
|
|
301
|
+
"""Get current states of all managed processes."""
|
|
302
|
+
states = []
|
|
303
|
+
for proc in self.cuda_processes:
|
|
304
|
+
proc.refresh_state()
|
|
305
|
+
states.append((proc.pid, proc.state))
|
|
306
|
+
return states
|
|
@@ -29,7 +29,7 @@ class FinalizedFunction:
|
|
|
29
29
|
callable: Callable[..., Any]
|
|
30
30
|
is_async: bool
|
|
31
31
|
is_generator: bool
|
|
32
|
-
|
|
32
|
+
supported_output_formats: Sequence["api_pb2.DataFormat.ValueType"]
|
|
33
33
|
lifespan_manager: Optional["LifespanManager"] = None
|
|
34
34
|
|
|
35
35
|
|
|
@@ -93,9 +93,9 @@ def construct_webhook_callable(
|
|
|
93
93
|
|
|
94
94
|
@dataclass
|
|
95
95
|
class ImportedFunction(Service):
|
|
96
|
-
user_cls_instance: Any
|
|
97
96
|
app: modal.app._App
|
|
98
97
|
service_deps: Optional[Sequence["modal._object._Object"]]
|
|
98
|
+
user_cls_instance = None
|
|
99
99
|
|
|
100
100
|
_user_defined_callable: Callable[..., Any]
|
|
101
101
|
|
|
@@ -108,6 +108,7 @@ class ImportedFunction(Service):
|
|
|
108
108
|
is_generator = fun_def.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR
|
|
109
109
|
|
|
110
110
|
webhook_config = fun_def.webhook_config
|
|
111
|
+
|
|
111
112
|
if not webhook_config.type:
|
|
112
113
|
# for non-webhooks, the runnable is straight forward:
|
|
113
114
|
return {
|
|
@@ -115,7 +116,10 @@ class ImportedFunction(Service):
|
|
|
115
116
|
callable=self._user_defined_callable,
|
|
116
117
|
is_async=is_async,
|
|
117
118
|
is_generator=is_generator,
|
|
118
|
-
|
|
119
|
+
supported_output_formats=fun_def.supported_output_formats
|
|
120
|
+
# FIXME (elias): the following `or [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]` is only
|
|
121
|
+
# needed for tests
|
|
122
|
+
or [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR],
|
|
119
123
|
)
|
|
120
124
|
}
|
|
121
125
|
|
|
@@ -129,7 +133,8 @@ class ImportedFunction(Service):
|
|
|
129
133
|
lifespan_manager=lifespan_manager,
|
|
130
134
|
is_async=True,
|
|
131
135
|
is_generator=True,
|
|
132
|
-
|
|
136
|
+
# FIXME (elias): the following `or [api_pb2.DATA_FORMAT_ASGI]` is only needed for tests
|
|
137
|
+
supported_output_formats=fun_def.supported_output_formats or [api_pb2.DATA_FORMAT_ASGI],
|
|
133
138
|
)
|
|
134
139
|
}
|
|
135
140
|
|
|
@@ -154,6 +159,7 @@ class ImportedClass(Service):
|
|
|
154
159
|
# Use the function definition for whether this is a generator (overriden by webhooks)
|
|
155
160
|
is_generator = _partial.params.is_generator
|
|
156
161
|
webhook_config = _partial.params.webhook_config
|
|
162
|
+
method_def = fun_def.method_definitions[method_name]
|
|
157
163
|
|
|
158
164
|
bound_func = user_func.__get__(self.user_cls_instance)
|
|
159
165
|
|
|
@@ -163,7 +169,10 @@ class ImportedClass(Service):
|
|
|
163
169
|
callable=bound_func,
|
|
164
170
|
is_async=is_async,
|
|
165
171
|
is_generator=bool(is_generator),
|
|
166
|
-
|
|
172
|
+
# FIXME (elias): the following `or [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]` is only
|
|
173
|
+
# needed for tests
|
|
174
|
+
supported_output_formats=method_def.supported_output_formats
|
|
175
|
+
or [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR],
|
|
167
176
|
)
|
|
168
177
|
else:
|
|
169
178
|
web_callable, lifespan_manager = construct_webhook_callable(
|
|
@@ -174,7 +183,8 @@ class ImportedClass(Service):
|
|
|
174
183
|
lifespan_manager=lifespan_manager,
|
|
175
184
|
is_async=True,
|
|
176
185
|
is_generator=True,
|
|
177
|
-
|
|
186
|
+
# FIXME (elias): the following `or [api_pb2.DATA_FORMAT_ASGI]` is only needed for tests
|
|
187
|
+
supported_output_formats=method_def.supported_output_formats or [api_pb2.DATA_FORMAT_ASGI],
|
|
178
188
|
)
|
|
179
189
|
finalized_functions[method_name] = finalized_function
|
|
180
190
|
return finalized_functions
|
|
@@ -199,7 +209,6 @@ def get_user_class_instance(_cls: modal.cls._Cls, args: tuple[Any, ...], kwargs:
|
|
|
199
209
|
|
|
200
210
|
def import_single_function_service(
|
|
201
211
|
function_def: api_pb2.Function,
|
|
202
|
-
ser_cls: Optional[type], # used only for @build functions
|
|
203
212
|
ser_fun: Optional[Callable[..., Any]],
|
|
204
213
|
) -> Service:
|
|
205
214
|
"""Imports a function dynamically, and locates the app.
|
|
@@ -228,12 +237,9 @@ def import_single_function_service(
|
|
|
228
237
|
service_deps: Optional[Sequence["modal._object._Object"]] = None
|
|
229
238
|
active_app: modal.app._App
|
|
230
239
|
|
|
231
|
-
user_cls_or_cls: typing.Union[None, type, modal.cls.Cls]
|
|
232
|
-
user_cls_instance = None
|
|
233
|
-
|
|
234
240
|
if ser_fun is not None:
|
|
235
241
|
# This is a serialized function we already fetched from the server
|
|
236
|
-
|
|
242
|
+
user_defined_callable = ser_fun
|
|
237
243
|
active_app = get_active_app_fallback(function_def)
|
|
238
244
|
else:
|
|
239
245
|
# Load the module dynamically
|
|
@@ -244,58 +250,22 @@ def import_single_function_service(
|
|
|
244
250
|
raise LocalFunctionError("Attempted to load a function defined in a function scope")
|
|
245
251
|
|
|
246
252
|
parts = qual_name.split(".")
|
|
247
|
-
if len(parts)
|
|
248
|
-
# This is a function
|
|
249
|
-
user_cls_or_cls = None
|
|
250
|
-
f = getattr(module, qual_name)
|
|
251
|
-
if isinstance(f, Function):
|
|
252
|
-
_function: modal._functions._Function[Any, Any, Any] = synchronizer._translate_in(f) # type: ignore
|
|
253
|
-
service_deps = _function.deps(only_explicit_mounts=True)
|
|
254
|
-
user_defined_callable = _function.get_raw_f()
|
|
255
|
-
assert _function._app # app should always be set on a decorated function
|
|
256
|
-
active_app = _function._app
|
|
257
|
-
else:
|
|
258
|
-
user_defined_callable = f
|
|
259
|
-
active_app = get_active_app_fallback(function_def)
|
|
260
|
-
|
|
261
|
-
elif len(parts) == 2:
|
|
262
|
-
# This path should only be triggered by @build class builder methods and can be removed
|
|
263
|
-
# once @build is deprecated.
|
|
264
|
-
assert not function_def.use_method_name # new "placeholder methods" should not be invoked directly!
|
|
265
|
-
assert function_def.is_builder_function
|
|
266
|
-
cls_name, fun_name = parts
|
|
267
|
-
user_cls_or_cls = getattr(module, cls_name)
|
|
268
|
-
if isinstance(user_cls_or_cls, modal.cls.Cls):
|
|
269
|
-
# The cls decorator is in global scope
|
|
270
|
-
_cls = typing.cast(modal.cls._Cls, synchronizer._translate_in(user_cls_or_cls))
|
|
271
|
-
user_defined_callable = _cls._callables[fun_name]
|
|
272
|
-
# Intentionally not including these, since @build functions don't actually
|
|
273
|
-
# forward the information from their parent class.
|
|
274
|
-
# service_deps = _cls._get_class_service_function().deps(only_explicit_mounts=True)
|
|
275
|
-
assert _cls._app
|
|
276
|
-
active_app = _cls._app
|
|
277
|
-
else:
|
|
278
|
-
# This is non-decorated class
|
|
279
|
-
user_defined_callable = getattr(user_cls_or_cls, fun_name) # unbound method
|
|
280
|
-
active_app = get_active_app_fallback(function_def)
|
|
281
|
-
else:
|
|
253
|
+
if len(parts) != 1:
|
|
282
254
|
raise InvalidError(f"Invalid function qualname {qual_name}")
|
|
283
255
|
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
256
|
+
f = getattr(module, qual_name)
|
|
257
|
+
if isinstance(f, Function):
|
|
258
|
+
_function: modal._functions._Function[Any, Any, Any] = synchronizer._translate_in(f) # type: ignore
|
|
259
|
+
service_deps = _function.deps(only_explicit_mounts=True)
|
|
260
|
+
user_defined_callable = _function.get_raw_f()
|
|
261
|
+
assert _function._app # app should always be set on a decorated function
|
|
262
|
+
active_app = _function._app
|
|
291
263
|
else:
|
|
292
|
-
#
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
user_defined_callable = user_defined_callable.__get__(user_cls_instance)
|
|
264
|
+
# function isn't decorated in global scope
|
|
265
|
+
user_defined_callable = f
|
|
266
|
+
active_app = get_active_app_fallback(function_def)
|
|
296
267
|
|
|
297
268
|
return ImportedFunction(
|
|
298
|
-
user_cls_instance,
|
|
299
269
|
active_app,
|
|
300
270
|
service_deps,
|
|
301
271
|
user_defined_callable,
|