modal 1.2.2.dev30__py3-none-any.whl → 1.2.2.dev31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
modal/cli/dict.py CHANGED
@@ -4,6 +4,7 @@ from typing import Optional
4
4
  import typer
5
5
  from typer import Argument, Option, Typer
6
6
 
7
+ from modal._load_context import LoadContext
7
8
  from modal._output import make_console
8
9
  from modal._resolver import Resolver
9
10
  from modal._utils.async_utils import synchronizer
@@ -29,8 +30,10 @@ async def create(name: str, *, env: Optional[str] = ENV_OPTION):
29
30
  """
30
31
  d = _Dict.from_name(name, environment_name=env, create_if_missing=True)
31
32
  client = await _Client.from_env()
32
- resolver = Resolver(client=client)
33
- await resolver.load(d)
33
+ resolver = Resolver()
34
+
35
+ load_context = LoadContext(client=client, environment_name=env)
36
+ await resolver.load(d, load_context)
34
37
 
35
38
 
36
39
  @dict_cli.command(name="list", rich_help_panel="Management")
modal/cli/queues.py CHANGED
@@ -5,6 +5,7 @@ from typing import Optional
5
5
  import typer
6
6
  from typer import Argument, Option, Typer
7
7
 
8
+ from modal._load_context import LoadContext
8
9
  from modal._output import make_console
9
10
  from modal._resolver import Resolver
10
11
  from modal._utils.async_utils import synchronizer
@@ -38,8 +39,9 @@ async def create(name: str, *, env: Optional[str] = ENV_OPTION):
38
39
  """
39
40
  q = _Queue.from_name(name, environment_name=env, create_if_missing=True)
40
41
  client = await _Client.from_env()
41
- resolver = Resolver(client=client)
42
- await resolver.load(q)
42
+ resolver = Resolver()
43
+ load_context = LoadContext(client=client, environment_name=env)
44
+ await resolver.load(q, load_context)
43
45
 
44
46
 
45
47
  @queue_cli.command(name="delete", rich_help_panel="Management")
modal/client.pyi CHANGED
@@ -33,7 +33,7 @@ class _Client:
33
33
  server_url: str,
34
34
  client_type: int,
35
35
  credentials: typing.Optional[tuple[str, str]],
36
- version: str = "1.2.2.dev30",
36
+ version: str = "1.2.2.dev31",
37
37
  ):
38
38
  """mdmd:hidden
39
39
  The Modal client object is not intended to be instantiated directly by users.
@@ -164,7 +164,7 @@ class Client:
164
164
  server_url: str,
165
165
  client_type: int,
166
166
  credentials: typing.Optional[tuple[str, str]],
167
- version: str = "1.2.2.dev30",
167
+ version: str = "1.2.2.dev31",
168
168
  ):
169
169
  """mdmd:hidden
170
170
  The Modal client object is not intended to be instantiated directly by users.
modal/cls.py CHANGED
@@ -12,6 +12,7 @@ from grpclib import GRPCError, Status
12
12
  from modal_proto import api_pb2
13
13
 
14
14
  from ._functions import _Function, _parse_retries
15
+ from ._load_context import LoadContext
15
16
  from ._object import _Object, live_method
