modal 1.1.1.dev41__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (68) hide show
  1. modal/__main__.py +1 -2
  2. modal/_container_entrypoint.py +18 -7
  3. modal/_functions.py +135 -13
  4. modal/_object.py +13 -2
  5. modal/_partial_function.py +8 -8
  6. modal/_runtime/asgi.py +3 -2
  7. modal/_runtime/container_io_manager.py +20 -14
  8. modal/_runtime/container_io_manager.pyi +38 -13
  9. modal/_runtime/execution_context.py +18 -2
  10. modal/_runtime/execution_context.pyi +4 -1
  11. modal/_runtime/gpu_memory_snapshot.py +158 -54
  12. modal/_utils/blob_utils.py +83 -24
  13. modal/_utils/function_utils.py +4 -3
  14. modal/_utils/time_utils.py +28 -4
  15. modal/app.py +8 -4
  16. modal/app.pyi +8 -8
  17. modal/cli/dict.py +14 -11
  18. modal/cli/entry_point.py +9 -3
  19. modal/cli/launch.py +102 -4
  20. modal/cli/profile.py +1 -0
  21. modal/cli/programs/launch_instance_ssh.py +94 -0
  22. modal/cli/programs/run_marimo.py +95 -0
  23. modal/cli/queues.py +49 -19
  24. modal/cli/secret.py +45 -18
  25. modal/cli/volume.py +14 -16
  26. modal/client.pyi +2 -10
  27. modal/cls.py +12 -2
  28. modal/cls.pyi +9 -1
  29. modal/config.py +7 -7
  30. modal/dict.py +206 -12
  31. modal/dict.pyi +358 -4
  32. modal/experimental/__init__.py +130 -0
  33. modal/file_io.py +1 -1
  34. modal/file_io.pyi +2 -2
  35. modal/file_pattern_matcher.py +25 -16
  36. modal/functions.pyi +111 -11
  37. modal/image.py +9 -3
  38. modal/image.pyi +7 -7
  39. modal/mount.py +20 -13
  40. modal/mount.pyi +16 -3
  41. modal/network_file_system.py +8 -2
  42. modal/object.pyi +3 -0
  43. modal/parallel_map.py +346 -101
  44. modal/parallel_map.pyi +108 -0
  45. modal/proxy.py +2 -1
  46. modal/queue.py +199 -9
  47. modal/queue.pyi +357 -3
  48. modal/sandbox.py +6 -5
  49. modal/sandbox.pyi +17 -14
  50. modal/secret.py +196 -3
  51. modal/secret.pyi +372 -0
  52. modal/volume.py +239 -23
  53. modal/volume.pyi +405 -10
  54. {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/METADATA +2 -2
  55. {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/RECORD +68 -66
  56. modal_docs/mdmd/mdmd.py +11 -1
  57. modal_proto/api.proto +37 -10
  58. modal_proto/api_grpc.py +32 -0
  59. modal_proto/api_pb2.py +627 -597
  60. modal_proto/api_pb2.pyi +107 -19
  61. modal_proto/api_pb2_grpc.py +67 -2
  62. modal_proto/api_pb2_grpc.pyi +24 -8
  63. modal_proto/modal_api_grpc.py +2 -0
  64. modal_version/__init__.py +1 -1
  65. {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/WHEEL +0 -0
  66. {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/entry_points.txt +0 -0
  67. {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/licenses/LICENSE +0 -0
  68. {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/top_level.txt +0 -0
@@ -27,6 +27,7 @@ class IOContext:
27
27
  input_ids: list[str]
28
28
  retry_counts: list[int]
29
29
  function_call_ids: list[str]
30
+ attempt_tokens: list[str]
30
31
  function_inputs: list[modal_proto.api_pb2.FunctionInput]
31
32
  finalized_function: modal._runtime.user_code_imports.FinalizedFunction
32
33
  _cancel_issued: bool
@@ -37,6 +38,7 @@ class IOContext:
37
38
  input_ids: list[str],
38
39
  retry_counts: list[int],
39
40
  function_call_ids: list[str],
41
+ attempt_tokens: list[str],
40
42
  finalized_function: modal._runtime.user_code_imports.FinalizedFunction,
41
43
  function_inputs: list[modal_proto.api_pb2.FunctionInput],
42
44
  is_batched: bool,
@@ -50,7 +52,7 @@ class IOContext:
50
52
  cls,
51
53
  client: modal.client._Client,
52
54
  finalized_functions: dict[str, modal._runtime.user_code_imports.FinalizedFunction],
53
- inputs: list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]],
55
+ inputs: list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]],
54
56
  is_batched: bool,
55
57
  ) -> IOContext: ...
56
58
  def set_cancel_callback(self, cb: collections.abc.Callable[[], None]): ...
@@ -133,12 +135,19 @@ class _ContainerIOManager:
133
135
  async def _dynamic_concurrency_loop(self): ...
134
136
  def serialize_data_format(self, obj: typing.Any, data_format: int) -> bytes: ...
135
137
  async def format_blob_data(self, data: bytes) -> dict[str, typing.Any]: ...
136
- def get_data_in(self, function_call_id: str) -> collections.abc.AsyncIterator[typing.Any]:
138
+ def get_data_in(
139
+ self, function_call_id: str, attempt_token: typing.Optional[str]
140
+ ) -> collections.abc.AsyncIterator[typing.Any]:
137
141
  """Read from the `data_in` stream of a function call."""
138
142
  ...
139
143
 
140
144
  async def put_data_out(
141
- self, function_call_id: str, start_index: int, data_format: int, serialized_messages: list[typing.Any]
145
+ self,
146
+ function_call_id: str,
147
+ attempt_token: str,
148
+ start_index: int,
149
+ data_format: int,
150
+ serialized_messages: list[typing.Any],
142
151
  ) -> None:
143
152
  """Put data onto the `data_out` stream of a function call.
144
153
 
@@ -149,7 +158,7 @@ class _ContainerIOManager:
149
158
  ...
150
159
 
151
160
  def generator_output_sender(
152
- self, function_call_id: str, data_format: int, message_rx: asyncio.queues.Queue
161
+ self, function_call_id: str, attempt_token: str, data_format: int, message_rx: asyncio.queues.Queue
153
162
  ) -> typing.AsyncContextManager[None]:
154
163
  """Runs background task that feeds generator outputs into a function call's `data_out` stream."""
155
164
  ...
@@ -166,7 +175,7 @@ class _ContainerIOManager:
166
175
  def get_max_inputs_to_fetch(self): ...
167
176
  def _generate_inputs(
168
177
  self, batch_max_size: int, batch_wait_ms: int
169
- ) -> collections.abc.AsyncIterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
178
+ ) -> collections.abc.AsyncIterator[list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]]]: ...
170
179
  def run_inputs_outputs(
171
180
  self,
172
181
  finalized_functions: dict[str, modal._runtime.user_code_imports.FinalizedFunction],
@@ -332,11 +341,15 @@ class ContainerIOManager:
332
341
  format_blob_data: __format_blob_data_spec[typing_extensions.Self]
333
342
 
334
343
  class __get_data_in_spec(typing_extensions.Protocol[SUPERSELF]):
335
- def __call__(self, /, function_call_id: str) -> typing.Iterator[typing.Any]:
344
+ def __call__(
345
+ self, /, function_call_id: str, attempt_token: typing.Optional[str]
346
+ ) -> typing.Iterator[typing.Any]:
336
347
  """Read from the `data_in` stream of a function call."""
337
348
  ...
338
349
 
339
- def aio(self, /, function_call_id: str) -> collections.abc.AsyncIterator[typing.Any]:
350
+ def aio(
351
+ self, /, function_call_id: str, attempt_token: typing.Optional[str]
352
+ ) -> collections.abc.AsyncIterator[typing.Any]:
340
353
  """Read from the `data_in` stream of a function call."""
341
354
  ...
342
355
 
@@ -344,7 +357,13 @@ class ContainerIOManager:
344
357
 
345
358
  class __put_data_out_spec(typing_extensions.Protocol[SUPERSELF]):
346
359
  def __call__(
347
- self, /, function_call_id: str, start_index: int, data_format: int, serialized_messages: list[typing.Any]
360
+ self,
361
+ /,
362
+ function_call_id: str,
363
+ attempt_token: str,
364
+ start_index: int,
365
+ data_format: int,
366
+ serialized_messages: list[typing.Any],
348
367
  ) -> None:
349
368
  """Put data onto the `data_out` stream of a function call.
350
369
 
@@ -355,7 +374,13 @@ class ContainerIOManager:
355
374
  ...
356
375
 
357
376
  async def aio(
358
- self, /, function_call_id: str, start_index: int, data_format: int, serialized_messages: list[typing.Any]
377
+ self,
378
+ /,
379
+ function_call_id: str,
380
+ attempt_token: str,
381
+ start_index: int,
382
+ data_format: int,
383
+ serialized_messages: list[typing.Any],
359
384
  ) -> None:
360
385
  """Put data onto the `data_out` stream of a function call.
361
386
 
@@ -369,13 +394,13 @@ class ContainerIOManager:
369
394
 
370
395
  class __generator_output_sender_spec(typing_extensions.Protocol[SUPERSELF]):
371
396
  def __call__(
372
- self, /, function_call_id: str, data_format: int, message_rx: asyncio.queues.Queue
397
+ self, /, function_call_id: str, attempt_token: str, data_format: int, message_rx: asyncio.queues.Queue
373
398
  ) -> synchronicity.combined_types.AsyncAndBlockingContextManager[None]:
374
399
  """Runs background task that feeds generator outputs into a function call's `data_out` stream."""
375
400
  ...
376
401
 
377
402
  def aio(
378
- self, /, function_call_id: str, data_format: int, message_rx: asyncio.queues.Queue
403
+ self, /, function_call_id: str, attempt_token: str, data_format: int, message_rx: asyncio.queues.Queue
379
404
  ) -> typing.AsyncContextManager[None]:
380
405
  """Runs background task that feeds generator outputs into a function call's `data_out` stream."""
381
406
  ...
@@ -410,10 +435,10 @@ class ContainerIOManager:
410
435
  class ___generate_inputs_spec(typing_extensions.Protocol[SUPERSELF]):
411
436
  def __call__(
412
437
  self, /, batch_max_size: int, batch_wait_ms: int
413
- ) -> typing.Iterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
438
+ ) -> typing.Iterator[list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]]]: ...
414
439
  def aio(
415
440
  self, /, batch_max_size: int, batch_wait_ms: int
416
- ) -> collections.abc.AsyncIterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
441
+ ) -> collections.abc.AsyncIterator[list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]]]: ...
417
442
 
