modal 1.2.2.dev30__py3-none-any.whl → 1.2.2.dev36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. modal/_functions.py +77 -52
  2. modal/_load_context.py +105 -0
  3. modal/_object.py +47 -18
  4. modal/_resolver.py +21 -35
  5. modal/app.py +7 -0
  6. modal/app.pyi +3 -0
  7. modal/cli/dict.py +5 -2
  8. modal/cli/queues.py +4 -2
  9. modal/client.pyi +2 -2
  10. modal/cls.py +71 -32
  11. modal/cls.pyi +3 -0
  12. modal/dict.py +14 -5
  13. modal/dict.pyi +2 -0
  14. modal/environments.py +16 -7
  15. modal/environments.pyi +6 -2
  16. modal/functions.pyi +10 -4
  17. modal/image.py +22 -22
  18. modal/mount.py +35 -25
  19. modal/mount.pyi +33 -7
  20. modal/network_file_system.py +14 -5
  21. modal/network_file_system.pyi +12 -2
  22. modal/object.pyi +35 -8
  23. modal/proxy.py +14 -6
  24. modal/proxy.pyi +10 -2
  25. modal/queue.py +14 -5
  26. modal/queue.pyi +12 -2
  27. modal/runner.py +43 -47
  28. modal/runner.pyi +2 -2
  29. modal/sandbox.py +21 -12
  30. modal/secret.py +57 -39
  31. modal/secret.pyi +21 -4
  32. modal/serving.py +7 -11
  33. modal/serving.pyi +7 -8
  34. modal/snapshot.py +11 -5
  35. modal/volume.py +25 -7
  36. modal/volume.pyi +2 -0
  37. {modal-1.2.2.dev30.dist-info → modal-1.2.2.dev36.dist-info}/METADATA +1 -1
  38. {modal-1.2.2.dev30.dist-info → modal-1.2.2.dev36.dist-info}/RECORD +46 -45
  39. modal_proto/api.proto +4 -0
  40. modal_proto/api_pb2.py +684 -684
  41. modal_proto/api_pb2.pyi +24 -3
  42. modal_version/__init__.py +1 -1
  43. {modal-1.2.2.dev30.dist-info → modal-1.2.2.dev36.dist-info}/WHEEL +0 -0
  44. {modal-1.2.2.dev30.dist-info → modal-1.2.2.dev36.dist-info}/entry_points.txt +0 -0
  45. {modal-1.2.2.dev30.dist-info → modal-1.2.2.dev36.dist-info}/licenses/LICENSE +0 -0
  46. {modal-1.2.2.dev30.dist-info → modal-1.2.2.dev36.dist-info}/top_level.txt +0 -0
modal/_functions.py CHANGED
@@ -19,7 +19,8 @@ from synchronicity.combined_types import MethodWithAio
19
19
  from modal_proto import api_pb2
20
20
  from modal_proto.modal_api_grpc import ModalClientModal
21
21
 
22
- from ._object import _get_environment_name, _Object, live_method, live_method_gen
22
+ from ._load_context import LoadContext
23
+ from ._object import _Object, live_method, live_method_gen
23
24
  from ._pty import get_pty_info
24
25
  from ._resolver import Resolver
25
26
  from ._resources import convert_fn_config_to_resources_config
@@ -656,7 +657,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
656
657
  @staticmethod
657
658
  def from_local(
658
659
  info: FunctionInfo,
659
- app,
660
+ app: Optional["modal.app._App"], # App here should only be None in case of Image.run_function
660
661
  image: _Image,
661
662
  env: Optional[dict[str, Optional[str]]] = None,
662
663
  secrets: Optional[Collection[_Secret]] = None,
@@ -882,12 +883,12 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
882
883
  is_web_endpoint, is_generator, restrict_output
883
884
  )
884
885
 
