modal 1.1.5.dev66__py3-none-any.whl → 1.3.1.dev8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (143) hide show
  1. modal/__init__.py +4 -4
  2. modal/__main__.py +4 -29
  3. modal/_billing.py +84 -0
  4. modal/_clustered_functions.py +1 -3
  5. modal/_container_entrypoint.py +33 -208
  6. modal/_functions.py +171 -138
  7. modal/_grpc_client.py +191 -0
  8. modal/_ipython.py +16 -6
  9. modal/_load_context.py +106 -0
  10. modal/_object.py +72 -21
  11. modal/_output.py +12 -14
  12. modal/_partial_function.py +31 -4
  13. modal/_resolver.py +44 -57
  14. modal/_runtime/container_io_manager.py +30 -28
  15. modal/_runtime/container_io_manager.pyi +42 -44
  16. modal/_runtime/gpu_memory_snapshot.py +9 -7
  17. modal/_runtime/user_code_event_loop.py +80 -0
  18. modal/_runtime/user_code_imports.py +236 -10
  19. modal/_serialization.py +2 -1
  20. modal/_traceback.py +4 -13
  21. modal/_tunnel.py +16 -11
  22. modal/_tunnel.pyi +25 -3
  23. modal/_utils/async_utils.py +337 -10
  24. modal/_utils/auth_token_manager.py +1 -4
  25. modal/_utils/blob_utils.py +29 -22
  26. modal/_utils/function_utils.py +20 -21
  27. modal/_utils/grpc_testing.py +6 -3
  28. modal/_utils/grpc_utils.py +223 -64
  29. modal/_utils/mount_utils.py +26 -1
  30. modal/_utils/name_utils.py +2 -3
  31. modal/_utils/package_utils.py +0 -1
  32. modal/_utils/rand_pb_testing.py +8 -1
  33. modal/_utils/task_command_router_client.py +524 -0
  34. modal/_vendor/cloudpickle.py +144 -48
  35. modal/app.py +285 -105
  36. modal/app.pyi +216 -53
  37. modal/billing.py +5 -0
  38. modal/builder/2025.06.txt +6 -3
  39. modal/builder/PREVIEW.txt +2 -1
  40. modal/builder/base-images.json +4 -2
  41. modal/cli/_download.py +19 -3
  42. modal/cli/cluster.py +4 -2
  43. modal/cli/config.py +3 -1
  44. modal/cli/container.py +5 -4
  45. modal/cli/dict.py +5 -2
  46. modal/cli/entry_point.py +26 -2
  47. modal/cli/environment.py +2 -16
  48. modal/cli/launch.py +1 -76
  49. modal/cli/network_file_system.py +5 -20
  50. modal/cli/programs/run_jupyter.py +1 -1
  51. modal/cli/programs/vscode.py +1 -1
  52. modal/cli/queues.py +5 -4
  53. modal/cli/run.py +24 -204
  54. modal/cli/secret.py +1 -2
  55. modal/cli/shell.py +375 -0
  56. modal/cli/utils.py +1 -13
  57. modal/cli/volume.py +11 -17
  58. modal/client.py +16 -125
  59. modal/client.pyi +94 -144
  60. modal/cloud_bucket_mount.py +3 -1
  61. modal/cloud_bucket_mount.pyi +4 -0
  62. modal/cls.py +101 -64
  63. modal/cls.pyi +9 -8
  64. modal/config.py +21 -1
  65. modal/container_process.py +288 -12
  66. modal/container_process.pyi +99 -38
  67. modal/dict.py +72 -33
  68. modal/dict.pyi +88 -57
  69. modal/environments.py +16 -8
  70. modal/environments.pyi +6 -2
  71. modal/exception.py +154 -16
  72. modal/experimental/__init__.py +24 -53
  73. modal/experimental/flash.py +161 -74
  74. modal/experimental/flash.pyi +97 -49
  75. modal/file_io.py +50 -92
  76. modal/file_io.pyi +117 -89
  77. modal/functions.pyi +70 -87
  78. modal/image.py +82 -47
  79. modal/image.pyi +51 -30
  80. modal/io_streams.py +500 -149
  81. modal/io_streams.pyi +279 -189
  82. modal/mount.py +60 -46
  83. modal/mount.pyi +41 -17
  84. modal/network_file_system.py +19 -11
  85. modal/network_file_system.pyi +72 -39
  86. modal/object.pyi +114 -22
  87. modal/parallel_map.py +42 -44
  88. modal/parallel_map.pyi +9 -17
  89. modal/partial_function.pyi +4 -2
  90. modal/proxy.py +14 -6
  91. modal/proxy.pyi +10 -2
  92. modal/queue.py +45 -38
  93. modal/queue.pyi +88 -52
  94. modal/runner.py +96 -96
  95. modal/runner.pyi +44 -27
  96. modal/sandbox.py +225 -107
  97. modal/sandbox.pyi +226 -60
  98. modal/secret.py +58 -56
  99. modal/secret.pyi +28 -13
  100. modal/serving.py +7 -11
  101. modal/serving.pyi +7 -8
  102. modal/snapshot.py +29 -15
  103. modal/snapshot.pyi +18 -10
  104. modal/token_flow.py +1 -1
  105. modal/token_flow.pyi +4 -6
  106. modal/volume.py +102 -55
  107. modal/volume.pyi +125 -66
  108. {modal-1.1.5.dev66.dist-info → modal-1.3.1.dev8.dist-info}/METADATA +10 -9
  109. modal-1.3.1.dev8.dist-info/RECORD +189 -0
  110. modal_proto/api.proto +141 -70
  111. modal_proto/api_grpc.py +42 -26
  112. modal_proto/api_pb2.py +1123 -1103
  113. modal_proto/api_pb2.pyi +331 -83
  114. modal_proto/api_pb2_grpc.py +80 -48
  115. modal_proto/api_pb2_grpc.pyi +26 -18
  116. modal_proto/modal_api_grpc.py +175 -174
  117. modal_proto/task_command_router.proto +164 -0
  118. modal_proto/task_command_router_grpc.py +138 -0
  119. modal_proto/task_command_router_pb2.py +180 -0
  120. modal_proto/{sandbox_router_pb2.pyi → task_command_router_pb2.pyi} +148 -57
  121. modal_proto/task_command_router_pb2_grpc.py +272 -0
  122. modal_proto/task_command_router_pb2_grpc.pyi +100 -0
  123. modal_version/__init__.py +1 -1
  124. modal_version/__main__.py +1 -1
  125. modal/cli/programs/launch_instance_ssh.py +0 -94
  126. modal/cli/programs/run_marimo.py +0 -95
  127. modal-1.1.5.dev66.dist-info/RECORD +0 -191
  128. modal_proto/modal_options_grpc.py +0 -3
  129. modal_proto/options.proto +0 -19
  130. modal_proto/options_grpc.py +0 -3
  131. modal_proto/options_pb2.py +0 -35
  132. modal_proto/options_pb2.pyi +0 -20
  133. modal_proto/options_pb2_grpc.py +0 -4
  134. modal_proto/options_pb2_grpc.pyi +0 -7
  135. modal_proto/sandbox_router.proto +0 -125
  136. modal_proto/sandbox_router_grpc.py +0 -89
  137. modal_proto/sandbox_router_pb2.py +0 -128
  138. modal_proto/sandbox_router_pb2_grpc.py +0 -169
  139. modal_proto/sandbox_router_pb2_grpc.pyi +0 -63
  140. {modal-1.1.5.dev66.dist-info → modal-1.3.1.dev8.dist-info}/WHEEL +0 -0
  141. {modal-1.1.5.dev66.dist-info → modal-1.3.1.dev8.dist-info}/entry_points.txt +0 -0
  142. {modal-1.1.5.dev66.dist-info → modal-1.3.1.dev8.dist-info}/licenses/LICENSE +0 -0
  143. {modal-1.1.5.dev66.dist-info → modal-1.3.1.dev8.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,16 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import asyncio
3
3
  import concurrent.futures
4
+ import contextlib
4
5
  import functools
5
6
  import inspect
6
7
  import itertools
8
+ import os
7
9
  import sys
8
10
  import time
11
+ import types
9
12
  import typing
13
+ import warnings
10
14
  from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Iterable, Iterator
11
15
  from contextlib import asynccontextmanager
12
16
  from dataclasses import dataclass
@@ -22,10 +26,14 @@ from typing import (
22
26
 
23
27
  import synchronicity
24
28
  from synchronicity.async_utils import Runner
29
+ from synchronicity.combined_types import MethodWithAio
25
30
  from synchronicity.exceptions import NestedEventLoops
26
31
  from typing_extensions import ParamSpec, assert_type
27
32
 
28
- from ..exception import InvalidError
33
+ from modal._ipython import is_interactive_ipython
34
+ from modal._utils.deprecation import deprecation_warning
35
+
36
+ from ..exception import AsyncUsageWarning, InvalidError
29
37
  from .logger import logger
30
38
 
31
39
  T = TypeVar("T")
@@ -36,7 +44,285 @@ if sys.platform == "win32":
36
44
  # quick workaround for deadlocks on shutdown - need to investigate further
37
45
  asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
38
46
 
39
- synchronizer = synchronicity.Synchronizer()
47
+
48
+ def rewrite_sync_to_async(code_line: str, original_func: Callable) -> tuple[bool, str]:
49
+ """
50
+ Rewrite a blocking call to use async/await syntax.
51
+
52
+ Handles four patterns:
53
+ 1. __aiter__: for x in obj -> async for x in obj
54
+ 2. __aenter__: with obj as x -> async with obj as x
55
+ 3. Async generators in for loops: for x in obj.method(...) -> async for x in obj.method(...)
56
+ 4. Regular methods: obj.method() -> await obj.method.aio()
57
+
58
+ Args:
59
+ code_line: The line of code containing the blocking call
60
+ original_func: The original function object being called
61
+
62
+ Returns:
63
+ A tuple of (success, rewritten_code):
64
+ - success: True if the pattern was found and rewritten, False if falling back to generic
65
+ - rewritten_code: The rewritten code or a generic suggestion
66
+ """
67
+ import re
68
+
69
+ func_name = original_func.__name__ # type: ignore
70
+
71
+ # Check if this is an async generator function
72
+ is_async_gen = inspect.isasyncgenfunction(original_func)
73
+
74
+ # Handle __aiter__ pattern: for x in obj -> async for x in obj
75
+ if func_name == "__aiter__" and code_line.startswith("for "):
76
+ suggestion = code_line.replace("for ", "async for ", 1)
77
+ return (True, suggestion)
78
+
79
+ # Handle __aenter__ pattern: with obj as x -> async with obj as x
80
+ if func_name == "__aenter__" and code_line.startswith("with "):
81
+ suggestion = code_line.replace("with ", "async with ", 1)
82
+ return (True, suggestion)
83
+
84
+ # Handle __setitem__ pattern: dct['key'] = value -> suggest alternative
85
+ if func_name == "__setitem__":
86
+ # Try to extract the object and key from the bracket syntax
87
+ setitem_match = re.match(r"(\w+)\[([^\]]+)\]\s*=\s*(.+)", code_line.strip())
88
+ if setitem_match:
89
+ obj, key, value = setitem_match.groups()
90
+ suggestion = (
91
+ f"You can't use `{obj}[{key}] = {value}` syntax asynchronously - "
92
+ f"there may be an alternative api, e.g. {obj}.put.aio({key}, {value})"
93
+ )
94
+ return (False, suggestion)
95
+ return (False, f"await ...{func_name}.aio(...)")
96
+
97
+ # Handle __getitem__ pattern: dct['key'] -> suggest alternative
98
+ if func_name == "__getitem__":
99
+ # Try to extract the object and key from the bracket syntax
100
+ getitem_match = re.match(r"(\w+)\[([^\]]+)\]$", code_line.strip())
101
+ if getitem_match:
102
+ obj, key = getitem_match.groups()
103
+ suggestion = (
104
+ f"You can't use `{obj}[{key}]` syntax asynchronously - "
105
+ f"there may be an alternative api, e.g. {obj}.get.aio({key})"
106
+ )
107
+ return (False, suggestion)
108
+ return (False, f"await ...{func_name}.aio(...)")
109
+
110
+ # Handle async generator methods in for loops: for x in obj.method(...) -> async for x in obj.method(...)
111
+ if is_async_gen and code_line.strip().startswith("for "):
112
+ # Pattern: for <var> in <expr>.<method>(<args>):
113
+ for_pattern = rf"(for\s+\w+\s+in\s+.*\.){re.escape(func_name)}(\s*\()"
114
+ for_match = re.search(for_pattern, code_line)
115
+
116
+ if for_match:
117
+ # Just replace "for" with "async for" - no .aio() needed for async generators
118
+ suggestion = code_line.replace("for ", "async for ", 1)
119
+ return (True, suggestion)
120
+
121
+ # Handle regular method calls and property access
122
+ # First check if it's a property access (no parentheses after the name)
123
+ property_pattern = rf"\.{re.escape(func_name)}(?!\s*\()"
124
+ property_match = re.search(property_pattern, code_line)
125
+
126
+ if property_match:
127
+ # This is a property access, rewrite to use await without .aio()
128
+ # Find the start of the expression (skip statement keywords and assignments)
129
+ statement_start = 0
130
+ prefix_match = re.match(r"^(\s*(?:\w+\s*=|return|yield|raise)\s+)", code_line)
131
+ if prefix_match:
132
+ statement_start = len(prefix_match.group(1))
133
+
134
+ before_expr = code_line[:statement_start]
135
+ after_prefix = code_line[statement_start:]
136
+
137
+ # Just add await before the expression for properties
138
+ suggestion = before_expr + "await " + after_prefix.lstrip()
139
+ return (True, suggestion)
140
+
141
+ # Try to find a method call (with parentheses)
142
+ method_pattern = rf"\.{re.escape(func_name)}\s*\("
143
+ method_match = re.search(method_pattern, code_line)
144
+
145
+ if not method_match:
146
+ # Can't find the function call or property
147
+ return (False, f"await ...{func_name}.aio(...)")
148
+
149
+ # Safety check: don't attempt rewrite for complex expressions
150
+ unsafe_keywords = ["if", "elif", "while", "and", "or", "not", "in", "is", "for"]
151
+
152
+ # Check if line contains control flow keywords (might be too complex)
153
+ for keyword in unsafe_keywords:
154
+ if re.search(rf"\b{keyword}\b", code_line):
155
+ # Fall back to generic suggestion for complex expressions
156
+ return (False, f"await ...{func_name}.aio(...)")
157
+
158
+ # Find the start of the object expression that leads to the method call
159
+ # We need to find where the object/chain starts, e.g., in "2 * foo.bar.method()" we want "foo"
160
+ # Work backwards from the method match to find the start of the identifier chain
161
+ method_start = method_match.start()
162
+
163
+ # Find the start of the identifier chain (the object being called)
164
+ # Walk backwards to find identifiers and dots that form the chain
165
+ expr_start = method_start
166
+ i = method_start - 1
167
+ while i >= 0:
168
+ c = code_line[i]
169
+ if c.isalnum() or c == "_" or c == ".":
170
+ expr_start = i
171
+ i -= 1
172
+ elif c.isspace():
173
+ # Skip whitespace within the chain (though unusual)
174
+ i -= 1
175
+ else:
176
+ # Found a non-identifier character, stop
177
+ break
178
+
179
+ # Now expr_start points to the start of the object chain (e.g., "foo" in "foo.method()")
180
+ # But we need to check if the identifier we found is actually a keyword like return/yield/raise
181
+ # In that case, skip over it and find the actual object
182
+ before_obj = code_line[:expr_start]
183
+ obj_and_rest = code_line[expr_start:]
184
+
185
+ # Check if what we found starts with a statement keyword
186
+ keyword_match = re.match(r"^(return|yield|raise)\s+", obj_and_rest)
187
+ if keyword_match:
188
+ # The "object" we found is actually a keyword, adjust to skip it
189
+ keyword_len = len(keyword_match.group(0))
190
+ before_obj = code_line[: expr_start + keyword_len]
191
+ obj_and_rest = code_line[expr_start + keyword_len :]
192
+
193
+ # Add .aio() after the method name and await before the object
194
+ rewritten_expr = re.sub(rf"(\.{re.escape(func_name)})\s*\(", r"\1.aio(", obj_and_rest, count=1)
195
+ suggestion = before_obj + "await " + rewritten_expr
196
+
197
+ return (True, suggestion)
198
+
199
+
200
+ @dataclass
201
+ class _CallFrame:
202
+ """Simple dataclass to hold call frame information."""
203
+
204
+ filename: str
205
+ lineno: int
206
+ line: Optional[str]
207
+
208
+
209
+ def _extract_user_call_frame():
210
+ """
211
+ Extract the call frame from user code by filtering out frames from synchronicity and asyncio.
212
+
213
+ Returns a _CallFrame with the filename, line number, and source line, or None if not found.
214
+ """
215
+ import linecache
216
+ import os
217
+
218
+ # Get the current call stack
219
+ stack = inspect.stack()
220
+
221
+ # Get the absolute path of this module to filter it out
222
+ this_file = os.path.abspath(__file__)
223
+
224
+ # Filter out frames from synchronicity, asyncio, and this module
225
+ for frame_info in stack:
226
+ filename = frame_info.filename
227
+ # Skip frames from synchronicity, asyncio packages, and this module
228
+ # Use path separators to ensure we're matching packages, not just filenames containing these words
229
+ if (
230
+ os.path.sep + "synchronicity" + os.path.sep in filename
231
+ or os.path.sep + "asyncio" + os.path.sep in filename
232
+ or os.path.abspath(filename) == this_file
233
+ ):
234
+ continue
235
+
236
+ # Found a user frame
237
+ line = linecache.getline(filename, frame_info.lineno)
238
+ return _CallFrame(filename=filename, lineno=frame_info.lineno, line=line if line else None)
239
+
240
+ # Fallback if we can't find a suitable frame
241
+ return None
242
+
243
+
244
+ def _blocking_in_async_warning(original_func: types.FunctionType):
245
+ if is_interactive_ipython():
246
+ # in notebooks or interactive sessions where sync usage is expected
247
+ # even if it's actually running in an event loop
248
+ return
249
+
250
+ import warnings
251
+
252
+ # Skip warnings for __aexit__ and __anext__ - the __aenter__ and __aiter__ warnings are sufficient
253
+ if original_func:
254
+ func_name = getattr(original_func, "__name__", str(original_func))
255
+ if func_name in ("__aexit__", "__anext__"):
256
+ # These dunders would typically already have caused a warning on the __aenter__ or __aiter__ respectively
257
+ return
258
+
259
+ # Extract the call frame from the stack
260
+ call_frame = _extract_user_call_frame()
261
+
262
+ # Build detailed warning message with location and function first
263
+ message_parts = [
264
+ "A blocking Modal interface is being used in an async context.",
265
+ "\n\nThis may cause performance issues or bugs.",
266
+ " Consider rewriting to use Modal's async interfaces:",
267
+ "\nhttps://modal.com/docs/guide/async",
268
+ ]
269
+
270
+ # Generate intelligent suggestion based on the context
271
+ suggestion = None
272
+ code_line = None
273
+
274
+ if original_func and call_frame and call_frame.line:
275
+ code_line = call_frame.line.strip()
276
+ # Use the unified rewrite function for all patterns
277
+ _, suggestion = rewrite_sync_to_async(code_line, original_func)
278
+
279
+ # Add suggestion in "change X to Y" format
280
+ if suggestion and code_line:
281
+ # this is a bit ugly, but the warnings formatter will show the offending source line
282
+ # on the last line regardless what we do, so we add this to not make it look out of place
283
+ message_parts.append(f"\n\nSuggested rewrite:\n {suggestion}\n\nOriginal line:")
284
+
285
+ # Use warn_explicit to provide precise location information from the call frame
286
+ if call_frame:
287
+ # Extract module name from filename, or use a default
288
+ module_name = os.path.splitext(os.path.basename(call_frame.filename))[0]
289
+
290
+ warnings.warn_explicit(
291
+ "".join(message_parts),
292
+ AsyncUsageWarning,
293
+ filename=call_frame.filename,
294
+ lineno=call_frame.lineno,
295
+ module=module_name,
296
+ )
297
+ else:
298
+ # Fallback to regular warn if no frame information available
299
+ warnings.warn("".join(message_parts), AsyncUsageWarning)
300
+
301
+
302
+ def _safe_blocking_in_async_warning(original_func: types.FunctionType):
303
+ """
304
+ Safety wrapper around _blocking_in_async_warning to ensure it never raises exceptions.
305
+
306
+ This is non-critical functionality (just a warning), so we don't want it to break user code.
307
+ However, if the warning has been configured to be treated as an error (via filterwarnings),
308
+ we should let that propagate.
309
+ """
310
+ from ..config import config
311
+
312
+ if not config.get("async_warnings"):
313
+ return
314
+ try:
315
+ _blocking_in_async_warning(original_func)
316
+ except AsyncUsageWarning:
317
+ # Re-raise the warning if it's been configured as an error
318
+ raise
319
+ except Exception:
320
+ # Silently ignore any other errors in the warning system
321
+ # We don't want the warning mechanism itself to cause problems
322
+ pass
323
+
324
+
325
+ synchronizer = synchronicity.Synchronizer(blocking_in_async_callback=_safe_blocking_in_async_warning)
40
326
 
41
327
 
42
328
  def synchronize_api(obj, target_module=None):
@@ -51,6 +337,10 @@ def synchronize_api(obj, target_module=None):
51
337
  return synchronizer.create_blocking(obj, blocking_name, target_module=target_module)
52
338
 
53
339
 
340
+ # Used for testing to configure the `n_attempts` that `retry` will use.
341
+ RETRY_N_ATTEMPTS_OVERRIDE: Optional[int] = None
342
+
343
+
54
344
  def retry(direct_fn=None, *, n_attempts=3, base_delay=0, delay_factor=2, timeout=90):
55
345
  """Decorator that calls an async function multiple times, with a given timeout.
56
346
 
@@ -75,8 +365,13 @@ def retry(direct_fn=None, *, n_attempts=3, base_delay=0, delay_factor=2, timeout
75
365
  def decorator(fn):
76
366
  @functools.wraps(fn)
77
367
  async def f_wrapped(*args, **kwargs):
368
+ if RETRY_N_ATTEMPTS_OVERRIDE is not None:
369
+ local_n_attempts = RETRY_N_ATTEMPTS_OVERRIDE
370
+ else:
371
+ local_n_attempts = n_attempts
372
+
78
373
  delay = base_delay
79
- for i in range(n_attempts):
374
+ for i in range(local_n_attempts):
80
375
  t0 = time.time()
81
376
  try:
82
377
  return await asyncio.wait_for(fn(*args, **kwargs), timeout=timeout)
@@ -84,12 +379,12 @@ def retry(direct_fn=None, *, n_attempts=3, base_delay=0, delay_factor=2, timeout
84
379
  logger.debug(f"Function {fn} was cancelled")
85
380
  raise
86
381
  except Exception as e:
87
- if i >= n_attempts - 1:
382
+ if i >= local_n_attempts - 1:
88
383
  raise
89
384
  logger.debug(
90
385
  f"Failed invoking function {fn}: {e}"
91
386
  f" (took {time.time() - t0}s, sleeping {delay}s"
92
- f" and trying {n_attempts - i - 1} more times)"
387
+ f" and trying {local_n_attempts - i - 1} more times)"
93
388
  )
94
389
  await asyncio.sleep(delay)
95
390
  delay *= delay_factor
@@ -125,7 +420,8 @@ class TaskContext:
125
420
  _loops: set[asyncio.Task]
126
421
 
127
422
  def __init__(self, grace: Optional[float] = None):
128
- self._grace = grace
423
+ self._grace = grace # grace is the time we want for tasks to finish before cancelling them
424
+ self._cancellation_grace: float = 1.0 # extra graceperiod for the cancellation itself to "bubble up"
129
425
  self._loops = set()
130
426
 
131
427
  async def start(self):
@@ -157,22 +453,29 @@ class TaskContext:
157
453
  # still needs to be handled
158
454
  # (https://stackoverflow.com/a/63356323/2475114)
159
455
  if gather_future:
160
- try:
456
+ with contextlib.suppress(asyncio.CancelledError):
161
457
  await gather_future
162
- except asyncio.CancelledError:
163
- pass
164
458
 
459
+ cancelled_tasks: list[asyncio.Task] = []
165
460
  for task in self._tasks:
166
461
  if task.done() and not task.cancelled():
167
462
  # Raise any exceptions if they happened.
168
463
  # Only tasks without a done_callback will still be present in self._tasks
169
464
  task.result()
170
465
 
171
- if task.done() or task in self._loops: # Note: Legacy code, we can probably cancel loops.
466
+ if task.done():
172
467
  continue
173
468
 
174
469
  # Cancel any remaining unfinished tasks.
175
470
  task.cancel()
471
+ cancelled_tasks.append(task)
472
+
473
+ cancellation_gather = asyncio.gather(*cancelled_tasks, return_exceptions=True)
474
+ try:
475
+ await asyncio.wait_for(cancellation_gather, timeout=self._cancellation_grace)
476
+ except asyncio.TimeoutError:
477
+ warnings.warn(f"Internal warning: Tasks did not cancel in a timely manner: {cancelled_tasks}")
478
+
176
479
  await asyncio.sleep(0) # wake up coroutines waiting for cancellations
177
480
 
178
481
  async def __aexit__(self, exc_type, value, tb):
@@ -370,6 +673,7 @@ class _WarnIfGeneratorIsNotConsumed:
370
673
  self.function_name = function_name
371
674
  self.iterated = False
372
675
  self.warned = False
676
+ self.__wrapped__ = gen
373
677
 
374
678
  def __aiter__(self):
375
679
  self.iterated = True
@@ -878,3 +1182,26 @@ async def async_chain(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T,
878
1182
  logger.exception(f"Error closing async generator: {e}")
879
1183
  if first_exception is not None:
880
1184
  raise first_exception
1185
+
1186
+
1187
+ def deprecate_aio_usage(deprecation_date: tuple[int, int, int], readable_sync_call: str):
1188
+ # Note: Currently only works on methods, not top level functions
1189
+ def deco(sync_implementation):
1190
+ if isinstance(sync_implementation, classmethod):
1191
+ sync_implementation = sync_implementation.__func__
1192
+ is_classmethod = True
1193
+ else:
1194
+ is_classmethod = False
1195
+
1196
+ async def _async_proxy(*args, **kwargs):
1197
+ deprecation_warning(
1198
+ deprecation_date,
1199
+ f"""The async constructor {readable_sync_call}.aio(...) will be deprecated in a future version of Modal.
1200
+ Please use {readable_sync_call}(...) instead (it doesn't perform any IO, and is safe in async contexts)
1201
+ """,
1202
+ )
1203
+ return sync_implementation(*args, **kwargs)
1204
+
1205
+ return MethodWithAio(sync_implementation, _async_proxy, synchronizer, is_classmethod=is_classmethod)
1206
+
1207
+ return deco
@@ -9,7 +9,6 @@ from typing import Any
9
9
  from modal.exception import ExecutionError
10
10
  from modal_proto import api_pb2, modal_api_grpc
11
11
 
12
- from .grpc_utils import retry_transient_errors
13
12
  from .logger import logger
14
13
 
15
14
 
@@ -66,9 +65,7 @@ class _AuthTokenManager:
66
65
  # new token. Once we have a new token, the other coroutines will unblock and return from here.
67
66
  if self._token and not self._needs_refresh():
68
67
  return
69
- resp: api_pb2.AuthTokenGetResponse = await retry_transient_errors(
70
- self._stub.AuthTokenGet, api_pb2.AuthTokenGetRequest()
71
- )
68
+ resp: api_pb2.AuthTokenGetResponse = await self._stub.AuthTokenGet(api_pb2.AuthTokenGetRequest())
72
69
  if not resp.token:
73
70
  # Not expected
74
71
  raise ExecutionError(
@@ -4,7 +4,6 @@ import dataclasses
4
4
  import hashlib
5
5
  import os
6
6
  import platform
7
- import random
8
7
  import time
9
8
  from collections.abc import AsyncIterator
10
9
  from contextlib import AbstractContextManager, contextmanager
@@ -27,7 +26,6 @@ from modal_proto.modal_api_grpc import ModalClientModal
27
26
 
28
27
  from ..exception import ExecutionError
29
28
  from .async_utils import TaskContext, retry
30
- from .grpc_utils import retry_transient_errors
31
29
  from .hash_utils import UploadHashes, get_upload_hashes
32
30
  from .http_utils import ClientSessionRegistry
33
31
  from .logger import logger
@@ -59,10 +57,8 @@ MULTIPART_UPLOAD_THRESHOLD = 1024**3
59
57
  # For block based storage like volumefs2: the size of a block
60
58
  BLOCK_SIZE: int = 8 * 1024 * 1024
61
59
 
62
- HEALTHY_R2_UPLOAD_PERCENTAGE = 0.95
63
60
 
64
-
65
- @retry(n_attempts=5, base_delay=0.5, timeout=None)
61
+ @retry(n_attempts=3, base_delay=0.3, timeout=None)
66
62
  async def _upload_to_s3_url(
67
63
  upload_url,
68
64
  payload: "BytesIOSegmentPayload",
@@ -153,12 +149,13 @@ async def perform_multipart_upload(
153
149
  part_etags = await TaskContext.gather(*upload_coros)
154
150
 
155
151
  # The body of the complete_multipart_upload command needs some data in xml format:
156
- completion_body = "<CompleteMultipartUpload>\n"
152
+ completion_parts = ["<CompleteMultipartUpload>"]
157
153
  for part_number, etag in enumerate(part_etags, 1):
158
- completion_body += f"""<Part>\n<PartNumber>{part_number}</PartNumber>\n<ETag>"{etag}"</ETag>\n</Part>\n"""
159
- completion_body += "</CompleteMultipartUpload>"
154
+ completion_parts.append(f"""<Part>\n<PartNumber>{part_number}</PartNumber>\n<ETag>"{etag}"</ETag>\n</Part>""")
155
+ completion_parts.append("</CompleteMultipartUpload>")
156
+ completion_body = "\n".join(completion_parts)
160
157
 
161
- # etag of combined object should be md5 hex of concatendated md5 *bytes* from parts + `-{num_parts}`
158
+ # etag of combined object should be md5 hex of concatenated md5 *bytes* from parts + `-{num_parts}`
162
159
  bin_hash_parts = [bytes.fromhex(etag) for etag in part_etags]
163
160
 
164
161
  expected_multipart_etag = hashlib.md5(b"".join(bin_hash_parts)).hexdigest() + f"-{len(part_etags)}"
@@ -191,13 +188,10 @@ def get_content_length(data: BinaryIO) -> int:
191
188
  async def _blob_upload_with_fallback(
192
189
  items, blob_ids: list[str], callback, content_length: int
193
190
  ) -> tuple[str, bool, int]:
191
+ """Try uploading to each provider in order, with fallback on failure."""
194
192
  r2_throughput_bytes_s = 0
195
193
  r2_failed = False
196
194
  for idx, (item, blob_id) in enumerate(zip(items, blob_ids)):
197
- # We want to default to R2 95% of the time and S3 5% of the time.
198
- # To ensure the failure path is continuously exercised.
199
- if idx == 0 and len(items) > 1 and random.random() > HEALTHY_R2_UPLOAD_PERCENTAGE:
200
- continue
201
195
  try:
202
196
  if blob_id.endswith(":r2"):
203
197
  t0 = time.monotonic_ns()
@@ -207,7 +201,7 @@ async def _blob_upload_with_fallback(
207
201
  else:
208
202
  await callback(item)
209
203
  return blob_id, r2_failed, r2_throughput_bytes_s
210
- except Exception as _:
204
+ except Exception:
211
205
  if blob_id.endswith(":r2"):
212
206
  r2_failed = True
213
207
  # Ignore all errors except the last one, since we're out of fallback options.
@@ -229,7 +223,7 @@ async def _blob_upload(
229
223
  content_sha256_base64=upload_hashes.sha256_base64,
230
224
  content_length=content_length,
231
225
  )
232
- resp = await retry_transient_errors(stub.BlobCreate, req)
226
+ resp = await stub.BlobCreate(req)
233
227
 
234
228
  if resp.WhichOneof("upload_types_oneof") == "multiparts":
235
229
 
@@ -335,7 +329,7 @@ async def blob_download(blob_id: str, stub: ModalClientModal) -> bytes:
335
329
  logger.debug(f"Downloading large blob {blob_id}")
336
330
  t0 = time.time()
337
331
  req = api_pb2.BlobGetRequest(blob_id=blob_id)
338
- resp = await retry_transient_errors(stub.BlobGet, req)
332
+ resp = await stub.BlobGet(req)
339
333
  data = await _download_from_url(resp.download_url)
340
334
  size_mib = len(data) / 1024 / 1024
341
335
  dur_s = max(time.time() - t0, 0.001) # avoid division by zero
@@ -348,7 +342,7 @@ async def blob_download(blob_id: str, stub: ModalClientModal) -> bytes:
348
342
 
349
343
  async def blob_iter(blob_id: str, stub: ModalClientModal) -> AsyncIterator[bytes]:
350
344
  req = api_pb2.BlobGetRequest(blob_id=blob_id)
351
- resp = await retry_transient_errors(stub.BlobGet, req)
345
+ resp = await stub.BlobGet(req)
352
346
  download_url = resp.download_url
353
347
  async with ClientSessionRegistry.get_session().get(download_url) as s3_resp:
354
348
  # S3 signal to slow down request rate.
@@ -372,11 +366,17 @@ class FileUploadSpec:
372
366
  mount_filename: str
373
367
 
374
368
  use_blob: bool
375
- content: Optional[bytes] # typically None if using blob, required otherwise
376
369
  sha256_hex: str
377
370
  md5_hex: str
378
371
  mode: int # file permission bits (last 12 bits of st_mode)
379
372
  size: int
373
+ content: Optional[bytes] = None # Set for very small files to avoid double-read
374
+
375
+ def read_content(self) -> bytes:
376
+ """Read content from source."""
377
+ with self.source() as fp:
378
+ fp.seek(0)
379
+ return fp.read()
380
380
 
381
381
 
382
382
  def _get_file_upload_spec(
@@ -385,6 +385,7 @@ def _get_file_upload_spec(
385
385
  mount_filename: PurePosixPath,
386
386
  mode: int,
387
387
  ) -> FileUploadSpec:
388
+ content = None
388
389
  with source() as fp:
389
390
  # Current position is ignored - we always upload from position 0
390
391
  fp.seek(0, os.SEEK_END)
@@ -395,12 +396,18 @@ def _get_file_upload_spec(
395
396
  # TODO(dano): remove the placeholder md5 once we stop requiring md5 for blobs
396
397
  md5_hex = "baadbaadbaadbaadbaadbaadbaadbaad" if size > MULTIPART_UPLOAD_THRESHOLD else None
397
398
  use_blob = True
398
- content = None
399
399
  hashes = get_upload_hashes(fp, md5_hex=md5_hex)
400
400
  else:
401
401
  use_blob = False
402
- content = fp.read()
403
- hashes = get_upload_hashes(content)
402
+ # For very small files (< 256 KiB), read content once and cache it
403
+ # This avoids double-read penalty while limiting memory usage
404
+ if size < 256 * 1024: # 256 KiB threshold
405
+ fp.seek(0)
406
+ content = fp.read()
407
+ hashes = get_upload_hashes(content)
408
+ else:
409
+ # For medium files (256 KiB - 4 MiB), compute hashes without caching content
410
+ hashes = get_upload_hashes(fp)
404
411
 
405
412
  return FileUploadSpec(
406
413
  source=source,
@@ -408,11 +415,11 @@ def _get_file_upload_spec(
408
415
  source_is_path=isinstance(source_description, Path),
409
416
  mount_filename=mount_filename.as_posix(),
410
417
  use_blob=use_blob,
411
- content=content,
412
418
  sha256_hex=hashes.sha256_hex(),
413
419
  md5_hex=hashes.md5_hex(),
414
420
  mode=mode & 0o7777,
415
421
  size=size,
422
+ content=content,
416
423
  )
417
424
 
418
425