modal 1.0.3.dev10__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (160) hide show
  1. modal/__init__.py +0 -2
  2. modal/__main__.py +3 -4
  3. modal/_billing.py +80 -0
  4. modal/_clustered_functions.py +7 -3
  5. modal/_clustered_functions.pyi +15 -3
  6. modal/_container_entrypoint.py +51 -69
  7. modal/_functions.py +508 -240
  8. modal/_grpc_client.py +171 -0
  9. modal/_load_context.py +105 -0
  10. modal/_object.py +81 -21
  11. modal/_output.py +58 -45
  12. modal/_partial_function.py +48 -73
  13. modal/_pty.py +7 -3
  14. modal/_resolver.py +26 -46
  15. modal/_runtime/asgi.py +4 -3
  16. modal/_runtime/container_io_manager.py +358 -220
  17. modal/_runtime/container_io_manager.pyi +296 -101
  18. modal/_runtime/execution_context.py +18 -2
  19. modal/_runtime/execution_context.pyi +64 -7
  20. modal/_runtime/gpu_memory_snapshot.py +262 -57
  21. modal/_runtime/user_code_imports.py +28 -58
  22. modal/_serialization.py +90 -6
  23. modal/_traceback.py +42 -1
  24. modal/_tunnel.pyi +380 -12
  25. modal/_utils/async_utils.py +84 -29
  26. modal/_utils/auth_token_manager.py +111 -0
  27. modal/_utils/blob_utils.py +181 -58
  28. modal/_utils/deprecation.py +19 -0
  29. modal/_utils/function_utils.py +91 -47
  30. modal/_utils/grpc_utils.py +89 -66
  31. modal/_utils/mount_utils.py +26 -1
  32. modal/_utils/name_utils.py +17 -3
  33. modal/_utils/task_command_router_client.py +536 -0
  34. modal/_utils/time_utils.py +34 -6
  35. modal/app.py +256 -88
  36. modal/app.pyi +909 -92
  37. modal/billing.py +5 -0
  38. modal/builder/2025.06.txt +18 -0
  39. modal/builder/PREVIEW.txt +18 -0
  40. modal/builder/base-images.json +58 -0
  41. modal/cli/_download.py +19 -3
  42. modal/cli/_traceback.py +3 -2
  43. modal/cli/app.py +4 -4
  44. modal/cli/cluster.py +15 -7
  45. modal/cli/config.py +5 -3
  46. modal/cli/container.py +7 -6
  47. modal/cli/dict.py +22 -16
  48. modal/cli/entry_point.py +12 -5
  49. modal/cli/environment.py +5 -4
  50. modal/cli/import_refs.py +3 -3
  51. modal/cli/launch.py +102 -5
  52. modal/cli/network_file_system.py +11 -12
  53. modal/cli/profile.py +3 -2
  54. modal/cli/programs/launch_instance_ssh.py +94 -0
  55. modal/cli/programs/run_jupyter.py +1 -1
  56. modal/cli/programs/run_marimo.py +95 -0
  57. modal/cli/programs/vscode.py +1 -1
  58. modal/cli/queues.py +57 -26
  59. modal/cli/run.py +91 -23
  60. modal/cli/secret.py +48 -22
  61. modal/cli/token.py +7 -8
  62. modal/cli/utils.py +4 -7
  63. modal/cli/volume.py +31 -25
  64. modal/client.py +15 -85
  65. modal/client.pyi +183 -62
  66. modal/cloud_bucket_mount.py +5 -3
  67. modal/cloud_bucket_mount.pyi +197 -5
  68. modal/cls.py +200 -126
  69. modal/cls.pyi +446 -68
  70. modal/config.py +29 -11
  71. modal/container_process.py +319 -19
  72. modal/container_process.pyi +190 -20
  73. modal/dict.py +290 -71
  74. modal/dict.pyi +835 -83
  75. modal/environments.py +15 -27
  76. modal/environments.pyi +46 -24
  77. modal/exception.py +14 -2
  78. modal/experimental/__init__.py +194 -40
  79. modal/experimental/flash.py +618 -0
  80. modal/experimental/flash.pyi +380 -0
  81. modal/experimental/ipython.py +11 -7
  82. modal/file_io.py +29 -36
  83. modal/file_io.pyi +251 -53
  84. modal/file_pattern_matcher.py +56 -16
  85. modal/functions.pyi +673 -92
  86. modal/gpu.py +1 -1
  87. modal/image.py +528 -176
  88. modal/image.pyi +1572 -145
  89. modal/io_streams.py +458 -128
  90. modal/io_streams.pyi +433 -52
  91. modal/mount.py +216 -151
  92. modal/mount.pyi +225 -78
  93. modal/network_file_system.py +45 -62
  94. modal/network_file_system.pyi +277 -56
  95. modal/object.pyi +93 -17
  96. modal/parallel_map.py +942 -129
  97. modal/parallel_map.pyi +294 -15
  98. modal/partial_function.py +0 -2
  99. modal/partial_function.pyi +234 -19
  100. modal/proxy.py +17 -8
  101. modal/proxy.pyi +36 -3
  102. modal/queue.py +270 -65
  103. modal/queue.pyi +817 -57
  104. modal/runner.py +115 -101
  105. modal/runner.pyi +205 -49
  106. modal/sandbox.py +512 -136
  107. modal/sandbox.pyi +845 -111
  108. modal/schedule.py +1 -1
  109. modal/secret.py +300 -70
  110. modal/secret.pyi +589 -34
  111. modal/serving.py +7 -11
  112. modal/serving.pyi +7 -8
  113. modal/snapshot.py +11 -8
  114. modal/snapshot.pyi +25 -4
  115. modal/token_flow.py +4 -4
  116. modal/token_flow.pyi +28 -8
  117. modal/volume.py +416 -158
  118. modal/volume.pyi +1117 -121
  119. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +10 -9
  120. modal-1.2.3.dev7.dist-info/RECORD +195 -0
  121. modal_docs/mdmd/mdmd.py +17 -4
  122. modal_proto/api.proto +534 -79
  123. modal_proto/api_grpc.py +337 -1
  124. modal_proto/api_pb2.py +1522 -968
  125. modal_proto/api_pb2.pyi +1619 -134
  126. modal_proto/api_pb2_grpc.py +699 -4
  127. modal_proto/api_pb2_grpc.pyi +226 -14
  128. modal_proto/modal_api_grpc.py +175 -154
  129. modal_proto/sandbox_router.proto +145 -0
  130. modal_proto/sandbox_router_grpc.py +105 -0
  131. modal_proto/sandbox_router_pb2.py +149 -0
  132. modal_proto/sandbox_router_pb2.pyi +333 -0
  133. modal_proto/sandbox_router_pb2_grpc.py +203 -0
  134. modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
  135. modal_proto/task_command_router.proto +144 -0
  136. modal_proto/task_command_router_grpc.py +105 -0
  137. modal_proto/task_command_router_pb2.py +149 -0
  138. modal_proto/task_command_router_pb2.pyi +333 -0
  139. modal_proto/task_command_router_pb2_grpc.py +203 -0
  140. modal_proto/task_command_router_pb2_grpc.pyi +75 -0
  141. modal_version/__init__.py +1 -1
  142. modal/requirements/PREVIEW.txt +0 -16
  143. modal/requirements/base-images.json +0 -26
  144. modal-1.0.3.dev10.dist-info/RECORD +0 -179
  145. modal_proto/modal_options_grpc.py +0 -3
  146. modal_proto/options.proto +0 -19
  147. modal_proto/options_grpc.py +0 -3
  148. modal_proto/options_pb2.py +0 -35
  149. modal_proto/options_pb2.pyi +0 -20
  150. modal_proto/options_pb2_grpc.py +0 -4
  151. modal_proto/options_pb2_grpc.pyi +0 -7
  152. /modal/{requirements → builder}/2023.12.312.txt +0 -0
  153. /modal/{requirements → builder}/2023.12.txt +0 -0
  154. /modal/{requirements → builder}/2024.04.txt +0 -0
  155. /modal/{requirements → builder}/2024.10.txt +0 -0
  156. /modal/{requirements → builder}/README.md +0 -0
  157. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
  158. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
  159. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
  160. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
