modal 1.0.3.dev28__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -789,12 +789,12 @@ def _batched(
789
789
  )
790
790
  if max_batch_size < 1:
791
791
  raise InvalidError("max_batch_size must be a positive integer.")
792
- if max_batch_size >= MAX_MAX_BATCH_SIZE:
793
- raise InvalidError(f"max_batch_size must be less than {MAX_MAX_BATCH_SIZE}.")
792
+ if max_batch_size > MAX_MAX_BATCH_SIZE:
793
+ raise InvalidError(f"max_batch_size cannot be greater than {MAX_MAX_BATCH_SIZE}.")
794
794
  if wait_ms < 0:
795
795
  raise InvalidError("wait_ms must be a non-negative integer.")
796
- if wait_ms >= MAX_BATCH_WAIT_MS:
797
- raise InvalidError(f"wait_ms must be less than {MAX_BATCH_WAIT_MS}.")
796
+ if wait_ms > MAX_BATCH_WAIT_MS:
797
+ raise InvalidError(f"wait_ms cannot be greater than {MAX_BATCH_WAIT_MS}.")
798
798
 
799
799
  flags = _PartialFunctionFlags.CALLABLE_INTERFACE | _PartialFunctionFlags.BATCHED
800
800
  params = _PartialFunctionParams(batch_max_size=max_batch_size, batch_wait_ms=wait_ms)
@@ -647,7 +647,9 @@ class StopSentinelType: ...
647
647
  STOP_SENTINEL = StopSentinelType()
648
648
 
649
649
 
