modal 1.2.2.dev22__py3-none-any.whl → 1.2.2.dev31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. modal/_functions.py +69 -37
  2. modal/_grpc_client.py +25 -2
  3. modal/_load_context.py +105 -0
  4. modal/_object.py +47 -18
  5. modal/_resolver.py +21 -35
  6. modal/_utils/function_utils.py +1 -2
  7. modal/app.py +7 -0
  8. modal/app.pyi +3 -0
  9. modal/cli/dict.py +5 -2
  10. modal/cli/queues.py +4 -2
  11. modal/client.pyi +2 -2
  12. modal/cloud_bucket_mount.py +2 -0
  13. modal/cloud_bucket_mount.pyi +4 -0
  14. modal/cls.py +71 -32
  15. modal/cls.pyi +3 -0
  16. modal/dict.py +14 -5
  17. modal/dict.pyi +2 -0
  18. modal/environments.py +16 -7
  19. modal/environments.pyi +6 -2
  20. modal/experimental/flash.py +2 -3
  21. modal/functions.pyi +10 -3
  22. modal/image.py +25 -25
  23. modal/image.pyi +9 -9
  24. modal/mount.py +34 -24
  25. modal/mount.pyi +33 -7
  26. modal/network_file_system.py +14 -5
  27. modal/network_file_system.pyi +12 -2
  28. modal/object.pyi +35 -8
  29. modal/proxy.py +14 -6
  30. modal/proxy.pyi +10 -2
  31. modal/queue.py +14 -5
  32. modal/queue.pyi +12 -2
  33. modal/runner.py +54 -50
  34. modal/runner.pyi +4 -3
  35. modal/sandbox.py +21 -12
  36. modal/secret.py +34 -17
  37. modal/secret.pyi +12 -2
  38. modal/serving.py +7 -11
  39. modal/serving.pyi +7 -8
  40. modal/snapshot.py +11 -5
  41. modal/volume.py +25 -7
  42. modal/volume.pyi +2 -0
  43. {modal-1.2.2.dev22.dist-info → modal-1.2.2.dev31.dist-info}/METADATA +2 -2
  44. {modal-1.2.2.dev22.dist-info → modal-1.2.2.dev31.dist-info}/RECORD +52 -51
  45. modal_proto/api.proto +10 -0
  46. modal_proto/api_pb2.py +841 -838
  47. modal_proto/api_pb2.pyi +25 -2
  48. modal_version/__init__.py +1 -1
  49. {modal-1.2.2.dev22.dist-info → modal-1.2.2.dev31.dist-info}/WHEEL +0 -0
  50. {modal-1.2.2.dev22.dist-info → modal-1.2.2.dev31.dist-info}/entry_points.txt +0 -0
  51. {modal-1.2.2.dev22.dist-info → modal-1.2.2.dev31.dist-info}/licenses/LICENSE +0 -0
  52. {modal-1.2.2.dev22.dist-info → modal-1.2.2.dev31.dist-info}/top_level.txt +0 -0
modal/proxy.py CHANGED
@@ -3,9 +3,11 @@ from typing import Optional
3
3
 
4
4
  from modal_proto import api_pb2
5
5
 
6
- from ._object import _get_environment_name, _Object
6
+ from ._load_context import LoadContext
7
+ from ._object import _Object
7
8
  from ._resolver import Resolver
8
9
  from ._utils.async_utils import synchronize_api
10
+ from .client import _Client
9
11
 
10
12
 
11
13
  class _Proxy(_Object, type_prefix="pr"):
@@ -20,6 +22,7 @@ class _Proxy(_Object, type_prefix="pr"):
20
22
  name: str,
21
23
  *,
22
24
  environment_name: Optional[str] = None,
25
+ client: Optional[_Client] = None,
23
26
  ) -> "_Proxy":
24
27
  """Reference a Proxy by its name.
25
28
 
@@ -28,16 +31,21 @@ class _Proxy(_Object, type_prefix="pr"):
28
31
 
29
32
  """
30
33
 
31
- async def _load(self: _Proxy, resolver: Resolver, existing_object_id: Optional[str]):
34
+ async def _load(self: _Proxy, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]):
32
35
  req = api_pb2.ProxyGetRequest(
33
36
  name=name,
34
- environment_name=_get_environment_name(environment_name, resolver),
37
+ environment_name=load_context.environment_name,
35
38
  )