@@ -3,22 +3,79 @@ import contextvars
3
3
  import typing
4
4
  import typing_extensions
5
5
 
6
- def is_local() -> bool: ...
7
- async def _interact() -> None: ...
6
+ def is_local() -> bool:
7
+ """Returns if we are currently on the machine launching/deploying a Modal app
8
+
9
+ Returns `True` when executed locally on the user's machine.
10
+ Returns `False` when executed from a Modal container in the cloud.
11
+ """
12
+ ...
13
+
14
+ async def _interact() -> None:
15
+ """Enable interactivity with user input inside a Modal container.
16
+
17
+ See the [interactivity guide](https://modal.com/docs/guide/developing-debugging#interactivity)
18
+ for more information on how to use this function.
19
+ """
20
+ ...
8
21
 
9
22
  class __interact_spec(typing_extensions.Protocol):
10
- def __call__(self, /) -> None: ...
11
- async def aio(self, /) -> None: ...
23
+ def __call__(self, /) -> None:
24
+ """Enable interactivity with user input inside a Modal container.
25
+
26
+ See the [interactivity guide](https://modal.com/docs/guide/developing-debugging#interactivity)
27
+ for more information on how to use this function.
28
+ """
29
+ ...
30
+
31
+ async def aio(self, /) -> None:
32
+ """Enable interactivity with user input inside a Modal container.
33
+
34
+ See the [interactivity guide](https://modal.com/docs/guide/developing-debugging#interactivity)
35
+ for more information on how to use this function.
36
+ """
37
+ ...
12
38
 