650
- async def async_merge(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
650
+ async def async_merge(
651
+ *generators: AsyncGenerator[T, None], cancellation_timeout: float = 10.0
652
+ ) -> AsyncGenerator[T, None]:
651
653
  """
652
654
  Asynchronously merges multiple async generators into a single async generator.
653
655
 
@@ -692,8 +694,9 @@ async def async_merge(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T,
692
694
 
693
695
  async def producer(generator: AsyncGenerator[T, None]):
694
696
  try:
695
- async for item in generator:
696
- await queue.put(ValueWrapper(item))
697
+ async with aclosing(generator) as stream:
698
+ async for item in stream:
699
+ await queue.put(ValueWrapper(item))
697
700
  except Exception as e:
698
701
  await queue.put(ExceptionWrapper(e))
699
702
 
@@ -735,15 +738,20 @@ async def async_merge(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T,
735
738
  new_output_task = asyncio.create_task(queue.get())
736
739
 
737
740
  finally:
738
- if not new_output_task.done():
739
- new_output_task.cancel()
740
- for task in tasks:
741
- if not task.done():
742
- try:
743
- task.cancel()
744
- await task
745
- except asyncio.CancelledError:
746
- pass
741
+ unfinished_tasks = [t for t in tasks | {new_output_task} if not t.done()]
742
+ for t in unfinished_tasks:
743
+ t.cancel()
744
+ try:
745
+ await asyncio.wait_for(
746
+ asyncio.shield(
747
+ # we need to `shield` the `gather` to ensure cooperation with the timeout
748
+ # all underlying tasks have been marked as cancelled at this point anyway
749
+ asyncio.gather(*unfinished_tasks, return_exceptions=True)
750
+ ),
751
+ timeout=cancellation_timeout,
752
+ )
753
+ except asyncio.TimeoutError:
754
+ logger.debug("Timed out while cleaning up async_merge")
747
755
 
748
756
 
749
757
  async def callable_to_agen(awaitable: Callable[[], Awaitable[T]]) -> AsyncGenerator[T, None]:
@@ -761,16 +769,34 @@ async def gather_cancel_on_exc(*coros_or_futures):
761
769
  raise
762
770
 
763
771
 
772
+ async def prevent_cancellation_abortion(coro):
773
+ # if this is cancelled, it will wait for coro cancellation handling
774
+ # and then unconditionally re-raises a CancelledError, even if the underlying coro
775
+ # doesn't re-raise the cancellation itself
776
+ t = asyncio.create_task(coro)
777
+ try:
778
+ return await asyncio.shield(t)
779
+ except asyncio.CancelledError:
780
+ if t.cancelled():
781
+ # coro cancelled itself - reraise
782
+ raise
783
+ t.cancel() # cancel task
784
+ await t # this *normally* reraises
785
+ raise # if the above somehow resolved, by swallowing cancellation - we still raise
786
+
787
+
764
788
  async def async_map(
765
789
  input_generator: AsyncGenerator[T, None],
766
790
  async_mapper_func: Callable[[T], Awaitable[V]],
767
791
  concurrency: int,
792
+ cancellation_timeout: float = 10.0,
768
793
  ) -> AsyncGenerator[V, None]:
769
794
  queue: asyncio.Queue[Union[ValueWrapper[T], StopSentinelType]] = asyncio.Queue(maxsize=concurrency * 2)
770
795
 
771
796
  async def producer() -> AsyncGenerator[V, None]:
772
- async for item in input_generator:
773
- await queue.put(ValueWrapper(item))
797
+ async with aclosing(input_generator) as stream:
798
+ async for item in stream:
799
+ await queue.put(ValueWrapper(item))
774
800
 
775
801
  for _ in range(concurrency):
776
802
  await queue.put(STOP_SENTINEL)
@@ -784,14 +810,17 @@ async def async_map(
784
810
  while True:
785
811
  item = await queue.get()
786
812
  if isinstance(item, ValueWrapper):
787
- yield await async_mapper_func(item.value)
813
+ res = await prevent_cancellation_abortion(async_mapper_func(item.value))
814
+ yield res
788
815
  elif isinstance(item, ExceptionWrapper):
789
816
  raise item.value
790
817
  else:
791
818
  assert_type(item, StopSentinelType)
792
819
  break
793
820
 
794
- async with aclosing(async_merge(*[worker() for _ in range(concurrency)], producer())) as stream:
821
+ async with aclosing(
822
+ async_merge(*[worker() for i in range(concurrency)], producer(), cancellation_timeout=cancellation_timeout)
823
+ ) as stream:
795
824
  async for item in stream:
796
825
  yield item
797
826
 
@@ -203,8 +203,11 @@ async def rm(
203
203
  ):
204
204
  ensure_env(env)
205
205
  volume = _NetworkFileSystem.from_name(volume_name)
206
+ console = Console()
206
207
  try:
207
208
  await volume.remove_file(remote_path, recursive=recursive)
209
+ console.print(OutputManager.step_completed(f"{remote_path} was deleted successfully!"))
210
+
208
211
  except GRPCError as exc:
209
212
  if exc.status in (Status.NOT_FOUND, Status.INVALID_ARGUMENT):
210
213
  raise UsageError(exc.message)
modal/cli/run.py CHANGED
@@ -27,6 +27,7 @@ from ..functions import Function
27
27
  from ..image import Image
28
28
  from ..output import enable_output
29
29
  from ..runner import deploy_app, interactive_shell, run_app
30
+ from ..secret import Secret
30
31
  from ..serving import serve_app
31
32
  from ..volume import Volume
32
33
  from .import_refs import (
@@ -531,6 +532,10 @@ def shell(
531
532
  " Can be used multiple times."
532
533
  ),
533
534
  ),
535
+ secret: Optional[list[str]] = typer.Option(
536
+ default=None,
537
+ help=("Name of a `modal.Secret` to mount inside the shell (if not using REF). Can be used multiple times."),
538
+ ),
534
539
  cpu: Optional[int] = typer.Option(default=None, help="Number of CPUs to allocate to the shell (if not using REF)."),
535
540
  memory: Optional[int] = typer.Option(
536
541
  default=None, help="Memory to allocate for the shell, in MiB (if not using REF)."
@@ -660,6 +665,7 @@ def shell(
660
665
  else:
661
666
  modal_image = Image.from_registry(image, add_python=add_python) if image else None
662
667
  volumes = {} if volume is None else {f"/mnt/{vol}": Volume.from_name(vol) for vol in volume}
668
+ secrets = [] if secret is None else [Secret.from_name(s) for s in secret]
663
669
  start_shell = partial(
664
670
  interactive_shell,
665
671
  image=modal_image,
@@ -668,6 +674,7 @@ def shell(
668
674
  gpu=gpu,
669
675
  cloud=cloud,
670
676
  volumes=volumes,
677
+ secrets=secrets,
671
678
  region=region.split(",") if region else [],
672
679
  pty=pty,
673
680
  )
modal/cli/volume.py CHANGED
@@ -245,8 +245,10 @@ async def rm(
245
245
  ):
246
246
  ensure_env(env)
247
247
  volume = _Volume.from_name(volume_name, environment_name=env)
248
+ console = Console()
248
249
  try:
249
250
  await volume.remove_file(remote_path, recursive=recursive)
251
+ console.print(OutputManager.step_completed(f"{remote_path} was deleted successfully!"))
250
252
  except GRPCError as exc:
251
253
  if exc.status in (Status.NOT_FOUND, Status.INVALID_ARGUMENT):
252
254
  raise UsageError(exc.message)
modal/client.pyi CHANGED
@@ -27,11 +27,7 @@ class _Client:
27
27
  _snapshotted: bool
28
28
 
29
29
  def __init__(
30
- self,
31
- server_url: str,
32
- client_type: int,
33
- credentials: typing.Optional[tuple[str, str]],
34
- version: str = "1.0.3.dev28",
30
+ self, server_url: str, client_type: int, credentials: typing.Optional[tuple[str, str]], version: str = "1.0.4"
35
31
  ): ...
36
32
  def is_closed(self) -> bool: ...
37
33
  @property
@@ -90,11 +86,7 @@ class Client:
90
86
  _snapshotted: bool
91
87
 
92
88
  def __init__(
93
- self,
94
- server_url: str,
95
- client_type: int,
96
- credentials: typing.Optional[tuple[str, str]],
97
- version: str = "1.0.3.dev28",
89
+ self, server_url: str, client_type: int, credentials: typing.Optional[tuple[str, str]], version: str = "1.0.4"
98
90
  ): ...
99
91
  def is_closed(self) -> bool: ...
100
92
  @property
modal/cls.py CHANGED
@@ -75,6 +75,7 @@ def _get_class_constructor_signature(user_cls: type) -> inspect.Signature:
75
75
 
76
76
  @dataclasses.dataclass()
77
77
  class _ServiceOptions:
78
+ # Note that default values should always be "untruthy" so we can detect when they are not set
78
79
  secrets: typing.Collection[_Secret] = ()
79
80
  validated_volumes: typing.Sequence[tuple[str, _Volume]] = ()
80
81
  resources: Optional[api_pb2.Resources] = None
@@ -88,6 +89,25 @@ class _ServiceOptions:
88
89
  batch_max_size: Optional[int] = None
89
90
  batch_wait_ms: Optional[int] = None
90
91
 
92
+ def merge_options(self, new_options: "_ServiceOptions") -> "_ServiceOptions":
93
+ """Implement protobuf-like MergeFrom semantics for this dataclass.
94
+
95
+ This mostly exists to support "stacking" of `.with_options()` calls.
96
+ """
97
+ new_options_dict = dataclasses.asdict(new_options)
98
+
99
+ # Resources needs special merge handling because individual fields are parameters in the public API
100
+ merged_resources = api_pb2.Resources()
101
+ if self.resources:
102
+ merged_resources.MergeFrom(self.resources)
103
+ if new_resources := new_options_dict.pop("resources"):
104
+ merged_resources.MergeFrom(new_resources)
105
+ self.resources = merged_resources
106
+
107
+ for key, value in new_options_dict.items():
108
+ if value: # Only overwrite data when the value was set in the new options
109
+ setattr(self, key, value)
110
+
91
111
 
92
112
  def _bind_instance_method(cls: "_Cls", service_function: _Function, method_name: str):
93
113
  """Binds an "instance service function" to a specific method using metadata for that method
@@ -664,15 +684,32 @@ More information on class parameterization can be found here: https://modal.com/
664
684
  container_idle_timeout: Optional[int] = None, # Now called `scaledown_window`
665
685
  allow_concurrent_inputs: Optional[int] = None, # See `.with_concurrency`
666
686
  ) -> "_Cls":
