modal 0.73.146__py3-none-any.whl → 0.73.148__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,7 @@ import enum
3
3
  import inspect
4
4
  import typing
5
5
  from collections.abc import Coroutine, Iterable
6
+ from dataclasses import asdict, dataclass
6
7
  from typing import (
7
8
  Any,
8
9
  Callable,
@@ -29,80 +30,180 @@ if typing.TYPE_CHECKING:
29
30
 
30
31
 
31
32
  class _PartialFunctionFlags(enum.IntFlag):
32
- FUNCTION = 1
33
- BUILD = 2
34
- ENTER_PRE_SNAPSHOT = 4
35
- ENTER_POST_SNAPSHOT = 8
36
- EXIT = 16
37
- BATCHED = 32
38
- CLUSTERED = 64 # Experimental: Clustered functions
33
+ # Lifecycle method flags
34
+ BUILD = 1 # Deprecated, will be removed
35
+ ENTER_PRE_SNAPSHOT = 2
36
+ ENTER_POST_SNAPSHOT = 4
37
+ EXIT = 8
38
+ # Interface flags
39
+ CALLABLE_INTERFACE = 16
40
+ WEB_INTERFACE = 32
41
+ # Service decorator flags
42
+ # It's, unclear if we need these, as we can also generally infer based on some params being set
43
+ # In the current state where @modal.batched is used _instead_ of `@modal.method`, we need to give
44
+ # `@modal.batched` two roles (exposing the callable interface, adding batching semantics).
45
+ # But it's probably better to make `@modal.batched` and `@modal.method` stackable, or to move
46
+ # `@modal.batched` to be a class-level decorator since it primarily governs service behavior.
47
+ BATCHED = 64
48
+ CONCURRENT = 128
49
+ CLUSTERED = 256 # Experimental: Clustered functions
39
50
 
40
51
  @staticmethod
41
52
  def all() -> int:
42
53
  return ~_PartialFunctionFlags(0)
43
54
 
55
+ @staticmethod
56
+ def lifecycle_flags() -> int:
57
+ return (
58
+ _PartialFunctionFlags.BUILD # Deprecated, will be removed
59
+ | _PartialFunctionFlags.ENTER_PRE_SNAPSHOT
60
+ | _PartialFunctionFlags.ENTER_POST_SNAPSHOT
61
+ | _PartialFunctionFlags.EXIT
62
+ )
63
+
64
+ @staticmethod
65
+ def interface_flags() -> int:
66
+ return _PartialFunctionFlags.CALLABLE_INTERFACE | _PartialFunctionFlags.WEB_INTERFACE
67
+
68
+
69
+ @dataclass
70
+ class _PartialFunctionParams:
71
+ webhook_config: Optional[api_pb2.WebhookConfig] = None
72
+ is_generator: Optional[bool] = None
73
+ force_build: Optional[bool] = None
74
+ batch_max_size: Optional[int] = None
75
+ batch_wait_ms: Optional[int] = None
76
+ cluster_size: Optional[int] = None
77
+ max_concurrent_inputs: Optional[int] = None
78
+ target_concurrent_inputs: Optional[int] = None
79
+ build_timeout: Optional[int] = None
80
+
81
+ def update(self, other: "_PartialFunctionParams") -> None:
82
+ """Update self with params set in other."""
83
+ for key, val in asdict(other).items():
84
+ if val is not None:
85
+ if getattr(self, key, None) is not None:
86
+ raise InvalidError(f"Cannot set `{key}` twice.")
87
+ setattr(self, key, val)
88
+
44
89
 
45
90
  P = typing_extensions.ParamSpec("P")
46
91
  ReturnType = typing_extensions.TypeVar("ReturnType", covariant=True)
47
92
  OriginalReturnType = typing_extensions.TypeVar("OriginalReturnType", covariant=True)
93
+ NullaryFuncOrMethod = Union[Callable[[], Any], Callable[[Any], Any]]
94
+ NullaryMethod = Callable[[Any], Any]
48
95
 
49
96
 
50
97
  class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
51
- """Intermediate function, produced by @enter, @build, @method, @web_endpoint, or @batched"""
98
+ """Object produced by a decorator in the `modal` namespace
99
+
100
+ The object will eventually by consumed by an App decorator.
101
+ """
52
102
 
53
- raw_f: Callable[P, ReturnType]
103
+ raw_f: Optional[Callable[P, ReturnType]] # function or method
104
+ user_cls: Optional[type] = None # class
54
105
  flags: _PartialFunctionFlags
55
- webhook_config: Optional[api_pb2.WebhookConfig]
56
- is_generator: bool
57
- batch_max_size: Optional[int]
58
- batch_wait_ms: Optional[int]
59
- force_build: bool
60
- cluster_size: Optional[int] # Experimental: Clustered functions
61
- build_timeout: Optional[int]
62
- max_concurrent_inputs: Optional[int]
63
- target_concurrent_inputs: Optional[int]
106
+ params: _PartialFunctionParams
107
+ registered: bool
64
108
 
65
109
  def __init__(
66
110
  self,
67
- raw_f: Callable[P, ReturnType],
111
+ obj: Union[Callable[P, ReturnType], type],
68
112
  flags: _PartialFunctionFlags,
69
- *,
70
- webhook_config: Optional[api_pb2.WebhookConfig] = None,
71
- is_generator: Optional[bool] = None,
72
- batch_max_size: Optional[int] = None,
73
- batch_wait_ms: Optional[int] = None,
74
- cluster_size: Optional[int] = None, # Experimental: Clustered functions
75
- force_build: bool = False,
76
- build_timeout: Optional[int] = None,
77
- max_concurrent_inputs: Optional[int] = None,
78
- target_concurrent_inputs: Optional[int] = None,
113
+ params: _PartialFunctionParams,
79
114
  ):
80
- self.raw_f = raw_f
81
- self.flags = flags
82
- self.webhook_config = webhook_config
83
- if is_generator is None:
84
- # auto detect - doesn't work if the function *returns* a generator
85
- final_is_generator = inspect.isgeneratorfunction(raw_f) or inspect.isasyncgenfunction(raw_f)
115
+ if isinstance(obj, type):
116
+ self.user_cls = obj
117
+ self.raw_f = None
86
118
  else:
87
- final_is_generator = is_generator
88
-
89
- self.is_generator = final_is_generator
90
- self.wrapped = False # Make sure that this was converted into a FunctionHandle
91
- self.batch_max_size = batch_max_size
92
- self.batch_wait_ms = batch_wait_ms
93
- self.cluster_size = cluster_size # Experimental: Clustered functions
94
- self.force_build = force_build
95
- self.build_timeout = build_timeout
96
- self.max_concurrent_inputs = max_concurrent_inputs
97
- self.target_concurrent_inputs = target_concurrent_inputs
119
+ self.raw_f = obj
120
+ self.user_cls = None
121
+ self.flags = flags
122
+ self.params = params
123
+ self.registered = False
124
+ self.validate_flag_composition()
125
+
126
+ def stack(self, flags: _PartialFunctionFlags, params: _PartialFunctionParams) -> typing_extensions.Self:
127
+ """Implement decorator composition by combining the flags and params."""
128
+ self.flags |= flags
129
+ self.params.update(params)
130
+ self.validate_flag_composition()
131
+ return self
132
+
133
+ def validate_flag_composition(self) -> None:
134
+ """Validate decorator composition based on PartialFunctionFlags."""
135
+ uses_interface_flags = self.flags & _PartialFunctionFlags.interface_flags()
136
+ uses_lifecycle_flags = self.flags & _PartialFunctionFlags.lifecycle_flags()
137
+ if uses_interface_flags and uses_lifecycle_flags:
138
+ self.registered = True # Hacky, avoid false-positive warning
139
+ raise InvalidError("Interface decorators cannot be combined with lifecycle decorators.")
140
+
141
+ has_web_interface = self.flags & _PartialFunctionFlags.WEB_INTERFACE
142
+ has_callable_interface = self.flags & _PartialFunctionFlags.CALLABLE_INTERFACE
143
+ if has_web_interface and has_callable_interface:
144
+ self.registered = True # Hacky, avoid false-positive warning
145
+ raise InvalidError("Callable decorators cannot be combined with web interface decorators.")
146
+
147
+ def validate_obj_compatibility(
148
+ self, decorator_name: str, require_sync: bool = False, require_nullary: bool = False
149
+ ) -> None:
150
+ """Enforce compatibility with the wrapped object; called from individual decorator functions."""
151
+ from .cls import _Cls # Avoid circular import
152
+
153
+ uses_lifecycle_flags = self.flags & _PartialFunctionFlags.lifecycle_flags()
154
+ uses_interface_flags = self.flags & _PartialFunctionFlags.interface_flags()
155
+ if self.user_cls is not None and (uses_lifecycle_flags or uses_interface_flags):
156
+ self.registered = True # Hacky, avoid false-positive warning
157
+ raise InvalidError(
158
+ f"Cannot apply `@modal.{decorator_name}` to a class. Hint: consider applying to a method instead."
159
+ )
160
+
161
+ wrapped_object = self.raw_f or self.user_cls
162
+ if isinstance(wrapped_object, _Function):
163
+ self.registered = True # Hacky, avoid false-positive warning
164
+ raise InvalidError(
165
+ f"Cannot stack `@modal.{decorator_name}` on top of `@app.function`."
166
+ " Hint: swap the order of the decorators."
167
+ )
168
+ elif isinstance(wrapped_object, _Cls):
169
+ self.registered = True # Hacky, avoid false-positive warning
170
+ raise InvalidError(
171
+ f"Cannot stack `@modal.{decorator_name}` on top of `@app.cls()`."
172
+ " Hint: swap the order of the decorators."
173
+ )
174
+
175
+ # Run some assertions about a callable wrappee defined by the specific decorator used
176
+ if self.raw_f is not None:
177
+ if not callable(self.raw_f):
178
+ self.registered = True # Hacky, avoid false-positive warning
179
+ raise InvalidError(f"The object wrapped by `@modal.{decorator_name}` must be callable.")
180
+
181
+ if require_sync and inspect.iscoroutinefunction(self.raw_f):
182
+ self.registered = True # Hacky, avoid false-positive warning
183
+ raise InvalidError(f"`@modal.{decorator_name}` can't be applied to an async function.")
184
+
185
+ if require_nullary and callable_has_non_self_params(self.raw_f):
186
+ self.registered = True # Hacky, avoid false-positive warning
187
+ if callable_has_non_self_non_default_params(self.raw_f):
188
+ raise InvalidError(f"Functions obj by `@modal.{decorator_name}` can't have parameters.")
189
+ else:
190
+ # TODO(michael): probably fine to just make this an error at this point
191
+ # but best to do it in a separate PR
192
+ deprecation_warning(
193
+ (2024, 9, 4),
194
+ f"The function obj by `@modal.{decorator_name}` has default parameters, "
195
+ "but shouldn't have any parameters - Modal will drop support for "
196
+ "default parameters in a future release.",
197
+ )
98
198
 
99
199
  def _get_raw_f(self) -> Callable[P, ReturnType]:
200
+ assert self.raw_f is not None
100
201
  return self.raw_f
101
202
 
102
203
  def _is_web_endpoint(self) -> bool:
103
- if self.webhook_config is None:
204
+ if self.params.webhook_config is None:
104
205
  return False
105
- return self.webhook_config.type != api_pb2.WEBHOOK_TYPE_UNSPECIFIED
206
+ return self.params.webhook_config.type != api_pb2.WEBHOOK_TYPE_UNSPECIFIED
106
207
 
107
208
  def __get__(self, obj, objtype=None) -> _Function[P, ReturnType, OriginalReturnType]:
108
209
  # to type checkers, any @method or similar function on a modal class, would appear to be
@@ -111,6 +212,7 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
111
212
  # However, modal classes are *actually* Cls instances (which isn't reflected in type checkers
112
213
  # due to Python's lack of type chekcing intersection types), so at runtime the Cls instance would
113
214
  # use its __getattr__ rather than this descriptor.
215
+ assert self.raw_f is not None # Should only be relevant in a method context
114
216
  k = self.raw_f.__name__
115
217
  if obj: # accessing the method on an instance of a class, e.g. `MyClass().fun``
116
218
  if hasattr(obj, "_modal_functions"):
@@ -126,37 +228,26 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
126
228
  return self.raw_f.__get__(obj, objtype)
127
229
 
128
230
  else: # accessing a method directly on the class, e.g. `MyClass.fun`
129
- # This happens mainly during serialization of the wrapped underlying class of a Cls
231
+ # This happens mainly during serialization of the obj underlying class of a Cls
130
232
  # since we don't have the instance info here we just return the PartialFunction itself
131
233
  # to let it be bound to a variable and become a Function later on
132
234
  return self # type: ignore # this returns a PartialFunction in a special internal case
133
235
 
134
236
  def __del__(self):
135
- if (self.flags & _PartialFunctionFlags.FUNCTION) and self.wrapped is False:
237
+ if self.registered is False:
238
+ if self.raw_f is not None:
239
+ name, object_type, suggestion = self.raw_f.__name__, "function", "@app.function or @app.cls"
240
+ elif self.user_cls is not None:
241
+ name, object_type, suggestion = self.user_cls.__name__, "class", "@app.cls"
136
242
  logger.warning(
137
- f"Method or web function {self.raw_f} was never turned into a function."
138
- " Did you forget a @app.function or @app.cls decorator?"
243
+ f"The `{name}` {object_type} was never registered with the App."
244
+ f" Did you forget an {suggestion} decorator?"
139
245
  )
140
246
 
141
- def add_flags(self, flags) -> "_PartialFunction":
142
- # Helper method used internally when stacking decorators
143
- self.wrapped = True
144
- return _PartialFunction(
145
- raw_f=self.raw_f,
146
- flags=(self.flags | flags),
147
- webhook_config=self.webhook_config,
148
- batch_max_size=self.batch_max_size,
149
- batch_wait_ms=self.batch_wait_ms,
150
- force_build=self.force_build,
151
- build_timeout=self.build_timeout,
152
- max_concurrent_inputs=self.max_concurrent_inputs,
153
- target_concurrent_inputs=self.target_concurrent_inputs,
154
- )
155
-
156
247
 
157
248
  def _find_partial_methods_for_user_cls(user_cls: type[Any], flags: int) -> dict[str, _PartialFunction]:
158
249
  """Grabs all method on a user class, and returns partials. Includes legacy methods."""
159
- from .partial_function import PartialFunction # wrapped type
250
+ from .partial_function import PartialFunction # obj type
160
251
 
161
252
  partial_functions: dict[str, _PartialFunction] = {}
162
253
  for parent_cls in reversed(user_cls.mro()):
@@ -173,7 +264,11 @@ def _find_partial_methods_for_user_cls(user_cls: type[Any], flags: int) -> dict[
173
264
  def _find_callables_for_obj(user_obj: Any, flags: int) -> dict[str, Callable[..., Any]]:
174
265
  """Grabs all methods for an object, and binds them to the class"""
175
266
  user_cls: type = type(user_obj)
176
- return {k: pf.raw_f.__get__(user_obj) for k, pf in _find_partial_methods_for_user_cls(user_cls, flags).items()}
267
+ return {
268
+ k: pf.raw_f.__get__(user_obj)
269
+ for k, pf in _find_partial_methods_for_user_cls(user_cls, flags).items()
270
+ if pf.raw_f is not None # Should be true for output of _find_partial_methods_for_user_cls, but hard to annotate
271
+ }
177
272
 
178
273
 
179
274
  class _MethodDecoratorType:
@@ -222,23 +317,24 @@ def _method(
222
317
  "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@modal.method()`."
223
318
  )
224
319
 
225
- def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction:
226
- nonlocal is_generator
227
- if isinstance(raw_f, _PartialFunction) and raw_f.webhook_config:
228
- raw_f.wrapped = True # suppress later warning
229
- raise InvalidError(
230
- "Web endpoints on classes should not be wrapped by `@method`. "
231
- "Suggestion: remove the `@method` decorator."
232
- )
233
- if isinstance(raw_f, _PartialFunction) and raw_f.batch_max_size is not None:
234
- raw_f.wrapped = True # suppress later warning
235
- raise InvalidError(
236
- "Batched function on classes should not be wrapped by `@method`. "
237
- "Suggestion: remove the `@method` decorator."
238
- )
239
- return _PartialFunction(raw_f, _PartialFunctionFlags.FUNCTION, is_generator=is_generator)
320
+ def wrapper(obj: Union[Callable[..., Any], _PartialFunction]) -> _PartialFunction:
321
+ flags = _PartialFunctionFlags.CALLABLE_INTERFACE
240
322
 
241
- return wrapper # type: ignore # synchronicity issue with wrapped vs unwrapped types and protocols
323
+ nonlocal is_generator # TODO(michael): we are likely to deprecate the explicit is_generator param
324
+ if is_generator is None:
325
+ callable = obj.raw_f if isinstance(obj, _PartialFunction) else obj
326
+ is_generator = inspect.isgeneratorfunction(callable) or inspect.isasyncgenfunction(callable)
327
+ params = _PartialFunctionParams(is_generator=is_generator)
328
+
329
+ if isinstance(obj, _PartialFunction):
330
+ pf = obj.stack(flags, params)
331
+ else:
332
+ pf = _PartialFunction(obj, flags, params)
333
+ pf.validate_obj_compatibility("method")
334
+ return pf
335
+
336
+ # TODO(michael) verify that we still need the type: ignore
337
+ return wrapper # type: ignore # synchronicity issue with obj vs unwrapped types and protocols
242
338
 
243
339
 
244
340
  def _parse_custom_domains(custom_domains: Optional[Iterable[str]] = None) -> list[api_pb2.CustomDomainConfig]:
@@ -259,7 +355,10 @@ def _fastapi_endpoint(
259
355
  custom_domains: Optional[Iterable[str]] = None, # Custom fully-qualified domain name (FQDN) for the endpoint.
260
356
  docs: bool = False, # Whether to enable interactive documentation for this endpoint at /docs.
261
357
  requires_proxy_auth: bool = False, # Require Modal-Key and Modal-Secret HTTP Headers on requests.
262
- ) -> Callable[[Callable[P, ReturnType]], _PartialFunction[P, ReturnType, ReturnType]]:
358
+ ) -> Callable[
359
+ [Union[_PartialFunction[P, ReturnType, ReturnType], Callable[P, ReturnType]]],
360
+ _PartialFunction[P, ReturnType, ReturnType],
361
+ ]:
263
362
  """Convert a function into a basic web endpoint by wrapping it with a FastAPI App.
264
363
 
265
364
  Modal will internally use [FastAPI](https://fastapi.tiangolo.com/) to expose a
@@ -285,27 +384,28 @@ def _fastapi_endpoint(
285
384
  "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@modal.fastapi_endpoint()`."
286
385
  )
287
386
 
288
- def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction:
289
- if isinstance(raw_f, _Function):
290
- raw_f = raw_f.get_raw_f()
291
- raise InvalidError(
292
- f"Applying decorators for {raw_f} in the wrong order!\nUsage:\n\n"
293
- "@app.function()\n@app.fastapi_endpoint()\ndef my_webhook():\n ..."
294
- )
387
+ webhook_config = api_pb2.WebhookConfig(
388
+ type=api_pb2.WEBHOOK_TYPE_FUNCTION,
389
+ method=method,
390
+ web_endpoint_docs=docs,
391
+ requested_suffix=label or "",
392
+ async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
393
+ custom_domains=_parse_custom_domains(custom_domains),
394
+ requires_proxy_auth=requires_proxy_auth,
395
+ )
295
396
 
296
- return _PartialFunction(
297
- raw_f,
298
- _PartialFunctionFlags.FUNCTION,
299
- webhook_config=api_pb2.WebhookConfig(
300
- type=api_pb2.WEBHOOK_TYPE_FUNCTION,
301
- method=method,
302
- web_endpoint_docs=docs,
303
- requested_suffix=label or "",
304
- async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
305
- custom_domains=_parse_custom_domains(custom_domains),
306
- requires_proxy_auth=requires_proxy_auth,
307
- ),
308
- )
397
+ flags = _PartialFunctionFlags.WEB_INTERFACE
398
+ params = _PartialFunctionParams(webhook_config=webhook_config)
399
+
400
+ def wrapper(
401
+ obj: Union[_PartialFunction[P, ReturnType, ReturnType], Callable[P, ReturnType]],
402
+ ) -> _PartialFunction[P, ReturnType, ReturnType]:
403
+ if isinstance(obj, _PartialFunction):
404
+ pf = obj.stack(flags, params)
405
+ else:
406
+ pf = _PartialFunction(obj, flags, params)
407
+ pf.validate_obj_compatibility("fastapi_endpoint")
408
+ return pf
309
409
 
310
410
  return wrapper
311
411
 
@@ -320,7 +420,10 @@ def _web_endpoint(
320
420
  Iterable[str]
321
421
  ] = None, # Create an endpoint using a custom domain fully-qualified domain name (FQDN).
322
422
  requires_proxy_auth: bool = False, # Require Modal-Key and Modal-Secret HTTP Headers on requests.
323
- ) -> Callable[[Callable[P, ReturnType]], _PartialFunction[P, ReturnType, ReturnType]]:
423
+ ) -> Callable[
424
+ [Union[_PartialFunction[P, ReturnType, ReturnType], Callable[P, ReturnType]]],
425
+ _PartialFunction[P, ReturnType, ReturnType],
426
+ ]:
324
427
  """Register a basic web endpoint with this application.
325
428
 
326
429
  DEPRECATED: This decorator has been renamed to `@modal.fastapi_endpoint`.
@@ -349,27 +452,28 @@ def _web_endpoint(
349
452
  (2025, 3, 5), "The `@modal.web_endpoint` decorator has been renamed to `@modal.fastapi_endpoint`."
350
453
  )
351
454
 
352
- def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction:
353
- if isinstance(raw_f, _Function):
354
- raw_f = raw_f.get_raw_f()
355
- raise InvalidError(
356
- f"Applying decorators for {raw_f} in the wrong order!\nUsage:\n\n"
357
- "@app.function()\n@modal.web_endpoint()\ndef my_webhook():\n ..."
358
- )
455
+ webhook_config = api_pb2.WebhookConfig(
456
+ type=api_pb2.WEBHOOK_TYPE_FUNCTION,
457
+ method=method,
458
+ web_endpoint_docs=docs,
459
+ requested_suffix=label or "",
460
+ async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
461
+ custom_domains=_parse_custom_domains(custom_domains),
462
+ requires_proxy_auth=requires_proxy_auth,
463
+ )
359
464
 
360
- return _PartialFunction(
361
- raw_f,
362
- _PartialFunctionFlags.FUNCTION,
363
- webhook_config=api_pb2.WebhookConfig(
364
- type=api_pb2.WEBHOOK_TYPE_FUNCTION,
365
- method=method,
366
- web_endpoint_docs=docs,
367
- requested_suffix=label,
368
- async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
369
- custom_domains=_parse_custom_domains(custom_domains),
370
- requires_proxy_auth=requires_proxy_auth,
371
- ),
372
- )
465
+ flags = _PartialFunctionFlags.WEB_INTERFACE
466
+ params = _PartialFunctionParams(webhook_config=webhook_config)
467
+
468
+ def wrapper(
469
+ obj: Union[_PartialFunction[P, ReturnType, ReturnType], Callable[P, ReturnType]],
470
+ ) -> _PartialFunction[P, ReturnType, ReturnType]:
471
+ if isinstance(obj, _PartialFunction):
472
+ pf = obj.stack(flags, params)
473
+ else:
474
+ pf = _PartialFunction(obj, flags, params)
475
+ pf.validate_obj_compatibility("web_endpoint")
476
+ return pf
373
477
 
374
478
  return wrapper
375
479
 
@@ -380,7 +484,7 @@ def _asgi_app(
380
484
  label: Optional[str] = None, # Label for created endpoint. Final subdomain will be <workspace>--<label>.modal.run.
381
485
  custom_domains: Optional[Iterable[str]] = None, # Deploy this endpoint on a custom domain.
382
486
  requires_proxy_auth: bool = False, # Require Modal-Key and Modal-Secret HTTP Headers on requests.
383
- ) -> Callable[[Callable[..., Any]], _PartialFunction]:
487
+ ) -> Callable[[Union[_PartialFunction, NullaryFuncOrMethod]], _PartialFunction]:
384
488
  """Decorator for registering an ASGI app with a Modal function.
385
489
 
386
490
  Asynchronous Server Gateway Interface (ASGI) is a standard for Python
@@ -409,35 +513,24 @@ def _asgi_app(
409
513
  "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@modal.asgi_app()`."
410
514
  )
411
515
 
412
- def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction:
413
- if callable_has_non_self_params(raw_f):
414
- if callable_has_non_self_non_default_params(raw_f):
415
- raise InvalidError(
416
- f"ASGI app function {raw_f.__name__} can't have parameters. See https://modal.com/docs/guide/webhooks#asgi."
417
- )
418
- else:
419
- deprecation_warning(
420
- (2024, 9, 4),
421
- f"ASGI app function {raw_f.__name__} has default parameters, but shouldn't have any parameters - "
422
- f"Modal will drop support for default parameters in a future release.",
423
- )
516
+ webhook_config = api_pb2.WebhookConfig(
517
+ type=api_pb2.WEBHOOK_TYPE_ASGI_APP,
518
+ requested_suffix=label or "",
519
+ async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
520
+ custom_domains=_parse_custom_domains(custom_domains),
521
+ requires_proxy_auth=requires_proxy_auth,
522
+ )
424
523
 
425
- if inspect.iscoroutinefunction(raw_f):
426
- raise InvalidError(
427
- f"ASGI app function {raw_f.__name__} is an async function. Only sync Python functions are supported."
428
- )
524
+ flags = _PartialFunctionFlags.WEB_INTERFACE
525
+ params = _PartialFunctionParams(webhook_config=webhook_config)
429
526
 
430
- return _PartialFunction(
431
- raw_f,
432
- _PartialFunctionFlags.FUNCTION,
433
- webhook_config=api_pb2.WebhookConfig(
434
- type=api_pb2.WEBHOOK_TYPE_ASGI_APP,
435
- requested_suffix=label,
436
- async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
437
- custom_domains=_parse_custom_domains(custom_domains),
438
- requires_proxy_auth=requires_proxy_auth,
439
- ),
440
- )
527
+ def wrapper(obj: Union[_PartialFunction, NullaryFuncOrMethod]) -> _PartialFunction:
528
+ if isinstance(obj, _PartialFunction):
529
+ pf = obj.stack(flags, params)
530
+ else:
531
+ pf = _PartialFunction(obj, flags, params)
532
+ pf.validate_obj_compatibility("asgi_app", require_sync=True, require_nullary=True)
533
+ return pf
441
534
 
442
535
  return wrapper
443
536
 
@@ -448,7 +541,7 @@ def _wsgi_app(
448
541
  label: Optional[str] = None, # Label for created endpoint. Final subdomain will be <workspace>--<label>.modal.run.
449
542
  custom_domains: Optional[Iterable[str]] = None, # Deploy this endpoint on a custom domain.
450
543
  requires_proxy_auth: bool = False, # Require Modal-Key and Modal-Secret HTTP Headers on requests.
451
- ) -> Callable[[Callable[..., Any]], _PartialFunction]:
544
+ ) -> Callable[[Union[_PartialFunction, NullaryFuncOrMethod]], _PartialFunction]:
452
545
  """Decorator for registering a WSGI app with a Modal function.