13
39
  interact: __interact_spec
14
40
 
15
- def current_input_id() -> typing.Optional[str]: ...
16
- def current_function_call_id() -> typing.Optional[str]: ...
41
+ def current_input_id() -> typing.Optional[str]:
42
+ """Returns the input ID for the current input.
43
+
44
+ Can only be called from Modal function (i.e. in a container context).
45
+
46
+ ```python
47
+ from modal import current_input_id
48
+
49
+ @app.function()
50
+ def process_stuff():
51
+ print(f"Starting to process {current_input_id()}")
52
+ ```
53
+ """
54
+ ...
55
+
56
+ def current_function_call_id() -> typing.Optional[str]:
57
+ """Returns the function call ID for the current input.
58
+
59
+ Can only be called from Modal function (i.e. in a container context).
60
+
61
+ ```python
62
+ from modal import current_function_call_id
63
+
64
+ @app.function()
65
+ def process_stuff():
66
+ print(f"Starting to process input from {current_function_call_id()}")
67
+ ```
68
+ """
69
+ ...
70
+
71
+ def current_attempt_token() -> typing.Optional[str]: ...
17
72
  def _set_current_context_ids(
18
- input_ids: list[str], function_call_ids: list[str]
73
+ input_ids: list[str], function_call_ids: list[str], attempt_tokens: list[str]
19
74
  ) -> collections.abc.Callable[[], None]: ...
20
75
  def _import_context(): ...
21
76
 
22
77
  _current_input_id: contextvars.ContextVar
23
78
 
24
79
  _current_function_call_id: contextvars.ContextVar
80
+
81
+ _current_attempt_token: contextvars.ContextVar
@@ -1,23 +1,34 @@
1
1
  # Copyright Modal Labs 2022
2
2
  #
3
3
  # This module provides a simple interface for creating GPU memory snapshots,
4
- # provising a convenient interface to `cuda-checkpoint` [1]. This is intended
4
+ # providing a convenient interface to `cuda-checkpoint` [1]. This is intended
5
5
  # to be used in conjunction with memory snapshots.
6
6
  #
7
7
  # [1] https://github.com/NVIDIA/cuda-checkpoint
8
8
 
9
- import os
10
9
  import subprocess
11
10
  import time
11
+ from concurrent.futures import ThreadPoolExecutor, as_completed
12
+ from dataclasses import dataclass
12
13
  from enum import Enum
14
+ from pathlib import Path
15
+ from typing import List, Optional
13
16
 
14
17
  from modal.config import config, logger
15
18
 
16
19
  CUDA_CHECKPOINT_PATH: str = config.get("cuda_checkpoint_path")
17
20
 
21
+ # Maximum total duration for an entire toggle operation.
22
+ CUDA_CHECKPOINT_TOGGLE_TIMEOUT: float = 5 * 60.0
23
+
24
+ # Maximum total duration for each individual `cuda-checkpoint` invocation.
25
+ CUDA_CHECKPOINT_TIMEOUT: float = 90
26
+
18
27
 
19
28
  class CudaCheckpointState(Enum):
20
- """State representation from the CUDA API: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1gc96cdda177a2b8c296144567cbea4f23"""
29
+ """State representation from the CUDA API [1].
30
+
31
+ [1] https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html"""
21
32
 
22
33
  RUNNING = "running"
23
34
  LOCKED = "locked"
@@ -26,76 +37,270 @@ class CudaCheckpointState(Enum):
26
37
 
27
38
 
28
39
  class CudaCheckpointException(Exception):
40
+ """Exception raised for CUDA checkpoint operations."""
41
+
29
42
  pass
30
43
 
31
44
 