418
443
  _generate_inputs: ___generate_inputs_spec[typing_extensions.Self]
419
444
 
@@ -72,22 +72,38 @@ def current_function_call_id() -> Optional[str]:
72
72
  return None
73
73
 
74
74
 
75
- def _set_current_context_ids(input_ids: list[str], function_call_ids: list[str]) -> Callable[[], None]:
76
- assert len(input_ids) == len(function_call_ids) and len(input_ids) > 0
75
+ def current_attempt_token() -> Optional[str]:
76
+ # This ContextVar isn't useful to expose to users.
77
+ try:
78
+ return _current_attempt_token.get()
79
+ except LookupError:
80
+ return None
81
+
82
+
83
+ def _set_current_context_ids(
84
+ input_ids: list[str], function_call_ids: list[str], attempt_tokens: list[str]
85
+ ) -> Callable[[], None]:
86
+ assert len(input_ids) == len(function_call_ids) == len(attempt_tokens) and input_ids
87
+
77
88
  input_id = input_ids[0]
78
89
  function_call_id = function_call_ids[0]
90
+ attempt_token = attempt_tokens[0]
91
+
79
92
  input_token = _current_input_id.set(input_id)
80
93
  function_call_token = _current_function_call_id.set(function_call_id)
94
+ attempt_token_token = _current_attempt_token.set(attempt_token)
81
95
 