453
546
 
454
547
  Web Server Gateway Interface (WSGI) is a standard for synchronous Python web apps.
@@ -477,35 +570,24 @@ def _wsgi_app(
477
570
  "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@modal.wsgi_app()`."
478
571
  )
479
572
 
480
- def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction:
481
- if callable_has_non_self_params(raw_f):
482
- if callable_has_non_self_non_default_params(raw_f):
483
- raise InvalidError(
484
- f"WSGI app function {raw_f.__name__} can't have parameters. See https://modal.com/docs/guide/webhooks#wsgi."
485
- )
486
- else:
487
- deprecation_warning(
488
- (2024, 9, 4),
489
- f"WSGI app function {raw_f.__name__} has default parameters, but shouldn't have any parameters - "
490
- f"Modal will drop support for default parameters in a future release.",
491
- )
573
+ webhook_config = api_pb2.WebhookConfig(
574
+ type=api_pb2.WEBHOOK_TYPE_WSGI_APP,
575
+ requested_suffix=label or "",
576
+ async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
577
+ custom_domains=_parse_custom_domains(custom_domains),
578
+ requires_proxy_auth=requires_proxy_auth,
579
+ )
492
580
 
493
- if inspect.iscoroutinefunction(raw_f):
494
- raise InvalidError(
495
- f"WSGI app function {raw_f.__name__} is an async function. Only sync Python functions are supported."
496
- )
581
+ flags = _PartialFunctionFlags.WEB_INTERFACE
582
+ params = _PartialFunctionParams(webhook_config=webhook_config)
497
583
 