16
17
  from ._partial_function import (
17
18
  _find_callables_for_obj,
@@ -31,8 +32,8 @@ from ._utils.deprecation import (
31
32
  warn_on_renamed_autoscaler_settings,
32
33
  )
33
34
  from ._utils.mount_utils import validate_volumes
35
+ from .client import _Client
34
36
  from .cloud_bucket_mount import _CloudBucketMount
35
- from .config import config
36
37
  from .exception import ExecutionError, InvalidError, NotFoundError
37
38
  from .gpu import GPU_T
38
39
  from .retries import Retries
@@ -134,7 +135,7 @@ def _bind_instance_method(cls: "_Cls", service_function: _Function, method_name:
134
135
  method_metadata = cls._method_metadata[method_name]
135
136
  new_function._hydrate(service_function.object_id, service_function.client, method_metadata)
136
137
 
137
- async def _load(fun: "_Function", resolver: Resolver, existing_object_id: Optional[str]):
138
+ async def _load(fun: "_Function", resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]):
138
139
  # there is currently no actual loading logic executed to create each method on
139
140
  # the *parametrized* instance of a class - it uses the parameter-bound service-function
140
141
  # for the instance. This load method just makes sure to set all attributes after the
@@ -153,11 +154,14 @@ def _bind_instance_method(cls: "_Cls", service_function: _Function, method_name:
153
154
 
154
155
  rep = f"Method({cls._name}.{method_name})"
155
156
 
157
+ # Bound methods should *reference* their parent Cls's LoadContext
158
+ # so that it can be modified in place on the parent and be reflected in the method
156
159
  fun = _Function._from_loader(
157
160
  _load,
158
161
  rep,
159
162
  deps=_deps,
160
163
  hydrate_lazily=True,
164
+ load_context_overrides=cls._load_context_overrides,
161
165
  )
162
166
  if service_function.is_hydrated:
163
167
  # Eager hydration (skip load) if the instance service function is already loaded
@@ -421,14 +425,13 @@ class _Obj:
421
425
 
422
426
  # Not hydrated Cls, and we don't have the class - typically a Cls.from_name that
423
427
  # has not yet been loaded. So use a special loader that loads it lazily:
424
- async def method_loader(fun, resolver: Resolver, existing_object_id):
425
- await resolver.load(self._cls) # load class so we get info about methods
428
+ async def method_loader(fun, resolver: Resolver, load_context: LoadContext, existing_object_id):
426
429
  method_function = _get_maybe_method()
427
430
  if method_function is None:
428
431
  raise NotFoundError(
429
432
  f"Class has no method {k}, and attributes can't be accessed for `Cls.from_name` instances"
430
433
  )
431
- await resolver.load(method_function) # get the appropriate method handle (lazy)
434
+ await resolver.load(method_function, load_context) # get the appropriate method handle (lazy)
432
435
  fun._hydrate_from_other(method_function)
433
436
 
434
437
  # The reason we don't *always* use this lazy loader is because it precludes attribute access
@@ -436,8 +439,9 @@ class _Obj:
436
439
  return _Function._from_loader(
437
440
  method_loader,
438
441
  rep=f"Method({self._cls._name}.{k})",
439
- deps=lambda: [], # TODO: use cls as dep instead of loading inside method_loader?
442
+ deps=lambda: [self._cls],
440
443
  hydrate_lazily=True,
444
+ load_context_overrides=self._cls._load_context_overrides,
441
445
  )
442
446
 
443
447
 
@@ -484,6 +488,7 @@ class _Cls(_Object, type_prefix="cs"):
484
488
  self._callables = other._callables
485
489
  self._name = other._name
486
490
  self._method_metadata = other._method_metadata
491
+ self._load_context_overrides = other._load_context_overrides
487
492
 
488
493
  def _get_partial_functions(self) -> dict[str, _PartialFunction]:
489
494
  if not self._user_cls:
@@ -595,15 +600,18 @@ More information on class parameterization can be found here: https://modal.com/
595
600
  def _deps() -> list[_Function]:
596
601
  return [class_service_function]
597
602
 
598
- async def _load(self: "_Cls", resolver: Resolver, existing_object_id: Optional[str]):
603
+ async def _load(self: "_Cls", resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]):
599
604
  req = api_pb2.ClassCreateRequest(
600
- app_id=resolver.app_id, existing_class_id=existing_object_id, only_class_function=True
605
+ app_id=load_context.app_id, existing_class_id=existing_object_id, only_class_function=True
601
606
  )
602
- resp = await resolver.client.stub.ClassCreate(req)
603
- self._hydrate(resp.class_id, resolver.client, resp.handle_metadata)
607
+ resp = await load_context.client.stub.ClassCreate(req)
608
+ self._hydrate(resp.class_id, load_context.client, resp.handle_metadata)
604
609
 
605
610
  rep = f"Cls({user_cls.__name__})"
606
- cls: _Cls = _Cls._from_loader(_load, rep, deps=_deps)
611
+ # Pass a *reference* to the App's LoadContext - this is important since the App is
612
+ # the only way to infer a LoadContext for an `@app.cls`, and the App doesn't
613
+ # get its client until *after* the Cls is created.
614
+ cls: _Cls = _Cls._from_loader(_load, rep, deps=_deps, load_context_overrides=app._root_load_context)
607
615
  cls._app = app
608
616
  cls._user_cls = user_cls
609
617
  cls._class_service_function = class_service_function
@@ -620,6 +628,7 @@ More information on class parameterization can be found here: https://modal.com/
620
628
  *,
621
629
  namespace: Any = None, # mdmd:line-hidden
622
630
  environment_name: Optional[str] = None,
631
+ client: Optional["_Client"] = None,
623
632
  ) -> "_Cls":
