modal 0.66.39__py3-none-any.whl → 0.66.48__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/_container_entrypoint.py +1 -1
- modal/_runtime/user_code_imports.py +1 -1
- modal/_utils/grpc_testing.py +33 -26
- modal/app.py +39 -24
- modal/app.pyi +4 -2
- modal/cli/import_refs.py +1 -1
- modal/cli/launch.py +6 -4
- modal/cli/run.py +2 -2
- modal/client.pyi +2 -2
- modal/cls.py +26 -19
- modal/cls.pyi +4 -4
- modal/functions.py +32 -29
- modal/functions.pyi +1 -5
- modal/image.py +49 -2
- modal/image.pyi +14 -2
- modal/io_streams.py +40 -33
- modal/io_streams.pyi +13 -13
- modal/mount.py +3 -1
- modal/partial_function.py +1 -1
- modal/runner.py +12 -6
- {modal-0.66.39.dist-info → modal-0.66.48.dist-info}/METADATA +1 -1
- {modal-0.66.39.dist-info → modal-0.66.48.dist-info}/RECORD +30 -30
- modal_proto/api.proto +2 -0
- modal_proto/api_pb2.py +244 -244
- modal_proto/api_pb2.pyi +7 -2
- modal_version/_version_generated.py +1 -1
- {modal-0.66.39.dist-info → modal-0.66.48.dist-info}/LICENSE +0 -0
- {modal-0.66.39.dist-info → modal-0.66.48.dist-info}/WHEEL +0 -0
- {modal-0.66.39.dist-info → modal-0.66.48.dist-info}/entry_points.txt +0 -0
- {modal-0.66.39.dist-info → modal-0.66.48.dist-info}/top_level.txt +0 -0
modal/_container_entrypoint.py
CHANGED
@@ -499,7 +499,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
|
|
499
499
|
call_lifecycle_functions(event_loop, container_io_manager, list(pre_snapshot_methods.values()))
|
500
500
|
|
501
501
|
# If this container is being used to create a checkpoint, checkpoint the container after
|
502
|
-
# global imports and
|
502
|
+
# global imports and initialization. Checkpointed containers run from this point onwards.
|
503
503
|
if is_snapshotting_function:
|
504
504
|
container_io_manager.memory_snapshot()
|
505
505
|
|
@@ -197,7 +197,7 @@ def get_user_class_instance(
|
|
197
197
|
modal_obj: modal.cls.Obj = cls(*args, **kwargs)
|
198
198
|
modal_obj.entered = True # ugly but prevents .local() from triggering additional enter-logic
|
199
199
|
# TODO: unify lifecycle logic between .local() and container_entrypoint
|
200
|
-
user_cls_instance = modal_obj.
|
200
|
+
user_cls_instance = modal_obj._cached_user_cls_instance()
|
201
201
|
else:
|
202
202
|
# undecorated class (non-global decoration or serialized)
|
203
203
|
user_cls_instance = cls(*args, **kwargs)
|
modal/_utils/grpc_testing.py
CHANGED
@@ -52,7 +52,7 @@ def patch_mock_servicer(cls):
|
|
52
52
|
ctx = InterceptionContext()
|
53
53
|
servicer.interception_context = ctx
|
54
54
|
yield ctx
|
55
|
-
ctx.
|
55
|
+
ctx._assert_responses_consumed()
|
56
56
|
servicer.interception_context = None
|
57
57
|
|
58
58
|
cls.intercept = intercept
|
@@ -64,7 +64,7 @@ def patch_mock_servicer(cls):
|
|
64
64
|
ctx = servicer_self.interception_context
|
65
65
|
if ctx:
|
66
66
|
intercepted_stream = await InterceptedStream(ctx, method_name, stream).initialize()
|
67
|
-
custom_responder = ctx.
|
67
|
+
custom_responder = ctx._next_custom_responder(method_name, intercepted_stream.request_message)
|
68
68
|
if custom_responder:
|
69
69
|
return await custom_responder(servicer_self, intercepted_stream)
|
70
70
|
else:
|
@@ -105,19 +105,23 @@ class InterceptionContext:
|
|
105
105
|
self.custom_responses: Dict[str, List[Tuple[Callable[[Any], bool], List[Any]]]] = defaultdict(list)
|
106
106
|
self.custom_defaults: Dict[str, Callable[["MockClientServicer", grpclib.server.Stream], Awaitable[None]]] = {}
|
107
107
|
|
108
|
-
def add_recv(self, method_name: str, msg):
|
109
|
-
self.calls.append((method_name, msg))
|
110
|
-
|
111
108
|
def add_response(
|
112
109
|
self, method_name: str, first_payload, *, request_filter: Callable[[Any], bool] = lambda req: True
|
113
110
|
):
|
114
|
-
|
111
|
+
"""Adds one response payload to an expected queue of responses for a method.
|
112
|
+
|
113
|
+
These responses will be used once each instead of calling the MockServicer's
|
114
|
+
implementation of the method.
|
115
|
+
|
116
|
+
The interception context will throw an exception on exit if not all of the added
|
117
|
+
responses have been consumed.
|
118
|
+
"""
|
115
119
|
self.custom_responses[method_name].append((request_filter, [first_payload]))
|
116
120
|
|
117
121
|
def set_responder(
|
118
122
|
self, method_name: str, responder: Callable[["MockClientServicer", grpclib.server.Stream], Awaitable[None]]
|
119
123
|
):
|
120
|
-
"""Replace the default responder
|
124
|
+
"""Replace the default responder from the MockClientServicer with a custom implementation
|
121
125
|
|
122
126
|
```python notest
|
123
127
|
def custom_responder(servicer, stream):
|
@@ -128,11 +132,28 @@ class InterceptionContext:
|
|
128
132
|
ctx.set_responder("SomeMethod", custom_responder)
|
129
133
|
```
|
130
134
|
|
131
|
-
Responses added via `.add_response()` take precedence
|
135
|
+
Responses added via `.add_response()` take precedence over the use of this replacement
|
132
136
|
"""
|
133
137
|
self.custom_defaults[method_name] = responder
|
134
138
|
|
135
|
-
def
|
139
|
+
def pop_request(self, method_name):
|
140
|
+
# fast forward to the next request of type method_name
|
141
|
+
# dropping any preceding requests if there is a match
|
142
|
+
# returns the payload of the request
|
143
|
+
for i, (_method_name, msg) in enumerate(self.calls):
|
144
|
+
if _method_name == method_name:
|
145
|
+
self.calls = self.calls[i + 1 :]
|
146
|
+
return msg
|
147
|
+
|
148
|
+
raise KeyError(f"No message of that type in call list: {self.calls}")
|
149
|
+
|
150
|
+
def get_requests(self, method_name: str) -> List[Any]:
|
151
|
+
return [msg for _method_name, msg in self.calls if _method_name == method_name]
|
152
|
+
|
153
|
+
def _add_recv(self, method_name: str, msg):
|
154
|
+
self.calls.append((method_name, msg))
|
155
|
+
|
156
|
+
def _next_custom_responder(self, method_name, request):
|
136
157
|
method_responses = self.custom_responses[method_name]
|
137
158
|
for i, (request_filter, response_messages) in enumerate(method_responses):
|
138
159
|
try:
|
@@ -159,7 +180,7 @@ class InterceptionContext:
|
|
159
180
|
|
160
181
|
return responder
|
161
182
|
|
162
|
-
def
|
183
|
+
def _assert_responses_consumed(self):
|
163
184
|
unconsumed = []
|
164
185
|
for method_name, queued_responses in self.custom_responses.items():
|
165
186
|
unconsumed += [method_name] * len(queued_responses)
|
@@ -167,23 +188,9 @@ class InterceptionContext:
|
|
167
188
|
if unconsumed:
|
168
189
|
raise ResponseNotConsumed(unconsumed)
|
169
190
|
|
170
|
-
def pop_request(self, method_name):
|
171
|
-
# fast forward to the next request of type method_name
|
172
|
-
# dropping any preceding requests if there is a match
|
173
|
-
# returns the payload of the request
|
174
|
-
for i, (_method_name, msg) in enumerate(self.calls):
|
175
|
-
if _method_name == method_name:
|
176
|
-
self.calls = self.calls[i + 1 :]
|
177
|
-
return msg
|
178
|
-
|
179
|
-
raise KeyError(f"No message of that type in call list: {self.calls}")
|
180
|
-
|
181
|
-
def get_requests(self, method_name: str) -> List[Any]:
|
182
|
-
return [msg for _method_name, msg in self.calls if _method_name == method_name]
|
183
|
-
|
184
191
|
|
185
192
|
class InterceptedStream:
|
186
|
-
def __init__(self, interception_context, method_name, stream):
|
193
|
+
def __init__(self, interception_context: InterceptionContext, method_name: str, stream):
|
187
194
|
self.interception_context = interception_context
|
188
195
|
self.method_name = method_name
|
189
196
|
self.stream = stream
|
@@ -200,7 +207,7 @@ class InterceptedStream:
|
|
200
207
|
return ret
|
201
208
|
|
202
209
|
msg = await self.stream.recv_message()
|
203
|
-
self.interception_context.
|
210
|
+
self.interception_context._add_recv(self.method_name, msg)
|
204
211
|
return msg
|
205
212
|
|
206
213
|
async def send_message(self, msg):
|
modal/app.py
CHANGED
@@ -177,7 +177,8 @@ class _App:
|
|
177
177
|
|
178
178
|
_name: Optional[str]
|
179
179
|
_description: Optional[str]
|
180
|
-
|
180
|
+
_functions: Dict[str, _Function]
|
181
|
+
_classes: Dict[str, _Cls]
|
181
182
|
|
182
183
|
_image: Optional[_Image]
|
183
184
|
_mounts: Sequence[_Mount]
|
@@ -223,7 +224,8 @@ class _App:
|
|
223
224
|
if image is not None and not isinstance(image, _Image):
|
224
225
|
raise InvalidError("image has to be a modal Image or AioImage object")
|
225
226
|
|
226
|
-
self.
|
227
|
+
self._functions = {}
|
228
|
+
self._classes = {}
|
227
229
|
self._image = image
|
228
230
|
self._mounts = mounts
|
229
231
|
self._secrets = secrets
|
@@ -312,6 +314,7 @@ class _App:
|
|
312
314
|
raise InvalidError(f"App attribute `{key}` with value {value!r} is not a valid Modal object")
|
313
315
|
|
314
316
|
def _add_object(self, tag, obj):
|
317
|
+
# TODO(erikbern): replace this with _add_function and _add_class
|
315
318
|
if self._running_app:
|
316
319
|
# If this is inside a container, then objects can be defined after app initialization.
|
317
320
|
# So we may have to initialize objects once they get bound to the app.
|
@@ -320,7 +323,12 @@ class _App:
|
|
320
323
|
metadata: Message = self._running_app.object_handle_metadata[object_id]
|
321
324
|
obj._hydrate(object_id, self._client, metadata)
|
322
325
|
|
323
|
-
|
326
|
+
if isinstance(obj, _Function):
|
327
|
+
self._functions[tag] = obj
|
328
|
+
elif isinstance(obj, _Cls):
|
329
|
+
self._classes[tag] = obj
|
330
|
+
else:
|
331
|
+
raise RuntimeError(f"Expected `obj` to be a _Function or _Cls (got {type(obj)}")
|
324
332
|
|
325
333
|
def __getitem__(self, tag: str):
|
326
334
|
deprecation_error((2024, 3, 25), _app_attr_error)
|
@@ -334,7 +342,7 @@ class _App:
|
|
334
342
|
if tag.startswith("__"):
|
335
343
|
# Hacky way to avoid certain issues, e.g. pickle will try to look this up
|
336
344
|
raise AttributeError(f"App has no member {tag}")
|
337
|
-
if tag not in self.
|
345
|
+
if tag not in self._functions or tag not in self._classes:
|
338
346
|
# Primarily to make hasattr work
|
339
347
|
raise AttributeError(f"App has no member {tag}")
|
340
348
|
deprecation_error((2024, 3, 25), _app_attr_error)
|
@@ -360,7 +368,9 @@ class _App:
|
|
360
368
|
|
361
369
|
def _uncreate_all_objects(self):
|
362
370
|
# TODO(erikbern): this doesn't unhydrate objects that aren't tagged
|
363
|
-
for obj in self.
|
371
|
+
for obj in self._functions.values():
|
372
|
+
obj._unhydrate()
|
373
|
+
for obj in self._classes.values():
|
364
374
|
obj._unhydrate()
|
365
375
|
|
366
376
|
@asynccontextmanager
|
@@ -459,18 +469,17 @@ class _App:
|
|
459
469
|
return [m for m in all_mounts if m.is_local()]
|
460
470
|
|
461
471
|
def _add_function(self, function: _Function, is_web_endpoint: bool):
|
462
|
-
if function.tag in self.
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
logger.warning(f"Warning: tag {function.tag} exists but is overridden by function")
|
472
|
+
if function.tag in self._functions:
|
473
|
+
if not is_notebook():
|
474
|
+
old_function: _Function = self._functions[function.tag]
|
475
|
+
logger.warning(
|
476
|
+
f"Warning: Tag '{function.tag}' collision!"
|
477
|
+
" Overriding existing function "
|
478
|
+
f"[{old_function._info.module_name}].{old_function._info.function_name}"
|
479
|
+
f" with new function [{function._info.module_name}].{function._info.function_name}"
|
480
|
+
)
|
481
|
+
if function.tag in self._classes:
|
482
|
+
logger.warning(f"Warning: tag {function.tag} exists but is overridden by function")
|
474
483
|
|
475
484
|
self._add_object(function.tag, function)
|
476
485
|
if is_web_endpoint:
|
@@ -484,21 +493,22 @@ class _App:
|
|
484
493
|
_App._container_app = running_app
|
485
494
|
|
486
495
|
# Hydrate objects on app
|
496
|
+
indexed_objects = dict(**self._functions, **self._classes)
|
487
497
|
for tag, object_id in running_app.tag_to_object_id.items():
|
488
|
-
if tag in
|
489
|
-
obj =
|
498
|
+
if tag in indexed_objects:
|
499
|
+
obj = indexed_objects[tag]
|
490
500
|
handle_metadata = running_app.object_handle_metadata[object_id]
|
491
501
|
obj._hydrate(object_id, client, handle_metadata)
|
492
502
|
|
493
503
|
@property
|
494
504
|
def registered_functions(self) -> Dict[str, _Function]:
|
495
505
|
"""All modal.Function objects registered on the app."""
|
496
|
-
return
|
506
|
+
return self._functions
|
497
507
|
|
498
508
|
@property
|
499
509
|
def registered_classes(self) -> Dict[str, _Function]:
|
500
510
|
"""All modal.Cls objects registered on the app."""
|
501
|
-
return
|
511
|
+
return self._classes
|
502
512
|
|
503
513
|
@property
|
504
514
|
def registered_entrypoints(self) -> Dict[str, _LocalEntrypoint]:
|
@@ -507,7 +517,11 @@ class _App:
|
|
507
517
|
|
508
518
|
@property
|
509
519
|
def indexed_objects(self) -> Dict[str, _Object]:
|
510
|
-
|
520
|
+
deprecation_warning(
|
521
|
+
(2024, 11, 25),
|
522
|
+
"`app.indexed_objects` is deprecated! Use `app.registered_functions` or `app.registered_classes` instead.",
|
523
|
+
)
|
524
|
+
return dict(**self._functions, **self._classes)
|
511
525
|
|
512
526
|
@property
|
513
527
|
def registered_web_endpoints(self) -> List[str]:
|
@@ -1002,8 +1016,9 @@ class _App:
|
|
1002
1016
|
bar.remote()
|
1003
1017
|
```
|
1004
1018
|
"""
|
1005
|
-
|
1006
|
-
|
1019
|
+
indexed_objects = dict(**other_app._functions, **other_app._classes)
|
1020
|
+
for tag, object in indexed_objects.items():
|
1021
|
+
existing_object = indexed_objects.get(tag)
|
1007
1022
|
if existing_object and existing_object != object:
|
1008
1023
|
logger.warning(
|
1009
1024
|
f"Named app object {tag} with existing value {existing_object} is being "
|
modal/app.pyi
CHANGED
@@ -76,7 +76,8 @@ class _App:
|
|
76
76
|
_container_app: typing.ClassVar[typing.Optional[modal.running_app.RunningApp]]
|
77
77
|
_name: typing.Optional[str]
|
78
78
|
_description: typing.Optional[str]
|
79
|
-
|
79
|
+
_functions: typing.Dict[str, modal.functions._Function]
|
80
|
+
_classes: typing.Dict[str, modal.cls._Cls]
|
80
81
|
_image: typing.Optional[modal.image._Image]
|
81
82
|
_mounts: typing.Sequence[modal.mount._Mount]
|
82
83
|
_secrets: typing.Sequence[modal.secret._Secret]
|
@@ -270,7 +271,8 @@ class App:
|
|
270
271
|
_container_app: typing.ClassVar[typing.Optional[modal.running_app.RunningApp]]
|
271
272
|
_name: typing.Optional[str]
|
272
273
|
_description: typing.Optional[str]
|
273
|
-
|
274
|
+
_functions: typing.Dict[str, modal.functions.Function]
|
275
|
+
_classes: typing.Dict[str, modal.cls.Cls]
|
274
276
|
_image: typing.Optional[modal.image.Image]
|
275
277
|
_mounts: typing.Sequence[modal.mount.Mount]
|
276
278
|
_secrets: typing.Sequence[modal.secret.Secret]
|
modal/cli/import_refs.py
CHANGED
@@ -154,7 +154,7 @@ Registered functions and local entrypoints on the selected app are:
|
|
154
154
|
# entrypoint is in entrypoint registry, for now
|
155
155
|
return app.registered_entrypoints[function_name]
|
156
156
|
|
157
|
-
function = app.
|
157
|
+
function = app.registered_functions[function_name]
|
158
158
|
assert isinstance(function, Function)
|
159
159
|
return function
|
160
160
|
|
modal/cli/launch.py
CHANGED
@@ -25,7 +25,7 @@ launch_cli = Typer(
|
|
25
25
|
)
|
26
26
|
|
27
27
|
|
28
|
-
def _launch_program(name: str, filename: str, args: Dict[str, Any]) -> None:
|
28
|
+
def _launch_program(name: str, filename: str, detach: bool, args: Dict[str, Any]) -> None:
|
29
29
|
os.environ["MODAL_LAUNCH_ARGS"] = json.dumps(args)
|
30
30
|
|
31
31
|
program_path = str(Path(__file__).parent / "programs" / filename)
|
@@ -37,7 +37,7 @@ def _launch_program(name: str, filename: str, args: Dict[str, Any]) -> None:
|
|
37
37
|
func = entrypoint.info.raw_f
|
38
38
|
isasync = inspect.iscoroutinefunction(func)
|
39
39
|
with enable_output():
|
40
|
-
with run_app(app):
|
40
|
+
with run_app(app, detach=detach):
|
41
41
|
try:
|
42
42
|
if isasync:
|
43
43
|
asyncio.run(func())
|
@@ -57,6 +57,7 @@ def jupyter(
|
|
57
57
|
add_python: Optional[str] = "3.11",
|
58
58
|
mount: Optional[str] = None, # Create a `modal.Mount` from a local directory.
|
59
59
|
volume: Optional[str] = None, # Attach a persisted `modal.Volume` by name (creating if missing).
|
60
|
+
detach: bool = False, # Run the app in "detached" mode to persist after local client disconnects
|
60
61
|
):
|
61
62
|
args = {
|
62
63
|
"cpu": cpu,
|
@@ -68,7 +69,7 @@ def jupyter(
|
|
68
69
|
"mount": mount,
|
69
70
|
"volume": volume,
|
70
71
|
}
|
71
|
-
_launch_program("jupyter", "run_jupyter.py", args)
|
72
|
+
_launch_program("jupyter", "run_jupyter.py", detach, args)
|
72
73
|
|
73
74
|
|
74
75
|
@launch_cli.command(name="vscode", help="Start Visual Studio Code on Modal.")
|
@@ -79,6 +80,7 @@ def vscode(
|
|
79
80
|
timeout: int = 3600,
|
80
81
|
mount: Optional[str] = None, # Create a `modal.Mount` from a local directory.
|
81
82
|
volume: Optional[str] = None, # Attach a persisted `modal.Volume` by name (creating if missing).
|
83
|
+
detach: bool = False, # Run the app in "detached" mode to persist after local client disconnects
|
82
84
|
):
|
83
85
|
args = {
|
84
86
|
"cpu": cpu,
|
@@ -88,4 +90,4 @@ def vscode(
|
|
88
90
|
"mount": mount,
|
89
91
|
"volume": volume,
|
90
92
|
}
|
91
|
-
_launch_program("vscode", "vscode.py", args)
|
93
|
+
_launch_program("vscode", "vscode.py", detach, args)
|
modal/cli/run.py
CHANGED
@@ -136,7 +136,7 @@ def _get_clean_app_description(func_ref: str) -> str:
|
|
136
136
|
|
137
137
|
|
138
138
|
def _get_click_command_for_function(app: App, function_tag):
|
139
|
-
function = app.
|
139
|
+
function = app.registered_functions[function_tag]
|
140
140
|
assert isinstance(function, Function)
|
141
141
|
function = typing.cast(Function, function)
|
142
142
|
if function.is_generator:
|
@@ -147,7 +147,7 @@ def _get_click_command_for_function(app: App, function_tag):
|
|
147
147
|
method_name: Optional[str] = None
|
148
148
|
if function.info.user_cls is not None:
|
149
149
|
class_name, method_name = function_tag.rsplit(".", 1)
|
150
|
-
cls = typing.cast(Cls, app.
|
150
|
+
cls = typing.cast(Cls, app.registered_classes[class_name])
|
151
151
|
cls_signature = _get_signature(function.info.user_cls)
|
152
152
|
fun_signature = _get_signature(function.info.raw_f, is_method=True)
|
153
153
|
signature = dict(**cls_signature, **fun_signature) # Pool all arguments
|
modal/client.pyi
CHANGED
@@ -31,7 +31,7 @@ class _Client:
|
|
31
31
|
server_url: str,
|
32
32
|
client_type: int,
|
33
33
|
credentials: typing.Optional[typing.Tuple[str, str]],
|
34
|
-
version: str = "0.66.
|
34
|
+
version: str = "0.66.48",
|
35
35
|
): ...
|
36
36
|
def is_closed(self) -> bool: ...
|
37
37
|
@property
|
@@ -90,7 +90,7 @@ class Client:
|
|
90
90
|
server_url: str,
|
91
91
|
client_type: int,
|
92
92
|
credentials: typing.Optional[typing.Tuple[str, str]],
|
93
|
-
version: str = "0.66.
|
93
|
+
version: str = "0.66.48",
|
94
94
|
): ...
|
95
95
|
def is_closed(self) -> bool: ...
|
96
96
|
@property
|
modal/cls.py
CHANGED
@@ -113,7 +113,7 @@ class _Obj:
|
|
113
113
|
method = self._instance_service_function._bind_instance_method(class_bound_method)
|
114
114
|
self._method_functions[method_name] = method
|
115
115
|
else:
|
116
|
-
# <v0.63 classes - bind each individual method to the new parameters
|
116
|
+
# looked up <v0.63 classes - bind each individual method to the new parameters
|
117
117
|
self._instance_service_function = None
|
118
118
|
for method_name, class_bound_method in classbound_methods.items():
|
119
119
|
method = class_bound_method._bind_parameters(self, from_other_workspace, options, args, kwargs)
|
@@ -125,12 +125,14 @@ class _Obj:
|
|
125
125
|
self._user_cls = user_cls
|
126
126
|
self._construction_args = (args, kwargs) # used for lazy construction in case of explicit constructors
|
127
127
|
|
128
|
-
def
|
128
|
+
def _new_user_cls_instance(self):
|
129
129
|
args, kwargs = self._construction_args
|
130
130
|
if not _use_annotation_parameters(self._user_cls):
|
131
131
|
# TODO(elias): deprecate this code path eventually
|
132
132
|
user_cls_instance = self._user_cls(*args, **kwargs)
|
133
133
|
else:
|
134
|
+
# ignore constructor (assumes there is no custom constructor,
|
135
|
+
# which is guaranteed by _use_annotation_parameters)
|
134
136
|
# set the attributes on the class corresponding to annotations
|
135
137
|
# with = parameter() specifications
|
136
138
|
sig = _get_class_constructor_signature(self._user_cls)
|
@@ -139,6 +141,7 @@ class _Obj:
|
|
139
141
|
user_cls_instance = self._user_cls.__new__(self._user_cls) # new instance without running __init__
|
140
142
|
user_cls_instance.__dict__.update(bound_vars.arguments)
|
141
143
|
|
144
|
+
# TODO: always use Obj instances instead of making modifications to user cls
|
142
145
|
user_cls_instance._modal_functions = self._method_functions # Needed for PartialFunction.__get__
|
143
146
|
return user_cls_instance
|
144
147
|
|
@@ -163,10 +166,12 @@ class _Obj:
|
|
163
166
|
)
|
164
167
|
await self._instance_service_function.keep_warm(warm_pool_size)
|
165
168
|
|
166
|
-
def
|
167
|
-
"""
|
169
|
+
def _cached_user_cls_instance(self):
|
170
|
+
"""Get or construct the local object
|
171
|
+
|
172
|
+
Used for .local() calls and getting attributes of classes"""
|
168
173
|
if not self._user_cls_instance:
|
169
|
-
self._user_cls_instance = self.
|
174
|
+
self._user_cls_instance = self._new_user_cls_instance() # Instantiate object
|
170
175
|
|
171
176
|
return self._user_cls_instance
|
172
177
|
|
@@ -196,7 +201,7 @@ class _Obj:
|
|
196
201
|
@synchronizer.nowrap
|
197
202
|
async def aenter(self):
|
198
203
|
if not self.entered:
|
199
|
-
user_cls_instance = self.
|
204
|
+
user_cls_instance = self._cached_user_cls_instance()
|
200
205
|
if hasattr(user_cls_instance, "__aenter__"):
|
201
206
|
await user_cls_instance.__aenter__()
|
202
207
|
elif hasattr(user_cls_instance, "__enter__"):
|
@@ -205,20 +210,22 @@ class _Obj:
|
|
205
210
|
|
206
211
|
def __getattr__(self, k):
|
207
212
|
if k in self._method_functions:
|
208
|
-
#
|
209
|
-
#
|
210
|
-
#
|
213
|
+
# If we know the user is accessing a *method* and not another attribute,
|
214
|
+
# we don't have to create an instance of the user class yet.
|
215
|
+
# This is because it might just be a call to `.remote()` on it which
|
216
|
+
# doesn't require a local instance.
|
217
|
+
# As long as we have the service function or params, we can do remote calls
|
218
|
+
# without calling the constructor of the class in the calling context.
|
211
219
|
return self._method_functions[k]
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
raise AttributeError(k)
|
220
|
+
|
221
|
+
# if it's *not* a method, it *might* be an attribute of the class,
|
222
|
+
# so we construct it and proxy the attribute
|
223
|
+
# TODO: To get lazy loading (from_name) of classes to work, we need to avoid
|
224
|
+
# this path, otherwise local initialization will happen regardless if user
|
225
|
+
# only runs .remote(), since we don't know methods for the class until we
|
226
|
+
# load it
|
227
|
+
user_cls_instance = self._cached_user_cls_instance()
|
228
|
+
return getattr(user_cls_instance, k)
|
222
229
|
|
223
230
|
|
224
231
|
Obj = synchronize_api(_Obj)
|
modal/cls.pyi
CHANGED
@@ -37,9 +37,9 @@ class _Obj:
|
|
37
37
|
args,
|
38
38
|
kwargs,
|
39
39
|
): ...
|
40
|
-
def
|
40
|
+
def _new_user_cls_instance(self): ...
|
41
41
|
async def keep_warm(self, warm_pool_size: int) -> None: ...
|
42
|
-
def
|
42
|
+
def _cached_user_cls_instance(self): ...
|
43
43
|
def enter(self): ...
|
44
44
|
@property
|
45
45
|
def entered(self): ...
|
@@ -66,7 +66,7 @@ class Obj:
|
|
66
66
|
kwargs,
|
67
67
|
): ...
|
68
68
|
def _uses_common_service_function(self): ...
|
69
|
-
def
|
69
|
+
def _new_user_cls_instance(self): ...
|
70
70
|
|
71
71
|
class __keep_warm_spec(typing_extensions.Protocol):
|
72
72
|
def __call__(self, warm_pool_size: int) -> None: ...
|
@@ -74,7 +74,7 @@ class Obj:
|
|
74
74
|
|
75
75
|
keep_warm: __keep_warm_spec
|
76
76
|
|
77
|
-
def
|
77
|
+
def _cached_user_cls_instance(self): ...
|
78
78
|
def enter(self): ...
|
79
79
|
@property
|
80
80
|
def entered(self): ...
|