885
- async def _preload(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
886
- assert resolver.client and resolver.client.stub
887
-
888
- assert resolver.app_id
886
+ async def _preload(
887
+ self: _Function, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]
888
+ ):
889
+ assert load_context.app_id
889
890
  req = api_pb2.FunctionPrecreateRequest(
890
- app_id=resolver.app_id,
891
+ app_id=load_context.app_id,
891
892
  function_name=info.function_name,
892
893
  function_type=function_type,
893
894
  existing_function_id=existing_object_id or "",
@@ -903,11 +904,12 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
903
904
  elif webhook_config:
904
905
  req.webhook_config.CopyFrom(webhook_config)
905
906
 
906
- response = await resolver.client.stub.FunctionPrecreate(req)
907
- self._hydrate(response.function_id, resolver.client, response.handle_metadata)
907
+ response = await load_context.client.stub.FunctionPrecreate(req)
908
+ self._hydrate(response.function_id, load_context.client, response.handle_metadata)
908
909
 
909
- async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
910
- assert resolver.client and resolver.client.stub
910
+ async def _load(
911
+ self: _Function, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]
912
+ ):
911
913
  with FunctionCreationStatus(resolver, tag) as function_creation_status:
912
914
  timeout_secs = timeout
913
915
 
@@ -1103,16 +1105,16 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1103
1105
  ),
1104
1106
  )
1105
1107
 
1106
- assert resolver.app_id
1108
+ assert load_context.app_id
1107
1109
  assert (function_definition is None) != (function_data is None) # xor
1108
1110
  request = api_pb2.FunctionCreateRequest(
1109
- app_id=resolver.app_id,
1111
+ app_id=load_context.app_id,
1110
1112
  function=function_definition,
1111
1113
  function_data=function_data,
1112
1114
  existing_function_id=existing_object_id or "",
1113
1115
  )
1114
1116
  try:
1115
- response: api_pb2.FunctionCreateResponse = await resolver.client.stub.FunctionCreate(request)
1117
+ response: api_pb2.FunctionCreateResponse = await load_context.client.stub.FunctionCreate(request)
1116
1118
  except GRPCError as exc:
1117
1119
  if exc.status == Status.INVALID_ARGUMENT:
1118
1120
  raise InvalidError(exc.message)
@@ -1127,10 +1129,14 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1127
1129
  serve_mounts = {m for m in all_mounts if m.is_local()}
1128
1130
  serve_mounts |= image._serve_mounts
1129
1131
  obj._serve_mounts = frozenset(serve_mounts)
1130
- self._hydrate(response.function_id, resolver.client, response.handle_metadata)
1132
+ self._hydrate(response.function_id, load_context.client, response.handle_metadata)
1131
1133
 
1132
1134
  rep = f"Function({tag})"
1133
- obj = _Function._from_loader(_load, rep, preload=_preload, deps=_deps)
1135
+ # Pass a *reference* to the App's LoadContext - this is important since the App is
1136
+ # the only way to infer a LoadContext for an `@app.function`, and the App doesn't
1137
+ # get its client until *after* the Function is created.
1138
+ load_context = app._root_load_context if app else LoadContext.empty()
1139
+ obj = _Function._from_loader(_load, rep, preload=_preload, deps=_deps, load_context_overrides=load_context)
1134
1140
 
1135
1141
  obj._raw_f = info.raw_f
1136
1142
  obj._info = info
@@ -1172,7 +1178,12 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1172
1178
 
1173
1179
  parent = self
1174
1180
 
1175
- async def _load(param_bound_func: _Function, resolver: Resolver, existing_object_id: Optional[str]):
1181
+ async def _load(
1182
+ param_bound_func: _Function,
1183
+ resolver: Resolver,
1184
+ load_context: LoadContext,
1185
+ existing_object_id: Optional[str],
1186
+ ):
1176
1187
  if not parent.is_hydrated:
1177
1188
  # While the base Object.hydrate() method appears to be idempotent, it's not always safe
1178
1189
  await parent.hydrate()
@@ -1205,7 +1216,6 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1205
1216
  param_bound_func._hydrate_from_other(parent)
1206
1217
  return
1207
1218
 
1208
- environment_name = _get_environment_name(None, resolver)
1209
1219
  assert parent is not None and parent.is_hydrated
1210
1220
 
1211
1221
  if options:
@@ -1245,7 +1255,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1245
1255
  function_id=parent.object_id,
1246
1256
  serialized_params=serialized_params,