82
96
  def _reset_current_context_ids():
83
97
  _current_input_id.reset(input_token)
84
98
  _current_function_call_id.reset(function_call_token)
99
+ _current_attempt_token.reset(attempt_token_token)
85
100
 
86
101
  return _reset_current_context_ids
87
102
 
88
103
 
89
104
  _current_input_id: ContextVar = ContextVar("_current_input_id")
90
105
  _current_function_call_id: ContextVar = ContextVar("_current_function_call_id")
106
+ _current_attempt_token: ContextVar = ContextVar("_current_attempt_token")
91
107
 
92
108
  _is_currently_importing = False # we set this to True while a container is importing user code
93
109
 
@@ -68,11 +68,14 @@ def current_function_call_id() -> typing.Optional[str]:
68
68
  """
69
69
  ...
70
70
 
71
+ def current_attempt_token() -> typing.Optional[str]: ...
71
72
  def _set_current_context_ids(
72
- input_ids: list[str], function_call_ids: list[str]
73
+ input_ids: list[str], function_call_ids: list[str], attempt_tokens: list[str]
73
74
  ) -> collections.abc.Callable[[], None]: ...
74
75
  def _import_context(): ...
75
76
 
76
77
  _current_input_id: contextvars.ContextVar
77
78
 
78
79
  _current_function_call_id: contextvars.ContextVar
80
+
81
+ _current_attempt_token: contextvars.ContextVar
@@ -1,17 +1,18 @@
1
1
  # Copyright Modal Labs 2022
2
2
  #
3
3
  # This module provides a simple interface for creating GPU memory snapshots,
4
- # provising a convenient interface to `cuda-checkpoint` [1]. This is intended
4
+ # providing a convenient interface to `cuda-checkpoint` [1]. This is intended
5
5
  # to be used in conjunction with memory snapshots.
6
6
  #
7
7
  # [1] https://github.com/NVIDIA/cuda-checkpoint
8
8
 
9
9
  import subprocess
10
10
  import time
11
- from concurrent.futures import ThreadPoolExecutor
11
+ from concurrent.futures import ThreadPoolExecutor, as_completed
12
12
  from dataclasses import dataclass
13
13
  from enum import Enum
14
14
  from pathlib import Path
15
+ from typing import List, Optional
15
16
 
16
17
  from modal.config import config, logger
17
18
 
@@ -19,7 +20,9 @@ CUDA_CHECKPOINT_PATH: str = config.get("cuda_checkpoint_path")
19
20
 
20
21
 
21
22
  class CudaCheckpointState(Enum):
22
- """State representation from the CUDA API: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1gc96cdda177a2b8c296144567cbea4f23"""
23
+ """State representation from the CUDA API [1].
24
+
25
+ [1] https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html"""
23
26
 
24
27
  RUNNING = "running"
25
28
  LOCKED = "locked"
@@ -28,6 +31,8 @@ class CudaCheckpointState(Enum):
28
31
 
29
32
 
30
33
  class CudaCheckpointException(Exception):
34
+ """Exception raised for CUDA checkpoint operations."""
35
+
31
36
  pass
32
37
 
33
38
 
@@ -39,16 +44,31 @@ class CudaCheckpointProcess:
39
44
  pid: int
40
45
  state: CudaCheckpointState
41
46
 
42
- def toggle(self, target_state: CudaCheckpointState, timeout_secs: float = 5 * 60.0):
47
+ def toggle(self, target_state: CudaCheckpointState, timeout_secs: float = 5 * 60.0) -> None:
43
48
  """Toggle CUDA checkpoint state for current process, moving GPU memory to the