624
633
  """Reference a Cls from a deployed App by its name.
625
634
 
@@ -632,19 +641,22 @@ More information on class parameterization can be found here: https://modal.com/
632
641
  ```
633
642
  """
634
643
  warn_if_passing_namespace(namespace, "modal.Cls.from_name")
635
- _environment_name = environment_name or config.get("environment")
636
644
 
637
- async def _load_remote(self: _Cls, resolver: Resolver, existing_object_id: Optional[str]):
645
+ async def _load_remote(
646
+ self: _Cls, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]
647
+ ):
638
648
  request = api_pb2.ClassGetRequest(
639
649
  app_name=app_name,
640
650
  object_tag=name,
641
- environment_name=_environment_name,
651
+ environment_name=load_context.environment_name,
642
652
  only_class_function=True,
643
653
  )
644
654
  try:
645
- response = await resolver.client.stub.ClassGet(request)
655
+ response = await load_context.client.stub.ClassGet(request)
646
656
  except NotFoundError as exc:
647
- env_context = f" (in the '{environment_name}' environment)" if environment_name else ""
657
+ env_context = (
658
+ f" (in the '{load_context.environment_name}' environment)" if load_context.environment_name else ""
659
+ )
648
660
  raise NotFoundError(
649
661
  f"Lookup failed for Cls '{name}' from the '{app_name}' app{env_context}: {exc}."
650
662
  ) from None
@@ -655,19 +667,26 @@ More information on class parameterization can be found here: https://modal.com/
655
667
  raise
656
668
 
657
669
  print_server_warnings(response.server_warnings)
658
- await resolver.load(self._class_service_function)
659
- self._hydrate(response.class_id, resolver.client, response.handle_metadata)
670
+ await resolver.load(self._class_service_function, load_context)
671
+ self._hydrate(response.class_id, load_context.client, response.handle_metadata)
660
672
 
661
673
  environment_rep = f", environment_name={environment_name!r}" if environment_name else ""
662
674
  rep = f"Cls.from_name({app_name!r}, {name!r}{environment_rep})"
663
- cls = cls._from_loader(_load_remote, rep, is_another_app=True, hydrate_lazily=True)
675
+
676
+ load_context_overrides = LoadContext(client=client, environment_name=environment_name)
677
+ cls = cls._from_loader(
678
+ _load_remote,
679
+ rep,
680
+ is_another_app=True,
681
+ hydrate_lazily=True,
682
+ load_context_overrides=load_context_overrides,
683
+ )
664
684
 
665
685
  class_service_name = f"{name}.*" # special name of the base service function for the class
666
686
  cls._class_service_function = _Function._from_name(
667
687
  app_name,
668
688
  class_service_name,
669
- namespace=namespace,
670
- environment_name=_environment_name,
689
+ load_context_overrides=load_context_overrides,
671
690
  )
672
691
  cls._name = name
673
692
  return cls
@@ -736,7 +755,7 @@ More information on class parameterization can be found here: https://modal.com/
736
755
  " please use the `.with_concurrency` method instead.",
737
756
  )
738
757
 
739
- async def _load_from_base(new_cls, resolver, existing_object_id):
758
+ async def _load_from_base(new_cls, resolver, load_context, existing_object_id):
740
759
  # this is a bit confusing, the cls will always have the same metadata
741
760
  # since it has the same *class* service function (i.e. "template")
742
761
  # But the (instance) service function for each Obj will be different
@@ -745,14 +764,21 @@ More information on class parameterization can be found here: https://modal.com/
745
764
  if not self.is_hydrated:
746
765
  # this should only happen for Cls.from_name instances
747
766
  # other classes should already be hydrated!
748
- await resolver.load(self)
767
+ await resolver.load(self, load_context)
749
768
 
