modal 1.1.5.dev83__py3-none-any.whl → 1.3.1.dev8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (139) hide show
  1. modal/__init__.py +4 -4
  2. modal/__main__.py +4 -29
  3. modal/_billing.py +84 -0
  4. modal/_clustered_functions.py +1 -3
  5. modal/_container_entrypoint.py +33 -208
  6. modal/_functions.py +146 -121
  7. modal/_grpc_client.py +191 -0
  8. modal/_ipython.py +16 -6
  9. modal/_load_context.py +106 -0
  10. modal/_object.py +72 -21
  11. modal/_output.py +12 -14
  12. modal/_partial_function.py +31 -4
  13. modal/_resolver.py +44 -57
  14. modal/_runtime/container_io_manager.py +26 -28
  15. modal/_runtime/container_io_manager.pyi +42 -44
  16. modal/_runtime/gpu_memory_snapshot.py +9 -7
  17. modal/_runtime/user_code_event_loop.py +80 -0
  18. modal/_runtime/user_code_imports.py +236 -10
  19. modal/_serialization.py +2 -1
  20. modal/_traceback.py +4 -13
  21. modal/_tunnel.py +16 -11
  22. modal/_tunnel.pyi +25 -3
  23. modal/_utils/async_utils.py +337 -10
  24. modal/_utils/auth_token_manager.py +1 -4
  25. modal/_utils/blob_utils.py +29 -22
  26. modal/_utils/function_utils.py +20 -21
  27. modal/_utils/grpc_testing.py +6 -3
  28. modal/_utils/grpc_utils.py +223 -64
  29. modal/_utils/mount_utils.py +26 -1
  30. modal/_utils/package_utils.py +0 -1
  31. modal/_utils/rand_pb_testing.py +8 -1
  32. modal/_utils/task_command_router_client.py +524 -0
  33. modal/_vendor/cloudpickle.py +144 -48
  34. modal/app.py +215 -96
  35. modal/app.pyi +78 -37
  36. modal/billing.py +5 -0
  37. modal/builder/2025.06.txt +6 -3
  38. modal/builder/PREVIEW.txt +2 -1
  39. modal/builder/base-images.json +4 -2
  40. modal/cli/_download.py +19 -3
  41. modal/cli/cluster.py +4 -2
  42. modal/cli/config.py +3 -1
  43. modal/cli/container.py +5 -4
  44. modal/cli/dict.py +5 -2
  45. modal/cli/entry_point.py +26 -2
  46. modal/cli/environment.py +2 -16
  47. modal/cli/launch.py +1 -76
  48. modal/cli/network_file_system.py +5 -20
  49. modal/cli/queues.py +5 -4
  50. modal/cli/run.py +24 -204
  51. modal/cli/secret.py +1 -2
  52. modal/cli/shell.py +375 -0
  53. modal/cli/utils.py +1 -13
  54. modal/cli/volume.py +11 -17
  55. modal/client.py +16 -125
  56. modal/client.pyi +94 -144
  57. modal/cloud_bucket_mount.py +3 -1
  58. modal/cloud_bucket_mount.pyi +4 -0
  59. modal/cls.py +101 -64
  60. modal/cls.pyi +9 -8
  61. modal/config.py +21 -1
  62. modal/container_process.py +288 -12
  63. modal/container_process.pyi +99 -38
  64. modal/dict.py +72 -33
  65. modal/dict.pyi +88 -57
  66. modal/environments.py +16 -8
  67. modal/environments.pyi +6 -2
  68. modal/exception.py +154 -16
  69. modal/experimental/__init__.py +23 -5
  70. modal/experimental/flash.py +161 -74
  71. modal/experimental/flash.pyi +97 -49
  72. modal/file_io.py +50 -92
  73. modal/file_io.pyi +117 -89
  74. modal/functions.pyi +70 -87
  75. modal/image.py +73 -47
  76. modal/image.pyi +33 -30
  77. modal/io_streams.py +500 -149
  78. modal/io_streams.pyi +279 -189
  79. modal/mount.py +60 -45
  80. modal/mount.pyi +41 -17
  81. modal/network_file_system.py +19 -11
  82. modal/network_file_system.pyi +72 -39
  83. modal/object.pyi +114 -22
  84. modal/parallel_map.py +42 -44
  85. modal/parallel_map.pyi +9 -17
  86. modal/partial_function.pyi +4 -2
  87. modal/proxy.py +14 -6
  88. modal/proxy.pyi +10 -2
  89. modal/queue.py +45 -38
  90. modal/queue.pyi +88 -52
  91. modal/runner.py +96 -96
  92. modal/runner.pyi +44 -27
  93. modal/sandbox.py +225 -108
  94. modal/sandbox.pyi +226 -63
  95. modal/secret.py +58 -56
  96. modal/secret.pyi +28 -13
  97. modal/serving.py +7 -11
  98. modal/serving.pyi +7 -8
  99. modal/snapshot.py +29 -15
  100. modal/snapshot.pyi +18 -10
  101. modal/token_flow.py +1 -1
  102. modal/token_flow.pyi +4 -6
  103. modal/volume.py +102 -55
  104. modal/volume.pyi +125 -66
  105. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/METADATA +10 -9
  106. modal-1.3.1.dev8.dist-info/RECORD +189 -0
  107. modal_proto/api.proto +86 -30
  108. modal_proto/api_grpc.py +10 -25
  109. modal_proto/api_pb2.py +1080 -1047
  110. modal_proto/api_pb2.pyi +253 -79
  111. modal_proto/api_pb2_grpc.py +14 -48
  112. modal_proto/api_pb2_grpc.pyi +6 -18
  113. modal_proto/modal_api_grpc.py +175 -176
  114. modal_proto/{sandbox_router.proto → task_command_router.proto} +62 -45
  115. modal_proto/task_command_router_grpc.py +138 -0
  116. modal_proto/task_command_router_pb2.py +180 -0
  117. modal_proto/{sandbox_router_pb2.pyi → task_command_router_pb2.pyi} +110 -63
  118. modal_proto/task_command_router_pb2_grpc.py +272 -0
  119. modal_proto/task_command_router_pb2_grpc.pyi +100 -0
  120. modal_version/__init__.py +1 -1
  121. modal_version/__main__.py +1 -1
  122. modal/cli/programs/launch_instance_ssh.py +0 -94
  123. modal/cli/programs/run_marimo.py +0 -95
  124. modal-1.1.5.dev83.dist-info/RECORD +0 -191
  125. modal_proto/modal_options_grpc.py +0 -3
  126. modal_proto/options.proto +0 -19
  127. modal_proto/options_grpc.py +0 -3
  128. modal_proto/options_pb2.py +0 -35
  129. modal_proto/options_pb2.pyi +0 -20
  130. modal_proto/options_pb2_grpc.py +0 -4
  131. modal_proto/options_pb2_grpc.pyi +0 -7
  132. modal_proto/sandbox_router_grpc.py +0 -105
  133. modal_proto/sandbox_router_pb2.py +0 -148
  134. modal_proto/sandbox_router_pb2_grpc.py +0 -203
  135. modal_proto/sandbox_router_pb2_grpc.pyi +0 -75
  136. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/WHEEL +0 -0
  137. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/entry_points.txt +0 -0
  138. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/licenses/LICENSE +0 -0
  139. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/top_level.txt +0 -0