498
- return _PartialFunction(
499
- raw_f,
500
- _PartialFunctionFlags.FUNCTION,
501
- webhook_config=api_pb2.WebhookConfig(
502
- type=api_pb2.WEBHOOK_TYPE_WSGI_APP,
503
- requested_suffix=label,
504
- async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
505
- custom_domains=_parse_custom_domains(custom_domains),
506
- requires_proxy_auth=requires_proxy_auth,
507
- ),
508
- )
584
+ def wrapper(obj: Union[_PartialFunction, NullaryFuncOrMethod]) -> _PartialFunction:
585
+ if isinstance(obj, _PartialFunction):
586
+ pf = obj.stack(flags, params)
587
+ else:
588
+ pf = _PartialFunction(obj, flags, params)
589
+ pf.validate_obj_compatibility("wsgi_app", require_sync=True, require_nullary=True)
590
+ return pf
509
591
 
510
592
  return wrapper
511
593
 
@@ -517,7 +599,7 @@ def _web_server(
517
599
  label: Optional[str] = None, # Label for created endpoint. Final subdomain will be <workspace>--<label>.modal.run.
518
600
  custom_domains: Optional[Iterable[str]] = None, # Deploy this endpoint on a custom domain.
519
601
  requires_proxy_auth: bool = False, # Require Modal-Key and Modal-Secret HTTP Headers on requests.
520
- ) -> Callable[[Callable[..., Any]], _PartialFunction]:
602
+ ) -> Callable[[Union[_PartialFunction, NullaryFuncOrMethod]], _PartialFunction]:
521
603
  """Decorator that registers an HTTP web server inside the container.
522
604
 
523
605
  This is similar to `@asgi_app` and `@wsgi_app`, but it allows you to expose a full HTTP server
@@ -549,33 +631,33 @@ def _web_server(
549
631
  if startup_timeout <= 0:
550
632
  raise InvalidError("The `startup_timeout` argument of `@web_server` must be positive.")
551
633
 
552
- def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction:
553
- return _PartialFunction(
554
- raw_f,
555
- _PartialFunctionFlags.FUNCTION,
556
- webhook_config=api_pb2.WebhookConfig(
557
- type=api_pb2.WEBHOOK_TYPE_WEB_SERVER,
558
- requested_suffix=label,
559
- async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
560
- custom_domains=_parse_custom_domains(custom_domains),
561
- web_server_port=port,
562
- web_server_startup_timeout=startup_timeout,
563
- requires_proxy_auth=requires_proxy_auth,
564
- ),
565
- )
634
+ webhook_config = api_pb2.WebhookConfig(
635
+ type=api_pb2.WEBHOOK_TYPE_WEB_SERVER,
636
+ requested_suffix=label or "",
637
+ async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
638
+ custom_domains=_parse_custom_domains(custom_domains),
639
+ web_server_port=port,
640
+ web_server_startup_timeout=startup_timeout,
641
+ requires_proxy_auth=requires_proxy_auth,
642
+ )
566
643
 
567
- return wrapper
644
+ flags = _PartialFunctionFlags.WEB_INTERFACE
645
+ params = _PartialFunctionParams(webhook_config=webhook_config)
568
646
 
647
+ def wrapper(obj: Union[_PartialFunction, NullaryFuncOrMethod]) -> _PartialFunction:
648
+ if isinstance(obj, _PartialFunction):
649
+ pf = obj.stack(flags, params)
650
+ else:
651
+ pf = _PartialFunction(obj, flags, params)
652
+ pf.validate_obj_compatibility("web_server", require_sync=True, require_nullary=True)
653
+ return pf
569
654
 
570
- def _disallow_wrapping_method(f: _PartialFunction, wrapper: str) -> None:
571
- if f.flags & _PartialFunctionFlags.FUNCTION:
572
- f.wrapped = True # Hack to avoid warning about not using @app.cls()
573
- raise InvalidError(f"Cannot use `@{wrapper}` decorator with `@method`.")
655
+ return wrapper
574
656
 
575
657
 
576
658
  def _build(
577
659
  _warn_parentheses_missing=None, *, force: bool = False, timeout: int = 86400
578
- ) -> Callable[[Union[Callable[[Any], Any], _PartialFunction]], _PartialFunction]:
660
+ ) -> Callable[[Union[_PartialFunction, NullaryMethod]], _PartialFunction]:
579
661
  """