44
- CPU and back depending on the current process state when called."""
49
+ CPU and back depending on the current process state when called.
50
+ """
45
51
  logger.debug(f"PID: {self.pid} Toggling CUDA checkpoint state to {target_state.value}")
46
52
 
47
53
  start_time = time.monotonic()
54
+ retry_count = 0
55
+ max_retries = 3
48
56
 
49
57
  while self._should_continue_toggle(target_state, start_time, timeout_secs):
50
- self._execute_toggle_command()
51
- time.sleep(0.1)
58
+ try:
59
+ self._execute_toggle_command()
60
+ # Use exponential backoff for retries
61
+ sleep_time = min(0.1 * (2**retry_count), 1.0)
62
+ time.sleep(sleep_time)
63
+ retry_count = 0
64
+ except CudaCheckpointException as e:
65
+ retry_count += 1
66
+ if retry_count >= max_retries:
67
+ raise CudaCheckpointException(
68
+ f"PID: {self.pid} Failed to toggle state after {max_retries} retries: {e}"
69
+ )
70
+ logger.debug(f"PID: {self.pid} Retry {retry_count}/{max_retries} after error: {e}")
71
+ time.sleep(0.5 * retry_count)
52
72
 
53
73
  logger.debug(f"PID: {self.pid} Target state {target_state.value} reached")