@@ -54,7 +54,12 @@ class _NetworkFileSystem(modal._object._Object):
54
54
  """
55
55
  @staticmethod
56
56
  def from_name(
57
- name: str, *, namespace=None, environment_name: typing.Optional[str] = None, create_if_missing: bool = False
57
+ name: str,
58
+ *,
59
+ namespace=None,
60
+ environment_name: typing.Optional[str] = None,
61
+ create_if_missing: bool = False,
62
+ client: typing.Optional[modal.client._Client] = None,
58
63
  ) -> _NetworkFileSystem:
59
64
  """Reference a NetworkFileSystem by its name, creating if necessary.
60
65
 
@@ -163,8 +168,6 @@ class _NetworkFileSystem(modal._object._Object):
163
168
  """Remove a file in a network file system."""
164
169
  ...
165
170
 
166
- SUPERSELF = typing.TypeVar("SUPERSELF", covariant=True)
167
-
168
171
  class NetworkFileSystem(modal.object.Object):
169
172
  """A shared, writable file system accessible by one or more Modal functions.
170
173
 
@@ -210,7 +213,12 @@ class NetworkFileSystem(modal.object.Object):
210
213
 
211
214
  @staticmethod
212
215
  def from_name(
213
- name: str, *, namespace=None, environment_name: typing.Optional[str] = None, create_if_missing: bool = False
216
+ name: str,
217
+ *,
218
+ namespace=None,
219
+ environment_name: typing.Optional[str] = None,
220
+ create_if_missing: bool = False,
221
+ client: typing.Optional[modal.client.Client] = None,
214
222
  ) -> NetworkFileSystem:
215
223
  """Reference a NetworkFileSystem by its name, creating if necessary.
216
224
 