36
- response: api_pb2.ProxyGetResponse = await resolver.client.stub.ProxyGet(req)
37
- self._hydrate(response.proxy.proxy_id, resolver.client, None)
39
+ response: api_pb2.ProxyGetResponse = await load_context.client.stub.ProxyGet(req)
40
+ self._hydrate(response.proxy.proxy_id, load_context.client, None)
38
41
 
39
42
  rep = _Proxy._repr(name, environment_name)
40
- return _Proxy._from_loader(_load, rep, is_another_app=True)
43
+ return _Proxy._from_loader(
44
+ _load,
45
+ rep,
46
+ is_another_app=True,
47
+ load_context_overrides=LoadContext(client=client, environment_name=environment_name),
48
+ )
41
49
 
42
50
 
43
51
  Proxy = synchronize_api(_Proxy, target_module=__name__)
modal/proxy.pyi CHANGED
@@ -1,4 +1,5 @@
1
1
  import modal._object
2
+ import modal.client
2
3
  import modal.object
3
4
  import typing
4
5
 
@@ -9,7 +10,12 @@ class _Proxy(modal._object._Object):
9
10
  a database. See [the guide](https://modal.com/docs/guide/proxy-ips) for more information.
10
11
  """
11
12
  @staticmethod
12
- def from_name(name: str, *, environment_name: typing.Optional[str] = None) -> _Proxy:
13
+ def from_name(
14
+ name: str,
15
+ *,
16
+ environment_name: typing.Optional[str] = None,
17
+ client: typing.Optional[modal.client._Client] = None,
18
+ ) -> _Proxy:
13
19
  """Reference a Proxy by its name.
14
20
 
15
21
  In contrast to most other Modal objects, new Proxy objects must be
@@ -28,7 +34,9 @@ class Proxy(modal.object.Object):
28
34
  ...
29
35
 
30
36
  @staticmethod
31
- def from_name(name: str, *, environment_name: typing.Optional[str] = None) -> Proxy:
37
+ def from_name(
38
+ name: str, *, environment_name: typing.Optional[str] = None, client: typing.Optional[modal.client.Client] = None
39
+ ) -> Proxy:
32
40
  """Reference a Proxy by its name.
33
41
 
34
42
  In contrast to most other Modal objects, new Proxy objects must be
modal/queue.py CHANGED
@@ -14,6 +14,7 @@ from synchronicity.async_wrap import asynccontextmanager
14
14
 
15
15
  from modal_proto import api_pb2
16
16
 
17
+ from ._load_context import LoadContext
17
18
  from ._object import (
18
19
  EPHEMERAL_OBJECT_HEARTBEAT_SLEEP,
19
20
  _get_environment_name,
@@ -361,6 +362,7 @@ class _Queue(_Object, type_prefix="qu"):
361
362
  namespace=None, # mdmd:line-hidden
362
363
  environment_name: Optional[str] = None,
363
364
  create_if_missing: bool = False,
365
+ client: Optional[_Client] = None,
364
366
  ) -> "_Queue":
365
367
  """Reference a named Queue, creating if necessary.
366
368
 
@@ -376,17 +378,24 @@ class _Queue(_Object, type_prefix="qu"):
376
378
  check_object_name(name, "Queue")
377
379
  warn_if_passing_namespace(namespace, "modal.Queue.from_name")
378
380
 
379
- async def _load(self: _Queue, resolver: Resolver, existing_object_id: Optional[str]):
381
+ async def _load(self: _Queue, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]):
380
382
  req = api_pb2.QueueGetOrCreateRequest(
381
383
  deployment_name=name,
382
- environment_name=_get_environment_name(environment_name, resolver),
384
+ environment_name=load_context.environment_name,
383
385
  object_creation_type=(api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING if create_if_missing else None),
384
386
  )
385
- response = await resolver.client.stub.QueueGetOrCreate(req)
386
- self._hydrate(response.queue_id, resolver.client, response.metadata)
387
+ response = await load_context.client.stub.QueueGetOrCreate(req)
388
+ self._hydrate(response.queue_id, load_context.client, response.metadata)
387
389
 
388
390
  rep = _Queue._repr(name, environment_name)