667
- """Create an instance of the Cls with configuration options overridden with new values.
687
+ """Override the static Function configuration at runtime.
688
+
689
+ This method will return a new instance of the cls that will autoscale independently of the
690
+ original instance. Note that options cannot be "unset" with this method (i.e., if a GPU
691
+ is configured in the `@app.cls()` decorator, passing `gpu=None` here will not create a
692
+ CPU-only instance).
668
693
 
669
694
  **Usage:**
670
695
 
696
+ You can use this method after looking up the Cls from a deployed App or if you have a
697
+ direct reference to a Cls from another Function or local entrypoint on its App:
698
+
671
699
  ```python notest
672
700
  Model = modal.Cls.from_name("my_app", "Model")
673
701
  ModelUsingGPU = Model.with_options(gpu="A100")
674
- ModelUsingGPU().generate.remote(42) # will run with an A100 GPU
702
+ ModelUsingGPU().generate.remote(input_prompt) # Run with an A100 GPU
675
703
  ```
704
+
705
+ The method can be called multiple times to "stack" updates:
706
+
707
+ ```python notest
708
+ Model.with_options(gpu="A100").with_options(scaledown_window=300) # Use an A100 with slow scaledown
709
+ ```
710
+
711
+ Note that container arguments (i.e. `volumes` and `secrets`) passed in subsequent calls
712
+ will not be merged.
676
713
  """
677
714
  retry_policy = _parse_retries(retries, f"Class {self.__name__}" if self._user_cls else "")
678
715
  if gpu or cpu or memory:
@@ -705,21 +742,23 @@ More information on class parameterization can be found here: https://modal.com/
705
742
 
706
743
  cls = _Cls._from_loader(_load_from_base, rep=f"{self._name}.with_options(...)", is_another_app=True, deps=_deps)
707
744
  cls._initialize_from_other(self)
708
- cls._options = dataclasses.replace(
709
- cls._options,
745
+
746
+ new_options = _ServiceOptions(
710
747
  secrets=secrets,
748
+ validated_volumes=validate_volumes(volumes),
711
749
  resources=resources,
712
750
  retry_policy=retry_policy,
713
751
  max_containers=max_containers,
714
752
  buffer_containers=buffer_containers,
715
753
  scaledown_window=scaledown_window,
716
754
  timeout_secs=timeout,
717
- validated_volumes=validate_volumes(volumes),
718
755
  # Note: set both for backwards / forwards compatibility
719
756
  # But going forward `.with_concurrency` is the preferred method with distinct parameterization
720
757
  max_concurrent_inputs=allow_concurrent_inputs,
721
758
  target_concurrent_inputs=allow_concurrent_inputs,
722
759
  )
760
+
761
+ cls._options.merge_options(new_options)
723
762
  return cls
724
763
 
725
764
  def with_concurrency(self: "_Cls", *, max_inputs: int, target_inputs: Optional[int] = None) -> "_Cls":
@@ -746,9 +785,9 @@ More information on class parameterization can be found here: https://modal.com/
746
785
  _load_from_base, rep=f"{self._name}.with_concurrency(...)", is_another_app=True, deps=_deps
747
786
  )
748
787
  cls._initialize_from_other(self)
749
- cls._options = dataclasses.replace(
750
- cls._options, max_concurrent_inputs=max_inputs, target_concurrent_inputs=target_inputs
751
- )
788
+
789
+ concurrency_options = _ServiceOptions(max_concurrent_inputs=max_inputs, target_concurrent_inputs=target_inputs)
790
+ cls._options.merge_options(concurrency_options)
752
791
  return cls
753
792
 
754
793
  def with_batching(self: "_Cls", *, max_batch_size: int, wait_ms: int) -> "_Cls":
@@ -775,7 +814,9 @@ More information on class parameterization can be found here: https://modal.com/
775
814
  _load_from_base, rep=f"{self._name}.with_concurrency(...)", is_another_app=True, deps=_deps
776
815
  )
777
816
  cls._initialize_from_other(self)
778
- cls._options = dataclasses.replace(cls._options, batch_max_size=max_batch_size, batch_wait_ms=wait_ms)
817
+
818
+ batching_options = _ServiceOptions(batch_max_size=max_batch_size, batch_wait_ms=wait_ms)
819
+ cls._options.merge_options(batching_options)
779
820
  return cls
780
821
 
781
822
  @staticmethod
modal/cls.pyi CHANGED
@@ -37,6 +37,7 @@ class _ServiceOptions:
37
37
  batch_max_size: typing.Optional[int]
38
38
  batch_wait_ms: typing.Optional[int]
39
39
 
40
+ def merge_options(self, new_options: _ServiceOptions) -> _ServiceOptions: ...
40
41
  def __init__(
41
42
  self,
42
43
  secrets: typing.Collection[modal.secret._Secret] = (),
modal/file_io.py CHANGED
@@ -144,11 +144,12 @@ class _FileIO(Generic[T]):
144
144
  _task_id: str = ""
145
145
  _file_descriptor: str = ""
146
146
  _client: _Client
147
- _watch_output_buffer: list[Optional[bytes]] = []
147
+ _watch_output_buffer: list[Union[Optional[bytes],Exception]] = []
148
148
 
149
149
  def __init__(self, client: _Client, task_id: str) -> None:
150
150
  self._client = client
151
151
  self._task_id = task_id
152
+ self._watch_output_buffer = []
152
153
 
153
154
  def _validate_mode(self, mode: str) -> None:
154
155
  if not any(char in mode for char in "rwax"):
@@ -173,11 +174,7 @@ class _FileIO(Generic[T]):
173
174
  raise ValueError(f"Invalid file mode: {mode}")
174
175
  seen_chars.add(char)
175
176
 
176
- def _handle_error(self, error: api_pb2.SystemErrorMessage) -> None:
177
- error_class = ERROR_MAPPING.get(error.error_code, FilesystemExecutionError)
178
- raise error_class(error.error_message)
179
-
180
- async def _consume_output(self, exec_id: str) -> AsyncIterator[Optional[bytes]]:
177
+ async def _consume_output(self, exec_id: str) -> AsyncIterator[Union[Optional[bytes], Exception]]:
181
178
  req = api_pb2.ContainerFilesystemExecGetOutputRequest(
182
179
  exec_id=exec_id,
183
180
  timeout=55,
@@ -187,7 +184,8 @@ class _FileIO(Generic[T]):
187
184
  yield None
188
185
  break
189
186
  if batch.HasField("error"):
190
- self._handle_error(batch.error)
187
+ error_class = ERROR_MAPPING.get(batch.error.error_code, FilesystemExecutionError)
188
+ yield error_class(batch.error.error_message)
191
189
  for message in batch.output:
192
190
  yield message
193
191
 
@@ -236,6 +234,8 @@ class _FileIO(Generic[T]):
236
234
  if data is None:
237
235
  completed = True
238
236
  break
237
+ if isinstance(data, Exception):
238
+ raise data
239
239
  output += data
240
240
  except (GRPCError, StreamTerminatedError) as exc:
241
241
  if retries_remaining > 0:
@@ -475,6 +475,8 @@ class _FileIO(Generic[T]):
475
475
  item = self._watch_output_buffer.pop(0)
476
476
  if item is None:
477
477
  break
478
+ if isinstance(item, Exception):
479
+ raise item
478
480
  buffer += item
479
481
  # a single event may be split across multiple messages
480
482
  # the end of an event is marked by two newlines
modal/file_io.pyi CHANGED
@@ -1,7 +1,6 @@
1
1
  import _typeshed
2
2
  import enum
3
3
  import modal.client
4
- import modal_proto.api_pb2
5
4
  import typing
6
5
  import typing_extensions
7
6
 
@@ -33,12 +32,11 @@ class _FileIO(typing.Generic[T]):
33
32
  _task_id: str
34
33
  _file_descriptor: str
35
34
  _client: modal.client._Client
36
- _watch_output_buffer: list[typing.Optional[bytes]]
35
+ _watch_output_buffer: list[typing.Union[bytes, None, Exception]]
37
36
 
38
37
  def __init__(self, client: modal.client._Client, task_id: str) -> None: ...
39
38
  def _validate_mode(self, mode: str) -> None: ...
40
- def _handle_error(self, error: modal_proto.api_pb2.SystemErrorMessage) -> None: ...
41
- def _consume_output(self, exec_id: str) -> typing.AsyncIterator[typing.Optional[bytes]]: ...
39
+ def _consume_output(self, exec_id: str) -> typing.AsyncIterator[typing.Union[bytes, None, Exception]]: ...
42
40
  async def _consume_watch_output(self, exec_id: str) -> None: ...
43
41
  async def _parse_watch_output(self, event: bytes) -> typing.Optional[FileWatchEvent]: ...
44
42
  async def _wait(self, exec_id: str) -> bytes: ...
@@ -112,15 +110,14 @@ class FileIO(typing.Generic[T]):
112
110
  _task_id: str
113
111
  _file_descriptor: str
114
112
  _client: modal.client.Client
115
- _watch_output_buffer: list[typing.Optional[bytes]]
113
+ _watch_output_buffer: list[typing.Union[bytes, None, Exception]]
116
114
 
117
115
  def __init__(self, client: modal.client.Client, task_id: str) -> None: ...
118
116
  def _validate_mode(self, mode: str) -> None: ...
119
- def _handle_error(self, error: modal_proto.api_pb2.SystemErrorMessage) -> None: ...
120
117
 
121
118
  class ___consume_output_spec(typing_extensions.Protocol[SUPERSELF]):
122
- def __call__(self, /, exec_id: str) -> typing.Iterator[typing.Optional[bytes]]: ...
123
- def aio(self, /, exec_id: str) -> typing.AsyncIterator[typing.Optional[bytes]]: ...
119
+ def __call__(self, /, exec_id: str) -> typing.Iterator[typing.Union[bytes, None, Exception]]: ...
120
+ def aio(self, /, exec_id: str) -> typing.AsyncIterator[typing.Union[bytes, None, Exception]]: ...
124
121
 
125
122
  _consume_output: ___consume_output_spec[typing_extensions.Self]
126
123
 
modal/functions.pyi CHANGED
@@ -227,11 +227,11 @@ class Function(
227
227
 
228
228
  _call_generator: ___call_generator_spec[typing_extensions.Self]
229
229
 
230
- class __remote_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER, SUPERSELF]):
230
+ class __remote_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER, SUPERSELF]):
231
231
  def __call__(self, /, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> ReturnType_INNER: ...
232
232
  async def aio(self, /, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> ReturnType_INNER: ...
233
233
 
234
- remote: __remote_spec[modal._functions.P, modal._functions.ReturnType, typing_extensions.Self]
234
+ remote: __remote_spec[modal._functions.ReturnType, modal._functions.P, typing_extensions.Self]
235
235
 
236
236
  class __remote_gen_spec(typing_extensions.Protocol[SUPERSELF]):
237
237
  def __call__(self, /, *args, **kwargs) -> typing.Generator[typing.Any, None, None]: ...
@@ -246,12 +246,12 @@ class Function(
246
246
  self, *args: modal._functions.P.args, **kwargs: modal._functions.P.kwargs
247
247
  ) -> modal._functions.OriginalReturnType: ...
248
248
 
249
- class ___experimental_spawn_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER, SUPERSELF]):
249
+ class ___experimental_spawn_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER, SUPERSELF]):
250
250
  def __call__(self, /, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
251
251
  async def aio(self, /, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
252
252
 
253
253
  _experimental_spawn: ___experimental_spawn_spec[
254
- modal._functions.P, modal._functions.ReturnType, typing_extensions.Self
254
+ modal._functions.ReturnType, modal._functions.P, typing_extensions.Self
255
255
  ]
256
256
 
257
257
  class ___spawn_map_inner_spec(typing_extensions.Protocol[P_INNER, SUPERSELF]):
@@ -260,11 +260,11 @@ class Function(
260
260
 
261
261
  _spawn_map_inner: ___spawn_map_inner_spec[modal._functions.P, typing_extensions.Self]
262
262
 
263
- class __spawn_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER, SUPERSELF]):
263
+ class __spawn_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER, SUPERSELF]):
264
264
  def __call__(self, /, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
265
265
  async def aio(self, /, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
266
266
 
267
- spawn: __spawn_spec[modal._functions.P, modal._functions.ReturnType, typing_extensions.Self]
267
+ spawn: __spawn_spec[modal._functions.ReturnType, modal._functions.P, typing_extensions.Self]
268
268
 
269
269
  def get_raw_f(self) -> collections.abc.Callable[..., typing.Any]: ...
270
270
 
modal/mount.py CHANGED
@@ -290,6 +290,7 @@ class _Mount(_Object, type_prefix="mo"):
290
290
  _deployment_name: Optional[str] = None
291
291
  _namespace: Optional[int] = None
292
292
  _environment_name: Optional[str] = None
293
+ _allow_overwrite: bool = False
293
294
  _content_checksum_sha256_hex: Optional[str] = None
294
295
 
295
296
  @staticmethod
@@ -600,11 +601,16 @@ class _Mount(_Object, type_prefix="mo"):
600
601
  # Build the mount.
601
602
  status_row.message(f"Creating mount {message_label}: Finalizing index of {len(files)} files")
602
603
  if self._deployment_name:
604
+ creation_type = (
605
+ api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING
606
+ if self._allow_overwrite
607
+ else api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS
608
+ )
603
609
  req = api_pb2.MountGetOrCreateRequest(
604
610
  deployment_name=self._deployment_name,
605
611
  namespace=self._namespace,
606
612
  environment_name=self._environment_name,
607
- object_creation_type=api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS,
613
+ object_creation_type=creation_type,
608
614
  files=files,
609
615
  )
610
616
  elif resolver.app_id is not None:
@@ -736,7 +742,9 @@ class _Mount(_Object, type_prefix="mo"):
736
742
  self: "_Mount",
737
743
  deployment_name: Optional[str] = None,
738
744
  namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE,
745
+ *,
739
746
  environment_name: Optional[str] = None,
747
+ allow_overwrite: bool = False,
740
748
  client: Optional[_Client] = None,
741
749
  ) -> None:
742
750
  check_object_name(deployment_name, "Mount")
@@ -744,6 +752,7 @@ class _Mount(_Object, type_prefix="mo"):
744
752
  self._deployment_name = deployment_name
745
753
  self._namespace = namespace
746
754
  self._environment_name = environment_name
755
+ self._allow_overwrite = allow_overwrite
747
756
  if client is None:
748
757
  client = await _Client.from_env()
749
758
  resolver = Resolver(client=client, environment_name=environment_name)
@@ -826,35 +835,34 @@ import sys; sys.path.append('{REMOTE_PACKAGES_PATH}')
826
835
  """.strip()