1247
1257
  function_options=options_pb,
1248
- environment_name=environment_name
1258
+ environment_name=load_context.environment_name
1249
1259
  or "", # TODO: investigate shouldn't environment name always be specified here?
1250
1260
  )
1251
1261
 
@@ -1262,7 +1272,13 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1262
1272
  return [dep for dep in all_deps if not dep.is_hydrated]
1263
1273
  return []
1264
1274
 
1265
- fun: _Function = _Function._from_loader(_load, "Function(parametrized)", hydrate_lazily=True, deps=_deps)
1275
+ fun: _Function = _Function._from_loader(
1276
+ _load,
1277
+ "Function(parametrized)",
1278
+ hydrate_lazily=True,
1279
+ deps=_deps,
1280
+ load_context_overrides=self._load_context_overrides,
1281
+ )
1266
1282
 
1267
1283
  fun._info = self._info
1268
1284
  fun._obj = obj
@@ -1360,34 +1376,43 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1360
1376
  cls,
1361
1377
  app_name: str,
1362
1378
  name: str,
1363
- namespace=None, # mdmd:line-hidden
1364
- environment_name: Optional[str] = None,
1379
+ *,
1380
+ load_context_overrides: LoadContext,
1365
1381
  ):
1366
1382
  # internal function lookup implementation that allows lookup of class "service functions"
1367
1383
  # in addition to non-class functions
1368
- async def _load_remote(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
1369
- assert resolver.client and resolver.client.stub
1384
+ async def _load_remote(
1385
+ self: _Function, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]
1386
+ ):
1370
1387
  request = api_pb2.FunctionGetRequest(
1371
1388
  app_name=app_name,
1372
1389
  object_tag=name,
1373
- environment_name=_get_environment_name(environment_name, resolver) or "",
1390
+ environment_name=load_context.environment_name,
1374
1391
  )
1375
1392
  try:
1376
- response = await resolver.client.stub.FunctionGet(request)
1393
+ response = await load_context.client.stub.FunctionGet(request)
1377
1394
  except NotFoundError as exc:
1378
1395
  # refine the error message
1379
- env_context = f" (in the '{environment_name}' environment)" if environment_name else ""
1396
+ env_context = (
1397
+ f" (in the '{load_context.environment_name}' environment)" if load_context.environment_name else ""
1398
+ )
1380
1399
  raise NotFoundError(
1381
1400
  f"Lookup failed for Function '{name}' from the '{app_name}' app{env_context}: {exc}."
1382
1401
  ) from None
1383
1402
 
1384
1403
  print_server_warnings(response.server_warnings)
1385
1404
 
1386
- self._hydrate(response.function_id, resolver.client, response.handle_metadata)
1405
+ self._hydrate(response.function_id, load_context.client, response.handle_metadata)
1387
1406
 
1388
- environment_rep = f", environment_name={environment_name!r}" if environment_name else ""
1407
+ environment_rep = (
1408
+ f", environment_name={load_context_overrides.environment_name!r}"
1409
+ if load_context_overrides._environment_name # slightly ugly - checking if _environment_name is overridden
1410
+ else ""
1411
+ )
1389
1412
  rep = f"modal.Function.from_name('{app_name}', '{name}'{environment_rep})"
1390
- return cls._from_loader(_load_remote, rep, is_another_app=True, hydrate_lazily=True)
1413
+ return cls._from_loader(
1414
+ _load_remote, rep, is_another_app=True, hydrate_lazily=True, load_context_overrides=load_context_overrides
1415
+ )
1391
1416
 
1392
1417
  @classmethod
1393
1418
  def from_name(
@@ -1397,6 +1422,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1397
1422
  *,
1398
1423
  namespace=None, # mdmd:line-hidden
1399
1424
  environment_name: Optional[str] = None,
1425
+ client: Optional[_Client] = None,
1400
1426
  ) -> "_Function":