389
- return _Queue._from_loader(_load, rep, is_another_app=True, hydrate_lazily=True, name=name)
391
+ return _Queue._from_loader(
392
+ _load,
393
+ rep,
394
+ is_another_app=True,
395
+ hydrate_lazily=True,
396
+ name=name,
397
+ load_context_overrides=LoadContext(environment_name=environment_name, client=client),
398
+ )
390
399
 
391
400
  @staticmethod
392
401
  async def delete(name: str, *, client: Optional[_Client] = None, environment_name: Optional[str] = None):
modal/queue.pyi CHANGED
@@ -464,7 +464,12 @@ class _Queue(modal._object._Object):
464
464
 
465
465
  @staticmethod
466
466
  def from_name(
467
- name: str, *, namespace=None, environment_name: typing.Optional[str] = None, create_if_missing: bool = False
467
+ name: str,
468
+ *,
469
+ namespace=None,
470
+ environment_name: typing.Optional[str] = None,
471
+ create_if_missing: bool = False,
472
+ client: typing.Optional[modal.client._Client] = None,
468
473
  ) -> _Queue:
469
474
  """Reference a named Queue, creating if necessary.
470
475
 
@@ -721,7 +726,12 @@ class Queue(modal.object.Object):
721
726
 
722
727
  @staticmethod
723
728
  def from_name(
724
- name: str, *, namespace=None, environment_name: typing.Optional[str] = None, create_if_missing: bool = False
729
+ name: str,
730
+ *,
731
+ namespace=None,
732
+ environment_name: typing.Optional[str] = None,
733
+ create_if_missing: bool = False,
734
+ client: typing.Optional[modal.client.Client] = None,
725
735
  ) -> Queue:
726
736
  """Reference a named Queue, creating if necessary.
727
737
 
modal/runner.py CHANGED
@@ -8,7 +8,6 @@ import asyncio
8
8
  import dataclasses
9
9
  import os
10
10
  import time
11
- import typing
12
11
  from collections.abc import AsyncGenerator
13
12
  from contextlib import nullcontext
14
13
  from multiprocessing.synchronize import Event
@@ -19,6 +18,7 @@ from synchronicity.async_wrap import asynccontextmanager
19
18
 
20
19
  import modal._runtime.execution_context
21
20
  import modal_proto.api_pb2
21
+ from modal._load_context import LoadContext
22
22
  from modal._utils.grpc_utils import Retry
23
23
  from modal_proto import api_pb2
24
24
 
