modal 0.73.2__py3-none-any.whl → 0.73.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
modal/functions.py CHANGED
@@ -1,1608 +1,7 @@
1
- # Copyright Modal Labs 2023
2
- import dataclasses
3
- import inspect
4
- import textwrap
5
- import time
6
- import typing
7
- import warnings
8
- from collections.abc import AsyncGenerator, Collection, Sequence, Sized
9
- from dataclasses import dataclass
10
- from pathlib import PurePosixPath
11
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union
1
+ # Copyright Modal Labs 2025
2
+ from ._functions import _Function, _FunctionCall, _gather
3
+ from ._utils.async_utils import synchronize_api
12
4
 
13
- import typing_extensions
14
- from google.protobuf.message import Message
15
- from grpclib import GRPCError, Status
16
- from synchronicity.combined_types import MethodWithAio
17
- from synchronicity.exceptions import UserCodeException
18
-
19
- from modal_proto import api_pb2
20
- from modal_proto.modal_api_grpc import ModalClientModal
21
-
22
- from ._location import parse_cloud_provider
23
- from ._object import _get_environment_name, _Object, live_method, live_method_gen
24
- from ._pty import get_pty_info
25
- from ._resolver import Resolver
26
- from ._resources import convert_fn_config_to_resources_config
27
- from ._runtime.execution_context import current_input_id, is_local
28
- from ._serialization import serialize, serialize_proto_params
29
- from ._traceback import print_server_warnings
30
- from ._utils.async_utils import (
31
- TaskContext,
32
- aclosing,
33
- async_merge,
34
- callable_to_agen,
35
- synchronize_api,
36
- synchronizer,
37
- warn_if_generator_is_not_consumed,
38
- )
39
- from ._utils.deprecation import deprecation_warning, renamed_parameter
40
- from ._utils.function_utils import (
41
- ATTEMPT_TIMEOUT_GRACE_PERIOD,
42
- OUTPUTS_TIMEOUT,
43
- FunctionCreationStatus,
44
- FunctionInfo,
45
- IncludeSourceMode,
46
- _create_input,
47
- _process_result,
48
- _stream_function_call_data,
49
- get_function_type,
50
- get_include_source_mode,
51
- is_async,
52
- )
53
- from ._utils.grpc_utils import retry_transient_errors
54
- from ._utils.mount_utils import validate_network_file_systems, validate_volumes
55
- from .call_graph import InputInfo, _reconstruct_call_graph
56
- from .client import _Client
57
- from .cloud_bucket_mount import _CloudBucketMount, cloud_bucket_mounts_to_proto
58
- from .config import config
59
- from .exception import (
60
- ExecutionError,
61
- FunctionTimeoutError,
62
- InternalFailure,
63
- InvalidError,
64
- NotFoundError,
65
- OutputExpiredError,
66
- )
67
- from .gpu import GPU_T, parse_gpu_config
68
- from .image import _Image
69
- from .mount import _get_client_mount, _Mount, get_auto_mounts
70
- from .network_file_system import _NetworkFileSystem, network_file_system_mount_protos
71
- from .output import _get_output_manager
72
- from .parallel_map import (
73
- _for_each_async,
74
- _for_each_sync,
75
- _map_async,
76
- _map_invocation,
77
- _map_sync,
78
- _starmap_async,
79
- _starmap_sync,
80
- _SynchronizedQueue,
81
- )
82
- from .proxy import _Proxy
83
- from .retries import Retries, RetryManager
84
- from .schedule import Schedule
85
- from .scheduler_placement import SchedulerPlacement
86
- from .secret import _Secret
87
- from .volume import _Volume
88
-
89
- if TYPE_CHECKING:
90
- import modal.app
91
- import modal.cls
92
- import modal.partial_function
93
-
94
-
95
- @dataclasses.dataclass
96
- class _RetryContext:
97
- function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType"
98
- retry_policy: api_pb2.FunctionRetryPolicy
99
- function_call_jwt: str
100
- input_jwt: str
101
- input_id: str
102
- item: api_pb2.FunctionPutInputsItem
103
-
104
-
105
- class _Invocation:
106
- """Internal client representation of a single-input call to a Modal Function or Generator"""
107
-
108
- stub: ModalClientModal
109
-
110
- def __init__(
111
- self,
112
- stub: ModalClientModal,
113
- function_call_id: str,
114
- client: _Client,
115
- retry_context: Optional[_RetryContext] = None,
116
- ):
117
- self.stub = stub
118
- self.client = client # Used by the deserializer.
119
- self.function_call_id = function_call_id # TODO: remove and use only input_id
120
- self._retry_context = retry_context
121
-
122
- @staticmethod
123
- async def create(
124
- function: "_Function",
125
- args,
126
- kwargs,
127
- *,
128
- client: _Client,
129
- function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType",
130
- ) -> "_Invocation":
131
- assert client.stub
132
- function_id = function.object_id
133
- item = await _create_input(args, kwargs, client, method_name=function._use_method_name)
134
-
135
- request = api_pb2.FunctionMapRequest(
136
- function_id=function_id,
137
- parent_input_id=current_input_id() or "",
138
- function_call_type=api_pb2.FUNCTION_CALL_TYPE_UNARY,
139
- pipelined_inputs=[item],
140
- function_call_invocation_type=function_call_invocation_type,
141
- )
142
- response = await retry_transient_errors(client.stub.FunctionMap, request)
143
- function_call_id = response.function_call_id
144
-
145
- if response.pipelined_inputs:
146
- assert len(response.pipelined_inputs) == 1
147
- input = response.pipelined_inputs[0]
148
- retry_context = _RetryContext(
149
- function_call_invocation_type=function_call_invocation_type,
150
- retry_policy=response.retry_policy,
151
- function_call_jwt=response.function_call_jwt,
152
- input_jwt=input.input_jwt,
153
- input_id=input.input_id,
154
- item=item,
155
- )
156
- return _Invocation(client.stub, function_call_id, client, retry_context)
157
-
158
- request_put = api_pb2.FunctionPutInputsRequest(
159
- function_id=function_id, inputs=[item], function_call_id=function_call_id
160
- )
161
- inputs_response: api_pb2.FunctionPutInputsResponse = await retry_transient_errors(
162
- client.stub.FunctionPutInputs,
163
- request_put,
164
- )
165
- processed_inputs = inputs_response.inputs
166
- if not processed_inputs:
167
- raise Exception("Could not create function call - the input queue seems to be full")
168
- input = inputs_response.inputs[0]
169
- retry_context = _RetryContext(
170
- function_call_invocation_type=function_call_invocation_type,
171
- retry_policy=response.retry_policy,
172
- function_call_jwt=response.function_call_jwt,
173
- input_jwt=input.input_jwt,
174
- input_id=input.input_id,
175
- item=item,
176
- )
177
- return _Invocation(client.stub, function_call_id, client, retry_context)
178
-
179
- async def pop_function_call_outputs(
180
- self, timeout: Optional[float], clear_on_success: bool, input_jwts: Optional[list[str]] = None
181
- ) -> api_pb2.FunctionGetOutputsResponse:
182
- t0 = time.time()
183
- if timeout is None:
184
- backend_timeout = OUTPUTS_TIMEOUT
185
- else:
186
- backend_timeout = min(OUTPUTS_TIMEOUT, timeout) # refresh backend call every 55s
187
-
188
- while True:
189
- # always execute at least one poll for results, regardless if timeout is 0
190
- request = api_pb2.FunctionGetOutputsRequest(
191
- function_call_id=self.function_call_id,
192
- timeout=backend_timeout,
193
- last_entry_id="0-0",
194
- clear_on_success=clear_on_success,
195
- requested_at=time.time(),
196
- input_jwts=input_jwts,
197
- )
198
- response: api_pb2.FunctionGetOutputsResponse = await retry_transient_errors(
199
- self.stub.FunctionGetOutputs,
200
- request,
201
- attempt_timeout=backend_timeout + ATTEMPT_TIMEOUT_GRACE_PERIOD,
202
- )
203
-
204
- if len(response.outputs) > 0:
205
- return response
206
-
207
- if timeout is not None:
208
- # update timeout in retry loop
209
- backend_timeout = min(OUTPUTS_TIMEOUT, t0 + timeout - time.time())
210
- if backend_timeout < 0:
211
- # return the last response to check for state of num_unfinished_inputs
212
- return response
213
-
214
- async def _retry_input(self) -> None:
215
- ctx = self._retry_context
216
- if not ctx:
217
- raise ValueError("Cannot retry input when _retry_context is empty.")
218
-
219
- item = api_pb2.FunctionRetryInputsItem(input_jwt=ctx.input_jwt, input=ctx.item.input)
220
- request = api_pb2.FunctionRetryInputsRequest(function_call_jwt=ctx.function_call_jwt, inputs=[item])
221
- await retry_transient_errors(
222
- self.client.stub.FunctionRetryInputs,
223
- request,
224
- )
225
-
226
- async def _get_single_output(self, expected_jwt: Optional[str] = None) -> Any:
227
- # waits indefinitely for a single result for the function, and clear the outputs buffer after
228
- item: api_pb2.FunctionGetOutputsItem = (
229
- await self.pop_function_call_outputs(
230
- timeout=None,
231
- clear_on_success=True,
232
- input_jwts=[expected_jwt] if expected_jwt else None,
233
- )
234
- ).outputs[0]
235
- return await _process_result(item.result, item.data_format, self.stub, self.client)
236
-
237
- async def run_function(self) -> Any:
238
- # Use retry logic only if retry policy is specified and
239
- ctx = self._retry_context
240
- if (
241
- not ctx
242
- or not ctx.retry_policy
243
- or ctx.retry_policy.retries == 0
244
- or ctx.function_call_invocation_type != api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
245
- ):
246
- return await self._get_single_output()
247
-
248
- # User errors including timeouts are managed by the user specified retry policy.
249
- user_retry_manager = RetryManager(ctx.retry_policy)
250
-
251
- while True:
252
- try:
253
- return await self._get_single_output(ctx.input_jwt)
254
- except (UserCodeException, FunctionTimeoutError) as exc:
255
- await user_retry_manager.raise_or_sleep(exc)
256
- except InternalFailure:
257
- # For system failures on the server, we retry immediately.
258
- pass
259
- await self._retry_input()
260
-
261
- async def poll_function(self, timeout: Optional[float] = None):
262
- """Waits up to timeout for a result from a function.
263
-
264
- If timeout is `None`, waits indefinitely. This function is not
265
- cancellation-safe.
266
- """
267
- response: api_pb2.FunctionGetOutputsResponse = await self.pop_function_call_outputs(
268
- timeout=timeout, clear_on_success=False
269
- )
270
- if len(response.outputs) == 0 and response.num_unfinished_inputs == 0:
271
- # if no unfinished inputs and no outputs, then function expired
272
- raise OutputExpiredError()
273
- elif len(response.outputs) == 0:
274
- raise TimeoutError()
275
-
276
- return await _process_result(
277
- response.outputs[0].result, response.outputs[0].data_format, self.stub, self.client
278
- )
279
-
280
- async def run_generator(self):
281
- items_received = 0
282
- items_total: Union[int, None] = None # populated when self.run_function() completes
283
- async with aclosing(
284
- async_merge(
285
- _stream_function_call_data(self.client, self.function_call_id, variant="data_out"),
286
- callable_to_agen(self.run_function),
287
- )
288
- ) as streamer:
289
- async for item in streamer:
290
- if isinstance(item, api_pb2.GeneratorDone):
291
- items_total = item.items_total
292
- else:
293
- yield item
294
- items_received += 1
295
- # The comparison avoids infinite loops if a non-deterministic generator is retried
296
- # and produces less data in the second run than what was already sent.
297
- if items_total is not None and items_received >= items_total:
298
- break
299
-
300
-
301
- # Wrapper type for api_pb2.FunctionStats
302
- @dataclass(frozen=True)
303
- class FunctionStats:
304
- """Simple data structure storing stats for a running function."""
305
-
306
- backlog: int
307
- num_total_runners: int
308
-
309
- def __getattr__(self, name):
310
- if name == "num_active_runners":
311
- msg = (
312
- "'FunctionStats.num_active_runners' is deprecated."
313
- " It currently always has a value of 0,"
314
- " but it will be removed in a future release."
315
- )
316
- deprecation_warning((2024, 6, 14), msg)
317
- return 0
318
- raise AttributeError(f"'FunctionStats' object has no attribute '{name}'")
319
-
320
-
321
- def _parse_retries(
322
- retries: Optional[Union[int, Retries]],
323
- source: str = "",
324
- ) -> Optional[api_pb2.FunctionRetryPolicy]:
325
- if isinstance(retries, int):
326
- return Retries(
327
- max_retries=retries,
328
- initial_delay=1.0,
329
- backoff_coefficient=1.0,
330
- )._to_proto()
331
- elif isinstance(retries, Retries):
332
- return retries._to_proto()
333
- elif retries is None:
334
- return None
335
- else:
336
- extra = f" on {source}" if source else ""
337
- msg = f"Retries parameter must be an integer or instance of modal.Retries. Found: {type(retries)}{extra}."
338
- raise InvalidError(msg)
339
-
340
-
341
- @dataclass
342
- class _FunctionSpec:
343
- """
344
- Stores information about a Function specification.
345
- This is used for `modal shell` to support running shells with
346
- the same configuration as a user-defined Function.
347
- """
348
-
349
- image: Optional[_Image]
350
- mounts: Sequence[_Mount]
351
- secrets: Sequence[_Secret]
352
- network_file_systems: dict[Union[str, PurePosixPath], _NetworkFileSystem]
353
- volumes: dict[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]]
354
- gpus: Union[GPU_T, list[GPU_T]] # TODO(irfansharif): Somehow assert that it's the first kind, in sandboxes
355
- cloud: Optional[str]
356
- cpu: Optional[Union[float, tuple[float, float]]]
357
- memory: Optional[Union[int, tuple[int, int]]]
358
- ephemeral_disk: Optional[int]
359
- scheduler_placement: Optional[SchedulerPlacement]
360
- proxy: Optional[_Proxy]
361
-
362
-
363
- P = typing_extensions.ParamSpec("P")
364
- ReturnType = typing.TypeVar("ReturnType", covariant=True)
365
- OriginalReturnType = typing.TypeVar(
366
- "OriginalReturnType", covariant=True
367
- ) # differs from return type if ReturnType is coroutine
368
-
369
-
370
- class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type_prefix="fu"):
371
- """Functions are the basic units of serverless execution on Modal.
372
-
373
- Generally, you will not construct a `Function` directly. Instead, use the
374
- `App.function()` decorator to register your Python functions with your App.
375
- """
376
-
377
- # TODO: more type annotations
378
- _info: Optional[FunctionInfo]
379
- _serve_mounts: frozenset[_Mount] # set at load time, only by loader
380
- _app: Optional["modal.app._App"] = None
381
- _obj: Optional["modal.cls._Obj"] = None # only set for InstanceServiceFunctions and bound instance methods
382
-
383
- _webhook_config: Optional[api_pb2.WebhookConfig] = None # this is set in definition scope, only locally
384
- _web_url: Optional[str] # this is set on hydration
385
-
386
- _function_name: Optional[str]
387
- _is_method: bool
388
- _spec: Optional[_FunctionSpec] = None
389
- _tag: str
390
- _raw_f: Optional[Callable[..., Any]] # this is set to None for a "class service [function]"
391
- _build_args: dict
392
-
393
- _is_generator: Optional[bool] = None
394
- _cluster_size: Optional[int] = None
395
-
396
- # when this is the method of a class/object function, invocation of this function
397
- # should supply the method name in the FunctionInput:
398
- _use_method_name: str = ""
399
-
400
- _class_parameter_info: Optional["api_pb2.ClassParameterInfo"] = None
401
- _method_handle_metadata: Optional[dict[str, "api_pb2.FunctionHandleMetadata"]] = None
402
-
403
- def _bind_method(
404
- self,
405
- user_cls,
406
- method_name: str,
407
- partial_function: "modal.partial_function._PartialFunction",
408
- ):
409
- """mdmd:hidden
410
-
411
- Creates a _Function that is bound to a specific class method name. This _Function is not uniquely tied
412
- to any backend function -- its object_id is the function ID of the class service function.
413
-
414
- """
415
- class_service_function = self
416
- assert class_service_function._info # has to be a local function to be able to "bind" it
417
- assert not class_service_function._is_method # should not be used on an already bound method placeholder
418
- assert not class_service_function._obj # should only be used on base function / class service function
419
- full_name = f"{user_cls.__name__}.{method_name}"
420
-
421
- rep = f"Method({full_name})"
422
- fun = _Object.__new__(_Function)
423
- fun._init(rep)
424
- fun._tag = full_name
425
- fun._raw_f = partial_function.raw_f
426
- fun._info = FunctionInfo(
427
- partial_function.raw_f, user_cls=user_cls, serialized=class_service_function.info.is_serialized()
428
- ) # needed for .local()
429
- fun._use_method_name = method_name
430
- fun._app = class_service_function._app
431
- fun._is_generator = partial_function.is_generator
432
- fun._cluster_size = partial_function.cluster_size
433
- fun._spec = class_service_function._spec
434
- fun._is_method = True
435
- return fun
436
-
437
- @staticmethod
438
- def from_args(
439
- info: FunctionInfo,
440
- app,
441
- image: _Image,
442
- secrets: Sequence[_Secret] = (),
443
- schedule: Optional[Schedule] = None,
444
- is_generator: bool = False,
445
- gpu: Union[GPU_T, list[GPU_T]] = None,
446
- # TODO: maybe break this out into a separate decorator for notebooks.
447
- mounts: Collection[_Mount] = (),
448
- network_file_systems: dict[Union[str, PurePosixPath], _NetworkFileSystem] = {},
449
- allow_cross_region_volumes: bool = False,
450
- volumes: dict[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]] = {},
451
- webhook_config: Optional[api_pb2.WebhookConfig] = None,
452
- memory: Optional[Union[int, tuple[int, int]]] = None,
453
- proxy: Optional[_Proxy] = None,
454
- retries: Optional[Union[int, Retries]] = None,
455
- timeout: Optional[int] = None,
456
- concurrency_limit: Optional[int] = None,
457
- allow_concurrent_inputs: Optional[int] = None,
458
- batch_max_size: Optional[int] = None,
459
- batch_wait_ms: Optional[int] = None,
460
- container_idle_timeout: Optional[int] = None,
461
- cpu: Optional[Union[float, tuple[float, float]]] = None,
462
- keep_warm: Optional[int] = None, # keep_warm=True is equivalent to keep_warm=1
463
- cloud: Optional[str] = None,
464
- scheduler_placement: Optional[SchedulerPlacement] = None,
465
- is_builder_function: bool = False,
466
- is_auto_snapshot: bool = False,
467
- enable_memory_snapshot: bool = False,
468
- block_network: bool = False,
469
- i6pn_enabled: bool = False,
470
- cluster_size: Optional[int] = None, # Experimental: Clustered functions
471
- max_inputs: Optional[int] = None,
472
- ephemeral_disk: Optional[int] = None,
473
- # current default: first-party, future default: main-package
474
- include_source: Optional[bool] = None,
475
- _experimental_buffer_containers: Optional[int] = None,
476
- _experimental_proxy_ip: Optional[str] = None,
477
- _experimental_custom_scaling_factor: Optional[float] = None,
478
- ) -> "_Function":
479
- """mdmd:hidden"""
480
- # Needed to avoid circular imports
481
- from .partial_function import _find_partial_methods_for_user_cls, _PartialFunctionFlags
482
-
483
- tag = info.get_tag()
484
-
485
- if info.raw_f:
486
- raw_f = info.raw_f
487
- assert callable(raw_f)
488
- if schedule is not None and not info.is_nullary():
489
- raise InvalidError(
490
- f"Function {raw_f} has a schedule, so it needs to support being called with no arguments"
491
- )
492
- else:
493
- # must be a "class service function"
494
- assert info.user_cls
495
- assert not webhook_config
496
- assert not schedule
497
-
498
- explicit_mounts = mounts
499
-
500
- if is_local():
501
- include_source_mode = get_include_source_mode(include_source)
502
- if include_source_mode != IncludeSourceMode.INCLUDE_NOTHING:
503
- entrypoint_mounts = info.get_entrypoint_mount()
504
- else:
505
- entrypoint_mounts = []
506
-
507
- all_mounts = [
508
- _get_client_mount(),
509
- *explicit_mounts,
510
- *entrypoint_mounts,
511
- ]
512
-
513
- if include_source_mode is IncludeSourceMode.INCLUDE_FIRST_PARTY:
514
- # TODO(elias): if using INCLUDE_FIRST_PARTY *and* mounts are added that haven't already been
515
- # added to the image via add_local_python_source
516
- all_mounts += get_auto_mounts()
517
- else:
518
- # skip any mount introspection/logic inside containers, since the function
519
- # should already be hydrated
520
- # TODO: maybe the entire from_args loader should be exited early if not local?
521
- # since it will be hydrated
522
- all_mounts = []
523
-
524
- retry_policy = _parse_retries(
525
- retries, f"Function '{info.get_tag()}'" if info.raw_f else f"Class '{info.get_tag()}'"
526
- )
527
-
528
- if webhook_config is not None and retry_policy is not None:
529
- raise InvalidError(
530
- "Web endpoints do not support retries.",
531
- )
532
-
533
- if is_generator and retry_policy is not None:
534
- deprecation_warning(
535
- (2024, 6, 25),
536
- "Retries for generator functions are deprecated and will soon be removed.",
537
- )
538
-
539
- if proxy:
540
- # HACK: remove this once we stop using ssh tunnels for this.
541
- if image:
542
- # TODO(elias): this will cause an error if users use prior `.add_local_*` commands without copy=True
543
- image = image.apt_install("autossh")
544
-
545
- function_spec = _FunctionSpec(
546
- mounts=all_mounts,
547
- secrets=secrets,
548
- gpus=gpu,
549
- network_file_systems=network_file_systems,
550
- volumes=volumes,
551
- image=image,
552
- cloud=cloud,
553
- cpu=cpu,
554
- memory=memory,
555
- ephemeral_disk=ephemeral_disk,
556
- scheduler_placement=scheduler_placement,
557
- proxy=proxy,
558
- )
559
-
560
- if info.user_cls and not is_auto_snapshot:
561
- build_functions = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.BUILD).items()
562
- for k, pf in build_functions:
563
- build_function = pf.raw_f
564
- snapshot_info = FunctionInfo(build_function, user_cls=info.user_cls)
565
- snapshot_function = _Function.from_args(
566
- snapshot_info,
567
- app=None,
568
- image=image,
569
- secrets=secrets,
570
- gpu=gpu,
571
- mounts=mounts,
572
- network_file_systems=network_file_systems,
573
- volumes=volumes,
574
- memory=memory,
575
- timeout=pf.build_timeout,
576
- cpu=cpu,
577
- ephemeral_disk=ephemeral_disk,
578
- is_builder_function=True,
579
- is_auto_snapshot=True,
580
- scheduler_placement=scheduler_placement,
581
- )
582
- image = _Image._from_args(
583
- base_images={"base": image},
584
- build_function=snapshot_function, # type: ignore # TODO: separate functions.py and _functions.py
585
- force_build=image.force_build or pf.force_build,
586
- )
587
-
588
- if keep_warm is not None and not isinstance(keep_warm, int):
589
- raise TypeError(f"`keep_warm` must be an int or bool, not {type(keep_warm).__name__}")
590
-
591
- if (keep_warm is not None) and (concurrency_limit is not None) and concurrency_limit < keep_warm:
592
- raise InvalidError(
593
- f"Function `{info.function_name}` has `{concurrency_limit=}`, "
594
- f"strictly less than its `{keep_warm=}` parameter."
595
- )
596
-
597
- if _experimental_custom_scaling_factor is not None and (
598
- _experimental_custom_scaling_factor < 0 or _experimental_custom_scaling_factor > 1
599
- ):
600
- raise InvalidError("`_experimental_custom_scaling_factor` must be between 0.0 and 1.0 inclusive.")
601
-
602
- if not cloud and not is_builder_function:
603
- cloud = config.get("default_cloud")
604
- if cloud:
605
- cloud_provider = parse_cloud_provider(cloud)
606
- else:
607
- cloud_provider = None
608
-
609
- if is_generator and webhook_config:
610
- if webhook_config.type == api_pb2.WEBHOOK_TYPE_FUNCTION:
611
- raise InvalidError(
612
- """Webhooks cannot be generators. If you want a streaming response, see https://modal.com/docs/guide/streaming-endpoints
613
- """
614
- )
615
- else:
616
- raise InvalidError("Webhooks cannot be generators")
617
-
618
- if info.raw_f and batch_max_size:
619
- func_name = info.raw_f.__name__
620
- if is_generator:
621
- raise InvalidError(f"Modal batched function {func_name} cannot return generators")
622
- for arg in inspect.signature(info.raw_f).parameters.values():
623
- if arg.default is not inspect.Parameter.empty:
624
- raise InvalidError(f"Modal batched function {func_name} does not accept default arguments.")
625
-
626
- if container_idle_timeout is not None and container_idle_timeout <= 0:
627
- raise InvalidError("`container_idle_timeout` must be > 0")
628
-
629
- if max_inputs is not None:
630
- if not isinstance(max_inputs, int):
631
- raise InvalidError(f"`max_inputs` must be an int, not {type(max_inputs).__name__}")
632
- if max_inputs <= 0:
633
- raise InvalidError("`max_inputs` must be positive")
634
- if max_inputs > 1:
635
- raise InvalidError("Only `max_inputs=1` is currently supported")
636
-
637
- # Validate volumes
638
- validated_volumes = validate_volumes(volumes)
639
- cloud_bucket_mounts = [(k, v) for k, v in validated_volumes if isinstance(v, _CloudBucketMount)]
640
- validated_volumes = [(k, v) for k, v in validated_volumes if isinstance(v, _Volume)]
641
-
642
- # Validate NFS
643
- validated_network_file_systems = validate_network_file_systems(network_file_systems)
644
-
645
- # Validate image
646
- if image is not None and not isinstance(image, _Image):
647
- raise InvalidError(f"Expected modal.Image object. Got {type(image)}.")
648
-
649
- method_definitions: Optional[dict[str, api_pb2.MethodDefinition]] = None
650
-
651
- if info.user_cls:
652
- method_definitions = {}
653
- partial_functions = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.FUNCTION)
654
- for method_name, partial_function in partial_functions.items():
655
- function_type = get_function_type(partial_function.is_generator)
656
- function_name = f"{info.user_cls.__name__}.{method_name}"
657
- method_definition = api_pb2.MethodDefinition(
658
- webhook_config=partial_function.webhook_config,
659
- function_type=function_type,
660
- function_name=function_name,
661
- )
662
- method_definitions[method_name] = method_definition
663
-
664
- function_type = get_function_type(is_generator)
665
-
666
- def _deps(only_explicit_mounts=False) -> list[_Object]:
667
- deps: list[_Object] = list(secrets)
668
- if only_explicit_mounts:
669
- # TODO: this is a bit hacky, but all_mounts may differ in the container vs locally
670
- # We don't want the function dependencies to change, so we have this way to force it to
671
- # only include its declared dependencies.
672
- # Only objects that need interaction within a user's container actually need to be
673
- # included when only_explicit_mounts=True, so omitting auto mounts here
674
- # wouldn't be a problem as long as Mounts are "passive" and only loaded by the
675
- # worker runtime
676
- deps += list(explicit_mounts)
677
- else:
678
- deps += list(all_mounts)
679
- if proxy:
680
- deps.append(proxy)
681
- if image:
682
- deps.append(image)
683
- for _, nfs in validated_network_file_systems:
684
- deps.append(nfs)
685
- for _, vol in validated_volumes:
686
- deps.append(vol)
687
- for _, cloud_bucket_mount in cloud_bucket_mounts:
688
- if cloud_bucket_mount.secret:
689
- deps.append(cloud_bucket_mount.secret)
690
-
691
- return deps
692
-
693
- async def _preload(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
694
- assert resolver.client and resolver.client.stub
695
-
696
- assert resolver.app_id
697
- req = api_pb2.FunctionPrecreateRequest(
698
- app_id=resolver.app_id,
699
- function_name=info.function_name,
700
- function_type=function_type,
701
- existing_function_id=existing_object_id or "",
702
- )
703
- if method_definitions:
704
- for method_name, method_definition in method_definitions.items():
705
- req.method_definitions[method_name].CopyFrom(method_definition)
706
- elif webhook_config:
707
- req.webhook_config.CopyFrom(webhook_config)
708
- response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req)
709
- self._hydrate(response.function_id, resolver.client, response.handle_metadata)
710
-
711
- async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
712
- assert resolver.client and resolver.client.stub
713
- with FunctionCreationStatus(resolver, tag) as function_creation_status:
714
- timeout_secs = timeout
715
-
716
- if app and app.is_interactive and not is_builder_function:
717
- pty_info = get_pty_info(shell=False)
718
- else:
719
- pty_info = None
720
-
721
- if info.is_serialized():
722
- # Use cloudpickle. Used when working w/ Jupyter notebooks.
723
- # serialize at _load time, not function decoration time
724
- # otherwise we can't capture a surrounding class for lifetime methods etc.
725
- function_serialized = info.serialized_function()
726
- class_serialized = serialize(info.user_cls) if info.user_cls is not None else None
727
- # Ensure that large data in global variables does not blow up the gRPC payload,
728
- # which has maximum size 100 MiB. We set the limit lower for performance reasons.
729
- if len(function_serialized) > 16 << 20: # 16 MiB
730
- raise InvalidError(
731
- f"Function {info.raw_f} has size {len(function_serialized)} bytes when packaged. "
732
- "This is larger than the maximum limit of 16 MiB. "
733
- "Try reducing the size of the closure by using parameters or mounts, "
734
- "not large global variables."
735
- )
736
- elif len(function_serialized) > 256 << 10: # 256 KiB
737
- warnings.warn(
738
- f"Function {info.raw_f} has size {len(function_serialized)} bytes when packaged. "
739
- "This is larger than the recommended limit of 256 KiB. "
740
- "Try reducing the size of the closure by using parameters or mounts, "
741
- "not large global variables."
742
- )
743
- else:
744
- function_serialized = None
745
- class_serialized = None
746
-
747
- app_name = ""
748
- if app and app.name:
749
- app_name = app.name
750
-
751
- # Relies on dicts being ordered (true as of Python 3.6).
752
- volume_mounts = [
753
- api_pb2.VolumeMount(
754
- mount_path=path,
755
- volume_id=volume.object_id,
756
- allow_background_commits=True,
757
- )
758
- for path, volume in validated_volumes
759
- ]
760
- loaded_mount_ids = {m.object_id for m in all_mounts} | {m.object_id for m in image._mount_layers}
761
-
762
- # Get object dependencies
763
- object_dependencies = []
764
- for dep in _deps(only_explicit_mounts=True):
765
- if not dep.object_id:
766
- raise Exception(f"Dependency {dep} isn't hydrated")
767
- object_dependencies.append(api_pb2.ObjectDependency(object_id=dep.object_id))
768
-
769
- function_data: Optional[api_pb2.FunctionData] = None
770
- function_definition: Optional[api_pb2.Function] = None
771
-
772
- # Create function remotely
773
- function_definition = api_pb2.Function(
774
- module_name=info.module_name or "",
775
- function_name=info.function_name,
776
- mount_ids=loaded_mount_ids,
777
- secret_ids=[secret.object_id for secret in secrets],
778
- image_id=(image.object_id if image else ""),
779
- definition_type=info.get_definition_type(),
780
- function_serialized=function_serialized or b"",
781
- class_serialized=class_serialized or b"",
782
- function_type=function_type,
783
- webhook_config=webhook_config,
784
- method_definitions=method_definitions,
785
- method_definitions_set=True,
786
- shared_volume_mounts=network_file_system_mount_protos(
787
- validated_network_file_systems, allow_cross_region_volumes
788
- ),
789
- volume_mounts=volume_mounts,
790
- proxy_id=(proxy.object_id if proxy else None),
791
- retry_policy=retry_policy,
792
- timeout_secs=timeout_secs or 0,
793
- task_idle_timeout_secs=container_idle_timeout or 0,
794
- concurrency_limit=concurrency_limit or 0,
795
- pty_info=pty_info,
796
- cloud_provider=cloud_provider, # Deprecated at some point
797
- cloud_provider_str=cloud.upper() if cloud else "", # Supersedes cloud_provider
798
- warm_pool_size=keep_warm or 0,
799
- runtime=config.get("function_runtime"),
800
- runtime_debug=config.get("function_runtime_debug"),
801
- runtime_perf_record=config.get("runtime_perf_record"),
802
- app_name=app_name,
803
- is_builder_function=is_builder_function,
804
- target_concurrent_inputs=allow_concurrent_inputs or 0,
805
- batch_max_size=batch_max_size or 0,
806
- batch_linger_ms=batch_wait_ms or 0,
807
- worker_id=config.get("worker_id"),
808
- is_auto_snapshot=is_auto_snapshot,
809
- is_method=bool(info.user_cls) and not info.is_service_class(),
810
- checkpointing_enabled=enable_memory_snapshot,
811
- object_dependencies=object_dependencies,
812
- block_network=block_network,
813
- max_inputs=max_inputs or 0,
814
- cloud_bucket_mounts=cloud_bucket_mounts_to_proto(cloud_bucket_mounts),
815
- scheduler_placement=scheduler_placement.proto if scheduler_placement else None,
816
- is_class=info.is_service_class(),
817
- class_parameter_info=info.class_parameter_info(),
818
- i6pn_enabled=i6pn_enabled,
819
- schedule=schedule.proto_message if schedule is not None else None,
820
- snapshot_debug=config.get("snapshot_debug"),
821
- _experimental_group_size=cluster_size or 0, # Experimental: Clustered functions
822
- _experimental_concurrent_cancellations=True,
823
- _experimental_buffer_containers=_experimental_buffer_containers or 0,
824
- _experimental_proxy_ip=_experimental_proxy_ip,
825
- _experimental_custom_scaling=_experimental_custom_scaling_factor is not None,
826
- )
827
-
828
- if isinstance(gpu, list):
829
- function_data = api_pb2.FunctionData(
830
- module_name=function_definition.module_name,
831
- function_name=function_definition.function_name,
832
- function_type=function_definition.function_type,
833
- warm_pool_size=function_definition.warm_pool_size,
834
- concurrency_limit=function_definition.concurrency_limit,
835
- task_idle_timeout_secs=function_definition.task_idle_timeout_secs,
836
- worker_id=function_definition.worker_id,
837
- timeout_secs=function_definition.timeout_secs,
838
- web_url=function_definition.web_url,
839
- web_url_info=function_definition.web_url_info,
840
- webhook_config=function_definition.webhook_config,
841
- custom_domain_info=function_definition.custom_domain_info,
842
- schedule=schedule.proto_message if schedule is not None else None,
843
- is_class=function_definition.is_class,
844
- class_parameter_info=function_definition.class_parameter_info,
845
- is_method=function_definition.is_method,
846
- use_function_id=function_definition.use_function_id,
847
- use_method_name=function_definition.use_method_name,
848
- method_definitions=function_definition.method_definitions,
849
- method_definitions_set=function_definition.method_definitions_set,
850
- _experimental_group_size=function_definition._experimental_group_size,
851
- _experimental_buffer_containers=function_definition._experimental_buffer_containers,
852
- _experimental_custom_scaling=function_definition._experimental_custom_scaling,
853
- _experimental_proxy_ip=function_definition._experimental_proxy_ip,
854
- snapshot_debug=function_definition.snapshot_debug,
855
- runtime_perf_record=function_definition.runtime_perf_record,
856
- )
857
-
858
- ranked_functions = []
859
- for rank, _gpu in enumerate(gpu):
860
- function_definition_copy = api_pb2.Function()
861
- function_definition_copy.CopyFrom(function_definition)
862
-
863
- function_definition_copy.resources.CopyFrom(
864
- convert_fn_config_to_resources_config(
865
- cpu=cpu, memory=memory, gpu=_gpu, ephemeral_disk=ephemeral_disk
866
- ),
867
- )
868
- ranked_function = api_pb2.FunctionData.RankedFunction(
869
- rank=rank,
870
- function=function_definition_copy,
871
- )
872
- ranked_functions.append(ranked_function)
873
- function_data.ranked_functions.extend(ranked_functions)
874
- function_definition = None # function_definition is not used in this case
875
- else:
876
- # TODO(irfansharif): Assert on this specific type once we get rid of python 3.9.
877
- # assert isinstance(gpu, GPU_T) # includes the case where gpu==None case
878
- function_definition.resources.CopyFrom(
879
- convert_fn_config_to_resources_config(
880
- cpu=cpu, memory=memory, gpu=gpu, ephemeral_disk=ephemeral_disk
881
- ), # type: ignore
882
- )
883
-
884
- assert resolver.app_id
885
- assert (function_definition is None) != (function_data is None) # xor
886
- request = api_pb2.FunctionCreateRequest(
887
- app_id=resolver.app_id,
888
- function=function_definition,
889
- function_data=function_data,
890
- existing_function_id=existing_object_id or "",
891
- defer_updates=True,
892
- )
893
- try:
894
- response: api_pb2.FunctionCreateResponse = await retry_transient_errors(
895
- resolver.client.stub.FunctionCreate, request
896
- )
897
- except GRPCError as exc:
898
- if exc.status == Status.INVALID_ARGUMENT:
899
- raise InvalidError(exc.message)
900
- if exc.status == Status.FAILED_PRECONDITION:
901
- raise InvalidError(exc.message)
902
- if exc.message and "Received :status = '413'" in exc.message:
903
- raise InvalidError(f"Function {info.function_name} is too large to deploy.")
904
- raise
905
- function_creation_status.set_response(response)
906
- serve_mounts = {m for m in all_mounts if m.is_local()} # needed for modal.serve file watching
907
- serve_mounts |= image._serve_mounts
908
- obj._serve_mounts = frozenset(serve_mounts)
909
- self._hydrate(response.function_id, resolver.client, response.handle_metadata)
910
-
911
- rep = f"Function({tag})"
912
- obj = _Function._from_loader(_load, rep, preload=_preload, deps=_deps)
913
-
914
- obj._raw_f = info.raw_f
915
- obj._info = info
916
- obj._tag = tag
917
- obj._app = app # needed for CLI right now
918
- obj._obj = None
919
- obj._is_generator = is_generator
920
- obj._cluster_size = cluster_size
921
- obj._is_method = False
922
- obj._spec = function_spec # needed for modal shell
923
- obj._webhook_config = webhook_config # only set locally
924
-
925
- # Used to check whether we should rebuild a modal.Image which uses `run_function`.
926
- gpus: list[GPU_T] = gpu if isinstance(gpu, list) else [gpu]
927
- obj._build_args = dict( # See get_build_def
928
- secrets=repr(secrets),
929
- gpu_config=repr([parse_gpu_config(_gpu) for _gpu in gpus]),
930
- mounts=repr(mounts),
931
- network_file_systems=repr(network_file_systems),
932
- )
933
- # these key are excluded if empty to avoid rebuilds on client upgrade
934
- if volumes:
935
- obj._build_args["volumes"] = repr(volumes)
936
- if cloud or scheduler_placement:
937
- obj._build_args["cloud"] = repr(cloud)
938
- obj._build_args["scheduler_placement"] = repr(scheduler_placement)
939
-
940
- return obj
941
-
942
- def _bind_parameters(
943
- self,
944
- obj: "modal.cls._Obj",
945
- options: Optional[api_pb2.FunctionOptions],
946
- args: Sized,
947
- kwargs: dict[str, Any],
948
- ) -> "_Function":
949
- """mdmd:hidden
950
-
951
- Binds a class-function to a specific instance of (init params, options) or a new workspace
952
- """
953
-
954
- # In some cases, reuse the base function, i.e. not create new clones of each method or the "service function"
955
- can_use_parent = len(args) + len(kwargs) == 0 and options is None
956
- parent = self
957
-
958
- async def _load(param_bound_func: _Function, resolver: Resolver, existing_object_id: Optional[str]):
959
- try:
960
- identity = f"{parent.info.function_name} class service function"
961
- except Exception:
962
- # Can't always look up the function name that way, so fall back to generic message
963
- identity = "class service function for a parametrized class"
964
- if not parent.is_hydrated:
965
- if parent.app._running_app is None:
966
- reason = ", because the App it is defined on is not running"
967
- else:
968
- reason = ""
969
- raise ExecutionError(
970
- f"The {identity} has not been hydrated with the metadata it needs to run on Modal{reason}."
971
- )
972
-
973
- assert parent._client and parent._client.stub
974
-
975
- if can_use_parent:
976
- # We can end up here if parent wasn't hydrated when class was instantiated, but has been since.
977
- param_bound_func._hydrate_from_other(parent)
978
- return
979
-
980
- if (
981
- parent._class_parameter_info
982
- and parent._class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO
983
- ):
984
- if args:
985
- # TODO(elias) - We could potentially support positional args as well, if we want to?
986
- raise InvalidError(
987
- "Can't use positional arguments with modal.parameter-based synthetic constructors.\n"
988
- "Use (<parameter_name>=value) keyword arguments when constructing classes instead."
989
- )
990
- serialized_params = serialize_proto_params(kwargs, parent._class_parameter_info.schema)
991
- else:
992
- serialized_params = serialize((args, kwargs))
993
- environment_name = _get_environment_name(None, resolver)
994
- assert parent is not None and parent.is_hydrated
995
- req = api_pb2.FunctionBindParamsRequest(
996
- function_id=parent.object_id,
997
- serialized_params=serialized_params,
998
- function_options=options,
999
- environment_name=environment_name
1000
- or "", # TODO: investigate shouldn't environment name always be specified here?
1001
- )
1002
-
1003
- response = await retry_transient_errors(parent._client.stub.FunctionBindParams, req)
1004
- param_bound_func._hydrate(response.bound_function_id, parent._client, response.handle_metadata)
1005
-
1006
- fun: _Function = _Function._from_loader(_load, "Function(parametrized)", hydrate_lazily=True)
1007
-
1008
- if can_use_parent and parent.is_hydrated:
1009
- # skip the resolver altogether:
1010
- fun._hydrate_from_other(parent)
1011
-
1012
- fun._info = self._info
1013
- fun._obj = obj
1014
- return fun
1015
-
1016
- @live_method
1017
- async def keep_warm(self, warm_pool_size: int) -> None:
1018
- """Set the warm pool size for the function.
1019
-
1020
- Please exercise care when using this advanced feature!
1021
- Setting and forgetting a warm pool on functions can lead to increased costs.
1022
-
1023
- ```python notest
1024
- # Usage on a regular function.
1025
- f = modal.Function.from_name("my-app", "function")
1026
- f.keep_warm(2)
1027
-
1028
- # Usage on a parametrized function.
1029
- Model = modal.Cls.from_name("my-app", "Model")
1030
- Model("fine-tuned-model").keep_warm(2)
1031
- ```
1032
- """
1033
- if self._is_method:
1034
- raise InvalidError(
1035
- textwrap.dedent(
1036
- """
1037
- The `.keep_warm()` method can not be used on Modal class *methods* deployed using Modal >v0.63.
1038
-
1039
- Call `.keep_warm()` on the class *instance* instead.
1040
- """
1041
- )
1042
- )
1043
- request = api_pb2.FunctionUpdateSchedulingParamsRequest(
1044
- function_id=self.object_id, warm_pool_size_override=warm_pool_size
1045
- )
1046
- await retry_transient_errors(self.client.stub.FunctionUpdateSchedulingParams, request)
1047
-
1048
- @classmethod
1049
- @renamed_parameter((2024, 12, 18), "tag", "name")
1050
- def from_name(
1051
- cls: type["_Function"],
1052
- app_name: str,
1053
- name: str,
1054
- namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE,
1055
- environment_name: Optional[str] = None,
1056
- ) -> "_Function":
1057
- """Reference a Function from a deployed App by its name.
1058
-
1059
- In contast to `modal.Function.lookup`, this is a lazy method
1060
- that defers hydrating the local object with metadata from
1061
- Modal servers until the first time it is actually used.
1062
-
1063
- ```python
1064
- f = modal.Function.from_name("other-app", "function")
1065
- ```
1066
- """
1067
-
1068
- async def _load_remote(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
1069
- assert resolver.client and resolver.client.stub
1070
- request = api_pb2.FunctionGetRequest(
1071
- app_name=app_name,
1072
- object_tag=name,
1073
- namespace=namespace,
1074
- environment_name=_get_environment_name(environment_name, resolver) or "",
1075
- )
1076
- try:
1077
- response = await retry_transient_errors(resolver.client.stub.FunctionGet, request)
1078
- except GRPCError as exc:
1079
- if exc.status == Status.NOT_FOUND:
1080
- raise NotFoundError(exc.message)
1081
- else:
1082
- raise
1083
-
1084
- print_server_warnings(response.server_warnings)
1085
-
1086
- self._hydrate(response.function_id, resolver.client, response.handle_metadata)
1087
-
1088
- rep = f"Ref({app_name})"
1089
- return cls._from_loader(_load_remote, rep, is_another_app=True, hydrate_lazily=True)
1090
-
1091
- @staticmethod
1092
- @renamed_parameter((2024, 12, 18), "tag", "name")
1093
- async def lookup(
1094
- app_name: str,
1095
- name: str,
1096
- namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE,
1097
- client: Optional[_Client] = None,
1098
- environment_name: Optional[str] = None,
1099
- ) -> "_Function":
1100
- """Lookup a Function from a deployed App by its name.
1101
-
1102
- DEPRECATED: This method is deprecated in favor of `modal.Function.from_name`.
1103
-
1104
- In contrast to `modal.Function.from_name`, this is an eager method
1105
- that will hydrate the local object with metadata from Modal servers.
1106
-
1107
- ```python notest
1108
- f = modal.Function.lookup("other-app", "function")
1109
- ```
1110
- """
1111
- deprecation_warning(
1112
- (2025, 1, 27),
1113
- "`modal.Function.lookup` is deprecated and will be removed in a future release."
1114
- " It can be replaced with `modal.Function.from_name`."
1115
- "\n\nSee https://modal.com/docs/guide/modal-1-0-migration for more information.",
1116
- )
1117
- obj = _Function.from_name(app_name, name, namespace=namespace, environment_name=environment_name)
1118
- if client is None:
1119
- client = await _Client.from_env()
1120
- resolver = Resolver(client=client)
1121
- await resolver.load(obj)
1122
- return obj
1123
-
1124
- @property
1125
- def tag(self) -> str:
1126
- """mdmd:hidden"""
1127
- assert self._tag
1128
- return self._tag
1129
-
1130
- @property
1131
- def app(self) -> "modal.app._App":
1132
- """mdmd:hidden"""
1133
- if self._app is None:
1134
- raise ExecutionError("The app has not been assigned on the function at this point")
1135
-
1136
- return self._app
1137
-
1138
- @property
1139
- def stub(self) -> "modal.app._App":
1140
- """mdmd:hidden"""
1141
- # Deprecated soon, only for backwards compatibility
1142
- return self.app
1143
-
1144
- @property
1145
- def info(self) -> FunctionInfo:
1146
- """mdmd:hidden"""
1147
- assert self._info
1148
- return self._info
1149
-
1150
- @property
1151
- def spec(self) -> _FunctionSpec:
1152
- """mdmd:hidden"""
1153
- assert self._spec
1154
- return self._spec
1155
-
1156
- def _is_web_endpoint(self) -> bool:
1157
- # only defined in definition scope/locally, and not for class methods at the moment
1158
- return bool(self._webhook_config and self._webhook_config.type != api_pb2.WEBHOOK_TYPE_UNSPECIFIED)
1159
-
1160
- def get_build_def(self) -> str:
1161
- """mdmd:hidden"""
1162
- # Plaintext source and arg definition for the function, so it's part of the image
1163
- # hash. We can't use the cloudpickle hash because it's not very stable.
1164
- assert hasattr(self, "_raw_f") and hasattr(self, "_build_args") and self._raw_f is not None
1165
- return f"{inspect.getsource(self._raw_f)}\n{repr(self._build_args)}"
1166
-
1167
- # Live handle methods
1168
-
1169
- def _initialize_from_empty(self):
1170
- # Overridden concrete implementation of base class method
1171
- self._progress = None
1172
- self._is_generator = None
1173
- self._cluster_size = None
1174
- self._web_url = None
1175
- self._function_name = None
1176
- self._info = None
1177
- self._serve_mounts = frozenset()
1178
-
1179
- def _hydrate_metadata(self, metadata: Optional[Message]):
1180
- # Overridden concrete implementation of base class method
1181
- assert metadata and isinstance(metadata, api_pb2.FunctionHandleMetadata)
1182
- self._is_generator = metadata.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR
1183
- self._web_url = metadata.web_url
1184
- self._function_name = metadata.function_name
1185
- self._is_method = metadata.is_method
1186
- self._use_method_name = metadata.use_method_name
1187
- self._class_parameter_info = metadata.class_parameter_info
1188
- self._method_handle_metadata = dict(metadata.method_handle_metadata)
1189
- self._definition_id = metadata.definition_id
1190
-
1191
- def _get_metadata(self):
1192
- # Overridden concrete implementation of base class method
1193
- assert self._function_name, f"Function name must be set before metadata can be retrieved for {self}"
1194
- return api_pb2.FunctionHandleMetadata(
1195
- function_name=self._function_name,
1196
- function_type=get_function_type(self._is_generator),
1197
- web_url=self._web_url or "",
1198
- use_method_name=self._use_method_name,
1199
- is_method=self._is_method,
1200
- class_parameter_info=self._class_parameter_info,
1201
- definition_id=self._definition_id,
1202
- method_handle_metadata=self._method_handle_metadata,
1203
- )
1204
-
1205
- def _check_no_web_url(self, fn_name: str):
1206
- if self._web_url:
1207
- raise InvalidError(
1208
- f"A webhook function cannot be invoked for remote execution with `.{fn_name}`. "
1209
- f"Invoke this function via its web url '{self._web_url}' "
1210
- + f"or call it locally: {self._function_name}.local()"
1211
- )
1212
-
1213
- # TODO (live_method on properties is not great, since it could be blocking the event loop from async contexts)
1214
- @property
1215
- @live_method
1216
- async def web_url(self) -> str:
1217
- """URL of a Function running as a web endpoint."""
1218
- if not self._web_url:
1219
- raise ValueError(
1220
- f"No web_url can be found for function {self._function_name}. web_url "
1221
- "can only be referenced from a running app context"
1222
- )
1223
- return self._web_url
1224
-
1225
- @property
1226
- async def is_generator(self) -> bool:
1227
- """mdmd:hidden"""
1228
- # hacky: kind of like @live_method, but not hydrating if we have the value already from local source
1229
- # TODO(michael) use a common / lightweight method for handling unhydrated metadata properties
1230
- if self._is_generator is not None:
1231
- # this is set if the function or class is local
1232
- return self._is_generator
1233
-
1234
- # not set - this is a from_name lookup - hydrate
1235
- await self.hydrate()
1236
- assert self._is_generator is not None # should be set now
1237
- return self._is_generator
1238
-
1239
- @property
1240
- def cluster_size(self) -> int:
1241
- """mdmd:hidden"""
1242
- return self._cluster_size or 1
1243
-
1244
- @live_method_gen
1245
- async def _map(
1246
- self, input_queue: _SynchronizedQueue, order_outputs: bool, return_exceptions: bool
1247
- ) -> AsyncGenerator[Any, None]:
1248
- """mdmd:hidden
1249
-
1250
- Synchronicity-wrapped map implementation. To be safe against invocations of user code in
1251
- the synchronicity thread it doesn't accept an [async]iterator, and instead takes a
1252
- _SynchronizedQueue instance that is fed by higher level functions like .map()
1253
-
1254
- _SynchronizedQueue is used instead of asyncio.Queue so that the main thread can put
1255
- items in the queue safely.
1256
- """
1257
- self._check_no_web_url("map")
1258
- if self._is_generator:
1259
- raise InvalidError("A generator function cannot be called with `.map(...)`.")
1260
-
1261
- assert self._function_name
1262
- if output_mgr := _get_output_manager():
1263
- count_update_callback = output_mgr.function_progress_callback(self._function_name, total=None)
1264
- else:
1265
- count_update_callback = None
1266
-
1267
- async with aclosing(
1268
- _map_invocation(
1269
- self, # type: ignore
1270
- input_queue,
1271
- self.client,
1272
- order_outputs,
1273
- return_exceptions,
1274
- count_update_callback,
1275
- )
1276
- ) as stream:
1277
- async for item in stream:
1278
- yield item
1279
-
1280
- async def _call_function(self, args, kwargs) -> ReturnType:
1281
- if config.get("client_retries"):
1282
- function_call_invocation_type = api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
1283
- else:
1284
- function_call_invocation_type = api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY
1285
- invocation = await _Invocation.create(
1286
- self,
1287
- args,
1288
- kwargs,
1289
- client=self.client,
1290
- function_call_invocation_type=function_call_invocation_type,
1291
- )
1292
-
1293
- return await invocation.run_function()
1294
-
1295
- async def _call_function_nowait(
1296
- self, args, kwargs, function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType"
1297
- ) -> _Invocation:
1298
- return await _Invocation.create(
1299
- self, args, kwargs, client=self.client, function_call_invocation_type=function_call_invocation_type
1300
- )
1301
-
1302
- @warn_if_generator_is_not_consumed()
1303
- @live_method_gen
1304
- @synchronizer.no_input_translation
1305
- async def _call_generator(self, args, kwargs):
1306
- invocation = await _Invocation.create(
1307
- self,
1308
- args,
1309
- kwargs,
1310
- client=self.client,
1311
- function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY,
1312
- )
1313
- async for res in invocation.run_generator():
1314
- yield res
1315
-
1316
- @synchronizer.no_io_translation
1317
- async def _call_generator_nowait(self, args, kwargs):
1318
- deprecation_warning(
1319
- (2024, 12, 11),
1320
- "Calling spawn on a generator function is deprecated and will soon raise an exception.",
1321
- )
1322
- return await _Invocation.create(
1323
- self,
1324
- args,
1325
- kwargs,
1326
- client=self.client,
1327
- function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC_LEGACY,
1328
- )
1329
-
1330
- @synchronizer.no_io_translation
1331
- @live_method
1332
- async def remote(self, *args: P.args, **kwargs: P.kwargs) -> ReturnType:
1333
- """
1334
- Calls the function remotely, executing it with the given arguments and returning the execution's result.
1335
- """
1336
- # TODO: Generics/TypeVars
1337
- self._check_no_web_url("remote")
1338
- if self._is_generator:
1339
- raise InvalidError(
1340
- "A generator function cannot be called with `.remote(...)`. Use `.remote_gen(...)` instead."
1341
- )
1342
-
1343
- return await self._call_function(args, kwargs)
1344
-
1345
- @synchronizer.no_io_translation
1346
- @live_method_gen
1347
- async def remote_gen(self, *args, **kwargs) -> AsyncGenerator[Any, None]:
1348
- """
1349
- Calls the generator remotely, executing it with the given arguments and returning the execution's result.
1350
- """
1351
- # TODO: Generics/TypeVars
1352
- self._check_no_web_url("remote_gen")
1353
-
1354
- if not self._is_generator:
1355
- raise InvalidError(
1356
- "A non-generator function cannot be called with `.remote_gen(...)`. Use `.remote(...)` instead."
1357
- )
1358
- async for item in self._call_generator(args, kwargs): # type: ignore
1359
- yield item
1360
-
1361
- def _is_local(self):
1362
- return self._info is not None
1363
-
1364
- def _get_info(self) -> FunctionInfo:
1365
- if not self._info:
1366
- raise ExecutionError("Can't get info for a function that isn't locally defined")
1367
- return self._info
1368
-
1369
- def _get_obj(self) -> Optional["modal.cls._Obj"]:
1370
- if not self._is_method:
1371
- return None
1372
- elif not self._obj:
1373
- raise ExecutionError("Method has no local object")
1374
- else:
1375
- return self._obj
1376
-
1377
- @synchronizer.nowrap
1378
- def local(self, *args: P.args, **kwargs: P.kwargs) -> OriginalReturnType:
1379
- """
1380
- Calls the function locally, executing it with the given arguments and returning the execution's result.
1381
-
1382
- The function will execute in the same environment as the caller, just like calling the underlying function
1383
- directly in Python. In particular, only secrets available in the caller environment will be available
1384
- through environment variables.
1385
- """
1386
- # TODO(erikbern): it would be nice to remove the nowrap thing, but right now that would cause
1387
- # "user code" to run on the synchronicity thread, which seems bad
1388
- if not self._is_local():
1389
- msg = (
1390
- "The definition for this function is missing here so it is not possible to invoke it locally. "
1391
- "If this function was retrieved via `Function.lookup` you need to use `.remote()`."
1392
- )
1393
- raise ExecutionError(msg)
1394
-
1395
- info = self._get_info()
1396
- if not info.raw_f:
1397
- # Here if calling .local on a service function itself which should never happen
1398
- # TODO: check if we end up here in a container for a serialized function?
1399
- raise ExecutionError("Can't call .local on service function")
1400
-
1401
- if is_local() and self.spec.volumes or self.spec.network_file_systems:
1402
- warnings.warn(
1403
- f"The {info.function_name} function is executing locally "
1404
- + "and will not have access to the mounted Volume or NetworkFileSystem data"
1405
- )
1406
-
1407
- obj: Optional["modal.cls._Obj"] = self._get_obj()
1408
-
1409
- if not obj:
1410
- fun = info.raw_f
1411
- return fun(*args, **kwargs)
1412
- else:
1413
- # This is a method on a class, so bind the self to the function
1414
- user_cls_instance = obj._cached_user_cls_instance()
1415
- fun = info.raw_f.__get__(user_cls_instance)
1416
-
1417
- # TODO: replace implicit local enter/exit with a context manager
1418
- if is_async(info.raw_f):
1419
- # We want to run __aenter__ and fun in the same coroutine
1420
- async def coro():
1421
- await obj._aenter()
1422
- return await fun(*args, **kwargs)
1423
-
1424
- return coro() # type: ignore
1425
- else:
1426
- obj._enter()
1427
- return fun(*args, **kwargs)
1428
-
1429
- @synchronizer.no_input_translation
1430
- @live_method
1431
- async def _experimental_spawn(self, *args: P.args, **kwargs: P.kwargs) -> "_FunctionCall[ReturnType]":
1432
- """[Experimental] Calls the function with the given arguments, without waiting for the results.
1433
-
1434
- This experimental version of the spawn method allows up to 1 million inputs to be spawned.
1435
-
1436
- Returns a `modal.functions.FunctionCall` object, that can later be polled or
1437
- waited for using `.get(timeout=...)`.
1438
- Conceptually similar to `multiprocessing.pool.apply_async`, or a Future/Promise in other contexts.
1439
- """
1440
- self._check_no_web_url("_experimental_spawn")
1441
- if self._is_generator:
1442
- invocation = await self._call_generator_nowait(args, kwargs)
1443
- else:
1444
- invocation = await self._call_function_nowait(
1445
- args, kwargs, function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC
1446
- )
1447
-
1448
- fc = _FunctionCall._new_hydrated(invocation.function_call_id, invocation.client, None)
1449
- fc._is_generator = self._is_generator if self._is_generator else False
1450
- return fc
1451
-
1452
- @synchronizer.no_input_translation
1453
- @live_method
1454
- async def spawn(self, *args: P.args, **kwargs: P.kwargs) -> "_FunctionCall[ReturnType]":
1455
- """Calls the function with the given arguments, without waiting for the results.
1456
-
1457
- Returns a `modal.functions.FunctionCall` object, that can later be polled or
1458
- waited for using `.get(timeout=...)`.
1459
- Conceptually similar to `multiprocessing.pool.apply_async`, or a Future/Promise in other contexts.
1460
- """
1461
- self._check_no_web_url("spawn")
1462
- if self._is_generator:
1463
- invocation = await self._call_generator_nowait(args, kwargs)
1464
- else:
1465
- invocation = await self._call_function_nowait(
1466
- args, kwargs, api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC_LEGACY
1467
- )
1468
-
1469
- fc = _FunctionCall._new_hydrated(invocation.function_call_id, invocation.client, None)
1470
- fc._is_generator = self._is_generator if self._is_generator else False
1471
- return fc
1472
-
1473
- def get_raw_f(self) -> Callable[..., Any]:
1474
- """Return the inner Python object wrapped by this Modal Function."""
1475
- assert self._raw_f is not None
1476
- return self._raw_f
1477
-
1478
- @live_method
1479
- async def get_current_stats(self) -> FunctionStats:
1480
- """Return a `FunctionStats` object describing the current function's queue and runner counts."""
1481
- resp = await retry_transient_errors(
1482
- self.client.stub.FunctionGetCurrentStats,
1483
- api_pb2.FunctionGetCurrentStatsRequest(function_id=self.object_id),
1484
- total_timeout=10.0,
1485
- )
1486
- return FunctionStats(backlog=resp.backlog, num_total_runners=resp.num_total_tasks)
1487
-
1488
- # A bit hacky - but the map-style functions need to not be synchronicity-wrapped
1489
- # in order to not execute their input iterators on the synchronicity event loop.
1490
- # We still need to wrap them using MethodWithAio to maintain a synchronicity-like
1491
- # api with `.aio` and get working type-stubs and reference docs generation:
1492
- map = MethodWithAio(_map_sync, _map_async, synchronizer)
1493
- starmap = MethodWithAio(_starmap_sync, _starmap_async, synchronizer)
1494
- for_each = MethodWithAio(_for_each_sync, _for_each_async, synchronizer)
1495
-
1496
-
1497
- Function = synchronize_api(_Function)
1498
-
1499
-
1500
- class _FunctionCall(typing.Generic[ReturnType], _Object, type_prefix="fc"):
1501
- """A reference to an executed function call.
1502
-
1503
- Constructed using `.spawn(...)` on a Modal function with the same
1504
- arguments that a function normally takes. Acts as a reference to
1505
- an ongoing function call that can be passed around and used to
1506
- poll or fetch function results at some later time.
1507
-
1508
- Conceptually similar to a Future/Promise/AsyncResult in other contexts and languages.
1509
- """
1510
-
1511
- _is_generator: bool = False
1512
-
1513
- def _invocation(self):
1514
- return _Invocation(self.client.stub, self.object_id, self.client)
1515
-
1516
- async def get(self, timeout: Optional[float] = None) -> ReturnType:
1517
- """Get the result of the function call.
1518
-
1519
- This function waits indefinitely by default. It takes an optional
1520
- `timeout` argument that specifies the maximum number of seconds to wait,
1521
- which can be set to `0` to poll for an output immediately.
1522
-
1523
- The returned coroutine is not cancellation-safe.
1524
- """
1525
-
1526
- if self._is_generator:
1527
- raise Exception("Cannot get the result of a generator function call. Use `get_gen` instead.")
1528
-
1529
- return await self._invocation().poll_function(timeout=timeout)
1530
-
1531
- async def get_gen(self) -> AsyncGenerator[Any, None]:
1532
- """
1533
- Calls the generator remotely, executing it with the given arguments and returning the execution's result.
1534
- """
1535
- if not self._is_generator:
1536
- raise Exception("Cannot iterate over a non-generator function call. Use `get` instead.")
1537
-
1538
- async for res in self._invocation().run_generator():
1539
- yield res
1540
-
1541
- async def get_call_graph(self) -> list[InputInfo]:
1542
- """Returns a structure representing the call graph from a given root
1543
- call ID, along with the status of execution for each node.
1544
-
1545
- See [`modal.call_graph`](/docs/reference/modal.call_graph) reference page
1546
- for documentation on the structure of the returned `InputInfo` items.
1547
- """
1548
- assert self._client and self._client.stub
1549
- request = api_pb2.FunctionGetCallGraphRequest(function_call_id=self.object_id)
1550
- response = await retry_transient_errors(self._client.stub.FunctionGetCallGraph, request)
1551
- return _reconstruct_call_graph(response)
1552
-
1553
- async def cancel(
1554
- self,
1555
- terminate_containers: bool = False, # if true, containers running the inputs are forcibly terminated
1556
- ):
1557
- """Cancels the function call, which will stop its execution and mark its inputs as
1558
- [`TERMINATED`](/docs/reference/modal.call_graph#modalcall_graphinputstatus).
1559
-
1560
- If `terminate_containers=True` - the containers running the cancelled inputs are all terminated
1561
- causing any non-cancelled inputs on those containers to be rescheduled in new containers.
1562
- """
1563
- request = api_pb2.FunctionCallCancelRequest(
1564
- function_call_id=self.object_id, terminate_containers=terminate_containers
1565
- )
1566
- assert self._client and self._client.stub
1567
- await retry_transient_errors(self._client.stub.FunctionCallCancel, request)
1568
-
1569
- @staticmethod
1570
- async def from_id(
1571
- function_call_id: str, client: Optional[_Client] = None, is_generator: bool = False
1572
- ) -> "_FunctionCall":
1573
- if client is None:
1574
- client = await _Client.from_env()
1575
-
1576
- fc = _FunctionCall._new_hydrated(function_call_id, client, None)
1577
- fc._is_generator = is_generator
1578
- return fc
1579
-
1580
-
1581
- FunctionCall = synchronize_api(_FunctionCall)
1582
-
1583
-
1584
- async def _gather(*function_calls: _FunctionCall[ReturnType]) -> typing.Sequence[ReturnType]:
1585
- """Wait until all Modal function calls have results before returning
1586
-
1587
- Accepts a variable number of FunctionCall objects as returned by `Function.spawn()`.
1588
-
1589
- Returns a list of results from each function call, or raises an exception
1590
- of the first failing function call.
1591
-
1592
- E.g.
1593
-
1594
- ```python notest
1595
- function_call_1 = slow_func_1.spawn()
1596
- function_call_2 = slow_func_2.spawn()
1597
-
1598
- result_1, result_2 = gather(function_call_1, function_call_2)
1599
- ```
1600
- """
1601
- try:
1602
- return await TaskContext.gather(*[fc.get() for fc in function_calls])
1603
- except Exception as exc:
1604
- # TODO: kill all running function calls
1605
- raise exc
1606
-
1607
-
1608
- gather = synchronize_api(_gather)
5
+ Function = synchronize_api(_Function, target_module=__name__)
6
+ FunctionCall = synchronize_api(_FunctionCall, target_module=__name__)
7
+ gather = synchronize_api(_gather, target_module=__name__)