580
662
  Decorator for methods that execute at _build time_ to create a new Image layer.
581
663
 
@@ -612,14 +694,16 @@ def _build(
612
694
  "\n\nSee https://modal.com/docs/guide/modal-1-0-migration for more information.",
613
695
  )
614
696
 
615
- def wrapper(f: Union[Callable[[Any], Any], _PartialFunction]) -> _PartialFunction:
616
- if isinstance(f, _PartialFunction):
617
- _disallow_wrapping_method(f, "build")
618
- f.force_build = force
619
- f.build_timeout = timeout
620
- return f.add_flags(_PartialFunctionFlags.BUILD)
697
+ flags = _PartialFunctionFlags.BUILD
698
+ params = _PartialFunctionParams(force_build=force, build_timeout=timeout)
699
+
700
+ def wrapper(obj: Union[_PartialFunction, NullaryMethod]) -> _PartialFunction:
701
+ if isinstance(obj, _PartialFunction):
702
+ pf = obj.stack(flags, params)
621
703
  else:
622
- return _PartialFunction(f, _PartialFunctionFlags.BUILD, force_build=force, build_timeout=timeout)
704
+ pf = _PartialFunction(obj, flags, params)
705
+ pf.validate_obj_compatibility("build")
706
+ return pf
623
707
 
