modal 1.1.1.dev41__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modal might be problematic. Click here for more details.
- modal/__main__.py +1 -2
- modal/_container_entrypoint.py +18 -7
- modal/_functions.py +135 -13
- modal/_object.py +13 -2
- modal/_partial_function.py +8 -8
- modal/_runtime/asgi.py +3 -2
- modal/_runtime/container_io_manager.py +20 -14
- modal/_runtime/container_io_manager.pyi +38 -13
- modal/_runtime/execution_context.py +18 -2
- modal/_runtime/execution_context.pyi +4 -1
- modal/_runtime/gpu_memory_snapshot.py +158 -54
- modal/_utils/blob_utils.py +83 -24
- modal/_utils/function_utils.py +4 -3
- modal/_utils/time_utils.py +28 -4
- modal/app.py +8 -4
- modal/app.pyi +8 -8
- modal/cli/dict.py +14 -11
- modal/cli/entry_point.py +9 -3
- modal/cli/launch.py +102 -4
- modal/cli/profile.py +1 -0
- modal/cli/programs/launch_instance_ssh.py +94 -0
- modal/cli/programs/run_marimo.py +95 -0
- modal/cli/queues.py +49 -19
- modal/cli/secret.py +45 -18
- modal/cli/volume.py +14 -16
- modal/client.pyi +2 -10
- modal/cls.py +12 -2
- modal/cls.pyi +9 -1
- modal/config.py +7 -7
- modal/dict.py +206 -12
- modal/dict.pyi +358 -4
- modal/experimental/__init__.py +130 -0
- modal/file_io.py +1 -1
- modal/file_io.pyi +2 -2
- modal/file_pattern_matcher.py +25 -16
- modal/functions.pyi +111 -11
- modal/image.py +9 -3
- modal/image.pyi +7 -7
- modal/mount.py +20 -13
- modal/mount.pyi +16 -3
- modal/network_file_system.py +8 -2
- modal/object.pyi +3 -0
- modal/parallel_map.py +346 -101
- modal/parallel_map.pyi +108 -0
- modal/proxy.py +2 -1
- modal/queue.py +199 -9
- modal/queue.pyi +357 -3
- modal/sandbox.py +6 -5
- modal/sandbox.pyi +17 -14
- modal/secret.py +196 -3
- modal/secret.pyi +372 -0
- modal/volume.py +239 -23
- modal/volume.pyi +405 -10
- {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/METADATA +2 -2
- {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/RECORD +68 -66
- modal_docs/mdmd/mdmd.py +11 -1
- modal_proto/api.proto +37 -10
- modal_proto/api_grpc.py +32 -0
- modal_proto/api_pb2.py +627 -597
- modal_proto/api_pb2.pyi +107 -19
- modal_proto/api_pb2_grpc.py +67 -2
- modal_proto/api_pb2_grpc.pyi +24 -8
- modal_proto/modal_api_grpc.py +2 -0
- modal_version/__init__.py +1 -1
- {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/WHEEL +0 -0
- {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/entry_points.txt +0 -0
- {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/licenses/LICENSE +0 -0
- {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/top_level.txt +0 -0
|
@@ -27,6 +27,7 @@ class IOContext:
|
|
|
27
27
|
input_ids: list[str]
|
|
28
28
|
retry_counts: list[int]
|
|
29
29
|
function_call_ids: list[str]
|
|
30
|
+
attempt_tokens: list[str]
|
|
30
31
|
function_inputs: list[modal_proto.api_pb2.FunctionInput]
|
|
31
32
|
finalized_function: modal._runtime.user_code_imports.FinalizedFunction
|
|
32
33
|
_cancel_issued: bool
|
|
@@ -37,6 +38,7 @@ class IOContext:
|
|
|
37
38
|
input_ids: list[str],
|
|
38
39
|
retry_counts: list[int],
|
|
39
40
|
function_call_ids: list[str],
|
|
41
|
+
attempt_tokens: list[str],
|
|
40
42
|
finalized_function: modal._runtime.user_code_imports.FinalizedFunction,
|
|
41
43
|
function_inputs: list[modal_proto.api_pb2.FunctionInput],
|
|
42
44
|
is_batched: bool,
|
|
@@ -50,7 +52,7 @@ class IOContext:
|
|
|
50
52
|
cls,
|
|
51
53
|
client: modal.client._Client,
|
|
52
54
|
finalized_functions: dict[str, modal._runtime.user_code_imports.FinalizedFunction],
|
|
53
|
-
inputs: list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]],
|
|
55
|
+
inputs: list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]],
|
|
54
56
|
is_batched: bool,
|
|
55
57
|
) -> IOContext: ...
|
|
56
58
|
def set_cancel_callback(self, cb: collections.abc.Callable[[], None]): ...
|
|
@@ -133,12 +135,19 @@ class _ContainerIOManager:
|
|
|
133
135
|
async def _dynamic_concurrency_loop(self): ...
|
|
134
136
|
def serialize_data_format(self, obj: typing.Any, data_format: int) -> bytes: ...
|
|
135
137
|
async def format_blob_data(self, data: bytes) -> dict[str, typing.Any]: ...
|
|
136
|
-
def get_data_in(
|
|
138
|
+
def get_data_in(
|
|
139
|
+
self, function_call_id: str, attempt_token: typing.Optional[str]
|
|
140
|
+
) -> collections.abc.AsyncIterator[typing.Any]:
|
|
137
141
|
"""Read from the `data_in` stream of a function call."""
|
|
138
142
|
...
|
|
139
143
|
|
|
140
144
|
async def put_data_out(
|
|
141
|
-
self,
|
|
145
|
+
self,
|
|
146
|
+
function_call_id: str,
|
|
147
|
+
attempt_token: str,
|
|
148
|
+
start_index: int,
|
|
149
|
+
data_format: int,
|
|
150
|
+
serialized_messages: list[typing.Any],
|
|
142
151
|
) -> None:
|
|
143
152
|
"""Put data onto the `data_out` stream of a function call.
|
|
144
153
|
|
|
@@ -149,7 +158,7 @@ class _ContainerIOManager:
|
|
|
149
158
|
...
|
|
150
159
|
|
|
151
160
|
def generator_output_sender(
|
|
152
|
-
self, function_call_id: str, data_format: int, message_rx: asyncio.queues.Queue
|
|
161
|
+
self, function_call_id: str, attempt_token: str, data_format: int, message_rx: asyncio.queues.Queue
|
|
153
162
|
) -> typing.AsyncContextManager[None]:
|
|
154
163
|
"""Runs background task that feeds generator outputs into a function call's `data_out` stream."""
|
|
155
164
|
...
|
|
@@ -166,7 +175,7 @@ class _ContainerIOManager:
|
|
|
166
175
|
def get_max_inputs_to_fetch(self): ...
|
|
167
176
|
def _generate_inputs(
|
|
168
177
|
self, batch_max_size: int, batch_wait_ms: int
|
|
169
|
-
) -> collections.abc.AsyncIterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
|
178
|
+
) -> collections.abc.AsyncIterator[list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
|
170
179
|
def run_inputs_outputs(
|
|
171
180
|
self,
|
|
172
181
|
finalized_functions: dict[str, modal._runtime.user_code_imports.FinalizedFunction],
|
|
@@ -332,11 +341,15 @@ class ContainerIOManager:
|
|
|
332
341
|
format_blob_data: __format_blob_data_spec[typing_extensions.Self]
|
|
333
342
|
|
|
334
343
|
class __get_data_in_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
335
|
-
def __call__(
|
|
344
|
+
def __call__(
|
|
345
|
+
self, /, function_call_id: str, attempt_token: typing.Optional[str]
|
|
346
|
+
) -> typing.Iterator[typing.Any]:
|
|
336
347
|
"""Read from the `data_in` stream of a function call."""
|
|
337
348
|
...
|
|
338
349
|
|
|
339
|
-
def aio(
|
|
350
|
+
def aio(
|
|
351
|
+
self, /, function_call_id: str, attempt_token: typing.Optional[str]
|
|
352
|
+
) -> collections.abc.AsyncIterator[typing.Any]:
|
|
340
353
|
"""Read from the `data_in` stream of a function call."""
|
|
341
354
|
...
|
|
342
355
|
|
|
@@ -344,7 +357,13 @@ class ContainerIOManager:
|
|
|
344
357
|
|
|
345
358
|
class __put_data_out_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
346
359
|
def __call__(
|
|
347
|
-
self,
|
|
360
|
+
self,
|
|
361
|
+
/,
|
|
362
|
+
function_call_id: str,
|
|
363
|
+
attempt_token: str,
|
|
364
|
+
start_index: int,
|
|
365
|
+
data_format: int,
|
|
366
|
+
serialized_messages: list[typing.Any],
|
|
348
367
|
) -> None:
|
|
349
368
|
"""Put data onto the `data_out` stream of a function call.
|
|
350
369
|
|
|
@@ -355,7 +374,13 @@ class ContainerIOManager:
|
|
|
355
374
|
...
|
|
356
375
|
|
|
357
376
|
async def aio(
|
|
358
|
-
self,
|
|
377
|
+
self,
|
|
378
|
+
/,
|
|
379
|
+
function_call_id: str,
|
|
380
|
+
attempt_token: str,
|
|
381
|
+
start_index: int,
|
|
382
|
+
data_format: int,
|
|
383
|
+
serialized_messages: list[typing.Any],
|
|
359
384
|
) -> None:
|
|
360
385
|
"""Put data onto the `data_out` stream of a function call.
|
|
361
386
|
|
|
@@ -369,13 +394,13 @@ class ContainerIOManager:
|
|
|
369
394
|
|
|
370
395
|
class __generator_output_sender_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
371
396
|
def __call__(
|
|
372
|
-
self, /, function_call_id: str, data_format: int, message_rx: asyncio.queues.Queue
|
|
397
|
+
self, /, function_call_id: str, attempt_token: str, data_format: int, message_rx: asyncio.queues.Queue
|
|
373
398
|
) -> synchronicity.combined_types.AsyncAndBlockingContextManager[None]:
|
|
374
399
|
"""Runs background task that feeds generator outputs into a function call's `data_out` stream."""
|
|
375
400
|
...
|
|
376
401
|
|
|
377
402
|
def aio(
|
|
378
|
-
self, /, function_call_id: str, data_format: int, message_rx: asyncio.queues.Queue
|
|
403
|
+
self, /, function_call_id: str, attempt_token: str, data_format: int, message_rx: asyncio.queues.Queue
|
|
379
404
|
) -> typing.AsyncContextManager[None]:
|
|
380
405
|
"""Runs background task that feeds generator outputs into a function call's `data_out` stream."""
|
|
381
406
|
...
|
|
@@ -410,10 +435,10 @@ class ContainerIOManager:
|
|
|
410
435
|
class ___generate_inputs_spec(typing_extensions.Protocol[SUPERSELF]):
|
|
411
436
|
def __call__(
|
|
412
437
|
self, /, batch_max_size: int, batch_wait_ms: int
|
|
413
|
-
) -> typing.Iterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
|
438
|
+
) -> typing.Iterator[list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
|
414
439
|
def aio(
|
|
415
440
|
self, /, batch_max_size: int, batch_wait_ms: int
|
|
416
|
-
) -> collections.abc.AsyncIterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
|
441
|
+
) -> collections.abc.AsyncIterator[list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
|
417
442
|
|
|
418
443
|
_generate_inputs: ___generate_inputs_spec[typing_extensions.Self]
|
|
419
444
|
|
|
@@ -72,22 +72,38 @@ def current_function_call_id() -> Optional[str]:
|
|
|
72
72
|
return None
|
|
73
73
|
|
|
74
74
|
|
|
75
|
-
def
|
|
76
|
-
|
|
75
|
+
def current_attempt_token() -> Optional[str]:
|
|
76
|
+
# This ContextVar isn't useful to expose to users.
|
|
77
|
+
try:
|
|
78
|
+
return _current_attempt_token.get()
|
|
79
|
+
except LookupError:
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _set_current_context_ids(
|
|
84
|
+
input_ids: list[str], function_call_ids: list[str], attempt_tokens: list[str]
|
|
85
|
+
) -> Callable[[], None]:
|
|
86
|
+
assert len(input_ids) == len(function_call_ids) == len(attempt_tokens) and input_ids
|
|
87
|
+
|
|
77
88
|
input_id = input_ids[0]
|
|
78
89
|
function_call_id = function_call_ids[0]
|
|
90
|
+
attempt_token = attempt_tokens[0]
|
|
91
|
+
|
|
79
92
|
input_token = _current_input_id.set(input_id)
|
|
80
93
|
function_call_token = _current_function_call_id.set(function_call_id)
|
|
94
|
+
attempt_token_token = _current_attempt_token.set(attempt_token)
|
|
81
95
|
|
|
82
96
|
def _reset_current_context_ids():
|
|
83
97
|
_current_input_id.reset(input_token)
|
|
84
98
|
_current_function_call_id.reset(function_call_token)
|
|
99
|
+
_current_attempt_token.reset(attempt_token_token)
|
|
85
100
|
|
|
86
101
|
return _reset_current_context_ids
|
|
87
102
|
|
|
88
103
|
|
|
89
104
|
_current_input_id: ContextVar = ContextVar("_current_input_id")
|
|
90
105
|
_current_function_call_id: ContextVar = ContextVar("_current_function_call_id")
|
|
106
|
+
_current_attempt_token: ContextVar = ContextVar("_current_attempt_token")
|
|
91
107
|
|
|
92
108
|
_is_currently_importing = False # we set this to True while a container is importing user code
|
|
93
109
|
|
|
@@ -68,11 +68,14 @@ def current_function_call_id() -> typing.Optional[str]:
|
|
|
68
68
|
"""
|
|
69
69
|
...
|
|
70
70
|
|
|
71
|
+
def current_attempt_token() -> typing.Optional[str]: ...
|
|
71
72
|
def _set_current_context_ids(
|
|
72
|
-
input_ids: list[str], function_call_ids: list[str]
|
|
73
|
+
input_ids: list[str], function_call_ids: list[str], attempt_tokens: list[str]
|
|
73
74
|
) -> collections.abc.Callable[[], None]: ...
|
|
74
75
|
def _import_context(): ...
|
|
75
76
|
|
|
76
77
|
_current_input_id: contextvars.ContextVar
|
|
77
78
|
|
|
78
79
|
_current_function_call_id: contextvars.ContextVar
|
|
80
|
+
|
|
81
|
+
_current_attempt_token: contextvars.ContextVar
|
|
@@ -1,17 +1,18 @@
|
|
|
1
1
|
# Copyright Modal Labs 2022
|
|
2
2
|
#
|
|
3
3
|
# This module provides a simple interface for creating GPU memory snapshots,
|
|
4
|
-
#
|
|
4
|
+
# providing a convenient interface to `cuda-checkpoint` [1]. This is intended
|
|
5
5
|
# to be used in conjunction with memory snapshots.
|
|
6
6
|
#
|
|
7
7
|
# [1] https://github.com/NVIDIA/cuda-checkpoint
|
|
8
8
|
|
|
9
9
|
import subprocess
|
|
10
10
|
import time
|
|
11
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
11
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
12
12
|
from dataclasses import dataclass
|
|
13
13
|
from enum import Enum
|
|
14
14
|
from pathlib import Path
|
|
15
|
+
from typing import List, Optional
|
|
15
16
|
|
|
16
17
|
from modal.config import config, logger
|
|
17
18
|
|
|
@@ -19,7 +20,9 @@ CUDA_CHECKPOINT_PATH: str = config.get("cuda_checkpoint_path")
|
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class CudaCheckpointState(Enum):
|
|
22
|
-
"""State representation from the CUDA API
|
|
23
|
+
"""State representation from the CUDA API [1].
|
|
24
|
+
|
|
25
|
+
[1] https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html"""
|
|
23
26
|
|
|
24
27
|
RUNNING = "running"
|
|
25
28
|
LOCKED = "locked"
|
|
@@ -28,6 +31,8 @@ class CudaCheckpointState(Enum):
|
|
|
28
31
|
|
|
29
32
|
|
|
30
33
|
class CudaCheckpointException(Exception):
|
|
34
|
+
"""Exception raised for CUDA checkpoint operations."""
|
|
35
|
+
|
|
31
36
|
pass
|
|
32
37
|
|
|
33
38
|
|
|
@@ -39,16 +44,31 @@ class CudaCheckpointProcess:
|
|
|
39
44
|
pid: int
|
|
40
45
|
state: CudaCheckpointState
|
|
41
46
|
|
|
42
|
-
def toggle(self, target_state: CudaCheckpointState, timeout_secs: float = 5 * 60.0):
|
|
47
|
+
def toggle(self, target_state: CudaCheckpointState, timeout_secs: float = 5 * 60.0) -> None:
|
|
43
48
|
"""Toggle CUDA checkpoint state for current process, moving GPU memory to the
|
|
44
|
-
CPU and back depending on the current process state when called.
|
|
49
|
+
CPU and back depending on the current process state when called.
|
|
50
|
+
"""
|
|
45
51
|
logger.debug(f"PID: {self.pid} Toggling CUDA checkpoint state to {target_state.value}")
|
|
46
52
|
|
|
47
53
|
start_time = time.monotonic()
|
|
54
|
+
retry_count = 0
|
|
55
|
+
max_retries = 3
|
|
48
56
|
|
|
49
57
|
while self._should_continue_toggle(target_state, start_time, timeout_secs):
|
|
50
|
-
|
|
51
|
-
|
|
58
|
+
try:
|
|
59
|
+
self._execute_toggle_command()
|
|
60
|
+
# Use exponential backoff for retries
|
|
61
|
+
sleep_time = min(0.1 * (2**retry_count), 1.0)
|
|
62
|
+
time.sleep(sleep_time)
|
|
63
|
+
retry_count = 0
|
|
64
|
+
except CudaCheckpointException as e:
|
|
65
|
+
retry_count += 1
|
|
66
|
+
if retry_count >= max_retries:
|
|
67
|
+
raise CudaCheckpointException(
|
|
68
|
+
f"PID: {self.pid} Failed to toggle state after {max_retries} retries: {e}"
|
|
69
|
+
)
|
|
70
|
+
logger.debug(f"PID: {self.pid} Retry {retry_count}/{max_retries} after error: {e}")
|
|
71
|
+
time.sleep(0.5 * retry_count)
|
|
52
72
|
|
|
53
73
|
logger.debug(f"PID: {self.pid} Target state {target_state.value} reached")
|
|
54
74
|
|
|
@@ -73,19 +93,25 @@ class CudaCheckpointProcess:
|
|
|
73
93
|
|
|
74
94
|
return True
|
|
75
95
|
|
|
76
|
-
def _execute_toggle_command(self):
|
|
96
|
+
def _execute_toggle_command(self) -> None:
|
|
77
97
|
"""Execute the cuda-checkpoint toggle command."""
|
|
78
98
|
try:
|
|
79
|
-
subprocess.run(
|
|
99
|
+
_ = subprocess.run(
|
|
80
100
|
[CUDA_CHECKPOINT_PATH, "--toggle", "--pid", str(self.pid)],
|
|
81
101
|
check=True,
|
|
82
102
|
capture_output=True,
|
|
83
103
|
text=True,
|
|
104
|
+
timeout=30,
|
|
84
105
|
)
|
|
85
106
|
logger.debug(f"PID: {self.pid} Successfully toggled CUDA checkpoint state")
|
|
86
107
|
except subprocess.CalledProcessError as e:
|
|
87
|
-
|
|
88
|
-
|
|
108
|
+
error_msg = f"PID: {self.pid} Failed to toggle CUDA checkpoint state: {e.stderr}"
|
|
109
|
+
logger.debug(error_msg)
|
|
110
|
+
raise CudaCheckpointException(error_msg)
|
|
111
|
+
except subprocess.TimeoutExpired:
|
|
112
|
+
error_msg = f"PID: {self.pid} Toggle command timed out"
|
|
113
|
+
logger.debug(error_msg)
|
|
114
|
+
raise CudaCheckpointException(error_msg)
|
|
89
115
|
|
|
90
116
|
def refresh_state(self) -> None:
|
|
91
117
|
"""Refreshes the current CUDA checkpoint state for this process."""
|
|
@@ -95,15 +121,20 @@ class CudaCheckpointProcess:
|
|
|
95
121
|
check=True,
|
|
96
122
|
capture_output=True,
|
|
97
123
|
text=True,
|
|
98
|
-
timeout=
|
|
124
|
+
timeout=10,
|
|
99
125
|
)
|
|
100
126
|
|
|
101
127
|
state_str = result.stdout.strip().lower()
|
|
102
128
|
self.state = CudaCheckpointState(state_str)
|
|
103
129
|
|
|
104
130
|
except subprocess.CalledProcessError as e:
|
|
105
|
-
|
|
106
|
-
|
|
131
|
+
error_msg = f"PID: {self.pid} Failed to get CUDA checkpoint state: {e.stderr}"
|
|
132
|
+
logger.debug(error_msg)
|
|
133
|
+
raise CudaCheckpointException(error_msg)
|
|
134
|
+
except subprocess.TimeoutExpired:
|
|
135
|
+
error_msg = f"PID: {self.pid} Get state command timed out"
|
|
136
|
+
logger.debug(error_msg)
|
|
137
|
+
raise CudaCheckpointException(error_msg)
|
|
107
138
|
|
|
108
139
|
|
|
109
140
|
class CudaCheckpointSession:
|
|
@@ -111,12 +142,17 @@ class CudaCheckpointSession:
|
|
|
111
142
|
|
|
112
143
|
def __init__(self):
|
|
113
144
|
self.cuda_processes = self._get_cuda_pids()
|
|
114
|
-
|
|
145
|
+
if self.cuda_processes:
|
|
146
|
+
logger.debug(
|
|
147
|
+
f"Found {len(self.cuda_processes)} PID(s) with CUDA sessions: {[c.pid for c in self.cuda_processes]}"
|
|
148
|
+
)
|
|
149
|
+
else:
|
|
150
|
+
logger.debug("No CUDA sessions found.")
|
|
115
151
|
|
|
116
|
-
def _get_cuda_pids(self) ->
|
|
152
|
+
def _get_cuda_pids(self) -> List[CudaCheckpointProcess]:
|
|
117
153
|
"""Iterates over all PIDs and identifies the ones that have running
|
|
118
154
|
CUDA sessions."""
|
|
119
|
-
cuda_pids:
|
|
155
|
+
cuda_pids: List[CudaCheckpointProcess] = []
|
|
120
156
|
|
|
121
157
|
# Get all active process IDs from /proc directory
|
|
122
158
|
proc_dir = Path("/proc")
|
|
@@ -125,75 +161,143 @@ class CudaCheckpointSession:
|
|
|
125
161
|
"OS does not have /proc path rendering it incompatible with GPU memory snapshots."
|
|
126
162
|
)
|
|
127
163
|
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
continue
|
|
131
|
-
|
|
132
|
-
pid = int(entry.name)
|
|
133
|
-
try:
|
|
134
|
-
# Call cuda-checkpoint to check if this PID has a CUDA session
|
|
135
|
-
result = subprocess.run(
|
|
136
|
-
[CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(pid)],
|
|
137
|
-
capture_output=True,
|
|
138
|
-
text=True,
|
|
139
|
-
timeout=10,
|
|
140
|
-
)
|
|
141
|
-
|
|
142
|
-
# If the command succeeds (return code 0), this PID has a CUDA session
|
|
143
|
-
if result.returncode == 0:
|
|
144
|
-
state_str = result.stdout.strip().lower()
|
|
145
|
-
state = CudaCheckpointState(state_str)
|
|
146
|
-
|
|
147
|
-
cuda_checkpoint_process = CudaCheckpointProcess(pid=pid, state=state)
|
|
148
|
-
cuda_pids.append(cuda_checkpoint_process)
|
|
164
|
+
# Get all numeric directories (PIDs) from /proc
|
|
165
|
+
pid_dirs = [entry for entry in proc_dir.iterdir() if entry.name.isdigit()]
|
|
149
166
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
167
|
+
# Use ThreadPoolExecutor to check PIDs in parallel for better performance
|
|
168
|
+
with ThreadPoolExecutor(max_workers=min(50, len(pid_dirs))) as executor:
|
|
169
|
+
future_to_pid = {
|
|
170
|
+
executor.submit(self._check_cuda_session, int(entry.name)): int(entry.name) for entry in pid_dirs
|
|
171
|
+
}
|
|
153
172
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
173
|
+
for future in as_completed(future_to_pid):
|
|
174
|
+
pid = future_to_pid[future]
|
|
175
|
+
try:
|
|
176
|
+
cuda_process = future.result()
|
|
177
|
+
if cuda_process:
|
|
178
|
+
cuda_pids.append(cuda_process)
|
|
179
|
+
except Exception as e:
|
|
180
|
+
logger.debug(f"Error checking PID {pid}: {e}")
|
|
159
181
|
|
|
160
182
|
# Sort PIDs for ordered checkpointing
|
|
161
183
|
cuda_pids.sort(key=lambda x: x.pid)
|
|
162
184
|
return cuda_pids
|
|
163
185
|
|
|
186
|
+
def _check_cuda_session(self, pid: int) -> Optional[CudaCheckpointProcess]:
|
|
187
|
+
"""Check if a specific PID has a CUDA session."""
|
|
188
|
+
try:
|
|
189
|
+
result = subprocess.run(
|
|
190
|
+
[CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(pid)],
|
|
191
|
+
capture_output=True,
|
|
192
|
+
text=True,
|
|
193
|
+
timeout=5,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# If the command succeeds (return code 0), this PID has a CUDA session
|
|
197
|
+
if result.returncode == 0:
|
|
198
|
+
state_str = result.stdout.strip().lower()
|
|
199
|
+
state = CudaCheckpointState(state_str)
|
|
200
|
+
return CudaCheckpointProcess(pid=pid, state=state)
|
|
201
|
+
|
|
202
|
+
except subprocess.CalledProcessError:
|
|
203
|
+
# Command failed, which is expected for PIDs without CUDA sessions
|
|
204
|
+
pass
|
|
205
|
+
except subprocess.TimeoutExpired:
|
|
206
|
+
logger.debug(f"Timeout checking CUDA state for PID {pid}")
|
|
207
|
+
except Exception as e:
|
|
208
|
+
logger.debug(f"Error checking PID {pid}: {e}")
|
|
209
|
+
|
|
210
|
+
return None
|
|
211
|
+
|
|
164
212
|
def checkpoint(self) -> None:
|
|
213
|
+
"""Checkpoint all CUDA processes, moving GPU memory to CPU."""
|
|
214
|
+
if not self.cuda_processes:
|
|
215
|
+
logger.debug("No CUDA processes to checkpoint.")
|
|
216
|
+
return
|
|
217
|
+
|
|
165
218
|
# Validate all states first
|
|
166
219
|
for proc in self.cuda_processes:
|
|
220
|
+
proc.refresh_state() # Refresh state before validation
|
|
167
221
|
if proc.state != CudaCheckpointState.RUNNING:
|
|
168
|
-
raise CudaCheckpointException(
|
|
222
|
+
raise CudaCheckpointException(
|
|
223
|
+
f"PID {proc.pid}: CUDA session not in {CudaCheckpointState.RUNNING.value} state. "
|
|
224
|
+
f"Current state: {proc.state.value}"
|
|
225
|
+
)
|
|
169
226
|
|
|
170
227
|
# Moving state from GPU to CPU can take several seconds per CUDA session.
|
|
171
228
|
# Make a parallel call per CUDA session.
|
|
172
229
|
start = time.perf_counter()
|
|
173
230
|
|
|
174
|
-
def checkpoint_impl(proc: CudaCheckpointProcess):
|
|
231
|
+
def checkpoint_impl(proc: CudaCheckpointProcess) -> None:
|
|
175
232
|
proc.toggle(CudaCheckpointState.CHECKPOINTED)
|
|
176
233
|
|
|
177
234
|
with ThreadPoolExecutor() as executor:
|
|
178
|
-
|
|
235
|
+
futures = [executor.submit(checkpoint_impl, proc) for proc in self.cuda_processes]
|
|
236
|
+
|
|
237
|
+
# Wait for all futures and collect any exceptions
|
|
238
|
+
exceptions = []
|
|
239
|
+
for future in as_completed(futures):
|
|
240
|
+
try:
|
|
241
|
+
future.result()
|
|
242
|
+
except Exception as e:
|
|
243
|
+
exceptions.append(e)
|
|
244
|
+
|
|
245
|
+
if exceptions:
|
|
246
|
+
raise CudaCheckpointException(
|
|
247
|
+
f"Failed to checkpoint {len(exceptions)} processes: {'; '.join(str(e) for e in exceptions)}"
|
|
248
|
+
)
|
|
179
249
|
|
|
180
250
|
elapsed = time.perf_counter() - start
|
|
181
|
-
logger.debug(f"Checkpointing CUDA sessions took => {elapsed:.3f}s")
|
|
251
|
+
logger.debug(f"Checkpointing {len(self.cuda_processes)} CUDA sessions took => {elapsed:.3f}s")
|
|
182
252
|
|
|
183
253
|
def restore(self) -> None:
|
|
254
|
+
"""Restore all CUDA processes, moving memory back from CPU to GPU."""
|
|
255
|
+
if not self.cuda_processes:
|
|
256
|
+
logger.debug("No CUDA sessions to restore.")
|
|
257
|
+
return
|
|
258
|
+
|
|
184
259
|
# Validate all states first
|
|
185
260
|
for proc in self.cuda_processes:
|
|
261
|
+
proc.refresh_state() # Refresh state before validation
|
|
186
262
|
if proc.state != CudaCheckpointState.CHECKPOINTED:
|
|
187
|
-
raise CudaCheckpointException(
|
|
263
|
+
raise CudaCheckpointException(
|
|
264
|
+
f"PID {proc.pid}: CUDA session not in {CudaCheckpointState.CHECKPOINTED.value} state. "
|
|
265
|
+
f"Current state: {proc.state.value}"
|
|
266
|
+
)
|
|
188
267
|
|
|
189
268
|
# See checkpoint() for rationale about parallelism.
|
|
190
269
|
start = time.perf_counter()
|
|
191
270
|
|
|
192
|
-
def restore_process(proc: CudaCheckpointProcess):
|
|
271
|
+
def restore_process(proc: CudaCheckpointProcess) -> None:
|
|
193
272
|
proc.toggle(CudaCheckpointState.RUNNING)
|
|
194
273
|
|
|
195
274
|
with ThreadPoolExecutor() as executor:
|
|
196
|
-
|
|
275
|
+
futures = [executor.submit(restore_process, proc) for proc in self.cuda_processes]
|
|
276
|
+
|
|
277
|
+
# Wait for all futures and collect any exceptions
|
|
278
|
+
exceptions = []
|
|
279
|
+
for future in as_completed(futures):
|
|
280
|
+
try:
|
|
281
|
+
future.result()
|
|
282
|
+
except Exception as e:
|
|
283
|
+
exceptions.append(e)
|
|
284
|
+
|
|
285
|
+
if exceptions:
|
|
286
|
+
raise CudaCheckpointException(
|
|
287
|
+
f"Failed to restore {len(exceptions)} processes: {'; '.join(str(e) for e in exceptions)}"
|
|
288
|
+
)
|
|
197
289
|
|
|
198
290
|
elapsed = time.perf_counter() - start
|
|
199
|
-
logger.debug(f"Restoring CUDA
|
|
291
|
+
logger.debug(f"Restoring {len(self.cuda_processes)} CUDA session(s) took => {elapsed:.3f}s")
|
|
292
|
+
|
|
293
|
+
def get_process_count(self) -> int:
|
|
294
|
+
"""Get the number of CUDA processes managed by this session."""
|
|
295
|
+
return len(self.cuda_processes)
|
|
296
|
+
|
|
297
|
+
def get_process_states(self) -> List[tuple[int, CudaCheckpointState]]:
|
|
298
|
+
"""Get current states of all managed processes."""
|
|
299
|
+
states = []
|
|
300
|
+
for proc in self.cuda_processes:
|
|
301
|
+
proc.refresh_state()
|
|
302
|
+
states.append((proc.pid, proc.state))
|
|
303
|
+
return states
|