32
- def toggle():
33
- """Toggle CUDA checkpoint state for current process, moving GPU memory to the
34
- CPU and back depending on the current process state when called."""
35
- pid = get_own_pid()
36
- logger.debug(f"Toggling CUDA checkpoint state for PID {pid}")
45
+ @dataclass
46
+ class CudaCheckpointProcess:
47
+ """Contains a reference to a PID with active CUDA session. This also provides
48
+ methods for checkpointing and restoring GPU memory."""
37
49
 
38
- try:
39
- subprocess.run(
40
- [
41
- CUDA_CHECKPOINT_PATH,
42
- "--toggle",
43
- "--pid",
44
- str(pid),
45
- ],
46
- check=True,
47
- capture_output=True,
48
- text=True,
49
- )
50
- logger.debug("Successfully toggled CUDA checkpoint state")
50
+ pid: int
51
+ state: CudaCheckpointState
51
52
 
52
- except subprocess.CalledProcessError as e:
53
- logger.debug(f"Failed to toggle CUDA checkpoint state: {e.stderr}")
54
- raise CudaCheckpointException(e.stderr)
53
+ def toggle(self, target_state: CudaCheckpointState, skip_first_refresh: bool = False) -> None:
54
+ """Toggle CUDA checkpoint state for current process, moving GPU memory to the
55
+ CPU and back depending on the current process state when called.
56
+ """
57
+ logger.debug(f"PID: {self.pid} Toggling CUDA checkpoint state to {target_state.value}")
55
58
 
59
+ start_time = time.monotonic()
60
+ retry_count = 0
61
+ max_retries = 3
56
62
 
57
- def get_state() -> CudaCheckpointState:
58
- """Get current CUDA checkpoint state for this process."""
59
- pid = get_own_pid()
63
+ attempts = 0
64
+ while self._should_continue_toggle(
65
+ target_state, start_time, refresh=not (skip_first_refresh and attempts == 0)
66
+ ):
67
+ attempts += 1
68
+ try:
69
+ self._execute_toggle_command()
70
+ # Use exponential backoff for retries
71
+ sleep_time = min(0.1 * (2**retry_count), 1.0)
72
+ time.sleep(sleep_time)
73
+ retry_count = 0
74
+ except CudaCheckpointException as e:
75
+ retry_count += 1
76
+ if retry_count >= max_retries:
77
+ raise CudaCheckpointException(
78
+ f"PID: {self.pid} Failed to toggle state after {max_retries} retries: {e}"
79
+ )
80
+ logger.debug(f"PID: {self.pid} Retry {retry_count}/{max_retries} after error: {e}")
81
+ time.sleep(0.5 * retry_count)
60
82
 
61
- try:
62
- result = subprocess.run(
63
- [CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(pid)], check=True, capture_output=True, text=True
64
- )
83
+ logger.debug(f"PID: {self.pid} Target state {target_state.value} reached")
65
84
 
66
- # Parse output to get state
67
- state_str = result.stdout.strip().lower()
68
- return CudaCheckpointState(state_str)
85
+ def _should_continue_toggle(
86
+ self, target_state: CudaCheckpointState, start_time: float, refresh: bool = True
87
+ ) -> bool:
88
+ """Check if toggle operation should continue based on current state and timeout."""
89
+ if refresh:
90
+ self.refresh_state()
69
91
 
70
- except subprocess.CalledProcessError as e:
71
- logger.debug(f"Failed to get CUDA checkpoint state: {e.stderr}")
72
- raise CudaCheckpointException(e.stderr)
92
+ if self.state == target_state:
93
+ return False
73
94
 
95
+ if self.state == CudaCheckpointState.FAILED:
96
+ raise CudaCheckpointException(f"PID: {self.pid} CUDA process state is {self.state}")
74
97
 
75
- def wait_for_state(target_state: CudaCheckpointState, timeout_secs: float = 5.0):
76
- """Wait for CUDA checkpoint to reach a specific state."""
77
- logger.debug(f"Waiting for CUDA checkpoint state {target_state.value}")
78
- start_time = time.monotonic()
98
+ elapsed = time.monotonic() - start_time
99
+ if elapsed >= CUDA_CHECKPOINT_TOGGLE_TIMEOUT:
100
+ raise CudaCheckpointException(
101
+ f"PID: {self.pid} Timeout after {elapsed:.2f}s waiting for state {target_state.value}. "
102
+ f"Current state: {self.state}"
103
+ )
79
104
 
80
- while True:
81
- current_state = get_state()
105
+ return True
82
106
 