1401
1427
  """Reference a Function from a deployed App by its name.
1402
1428
 
@@ -1420,7 +1446,9 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1420
1446
  )
1421
1447
 
1422
1448
  warn_if_passing_namespace(namespace, "modal.Function.from_name")
1423
- return cls._from_name(app_name, name, environment_name=environment_name)
1449
+ return cls._from_name(
1450
+ app_name, name, load_context_overrides=LoadContext(environment_name=environment_name, client=client)
1451
+ )
1424
1452
 
1425
1453
  @property
1426
1454
  def tag(self) -> str:
@@ -1643,8 +1671,8 @@ Use the `Function.get_web_url()` method instead.
1643
1671
  input_queue,
1644
1672
  self.client,
1645
1673
  )
1646
- metadata = api_pb2.FunctionCallFromIdResponse(function_call_id=function_call_id, num_inputs=num_inputs)
1647
- fc: _FunctionCall[ReturnType] = _FunctionCall._new_hydrated(function_call_id, self.client, metadata)
1674
+ fc: _FunctionCall[ReturnType] = _FunctionCall._new_hydrated(function_call_id, self.client, None)
1675
+ fc._num_inputs = num_inputs # set the cached value of num_inputs
1648
1676
  return fc
1649
1677
 
1650
1678
  async def _call_function(self, args, kwargs) -> ReturnType:
@@ -1913,19 +1941,16 @@ class _FunctionCall(typing.Generic[ReturnType], _Object, type_prefix="fc"):
1913
1941
  def _invocation(self):
1914
1942
  return _Invocation(self.client.stub, self.object_id, self.client)
1915
1943
 
1916
- def _hydrate_metadata(self, metadata: Optional[Message]):
1917
- if not metadata:
1918
- return
1919
- assert isinstance(metadata, api_pb2.FunctionCallFromIdResponse)
1920
- self._num_inputs = metadata.num_inputs
1921
-
1922
1944
  @live_method
1923
1945
  async def num_inputs(self) -> int:
1924
1946
  """Get the number of inputs in the function call."""
1925
- # Should have been hydrated.
1926
- assert self._num_inputs is not None
1947
+ if self._num_inputs is None:
1948
+ request = api_pb2.FunctionCallFromIdRequest(function_call_id=self.object_id)
1949
+ resp = await self.client.stub.FunctionCallFromId(request)
1950
+ self._num_inputs = resp.num_inputs # cached
1927
1951
  return self._num_inputs
1928
1952
 
1953
+ @live_method
1929
1954
  async def get(self, timeout: Optional[float] = None, *, index: int = 0) -> ReturnType:
1930
1955
  """Get the result of the index-th input of the function call.
1931
1956
  `.spawn()` calls have a single output, so only specifying `index=0` is valid.
@@ -1969,6 +1994,7 @@ class _FunctionCall(typing.Generic[ReturnType], _Object, type_prefix="fc"):
1969
1994
  async for _, item in self._invocation().enumerate(start_index=start, end_index=end):
1970
1995
  yield item
1971
1996
 
1997
+ @live_method
1972
1998
  async def get_call_graph(self) -> list[InputInfo]:
1973
1999
  """Returns a structure representing the call graph from a given root
1974
2000
  call ID, along with the status of execution for each node.
@@ -1981,6 +2007,7 @@ class _FunctionCall(typing.Generic[ReturnType], _Object, type_prefix="fc"):
1981
2007
  response = await self._client.stub.FunctionGetCallGraph(request)
1982
2008
  return _reconstruct_call_graph(response)
1983
2009
 
2010
+ @live_method
1984
2011
  async def cancel(
1985
2012
  self,
1986
2013
  # if true, containers running the inputs are forcibly terminated
@@ -2018,20 +2045,18 @@ class _FunctionCall(typing.Generic[ReturnType], _Object, type_prefix="fc"):
2018
2045
  if you no longer have access to the original object returned from `Function.spawn`.
2019
2046
 
2020
2047
  """
2021
- if client is None:
2022
- client = await _Client.from_env()
2023
2048
 
2024
- async def _load(self: _FunctionCall, resolver: Resolver, existing_object_id: Optional[str]):
2025
- request = api_pb2.FunctionCallFromIdRequest(function_call_id=function_call_id)
2026
- resp = await resolver.client.stub.FunctionCallFromId(request)
2027
- self._hydrate(function_call_id, resolver.client, resp)
2049
+ async def _load(
2050
+ self: _FunctionCall, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]
2051
+ ):
2052
+ # this loader doesn't do anything in practice, but it will get the client from the load_context
2053
+ self._hydrate(function_call_id, load_context.client, None)
2028
2054
 