54
74
 
@@ -73,19 +93,25 @@ class CudaCheckpointProcess:
73
93
 
74
94
  return True
75
95
 
76
- def _execute_toggle_command(self):
96
+ def _execute_toggle_command(self) -> None:
77
97
  """Execute the cuda-checkpoint toggle command."""
78
98
  try:
79
- subprocess.run(
99
+ _ = subprocess.run(
80
100
  [CUDA_CHECKPOINT_PATH, "--toggle", "--pid", str(self.pid)],
81
101
  check=True,
82
102
  capture_output=True,
83
103
  text=True,
104
+ timeout=30,
84
105
  )
85
106
  logger.debug(f"PID: {self.pid} Successfully toggled CUDA checkpoint state")
86
107
  except subprocess.CalledProcessError as e:
87
- logger.debug(f"PID: {self.pid} Failed to toggle CUDA checkpoint state: {e.stderr}")
88
- raise CudaCheckpointException(e.stderr)
108
+ error_msg = f"PID: {self.pid} Failed to toggle CUDA checkpoint state: {e.stderr}"
109
+ logger.debug(error_msg)
110
+ raise CudaCheckpointException(error_msg)
111
+ except subprocess.TimeoutExpired:
112
+ error_msg = f"PID: {self.pid} Toggle command timed out"
113
+ logger.debug(error_msg)
114
+ raise CudaCheckpointException(error_msg)
89
115
 
90
116
  def refresh_state(self) -> None:
91
117
  """Refreshes the current CUDA checkpoint state for this process."""
@@ -95,15 +121,20 @@ class CudaCheckpointProcess:
95
121
  check=True,
96
122
  capture_output=True,
97
123
  text=True,
98
- timeout=5,
124
+ timeout=10,
99
125
  )
100
126
 
101
127
  state_str = result.stdout.strip().lower()
102
128
  self.state = CudaCheckpointState(state_str)
103
129
 
104
130
  except subprocess.CalledProcessError as e:
105
- logger.debug(f"PID: {self.pid} Failed to get CUDA checkpoint state: {e.stderr}")
106
- raise CudaCheckpointException(e.stderr)
131
+ error_msg = f"PID: {self.pid} Failed to get CUDA checkpoint state: {e.stderr}"
132
+ logger.debug(error_msg)
133
+ raise CudaCheckpointException(error_msg)
134
+ except subprocess.TimeoutExpired:
135
+ error_msg = f"PID: {self.pid} Get state command timed out"
136
+ logger.debug(error_msg)
137
+ raise CudaCheckpointException(error_msg)
107
138
 
108
139
 
109
140
  class CudaCheckpointSession:
@@ -111,12 +142,17 @@ class CudaCheckpointSession:
111
142
 
112
143
  def __init__(self):
113
144
  self.cuda_processes = self._get_cuda_pids()
114
- logger.debug(f"PIDs with CUDA sessions: {[c.pid for c in self.cuda_processes]}")
145
+ if self.cuda_processes:
146
+ logger.debug(
147
+ f"Found {len(self.cuda_processes)} PID(s) with CUDA sessions: {[c.pid for c in self.cuda_processes]}"
148
+ )
149
+ else:
150
+ logger.debug("No CUDA sessions found.")
115
151
 
116
- def _get_cuda_pids(self) -> list[CudaCheckpointProcess]:
152
+ def _get_cuda_pids(self) -> List[CudaCheckpointProcess]:
117
153
  """Iterates over all PIDs and identifies the ones that have running
118
154
  CUDA sessions."""
119
- cuda_pids: list[CudaCheckpointProcess] = []
155
+ cuda_pids: List[CudaCheckpointProcess] = []
120
156
 
121
157
  # Get all active process IDs from /proc directory
122
158
  proc_dir = Path("/proc")