750
769
  new_cls._initialize_from_other(self)
751
770
 
752
771
  def _deps():
753
772
  return []
754
773
 
755
- cls = _Cls._from_loader(_load_from_base, rep=f"{self._name}.with_options(...)", is_another_app=True, deps=_deps)
774
+ cls = _Cls._from_loader(
775
+ _load_from_base,
776
+ rep=f"{self._name}.with_options(...)",
777
+ is_another_app=True,
778
+ deps=_deps,
779
+ load_context_overrides=self._load_context_overrides,
780
+ hydrate_lazily=True,
781
+ )
756
782
  cls._initialize_from_other(self)
757
783
 
758
784
  # Validate volumes
@@ -797,16 +823,21 @@ More information on class parameterization can be found here: https://modal.com/
797
823
  ```
798
824
  """
799
825
 
800
- async def _load_from_base(new_cls, resolver, existing_object_id):
826
+ async def _load_from_base(new_cls, resolver, load_context, existing_object_id):
801
827
  if not self.is_hydrated:
802
- await resolver.load(self)
828
+ await resolver.load(self, load_context)
803
829
  new_cls._initialize_from_other(self)
804
830
 
805
831
  def _deps():
806
832
  return []
807
833
 
808
834
  cls = _Cls._from_loader(
809
- _load_from_base, rep=f"{self._name}.with_concurrency(...)", is_another_app=True, deps=_deps
835
+ _load_from_base,
836
+ rep=f"{self._name}.with_concurrency(...)",
837
+ is_another_app=True,
838
+ deps=_deps,
839
+ load_context_overrides=self._load_context_overrides,
840
+ hydrate_lazily=True,
810
841
  )
811
842
  cls._initialize_from_other(self)
812
843
 
@@ -826,16 +857,21 @@ More information on class parameterization can be found here: https://modal.com/
826
857
  ```
827
858
  """
828
859
 
829
- async def _load_from_base(new_cls, resolver, existing_object_id):
860
+ async def _load_from_base(new_cls, resolver, load_context, existing_object_id):
830
861
  if not self.is_hydrated:
831
- await resolver.load(self)
862
+ await resolver.load(self, load_context)
832
863
  new_cls._initialize_from_other(self)
833
864
 
834
865
  def _deps():
835
866
  return []
836
867
 
837
868
  cls = _Cls._from_loader(
838
- _load_from_base, rep=f"{self._name}.with_concurrency(...)", is_another_app=True, deps=_deps
869
+ _load_from_base,
870
+ rep=f"{self._name}.with_concurrency(...)",
871
+ is_another_app=True,
872
+ deps=_deps,
873
+ load_context_overrides=self._load_context_overrides,
874
+ hydrate_lazily=True,
839
875
  )
840
876
  cls._initialize_from_other(self)
841
877
 
@@ -865,17 +901,20 @@ More information on class parameterization can be found here: https://modal.com/
865
901
  # We create a synthetic dummy Function that is guaranteed to raise an AttributeError when
866
902
  # a user tries to use any of its "live methods" - this lets us raise exceptions for users
867
903
  # only if they try to access methods on a Cls as if they were methods on the instance.
868
- async def method_loader(fun: _Function, resolver: Resolver, existing_object_id: Optional[str]):
904
+ async def error_loader(
905
+ fun: _Function, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]
906
+ ):
869
907
  raise AttributeError(
870
908
  "You can't access methods on a Cls directly - Did you forget to instantiate the class first?\n"
871
909
  "e.g. instead of MyClass.method.remote(), do MyClass().method.remote()"
872
910
  )
873
911
 
874
912
  return _Function._from_loader(
875
- method_loader,
913
+ error_loader,
876
914
  rep=f"UnboundMethod({self._name}.{k})",
877
915
  deps=lambda: [],
878
916
  hydrate_lazily=True,
917
+ load_context_overrides=self._load_context_overrides,
879
918
  )
880
919
 
881
920
  def _is_local(self) -> bool:
modal/cls.pyi CHANGED
@@ -5,6 +5,7 @@ import modal._functions
5
5
  import modal._object
6
6
  import modal._partial_function
7
7
  import modal.app
8
+ import modal.client
8
9
  import modal.cloud_bucket_mount
9
10
  import modal.functions
10
11
  import modal.gpu
@@ -379,6 +380,7 @@ class _Cls(modal._object._Object):
379
380
  *,
380
381
  namespace: typing.Any = None,
381
382
  environment_name: typing.Optional[str] = None,
383
+ client: typing.Optional[modal.client._Client] = None,
382
384
  ) -> _Cls:
383
385
  """Reference a Cls from a deployed App by its name.