624
708
  return wrapper
625
709
 
@@ -628,7 +712,7 @@ def _enter(
628
712
  _warn_parentheses_missing=None,
629
713
  *,
630
714
  snap: bool = False,
631
- ) -> Callable[[Union[Callable[[Any], Any], _PartialFunction]], _PartialFunction]:
715
+ ) -> Callable[[Union[_PartialFunction, NullaryMethod]], _PartialFunction]:
632
716
  """Decorator for methods which should be executed when a new container is started.
633
717
 
634
718
  See the [lifeycle function guide](https://modal.com/docs/guide/lifecycle-functions#enter) for more information."""
@@ -637,32 +721,22 @@ def _enter(
637
721
  "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@modal.enter()`."
638
722
  )
639
723
 
640
- if snap:
641
- flag = _PartialFunctionFlags.ENTER_PRE_SNAPSHOT
642
- else:
643
- flag = _PartialFunctionFlags.ENTER_POST_SNAPSHOT
724
+ flags = _PartialFunctionFlags.ENTER_PRE_SNAPSHOT if snap else _PartialFunctionFlags.ENTER_POST_SNAPSHOT
725
+ params = _PartialFunctionParams()
644
726
 
645
- def wrapper(f: Union[Callable[[Any], Any], _PartialFunction]) -> _PartialFunction:
646
- if isinstance(f, _PartialFunction):
647
- _disallow_wrapping_method(f, "enter")
648
- return f.add_flags(flag)
727
+ def wrapper(obj: Union[_PartialFunction, NullaryMethod]) -> _PartialFunction:
728
+ # TODO: reject stacking once depreceate @modal.build
729
+ if isinstance(obj, _PartialFunction):
730
+ pf = obj.stack(flags, params)
649
731
  else:
650
- return _PartialFunction(f, flag)
732
+ pf = _PartialFunction(obj, flags, params)
733
+ pf.validate_obj_compatibility("enter") # TODO require_nullary?
734
+ return pf
651
735
 
652
736
  return wrapper
653
737
 
654
738
 
655
- ExitHandlerType = Union[
656
- # NOTE: return types of these callables should be `Union[None, Awaitable[None]]` but
657
- # synchronicity type stubs would strip Awaitable so we use Any for now
658
- # Original, __exit__ style method signature (now deprecated)
659
- Callable[[Any, Optional[type[BaseException]], Optional[BaseException], Any], Any],
660
- # Forward-looking unparametrized method
661
- Callable[[Any], Any],
662
- ]
663
-
664
-
665
- def _exit(_warn_parentheses_missing=None) -> Callable[[ExitHandlerType], _PartialFunction]:
739
+ def _exit(_warn_parentheses_missing=None) -> Callable[[NullaryMethod], _PartialFunction]:
666
740
  """Decorator for methods which should be executed when a container is about to exit.
667
741
 
668
742
  See the [lifeycle function guide](https://modal.com/docs/guide/lifecycle-functions#exit) for more information."""
@@ -671,11 +745,16 @@ def _exit(_warn_parentheses_missing=None) -> Callable[[ExitHandlerType], _Partia
671
745
  "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@modal.exit()`."
672
746
  )
