modal 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (75) hide show
  1. modal/__main__.py +2 -2
  2. modal/_clustered_functions.py +3 -0
  3. modal/_clustered_functions.pyi +3 -2
  4. modal/_functions.py +78 -26
  5. modal/_object.py +9 -1
  6. modal/_output.py +14 -25
  7. modal/_runtime/gpu_memory_snapshot.py +158 -54
  8. modal/_utils/async_utils.py +6 -4
  9. modal/_utils/auth_token_manager.py +1 -1
  10. modal/_utils/blob_utils.py +16 -21
  11. modal/_utils/function_utils.py +16 -4
  12. modal/_utils/time_utils.py +8 -4
  13. modal/app.py +0 -4
  14. modal/app.pyi +0 -4
  15. modal/cli/_traceback.py +3 -2
  16. modal/cli/app.py +4 -4
  17. modal/cli/cluster.py +4 -4
  18. modal/cli/config.py +2 -2
  19. modal/cli/container.py +2 -2
  20. modal/cli/dict.py +4 -4
  21. modal/cli/entry_point.py +2 -2
  22. modal/cli/import_refs.py +3 -3
  23. modal/cli/network_file_system.py +8 -9
  24. modal/cli/profile.py +2 -2
  25. modal/cli/queues.py +5 -5
  26. modal/cli/secret.py +5 -5
  27. modal/cli/utils.py +3 -4
  28. modal/cli/volume.py +8 -9
  29. modal/client.py +8 -1
  30. modal/client.pyi +9 -2
  31. modal/container_process.py +2 -2
  32. modal/dict.py +47 -3
  33. modal/dict.pyi +55 -0
  34. modal/exception.py +4 -0
  35. modal/experimental/__init__.py +1 -1
  36. modal/experimental/flash.py +18 -2
  37. modal/experimental/flash.pyi +19 -0
  38. modal/functions.pyi +0 -1
  39. modal/image.py +26 -10
  40. modal/image.pyi +12 -4
  41. modal/mount.py +1 -1
  42. modal/object.pyi +4 -0
  43. modal/parallel_map.py +432 -4
  44. modal/parallel_map.pyi +28 -0
  45. modal/queue.py +46 -3
  46. modal/queue.pyi +53 -0
  47. modal/sandbox.py +105 -25
  48. modal/sandbox.pyi +108 -18
  49. modal/secret.py +48 -5
  50. modal/secret.pyi +55 -0
  51. modal/token_flow.py +3 -3
  52. modal/volume.py +49 -18
  53. modal/volume.pyi +50 -8
  54. {modal-1.1.0.dist-info → modal-1.1.1.dist-info}/METADATA +2 -2
  55. {modal-1.1.0.dist-info → modal-1.1.1.dist-info}/RECORD +75 -75
  56. modal_proto/api.proto +140 -14
  57. modal_proto/api_grpc.py +80 -0
  58. modal_proto/api_pb2.py +927 -756
  59. modal_proto/api_pb2.pyi +488 -34
  60. modal_proto/api_pb2_grpc.py +166 -0
  61. modal_proto/api_pb2_grpc.pyi +52 -0
  62. modal_proto/modal_api_grpc.py +5 -0
  63. modal_version/__init__.py +1 -1
  64. /modal/{requirements → builder}/2023.12.312.txt +0 -0
  65. /modal/{requirements → builder}/2023.12.txt +0 -0
  66. /modal/{requirements → builder}/2024.04.txt +0 -0
  67. /modal/{requirements → builder}/2024.10.txt +0 -0
  68. /modal/{requirements → builder}/2025.06.txt +0 -0
  69. /modal/{requirements → builder}/PREVIEW.txt +0 -0
  70. /modal/{requirements → builder}/README.md +0 -0
  71. /modal/{requirements → builder}/base-images.json +0 -0
  72. {modal-1.1.0.dist-info → modal-1.1.1.dist-info}/WHEEL +0 -0
  73. {modal-1.1.0.dist-info → modal-1.1.1.dist-info}/entry_points.txt +0 -0
  74. {modal-1.1.0.dist-info → modal-1.1.1.dist-info}/licenses/LICENSE +0 -0
  75. {modal-1.1.0.dist-info → modal-1.1.1.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,18 @@
1
1
  # Copyright Modal Labs 2022
2
2
  #
3
3
  # This module provides a simple interface for creating GPU memory snapshots,
4
- # provising a convenient interface to `cuda-checkpoint` [1]. This is intended
4
+ # providing a convenient interface to `cuda-checkpoint` [1]. This is intended
5
5
  # to be used in conjunction with memory snapshots.
6
6
  #
7
7
  # [1] https://github.com/NVIDIA/cuda-checkpoint
8
8
 
9
9
  import subprocess
10
10
  import time
11
- from concurrent.futures import ThreadPoolExecutor
11
+ from concurrent.futures import ThreadPoolExecutor, as_completed
12
12
  from dataclasses import dataclass
13
13
  from enum import Enum
14
14
  from pathlib import Path
15
+ from typing import List, Optional
15
16
 
16
17
  from modal.config import config, logger
17
18
 
@@ -19,7 +20,9 @@ CUDA_CHECKPOINT_PATH: str = config.get("cuda_checkpoint_path")
19
20
 
20
21
 
21
22
  class CudaCheckpointState(Enum):
22
- """State representation from the CUDA API: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1gc96cdda177a2b8c296144567cbea4f23"""
23
+ """State representation from the CUDA API [1].
24
+
25
+ [1] https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html"""
23
26
 
24
27
  RUNNING = "running"
25
28
  LOCKED = "locked"
@@ -28,6 +31,8 @@ class CudaCheckpointState(Enum):
28
31
 
29
32
 
30
33
  class CudaCheckpointException(Exception):
34
+ """Exception raised for CUDA checkpoint operations."""
35
+
31
36
  pass
32
37
 
33
38
 
@@ -39,16 +44,31 @@ class CudaCheckpointProcess:
39
44
  pid: int
40
45
  state: CudaCheckpointState
41
46
 
42
- def toggle(self, target_state: CudaCheckpointState, timeout_secs: float = 5 * 60.0):
47
+ def toggle(self, target_state: CudaCheckpointState, timeout_secs: float = 5 * 60.0) -> None:
43
48
  """Toggle CUDA checkpoint state for current process, moving GPU memory to the
44
- CPU and back depending on the current process state when called."""
49
+ CPU and back depending on the current process state when called.
50
+ """
45
51
  logger.debug(f"PID: {self.pid} Toggling CUDA checkpoint state to {target_state.value}")
46
52
 
47
53
  start_time = time.monotonic()
54
+ retry_count = 0
55
+ max_retries = 3
48
56
 
49
57
  while self._should_continue_toggle(target_state, start_time, timeout_secs):
50
- self._execute_toggle_command()
51
- time.sleep(0.1)
58
+ try:
59
+ self._execute_toggle_command()
60
+ # Use exponential backoff for retries
61
+ sleep_time = min(0.1 * (2**retry_count), 1.0)
62
+ time.sleep(sleep_time)
63
+ retry_count = 0
64
+ except CudaCheckpointException as e:
65
+ retry_count += 1
66
+ if retry_count >= max_retries:
67
+ raise CudaCheckpointException(
68
+ f"PID: {self.pid} Failed to toggle state after {max_retries} retries: {e}"
69
+ )
70
+ logger.debug(f"PID: {self.pid} Retry {retry_count}/{max_retries} after error: {e}")
71
+ time.sleep(0.5 * retry_count)
52
72
 
53
73
  logger.debug(f"PID: {self.pid} Target state {target_state.value} reached")
54
74
 
@@ -73,19 +93,25 @@ class CudaCheckpointProcess:
73
93
 
74
94
  return True
75
95
 
76
- def _execute_toggle_command(self):
96
+ def _execute_toggle_command(self) -> None:
77
97
  """Execute the cuda-checkpoint toggle command."""
78
98
  try:
79
- subprocess.run(
99
+ _ = subprocess.run(
80
100
  [CUDA_CHECKPOINT_PATH, "--toggle", "--pid", str(self.pid)],
81
101
  check=True,
82
102
  capture_output=True,
83
103
  text=True,
104
+ timeout=30,
84
105
  )
85
106
  logger.debug(f"PID: {self.pid} Successfully toggled CUDA checkpoint state")
86
107
  except subprocess.CalledProcessError as e:
87
- logger.debug(f"PID: {self.pid} Failed to toggle CUDA checkpoint state: {e.stderr}")
88
- raise CudaCheckpointException(e.stderr)
108
+ error_msg = f"PID: {self.pid} Failed to toggle CUDA checkpoint state: {e.stderr}"
109
+ logger.debug(error_msg)
110
+ raise CudaCheckpointException(error_msg)
111
+ except subprocess.TimeoutExpired:
112
+ error_msg = f"PID: {self.pid} Toggle command timed out"
113
+ logger.debug(error_msg)
114
+ raise CudaCheckpointException(error_msg)
89
115
 
90
116
  def refresh_state(self) -> None:
91
117
  """Refreshes the current CUDA checkpoint state for this process."""
@@ -95,15 +121,20 @@ class CudaCheckpointProcess:
95
121
  check=True,
96
122
  capture_output=True,
97
123
  text=True,
98
- timeout=5,
124
+ timeout=10,
99
125
  )
100
126
 
101
127
  state_str = result.stdout.strip().lower()
102
128
  self.state = CudaCheckpointState(state_str)
103
129
 
104
130
  except subprocess.CalledProcessError as e:
105
- logger.debug(f"PID: {self.pid} Failed to get CUDA checkpoint state: {e.stderr}")
106
- raise CudaCheckpointException(e.stderr)
131
+ error_msg = f"PID: {self.pid} Failed to get CUDA checkpoint state: {e.stderr}"
132
+ logger.debug(error_msg)
133
+ raise CudaCheckpointException(error_msg)
134
+ except subprocess.TimeoutExpired:
135
+ error_msg = f"PID: {self.pid} Get state command timed out"
136
+ logger.debug(error_msg)
137
+ raise CudaCheckpointException(error_msg)
107
138
 
108
139
 
109
140
  class CudaCheckpointSession:
@@ -111,12 +142,17 @@ class CudaCheckpointSession:
111
142
 
112
143
  def __init__(self):
113
144
  self.cuda_processes = self._get_cuda_pids()
114
- logger.debug(f"PIDs with CUDA sessions: {[c.pid for c in self.cuda_processes]}")
145
+ if self.cuda_processes:
146
+ logger.debug(
147
+ f"Found {len(self.cuda_processes)} PID(s) with CUDA sessions: {[c.pid for c in self.cuda_processes]}"
148
+ )
149
+ else:
150
+ logger.debug("No CUDA sessions found.")
115
151
 
116
- def _get_cuda_pids(self) -> list[CudaCheckpointProcess]:
152
+ def _get_cuda_pids(self) -> List[CudaCheckpointProcess]:
117
153
  """Iterates over all PIDs and identifies the ones that have running
118
154
  CUDA sessions."""
119
- cuda_pids: list[CudaCheckpointProcess] = []
155
+ cuda_pids: List[CudaCheckpointProcess] = []
120
156
 
121
157
  # Get all active process IDs from /proc directory
122
158
  proc_dir = Path("/proc")
@@ -125,75 +161,143 @@ class CudaCheckpointSession:
125
161
  "OS does not have /proc path rendering it incompatible with GPU memory snapshots."
126
162
  )
127
163
 
128
- for entry in proc_dir.iterdir():
129
- if not entry.name.isdigit():
130
- continue
131
-
132
- pid = int(entry.name)
133
- try:
134
- # Call cuda-checkpoint to check if this PID has a CUDA session
135
- result = subprocess.run(
136
- [CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(pid)],
137
- capture_output=True,
138
- text=True,
139
- timeout=10,
140
- )
141
-
142
- # If the command succeeds (return code 0), this PID has a CUDA session
143
- if result.returncode == 0:
144
- state_str = result.stdout.strip().lower()
145
- state = CudaCheckpointState(state_str)
146
-
147
- cuda_checkpoint_process = CudaCheckpointProcess(pid=pid, state=state)
148
- cuda_pids.append(cuda_checkpoint_process)
164
+ # Get all numeric directories (PIDs) from /proc
165
+ pid_dirs = [entry for entry in proc_dir.iterdir() if entry.name.isdigit()]
149
166
 
150
- # Command failed, which is expected for PIDs without CUDA sessions
151
- except subprocess.CalledProcessError:
152
- continue
167
+ # Use ThreadPoolExecutor to check PIDs in parallel for better performance
168
+ with ThreadPoolExecutor(max_workers=min(50, len(pid_dirs))) as executor:
169
+ future_to_pid = {
170
+ executor.submit(self._check_cuda_session, int(entry.name)): int(entry.name) for entry in pid_dirs
171
+ }
153
172
 
154
- # Raise other exceptions
155
- except subprocess.TimeoutExpired:
156
- raise CudaCheckpointException(f"Failed to get CUDA state for PID {pid}")
157
- except Exception as e:
158
- raise CudaCheckpointException(e)
173
+ for future in as_completed(future_to_pid):
174
+ pid = future_to_pid[future]
175
+ try:
176
+ cuda_process = future.result()
177
+ if cuda_process:
178
+ cuda_pids.append(cuda_process)
179
+ except Exception as e:
180
+ logger.debug(f"Error checking PID {pid}: {e}")
159
181
 
160
182
  # Sort PIDs for ordered checkpointing
161
183
  cuda_pids.sort(key=lambda x: x.pid)
162
184
  return cuda_pids
163
185
 
186
+ def _check_cuda_session(self, pid: int) -> Optional[CudaCheckpointProcess]:
187
+ """Check if a specific PID has a CUDA session."""
188
+ try:
189
+ result = subprocess.run(
190
+ [CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(pid)],
191
+ capture_output=True,
192
+ text=True,
193
+ timeout=5,
194
+ )
195
+
196
+ # If the command succeeds (return code 0), this PID has a CUDA session
197
+ if result.returncode == 0:
198
+ state_str = result.stdout.strip().lower()
199
+ state = CudaCheckpointState(state_str)
200
+ return CudaCheckpointProcess(pid=pid, state=state)
201
+
202
+ except subprocess.CalledProcessError:
203
+ # Command failed, which is expected for PIDs without CUDA sessions
204
+ pass
205
+ except subprocess.TimeoutExpired:
206
+ logger.debug(f"Timeout checking CUDA state for PID {pid}")
207
+ except Exception as e:
208
+ logger.debug(f"Error checking PID {pid}: {e}")
209
+
210
+ return None
211
+
164
212
  def checkpoint(self) -> None:
213
+ """Checkpoint all CUDA processes, moving GPU memory to CPU."""
214
+ if not self.cuda_processes:
215
+ logger.debug("No CUDA processes to checkpoint.")
216
+ return
217
+
165
218
  # Validate all states first
166
219
  for proc in self.cuda_processes:
220
+ proc.refresh_state() # Refresh state before validation
167
221
  if proc.state != CudaCheckpointState.RUNNING:
168
- raise CudaCheckpointException(f"CUDA session not in {CudaCheckpointState.RUNNING} state.")
222
+ raise CudaCheckpointException(
223
+ f"PID {proc.pid}: CUDA session not in {CudaCheckpointState.RUNNING.value} state. "
224
+ f"Current state: {proc.state.value}"
225
+ )
169
226
 
170
227
  # Moving state from GPU to CPU can take several seconds per CUDA session.
171
228
  # Make a parallel call per CUDA session.
172
229
  start = time.perf_counter()
173
230
 
174
- def checkpoint_impl(proc: CudaCheckpointProcess):
231
+ def checkpoint_impl(proc: CudaCheckpointProcess) -> None:
175
232
  proc.toggle(CudaCheckpointState.CHECKPOINTED)
176
233
 
177
234
  with ThreadPoolExecutor() as executor:
178
- list(executor.map(checkpoint_impl, self.cuda_processes))
235
+ futures = [executor.submit(checkpoint_impl, proc) for proc in self.cuda_processes]
236
+
237
+ # Wait for all futures and collect any exceptions
238
+ exceptions = []
239
+ for future in as_completed(futures):
240
+ try:
241
+ future.result()
242
+ except Exception as e:
243
+ exceptions.append(e)
244
+
245
+ if exceptions:
246
+ raise CudaCheckpointException(
247
+ f"Failed to checkpoint {len(exceptions)} processes: {'; '.join(str(e) for e in exceptions)}"
248
+ )
179
249
 
180
250
  elapsed = time.perf_counter() - start
181
- logger.debug(f"Checkpointing CUDA sessions took => {elapsed:.3f}s")
251
+ logger.debug(f"Checkpointing {len(self.cuda_processes)} CUDA sessions took => {elapsed:.3f}s")
182
252
 
183
253
  def restore(self) -> None:
254
+ """Restore all CUDA processes, moving memory back from CPU to GPU."""
255
+ if not self.cuda_processes:
256
+ logger.debug("No CUDA sessions to restore.")
257
+ return
258
+
184
259
  # Validate all states first
185
260
  for proc in self.cuda_processes:
261
+ proc.refresh_state() # Refresh state before validation
186
262
  if proc.state != CudaCheckpointState.CHECKPOINTED:
187
- raise CudaCheckpointException(f"CUDA session not in {CudaCheckpointState.CHECKPOINTED} state.")
263
+ raise CudaCheckpointException(
264
+ f"PID {proc.pid}: CUDA session not in {CudaCheckpointState.CHECKPOINTED.value} state. "
265
+ f"Current state: {proc.state.value}"
266
+ )
188
267
 
189
268
  # See checkpoint() for rationale about parallelism.
190
269
  start = time.perf_counter()
191
270
 
192
- def restore_process(proc: CudaCheckpointProcess):
271
+ def restore_process(proc: CudaCheckpointProcess) -> None:
193
272
  proc.toggle(CudaCheckpointState.RUNNING)
194
273
 
195
274
  with ThreadPoolExecutor() as executor:
196
- list(executor.map(restore_process, self.cuda_processes))
275
+ futures = [executor.submit(restore_process, proc) for proc in self.cuda_processes]
276
+
277
+ # Wait for all futures and collect any exceptions
278
+ exceptions = []
279
+ for future in as_completed(futures):
280
+ try:
281
+ future.result()
282
+ except Exception as e:
283
+ exceptions.append(e)
284
+
285
+ if exceptions:
286
+ raise CudaCheckpointException(
287
+ f"Failed to restore {len(exceptions)} processes: {'; '.join(str(e) for e in exceptions)}"
288
+ )
197
289
 
198
290
  elapsed = time.perf_counter() - start
199
- logger.debug(f"Restoring CUDA sessions took => {elapsed:.3f}s")
291
+ logger.debug(f"Restoring {len(self.cuda_processes)} CUDA session(s) took => {elapsed:.3f}s")
292
+
293
+ def get_process_count(self) -> int:
294
+ """Get the number of CUDA processes managed by this session."""
295
+ return len(self.cuda_processes)
296
+
297
+ def get_process_states(self) -> List[tuple[int, CudaCheckpointState]]:
298
+ """Get current states of all managed processes."""
299
+ states = []
300
+ for proc in self.cuda_processes:
301
+ proc.refresh_state()
302
+ states.append((proc.pid, proc.state))
303
+ return states
@@ -279,7 +279,9 @@ class TimestampPriorityQueue(Generic[T]):
279
279
 
280
280
  def __init__(self, maxsize: int = 0):
281
281
  self.condition = asyncio.Condition()
282
- self._queue: asyncio.PriorityQueue[tuple[float, Union[T, None]]] = asyncio.PriorityQueue(maxsize=maxsize)
282
+ self._queue: asyncio.PriorityQueue[tuple[float, int, Union[T, None]]] = asyncio.PriorityQueue(maxsize=maxsize)
283
+ # Used to tiebreak items with the same timestamp that are not comparable. (eg. protos)
284
+ self._counter = itertools.count()
283
285
 
284
286
  async def close(self):
285
287
  await self.put(self._MAX_PRIORITY, None)
@@ -288,7 +290,7 @@ class TimestampPriorityQueue(Generic[T]):
288
290
  """
289
291
  Add an item to the queue to be processed at a specific timestamp.
290
292
  """
291
- await self._queue.put((timestamp, item))
293
+ await self._queue.put((timestamp, next(self._counter), item))
292
294
  async with self.condition:
293
295
  self.condition.notify_all() # notify any waiting coroutines
294
296
 
@@ -301,7 +303,7 @@ class TimestampPriorityQueue(Generic[T]):
301
303
  while self.empty():
302
304
  await self.condition.wait()
303
305
  # peek at the next item
304
- timestamp, item = await self._queue.get()
306
+ timestamp, counter, item = await self._queue.get()
305
307
  now = time.time()
306
308
  if timestamp < now:
307
309
  return item
@@ -309,7 +311,7 @@ class TimestampPriorityQueue(Generic[T]):
309
311
  return None
310
312
  # not ready yet, calculate sleep time
311
313
  sleep_time = timestamp - now
312
- self._queue.put_nowait((timestamp, item)) # put it back
314
+ self._queue.put_nowait((timestamp, counter, item)) # put it back
313
315
  # wait until either the timeout or a new item is added
314
316
  try:
315
317
  await asyncio.wait_for(self.condition.wait(), timeout=sleep_time)
@@ -27,7 +27,7 @@ class _AuthTokenManager:
27
27
  self._expiry = 0.0
28
28
  self._lock: typing.Union[asyncio.Lock, None] = None
29
29
 
30
- async def get_token(self):
30
+ async def get_token(self) -> str:
31
31
  """
32
32
  When called, the AuthTokenManager can be in one of three states:
33
33
  1. Has a valid cached token. It is returned to the caller.
@@ -188,16 +188,10 @@ def get_content_length(data: BinaryIO) -> int:
188
188
  return content_length - pos
189
189
 
190
190
 
191
- async def _measure_endpoint_latency(item: str) -> int:
192
- latency_ms = 0
193
- t0 = time.monotonic_ns()
194
- async with ClientSessionRegistry.get_session().head(item) as _:
195
- latency_ms = (time.monotonic_ns() - t0) // 1_000_000
196
- return latency_ms
197
-
198
-
199
- async def _blob_upload_with_fallback(items, blob_ids: list[str], callback) -> tuple[str, bool, int]:
200
- r2_latency_ms = 0
191
+ async def _blob_upload_with_fallback(
192
+ items, blob_ids: list[str], callback, content_length: int
193
+ ) -> tuple[str, bool, int]:
194
+ r2_throughput_bytes_s = 0
201
195
  r2_failed = False
202
196
  for idx, (item, blob_id) in enumerate(zip(items, blob_ids)):
203
197
  # We want to default to R2 95% of the time and S3 5% of the time.
@@ -206,14 +200,13 @@ async def _blob_upload_with_fallback(items, blob_ids: list[str], callback) -> tu
206
200
  continue
207
201
  try:
208
202
  if blob_id.endswith(":r2"):
209
- # measure the time it takes to contact the bucket endpoint
210
- r2_latency_ms, _ = await asyncio.gather(
211
- _measure_endpoint_latency(item),
212
- callback(item),
213
- )
203
+ t0 = time.monotonic_ns()
204
+ await callback(item)
205
+ dt_ns = time.monotonic_ns() - t0
206
+ r2_throughput_bytes_s = (content_length * 1_000_000_000) // max(dt_ns, 1)
214
207
  else:
215
208
  await callback(item)
216
- return blob_id, r2_failed, r2_latency_ms
209
+ return blob_id, r2_failed, r2_throughput_bytes_s
217
210
  except Exception as _:
218
211
  if blob_id.endswith(":r2"):
219
212
  r2_failed = True
@@ -251,10 +244,11 @@ async def _blob_upload(
251
244
  progress_report_cb=progress_report_cb,
252
245
  )
253
246
 
254
- blob_id, r2_failed, r2_latency_ms = await _blob_upload_with_fallback(
247
+ blob_id, r2_failed, r2_throughput_bytes_s = await _blob_upload_with_fallback(
255
248
  resp.multiparts.items,
256
249
  resp.blob_ids,
257
250
  upload_multipart_upload,
251
+ content_length=content_length,
258
252
  )
259
253
  else:
260
254
  from .bytes_io_segment_payload import BytesIOSegmentPayload
@@ -271,16 +265,17 @@ async def _blob_upload(
271
265
  content_md5_b64=upload_hashes.md5_base64,
272
266
  )
273
267
 
274
- blob_id, r2_failed, r2_latency_ms = await _blob_upload_with_fallback(
268
+ blob_id, r2_failed, r2_throughput_bytes_s = await _blob_upload_with_fallback(
275
269
  resp.upload_urls.items,
276
270
  resp.blob_ids,
277
271
  upload_to_s3_url,
272
+ content_length=content_length,
278
273
  )
279
274
 
280
275
  if progress_report_cb:
281
276
  progress_report_cb(complete=True)
282
277
 
283
- return blob_id, r2_failed, r2_latency_ms
278
+ return blob_id, r2_failed, r2_throughput_bytes_s
284
279
 
285
280
 
286
281
  async def blob_upload_with_r2_failure_info(payload: bytes, stub: ModalClientModal) -> tuple[str, bool, int]:
@@ -291,13 +286,13 @@ async def blob_upload_with_r2_failure_info(payload: bytes, stub: ModalClientModa
291
286
  logger.warning("Blob uploading string, not bytes - auto-encoding as utf8")
292
287
  payload = payload.encode("utf8")
293
288
  upload_hashes = get_upload_hashes(payload)
294
- blob_id, r2_failed, r2_latency_ms = await _blob_upload(upload_hashes, payload, stub)
289
+ blob_id, r2_failed, r2_throughput_bytes_s = await _blob_upload(upload_hashes, payload, stub)
295
290
  dur_s = max(time.time() - t0, 0.001) # avoid division by zero
296
291
  throughput_mib_s = (size_mib) / dur_s
297
292
  logger.debug(
298
293
  f"Uploaded large blob of size {size_mib:.2f} MiB ({throughput_mib_s:.2f} MiB/s, total {dur_s:.2f}s). {blob_id}"
299
294
  )
300
- return blob_id, r2_failed, r2_latency_ms
295
+ return blob_id, r2_failed, r2_throughput_bytes_s
301
296
 
302
297
 
303
298
  async def blob_upload(payload: bytes, stub: ModalClientModal) -> str:
@@ -385,9 +385,16 @@ def callable_has_non_self_non_default_params(f: Callable[..., Any]) -> bool:
385
385
 
386
386
 
387
387
  async def _stream_function_call_data(
388
- client, stub, function_call_id: str, variant: Literal["data_in", "data_out"]
388
+ client,
389
+ stub,
390
+ function_call_id: Optional[str],
391
+ variant: Literal["data_in", "data_out"],
392
+ attempt_token: Optional[str] = None,
389
393
  ) -> AsyncGenerator[Any, None]:
390
394
  """Read from the `data_in` or `data_out` stream of a function call."""
395
+ if function_call_id is None and attempt_token is None:
396
+ raise ValueError("function_call_id or attempt_token is required for data_out stream")
397
+
391
398
  if stub is None:
392
399
  stub = client.stub
393
400
 
@@ -405,7 +412,11 @@ async def _stream_function_call_data(
405
412
  raise ValueError(f"Invalid variant {variant}")
406
413
 
407
414
  while True:
408
- req = api_pb2.FunctionCallGetDataRequest(function_call_id=function_call_id, last_index=last_index)
415
+ req = api_pb2.FunctionCallGetDataRequest(
416
+ function_call_id=function_call_id,
417
+ last_index=last_index,
418
+ attempt_token=attempt_token,
419
+ )
409
420
  try:
410
421
  async for chunk in stub_fn.unary_stream(req):
411
422
  if chunk.index <= last_index:
@@ -531,6 +542,7 @@ def should_upload(
531
542
  )
532
543
 
533
544
 
545
+ # This must be called against the client stub, not the input-plane stub.
534
546
  async def _create_input(
535
547
  args,
536
548
  kwargs,
@@ -552,7 +564,7 @@ async def _create_input(
552
564
  args_serialized = serialize((args, kwargs))
553
565
 
554
566
  if should_upload(len(args_serialized), max_object_size_bytes, function_call_invocation_type):
555
- args_blob_id, r2_failed, r2_latency_ms = await blob_upload_with_r2_failure_info(args_serialized, stub)
567
+ args_blob_id, r2_failed, r2_throughput_bytes_s = await blob_upload_with_r2_failure_info(args_serialized, stub)
556
568
  return api_pb2.FunctionPutInputsItem(
557
569
  input=api_pb2.FunctionInput(
558
570
  args_blob_id=args_blob_id,
@@ -561,7 +573,7 @@ async def _create_input(
561
573
  ),
562
574
  idx=idx,
563
575
  r2_failed=r2_failed,
564
- r2_latency_ms=r2_latency_ms,
576
+ r2_throughput_bytes_s=r2_throughput_bytes_s,
565
577
  )
566
578
  else:
567
579
  return api_pb2.FunctionPutInputsItem(
@@ -3,13 +3,17 @@ from datetime import datetime
3
3
  from typing import Optional
4
4
 
5
5
 
6
- def timestamp_to_local(ts: float, isotz: bool = True) -> Optional[str]:
6
+ def timestamp_to_localized_dt(ts: float) -> datetime:
7
+ locale_tz = datetime.now().astimezone().tzinfo
8
+ return datetime.fromtimestamp(ts, tz=locale_tz)
9
+
10
+
11
+ def timestamp_to_localized_str(ts: float, isotz: bool = True) -> Optional[str]:
7
12
  if ts > 0:
8
- locale_tz = datetime.now().astimezone().tzinfo
9
- dt = datetime.fromtimestamp(ts, tz=locale_tz)
13
+ dt = timestamp_to_localized_dt(ts)
10
14
  if isotz:
11
15
  return dt.isoformat(sep=" ", timespec="seconds")
12
16
  else:
13
- return f"{datetime.strftime(dt, '%Y-%m-%d %H:%M')} {locale_tz.tzname(dt)}"
17
+ return f"{dt:%Y-%m-%d %H:%M %Z}"
14
18
  else:
15
19
  return None
modal/app.py CHANGED
@@ -665,7 +665,6 @@ class _App:
665
665
  ] = None, # Experimental controls over fine-grained scheduling (alpha).
666
666
  _experimental_proxy_ip: Optional[str] = None, # IP address of proxy
667
667
  _experimental_custom_scaling_factor: Optional[float] = None, # Custom scaling factor
668
- _experimental_enable_gpu_snapshot: bool = False, # Experimentally enable GPU memory snapshots.
669
668
  # Parameters below here are deprecated. Please update your code as suggested
670
669
  keep_warm: Optional[int] = None, # Replaced with `min_containers`
671
670
  concurrency_limit: Optional[int] = None, # Replaced with `max_containers`
@@ -830,7 +829,6 @@ class _App:
830
829
  include_source=include_source if include_source is not None else self._include_source_default,
831
830
  experimental_options={k: str(v) for k, v in (experimental_options or {}).items()},
832
831
  _experimental_proxy_ip=_experimental_proxy_ip,
833
- _experimental_enable_gpu_snapshot=_experimental_enable_gpu_snapshot,
834
832
  )
835
833
 
836
834
  self._add_function(function, webhook_config is not None)
@@ -889,7 +887,6 @@ class _App:
889
887
  ] = None, # Experimental controls over fine-grained scheduling (alpha).
890
888
  _experimental_proxy_ip: Optional[str] = None, # IP address of proxy
891
889
  _experimental_custom_scaling_factor: Optional[float] = None, # Custom scaling factor
892
- _experimental_enable_gpu_snapshot: bool = False, # Experimentally enable GPU memory snapshots.
893
890
  # Parameters below here are deprecated. Please update your code as suggested
894
891
  keep_warm: Optional[int] = None, # Replaced with `min_containers`
895
892
  concurrency_limit: Optional[int] = None, # Replaced with `max_containers`
@@ -1014,7 +1011,6 @@ class _App:
1014
1011
  experimental_options={k: str(v) for k, v in (experimental_options or {}).items()},
1015
1012
  _experimental_proxy_ip=_experimental_proxy_ip,
1016
1013
  _experimental_custom_scaling_factor=_experimental_custom_scaling_factor,
1017
- _experimental_enable_gpu_snapshot=_experimental_enable_gpu_snapshot,
1018
1014
  )
1019
1015
 
1020
1016
  self._add_function(cls_func, is_web_endpoint=False)
modal/app.pyi CHANGED
@@ -425,7 +425,6 @@ class _App:
425
425
  _experimental_scheduler_placement: typing.Optional[modal.scheduler_placement.SchedulerPlacement] = None,
426
426
  _experimental_proxy_ip: typing.Optional[str] = None,
427
427
  _experimental_custom_scaling_factor: typing.Optional[float] = None,
428
- _experimental_enable_gpu_snapshot: bool = False,
429
428
  keep_warm: typing.Optional[int] = None,
430
429
  concurrency_limit: typing.Optional[int] = None,
431
430
  container_idle_timeout: typing.Optional[int] = None,
@@ -477,7 +476,6 @@ class _App:
477
476
  _experimental_scheduler_placement: typing.Optional[modal.scheduler_placement.SchedulerPlacement] = None,
478
477
  _experimental_proxy_ip: typing.Optional[str] = None,
479
478
  _experimental_custom_scaling_factor: typing.Optional[float] = None,
480
- _experimental_enable_gpu_snapshot: bool = False,
481
479
  keep_warm: typing.Optional[int] = None,
482
480
  concurrency_limit: typing.Optional[int] = None,
483
481
  container_idle_timeout: typing.Optional[int] = None,
@@ -1030,7 +1028,6 @@ class App:
1030
1028
  _experimental_scheduler_placement: typing.Optional[modal.scheduler_placement.SchedulerPlacement] = None,
1031
1029
  _experimental_proxy_ip: typing.Optional[str] = None,
1032
1030
  _experimental_custom_scaling_factor: typing.Optional[float] = None,
1033
- _experimental_enable_gpu_snapshot: bool = False,
1034
1031
  keep_warm: typing.Optional[int] = None,
1035
1032
  concurrency_limit: typing.Optional[int] = None,
1036
1033
  container_idle_timeout: typing.Optional[int] = None,
@@ -1082,7 +1079,6 @@ class App:
1082
1079
  _experimental_scheduler_placement: typing.Optional[modal.scheduler_placement.SchedulerPlacement] = None,
1083
1080
  _experimental_proxy_ip: typing.Optional[str] = None,
1084
1081
  _experimental_custom_scaling_factor: typing.Optional[float] = None,
1085
- _experimental_enable_gpu_snapshot: bool = False,
1086
1082
  keep_warm: typing.Optional[int] = None,
1087
1083
  concurrency_limit: typing.Optional[int] = None,
1088
1084
  container_idle_timeout: typing.Optional[int] = None,
modal/cli/_traceback.py CHANGED
@@ -6,12 +6,13 @@ import re
6
6
  import warnings
7
7
  from typing import Optional
8
8
 
9
- from rich.console import Console, RenderResult, group
9
+ from rich.console import RenderResult, group
10
10
  from rich.panel import Panel
11
11
  from rich.syntax import Syntax
12
12
  from rich.text import Text
13
13
  from rich.traceback import PathHighlighter, Stack, Traceback, install
14
14
 
15
+ from .._output import make_console
15
16
  from ..exception import DeprecationError, PendingDeprecationError, ServerWarning
16
17
 
17
18
 
@@ -193,7 +194,7 @@ def highlight_modal_warnings() -> None:
193
194
  title=title,
194
195
  title_align="left",
195
196
  )
196
- Console().print(panel)
197
+ make_console().print(panel)
197
198
  else:
198
199
  base_showwarning(warning, category, filename, lineno, file=None, line=None)
199
200
 
modal/cli/app.py CHANGED
@@ -15,7 +15,7 @@ from modal.client import _Client
15
15
  from modal.environments import ensure_env
16
16
  from modal_proto import api_pb2
17
17
 
18
- from .._utils.time_utils import timestamp_to_local
18
+ from .._utils.time_utils import timestamp_to_localized_str
19
19
  from .utils import ENV_OPTION, display_table, get_app_id_from_name, stream_app_logs
20
20
 
21
21
  APP_IDENTIFIER = Argument("", help="App name or ID")
@@ -71,8 +71,8 @@ async def list_(env: Optional[str] = ENV_OPTION, json: bool = False):
71
71
  app_stats.description,
72
72
  state,
73
73
  str(app_stats.n_running_tasks),
74
- timestamp_to_local(app_stats.created_at, json),
75
- timestamp_to_local(app_stats.stopped_at, json),
74
+ timestamp_to_localized_str(app_stats.created_at, json),
75
+ timestamp_to_localized_str(app_stats.stopped_at, json),
76
76
  ]
77
77
  )
78
78
 
@@ -217,7 +217,7 @@ async def history(
217
217
 
218
218
  row = [
219
219
  Text(f"v{app_stats.version}", style=style),
220
- Text(timestamp_to_local(app_stats.deployed_at, json), style=style),
220
+ Text(timestamp_to_localized_str(app_stats.deployed_at, json), style=style),
221
221
  Text(app_stats.client_version, style=style),
222
222
  Text(app_stats.deployed_by, style=style),
223
223
  ]