384
386
 
@@ -537,6 +539,7 @@ class Cls(modal.object.Object):
537
539
  *,
538
540
  namespace: typing.Any = None,
539
541
  environment_name: typing.Optional[str] = None,
542
+ client: typing.Optional[modal.client.Client] = None,
540
543
  ) -> Cls:
541
544
  """Reference a Cls from a deployed App by its name.
542
545
 
modal/dict.py CHANGED
@@ -12,6 +12,7 @@ from synchronicity.async_wrap import asynccontextmanager
12
12
  from modal._utils.grpc_utils import Retry
13
13
  from modal_proto import api_pb2
14
14
 
15
+ from ._load_context import LoadContext
15
16
  from ._object import (
16
17
  EPHEMERAL_OBJECT_HEARTBEAT_SLEEP,
17
18
  _get_environment_name,
@@ -347,6 +348,7 @@ class _Dict(_Object, type_prefix="di"):
347
348
  namespace=None, # mdmd:line-hidden
348
349
  environment_name: Optional[str] = None,
349
350
  create_if_missing: bool = False,
351
+ client: Optional[_Client] = None,
350
352
  ) -> "_Dict":
351
353
  """Reference a named Dict, creating if necessary.
352
354
 
@@ -368,20 +370,27 @@ class _Dict(_Object, type_prefix="di"):
368
370
  "Passing data to `modal.Dict.from_name` is deprecated and will stop working in a future release.",
369
371
  )
370
372
 
371
- async def _load(self: _Dict, resolver: Resolver, existing_object_id: Optional[str]):
373
+ async def _load(self: _Dict, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]):
372
374
  serialized = _serialize_dict(data if data is not None else {})
373
375
  req = api_pb2.DictGetOrCreateRequest(
374
376
  deployment_name=name,
375
- environment_name=_get_environment_name(environment_name, resolver),
377
+ environment_name=load_context.environment_name,
376
378
  object_creation_type=(api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING if create_if_missing else None),
377
379
  data=serialized,
378
380
  )
379
- response = await resolver.client.stub.DictGetOrCreate(req)
381
+ response = await load_context.client.stub.DictGetOrCreate(req)
380
382
  logger.debug(f"Created dict with id {response.dict_id}")
381
- self._hydrate(response.dict_id, resolver.client, response.metadata)
383
+ self._hydrate(response.dict_id, load_context.client, response.metadata)
382
384
 
383
385
  rep = _Dict._repr(name, environment_name)
384
- return _Dict._from_loader(_load, rep, is_another_app=True, hydrate_lazily=True, name=name)
386
+ return _Dict._from_loader(
387
+ _load,
388
+ rep,
389
+ is_another_app=True,
390
+ hydrate_lazily=True,
391
+ name=name,
392
+ load_context_overrides=LoadContext(environment_name=environment_name, client=client),
393
+ )
385
394
 
386
395
  @staticmethod
387
396
  async def delete(
modal/dict.pyi CHANGED
@@ -448,6 +448,7 @@ class _Dict(modal._object._Object):
448
448
  namespace=None,
449
449
  environment_name: typing.Optional[str] = None,
450
450
  create_if_missing: bool = False,
451
+ client: typing.Optional[modal.client._Client] = None,
451
452
  ) -> _Dict:
452
453
  """Reference a named Dict, creating if necessary.
453
454
 
@@ -666,6 +667,7 @@ class Dict(modal.object.Object):
666
667
  namespace=None,
667
668
  environment_name: typing.Optional[str] = None,
668
669
  create_if_missing: bool = False,
670
+ client: typing.Optional[modal.client.Client] = None,
669
671
  ) -> Dict:
670
672
  """Reference a named Dict, creating if necessary.
671
673
 