@@ -228,27 +236,52 @@ class NetworkFileSystem(modal.object.Object):
228
236
  """
229
237
  ...
230
238
 
231
- @classmethod
232
- def ephemeral(
233
- cls: type[NetworkFileSystem],
234
- client: typing.Optional[modal.client.Client] = None,
235
- environment_name: typing.Optional[str] = None,
236
- _heartbeat_sleep: float = 300,
237
- ) -> synchronicity.combined_types.AsyncAndBlockingContextManager[NetworkFileSystem]:
238
- """Creates a new ephemeral network filesystem within a context manager:
239
+ class __ephemeral_spec(typing_extensions.Protocol):
240
+ def __call__(
241
+ self,
242
+ /,
243
+ client: typing.Optional[modal.client.Client] = None,
244
+ environment_name: typing.Optional[str] = None,
245
+ _heartbeat_sleep: float = 300,
246
+ ) -> synchronicity.combined_types.AsyncAndBlockingContextManager[NetworkFileSystem]:
247
+ """Creates a new ephemeral network filesystem within a context manager:
248
+
249
+ Usage:
250
+ ```python
251
+ with modal.NetworkFileSystem.ephemeral() as nfs:
252
+ assert nfs.listdir("/") == []
253
+ ```
254
+
255
+ ```python notest
256
+ async with modal.NetworkFileSystem.ephemeral() as nfs:
257
+ assert await nfs.listdir("/") == []
258
+ ```
259
+ """
260
+ ...
239
261
 
240
- Usage:
241
- ```python
242
- with modal.NetworkFileSystem.ephemeral() as nfs:
243
- assert nfs.listdir("/") == []
244
- ```
262
+ def aio(
263
+ self,
264
+ /,
265
+ client: typing.Optional[modal.client.Client] = None,
266
+ environment_name: typing.Optional[str] = None,
267
+ _heartbeat_sleep: float = 300,
268
+ ) -> typing.AsyncContextManager[NetworkFileSystem]:
269
+ """Creates a new ephemeral network filesystem within a context manager:
270
+
271
+ Usage:
272
+ ```python
273
+ with modal.NetworkFileSystem.ephemeral() as nfs:
274
+ assert nfs.listdir("/") == []
275
+ ```
276
+
277
+ ```python notest
278
+ async with modal.NetworkFileSystem.ephemeral() as nfs:
279
+ assert await nfs.listdir("/") == []
280
+ ```
281
+ """
282
+ ...
245
283
 
246
- ```python notest
247
- async with modal.NetworkFileSystem.ephemeral() as nfs:
248
- assert await nfs.listdir("/") == []
249
- ```
250
- """
251
- ...
284
+ ephemeral: typing.ClassVar[__ephemeral_spec]
252
285
 
253
286
  class __create_deployed_spec(typing_extensions.Protocol):
254
287
  def __call__(
@@ -273,7 +306,7 @@ class NetworkFileSystem(modal.object.Object):
273
306
  """mdmd:hidden"""
274
307
  ...
275
308
 
276
- create_deployed: __create_deployed_spec
309
+ create_deployed: typing.ClassVar[__create_deployed_spec]
277
310
 
278
311
  class __delete_spec(typing_extensions.Protocol):
279
312
  def __call__(
@@ -291,9 +324,9 @@ class NetworkFileSystem(modal.object.Object):
291
324
  environment_name: typing.Optional[str] = None,
292
325
  ): ...
293
326
 
294
- delete: __delete_spec
327
+ delete: typing.ClassVar[__delete_spec]
295
328
 
296
- class __write_file_spec(typing_extensions.Protocol[SUPERSELF]):
329
+ class __write_file_spec(typing_extensions.Protocol):
297
330
  def __call__(
298
331
  self,
299
332
  /,
@@ -326,9 +359,9 @@ class NetworkFileSystem(modal.object.Object):
326
359
  """
327
360
  ...
328
361
 
329
- write_file: __write_file_spec[typing_extensions.Self]
362
+ write_file: __write_file_spec
330
363
 
331
- class __read_file_spec(typing_extensions.Protocol[SUPERSELF]):
364
+ class __read_file_spec(typing_extensions.Protocol):
332
365
  def __call__(self, /, path: str) -> typing.Iterator[bytes]:
333
366
  """Read a file from the network file system"""
334
367
  ...
@@ -337,9 +370,9 @@ class NetworkFileSystem(modal.object.Object):
337
370
  """Read a file from the network file system"""
338
371
  ...
339
372
 
340
- read_file: __read_file_spec[typing_extensions.Self]
373
+ read_file: __read_file_spec
341
374
 
342
- class __iterdir_spec(typing_extensions.Protocol[SUPERSELF]):
375
+ class __iterdir_spec(typing_extensions.Protocol):
343
376
  def __call__(self, /, path: str) -> typing.Iterator[modal.volume.FileEntry]:
344
377
  """Iterate over all files in a directory in the network file system.
345
378
 
@@ -360,9 +393,9 @@ class NetworkFileSystem(modal.object.Object):
360
393
  """
361
394
  ...
362
395
 
363
- iterdir: __iterdir_spec[typing_extensions.Self]
396
+ iterdir: __iterdir_spec
364
397
 
365
- class __add_local_file_spec(typing_extensions.Protocol[SUPERSELF]):
398
+ class __add_local_file_spec(typing_extensions.Protocol):
366
399
  def __call__(
367
400
  self,
368
401
  /,
@@ -378,9 +411,9 @@ class NetworkFileSystem(modal.object.Object):
378
411
  progress_cb: typing.Optional[collections.abc.Callable[..., typing.Any]] = None,
379
412
  ): ...
380
413
 
381
- add_local_file: __add_local_file_spec[typing_extensions.Self]
414
+ add_local_file: __add_local_file_spec
382
415
 
383
- class __add_local_dir_spec(typing_extensions.Protocol[SUPERSELF]):
416
+ class __add_local_dir_spec(typing_extensions.Protocol):
384
417
  def __call__(
385
418
  self,
386
419
  /,
@@ -396,9 +429,9 @@ class NetworkFileSystem(modal.object.Object):
396
429
  progress_cb: typing.Optional[collections.abc.Callable[..., typing.Any]] = None,
397
430
  ): ...
398
431
 
399
- add_local_dir: __add_local_dir_spec[typing_extensions.Self]
432
+ add_local_dir: __add_local_dir_spec
400
433
 
401
- class __listdir_spec(typing_extensions.Protocol[SUPERSELF]):
434
+ class __listdir_spec(typing_extensions.Protocol):
402
435
  def __call__(self, /, path: str) -> list[modal.volume.FileEntry]:
403
436
  """List all files in a directory in the network file system.
404
437
 
@@ -419,9 +452,9 @@ class NetworkFileSystem(modal.object.Object):
419
452
  """
420
453
  ...
421
454
 
422
- listdir: __listdir_spec[typing_extensions.Self]
455
+ listdir: __listdir_spec
423
456
 
424
- class __remove_file_spec(typing_extensions.Protocol[SUPERSELF]):
457
+ class __remove_file_spec(typing_extensions.Protocol):
425
458
  def __call__(self, /, path: str, recursive=False):
426
459
  """Remove a file in a network file system."""
427
460
  ...
@@ -430,4 +463,4 @@ class NetworkFileSystem(modal.object.Object):
430
463
  """Remove a file in a network file system."""
431
464
  ...
432
465
 
433
- remove_file: __remove_file_spec[typing_extensions.Self]
466
+ remove_file: __remove_file_spec
modal/object.pyi CHANGED
@@ -1,5 +1,6 @@
1
1
  import collections.abc
2
2
  import google.protobuf.message
3
+ import modal._load_context
3
4
  import modal._resolver
4
5
  import modal.client
5
6
  import typing
@@ -12,12 +13,14 @@ class Object:
12
13
  _prefix_to_type: typing.ClassVar[dict[str, type]]
13
14
  _load: typing.Optional[
14
15
  collections.abc.Callable[
15
- [typing_extensions.Self, modal._resolver.Resolver, typing.Optional[str]], collections.abc.Awaitable[None]
16
+ [typing_extensions.Self, modal._resolver.Resolver, modal._load_context.LoadContext, typing.Optional[str]],
17
+ collections.abc.Awaitable[None],
16
18
  ]
17
19
  ]
18
20
  _preload: typing.Optional[
19
21
  collections.abc.Callable[
20
- [typing_extensions.Self, modal._resolver.Resolver, typing.Optional[str]], collections.abc.Awaitable[None]
22
+ [typing_extensions.Self, modal._resolver.Resolver, modal._load_context.LoadContext, typing.Optional[str]],
23
+ collections.abc.Awaitable[None],
21
24
  ]
22
25
  ]
23
26
  _rep: str
@@ -27,6 +30,7 @@ class Object:
27
30
  _deduplication_key: typing.Optional[
28
31
  collections.abc.Callable[[], collections.abc.Awaitable[collections.abc.Hashable]]
29
32
  ]
33
+ _load_context_overrides: modal._load_context.LoadContext
30
34
  _object_id: typing.Optional[str]
31
35
  _client: typing.Optional[modal.client.Client]
32
36
  _is_hydrated: bool
@@ -40,22 +44,40 @@ class Object:
40
44
  @classmethod
41
45
  def __init_subclass__(cls, type_prefix: typing.Optional[str] = None): ...
