modal 0.72.4__py3-none-any.whl → 0.72.48__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. modal/_container_entrypoint.py +5 -10
  2. modal/_object.py +297 -0
  3. modal/_resolver.py +7 -5
  4. modal/_runtime/container_io_manager.py +0 -11
  5. modal/_runtime/user_code_imports.py +7 -7
  6. modal/_serialization.py +4 -3
  7. modal/_tunnel.py +1 -1
  8. modal/app.py +14 -61
  9. modal/app.pyi +25 -25
  10. modal/cli/app.py +3 -2
  11. modal/cli/container.py +1 -1
  12. modal/cli/import_refs.py +185 -113
  13. modal/cli/launch.py +10 -5
  14. modal/cli/programs/run_jupyter.py +2 -2
  15. modal/cli/programs/vscode.py +3 -3
  16. modal/cli/run.py +134 -68
  17. modal/client.py +1 -0
  18. modal/client.pyi +18 -14
  19. modal/cloud_bucket_mount.py +4 -0
  20. modal/cloud_bucket_mount.pyi +4 -0
  21. modal/cls.py +33 -5
  22. modal/cls.pyi +20 -5
  23. modal/container_process.pyi +8 -6
  24. modal/dict.py +1 -1
  25. modal/dict.pyi +32 -29
  26. modal/environments.py +1 -1
  27. modal/environments.pyi +2 -1
  28. modal/experimental.py +47 -11
  29. modal/experimental.pyi +29 -0
  30. modal/file_io.pyi +30 -28
  31. modal/file_pattern_matcher.py +32 -25
  32. modal/functions.py +31 -23
  33. modal/functions.pyi +57 -50
  34. modal/gpu.py +19 -26
  35. modal/image.py +47 -19
  36. modal/image.pyi +28 -21
  37. modal/io_streams.pyi +14 -12
  38. modal/mount.py +14 -5
  39. modal/mount.pyi +28 -25
  40. modal/network_file_system.py +7 -7
  41. modal/network_file_system.pyi +27 -24
  42. modal/object.py +2 -265
  43. modal/object.pyi +46 -130
  44. modal/parallel_map.py +2 -2
  45. modal/parallel_map.pyi +10 -7
  46. modal/partial_function.py +22 -3
  47. modal/partial_function.pyi +45 -27
  48. modal/proxy.py +1 -1
  49. modal/proxy.pyi +2 -1
  50. modal/queue.py +1 -1
  51. modal/queue.pyi +26 -23
  52. modal/runner.py +14 -3
  53. modal/sandbox.py +11 -7
  54. modal/sandbox.pyi +30 -27
  55. modal/secret.py +1 -1
  56. modal/secret.pyi +2 -1
  57. modal/token_flow.pyi +6 -4
  58. modal/volume.py +1 -1
  59. modal/volume.pyi +36 -33
  60. {modal-0.72.4.dist-info → modal-0.72.48.dist-info}/METADATA +2 -2
  61. {modal-0.72.4.dist-info → modal-0.72.48.dist-info}/RECORD +73 -71
  62. modal_proto/api.proto +151 -4
  63. modal_proto/api_grpc.py +113 -0
  64. modal_proto/api_pb2.py +998 -795
  65. modal_proto/api_pb2.pyi +430 -11
  66. modal_proto/api_pb2_grpc.py +233 -1
  67. modal_proto/api_pb2_grpc.pyi +75 -3
  68. modal_proto/modal_api_grpc.py +7 -0
  69. modal_version/_version_generated.py +1 -1
  70. {modal-0.72.4.dist-info → modal-0.72.48.dist-info}/LICENSE +0 -0
  71. {modal-0.72.4.dist-info → modal-0.72.48.dist-info}/WHEEL +0 -0
  72. {modal-0.72.4.dist-info → modal-0.72.48.dist-info}/entry_points.txt +0 -0
  73. {modal-0.72.4.dist-info → modal-0.72.48.dist-info}/top_level.txt +0 -0
modal/functions.py CHANGED
@@ -26,6 +26,7 @@ from modal_proto import api_pb2
26
26
  from modal_proto.modal_api_grpc import ModalClientModal
27
27
 
28
28
  from ._location import parse_cloud_provider
29
+ from ._object import _get_environment_name, _Object, live_method, live_method_gen
29
30
  from ._pty import get_pty_info
30
31
  from ._resolver import Resolver
31
32
  from ._resources import convert_fn_config_to_resources_config
@@ -71,7 +72,6 @@ from .gpu import GPU_T, parse_gpu_config
71
72
  from .image import _Image
72
73
  from .mount import _get_client_mount, _Mount, get_auto_mounts
73
74
  from .network_file_system import _NetworkFileSystem, network_file_system_mount_protos
74
- from .object import _get_environment_name, _Object, live_method, live_method_gen
75
75
  from .output import _get_output_manager