modal/environments.py CHANGED
@@ -8,6 +8,7 @@ from google.protobuf.wrappers_pb2 import StringValue
8
8
 
9
9
  from modal_proto import api_pb2
10
10
 
11
+ from ._load_context import LoadContext
11
12
  from ._object import _Object
12
13
  from ._resolver import Resolver
13
14
  from ._utils.async_utils import synchronize_api, synchronizer
@@ -52,6 +53,7 @@ class _Environment(_Object, type_prefix="en"):
52
53
  name: str,
53
54
  *,
54
55
  create_if_missing: bool = False,
56
+ client: Optional[_Client] = None,
55
57
  ):
56
58
  if name:
57
59
  # Allow null names for the case where we want to look up the "default" environment,
@@ -61,7 +63,9 @@ class _Environment(_Object, type_prefix="en"):
61
63
  # environments as part of public API when we make this class more useful.
62
64
  check_object_name(name, "Environment")
63
65
 
64
- async def _load(self: _Environment, resolver: Resolver, existing_object_id: Optional[str]):
66
+ async def _load(
67
+ self: _Environment, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]
68
+ ):
65
69
  request = api_pb2.EnvironmentGetOrCreateRequest(
66
70
  deployment_name=name,
67
71
  object_creation_type=(
@@ -70,12 +74,17 @@ class _Environment(_Object, type_prefix="en"):
70
74
  else api_pb2.OBJECT_CREATION_TYPE_UNSPECIFIED
71
75
  ),
72
76
  )
73
- response = await resolver.client.stub.EnvironmentGetOrCreate(request)
77
+ response = await load_context.client.stub.EnvironmentGetOrCreate(request)
74
78
  logger.debug(f"Created environment with id {response.environment_id}")
75
- self._hydrate(response.environment_id, resolver.client, response.metadata)
76
-
77
- # TODO environment name (and id?) in the repr? (We should make reprs consistently more useful)
78
- return _Environment._from_loader(_load, "Environment()", is_another_app=True, hydrate_lazily=True)
79
+ self._hydrate(response.environment_id, load_context.client, response.metadata)
80
+
81
+ return _Environment._from_loader(
82
+ _load,
83
+ f"Environment.from_name({name!r})",
84
+ is_another_app=True,
85
+ hydrate_lazily=True,
86
+ load_context_overrides=LoadContext(client=client),
87
+ )
79
88
 
80
89
 
81
90
  Environment = synchronize_api(_Environment)
@@ -88,7 +97,7 @@ ENVIRONMENT_CACHE: dict[str, _Environment] = {}
88
97
  async def _get_environment_cached(name: str, client: _Client) -> _Environment:
89
98
  if name in ENVIRONMENT_CACHE:
90
99
  return ENVIRONMENT_CACHE[name]
91
- environment = await _Environment.from_name(name).hydrate(client)
100
+ environment = await _Environment.from_name(name, client=client).hydrate()
92
101
  ENVIRONMENT_CACHE[name] = environment
93
102
  return environment
94
103
 
modal/environments.pyi CHANGED
@@ -45,7 +45,9 @@ class _Environment(modal._object._Object):
45
45
 
46
46
  def _hydrate_metadata(self, metadata: google.protobuf.message.Message): ...
47
47
  @staticmethod
48
- def from_name(name: str, *, create_if_missing: bool = False): ...
48
+ def from_name(
49
+ name: str, *, create_if_missing: bool = False, client: typing.Optional[modal.client._Client] = None
50
+ ): ...
49
51
 
50
52
  class Environment(modal.object.Object):
51
53
  _settings: EnvironmentSettings
@@ -56,7 +58,9 @@ class Environment(modal.object.Object):
56
58
 
57
59
  def _hydrate_metadata(self, metadata: google.protobuf.message.Message): ...
58
60
  @staticmethod
59
- def from_name(name: str, *, create_if_missing: bool = False): ...
61
+ def from_name(
62
+ name: str, *, create_if_missing: bool = False, client: typing.Optional[modal.client.Client] = None
63
+ ): ...
60
64
 
61
65
  async def _get_environment_cached(name: str, client: modal.client._Client) -> _Environment: ...
62
66
 
modal/functions.pyi CHANGED
@@ -1,6 +1,7 @@
1
1
  import collections.abc