83
- if current_state == target_state:
84
- logger.debug(f"Target state {target_state.value} reached")
85
- break
107
+ def _execute_toggle_command(self) -> None:
108
+ """Execute the cuda-checkpoint toggle command."""
109
+ try:
110
+ _ = subprocess.run(
111
+ [CUDA_CHECKPOINT_PATH, "--toggle", "--pid", str(self.pid)],
112
+ check=True,
113
+ capture_output=True,
114
+ text=True,
115
+ timeout=CUDA_CHECKPOINT_TIMEOUT,
116
+ )
117
+ logger.debug(f"PID: {self.pid} Successfully toggled CUDA checkpoint state")
118
+ except subprocess.CalledProcessError as e:
119
+ error_msg = f"PID: {self.pid} Failed to toggle CUDA checkpoint state: {e.stderr}"
120
+ logger.debug(error_msg)
121
+ raise CudaCheckpointException(error_msg)
122
+ except subprocess.TimeoutExpired:
123
+ error_msg = f"PID: {self.pid} Toggle command timed out"
124
+ logger.debug(error_msg)
125
+ raise CudaCheckpointException(error_msg)
86
126
 
87
- if current_state == CudaCheckpointState.FAILED:
88
- raise CudaCheckpointException(f"CUDA process state is {current_state}")
127
+ def refresh_state(self) -> None:
128
+ """Refreshes the current CUDA checkpoint state for this process."""
129
+ try:
130
+ result = subprocess.run(
131
+ [CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(self.pid)],
132
+ check=True,
133
+ capture_output=True,
134
+ text=True,
135
+ timeout=CUDA_CHECKPOINT_TIMEOUT,
136
+ )
89
137
 