76
76
  from .parallel_map import (
77
77
  _for_each_async,
@@ -383,12 +383,15 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
383
383
  _serve_mounts: frozenset[_Mount] # set at load time, only by loader
384
384
  _app: Optional["modal.app._App"] = None
385
385
  _obj: Optional["modal.cls._Obj"] = None # only set for InstanceServiceFunctions and bound instance methods
386
- _web_url: Optional[str]
386
+
387
+ _webhook_config: Optional[api_pb2.WebhookConfig] = None # this is set in definition scope, only locally
388
+ _web_url: Optional[str] # this is set on hydration
389
+
387
390
  _function_name: Optional[str]
388
391
  _is_method: bool
389
392
  _spec: Optional[_FunctionSpec] = None
390
393
  _tag: str
391
- _raw_f: Callable[..., Any]
394
+ _raw_f: Optional[Callable[..., Any]] # this is set to None for a "class service [function]"
392
395
  _build_args: dict
393
396
 
394
397
  _is_generator: Optional[bool] = None
@@ -474,7 +477,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
474
477
  _experimental_buffer_containers: Optional[int] = None,
475
478
  _experimental_proxy_ip: Optional[str] = None,
476
479
  _experimental_custom_scaling_factor: Optional[float] = None,
477
- ) -> None:
480
+ ) -> "_Function":
478
481
  """mdmd:hidden"""
479
482
  # Needed to avoid circular imports
480
483
  from .partial_function import _find_partial_methods_for_user_cls, _PartialFunctionFlags
@@ -573,7 +576,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
573
576
  )
574
577
  image = _Image._from_args(
575
578
  base_images={"base": image},
576
- build_function=snapshot_function,
579
+ build_function=snapshot_function, # type: ignore # TODO: separate functions.py and _functions.py
577
580
  force_build=image.force_build or pf.force_build,
578
581
  )
579
582
 
@@ -785,7 +788,8 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
785
788
  task_idle_timeout_secs=container_idle_timeout or 0,
786
789
  concurrency_limit=concurrency_limit or 0,
787
790
  pty_info=pty_info,
788
- cloud_provider=cloud_provider,
791
+ cloud_provider=cloud_provider, # Deprecated at some point
792
+ cloud_provider_str=cloud.upper() if cloud else "", # Supersedes cloud_provider
789
793
  warm_pool_size=keep_warm or 0,
790
794
  runtime=config.get("function_runtime"),
791
795
  runtime_debug=config.get("function_runtime_debug"),
@@ -911,6 +915,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
911
915
  obj._cluster_size = cluster_size
912
916
  obj._is_method = False
913
917
  obj._spec = function_spec # needed for modal shell
918
+ obj._webhook_config = webhook_config # only set locally
914
919
 
915
920
  # Used to check whether we should rebuild a modal.Image which uses `run_function`.
916
921
  gpus: list[GPU_T] = gpu if isinstance(gpu, list) else [gpu]
@@ -962,7 +967,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
962
967
  f"The {identity} has not been hydrated with the metadata it needs to run on Modal{reason}."
963
968
  )
964
969
 
965
- assert parent._client.stub
970
+ assert parent._client and parent._client.stub
966
971
 
967
972
  if can_use_parent:
968
973
  # We can end up here if parent wasn't hydrated when class was instantiated, but has been since.
@@ -983,9 +988,9 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
983
988
  else:
984
989
  serialized_params = serialize((args, kwargs))
985
990
  environment_name = _get_environment_name(None, resolver)