2029
2055
  rep = f"FunctionCall.from_id({function_call_id!r})"
2030
- fc: _FunctionCall[Any] = _FunctionCall._from_loader(_load, rep, hydrate_lazily=True)
2031
- # We already know the object ID, so we can set it directly
2032
- fc._object_id = function_call_id
2033
- fc._client = client
2034
- return fc
2056
+
2057
+ return _FunctionCall._from_loader(
2058
+ _load, rep, hydrate_lazily=True, load_context_overrides=LoadContext(client=client)
2059
+ )
2035
2060
 
2036
2061
  @staticmethod
2037
2062
  async def gather(*function_calls: "_FunctionCall[T]") -> typing.Sequence[T]:
modal/_load_context.py ADDED
@@ -0,0 +1,105 @@
1
+ # Copyright Modal Labs 2025
2
+ from typing import Optional
3
+
4
+ from .client import _Client
5
+ from .config import config
6
+
7
+
8
+ class LoadContext:
9
+ """Encapsulates optional metadata values used during object loading.
10
+
11
+ This metadata is set during object construction and propagated through
12
+ parent-child relationships (e.g., App -> Function, Cls -> Obj -> bound methods).
13
+ """
14
+
15
+ _client: Optional[_Client] = None
16
+ _environment_name: Optional[str] = None
17
+ _app_id: Optional[str] = None
18
+
19
+ def __init__(
20
+ self,
21
+ *,
22
+ client: Optional[_Client] = None,
23
+ environment_name: Optional[str] = None,
24
+ app_id: Optional[str] = None,
25
+ ):
26
+ self._client = client
27
+ self._environment_name = environment_name
28
+ self._app_id = app_id
29
+
30
+ @property
31
+ def client(self) -> _Client:
32
+ assert self._client is not None
33
+ return self._client
34
+
35
+ @property
36
+ def environment_name(self) -> str:
37
+ assert self._environment_name is not None
38
+ return self._environment_name
39
+
40
+ @property
41
+ def app_id(self) -> Optional[str]:
42
+ return self._app_id
43
+
44
+ @classmethod
45
+ def empty(cls) -> "LoadContext":
46
+ """Create an empty LoadContext with all fields set to None.
47
+
48
+ Used when loading objects that don't have a parent context.
49
+ """
50
+ return cls(client=None, environment_name=None, app_id=None)
51
+
52
+ def merged_with(self, parent: "LoadContext") -> "LoadContext":
53
+ """Create a new LoadContext with parent values filling in None fields.
54
+
55
+ Returns a new LoadContext without mutating self or parent.
56
+ Values from self take precedence over values from parent.
57
+ """
58
+ return LoadContext(
59
+ client=self._client if self._client is not None else parent._client,
60
+ environment_name=self._environment_name if self._environment_name is not None else parent._environment_name,
61
+ app_id=self._app_id if self._app_id is not None else parent._app_id,
62
+ ) # TODO (elias): apply_defaults?
63
+
64
+ async def apply_defaults(self) -> "LoadContext":
65
+ """Infer default client and environment_name if not present
66
+
67
+ Returns a new instance (no in place mutation)"""
68
+
69
+ return LoadContext(
70
+ client=await _Client.from_env() if self._client is None else self.client,
71
+ environment_name=self._environment_name or config.get("environment") or "",
72
+ app_id=self._app_id,
73
+ )
74
+
75
+ def reset(self) -> "LoadContext":
76
+ self._client = None
77
+ self._environment_name = None
78
+ self._app_id = None
79
+ return self
80
+
81
+ async def in_place_upgrade(
82
+ self, client: Optional[_Client] = None, environment_name: Optional[str] = None, app_id: Optional[str] = None
83
+ ) -> "LoadContext":
84
+ """In-place set values if they aren't already set, or set default values
85
+
86
+ Intended for Function/Cls hydration specifically
87
+
88
+ In those cases, it's important to in-place upgrade/apply_defaults since any "sibling" of the function/cls
89
+ would share the load context with its parent, and the initial load context overrides may not be sufficient
90
+ since an `app.deploy()` etc could get arguments that set a new client etc.
91
+
92
+ E.g.
93
+ @app.function()
94
+ def f():
95
+ ...
96
+
97
+ f2 = Function.with_options(...)
98
+
99
+ with app.run(client=...): # hydrates f and f2 at this point
100
+ ...
101
+ """
102
+ self._client = self._client or client or await _Client.from_env()
103
+ self._environment_name = self._environment_name or environment_name or config.get("environment") or ""
104
+ self._app_id = self._app_id or app_id
105
+ return self
modal/_object.py CHANGED
@@ -10,6 +10,7 @@ from typing_extensions import Self
10
10
 