@@ -125,75 +161,143 @@ class CudaCheckpointSession:
125
161
  "OS does not have /proc path rendering it incompatible with GPU memory snapshots."
126
162
  )
127
163
 
128
- for entry in proc_dir.iterdir():
129
- if not entry.name.isdigit():
130
- continue
131
-
132
- pid = int(entry.name)
133
- try:
134
- # Call cuda-checkpoint to check if this PID has a CUDA session
135
- result = subprocess.run(
136
- [CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(pid)],
137
- capture_output=True,
138
- text=True,
139
- timeout=10,
140
- )
141
-
142
- # If the command succeeds (return code 0), this PID has a CUDA session
143
- if result.returncode == 0:
144
- state_str = result.stdout.strip().lower()
145
- state = CudaCheckpointState(state_str)
146
-
147
- cuda_checkpoint_process = CudaCheckpointProcess(pid=pid, state=state)
148
- cuda_pids.append(cuda_checkpoint_process)
164
+ # Get all numeric directories (PIDs) from /proc
165
+ pid_dirs = [entry for entry in proc_dir.iterdir() if entry.name.isdigit()]
149
166
 
150
- # Command failed, which is expected for PIDs without CUDA sessions
151
- except subprocess.CalledProcessError:
152
- continue
167
+ # Use ThreadPoolExecutor to check PIDs in parallel for better performance
168
+ with ThreadPoolExecutor(max_workers=min(50, len(pid_dirs))) as executor:
169
+ future_to_pid = {
170
+ executor.submit(self._check_cuda_session, int(entry.name)): int(entry.name) for entry in pid_dirs
171
+ }
153
172
 
154
- # Raise other exceptions
155
- except subprocess.TimeoutExpired:
156
- raise CudaCheckpointException(f"Failed to get CUDA state for PID {pid}")
157
- except Exception as e:
158
- raise CudaCheckpointException(e)
173
+ for future in as_completed(future_to_pid):
174
+ pid = future_to_pid[future]
175
+ try:
176
+ cuda_process = future.result()
177
+ if cuda_process:
178
+ cuda_pids.append(cuda_process)
179
+ except Exception as e:
180
+ logger.debug(f"Error checking PID {pid}: {e}")
159
181
 
160
182
  # Sort PIDs for ordered checkpointing
161
183
  cuda_pids.sort(key=lambda x: x.pid)
162
184
  return cuda_pids
163
185
 
186
+ def _check_cuda_session(self, pid: int) -> Optional[CudaCheckpointProcess]:
187
+ """Check if a specific PID has a CUDA session."""
188
+ try:
189
+ result = subprocess.run(
190
+ [CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(pid)],
191
+ capture_output=True,
192
+ text=True,
193
+ timeout=5,
194
+ )
195
+
196
+ # If the command succeeds (return code 0), this PID has a CUDA session
197
+ if result.returncode == 0:
198
+ state_str = result.stdout.strip().lower()
199
+ state = CudaCheckpointState(state_str)
200
+ return CudaCheckpointProcess(pid=pid, state=state)
201
+
202
+ except subprocess.CalledProcessError:
203
+ # Command failed, which is expected for PIDs without CUDA sessions
204
+ pass
205
+ except subprocess.TimeoutExpired:
206
+ logger.debug(f"Timeout checking CUDA state for PID {pid}")
207
+ except Exception as e:
208
+ logger.debug(f"Error checking PID {pid}: {e}")
209
+
210
+ return None
211
+
164
212
  def checkpoint(self) -> None:
213
+ """Checkpoint all CUDA processes, moving GPU memory to CPU."""
214
+ if not self.cuda_processes:
215
+ logger.debug("No CUDA processes to checkpoint.")
216
+ return
217
+
165
218
  # Validate all states first
166
219
  for proc in self.cuda_processes:
220
+ proc.refresh_state() # Refresh state before validation
167
221
  if proc.state != CudaCheckpointState.RUNNING:
168
- raise CudaCheckpointException(f"CUDA session not in {CudaCheckpointState.RUNNING} state.")
222
+ raise CudaCheckpointException(
223
+ f"PID {proc.pid}: CUDA session not in {CudaCheckpointState.RUNNING.value} state. "
224
+ f"Current state: {proc.state.value}"
225
+ )
169
226
 