90
- elapsed = time.monotonic() - start_time
91
- if elapsed >= timeout_secs:
92
- raise CudaCheckpointException(f"Timeout after {elapsed:.2f}s waiting for state {target_state.value}")
138
+ state_str = result.stdout.strip().lower()
139
+ self.state = CudaCheckpointState(state_str)
140
+
141
+ except subprocess.CalledProcessError as e:
142
+ error_msg = f"PID: {self.pid} Failed to get CUDA checkpoint state: {e.stderr}"
143
+ logger.debug(error_msg)
144
+ raise CudaCheckpointException(error_msg)
145
+ except subprocess.TimeoutExpired:
146
+ error_msg = f"PID: {self.pid} Get state command timed out"
147
+ logger.debug(error_msg)
148
+ raise CudaCheckpointException(error_msg)
149
+
150
+
151
+ class CudaCheckpointSession:
152
+ """Manages the checkpointing state of processes with active CUDA sessions."""
153
+
154
+ def __init__(self):
155
+ self.cuda_processes = self._get_cuda_pids()
156
+ if self.cuda_processes:
157
+ logger.debug(
158
+ f"Found {len(self.cuda_processes)} PID(s) with CUDA sessions: {[c.pid for c in self.cuda_processes]}"
159
+ )
160
+ else:
161
+ logger.debug("No CUDA sessions found.")
162
+
163
+ def _get_cuda_pids(self) -> List[CudaCheckpointProcess]:
164
+ """Iterates over all PIDs and identifies the ones that have running
165
+ CUDA sessions."""
166
+ cuda_pids: List[CudaCheckpointProcess] = []
167
+
168
+ # Get all active process IDs from /proc directory
169
+ proc_dir = Path("/proc")
170
+ if not proc_dir.exists():
171
+ raise CudaCheckpointException(
172
+ "OS does not have /proc path rendering it incompatible with GPU memory snapshots."
173
+ )
174
+
175
+ # Get all numeric directories (PIDs) from /proc
176
+ pid_dirs = [entry for entry in proc_dir.iterdir() if entry.name.isdigit()]
177
+
178
+ # Use ThreadPoolExecutor to check PIDs in parallel for better performance
179
+ with ThreadPoolExecutor(max_workers=min(50, len(pid_dirs))) as executor:
180
+ future_to_pid = {
181
+ executor.submit(self._check_cuda_session, int(entry.name)): int(entry.name) for entry in pid_dirs
182
+ }
183
+
184
+ for future in as_completed(future_to_pid):
185
+ pid = future_to_pid[future]
186
+ try:
187
+ cuda_process = future.result()
188
+ if cuda_process:
189
+ cuda_pids.append(cuda_process)
190
+ except Exception as e:
191
+ logger.debug(f"Error checking PID {pid}: {e}")
192
+
193
+ # Sort PIDs for ordered checkpointing
194
+ cuda_pids.sort(key=lambda x: x.pid)
195
+ return cuda_pids
196
+
197
+ def _check_cuda_session(self, pid: int) -> Optional[CudaCheckpointProcess]:
198
+ """Check if a specific PID has a CUDA session."""
199
+ try:
200
+ result = subprocess.run(
201
+ [CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(pid)],
202
+ capture_output=True,
203
+ text=True,
204
+ # This should be quick since no checkpoint has taken place yet
205
+ timeout=5,
206
+ )
207
+
208
+ # If the command succeeds (return code 0), this PID has a CUDA session
209
+ if result.returncode == 0:
210
+ state_str = result.stdout.strip().lower()
211
+ state = CudaCheckpointState(state_str)
212
+ return CudaCheckpointProcess(pid=pid, state=state)
213
+
214
+ except subprocess.CalledProcessError:
215
+ # Command failed, which is expected for PIDs without CUDA sessions
216
+ pass
217
+ except subprocess.TimeoutExpired:
218
+ logger.debug(f"Timeout checking CUDA state for PID {pid}")
219
+ except Exception as e:
220
+ logger.debug(f"Error checking PID {pid}: {e}")
221
+
222
+ return None
223
+
224
+ def checkpoint(self) -> None:
225
+ """Checkpoint all CUDA processes, moving GPU memory to CPU."""
226
+ if not self.cuda_processes:
227
+ logger.debug("No CUDA processes to checkpoint.")
228
+ return
229
+
230
+ # Validate all states first
231
+ for proc in self.cuda_processes:
232
+ proc.refresh_state() # Refresh state before validation
233
+ if proc.state != CudaCheckpointState.RUNNING:
234
+ raise CudaCheckpointException(
235
+ f"PID {proc.pid}: CUDA session not in {CudaCheckpointState.RUNNING.value} state. "
236
+ f"Current state: {proc.state.value}"
237
+ )
238
+
239
+ # Moving state from GPU to CPU can take several seconds per CUDA session.
240
+ # Make a parallel call per CUDA session.
241
+ start = time.perf_counter()
242
+
243
+ def checkpoint_impl(proc: CudaCheckpointProcess) -> None:
244
+ proc.toggle(CudaCheckpointState.CHECKPOINTED)
245
+
246
+ with ThreadPoolExecutor() as executor:
247
+ futures = [executor.submit(checkpoint_impl, proc) for proc in self.cuda_processes]
248
+
249
+ # Wait for all futures and collect any exceptions
250
+ exceptions = []
251
+ for future in as_completed(futures):
252
+ try:
253
+ future.result()
254
+ except Exception as e:
255
+ exceptions.append(e)
256
+
257
+ if exceptions:
258
+ raise CudaCheckpointException(
259
+ f"Failed to checkpoint {len(exceptions)} processes: {'; '.join(str(e) for e in exceptions)}"
260
+ )
261
+
262
+ elapsed = time.perf_counter() - start
263
+ logger.debug(f"Checkpointing {len(self.cuda_processes)} CUDA sessions took => {elapsed:.3f}s")
264
+
265
+ def restore(self) -> None:
266
+ """Restore all CUDA processes, moving memory back from CPU to GPU."""
267
+ if not self.cuda_processes:
268
+ logger.debug("No CUDA sessions to restore.")
269
+ return
270
+
271
+ # See checkpoint() for rationale about parallelism.
272
+ start = time.perf_counter()
273
+
274
+ def restore_process(proc: CudaCheckpointProcess) -> None:
275
+ proc.toggle(CudaCheckpointState.RUNNING, skip_first_refresh=True)
276
+
277
+ with ThreadPoolExecutor() as executor:
278
+ futures = [executor.submit(restore_process, proc) for proc in self.cuda_processes]
279
+
280
+ # Wait for all futures and collect any exceptions
281
+ exceptions = []
282
+ for future in as_completed(futures):
283
+ try:
284
+ future.result()
285
+ except Exception as e:
286
+ exceptions.append(e)
287
+
288
+ if exceptions:
289
+ raise CudaCheckpointException(
290
+ f"Failed to restore {len(exceptions)} processes: {'; '.join(str(e) for e in exceptions)}"
291
+ )
93
292
 
94
- time.sleep(0.1)
293
+ elapsed = time.perf_counter() - start
294
+ logger.debug(f"Restoring {len(self.cuda_processes)} CUDA session(s) took => {elapsed:.3f}s")
95
295
 