2
2
  import google.protobuf.message
3
3
  import modal._functions
4
+ import modal._load_context
4
5
  import modal._utils.async_utils
5
6
  import modal._utils.function_utils
6
7
  import modal.app
@@ -66,7 +67,7 @@ class Function(
66
67
  @staticmethod
67
68
  def from_local(
68
69
  info: modal._utils.function_utils.FunctionInfo,
69
- app,
70
+ app: typing.Optional[modal.app.App],
70
71
  image: modal.image.Image,
71
72
  env: typing.Optional[dict[str, typing.Optional[str]]] = None,
72
73
  secrets: typing.Optional[collections.abc.Collection[modal.secret.Secret]] = None,
@@ -245,10 +246,16 @@ class Function(
245
246
  keep_warm: __keep_warm_spec[typing_extensions.Self]
246
247
 
247
248
  @classmethod
248
- def _from_name(cls, app_name: str, name: str, namespace=None, environment_name: typing.Optional[str] = None): ...
249
+ def _from_name(cls, app_name: str, name: str, *, load_context_overrides: modal._load_context.LoadContext): ...
249
250
  @classmethod
250
251
  def from_name(
251
- cls: type[Function], app_name: str, name: str, *, namespace=None, environment_name: typing.Optional[str] = None
252
+ cls: type[Function],
253
+ app_name: str,
254
+ name: str,
255
+ *,
256
+ namespace=None,
257
+ environment_name: typing.Optional[str] = None,
258
+ client: typing.Optional[modal.client.Client] = None,
252
259
  ) -> Function:
253
260
  """Reference a Function from a deployed App by its name.
254
261
 
modal/image.py CHANGED
@@ -28,6 +28,7 @@ from typing_extensions import Self
28
28
  from modal._serialization import serialize_data_format
29
29
  from modal_proto import api_pb2
30
30
 
31
+ from ._load_context import LoadContext
31
32
  from ._object import _Object, live_method_gen
32
33
  from ._resolver import Resolver
33
34
  from ._serialization import get_preferred_payload_format, serialize
@@ -434,12 +435,16 @@ class _Image(_Object, type_prefix="im"):
434
435
 
435
436
  base_image = self
436
437
 
437
- async def _load(self2: "_Image", resolver: Resolver, existing_object_id: Optional[str]):
438
+ async def _load(
439
+ self2: "_Image", resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]
440
+ ):
438
441
  self2._hydrate_from_other(base_image) # same image id as base image as long as it's lazy
439
442
  self2._deferred_mounts = tuple(base_image._deferred_mounts) + (mount,)
440
443
  self2._serve_mounts = base_image._serve_mounts | ({mount} if mount.is_local() else set())
441
444
 
442
- img = _Image._from_loader(_load, "Image(local files)", deps=lambda: [base_image, mount])
445
+ img = _Image._from_loader(
446
+ _load, "Image(local files)", deps=lambda: [base_image, mount], load_context_overrides=LoadContext.empty()
447
+ )
443
448
  img._added_python_source_set = base_image._added_python_source_set
444
449
  return img
445
450
 
@@ -523,18 +528,18 @@ class _Image(_Object, type_prefix="im"):
523
528
  deps += (vol,)
524
529
  return deps
525
530
 
526
- async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[str]):
531
+ async def _load(self: _Image, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]):
527
532
  context_mount = context_mount_function() if context_mount_function else None
528
533
  if context_mount:
529
- await resolver.load(context_mount)
534
+ await resolver.load(context_mount, load_context)
530
535
 
531
536
  if _do_assert_no_mount_layers:
532
537
  for image in base_images.values():
533
538
  # base images can't have
534
539
  image._assert_no_mount_layers()
535
540
 
536
- assert resolver.app_id # type narrowing
537
- environment = await _get_environment_cached(resolver.environment_name or "", resolver.client)
541
+ assert load_context.app_id # type narrowing
542
+ environment = await _get_environment_cached(load_context.environment_name or "", load_context.client)
538
543
  # A bit hacky,but assume that the environment provides a valid builder version
539
544
  image_builder_version = cast(ImageBuilderVersion, environment._settings.image_builder_version)
540
545
  builder_version = _get_image_builder_version(image_builder_version)
@@ -626,7 +631,7 @@ class _Image(_Object, type_prefix="im"):
626
631
  )
627
632
 
628
633
  req = api_pb2.ImageGetOrCreateRequest(
629
- app_id=resolver.app_id,
634
+ app_id=load_context.app_id,
630
635
  image=image_definition,
631
636
  existing_image_id=existing_object_id or "", # TODO: ignored
632
637
  build_function_id=build_function_id,
@@ -638,7 +643,7 @@ class _Image(_Object, type_prefix="im"):
638
643
  allow_global_deployment=os.environ.get("MODAL_IMAGE_ALLOW_GLOBAL_DEPLOYMENT") == "1",
639
644
  ignore_cache=config.get("ignore_cache"),
640
645
  )
641
- resp = await resolver.client.stub.ImageGetOrCreate(req)
646
+ resp = await load_context.client.stub.ImageGetOrCreate(req)
642
647
  image_id = resp.image_id
643
648
  result: api_pb2.GenericResult
644
649
  metadata: Optional[api_pb2.ImageMetadata] = None
@@ -651,7 +656,7 @@ class _Image(_Object, type_prefix="im"):
651
656
  else:
652
657
  # not built or in the process of building - wait for build
653
658
  logger.debug("Waiting for image %s" % image_id)
654
- resp = await _image_await_build_result(image_id, resolver.client)
659
+ resp = await _image_await_build_result(image_id, load_context.client)
655
660
  result = resp.result
656
661
  if resp.HasField("metadata"):
657
662
  metadata = resp.metadata
@@ -681,7 +686,7 @@ class _Image(_Object, type_prefix="im"):
681
686
  else:
682
687
  raise RemoteError("Unknown status %s!" % result.status)
683
688
 
684
- self._hydrate(image_id, resolver.client, metadata)
689
+ self._hydrate(image_id, load_context.client, metadata)
685
690
  local_mounts = set()
686
691
  for base in base_images.values():
687
692
  local_mounts |= base._serve_mounts
@@ -690,7 +695,7 @@ class _Image(_Object, type_prefix="im"):
690
695
  self._serve_mounts = frozenset(local_mounts)
691
696
 
692
697
  rep = f"Image({dockerfile_function})"
693
- obj = _Image._from_loader(_load, rep, deps=_deps)
698
+ obj = _Image._from_loader(_load, rep, deps=_deps, load_context_overrides=LoadContext.empty())
694
699
  obj.force_build = force_build
695
700
  obj._added_python_source_set = frozenset.union(
696
701
  frozenset(), *(base._added_python_source_set for base in base_images.values())
@@ -863,15 +868,13 @@ class _Image(_Object, type_prefix="im"):
863
868
 
864
869
  The ID of an Image object can be accessed using `.object_id`.
865
870
  """
866
- if client is None:
867
- client = await _Client.from_env()
868
871
 
869
- async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[str]):
870
- resp = await client.stub.ImageFromId(api_pb2.ImageFromIdRequest(image_id=image_id))
871
- self._hydrate(resp.image_id, resolver.client, resp.metadata)
872
+ async def _load(self: _Image, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]):
873
+ resp = await load_context.client.stub.ImageFromId(api_pb2.ImageFromIdRequest(image_id=image_id))
874
+ self._hydrate(resp.image_id, load_context.client, resp.metadata)
872
875
 
873
876
  rep = f"Image.from_id({image_id!r})"
874
- obj = _Image._from_loader(_load, rep)
877
+ obj = _Image._from_loader(_load, rep, load_context_overrides=LoadContext(client=client))
875
878
 
876
879
  return obj
877
880
 
@@ -930,11 +933,8 @@ class _Image(_Object, type_prefix="im"):
930
933
  if app.app_id is None:
931
934
  raise InvalidError("App has not been initialized yet. Use the content manager `app.run()` or `App.lookup`")
932
935
 
933
- app_id = app.app_id
934
- app_client = app._client or await _Client.from_env()
935
-
936
- resolver = Resolver(app_client, app_id=app_id)
937
- await resolver.load(self)
936
+ resolver = Resolver()
937
+ await resolver.load(self, app._root_load_context)
938
938
  return self
939
939
 
940
940
  def pip_install(