170
227
  # Moving state from GPU to CPU can take several seconds per CUDA session.
171
228
  # Make a parallel call per CUDA session.
172
229
  start = time.perf_counter()
173
230
 
174
- def checkpoint_impl(proc: CudaCheckpointProcess):
231
+ def checkpoint_impl(proc: CudaCheckpointProcess) -> None:
175
232
  proc.toggle(CudaCheckpointState.CHECKPOINTED)
176
233
 
177
234
  with ThreadPoolExecutor() as executor:
178
- list(executor.map(checkpoint_impl, self.cuda_processes))
235
+ futures = [executor.submit(checkpoint_impl, proc) for proc in self.cuda_processes]
236
+
237
+ # Wait for all futures and collect any exceptions
238
+ exceptions = []
239
+ for future in as_completed(futures):
240
+ try:
241
+ future.result()
242
+ except Exception as e:
243
+ exceptions.append(e)
244
+
245
+ if exceptions:
246
+ raise CudaCheckpointException(
247
+ f"Failed to checkpoint {len(exceptions)} processes: {'; '.join(str(e) for e in exceptions)}"
248
+ )
179
249
 
180
250
  elapsed = time.perf_counter() - start
181
- logger.debug(f"Checkpointing CUDA sessions took => {elapsed:.3f}s")
251
+ logger.debug(f"Checkpointing {len(self.cuda_processes)} CUDA sessions took => {elapsed:.3f}s")
182
252
 
183
253
  def restore(self) -> None:
254
+ """Restore all CUDA processes, moving memory back from CPU to GPU."""
255
+ if not self.cuda_processes:
256
+ logger.debug("No CUDA sessions to restore.")
257
+ return
258
+
184
259
  # Validate all states first
185
260
  for proc in self.cuda_processes:
261
+ proc.refresh_state() # Refresh state before validation
186
262
  if proc.state != CudaCheckpointState.CHECKPOINTED:
187
- raise CudaCheckpointException(f"CUDA session not in {CudaCheckpointState.CHECKPOINTED} state.")
263
+ raise CudaCheckpointException(
264
+ f"PID {proc.pid}: CUDA session not in {CudaCheckpointState.CHECKPOINTED.value} state. "
265
+ f"Current state: {proc.state.value}"
266
+ )
188
267
 
189
268
  # See checkpoint() for rationale about parallelism.
190
269
  start = time.perf_counter()
191
270
 
192
- def restore_process(proc: CudaCheckpointProcess):
271
+ def restore_process(proc: CudaCheckpointProcess) -> None:
193
272
  proc.toggle(CudaCheckpointState.RUNNING)
194
273
 
195
274
  with ThreadPoolExecutor() as executor:
196
- list(executor.map(restore_process, self.cuda_processes))
275
+ futures = [executor.submit(restore_process, proc) for proc in self.cuda_processes]
276
+
277
+ # Wait for all futures and collect any exceptions
278
+ exceptions = []
279
+ for future in as_completed(futures):
280
+ try:
281
+ future.result()
282
+ except Exception as e:
283
+ exceptions.append(e)
284
+
285
+ if exceptions:
286
+ raise CudaCheckpointException(
287
+ f"Failed to restore {len(exceptions)} processes: {'; '.join(str(e) for e in exceptions)}"
288
+ )
197
289
 
198
290
  elapsed = time.perf_counter() - start
199
- logger.debug(f"Restoring CUDA sessions took => {elapsed:.3f}s")
291
+ logger.debug(f"Restoring {len(self.cuda_processes)} CUDA session(s) took => {elapsed:.3f}s")
292
+
293
+ def get_process_count(self) -> int:
294
+ """Get the number of CUDA processes managed by this session."""
295
+ return len(self.cuda_processes)
296
+
297
+ def get_process_states(self) -> List[tuple[int, CudaCheckpointState]]:
298
+ """Get current states of all managed processes."""
299
+ states = []
300
+ for proc in self.cuda_processes:
301
+ proc.refresh_state()
302
+ states.append((proc.pid, proc.state))
303
+ return states