11
11
  from modal._traceback import suppress_tb_frames
12
12
 
13
+ from ._load_context import LoadContext
13
14
  from ._resolver import Resolver
14
15
  from ._utils.async_utils import aclosing
15
16
  from ._utils.deprecation import deprecation_warning
@@ -20,11 +21,19 @@ from .exception import ExecutionError, InvalidError
20
21
  EPHEMERAL_OBJECT_HEARTBEAT_SLEEP: int = 300
21
22
 
22
23
 
23
- def _get_environment_name(environment_name: Optional[str] = None, resolver: Optional[Resolver] = None) -> Optional[str]:
24
+ def _get_environment_name(
25
+ environment_name: Optional[str] = None,
26
+ ) -> Optional[str]:
27
+ """Get environment name from various sources.
28
+
29
+ Args:
30
+ environment_name: Explicitly provided environment name (highest priority)
31
+
32
+ Returns:
33
+ Environment name from first available source, or config default
34
+ """
24
35
  if environment_name:
25
36
  return environment_name
26
- elif resolver and resolver.environment_name:
27
- return resolver.environment_name
28
37
  else:
29
38
  return config.get("environment")
30
39
 
@@ -34,13 +43,14 @@ class _Object:
34
43
  _prefix_to_type: ClassVar[dict[str, type]] = {}
35
44
 
36
45
  # For constructors
37
- _load: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]]
38
- _preload: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]]
46
+ _load: Optional[Callable[[Self, Resolver, LoadContext, Optional[str]], Awaitable[None]]]
47
+ _preload: Optional[Callable[[Self, Resolver, LoadContext, Optional[str]], Awaitable[None]]]
39
48
  _rep: str
40
49
  _is_another_app: bool
41
50
  _hydrate_lazily: bool
42
51
  _deps: Optional[Callable[..., Sequence["_Object"]]]
43
52
  _deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None
53
+ _load_context_overrides: LoadContext
44
54
 
45
55
  # For hydrated objects
46
56
  _object_id: Optional[str]
@@ -66,13 +76,15 @@ class _Object:
66
76
  def _init(
67
77
  self,
68
78
  rep: str,
69
- load: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]] = None,
79
+ load: Optional[Callable[[Self, Resolver, LoadContext, Optional[str]], Awaitable[None]]] = None,
70
80
  is_another_app: bool = False,
71
- preload: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]] = None,
81
+ preload: Optional[Callable[[Self, Resolver, LoadContext, Optional[str]], Awaitable[None]]] = None,
72
82
  hydrate_lazily: bool = False,
73
83
  deps: Optional[Callable[..., Sequence["_Object"]]] = None,
74
84
  deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None,
75
85
  name: Optional[str] = None,
86
+ *,
87
+ load_context_overrides: Optional[LoadContext] = None,
76
88
  ):
77
89
  self._local_uuid = str(uuid.uuid4())
78
90
  self._load = load
@@ -82,6 +94,9 @@ class _Object:
82
94
  self._hydrate_lazily = hydrate_lazily
83
95
  self._deps = deps
84
96
  self._deduplication_key = deduplication_key
97
+ self._load_context_overrides = (
98
+ load_context_overrides if load_context_overrides is not None else LoadContext.empty()
99
+ )
85
100
 