673
747
 
674
- def wrapper(f: ExitHandlerType) -> _PartialFunction:
675
- if isinstance(f, _PartialFunction):
676
- _disallow_wrapping_method(f, "exit")
748
+ flags = _PartialFunctionFlags.EXIT
749
+ params = _PartialFunctionParams()
677
750
 
678
- return _PartialFunction(f, _PartialFunctionFlags.EXIT)
751
+ def wrapper(obj: Union[_PartialFunction, NullaryMethod]) -> _PartialFunction:
752
+ if isinstance(obj, _PartialFunction):
753
+ pf = obj.stack(flags, params)
754
+ else:
755
+ pf = _PartialFunction(obj, flags, params)
756
+ pf.validate_obj_compatibility("exit") # TODO require_nullary?
757
+ return pf
679
758
 
680
759
  return wrapper
681
760
 
@@ -685,19 +764,30 @@ def _batched(
685
764
  *,
686
765
  max_batch_size: int,
687
766
  wait_ms: int,
688
- ) -> Callable[[Callable[..., Any]], _PartialFunction]:
767
+ ) -> Callable[
768
+ [Union[_PartialFunction[P, ReturnType, ReturnType], Callable[P, ReturnType]]],
769
+ _PartialFunction[P, ReturnType, ReturnType],
770
+ ]:
689
771
  """Decorator for functions or class methods that should be batched.
690
772
 
691
773
  **Usage**
692
774
 
693
- ```python notest
775
+ ```python
776
+ # Stack the decorator under `@app.function()` to enable dynamic batching
694
777
  @app.function()
695
778
  @modal.batched(max_batch_size=4, wait_ms=1000)
696
779
  async def batched_multiply(xs: list[int], ys: list[int]) -> list[int]:
697
780
  return [x * y for x, y in zip(xs, xs)]
698
781
 
699
782
  # call batched_multiply with individual inputs
700
- batched_multiply.remote.aio(2, 100)
783
+ # batched_multiply.remote.aio(2, 100)
784
+
785
+ # With `@app.cls()`, apply the decorator to a method (this may change in the future)
786
+ @app.cls()
787
+ class BatchedClass:
788
+ @modal.batched(max_batch_size=4, wait_ms=1000)
789
+ def batched_multiply(self, xs: list[int], ys: list[int]) -> list[int]:
790
+ return [x * y for x, y in zip(xs, xs)]
701
791
  ```
702
792
 
703
793
  See the [dynamic batching guide](https://modal.com/docs/guide/dynamic-batching) for more information.
@@ -715,19 +805,18 @@ def _batched(
715
805
  if wait_ms >= MAX_BATCH_WAIT_MS:
716
806
  raise InvalidError(f"wait_ms must be less than {MAX_BATCH_WAIT_MS}.")
717
807
 
718
- def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction:
719
- if isinstance(raw_f, _Function):
720
- raw_f = raw_f.get_raw_f()
721
- raise InvalidError(
722
- f"Applying decorators for {raw_f} in the wrong order!\nUsage:\n\n"
723
- "@app.function()\n@modal.batched()\ndef batched_function():\n ..."
724
- )
725
- return _PartialFunction(
726
- raw_f,
727
- _PartialFunctionFlags.FUNCTION | _PartialFunctionFlags.BATCHED,
728
- batch_max_size=max_batch_size,
729
- batch_wait_ms=wait_ms,
730
- )
808
+ flags = _PartialFunctionFlags.CALLABLE_INTERFACE | _PartialFunctionFlags.BATCHED
809
+ params = _PartialFunctionParams(batch_max_size=max_batch_size, batch_wait_ms=wait_ms)
810
+
811
+ def wrapper(
812
+ obj: Union[_PartialFunction[P, ReturnType, ReturnType], Callable[P, ReturnType]],
813
+ ) -> _PartialFunction[P, ReturnType, ReturnType]:
814
+ if isinstance(obj, _PartialFunction):
815
+ pf = obj.stack(flags, params)
816
+ else:
817
+ pf = _PartialFunction(obj, flags, params)
818
+ pf.validate_obj_compatibility("batched")
819
+ return pf
731
820
 
732
821
  return wrapper
733
822
 
@@ -737,7 +826,10 @@ def _concurrent(
737
826
  *,
738
827
  max_inputs: int, # Hard limit on each container's input concurrency
739
828
  target_inputs: Optional[int] = None, # Input concurrency that Modal's autoscaler should target
740
- ) -> Callable[[Union[Callable[..., Any], _PartialFunction]], _PartialFunction]:
829
+ ) -> Callable[
830
+ [Union[Callable[P, ReturnType], _PartialFunction[P, ReturnType, ReturnType]]],
831
+ _PartialFunction[P, ReturnType, ReturnType],
832
+ ]:
741
833
  """Decorator that allows individual containers to handle multiple inputs concurrently.