@@ -76,6 +76,7 @@ async def _init_local_app_existing(client: _Client, existing_app_id: str, enviro
76
76
  async def _init_local_app_new(
77
77
  client: _Client,
78
78
  description: str,
79
+ tags: dict[str, str],
79
80
  app_state: int, # ValueType
80
81
  environment_name: str = "",
81
82
  interactive: bool = False,
@@ -84,6 +85,7 @@ async def _init_local_app_new(
84
85
  description=description,
85
86
  environment_name=environment_name,
86
87
  app_state=app_state, # type: ignore
88
+ tags=tags,
87
89
  )
88
90
  app_resp, _ = await gather_cancel_on_exc( # TODO: use TaskGroup?
89
91
  client.stub.AppCreate(app_req),
@@ -102,6 +104,7 @@ async def _init_local_app_new(
102
104
  async def _init_local_app_from_name(
103
105
  client: _Client,
104
106
  name: str,
107
+ tags: dict[str, str],
105
108
  environment_name: str = "",
106
109
  ) -> RunningApp:
107
110
  # Look up any existing deployment
@@ -117,23 +120,19 @@ async def _init_local_app_from_name(
117
120
  return await _init_local_app_existing(client, existing_app_id, environment_name)
118
121
  else:
119
122
  return await _init_local_app_new(
120
- client, name, api_pb2.APP_STATE_INITIALIZING, environment_name=environment_name
123
+ client, name, tags, api_pb2.APP_STATE_INITIALIZING, environment_name=environment_name
121
124
  )
122
125
 
123
126
 
124
127
  async def _create_all_objects(
125
- client: _Client,
126
128
  running_app: RunningApp,
127
129
  local_app_state: "modal.app._LocalAppState",
128
- environment_name: str,
130
+ load_context: LoadContext,
129
131
  ) -> None:
130
132
  """Create objects that have been defined but not created on the server."""
131
133
  indexed_objects: dict[str, _Object] = {**local_app_state.functions, **local_app_state.classes}
132
- resolver = Resolver(
133
- client,
134
- environment_name=environment_name,
135
- app_id=running_app.app_id,
136
- )
134
+
135
+ resolver = Resolver()
137
136
  with resolver.display():
138
137
  # Get current objects, and reset all objects
139
138
  tag_to_object_id = {**running_app.function_ids, **running_app.class_ids}
@@ -156,7 +155,7 @@ async def _create_all_objects(
156
155
  # Note: preload only currently implemented for Functions, returns None otherwise
157
156
  # this is to ensure that directly referenced functions from the global scope has
158
157
  # ids associated with them when they are serialized into other functions
159
- await resolver.preload(obj, existing_object_id)
158
+ await resolver.preload(obj, load_context, existing_object_id)
160
159
  if obj.is_hydrated:
161
160
  tag_to_object_id[tag] = obj.object_id
162
161
 
@@ -164,7 +163,8 @@ async def _create_all_objects(
164
163
 
165
164
  async def _load(tag, obj):
166
165
  existing_object_id = tag_to_object_id.get(tag)
167
- await resolver.load(obj, existing_object_id)
166
+ # Pass load_context so dependencies can inherit app_id, client, etc.
167
+ await resolver.load(obj, load_context, existing_object_id=existing_object_id)
168
168
  if _Function._is_id_type(obj.object_id):
169
169
  running_app.function_ids[tag] = obj.object_id
170
170
  elif _Cls._is_id_type(obj.object_id):
@@ -263,8 +263,9 @@ async def _run_app(
263
263
  interactive: bool = False,
264
264
  ) -> AsyncGenerator["modal.app._App", None]:
265
265
  """mdmd:hidden"""
266
- if environment_name is None:
267
- environment_name = typing.cast(str, config.get("environment"))
266
+ load_context = await app._root_load_context.reset().in_place_upgrade(
267
+ client=client, environment_name=environment_name
268
+ )
268
269
 
269
270
  if modal._runtime.execution_context._is_currently_importing:
270
271
  raise InvalidError("Can not run an app in global scope within a container")
@@ -285,9 +286,6 @@ async def _run_app(
285
286
  # https://docs.python.org/3/library/__main__.html#import-main
286
287
  app.set_description(__main__.__name__)
287
288
 
288
- if client is None:
289
- client = await _Client.from_env()
290
-
291
289
  app_state = api_pb2.APP_STATE_DETACHED if detach else api_pb2.APP_STATE_EPHEMERAL
292
290
 
293
291
  output_mgr = _get_output_manager()
@@ -295,21 +293,25 @@ async def _run_app(
295
293
  msg = "Interactive mode requires output to be enabled. (Use the the `modal.enable_output()` context manager.)"
296
294
  raise InvalidError(msg)
297
295
 
296
+ local_app_state = app._local_state
297
+
298
298
  running_app: RunningApp = await _init_local_app_new(
299
- client,
299
+ load_context.client,
300
300
  app.description or "",
301
- environment_name=environment_name or "",
301
+ local_app_state.tags,
302
+ environment_name=load_context.environment_name,
302
303
  app_state=app_state,
303
304
  interactive=interactive,
304
305
  )
306
+ await load_context.in_place_upgrade(app_id=running_app.app_id)
305
307
 
306
308
  logs_timeout = config["logs_timeout"]
307
- async with app._set_local_app(client, running_app), TaskContext(grace=logs_timeout) as tc:
309
+ async with app._set_local_app(load_context.client, running_app), TaskContext(grace=logs_timeout) as tc:
308
310
  # Start heartbeats loop to keep the client alive
309
311
  # we don't log heartbeat exceptions in detached mode
310
312
  # as losing the local connection will not affect the running app
311
313
  def heartbeat():
312
- return _heartbeat(client, running_app.app_id)
314
+ return _heartbeat(load_context.client, running_app.app_id)
313
315
 
314
316
  heartbeat_loop = tc.infinite_loop(heartbeat, sleep=HEARTBEAT_INTERVAL, log_exception=not detach)
315
317
  logs_loop: Optional[asyncio.Task] = None
@@ -330,25 +332,26 @@ async def _run_app(
330
332
  # Start logs loop
331
333
 
332
334
  logs_loop = tc.create_task(
333
- get_app_logs_loop(client, output_mgr, app_id=running_app.app_id, app_logs_url=running_app.app_logs_url)
335
+ get_app_logs_loop(
336
+ load_context.client, output_mgr, app_id=running_app.app_id, app_logs_url=running_app.app_logs_url
337
+ )
334
338
  )
335
339
 
336
- local_app_state = app._local_state
337
340
  try:
338
341
  # Create all members
339
- await _create_all_objects(client, running_app, local_app_state, environment_name)
342
+ await _create_all_objects(running_app, local_app_state, load_context)
340
343
 
341
344
  # Publish the app
342
- await _publish_app(client, running_app, app_state, local_app_state)
345
+ await _publish_app(load_context.client, running_app, app_state, local_app_state)
343
346
  except asyncio.CancelledError as e:
344
347
  # this typically happens on sigint/ctrl-C during setup (the KeyboardInterrupt happens in the main thread)
345
348
  if output_mgr := _get_output_manager():
346
349
  output_mgr.print("Aborting app initialization...\n")
347
350
 
348
- await _status_based_disconnect(client, running_app.app_id, e)
351
+ await _status_based_disconnect(load_context.client, running_app.app_id, e)
349
352
  raise
350
353
  except BaseException as e:
351
- await _status_based_disconnect(client, running_app.app_id, e)
354
+ await _status_based_disconnect(load_context.client, running_app.app_id, e)
352
355
  raise
353
356
 
354
357
  detached_disconnect_msg = (
@@ -376,7 +379,7 @@ async def _run_app(
376
379
  yield app
377
380
  # successful completion!
378
381
  heartbeat_loop.cancel()
379
- await _status_based_disconnect(client, running_app.app_id, exc_info=None)
382
+ await _status_based_disconnect(load_context.client, running_app.app_id, exc_info=None)
380
383
  except KeyboardInterrupt as e:
381
384
  # this happens only if sigint comes in during the yield block above
382
385
  if detach:
@@ -385,13 +388,13 @@ async def _run_app(
385
388
  output_mgr.print(detached_disconnect_msg)
386
389
  if logs_loop:
387
390
  logs_loop.cancel()
388
- await _status_based_disconnect(client, running_app.app_id, e)
391
+ await _status_based_disconnect(load_context.client, running_app.app_id, e)
389
392
  else:
390
393
  if output_mgr := _get_output_manager():
391
394
  output_mgr.print(
392
395
  "Disconnecting from Modal - This will terminate your Modal app in a few seconds.\n"
393
396
  )
394
- await _status_based_disconnect(client, running_app.app_id, e)
397
+ await _status_based_disconnect(load_context.client, running_app.app_id, e)
395
398
  if logs_loop:
396
399
  try:
397
400
  await asyncio.wait_for(logs_loop, timeout=logs_timeout)
@@ -416,7 +419,7 @@ async def _run_app(
416
419
  raise
417
420
  except BaseException as e:
418
421
  logger.info("Exception during app run")
419
- await _status_based_disconnect(client, running_app.app_id, e)
422
+ await _status_based_disconnect(load_context.client, running_app.app_id, e)
420
423
  raise
421
424
 
422
425
  # wait for logs gracefully, even though the task context would do the same
@@ -444,21 +447,17 @@ async def _serve_update(
444
447
  ) -> None:
445
448
  """mdmd:hidden"""
446
449
  # Used by child process to reinitialize a served app
447
- client = await _Client.from_env()
450
+ load_context = await app._root_load_context.reset().in_place_upgrade(environment_name=environment_name)
448
451
  try:
449
- running_app: RunningApp = await _init_local_app_existing(client, existing_app_id, environment_name)
452
+ running_app: RunningApp = await _init_local_app_existing(load_context.client, existing_app_id, environment_name)
453
+ await load_context.in_place_upgrade(app_id=running_app.app_id)
450
454
  local_app_state = app._local_state
451
455
  # Create objects
452
- await _create_all_objects(
453
- client,
454
- running_app,
455
- local_app_state,
456
- environment_name,
457
- )
456
+ await _create_all_objects(running_app, local_app_state, load_context)
458
457
 
459
458
  # Publish the updated app
460
459
  await _publish_app(
461
- client,
460
+ load_context.client,
462
461
  running_app,
463
462
  app_state=api_pb2.APP_STATE_UNSPECIFIED,
464
463
  app_local_state=local_app_state,
@@ -493,9 +492,6 @@ async def _deploy_app(
493
492
 
494
493
  Users should prefer the `modal deploy` CLI or the `App.deploy` method.
495
494
  """
496
- if environment_name is None:
497
- environment_name = typing.cast(str, config.get("environment"))
498
-
499
495
  warn_if_passing_namespace(namespace, "modal.runner.deploy_app")
500
496
 
501
497
  name = name or app.name or ""
@@ -521,12 +517,25 @@ async def _deploy_app(
521
517
  if client is None:
522
518
  client = await _Client.from_env()
523
519
 
520
+ local_app_state = app._local_state
524
521
  t0 = time.time()
525
522
 
526
523
  # Get git information to track deployment history
527
524
  commit_info_task = asyncio.create_task(get_git_commit_info())
528
525
 
529
- running_app: RunningApp = await _init_local_app_from_name(client, name, environment_name=environment_name)
526
+ # We need to do in-place replacement of fields in self._root_load_context in case it has already "spread"
527
+ # to with_options() instances or similar before load
528
+ root_load_context = await app._root_load_context.reset().in_place_upgrade(
529
+ client=client,
530
+ environment_name=environment_name,
531
+ )
532
+ running_app: RunningApp = await _init_local_app_from_name(
533
+ root_load_context.client, name, local_app_state.tags, environment_name=root_load_context.environment_name
534
+ )
535
+
536
+ await root_load_context.in_place_upgrade(
537
+ app_id=running_app.app_id,
538
+ )
530
539
 
531
540
  async with TaskContext(0) as tc:
532
541
  # Start heartbeats loop to keep the client alive
@@ -537,12 +546,7 @@ async def _deploy_app(
537
546
 
538
547
  try:
539
548
  # Create all members
540
- await _create_all_objects(
541
- client,
542
- running_app,
543
- app._local_state,
544
- environment_name=environment_name,
545
- )
549
+ await _create_all_objects(running_app, local_app_state, root_load_context)
546
550
 
547
551
  commit_info = None
548
552
  try:
@@ -554,7 +558,7 @@ async def _deploy_app(
554
558
  client,
555
559
  running_app,
556
560
  api_pb2.APP_STATE_DEPLOYED,
557
- app._local_state,
561
+ local_app_state,
558
562
  name=name,
559
563
  deployment_tag=tag,
560
564
  commit_info=commit_info,
modal/runner.pyi CHANGED
@@ -1,3 +1,4 @@
1
+ import modal._load_context
1
2
  import modal.app
2
3
  import modal.client
3
4
  import modal.running_app
@@ -16,18 +17,18 @@ async def _init_local_app_existing(
16
17
  async def _init_local_app_new(
17
18
  client: modal.client._Client,
18
19
  description: str,
20
+ tags: dict[str, str],
19
21
  app_state: int,
20
22
  environment_name: str = "",
21
23
  interactive: bool = False,
22
24
  ) -> modal.running_app.RunningApp: ...
23
25
  async def _init_local_app_from_name(
24
- client: modal.client._Client, name: str, environment_name: str = ""
26
+ client: modal.client._Client, name: str, tags: dict[str, str], environment_name: str = ""
25
27
  ) -> modal.running_app.RunningApp: ...
26
28
  async def _create_all_objects(
27
- client: modal.client._Client,
28
29
  running_app: modal.running_app.RunningApp,
29
30
  local_app_state: modal.app._LocalAppState,
30
- environment_name: str,
31
+ load_context: modal._load_context.LoadContext,
31
32
  ) -> None:
32
33
  """Create objects that have been defined but not created on the server."""
33
34
  ...
modal/sandbox.py CHANGED
@@ -23,6 +23,7 @@ from modal.mount import _Mount
23
23
  from modal.volume import _Volume
24
24
  from modal_proto import api_pb2, task_command_router_pb2 as sr_pb2
25
25
 
26
+ from ._load_context import LoadContext
26
27
  from ._object import _get_environment_name, _Object
27
28
  from ._resolver import Resolver
28
29
  from ._resources import convert_fn_config_to_resources_config
@@ -191,7 +192,9 @@ class _Sandbox(_Object, type_prefix="sb"):
191
192
  deps.append(proxy)
192
193
  return deps
193
194
 
194
- async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optional[str]):
195
+ async def _load(
196
+ self: _Sandbox, resolver: Resolver, load_context: LoadContext, _existing_object_id: Optional[str]
197
+ ):
195
198
  # Relies on dicts being ordered (true as of Python 3.6).
196
199
  volume_mounts = [
197
200
  api_pb2.VolumeMount(
@@ -260,18 +263,18 @@ class _Sandbox(_Object, type_prefix="sb"):
260
263
  experimental_options=experimental_options,
261
264
  )
262
265
 
263
- create_req = api_pb2.SandboxCreateRequest(app_id=resolver.app_id, definition=definition)
266
+ create_req = api_pb2.SandboxCreateRequest(app_id=load_context.app_id, definition=definition)
264
267
  try:
265
- create_resp = await resolver.client.stub.SandboxCreate(create_req)
268
+ create_resp = await load_context.client.stub.SandboxCreate(create_req)
266
269
  except GRPCError as exc:
267
270
  if exc.status == Status.ALREADY_EXISTS:
268
271
  raise AlreadyExistsError(exc.message)
269
272
  raise exc
270
273
 
271
274
  sandbox_id = create_resp.sandbox_id
272
- self._hydrate(sandbox_id, resolver.client, None)
275
+ self._hydrate(sandbox_id, load_context.client, None)
273
276
 
274
- return _Sandbox._from_loader(_load, "Sandbox()", deps=_deps)
277
+ return _Sandbox._from_loader(_load, "Sandbox()", deps=_deps, load_context_overrides=LoadContext.empty())
275
278
 
276
279
  @staticmethod
277
280
  async def create(
@@ -486,6 +489,7 @@ class _Sandbox(_Object, type_prefix="sb"):
486
489
  app_id = app.app_id
487
490
  app_client = app._client
488
491
  elif (container_app := _App._get_container_app()) is not None:
492
+ # implicit app/client provided by running in a modal Function
489
493
  app_id = container_app.app_id
490
494
  app_client = container_app._client
491
495
  else:
@@ -498,10 +502,11 @@ class _Sandbox(_Object, type_prefix="sb"):
498
502
  "```",
499
503
  )
500
504
 
501
- client = client or app_client or await _Client.from_env()
505
+ client = client or app_client
502
506
 
503
- resolver = Resolver(client, app_id=app_id)
504
- await resolver.load(obj)
507
+ resolver = Resolver()
508
+ load_context = LoadContext(client=client, app_id=app_id)
509
+ await resolver.load(obj, load_context)
505
510
  return obj
506
511
 
507
512
  def _hydrate_metadata(self, handle_metadata: Optional[Message]):
@@ -606,12 +611,13 @@ class _Sandbox(_Object, type_prefix="sb"):
606
611
  image_id = resp.image_id
607
612
  metadata = resp.image_metadata
608
613
 
609
- async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[str]):
614
+ async def _load(self: _Image, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]):
610
615
  # no need to hydrate again since we do it eagerly below
611
616
  pass
612
617
 
613
618
  rep = "Image()"
614
- image = _Image._from_loader(_load, rep, hydrate_lazily=True)
619
+ # TODO: use ._new_hydrated instead
620
+ image = _Image._from_loader(_load, rep, hydrate_lazily=True, load_context_overrides=LoadContext.empty())
615
621
  image._hydrate(image_id, self._client, metadata) # hydrating eagerly since we have all of the data
616
622
 
617
623
  return image
@@ -990,12 +996,15 @@ class _Sandbox(_Object, type_prefix="sb"):
990
996
  if wait_resp.result.status != api_pb2.GenericResult.GENERIC_STATUS_SUCCESS:
991
997
  raise ExecutionError(wait_resp.result.exception)
992
998
 
993
- async def _load(self: _SandboxSnapshot, resolver: Resolver, existing_object_id: Optional[str]):
999
+ async def _load(
1000
+ self: _SandboxSnapshot, resolver: Resolver, load_context: LoadContext, existing_object_id: Optional[str]
1001
+ ):
994
1002
  # we eagerly hydrate the sandbox snapshot below
995
1003
  pass
996
1004
 
997
1005
  rep = "SandboxSnapshot()"
998
- obj = _SandboxSnapshot._from_loader(_load, rep, hydrate_lazily=True)
1006
+ # TODO: use ._new_hydrated instead
1007
+ obj = _SandboxSnapshot._from_loader(_load, rep, hydrate_lazily=True, load_context_overrides=LoadContext.empty())
999
1008
  obj._hydrate(snapshot_id, self._client, None)
1000
1009
 
1001
1010
  return obj