296
+ def get_process_count(self) -> int:
297
+ """Get the number of CUDA processes managed by this session."""
298
+ return len(self.cuda_processes)
96
299
 
97
- def get_own_pid():
98
- """Returns the Process ID (PID) of the current Python process
99
- using only the standard library.
100
- """
101
- return os.getpid()
300
+ def get_process_states(self) -> List[tuple[int, CudaCheckpointState]]:
301
+ """Get current states of all managed processes."""
302
+ states = []
303
+ for proc in self.cuda_processes:
304
+ proc.refresh_state()
305
+ states.append((proc.pid, proc.state))
306
+ return states
@@ -29,7 +29,7 @@ class FinalizedFunction:
29
29
  callable: Callable[..., Any]
30
30
  is_async: bool
31
31
  is_generator: bool
32
- data_format: int # api_pb2.DataFormat
32
+ supported_output_formats: Sequence["api_pb2.DataFormat.ValueType"]
33
33
  lifespan_manager: Optional["LifespanManager"] = None
34
34
 
35
35
 
@@ -93,9 +93,9 @@ def construct_webhook_callable(
93
93
 
94
94
  @dataclass
95
95
  class ImportedFunction(Service):
96
- user_cls_instance: Any
97
96
  app: modal.app._App
98
97
  service_deps: Optional[Sequence["modal._object._Object"]]
98
+ user_cls_instance = None
99
99
 
100
100
  _user_defined_callable: Callable[..., Any]
101
101
 
@@ -108,6 +108,7 @@ class ImportedFunction(Service):
108
108
  is_generator = fun_def.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR
109
109
 
110
110
  webhook_config = fun_def.webhook_config
111
+
111
112
  if not webhook_config.type:
112
113
  # for non-webhooks, the runnable is straight forward:
113
114
  return {
@@ -115,7 +116,10 @@ class ImportedFunction(Service):
115
116
  callable=self._user_defined_callable,
116
117
  is_async=is_async,
117
118
  is_generator=is_generator,
118
- data_format=api_pb2.DATA_FORMAT_PICKLE,
119
+ supported_output_formats=fun_def.supported_output_formats
120
+ # FIXME (elias): the following `or [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]` is only
121
+ # needed for tests
122
+ or [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR],
119
123
  )
120
124
  }
121
125
 
@@ -129,7 +133,8 @@ class ImportedFunction(Service):
129
133
  lifespan_manager=lifespan_manager,
130
134
  is_async=True,
131
135
  is_generator=True,
132
- data_format=api_pb2.DATA_FORMAT_ASGI,
136
+ # FIXME (elias): the following `or [api_pb2.DATA_FORMAT_ASGI]` is only needed for tests
137
+ supported_output_formats=fun_def.supported_output_formats or [api_pb2.DATA_FORMAT_ASGI],
133
138
  )
134
139
  }
135
140
 
@@ -154,6 +159,7 @@ class ImportedClass(Service):
154
159
  # Use the function definition for whether this is a generator (overriden by webhooks)
155
160
  is_generator = _partial.params.is_generator
156
161
  webhook_config = _partial.params.webhook_config
162
+ method_def = fun_def.method_definitions[method_name]
157
163
 
158
164
  bound_func = user_func.__get__(self.user_cls_instance)
159
165
 
@@ -163,7 +169,10 @@ class ImportedClass(Service):
163
169
  callable=bound_func,
164
170
  is_async=is_async,
165
171
  is_generator=bool(is_generator),
166
- data_format=api_pb2.DATA_FORMAT_PICKLE,
172
+ # FIXME (elias): the following `or [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]` is only
173
+ # needed for tests
174
+ supported_output_formats=method_def.supported_output_formats
175
+ or [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR],
167
176
  )
168
177
  else:
169
178
  web_callable, lifespan_manager = construct_webhook_callable(
@@ -174,7 +183,8 @@ class ImportedClass(Service):
174
183
  lifespan_manager=lifespan_manager,
175
184
  is_async=True,
176
185
  is_generator=True,
177
- data_format=api_pb2.DATA_FORMAT_ASGI,
186
+ # FIXME (elias): the following `or [api_pb2.DATA_FORMAT_ASGI]` is only needed for tests
187
+ supported_output_formats=method_def.supported_output_formats or [api_pb2.DATA_FORMAT_ASGI],
178
188
  )
179
189
  finalized_functions[method_name] = finalized_function
180
190
  return finalized_functions