86
101
  self._object_id = None
87
102
  self._client = None
@@ -163,18 +178,30 @@ class _Object:
163
178
  @classmethod
164
179
  def _from_loader(
165
180
  cls,
166
- load: Callable[[Self, Resolver, Optional[str]], Awaitable[None]],
181
+ load: Callable[[Self, Resolver, LoadContext, Optional[str]], Awaitable[None]],
167
182
  rep: str,
168
183
  is_another_app: bool = False,
169
- preload: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]] = None,
184
+ preload: Optional[Callable[[Self, Resolver, LoadContext, Optional[str]], Awaitable[None]]] = None,
170
185
  hydrate_lazily: bool = False,
171
186
  deps: Optional[Callable[..., Sequence["_Object"]]] = None,
172
187
  deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None,
173
188
  name: Optional[str] = None,
189
+ *,
190
+ load_context_overrides: LoadContext,
174
191
  ):
175
192
  # TODO(erikbern): flip the order of the two first arguments
176
193
  obj = _Object.__new__(cls)
177
- obj._init(rep, load, is_another_app, preload, hydrate_lazily, deps, deduplication_key, name)
194
+ obj._init(
195
+ rep,
196
+ load,
197
+ is_another_app,
198
+ preload,
199
+ hydrate_lazily,
200
+ deps,
201
+ deduplication_key,
202
+ name,
203
+ load_context_overrides=load_context_overrides,
204
+ )
178
205
  return obj
179
206
 
180
207
  @staticmethod
@@ -275,25 +302,27 @@ class _Object:
275
302
 
276
303
  *Added in v0.72.39*: This method replaces the deprecated `.resolve()` method.