42
46
 
43
- class ___init_spec(typing_extensions.Protocol[SUPERSELF]):
47
+ class ___init_spec(typing_extensions.Protocol):
44
48
  def __call__(
45
49
  self,
46
50
  /,
47
51
  rep: str,
48
52
  load: typing.Optional[
49
- collections.abc.Callable[[SUPERSELF, modal._resolver.Resolver, typing.Optional[str]], None]
53
+ collections.abc.Callable[
54
+ [
55
+ typing_extensions.Self,
56
+ modal._resolver.Resolver,
57
+ modal._load_context.LoadContext,
58
+ typing.Optional[str],
59
+ ],
60
+ None,
61
+ ]
50
62
  ] = None,
51
63
  is_another_app: bool = False,
52
64
  preload: typing.Optional[
53
- collections.abc.Callable[[SUPERSELF, modal._resolver.Resolver, typing.Optional[str]], None]
65
+ collections.abc.Callable[
66
+ [
67
+ typing_extensions.Self,
68
+ modal._resolver.Resolver,
69
+ modal._load_context.LoadContext,
70
+ typing.Optional[str],
71
+ ],
72
+ None,
73
+ ]
54
74
  ] = None,
55
75
  hydrate_lazily: bool = False,
56
76
  deps: typing.Optional[collections.abc.Callable[..., collections.abc.Sequence[Object]]] = None,
57
77
  deduplication_key: typing.Optional[collections.abc.Callable[[], collections.abc.Hashable]] = None,
58
78
  name: typing.Optional[str] = None,
79
+ *,
80
+ load_context_overrides: typing.Optional[modal._load_context.LoadContext] = None,
59
81
  ): ...
60
82
  def aio(
61
83
  self,
@@ -63,13 +85,25 @@ class Object:
63
85
  rep: str,
64
86
  load: typing.Optional[
65
87
  collections.abc.Callable[
66
- [SUPERSELF, modal._resolver.Resolver, typing.Optional[str]], collections.abc.Awaitable[None]
88
+ [
89
+ typing_extensions.Self,
90
+ modal._resolver.Resolver,
91
+ modal._load_context.LoadContext,
92
+ typing.Optional[str],
93
+ ],
94
+ collections.abc.Awaitable[None],
67
95
  ]
68
96
  ] = None,
69
97
  is_another_app: bool = False,
70
98
  preload: typing.Optional[
71
99
  collections.abc.Callable[
72
- [SUPERSELF, modal._resolver.Resolver, typing.Optional[str]], collections.abc.Awaitable[None]
100
+ [
101
+ typing_extensions.Self,
102
+ modal._resolver.Resolver,
103
+ modal._load_context.LoadContext,
104
+ typing.Optional[str],
105
+ ],
106
+ collections.abc.Awaitable[None],
73
107
  ]
74
108
  ] = None,
75
109
  hydrate_lazily: bool = False,
@@ -78,9 +112,11 @@ class Object:
78
112
  collections.abc.Callable[[], collections.abc.Awaitable[collections.abc.Hashable]]
79
113
  ] = None,
80
114
  name: typing.Optional[str] = None,
115
+ *,
116
+ load_context_overrides: typing.Optional[modal._load_context.LoadContext] = None,
81
117
  ): ...
82
118
 
83
- _init: ___init_spec[typing_extensions.Self]
119
+ _init: ___init_spec
84
120
 
85
121
  def _unhydrate(self): ...
86
122
  def _initialize_from_empty(self): ...