986
- assert parent is not None
991
+ assert parent is not None and parent.is_hydrated
987
992
  req = api_pb2.FunctionBindParamsRequest(
988
- function_id=parent._object_id,
993
+ function_id=parent.object_id,
989
994
  serialized_params=serialized_params,
990
995
  function_options=options,
991
996
  environment_name=environment_name
@@ -1032,11 +1037,10 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1032
1037
  """
1033
1038
  )
1034
1039
  )
1035
- assert self._client and self._client.stub
1036
1040
  request = api_pb2.FunctionUpdateSchedulingParamsRequest(
1037
- function_id=self._object_id, warm_pool_size_override=warm_pool_size
1041
+ function_id=self.object_id, warm_pool_size_override=warm_pool_size
1038
1042
  )
1039
- await retry_transient_errors(self._client.stub.FunctionUpdateSchedulingParams, request)
1043
+ await retry_transient_errors(self.client.stub.FunctionUpdateSchedulingParams, request)
1040
1044
 
1041
1045
  @classmethod
1042
1046
  @renamed_parameter((2024, 12, 18), "tag", "name")
@@ -1138,11 +1142,15 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1138
1142
  assert self._spec
1139
1143
  return self._spec
1140
1144
 
1145
+ def _is_web_endpoint(self) -> bool:
1146
+ # only defined in definition scope/locally, and not for class methods at the moment
1147
+ return bool(self._webhook_config and self._webhook_config.type != api_pb2.WEBHOOK_TYPE_UNSPECIFIED)
1148
+
1141
1149
  def get_build_def(self) -> str:
1142
1150
  """mdmd:hidden"""
1143
1151
  # Plaintext source and arg definition for the function, so it's part of the image
1144
1152
  # hash. We can't use the cloudpickle hash because it's not very stable.
1145
- assert hasattr(self, "_raw_f") and hasattr(self, "_build_args")
1153
+ assert hasattr(self, "_raw_f") and hasattr(self, "_build_args") and self._raw_f is not None
1146
1154
  return f"{inspect.getsource(self._raw_f)}\n{repr(self._build_args)}"
1147
1155
 
1148
1156
  # Live handle methods
@@ -1207,12 +1215,13 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1207
1215
  async def is_generator(self) -> bool:
1208
1216
  """mdmd:hidden"""
1209
1217
  # hacky: kind of like @live_method, but not hydrating if we have the value already from local source
1218
+ # TODO(michael) use a common / lightweight method for handling unhydrated metadata properties
1210
1219
  if self._is_generator is not None:
1211
1220
  # this is set if the function or class is local
1212
1221
  return self._is_generator
1213
1222
 
1214
1223
  # not set - this is a from_name lookup - hydrate
1215
- await self.resolve()
1224
+ await self.hydrate()
1216
1225
  assert self._is_generator is not None # should be set now
1217
1226
  return self._is_generator
1218
1227
 
@@ -1248,7 +1257,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1248
1257
  _map_invocation(
1249
1258
  self, # type: ignore
1250
1259
  input_queue,
1251
- self._client,
1260
+ self.client,
1252
1261
  order_outputs,
1253
1262
  return_exceptions,
1254
1263
  count_update_callback,
@@ -1266,7 +1275,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1266
1275
  self,
1267
1276
  args,
1268
1277
  kwargs,
1269
- client=self._client,
1278
+ client=self.client,
1270
1279
  function_call_invocation_type=function_call_invocation_type,
1271
1280
  )
1272
1281
 
@@ -1276,7 +1285,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1276
1285
  self, args, kwargs, function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType"
1277
1286
  ) -> _Invocation:
1278
1287
  return await _Invocation.create(
1279
- self, args, kwargs, client=self._client, function_call_invocation_type=function_call_invocation_type
1288
+ self, args, kwargs, client=self.client, function_call_invocation_type=function_call_invocation_type
1280
1289
  )
1281
1290
 
1282
1291
  @warn_if_generator_is_not_consumed()
@@ -1287,7 +1296,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1287
1296
  self,
1288
1297
  args,
1289
1298
  kwargs,
1290
- client=self._client,
1299
+ client=self.client,
1291
1300
  function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY,
1292
1301
  )
1293
1302
  async for res in invocation.run_generator():
@@ -1303,7 +1312,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1303
1312
  self,
1304
1313
  args,
1305
1314
  kwargs,
1306
- client=self._client,
1315
+ client=self.client,
1307
1316
  function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC_LEGACY,
1308
1317
  )
1309
1318
 
@@ -1452,14 +1461,14 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1452
1461
 
1453
1462
  def get_raw_f(self) -> Callable[..., Any]:
1454
1463
  """Return the inner Python object wrapped by this Modal Function."""
1464
+ assert self._raw_f is not None
1455
1465
  return self._raw_f
1456
1466
 
1457
1467
  @live_method
1458
1468
  async def get_current_stats(self) -> FunctionStats:
1459
1469
  """Return a `FunctionStats` object describing the current function's queue and runner counts."""
1460
- assert self._client.stub
1461
1470
  resp = await retry_transient_errors(
1462
- self._client.stub.FunctionGetCurrentStats,
1471
+ self.client.stub.FunctionGetCurrentStats,
1463
1472
  api_pb2.FunctionGetCurrentStatsRequest(function_id=self.object_id),
1464
1473
  total_timeout=10.0,
1465
1474
  )
@@ -1491,8 +1500,7 @@ class _FunctionCall(typing.Generic[ReturnType], _Object, type_prefix="fc"):
1491
1500
  _is_generator: bool = False
1492
1501
 
1493
1502
  def _invocation(self):
1494
- assert self._client.stub
1495
- return _Invocation(self._client.stub, self.object_id, self._client)
1503
+ return _Invocation(self.client.stub, self.object_id, self.client)
1496
1504
 
1497
1505
  async def get(self, timeout: Optional[float] = None) -> ReturnType:
1498
1506
  """Get the result of the function call.
modal/functions.pyi CHANGED
@@ -1,5 +1,6 @@
1
1
  import collections.abc
2
2
  import google.protobuf.message
3
+ import modal._object
3
4
  import modal._utils.async_utils
4
5
  import modal._utils.function_utils
5
6
  import modal.app
@@ -133,17 +134,20 @@ ReturnType = typing.TypeVar("ReturnType", covariant=True)
133
134
 
134
135
  OriginalReturnType = typing.TypeVar("OriginalReturnType", covariant=True)
135
136
 
136
- class _Function(typing.Generic[P, ReturnType, OriginalReturnType], modal.object._Object):
137
+ SUPERSELF = typing.TypeVar("SUPERSELF", covariant=True)
138
+
139
+ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], modal._object._Object):
137
140
  _info: typing.Optional[modal._utils.function_utils.FunctionInfo]
138
141
  _serve_mounts: frozenset[modal.mount._Mount]
139
142
  _app: typing.Optional[modal.app._App]
140
143
  _obj: typing.Optional[modal.cls._Obj]
144
+ _webhook_config: typing.Optional[modal_proto.api_pb2.WebhookConfig]
141
145
  _web_url: typing.Optional[str]
142
146
  _function_name: typing.Optional[str]
143
147
  _is_method: bool
144
148
  _spec: typing.Optional[_FunctionSpec]
145
149
  _tag: str
146
- _raw_f: typing.Callable[..., typing.Any]
150
+ _raw_f: typing.Optional[collections.abc.Callable[..., typing.Any]]
147
151
  _build_args: dict
148
152
  _is_generator: typing.Optional[bool]
149
153
  _cluster_size: typing.Optional[int]
@@ -197,7 +201,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], modal.object.
197
201
  _experimental_buffer_containers: typing.Optional[int] = None,
198
202
  _experimental_proxy_ip: typing.Optional[str] = None,
199
203
  _experimental_custom_scaling_factor: typing.Optional[float] = None,
200
- ) -> None: ...
204
+ ) -> _Function: ...
201
205
  def _bind_parameters(
202
206
  self,
203
207
  obj: modal.cls._Obj,
@@ -228,6 +232,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], modal.object.
228
232
  def info(self) -> modal._utils.function_utils.FunctionInfo: ...
229
233
  @property
230
234
  def spec(self) -> _FunctionSpec: ...
235
+ def _is_web_endpoint(self) -> bool: ...
231
236
  def get_build_def(self) -> str: ...
232
237
  def _initialize_from_empty(self): ...
233
238
  def _hydrate_metadata(self, metadata: typing.Optional[google.protobuf.message.Message]): ...
@@ -254,10 +259,10 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], modal.object.
254
259
  def local(self, *args: P.args, **kwargs: P.kwargs) -> OriginalReturnType: ...
255
260
  async def _experimental_spawn(self, *args: P.args, **kwargs: P.kwargs) -> _FunctionCall[ReturnType]: ...
256
261
  async def spawn(self, *args: P.args, **kwargs: P.kwargs) -> _FunctionCall[ReturnType]: ...
257
- def get_raw_f(self) -> typing.Callable[..., typing.Any]: ...
262
+ def get_raw_f(self) -> collections.abc.Callable[..., typing.Any]: ...
258
263
  async def get_current_stats(self) -> FunctionStats: ...
259
264
 
260
- class __map_spec(typing_extensions.Protocol):
265
+ class __map_spec(typing_extensions.Protocol[SUPERSELF]):
261
266
  def __call__(
262
267
  self, *input_iterators, kwargs={}, order_outputs: bool = True, return_exceptions: bool = False
263
268
  ) -> modal._utils.async_utils.AsyncOrSyncIterable: ...
@@ -269,9 +274,9 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], modal.object.
269
274
  return_exceptions: bool = False,
270
275
  ) -> typing.AsyncGenerator[typing.Any, None]: ...
271
276
 
272
- map: __map_spec
277
+ map: __map_spec[typing_extensions.Self]
273
278
 
274
- class __starmap_spec(typing_extensions.Protocol):
279
+ class __starmap_spec(typing_extensions.Protocol[SUPERSELF]):
275
280
  def __call__(
276
281
  self,
277
282
  input_iterator: typing.Iterable[typing.Sequence[typing.Any]],
@@ -289,13 +294,13 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], modal.object.
289
294
  return_exceptions: bool = False,
290
295
  ) -> typing.AsyncIterable[typing.Any]: ...
291
296
 
292
- starmap: __starmap_spec
297
+ starmap: __starmap_spec[typing_extensions.Self]
293
298
 
294
- class __for_each_spec(typing_extensions.Protocol):
299
+ class __for_each_spec(typing_extensions.Protocol[SUPERSELF]):
295
300
  def __call__(self, *input_iterators, kwargs={}, ignore_exceptions: bool = False): ...
296
301
  async def aio(self, *input_iterators, kwargs={}, ignore_exceptions: bool = False): ...
297
302
 
298
- for_each: __for_each_spec
303
+ for_each: __for_each_spec[typing_extensions.Self]
299
304
 
300
305
  ReturnType_INNER = typing.TypeVar("ReturnType_INNER", covariant=True)
301
306
 
@@ -306,12 +311,13 @@ class Function(typing.Generic[P, ReturnType, OriginalReturnType], modal.object.O
306
311
  _serve_mounts: frozenset[modal.mount.Mount]
307
312
  _app: typing.Optional[modal.app.App]
308
313
  _obj: typing.Optional[modal.cls.Obj]
314
+ _webhook_config: typing.Optional[modal_proto.api_pb2.WebhookConfig]
309
315
  _web_url: typing.Optional[str]
310
316
  _function_name: typing.Optional[str]
311
317
  _is_method: bool
312
318
  _spec: typing.Optional[_FunctionSpec]
313
319
  _tag: str
314
- _raw_f: typing.Callable[..., typing.Any]
320
+ _raw_f: typing.Optional[collections.abc.Callable[..., typing.Any]]
315
321
  _build_args: dict
316
322
  _is_generator: typing.Optional[bool]
317
323
  _cluster_size: typing.Optional[int]
@@ -366,7 +372,7 @@ class Function(typing.Generic[P, ReturnType, OriginalReturnType], modal.object.O
366
372
  _experimental_buffer_containers: typing.Optional[int] = None,
367
373
  _experimental_proxy_ip: typing.Optional[str] = None,
368
374
  _experimental_custom_scaling_factor: typing.Optional[float] = None,
369
- ) -> None: ...
375
+ ) -> Function: ...
370
376
  def _bind_parameters(
371
377
  self,
372
378
  obj: modal.cls.Obj,
@@ -375,11 +381,11 @@ class Function(typing.Generic[P, ReturnType, OriginalReturnType], modal.object.O
375
381
  kwargs: dict[str, typing.Any],
376
382
  ) -> Function: ...
377
383
 
378
- class __keep_warm_spec(typing_extensions.Protocol):
384
+ class __keep_warm_spec(typing_extensions.Protocol[SUPERSELF]):
379
385
  def __call__(self, warm_pool_size: int) -> None: ...
380
386
  async def aio(self, warm_pool_size: int) -> None: ...
381
387
 
382
- keep_warm: __keep_warm_spec
388
+ keep_warm: __keep_warm_spec[typing_extensions.Self]
383
389
 
384
390
  @classmethod
385
391
  def from_name(
@@ -416,6 +422,7 @@ class Function(typing.Generic[P, ReturnType, OriginalReturnType], modal.object.O
416
422
  def info(self) -> modal._utils.function_utils.FunctionInfo: ...
417
423
  @property
418
424
  def spec(self) -> _FunctionSpec: ...
425
+ def _is_web_endpoint(self) -> bool: ...
419
426
  def get_build_def(self) -> str: ...
420
427
  def _initialize_from_empty(self): ...
421
428
  def _hydrate_metadata(self, metadata: typing.Optional[google.protobuf.message.Message]): ...
@@ -428,7 +435,7 @@ class Function(typing.Generic[P, ReturnType, OriginalReturnType], modal.object.O
428
435
  @property
429
436
  def cluster_size(self) -> int: ...
430
437
 
431
- class ___map_spec(typing_extensions.Protocol):
438
+ class ___map_spec(typing_extensions.Protocol[SUPERSELF]):
432
439
  def __call__(
433
440
  self, input_queue: modal.parallel_map.SynchronizedQueue, order_outputs: bool, return_exceptions: bool
434
441
  ) -> typing.Generator[typing.Any, None, None]: ...
@@ -436,70 +443,70 @@ class Function(typing.Generic[P, ReturnType, OriginalReturnType], modal.object.O
436
443
  self, input_queue: modal.parallel_map.SynchronizedQueue, order_outputs: bool, return_exceptions: bool
437
444
  ) -> collections.abc.AsyncGenerator[typing.Any, None]: ...
438
445
 
439
- _map: ___map_spec
446
+ _map: ___map_spec[typing_extensions.Self]
440
447
 
441
- class ___call_function_spec(typing_extensions.Protocol[ReturnType_INNER]):
448
+ class ___call_function_spec(typing_extensions.Protocol[ReturnType_INNER, SUPERSELF]):
442
449
  def __call__(self, args, kwargs) -> ReturnType_INNER: ...
443
450
  async def aio(self, args, kwargs) -> ReturnType_INNER: ...
444
451
 
445
- _call_function: ___call_function_spec[ReturnType]
452
+ _call_function: ___call_function_spec[ReturnType, typing_extensions.Self]
446
453
 
447
- class ___call_function_nowait_spec(typing_extensions.Protocol):
454
+ class ___call_function_nowait_spec(typing_extensions.Protocol[SUPERSELF]):
448
455
  def __call__(self, args, kwargs, function_call_invocation_type: int) -> _Invocation: ...
449
456
  async def aio(self, args, kwargs, function_call_invocation_type: int) -> _Invocation: ...
450
457
 
451
- _call_function_nowait: ___call_function_nowait_spec
458
+ _call_function_nowait: ___call_function_nowait_spec[typing_extensions.Self]
452
459
 
453
- class ___call_generator_spec(typing_extensions.Protocol):
460
+ class ___call_generator_spec(typing_extensions.Protocol[SUPERSELF]):
454
461
  def __call__(self, args, kwargs): ...
455
462
  def aio(self, args, kwargs): ...
456
463
 
457
- _call_generator: ___call_generator_spec
464
+ _call_generator: ___call_generator_spec[typing_extensions.Self]
458
465
 
459
- class ___call_generator_nowait_spec(typing_extensions.Protocol):
466
+ class ___call_generator_nowait_spec(typing_extensions.Protocol[SUPERSELF]):
460
467
  def __call__(self, args, kwargs): ...
461
468
  async def aio(self, args, kwargs): ...
462
469
 
463
- _call_generator_nowait: ___call_generator_nowait_spec
470
+ _call_generator_nowait: ___call_generator_nowait_spec[typing_extensions.Self]
464
471
 
465
- class __remote_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER]):
472
+ class __remote_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER, SUPERSELF]):
466
473
  def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> ReturnType_INNER: ...
467
474
  async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> ReturnType_INNER: ...
468
475
 
469
- remote: __remote_spec[P, ReturnType]
476
+ remote: __remote_spec[ReturnType, P, typing_extensions.Self]
470
477
 
471
- class __remote_gen_spec(typing_extensions.Protocol):
478
+ class __remote_gen_spec(typing_extensions.Protocol[SUPERSELF]):
472
479
  def __call__(self, *args, **kwargs) -> typing.Generator[typing.Any, None, None]: ...
473
480
  def aio(self, *args, **kwargs) -> collections.abc.AsyncGenerator[typing.Any, None]: ...
474
481
 
475
- remote_gen: __remote_gen_spec
482
+ remote_gen: __remote_gen_spec[typing_extensions.Self]
476
483
 
477
484
  def _is_local(self): ...
478
485
  def _get_info(self) -> modal._utils.function_utils.FunctionInfo: ...
479
486
  def _get_obj(self) -> typing.Optional[modal.cls.Obj]: ...
480
487
  def local(self, *args: P.args, **kwargs: P.kwargs) -> OriginalReturnType: ...
481
488
 
482
- class ___experimental_spawn_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER]):
489
+ class ___experimental_spawn_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER, SUPERSELF]):
483
490
  def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
484
491
  async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
485
492
 
486
- _experimental_spawn: ___experimental_spawn_spec[P, ReturnType]
493
+ _experimental_spawn: ___experimental_spawn_spec[ReturnType, P, typing_extensions.Self]
487
494
 
488
- class __spawn_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER]):
495
+ class __spawn_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER, SUPERSELF]):
489
496
  def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
490
497
  async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
491
498
 
492
- spawn: __spawn_spec[P, ReturnType]
499
+ spawn: __spawn_spec[ReturnType, P, typing_extensions.Self]
493
500
 
494
- def get_raw_f(self) -> typing.Callable[..., typing.Any]: ...
501
+ def get_raw_f(self) -> collections.abc.Callable[..., typing.Any]: ...
495
502
 
496
- class __get_current_stats_spec(typing_extensions.Protocol):
503
+ class __get_current_stats_spec(typing_extensions.Protocol[SUPERSELF]):
497
504
  def __call__(self) -> FunctionStats: ...
498
505
  async def aio(self) -> FunctionStats: ...
499
506
 
500
- get_current_stats: __get_current_stats_spec
507
+ get_current_stats: __get_current_stats_spec[typing_extensions.Self]
501
508
 
502
- class __map_spec(typing_extensions.Protocol):
509
+ class __map_spec(typing_extensions.Protocol[SUPERSELF]):
503
510
  def __call__(
504
511
  self, *input_iterators, kwargs={}, order_outputs: bool = True, return_exceptions: bool = False
505
512
  ) -> modal._utils.async_utils.AsyncOrSyncIterable: ...
@@ -511,9 +518,9 @@ class Function(typing.Generic[P, ReturnType, OriginalReturnType], modal.object.O
511
518
  return_exceptions: bool = False,
512
519
  ) -> typing.AsyncGenerator[typing.Any, None]: ...
513
520
 
514
- map: __map_spec
521
+ map: __map_spec[typing_extensions.Self]
515
522
 
516
- class __starmap_spec(typing_extensions.Protocol):
523
+ class __starmap_spec(typing_extensions.Protocol[SUPERSELF]):
517
524
  def __call__(
518
525
  self,
519
526
  input_iterator: typing.Iterable[typing.Sequence[typing.Any]],
@@ -531,15 +538,15 @@ class Function(typing.Generic[P, ReturnType, OriginalReturnType], modal.object.O
531
538
  return_exceptions: bool = False,
532
539
  ) -> typing.AsyncIterable[typing.Any]: ...
533
540
 
534
- starmap: __starmap_spec
541
+ starmap: __starmap_spec[typing_extensions.Self]
535
542
 
536
- class __for_each_spec(typing_extensions.Protocol):
543
+ class __for_each_spec(typing_extensions.Protocol[SUPERSELF]):
537
544
  def __call__(self, *input_iterators, kwargs={}, ignore_exceptions: bool = False): ...
538
545
  async def aio(self, *input_iterators, kwargs={}, ignore_exceptions: bool = False): ...
539
546
 
540
- for_each: __for_each_spec
547
+ for_each: __for_each_spec[typing_extensions.Self]
541
548
 
542
- class _FunctionCall(typing.Generic[ReturnType], modal.object._Object):
549
+ class _FunctionCall(typing.Generic[ReturnType], modal._object._Object):
543
550
  _is_generator: bool
544
551
 
545
552
  def _invocation(self): ...
@@ -558,29 +565,29 @@ class FunctionCall(typing.Generic[ReturnType], modal.object.Object):
558
565
  def __init__(self, *args, **kwargs): ...
559
566
  def _invocation(self): ...
560
567
 
561
- class __get_spec(typing_extensions.Protocol[ReturnType_INNER]):
568
+ class __get_spec(typing_extensions.Protocol[ReturnType_INNER, SUPERSELF]):
562
569
  def __call__(self, timeout: typing.Optional[float] = None) -> ReturnType_INNER: ...
563
570
  async def aio(self, timeout: typing.Optional[float] = None) -> ReturnType_INNER: ...
564
571
 
565
- get: __get_spec[ReturnType]
572
+ get: __get_spec[ReturnType, typing_extensions.Self]
566
573
 
567
- class __get_gen_spec(typing_extensions.Protocol):
574
+ class __get_gen_spec(typing_extensions.Protocol[SUPERSELF]):
568
575
  def __call__(self) -> typing.Generator[typing.Any, None, None]: ...
569
576
  def aio(self) -> collections.abc.AsyncGenerator[typing.Any, None]: ...
570
577
 
571
- get_gen: __get_gen_spec
578
+ get_gen: __get_gen_spec[typing_extensions.Self]
572
579
 
573
- class __get_call_graph_spec(typing_extensions.Protocol):
580
+ class __get_call_graph_spec(typing_extensions.Protocol[SUPERSELF]):
574
581
  def __call__(self) -> list[modal.call_graph.InputInfo]: ...
575
582
  async def aio(self) -> list[modal.call_graph.InputInfo]: ...
576
583
 
577
- get_call_graph: __get_call_graph_spec
584
+ get_call_graph: __get_call_graph_spec[typing_extensions.Self]
578
585
 
579
- class __cancel_spec(typing_extensions.Protocol):
586
+ class __cancel_spec(typing_extensions.Protocol[SUPERSELF]):
580
587
  def __call__(self, terminate_containers: bool = False): ...
581
588
  async def aio(self, terminate_containers: bool = False): ...
582
589
 
583
- cancel: __cancel_spec
590
+ cancel: __cancel_spec[typing_extensions.Self]
584
591
 
585
592
  class __from_id_spec(typing_extensions.Protocol):
586
593
  def __call__(
modal/gpu.py CHANGED
@@ -9,8 +9,9 @@ from .exception import InvalidError
9
9
 
10
10
  @dataclass(frozen=True)
11
11
  class _GPUConfig:
12
- type: "api_pb2.GPUType.V"
12
+ type: "api_pb2.GPUType.V" # Deprecated, at some point
13
13
  count: int
14
+ gpu_type: str
14
15
  memory: int = 0
15
16
 
16
17
  def _to_proto(self) -> api_pb2.GPUConfig:
@@ -19,6 +20,7 @@ class _GPUConfig:
19
20
  type=self.type,
20
21
  count=self.count,
21
22
  memory=self.memory,
23
+ gpu_type=self.gpu_type,
22
24
  )
23
25
 
24
26
 
@@ -26,14 +28,14 @@ class T4(_GPUConfig):
26
28
  """
27
29
  [NVIDIA T4 Tensor Core](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPU class.
28
30
 
29
- A low-cost data center GPU based on the Turing architecture, providing 16GiB of GPU memory.
31
+ A low-cost data center GPU based on the Turing architecture, providing 16GB of GPU memory.
30
32
  """
31
33
 
32
34
  def __init__(
33
35
  self,
34
36
  count: int = 1, # Number of GPUs per container. Defaults to 1.
35
37
  ):
36
- super().__init__(api_pb2.GPU_TYPE_T4, count, 0)
38
+ super().__init__(api_pb2.GPU_TYPE_T4, count, "T4")
37
39
 
38
40
  def __repr__(self):
39
41
  return f"GPU(T4, count={self.count})"
@@ -43,7 +45,7 @@ class L4(_GPUConfig):
43
45
  """
44
46
  [NVIDIA L4 Tensor Core](https://www.nvidia.com/en-us/data-center/l4/) GPU class.
45
47
 
46
- A mid-tier data center GPU based on the Ada Lovelace architecture, providing 24GiB of GPU memory.
48
+ A mid-tier data center GPU based on the Ada Lovelace architecture, providing 24GB of GPU memory.
47
49
  Includes RTX (ray tracing) support.
48
50
  """
49
51
 
@@ -51,7 +53,7 @@ class L4(_GPUConfig):
51
53
  self,
52
54
  count: int = 1, # Number of GPUs per container. Defaults to 1.
53
55
  ):
54
- super().__init__(api_pb2.GPU_TYPE_L4, count, 0)
56
+ super().__init__(api_pb2.GPU_TYPE_L4, count, "L4")
55
57
 
56
58
  def __repr__(self):
57
59
  return f"GPU(L4, count={self.count})"
@@ -61,30 +63,21 @@ class A100(_GPUConfig):
61
63
  """
62
64
  [NVIDIA A100 Tensor Core](https://www.nvidia.com/en-us/data-center/a100/) GPU class.
63
65
 
64
- The flagship data center GPU of the Ampere architecture. Available in 40GiB and 80GiB GPU memory configurations.
66
+ The flagship data center GPU of the Ampere architecture. Available in 40GB and 80GB GPU memory configurations.
65
67
  """
66
68
 
67
69
  def __init__(
68
70
  self,
69
71
  *,
70
72
  count: int = 1, # Number of GPUs per container. Defaults to 1.
71
- size: Union[str, None] = None, # Select GiB configuration of GPU device: "40GB" or "80GB". Defaults to "40GB".
73
+ size: Union[str, None] = None, # Select GB configuration of GPU device: "40GB" or "80GB". Defaults to "40GB".
72
74
  ):
73
- allowed_size_values = {"40GB", "80GB"}
74
-
75
- if size:
76
- if size not in allowed_size_values:
77
- raise ValueError(
78
- f"size='{size}' is invalid. A100s can only have memory values of {allowed_size_values}."
79
- )
80
- memory = int(size.replace("GB", ""))
75
+ if size == "40GB" or not size:
76
+ super().__init__(api_pb2.GPU_TYPE_A100, count, "A100-40GB", 40)
77
+ elif size == "80GB":
78
+ super().__init__(api_pb2.GPU_TYPE_A100_80GB, count, "A100-80GB", 80)
81
79
  else:
82
- memory = 40
83
-
84
- if memory == 80:
85
- super().__init__(api_pb2.GPU_TYPE_A100_80GB, count, memory)
86
- else:
87
- super().__init__(api_pb2.GPU_TYPE_A100, count, memory)
80
+ raise ValueError(f"size='{size}' is invalid. A100s can only have memory values of 40GB or 80GB.")
88
81
 
89
82
  def __repr__(self):
90
83
  if self.memory == 80:
@@ -97,7 +90,7 @@ class A10G(_GPUConfig):
97
90
  """
98
91
  [NVIDIA A10G Tensor Core](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) GPU class.
99
92
 
100
- A mid-tier data center GPU based on the Ampere architecture, providing 24 GiB of memory.
93
+ A mid-tier data center GPU based on the Ampere architecture, providing 24 GB of memory.
101
94
  A10G GPUs deliver up to 3.3x better ML training performance, 3x better ML inference performance,
102
95
  and 3x better graphics performance, in comparison to NVIDIA T4 GPUs.
103
96
  """
@@ -109,7 +102,7 @@ class A10G(_GPUConfig):
109
102
  # Useful if you have very large models that don't fit on a single GPU.
110
103
  count: int = 1,
111
104
  ):
112
- super().__init__(api_pb2.GPU_TYPE_A10G, count)
105
+ super().__init__(api_pb2.GPU_TYPE_A10G, count, "A10G")
113
106
 
114
107
  def __repr__(self):
115
108
  return f"GPU(A10G, count={self.count})"
@@ -131,7 +124,7 @@ class H100(_GPUConfig):
131
124
  # Useful if you have very large models that don't fit on a single GPU.
132
125
  count: int = 1,
133
126
  ):
134
- super().__init__(api_pb2.GPU_TYPE_H100, count)
127
+ super().__init__(api_pb2.GPU_TYPE_H100, count, "H100")
135
128
 
136
129
  def __repr__(self):
137
130
  return f"GPU(H100, count={self.count})"
@@ -152,7 +145,7 @@ class L40S(_GPUConfig):
152
145
  # Useful if you have very large models that don't fit on a single GPU.
153
146
  count: int = 1,
154
147
  ):
155
- super().__init__(api_pb2.GPU_TYPE_L40S, count)
148
+ super().__init__(api_pb2.GPU_TYPE_L40S, count, "L40S")
156
149
 
157
150
  def __repr__(self):
158
151
  return f"GPU(L40S, count={self.count})"
@@ -162,7 +155,7 @@ class Any(_GPUConfig):
162
155
  """Selects any one of the GPU classes available within Modal, according to availability."""
163
156
 
164
157
  def __init__(self, *, count: int = 1):
165
- super().__init__(api_pb2.GPU_TYPE_ANY, count)
158
+ super().__init__(api_pb2.GPU_TYPE_ANY, count, "ANY")
166
159
 
167
160
  def __repr__(self):
168
161
  return f"GPU(Any, count={self.count})"