277
304
  """
305
+ # TODO: add deprecation for the client argument here - should be added in constructors instead
278
306
  if self._is_hydrated:
279
307
  if self.client._snapshotted and not self._is_rehydrated:
280
308
  # memory snapshots capture references which must be rehydrated
281
309
  # on restore to handle staleness.
282
310
  logger.debug(f"rehydrating {self} after snapshot")
283
311
  self._is_hydrated = False # un-hydrate and re-resolve
284
- c = client if client is not None else await _Client.from_env()
285
- resolver = Resolver(c)
286
- await resolver.load(typing.cast(_Object, self))
312
+ # Set the client on LoadContext before loading
313
+ root_load_context = LoadContext(client=client)
314
+ resolver = Resolver()
315
+ await resolver.load(typing.cast(_Object, self), root_load_context)
287
316
  self._is_rehydrated = True
288
- logger.debug(f"rehydrated {self} with client {id(c)}")
317
+ logger.debug(f"rehydrated {self} with client {id(self.client)}")
289
318
  elif not self._hydrate_lazily:
290
- # TODO(michael) can remove _hydrate lazily? I think all objects support it now?
291
319
  self._validate_is_hydrated()
292
320
  else:
293
- c = client if client is not None else await _Client.from_env()
294
- resolver = Resolver(c)
321
+ # Set the client on LoadContext before loading
322
+ root_load_context = LoadContext(client=client)
323
+ resolver = Resolver()
295
324
  with suppress_tb_frames(1): # skip this frame by default
296
- await resolver.load(self)
325
+ await resolver.load(self, root_load_context)
297
326
  return self
298
327
 
299
328
 
modal/_resolver.py CHANGED
@@ -8,17 +8,16 @@ from asyncio import Future
8
8
  from collections.abc import Hashable
9
9
  from typing import TYPE_CHECKING, Optional
10
10
 
11
+ import modal._object
11
12
  from modal._traceback import suppress_tb_frames
12
13
  from modal_proto import api_pb2
13
14
 
15
+ from ._load_context import LoadContext
14
16
  from ._utils.async_utils import TaskContext
15
- from .client import _Client
16
17
 
17
18
  if TYPE_CHECKING:
18
19
  from rich.tree import Tree
19
20
 
20
- import modal._object
21
-
22
21
 
23
22
  class StatusRow:
24
23
  def __init__(self, progress: "typing.Optional[Tree]"):
@@ -48,19 +47,10 @@ class StatusRow:
48
47
 
49
48
  class Resolver:
50
49
  _local_uuid_to_future: dict[str, Future]
51
- _environment_name: Optional[str]
52
- _app_id: Optional[str]
53
50
  _deduplication_cache: dict[Hashable, Future]
54
- _client: _Client
55
51
  _build_start: float
56
52
 
57
- def __init__(
58
- self,
59
- client: _Client,
60
- *,
61
- environment_name: Optional[str] = None,
62
- app_id: Optional[str] = None,
63
- ):
53
+ def __init__(self):
64
54
  try:
65
55
  # TODO(michael) If we don't clean this up more thoroughly, it would probably
66
56
  # be good to have a single source of truth for "rich is installed" rather than
@@ -75,9 +65,6 @@ class Resolver:
75
65
 
76
66
  self._local_uuid_to_future = {}
77
67
  self._tree = tree
78
- self._client = client
79
- self._app_id = app_id
80
- self._environment_name = environment_name
81
68
  self._deduplication_cache = {}
82
69
 
83
70
  with tempfile.TemporaryFile() as temp_file:
@@ -85,27 +72,24 @@ class Resolver:
85
72
  # to the mtime on mounted files, and want those measurements to have the same resolution.
86
73
  self._build_start = os.fstat(temp_file.fileno()).st_mtime
87
74
 
88
- @property
89
- def app_id(self) -> Optional[str]:
90
- return self._app_id
91
-
92
- @property
93
- def client(self):
94
- return self._client
95
-
96
- @property
97
- def environment_name(self):
98
- return self._environment_name
99
-
100
75
  @property
101
76
  def build_start(self) -> float:
102
77
  return self._build_start
103
78
 
104
- async def preload(self, obj, existing_object_id: Optional[str]):
79
+ async def preload(
80
+ self, obj: "modal._object._Object", parent_load_context: "LoadContext", existing_object_id: Optional[str]
81
+ ):
105
82
  if obj._preload is not None:
106
- await obj._preload(obj, self, existing_object_id)
83
+ load_context = obj._load_context_overrides.merged_with(parent_load_context)
84
+ await obj._preload(obj, self, load_context, existing_object_id)
107
85
 
108
- async def load(self, obj: "modal._object._Object", existing_object_id: Optional[str] = None):
86
+ async def load(
87
+ self,
88
+ obj: "modal._object._Object",
89
+ parent_load_context: "LoadContext",
90
+ *,
91
+ existing_object_id: Optional[str] = None,
92
+ ):
109
93
  if obj._is_hydrated and obj._is_another_app:
110
94
  # No need to reload this, it won't typically change
111
95
  if obj.local_uuid not in self._local_uuid_to_future:
@@ -129,21 +113,23 @@ class Resolver:
129
113
  cached_future = self._deduplication_cache.get(deduplication_key)
130
114
  if cached_future:
131
115
  hydrated_object = await cached_future
132
- obj._hydrate(hydrated_object.object_id, self._client, hydrated_object._get_metadata())
116
+ # Use the client from the already-hydrated object
117
+ obj._hydrate(hydrated_object.object_id, hydrated_object.client, hydrated_object._get_metadata())
133
118
  return obj
134
119
 
135
120
  if not cached_future:
136
121
  # don't run any awaits within this if-block to prevent race conditions
137
122
  async def loader():
138
- # Wait for all its dependencies
123
+ load_context = await obj._load_context_overrides.merged_with(parent_load_context).apply_defaults()
124
+
139
125
  # TODO(erikbern): do we need existing_object_id for those?
140
- await TaskContext.gather(*[self.load(dep) for dep in obj.deps()])
126
+ await TaskContext.gather(*[self.load(dep, load_context) for dep in obj.deps()])
141
127
 
142
128
  # Load the object itself
143
129
  if not obj._load:
144
130
  raise Exception(f"Object {obj} has no loader function")
145
131
 
146
- await obj._load(obj, self, existing_object_id)
132
+ await obj._load(obj, self, load_context, existing_object_id)
147
133
 
148
134
  # Check that the id of functions didn't change
149
135
  # Persisted refs are ignored because their life cycle is managed independently.