827
836
 
828
837
 
829
- async def _create_single_mount(
838
+ async def _create_single_client_dependency_mount(
830
839
  client: _Client,
831
840
  builder_version: str,
832
841
  python_version: str,
833
- platform: str,
834
842
  arch: str,
835
- uv_python_platform: str = None,
843
+ platform: str,
844
+ uv_python_platform: str,
836
845
  check_if_exists: bool = True,
846
+ allow_overwrite: bool = False,
837
847
  ):
838
- import subprocess
839
848
  import tempfile
840
849
 
841
850
  profile_environment = config.get("environment")
842
851
  abi_tag = "cp" + python_version.replace(".", "")
843
852
  mount_name = f"{builder_version}-{abi_tag}-{platform}-{arch}"
844
- uv_python_platform = uv_python_platform or f"{arch}-{platform}"
845
853
 
846
854
  if check_if_exists:
847
855
  try:
848
856
  await Mount.from_name(mount_name, namespace=api_pb2.DEPLOYMENT_NAMESPACE_GLOBAL).hydrate.aio(client)
849
- print(f" Found existing mount {mount_name} in global namespace.")
857
+ print(f" Found existing mount {mount_name} in global namespace.")
850
858
  return
851
859
  except modal.exception.NotFoundError:
852
860
  pass
853
861
 
854
- with tempfile.TemporaryDirectory() as tmpd:
862
+ with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpd:
855
863
  print(f"📦 Building {mount_name}.")
856
864
  requirements = os.path.join(os.path.dirname(__file__), f"requirements/{builder_version}.txt")
857
- subprocess.run(
865
+ cmd = " ".join(
858
866
  [
859
867
  "uv",
860
868
  "pip",
@@ -871,12 +879,21 @@ async def _create_single_mount(
871
879
  uv_python_platform,
872
880
  "--python-version",
873
881
  python_version,
874
- ],
875
- check=True,
876
- capture_output=True,
882
+ ]
877
883
  )
884
+ proc = await asyncio.create_subprocess_shell(
885
+ cmd,
886
+ stdout=asyncio.subprocess.PIPE,
887
+ stderr=asyncio.subprocess.PIPE,
888
+ )
889
+ await proc.wait()
890
+ if proc.returncode:
891
+ stdout, stderr = await proc.communicate()
892
+ print(stdout.decode("utf-8"))
893
+ print(stderr.decode("utf-8"))
894
+ raise RuntimeError(f"Subprocess failed with {proc.returncode}", proc.args)
878
895
 
879
- print(f"🌐 Downloaded and unpacked packages to {tmpd}.")
896
+ print(f"🌐 Downloaded and unpacked {mount_name} packages to {tmpd}.")
880
897
 
881
898
  python_mount = Mount._from_local_dir(tmpd, remote_path=REMOTE_PACKAGES_PATH)
882
899
 
@@ -895,6 +912,7 @@ async def _create_single_mount(
895
912
  mount_name,
896
913
  api_pb2.DEPLOYMENT_NAMESPACE_GLOBAL,
897
914
  environment_name=profile_environment,
915
+ allow_overwrite=allow_overwrite,
898
916
  client=client,
899
917
  )
900
918
  print(f"✅ Deployed mount {mount_name} to global namespace.")
@@ -902,34 +920,30 @@ async def _create_single_mount(
902
920
 
903
921
  async def _create_client_dependency_mounts(
904
922
  client=None,
905
- check_if_exists=True,
906
923
  python_versions: list[str] = list(PYTHON_STANDALONE_VERSIONS),
924
+ check_if_exists=True,
907
925
  ):
926
+ arch = "x86_64"
927
+ platform_tags = [
928
+ ("manylinux_2_17", f"{arch}-manylinux_2_17"), # glibc >= 2.17
929
+ ("musllinux_1_2", f"{arch}-unknown-linux-musl"), # musl >= 1.2
930
+ ]
908
931
  coros = []
909
- for python_version in python_versions:
910
- # glibc >= 2.17
911
- coros.append(
912
- _create_single_mount(
913
- client,
914
- "PREVIEW",
915
- python_version,
916
- "manylinux_2_17",
917
- "x86_64",
918
- check_if_exists=check_if_exists,
919
- )
920
- )
921
- # musl >= 1.2
922
- coros.append(
923
- _create_single_mount(
924
- client,
925
- "PREVIEW",
926
- python_version,
927
- "musllinux_1_2",
928
- "x86_64",
929
- uv_python_platform="x86_64-unknown-linux-musl",
930
- check_if_exists=check_if_exists,
931
- )
932
- )
932
+ for builder_version in ["PREVIEW"]:
933
+ for python_version in python_versions:
934
+ for platform, uv_python_platform in platform_tags:
935
+ coros.append(
936
+ _create_single_client_dependency_mount(
937
+ client,
938
+ builder_version,
939
+ python_version,
940
+ arch,
941
+ platform,
942
+ uv_python_platform,
943
+ check_if_exists=builder_version != "PREVIEW",
944
+ allow_overwrite=builder_version == "PREVIEW",
945
+ )
946
+ )
933
947
  await TaskContext.gather(*coros)
934
948
 
935
949