@@ -98,20 +134,76 @@ class Object:
98
134
  """
99
135
  ...
100
136
 
101
- @classmethod
102
- def _from_loader(
103
- cls,
104
- load: collections.abc.Callable[[typing_extensions.Self, modal._resolver.Resolver, typing.Optional[str]], None],
105
- rep: str,
106
- is_another_app: bool = False,
107
- preload: typing.Optional[
108
- collections.abc.Callable[[typing_extensions.Self, modal._resolver.Resolver, typing.Optional[str]], None]
109
- ] = None,
110
- hydrate_lazily: bool = False,
111
- deps: typing.Optional[collections.abc.Callable[..., collections.abc.Sequence[Object]]] = None,
112
- deduplication_key: typing.Optional[collections.abc.Callable[[], collections.abc.Hashable]] = None,
113
- name: typing.Optional[str] = None,
114
- ): ...
137
+ class ___from_loader_spec(typing_extensions.Protocol):
138
+ def __call__(
139
+ self,
140
+ /,
141
+ load: collections.abc.Callable[
142
+ [
143
+ typing_extensions.Self,
144
+ modal._resolver.Resolver,
145
+ modal._load_context.LoadContext,
146
+ typing.Optional[str],
147
+ ],
148
+ None,
149
+ ],
150
+ rep: str,
151
+ is_another_app: bool = False,
152
+ preload: typing.Optional[
153
+ collections.abc.Callable[
154
+ [
155
+ typing_extensions.Self,
156
+ modal._resolver.Resolver,
157
+ modal._load_context.LoadContext,
158
+ typing.Optional[str],
159
+ ],
160
+ None,
161
+ ]
162
+ ] = None,
163
+ hydrate_lazily: bool = False,
164
+ deps: typing.Optional[collections.abc.Callable[..., collections.abc.Sequence[Object]]] = None,
165
+ deduplication_key: typing.Optional[collections.abc.Callable[[], collections.abc.Hashable]] = None,
166
+ name: typing.Optional[str] = None,
167
+ *,
168
+ load_context_overrides: modal._load_context.LoadContext,
169
+ ): ...
170
+ def aio(
171
+ self,
172
+ /,
173
+ load: collections.abc.Callable[
174
+ [
175
+ typing_extensions.Self,
176
+ modal._resolver.Resolver,
177
+ modal._load_context.LoadContext,
178
+ typing.Optional[str],
179
+ ],
180
+ collections.abc.Awaitable[None],
181
+ ],
182
+ rep: str,
183
+ is_another_app: bool = False,
184
+ preload: typing.Optional[
185
+ collections.abc.Callable[
186
+ [
187
+ typing_extensions.Self,
188
+ modal._resolver.Resolver,
189
+ modal._load_context.LoadContext,
190
+ typing.Optional[str],
191
+ ],
192
+ collections.abc.Awaitable[None],
193
+ ]
194
+ ] = None,
195
+ hydrate_lazily: bool = False,
196
+ deps: typing.Optional[collections.abc.Callable[..., collections.abc.Sequence[Object]]] = None,
197
+ deduplication_key: typing.Optional[
198
+ collections.abc.Callable[[], collections.abc.Awaitable[collections.abc.Hashable]]
199
+ ] = None,
200
+ name: typing.Optional[str] = None,
201
+ *,
202
+ load_context_overrides: modal._load_context.LoadContext,
203
+ ): ...
204
+
205
+ _from_loader: typing.ClassVar[___from_loader_spec]
206
+
115
207
  @staticmethod
116
208
  def _get_type_from_id(object_id: str) -> type[Object]: ...
117
209
  @classmethod
modal/parallel_map.py CHANGED
@@ -35,7 +35,7 @@ from modal._utils.function_utils import (
35
35
  _create_input,
36
36
  _process_result,
37
37
  )
38
- from modal._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, RetryWarningMessage, retry_transient_errors
38
+ from modal._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, Retry, RetryWarningMessage
39
39
  from modal._utils.jwt_utils import DecodedJwt
40
40
  from modal.config import logger
41
41
  from modal.retries import RetryManager
@@ -187,7 +187,7 @@ class InputPumper:
187
187
  f" push is {self.input_queue.qsize()}. "
188
188
  )
189
189
 
190
- resp = await self._send_inputs(self.client.stub.FunctionPutInputs, request)
190
+ resp = await self.client.stub.FunctionPutInputs(request, retry=self._function_inputs_retry)
191
191
  self.inputs_sent += len(items)
192
192
  # Change item state to WAITING_FOR_OUTPUT, and set the input_id and input_jwt which are in the response.
193
193
  if self.map_items_manager is not None:
@@ -198,11 +198,8 @@ class InputPumper:
198
198
  )
199
199
  yield
200
200
 
201
- async def _send_inputs(
202
- self,
203
- fn: "modal.client.UnaryUnaryWrapper",
204
- request: typing.Union[api_pb2.FunctionPutInputsRequest, api_pb2.FunctionRetryInputsRequest],
205
- ) -> typing.Union[api_pb2.FunctionPutInputsResponse, api_pb2.FunctionRetryInputsResponse]:
201
+ @property
202
+ def _function_inputs_retry(self) -> Retry:
206
203
  # with 8 retries we log the warning below about every 30 seconds which isn't too spammy.
207
204
  retry_warning_message = RetryWarningMessage(
208
205
  message=f"Warning: map progress for function {self.function._function_name} is limited."
@@ -210,13 +207,11 @@ class InputPumper:
210
207
  warning_interval=8,
211
208
  errors_to_warn_for=[Status.RESOURCE_EXHAUSTED],
212
209
  )
213
- return await retry_transient_errors(
214
- fn,
215
- request,
210
+ return Retry(
216
211
  max_retries=None,
217
212
  max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
218
213
  additional_status_codes=[Status.RESOURCE_EXHAUSTED],
219
- retry_warning_message=retry_warning_message,
214
+ warning_message=retry_warning_message,
220
215
  )
221
216
 
222
217
 
@@ -255,7 +250,7 @@ class SyncInputPumper(InputPumper):
255
250
  function_call_jwt=self.function_call_jwt,
256
251
  inputs=inputs,
257
252
  )
258
- resp = await self._send_inputs(self.client.stub.FunctionRetryInputs, request)
253
+ resp = await self.client.stub.FunctionRetryInputs(request, retry=self._function_inputs_retry)
259
254
  # Update the state to WAITING_FOR_OUTPUT, and update the input_jwt in the context
260
255
  # to the new value in the response.
261
256
  self.map_items_manager.handle_retry_response(resp.input_jwts)
@@ -289,7 +284,7 @@ class AsyncInputPumper(InputPumper):
289
284
  function_call_id=self.function_call_id,
290
285
  num_inputs=self.inputs_sent,
291
286
  )
292
- await retry_transient_errors(self.client.stub.FunctionFinishInputs, request, max_retries=None)
287
+ await self.client.stub.FunctionFinishInputs(request, retry=Retry(max_retries=None))
293
288
  yield
294
289
 
295
290
 
@@ -303,7 +298,7 @@ async def _spawn_map_invocation(
303
298
  function_call_type=api_pb2.FUNCTION_CALL_TYPE_MAP,
304
299
  function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC,
305
300
  )
306
- response: api_pb2.FunctionMapResponse = await retry_transient_errors(client.stub.FunctionMap, request)
301
+ response: api_pb2.FunctionMapResponse = await client.stub.FunctionMap(request)
307
302
  function_call_id = response.function_call_id
308
303
 
309
304
  have_all_inputs = False
@@ -382,7 +377,7 @@ async def _map_invocation(
382
377
  return_exceptions=return_exceptions,
383
378
  function_call_invocation_type=function_call_invocation_type,
384
379
  )
385
- response: api_pb2.FunctionMapResponse = await retry_transient_errors(client.stub.FunctionMap, request)
380
+ response: api_pb2.FunctionMapResponse = await client.stub.FunctionMap(request)
386
381
 
387
382
  function_call_id = response.function_call_id
388
383
  function_call_jwt = response.function_call_jwt
@@ -478,11 +473,12 @@ async def _map_invocation(
478
473
  input_jwts=input_jwts,
479
474
  )
480
475
  get_response_task = asyncio.create_task(
481
- retry_transient_errors(
482
- client.stub.FunctionGetOutputs,
476
+ client.stub.FunctionGetOutputs(
483
477
  request,
484
- max_retries=20,
485
- attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
478
+ retry=Retry(
479
+ max_retries=20,
480
+ attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
481
+ ),
486
482
  )
487
483
  )
488
484
  map_done_task = asyncio.create_task(map_done_event.wait())
@@ -541,7 +537,7 @@ async def _map_invocation(
541
537
  clear_on_success=True,
542
538
  requested_at=time.time(),
543
539
  )
544
- await retry_transient_errors(client.stub.FunctionGetOutputs, request)
540
+ await client.stub.FunctionGetOutputs(request)
545
541
  await retry_queue.close()
546
542
 
547
543
  async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> tuple[int, Any]:
@@ -671,8 +667,8 @@ async def _map_invocation_inputplane(
671
667
  last_entry_id = ""
672
668
 
673
669
  # The input-plane server returns this after the first request.
674
- function_call_id = None
675
- function_call_id_received = asyncio.Event()
670
+ map_token = None
671
+ map_token_received = asyncio.Event()
676
672
 
677
673
  # Single priority queue that holds *both* fresh inputs (timestamp == now)
678
674
  # and future retries (timestamp > now).
@@ -751,7 +747,7 @@ async def _map_invocation_inputplane(
751
747
  yield
752
748
 
753
749
  async def pump_inputs():
754
- nonlocal function_call_id, max_inputs_outstanding
750
+ nonlocal map_token, max_inputs_outstanding
755
751
  async for batch in queue_batch_iterator(queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
756
752
  # Convert the queued items into the proto format expected by the RPC.
757
753
  request_items: list[api_pb2.MapStartOrContinueItem] = [
@@ -763,20 +759,21 @@ async def _map_invocation_inputplane(
763
759
  # Build request
764
760
  request = api_pb2.MapStartOrContinueRequest(
765
761
  function_id=function.object_id,
766
- function_call_id=function_call_id,
762
+ map_token=map_token,
767
763
  parent_input_id=current_input_id() or "",
768
764
  items=request_items,
769
765
  )
770
766
 
771
767
  metadata = await client.get_input_plane_metadata(function._input_plane_region)
772
768
 
773
- response: api_pb2.MapStartOrContinueResponse = await retry_transient_errors(
774
- input_plane_stub.MapStartOrContinue,
769
+ response: api_pb2.MapStartOrContinueResponse = await input_plane_stub.MapStartOrContinue(
775
770
  request,
771
+ retry=Retry(
772
+ additional_status_codes=[Status.RESOURCE_EXHAUSTED],
773
+ max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
774
+ max_retries=None,
775
+ ),
776
776
  metadata=metadata,
777
- additional_status_codes=[Status.RESOURCE_EXHAUSTED],
778
- max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
779
- max_retries=None,
780
777
  )
781
778
 
782
779
  # match response items to the corresponding request item index
@@ -789,9 +786,9 @@ async def _map_invocation_inputplane(
789
786
 
790
787
  # Set the function call id and actual retry policy with the data from the first response.
791
788
  # This conditional is skipped for subsequent iterations of this for-loop.
792
- if function_call_id is None:
793
- function_call_id = response.function_call_id
794
- function_call_id_received.set()
789
+ if map_token is None:
790
+ map_token = response.map_token
791
+ map_token_received.set()
795
792
  max_inputs_outstanding = response.max_inputs_outstanding or MAX_INPUTS_OUTSTANDING_DEFAULT
796
793
  map_items_manager.set_retry_policy(response.retry_policy)
797
794
  # Update the retry policy for the first batch of inputs.
@@ -804,8 +801,8 @@ async def _map_invocation_inputplane(
804
801
  nonlocal last_entry_id # shared with get_all_outputs
805
802
  try:
806
803
  while not map_done_event.is_set():
807
- if function_call_id is None:
808
- await function_call_id_received.wait()
804
+ if map_token is None:
805
+ await map_token_received.wait()
809
806
  continue
810
807
 
811
808
  sleep_task = asyncio.create_task(asyncio.sleep(1))
@@ -824,8 +821,8 @@ async def _map_invocation_inputplane(
824
821
  )
825
822
 
826
823
  metadata = await client.get_input_plane_metadata(function._input_plane_region)
827
- response: api_pb2.MapCheckInputsResponse = await retry_transient_errors(
828
- input_plane_stub.MapCheckInputs, request, metadata=metadata
824
+ response: api_pb2.MapCheckInputsResponse = await input_plane_stub.MapCheckInputs(
825
+ request, metadata=metadata
829
826
  )
830
827
  check_inputs_response = [
831
828
  (check_inputs[resp_idx][0], response.lost[resp_idx]) for resp_idx, _ in enumerate(response.lost)
@@ -847,23 +844,24 @@ async def _map_invocation_inputplane(
847
844
  last_entry_id
848
845
 
849
846
  while not map_done_event.is_set():
850
- if function_call_id is None:
851
- await function_call_id_received.wait()
847
+ if map_token is None:
848
+ await map_token_received.wait()
852
849
  continue
853
850
 
854
851
  request = api_pb2.MapAwaitRequest(
855
- function_call_id=function_call_id,
852
+ map_token=map_token,
856
853
  last_entry_id=last_entry_id,
857
854
  requested_at=time.time(),
858
855
  timeout=OUTPUTS_TIMEOUT,
859
856
  )
860
857
  metadata = await client.get_input_plane_metadata(function._input_plane_region)
861
858
  get_response_task = asyncio.create_task(
862
- retry_transient_errors(
863
- input_plane_stub.MapAwait,
859
+ input_plane_stub.MapAwait(
864
860
  request,
865
- max_retries=20,
866
- attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
861
+ retry=Retry(
862
+ max_retries=20,
863
+ attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
864
+ ),
867
865
  metadata=metadata,
868
866
  )
869
867
  )
@@ -963,7 +961,7 @@ async def _map_invocation_inputplane(
963
961
  f"Map stats:\nsuccessful_completions={successful_completions} failed_completions={failed_completions} "
964
962
  f"no_context_duplicates={no_context_duplicates} stale_retry_duplicates={stale_retry_duplicates} "
965
963
  f"already_complete_duplicates={already_complete_duplicates} retried_outputs={retried_outputs} "
966
- f"function_call_id={function_call_id} max_inputs_outstanding={max_inputs_outstanding} "
964
+ f"map_token={map_token} max_inputs_outstanding={max_inputs_outstanding} "
967
965
  f"map_items_manager_size={len(map_items_manager)} input_queue_size={input_queue_size}"
968
966
  )
969
967