@@ -199,7 +209,6 @@ def get_user_class_instance(_cls: modal.cls._Cls, args: tuple[Any, ...], kwargs:
199
209
 
200
210
  def import_single_function_service(
201
211
  function_def: api_pb2.Function,
202
- ser_cls: Optional[type], # used only for @build functions
203
212
  ser_fun: Optional[Callable[..., Any]],
204
213
  ) -> Service:
205
214
  """Imports a function dynamically, and locates the app.
@@ -228,12 +237,9 @@ def import_single_function_service(
228
237
  service_deps: Optional[Sequence["modal._object._Object"]] = None
229
238
  active_app: modal.app._App
230
239
 
231
- user_cls_or_cls: typing.Union[None, type, modal.cls.Cls]
232
- user_cls_instance = None
233
-
234
240
  if ser_fun is not None:
235
241
  # This is a serialized function we already fetched from the server
236
- user_cls_or_cls, user_defined_callable = ser_cls, ser_fun
242
+ user_defined_callable = ser_fun
237
243
  active_app = get_active_app_fallback(function_def)
238
244
  else:
239
245
  # Load the module dynamically
@@ -244,58 +250,22 @@ def import_single_function_service(
244
250
  raise LocalFunctionError("Attempted to load a function defined in a function scope")
245
251
 
246
252
  parts = qual_name.split(".")
247
- if len(parts) == 1:
248
- # This is a function
249
- user_cls_or_cls = None
250
- f = getattr(module, qual_name)
251
- if isinstance(f, Function):
252
- _function: modal._functions._Function[Any, Any, Any] = synchronizer._translate_in(f) # type: ignore
253
- service_deps = _function.deps(only_explicit_mounts=True)
254
- user_defined_callable = _function.get_raw_f()
255
- assert _function._app # app should always be set on a decorated function
256
- active_app = _function._app
257
- else:
258
- user_defined_callable = f
259
- active_app = get_active_app_fallback(function_def)
260
-
261
- elif len(parts) == 2:
262
- # This path should only be triggered by @build class builder methods and can be removed
263
- # once @build is deprecated.
264
- assert not function_def.use_method_name # new "placeholder methods" should not be invoked directly!
265
- assert function_def.is_builder_function
266
- cls_name, fun_name = parts
267
- user_cls_or_cls = getattr(module, cls_name)
268
- if isinstance(user_cls_or_cls, modal.cls.Cls):
269
- # The cls decorator is in global scope
270
- _cls = typing.cast(modal.cls._Cls, synchronizer._translate_in(user_cls_or_cls))
271
- user_defined_callable = _cls._callables[fun_name]
272
- # Intentionally not including these, since @build functions don't actually
273
- # forward the information from their parent class.
274
- # service_deps = _cls._get_class_service_function().deps(only_explicit_mounts=True)
275
- assert _cls._app
276
- active_app = _cls._app
277
- else:
278
- # This is non-decorated class
279
- user_defined_callable = getattr(user_cls_or_cls, fun_name) # unbound method
280
- active_app = get_active_app_fallback(function_def)
281
- else:
253
+ if len(parts) != 1:
282
254
  raise InvalidError(f"Invalid function qualname {qual_name}")
283
255
 
284
- # Instantiate the class if it's defined
285
- if user_cls_or_cls:
286
- if isinstance(user_cls_or_cls, modal.cls.Cls):
287
- # This code is only used for @build methods on classes
288
- _cls = typing.cast(modal.cls._Cls, user_cls_or_cls)
289
- user_cls_instance = get_user_class_instance(_cls, (), {})
290
- # Bind the unbound method to the instance as self (using the descriptor protocol!)
256
+ f = getattr(module, qual_name)
257
+ if isinstance(f, Function):
258
+ _function: modal._functions._Function[Any, Any, Any] = synchronizer._translate_in(f) # type: ignore
259
+ service_deps = _function.deps(only_explicit_mounts=True)
260
+ user_defined_callable = _function.get_raw_f()
261
+ assert _function._app # app should always be set on a decorated function
262
+ active_app = _function._app
291
263
  else:
292
- # serialized=True or "undecorated"
293
- user_cls_instance = user_cls_or_cls()
294
-
295
- user_defined_callable = user_defined_callable.__get__(user_cls_instance)
264
+ # function isn't decorated in global scope
265
+ user_defined_callable = f
266
+ active_app = get_active_app_fallback(function_def)
296
267
 
297
268
  return ImportedFunction(
298
- user_cls_instance,
299
269
  active_app,
300
270
  service_deps,
301
271
  user_defined_callable,