modal 1.1.1.dev41__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (68) hide show
  1. modal/__main__.py +1 -2
  2. modal/_container_entrypoint.py +18 -7
  3. modal/_functions.py +135 -13
  4. modal/_object.py +13 -2
  5. modal/_partial_function.py +8 -8
  6. modal/_runtime/asgi.py +3 -2
  7. modal/_runtime/container_io_manager.py +20 -14
  8. modal/_runtime/container_io_manager.pyi +38 -13
  9. modal/_runtime/execution_context.py +18 -2
  10. modal/_runtime/execution_context.pyi +4 -1
  11. modal/_runtime/gpu_memory_snapshot.py +158 -54
  12. modal/_utils/blob_utils.py +83 -24
  13. modal/_utils/function_utils.py +4 -3
  14. modal/_utils/time_utils.py +28 -4
  15. modal/app.py +8 -4
  16. modal/app.pyi +8 -8
  17. modal/cli/dict.py +14 -11
  18. modal/cli/entry_point.py +9 -3
  19. modal/cli/launch.py +102 -4
  20. modal/cli/profile.py +1 -0
  21. modal/cli/programs/launch_instance_ssh.py +94 -0
  22. modal/cli/programs/run_marimo.py +95 -0
  23. modal/cli/queues.py +49 -19
  24. modal/cli/secret.py +45 -18
  25. modal/cli/volume.py +14 -16
  26. modal/client.pyi +2 -10
  27. modal/cls.py +12 -2
  28. modal/cls.pyi +9 -1
  29. modal/config.py +7 -7
  30. modal/dict.py +206 -12
  31. modal/dict.pyi +358 -4
  32. modal/experimental/__init__.py +130 -0
  33. modal/file_io.py +1 -1
  34. modal/file_io.pyi +2 -2
  35. modal/file_pattern_matcher.py +25 -16
  36. modal/functions.pyi +111 -11
  37. modal/image.py +9 -3
  38. modal/image.pyi +7 -7
  39. modal/mount.py +20 -13
  40. modal/mount.pyi +16 -3
  41. modal/network_file_system.py +8 -2
  42. modal/object.pyi +3 -0
  43. modal/parallel_map.py +346 -101
  44. modal/parallel_map.pyi +108 -0
  45. modal/proxy.py +2 -1
  46. modal/queue.py +199 -9
  47. modal/queue.pyi +357 -3
  48. modal/sandbox.py +6 -5
  49. modal/sandbox.pyi +17 -14
  50. modal/secret.py +196 -3
  51. modal/secret.pyi +372 -0
  52. modal/volume.py +239 -23
  53. modal/volume.pyi +405 -10
  54. {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/METADATA +2 -2
  55. {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/RECORD +68 -66
  56. modal_docs/mdmd/mdmd.py +11 -1
  57. modal_proto/api.proto +37 -10
  58. modal_proto/api_grpc.py +32 -0
  59. modal_proto/api_pb2.py +627 -597
  60. modal_proto/api_pb2.pyi +107 -19
  61. modal_proto/api_pb2_grpc.py +67 -2
  62. modal_proto/api_pb2_grpc.pyi +24 -8
  63. modal_proto/modal_api_grpc.py +2 -0
  64. modal_version/__init__.py +1 -1
  65. {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/WHEEL +0 -0
  66. {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/entry_points.txt +0 -0
  67. {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/licenses/LICENSE +0 -0
  68. {modal-1.1.1.dev41.dist-info → modal-1.1.2.dist-info}/top_level.txt +0 -0
@@ -136,7 +136,7 @@ class _NetworkFileSystem(_Object, type_prefix="sv"):
136
136
  cls: type["_NetworkFileSystem"],
137
137
  client: Optional[_Client] = None,
138
138
  environment_name: Optional[str] = None,
139
- _heartbeat_sleep: float = EPHEMERAL_OBJECT_HEARTBEAT_SLEEP,
139
+ _heartbeat_sleep: float = EPHEMERAL_OBJECT_HEARTBEAT_SLEEP, # mdmd:line-hidden
140
140
  ) -> AsyncIterator["_NetworkFileSystem"]:
141
141
  """Creates a new ephemeral network filesystem within a context manager:
142
142
 
@@ -161,7 +161,13 @@ class _NetworkFileSystem(_Object, type_prefix="sv"):
161
161
  async with TaskContext() as tc:
162
162
  request = api_pb2.SharedVolumeHeartbeatRequest(shared_volume_id=response.shared_volume_id)
163
163
  tc.infinite_loop(lambda: client.stub.SharedVolumeHeartbeat(request), sleep=_heartbeat_sleep)
164
- yield cls._new_hydrated(response.shared_volume_id, client, None, is_another_app=True)
164
+ yield cls._new_hydrated(
165
+ response.shared_volume_id,
166
+ client,
167
+ None,
168
+ is_another_app=True,
169
+ rep="modal.NetworkFileSystem.ephemeral()",
170
+ )
165
171
 
166
172
  @staticmethod
167
173
  async def lookup(
modal/object.pyi CHANGED
@@ -117,12 +117,15 @@ class Object:
117
117
  @classmethod
118
118
  def _is_id_type(cls, object_id) -> bool: ...
119
119
  @classmethod
120
+ def _repr(cls, name: str, environment_name: typing.Optional[str] = None) -> str: ...
121
+ @classmethod
120
122
  def _new_hydrated(
121
123
  cls,
122
124
  object_id: str,
123
125
  client: modal.client.Client,
124
126
  handle_metadata: typing.Optional[google.protobuf.message.Message],
125
127
  is_another_app: bool = False,
128
+ rep: typing.Optional[str] = None,
126
129
  ) -> typing_extensions.Self: ...
127
130
  def _hydrate_from_other(self, other: typing_extensions.Self): ...
128
131
  def __repr__(self): ...
modal/parallel_map.py CHANGED
@@ -86,6 +86,274 @@ if typing.TYPE_CHECKING:
86
86
  import modal.functions
87
87
 
88
88
 
89
+ class InputPreprocessor:
90
+ """
91
+ Constructs FunctionPutInputsItem objects from the raw-input queue, and puts them in the processed-input queue.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ client: "modal.client._Client",
97
+ *,
98
+ raw_input_queue: _SynchronizedQueue,
99
+ processed_input_queue: asyncio.Queue,
100
+ function: "modal.functions._Function",
101
+ created_callback: Callable[[int], None],
102
+ done_callback: Callable[[], None],
103
+ ):
104
+ self.client = client
105
+ self.function = function
106
+ self.inputs_created = 0
107
+ self.raw_input_queue = raw_input_queue
108
+ self.processed_input_queue = processed_input_queue
109
+ self.created_callback = created_callback
110
+ self.done_callback = done_callback
111
+
112
+ async def input_iter(self):
113
+ while 1:
114
+ raw_input = await self.raw_input_queue.get()
115
+ if raw_input is None: # end of input sentinel
116
+ break
117
+ yield raw_input # args, kwargs
118
+
119
+ def create_input_factory(self):
120
+ async def create_input(argskwargs):
121
+ idx = self.inputs_created
122
+ self.inputs_created += 1
123
+ self.created_callback(self.inputs_created)
124
+ (args, kwargs) = argskwargs
125
+ return await _create_input(
126
+ args,
127
+ kwargs,
128
+ self.client.stub,
129
+ max_object_size_bytes=self.function._max_object_size_bytes,
130
+ idx=idx,
131
+ method_name=self.function._use_method_name,
132
+ )
133
+
134
+ return create_input
135
+
136
+ async def drain_input_generator(self):
137
+ # Parallelize uploading blobs
138
+ async with aclosing(
139
+ async_map_ordered(self.input_iter(), self.create_input_factory(), concurrency=BLOB_MAX_PARALLELISM)
140
+ ) as streamer:
141
+ async for item in streamer:
142
+ await self.processed_input_queue.put(item)
143
+
144
+ # close queue iterator
145
+ await self.processed_input_queue.put(None)
146
+ self.done_callback()
147
+ yield
148
+
149
+
150
+ class InputPumper:
151
+ """
152
+ Reads inputs from a queue of FunctionPutInputsItems, and sends them to the server.
153
+ """
154
+
155
+ def __init__(
156
+ self,
157
+ client: "modal.client._Client",
158
+ *,
159
+ input_queue: asyncio.Queue,
160
+ function: "modal.functions._Function",
161
+ function_call_id: str,
162
+ map_items_manager: Optional["_MapItemsManager"] = None,
163
+ ):
164
+ self.client = client
165
+ self.function = function
166
+ self.map_items_manager = map_items_manager
167
+ self.input_queue = input_queue
168
+ self.inputs_sent = 0
169
+ self.function_call_id = function_call_id
170
+
171
+ async def pump_inputs(self):
172
+ assert self.client.stub
173
+ async for items in queue_batch_iterator(self.input_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
174
+ # Add items to the manager. Their state will be SENDING.
175
+ if self.map_items_manager is not None:
176
+ await self.map_items_manager.add_items(items)
177
+ request = api_pb2.FunctionPutInputsRequest(
178
+ function_id=self.function.object_id,
179
+ inputs=items,
180
+ function_call_id=self.function_call_id,
181
+ )
182
+ logger.debug(
183
+ f"Pushing {len(items)} inputs to server. Num queued inputs awaiting"
184
+ f" push is {self.input_queue.qsize()}. "
185
+ )
186
+
187
+ resp = await self._send_inputs(self.client.stub.FunctionPutInputs, request)
188
+ self.inputs_sent += len(items)
189
+ # Change item state to WAITING_FOR_OUTPUT, and set the input_id and input_jwt which are in the response.
190
+ if self.map_items_manager is not None:
191
+ self.map_items_manager.handle_put_inputs_response(resp.inputs)
192
+ logger.debug(
193
+ f"Successfully pushed {len(items)} inputs to server. "
194
+ f"Num queued inputs awaiting push is {self.input_queue.qsize()}."
195
+ )
196
+ yield
197
+
198
+ async def _send_inputs(
199
+ self,
200
+ fn: "modal.client.UnaryUnaryWrapper",
201
+ request: typing.Union[api_pb2.FunctionPutInputsRequest, api_pb2.FunctionRetryInputsRequest],
202
+ ) -> typing.Union[api_pb2.FunctionPutInputsResponse, api_pb2.FunctionRetryInputsResponse]:
203
+ # with 8 retries we log the warning below about every 30 seconds which isn't too spammy.
204
+ retry_warning_message = RetryWarningMessage(
205
+ message=f"Warning: map progress for function {self.function._function_name} is limited."
206
+ " Common bottlenecks include slow iteration over results, or function backlogs.",
207
+ warning_interval=8,
208
+ errors_to_warn_for=[Status.RESOURCE_EXHAUSTED],
209
+ )
210
+ return await retry_transient_errors(
211
+ fn,
212
+ request,
213
+ max_retries=None,
214
+ max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
215
+ additional_status_codes=[Status.RESOURCE_EXHAUSTED],
216
+ retry_warning_message=retry_warning_message,
217
+ )
218
+
219
+
220
+ class SyncInputPumper(InputPumper):
221
+ def __init__(
222
+ self,
223
+ client: "modal.client._Client",
224
+ *,
225
+ input_queue: asyncio.Queue,
226
+ retry_queue: TimestampPriorityQueue,
227
+ function: "modal.functions._Function",
228
+ function_call_jwt: str,
229
+ function_call_id: str,
230
+ map_items_manager: "_MapItemsManager",
231
+ ):
232
+ super().__init__(
233
+ client,
234
+ input_queue=input_queue,
235
+ function=function,
236
+ function_call_id=function_call_id,
237
+ map_items_manager=map_items_manager,
238
+ )
239
+ self.retry_queue = retry_queue
240
+ self.inputs_retried = 0
241
+ self.function_call_jwt = function_call_jwt
242
+
243
+ async def retry_inputs(self):
244
+ async for retriable_idxs in queue_batch_iterator(self.retry_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
245
+ # For each index, use the context in the manager to create a FunctionRetryInputsItem.
246
+ # This will also update the context state to RETRYING.
247
+ inputs: list[api_pb2.FunctionRetryInputsItem] = await self.map_items_manager.prepare_items_for_retry(
248
+ retriable_idxs
249
+ )
250
+ request = api_pb2.FunctionRetryInputsRequest(
251
+ function_call_jwt=self.function_call_jwt,
252
+ inputs=inputs,
253
+ )
254
+ resp = await self._send_inputs(self.client.stub.FunctionRetryInputs, request)
255
+ # Update the state to WAITING_FOR_OUTPUT, and update the input_jwt in the context
256
+ # to the new value in the response.
257
+ self.map_items_manager.handle_retry_response(resp.input_jwts)
258
+ logger.debug(f"Successfully pushed retry for {len(inputs)} to server.")
259
+ self.inputs_retried += len(inputs)
260
+ yield
261
+
262
+
263
+ class AsyncInputPumper(InputPumper):
264
+ def __init__(
265
+ self,
266
+ client: "modal.client._Client",
267
+ *,
268
+ input_queue: asyncio.Queue,
269
+ function: "modal.functions._Function",
270
+ function_call_id: str,
271
+ ):
272
+ super().__init__(client, input_queue=input_queue, function=function, function_call_id=function_call_id)
273
+
274
+ async def pump_inputs(self):
275
+ async for _ in super().pump_inputs():
276
+ pass
277
+ request = api_pb2.FunctionFinishInputsRequest(
278
+ function_id=self.function.object_id,
279
+ function_call_id=self.function_call_id,
280
+ num_inputs=self.inputs_sent,
281
+ )
282
+ await retry_transient_errors(self.client.stub.FunctionFinishInputs, request, max_retries=None)
283
+ yield
284
+
285
+
286
+ async def _spawn_map_invocation(
287
+ function: "modal.functions._Function", raw_input_queue: _SynchronizedQueue, client: "modal.client._Client"
288
+ ) -> tuple[str, int]:
289
+ assert client.stub
290
+ request = api_pb2.FunctionMapRequest(
291
+ function_id=function.object_id,
292
+ parent_input_id=current_input_id() or "",
293
+ function_call_type=api_pb2.FUNCTION_CALL_TYPE_MAP,
294
+ function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC,
295
+ )
296
+ response: api_pb2.FunctionMapResponse = await retry_transient_errors(client.stub.FunctionMap, request)
297
+ function_call_id = response.function_call_id
298
+
299
+ have_all_inputs = False
300
+ inputs_created = 0
301
+
302
+ def set_inputs_created(set_inputs_created):
303
+ nonlocal inputs_created
304
+ assert set_inputs_created is None or set_inputs_created > inputs_created
305
+ inputs_created = set_inputs_created
306
+
307
+ def set_have_all_inputs():
308
+ nonlocal have_all_inputs
309
+ have_all_inputs = True
310
+
311
+ input_queue: asyncio.Queue[api_pb2.FunctionPutInputsItem | None] = asyncio.Queue()
312
+ input_preprocessor = InputPreprocessor(
313
+ client=client,
314
+ raw_input_queue=raw_input_queue,
315
+ processed_input_queue=input_queue,
316
+ function=function,
317
+ created_callback=set_inputs_created,
318
+ done_callback=set_have_all_inputs,
319
+ )
320
+
321
+ input_pumper = AsyncInputPumper(
322
+ client=client,
323
+ input_queue=input_queue,
324
+ function=function,
325
+ function_call_id=function_call_id,
326
+ )
327
+
328
+ def log_stats():
329
+ logger.debug(
330
+ f"have_all_inputs={have_all_inputs} inputs_created={inputs_created} inputs_sent={input_pumper.inputs_sent} "
331
+ )
332
+
333
+ async def log_task():
334
+ while True:
335
+ log_stats()
336
+ try:
337
+ await asyncio.sleep(10)
338
+ except asyncio.CancelledError:
339
+ # Log final stats before exiting
340
+ log_stats()
341
+ break
342
+
343
+ async def consume_generator(gen):
344
+ async for _ in gen:
345
+ pass
346
+
347
+ log_debug_stats_task = asyncio.create_task(log_task())
348
+ await asyncio.gather(
349
+ consume_generator(input_preprocessor.drain_input_generator()),
350
+ consume_generator(input_pumper.pump_inputs()),
351
+ )
352
+ log_debug_stats_task.cancel()
353
+ await log_debug_stats_task
354
+ return function_call_id, inputs_created
355
+
356
+
89
357
  async def _map_invocation(
90
358
  function: "modal.functions._Function",
91
359
  raw_input_queue: _SynchronizedQueue,
@@ -117,8 +385,6 @@ async def _map_invocation(
117
385
  have_all_inputs = False
118
386
  map_done_event = asyncio.Event()
119
387
  inputs_created = 0
120
- inputs_sent = 0
121
- inputs_retried = 0
122
388
  outputs_completed = 0
123
389
  outputs_received = 0
124
390
  retried_outputs = 0
@@ -135,25 +401,24 @@ async def _map_invocation(
135
401
  retry_policy, function_call_invocation_type, retry_queue, sync_client_retries_enabled, max_inputs_outstanding
136
402
  )
137
403
 
138
- async def create_input(argskwargs):
139
- idx = inputs_created
140
- update_state(set_inputs_created=inputs_created + 1)
141
- (args, kwargs) = argskwargs
142
- return await _create_input(
143
- args,
144
- kwargs,
145
- client.stub,
146
- max_object_size_bytes=function._max_object_size_bytes,
147
- idx=idx,
148
- method_name=function._use_method_name,
149
- )
404
+ input_preprocessor = InputPreprocessor(
405
+ client=client,
406
+ raw_input_queue=raw_input_queue,
407
+ processed_input_queue=input_queue,
408
+ function=function,
409
+ created_callback=lambda x: update_state(set_inputs_created=x),
410
+ done_callback=lambda: update_state(set_have_all_inputs=True),
411
+ )
150
412
 
151
- async def input_iter():
152
- while 1:
153
- raw_input = await raw_input_queue.get()
154
- if raw_input is None: # end of input sentinel
155
- break
156
- yield raw_input # args, kwargs
413
+ input_pumper = SyncInputPumper(
414
+ client=client,
415
+ input_queue=input_queue,
416
+ retry_queue=retry_queue,
417
+ function=function,
418
+ map_items_manager=map_items_manager,
419
+ function_call_jwt=function_call_jwt,
420
+ function_call_id=function_call_id,
421
+ )
157
422
 
158
423
  def update_state(set_have_all_inputs=None, set_inputs_created=None, set_outputs_completed=None):
159
424
  # This should be the only method that needs nonlocal of the following vars
@@ -175,84 +440,6 @@ async def _map_invocation(
175
440
  # map is done
176
441
  map_done_event.set()
177
442
 
178
- async def drain_input_generator():
179
- # Parallelize uploading blobs
180
- async with aclosing(
181
- async_map_ordered(input_iter(), create_input, concurrency=BLOB_MAX_PARALLELISM)
182
- ) as streamer:
183
- async for item in streamer:
184
- await input_queue.put(item)
185
-
186
- # close queue iterator
187
- await input_queue.put(None)
188
- update_state(set_have_all_inputs=True)
189
- yield
190
-
191
- async def pump_inputs():
192
- assert client.stub
193
- nonlocal inputs_sent
194
- async for items in queue_batch_iterator(input_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
195
- # Add items to the manager. Their state will be SENDING.
196
- await map_items_manager.add_items(items)
197
- request = api_pb2.FunctionPutInputsRequest(
198
- function_id=function.object_id,
199
- inputs=items,
200
- function_call_id=function_call_id,
201
- )
202
- logger.debug(
203
- f"Pushing {len(items)} inputs to server. Num queued inputs awaiting push is {input_queue.qsize()}."
204
- )
205
-
206
- resp = await send_inputs(client.stub.FunctionPutInputs, request)
207
- inputs_sent += len(items)
208
- # Change item state to WAITING_FOR_OUTPUT, and set the input_id and input_jwt which are in the response.
209
- map_items_manager.handle_put_inputs_response(resp.inputs)
210
- logger.debug(
211
- f"Successfully pushed {len(items)} inputs to server. "
212
- f"Num queued inputs awaiting push is {input_queue.qsize()}."
213
- )
214
- yield
215
-
216
- async def retry_inputs():
217
- nonlocal inputs_retried
218
- async for retriable_idxs in queue_batch_iterator(retry_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
219
- # For each index, use the context in the manager to create a FunctionRetryInputsItem.
220
- # This will also update the context state to RETRYING.
221
- inputs: list[api_pb2.FunctionRetryInputsItem] = await map_items_manager.prepare_items_for_retry(
222
- retriable_idxs
223
- )
224
- request = api_pb2.FunctionRetryInputsRequest(
225
- function_call_jwt=function_call_jwt,
226
- inputs=inputs,
227
- )
228
- resp = await send_inputs(client.stub.FunctionRetryInputs, request)
229
- # Update the state to WAITING_FOR_OUTPUT, and update the input_jwt in the context
230
- # to the new value in the response.
231
- map_items_manager.handle_retry_response(resp.input_jwts)
232
- logger.debug(f"Successfully pushed retry for {len(inputs)} to server.")
233
- inputs_retried += len(inputs)
234
- yield
235
-
236
- async def send_inputs(
237
- fn: "modal.client.UnaryUnaryWrapper",
238
- request: typing.Union[api_pb2.FunctionPutInputsRequest, api_pb2.FunctionRetryInputsRequest],
239
- ) -> typing.Union[api_pb2.FunctionPutInputsResponse, api_pb2.FunctionRetryInputsResponse]:
240
- # with 8 retries we log the warning below about every 30 seconds which isn't too spammy.
241
- retry_warning_message = RetryWarningMessage(
242
- message=f"Warning: map progress for function {function._function_name} is limited."
243
- " Common bottlenecks include slow iteration over results, or function backlogs.",
244
- warning_interval=8,
245
- errors_to_warn_for=[Status.RESOURCE_EXHAUSTED],
246
- )
247
- return await retry_transient_errors(
248
- fn,
249
- request,
250
- max_retries=None,
251
- max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
252
- additional_status_codes=[Status.RESOURCE_EXHAUSTED],
253
- retry_warning_message=retry_warning_message,
254
- )
255
-
256
443
  async def get_all_outputs():
257
444
  assert client.stub
258
445
  nonlocal \
@@ -395,8 +582,11 @@ async def _map_invocation(
395
582
  def log_stats():
396
583
  logger.debug(
397
584
  f"Map stats: sync_client_retries_enabled={sync_client_retries_enabled} "
398
- f"have_all_inputs={have_all_inputs} inputs_created={inputs_created} input_sent={inputs_sent} "
399
- f"inputs_retried={inputs_retried} outputs_received={outputs_received} "
585
+ f"have_all_inputs={have_all_inputs} "
586
+ f"inputs_created={inputs_created} "
587
+ f"input_sent={input_pumper.inputs_sent} "
588
+ f"inputs_retried={input_pumper.inputs_retried} "
589
+ f"outputs_received={outputs_received} "
400
590
  f"successful_completions={successful_completions} failed_completions={failed_completions} "
401
591
  f"no_context_duplicates={no_context_duplicates} old_retry_duplicates={stale_retry_duplicates} "
402
592
  f"already_complete_duplicates={already_complete_duplicates} "
@@ -415,7 +605,12 @@ async def _map_invocation(
415
605
 
416
606
  log_debug_stats_task = asyncio.create_task(log_debug_stats())
417
607
  async with aclosing(
418
- async_merge(drain_input_generator(), pump_inputs(), poll_outputs(), retry_inputs())
608
+ async_merge(
609
+ input_preprocessor.drain_input_generator(),
610
+ input_pumper.pump_inputs(),
611
+ input_pumper.retry_inputs(),
612
+ poll_outputs(),
613
+ )
419
614
  ) as streamer:
420
615
  async for response in streamer:
421
616
  if response is not None: # type: ignore[unreachable]
@@ -962,6 +1157,56 @@ def _map_sync(
962
1157
  )
963
1158
 
964
1159
 
1160
+ async def _experimental_spawn_map_async(self, *input_iterators, kwargs={}) -> "modal.functions._FunctionCall":
1161
+ async_input_gen = async_zip(*[sync_or_async_iter(it) for it in input_iterators])
1162
+ return await _spawn_map_helper(self, async_input_gen, kwargs)
1163
+
1164
+
1165
+ async def _spawn_map_helper(
1166
+ self: "modal.functions.Function", async_input_gen, kwargs={}
1167
+ ) -> "modal.functions._FunctionCall":
1168
+ raw_input_queue: Any = SynchronizedQueue() # type: ignore
1169
+ await raw_input_queue.init.aio()
1170
+
1171
+ async def feed_queue():
1172
+ async with aclosing(async_input_gen) as streamer:
1173
+ async for args in streamer:
1174
+ await raw_input_queue.put.aio((args, kwargs))
1175
+ await raw_input_queue.put.aio(None) # end-of-input sentinel
1176
+
1177
+ fc, _ = await asyncio.gather(self._spawn_map.aio(raw_input_queue), feed_queue())
1178
+ return fc
1179
+
1180
+
1181
+ def _experimental_spawn_map_sync(self, *input_iterators, kwargs={}) -> "modal.functions._FunctionCall":
1182
+ """mdmd:hidden
1183
+ Spawn parallel execution over a set of inputs, returning as soon as the inputs are created.
1184
+
1185
+ Unlike `modal.Function.map`, this method does not block on completion of the remote execution but
1186
+ returns a `modal.FunctionCall` object that can be used to poll status and retrieve results later.
1187
+
1188
+ Takes one iterator argument per argument in the function being mapped over.
1189
+
1190
+ Example:
1191
+ ```python
1192
+ @app.function()
1193
+ def my_func(a, b):
1194
+ return a ** b
1195
+
1196
+
1197
+ @app.local_entrypoint()
1198
+ def main():
1199
+ fc = my_func.spawn_map([1, 2], [3, 4])
1200
+ ```
1201
+
1202
+ """
1203
+
1204
+ return run_coroutine_in_temporary_event_loop(
1205
+ _experimental_spawn_map_async(self, *input_iterators, kwargs=kwargs),
1206
+ "You can't run Function.spawn_map() from an async function. Use Function.spawn_map.aio() instead.",
1207
+ )
1208
+
1209
+
965
1210
  async def _spawn_map_async(self, *input_iterators, kwargs={}) -> None:
966
1211
  """This runs in an event loop on the main thread. It consumes inputs from the input iterators and creates async
967
1212
  function calls for each.
modal/parallel_map.pyi CHANGED
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import asyncio.events
3
+ import asyncio.queues
3
4
  import collections.abc
4
5
  import enum
5
6
  import modal._functions
@@ -60,6 +61,86 @@ class _OutputValue:
60
61
  """Return self==value."""
61
62
  ...
62
63
 
64
+ class InputPreprocessor:
65
+ """Constructs FunctionPutInputsItem objects from the raw-input queue, and puts them in the processed-input queue."""
66
+ def __init__(
67
+ self,
68
+ client: modal.client._Client,
69
+ *,
70
+ raw_input_queue: _SynchronizedQueue,
71
+ processed_input_queue: asyncio.queues.Queue,
72
+ function: modal._functions._Function,
73
+ created_callback: collections.abc.Callable[[int], None],
74
+ done_callback: collections.abc.Callable[[], None],
75
+ ):
76
+ """Initialize self. See help(type(self)) for accurate signature."""
77
+ ...
78
+
79
+ def input_iter(self): ...
80
+ def create_input_factory(self): ...
81
+ def drain_input_generator(self): ...
82
+
83
+ class InputPumper:
84
+ """Reads inputs from a queue of FunctionPutInputsItems, and sends them to the server."""
85
+ def __init__(
86
+ self,
87
+ client: modal.client._Client,
88
+ *,
89
+ input_queue: asyncio.queues.Queue,
90
+ function: modal._functions._Function,
91
+ function_call_id: str,
92
+ map_items_manager: typing.Optional[_MapItemsManager] = None,
93
+ ):
94
+ """Initialize self. See help(type(self)) for accurate signature."""
95
+ ...
96
+
97
+ def pump_inputs(self): ...
98
+ async def _send_inputs(
99
+ self,
100
+ fn: modal.client.UnaryUnaryWrapper,
101
+ request: typing.Union[
102
+ modal_proto.api_pb2.FunctionPutInputsRequest, modal_proto.api_pb2.FunctionRetryInputsRequest
103
+ ],
104
+ ) -> typing.Union[
105
+ modal_proto.api_pb2.FunctionPutInputsResponse, modal_proto.api_pb2.FunctionRetryInputsResponse
106
+ ]: ...
107
+
108
+ class SyncInputPumper(InputPumper):
109
+ """Reads inputs from a queue of FunctionPutInputsItems, and sends them to the server."""
110
+ def __init__(
111
+ self,
112
+ client: modal.client._Client,
113
+ *,
114
+ input_queue: asyncio.queues.Queue,
115
+ retry_queue: modal._utils.async_utils.TimestampPriorityQueue,
116
+ function: modal._functions._Function,
117
+ function_call_jwt: str,
118
+ function_call_id: str,
119
+ map_items_manager: _MapItemsManager,
120
+ ):
121
+ """Initialize self. See help(type(self)) for accurate signature."""
122
+ ...
123
+
124
+ def retry_inputs(self): ...
125
+
126
+ class AsyncInputPumper(InputPumper):
127
+ """Reads inputs from a queue of FunctionPutInputsItems, and sends them to the server."""
128
+ def __init__(
129
+ self,
130
+ client: modal.client._Client,
131
+ *,
132
+ input_queue: asyncio.queues.Queue,
133
+ function: modal._functions._Function,
134
+ function_call_id: str,
135
+ ):
136
+ """Initialize self. See help(type(self)) for accurate signature."""
137
+ ...
138
+
139
+ def pump_inputs(self): ...
140
+
141
+ async def _spawn_map_invocation(
142
+ function: modal._functions._Function, raw_input_queue: _SynchronizedQueue, client: modal.client._Client
143
+ ) -> tuple[str, int]: ...
63
144
  def _map_invocation(
64
145
  function: modal._functions._Function,
65
146
  raw_input_queue: _SynchronizedQueue,
@@ -179,6 +260,33 @@ def _map_sync(
179
260
  """
180
261
  ...
181
262
 
263
+ async def _experimental_spawn_map_async(self, *input_iterators, kwargs={}) -> modal._functions._FunctionCall: ...
264
+ async def _spawn_map_helper(
265
+ self: modal.functions.Function, async_input_gen, kwargs={}
266
+ ) -> modal._functions._FunctionCall: ...
267
+ def _experimental_spawn_map_sync(self, *input_iterators, kwargs={}) -> modal._functions._FunctionCall:
268
+ """mdmd:hidden
269
+ Spawn parallel execution over a set of inputs, returning as soon as the inputs are created.
270
+
271
+ Unlike `modal.Function.map`, this method does not block on completion of the remote execution but
272
+ returns a `modal.FunctionCall` object that can be used to poll status and retrieve results later.
273
+
274
+ Takes one iterator argument per argument in the function being mapped over.
275
+
276
+ Example:
277
+ ```python
278
+ @app.function()
279
+ def my_func(a, b):
280
+ return a ** b
281
+
282
+
283
+ @app.local_entrypoint()
284
+ def main():
285
+ fc = my_func.spawn_map([1, 2], [3, 4])
286
+ ```
287
+ """
288
+ ...
289
+
182
290
  async def _spawn_map_async(self, *input_iterators, kwargs={}) -> None:
183
291
  """This runs in an event loop on the main thread. It consumes inputs from the input iterators and creates async
184
292
  function calls for each.
modal/proxy.py CHANGED
@@ -36,7 +36,8 @@ class _Proxy(_Object, type_prefix="pr"):
36
36
  response: api_pb2.ProxyGetResponse = await resolver.client.stub.ProxyGet(req)
37
37
  self._hydrate(response.proxy.proxy_id, resolver.client, None)
38
38
 
39
- return _Proxy._from_loader(_load, "Proxy()", is_another_app=True)
39
+ rep = _Proxy._repr(name, environment_name)
40
+ return _Proxy._from_loader(_load, rep, is_another_app=True)
40
41
 
41
42
 
42
43
  Proxy = synchronize_api(_Proxy, target_module=__name__)