742
834
 
743
835
  The concurrency mechanism depends on whether the function is async or not:
@@ -784,19 +876,55 @@ def _concurrent(
784
876
  if target_inputs and target_inputs > max_inputs:
785
877
  raise InvalidError("`target_inputs` parameter cannot be greater than `max_inputs`.")
786
878
 
787
- def wrapper(obj: Union[Callable[..., Any], _PartialFunction]) -> _PartialFunction:
879
+ flags = _PartialFunctionFlags.CONCURRENT
880
+ params = _PartialFunctionParams(max_concurrent_inputs=max_inputs, target_concurrent_inputs=target_inputs)
881
+
882
+ # Note: ideally we would have some way of declaring that this decorator cannot be used on an individual method.
883
+ # I don't think there's any clear way for the wrapper function to know it's been passed "a method" rather than
884
+ # a normal function. So we need to run that check in the `@app.cls` decorator, which is a little far removed.
885
+
886
+ def wrapper(
887
+ obj: Union[_PartialFunction[P, ReturnType, ReturnType], Callable[P, ReturnType]],
888
+ ) -> _PartialFunction[P, ReturnType, ReturnType]:
788
889
  if isinstance(obj, _PartialFunction):
789
- # Risky that we need to mutate the parameters here; should make this safer
790
- obj.max_concurrent_inputs = max_inputs
791
- obj.target_concurrent_inputs = target_inputs
792
- obj.add_flags(_PartialFunctionFlags.FUNCTION)
793
- return obj
794
-
795
- return _PartialFunction(
796
- obj,
797
- _PartialFunctionFlags.FUNCTION,
798
- max_concurrent_inputs=max_inputs,
799
- target_concurrent_inputs=target_inputs,
800
- )
890
+ pf = obj.stack(flags, params)
891
+ else:
892
+ pf = _PartialFunction(obj, flags, params)
893
+ pf.validate_obj_compatibility("concurrent")
894
+ return pf
895
+
896
+ return wrapper
897
+
898
+
899
+ # NOTE: clustered is currently exposed through modal.experimental, not the top-level namespace
900
+ def _clustered(size: int, broadcast: bool = True):
901
+ """Provision clusters of colocated and networked containers for the Function.
902
+
903
+ Parameters:
904
+ size: int
905
+ Number of containers spun up to handle each input.
906
+ broadcast: bool = True
907
+ If True, inputs will be sent simultaneously to each container. Otherwise,
908
+ inputs will be sent only to the rank-0 container, which is responsible for
909
+ delegating to the workers.
910
+ """
911
+
912
+ assert broadcast, "broadcast=False has not been implemented yet!"
913
+
914
+ if size <= 0:
915
+ raise ValueError("cluster size must be greater than 0")
916
+
917
+ flags = _PartialFunctionFlags.CLUSTERED
918
+ params = _PartialFunctionParams(cluster_size=size)
919
+
920
+ def wrapper(
921
+ obj: Union[_PartialFunction[P, ReturnType, ReturnType], Callable[P, ReturnType]],
922
+ ) -> _PartialFunction[P, ReturnType, ReturnType]:
923
+ if isinstance(obj, _PartialFunction):
924
+ pf = obj.stack(flags, params)
925
+ else:
926
+ pf = _PartialFunction(obj, flags, params)
927
+ pf.validate_obj_compatibility("clustered")
928
+ return pf
801
929
 
802
930
  return wrapper