modal 0.67.7__py3-none-any.whl → 0.67.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/_clustered_functions.py +2 -2
- modal/_clustered_functions.pyi +2 -2
- modal/_container_entrypoint.py +5 -4
- modal/_output.py +29 -28
- modal/_pty.py +2 -2
- modal/_resolver.py +6 -5
- modal/_resources.py +3 -3
- modal/_runtime/asgi.py +7 -6
- modal/_runtime/container_io_manager.py +22 -26
- modal/_runtime/execution_context.py +2 -2
- modal/_runtime/telemetry.py +1 -2
- modal/_runtime/user_code_imports.py +11 -13
- modal/_serialization.py +3 -7
- modal/_traceback.py +5 -5
- modal/_tunnel.py +4 -3
- modal/_tunnel.pyi +2 -2
- modal/_utils/async_utils.py +8 -15
- modal/_utils/blob_utils.py +4 -3
- modal/_utils/function_utils.py +11 -10
- modal/_utils/grpc_testing.py +7 -6
- modal/_utils/grpc_utils.py +2 -3
- modal/_utils/hash_utils.py +2 -2
- modal/_utils/mount_utils.py +5 -4
- modal/_utils/package_utils.py +2 -3
- modal/_utils/pattern_matcher.py +6 -6
- modal/_utils/rand_pb_testing.py +3 -3
- modal/_utils/shell_utils.py +2 -1
- modal/_vendor/a2wsgi_wsgi.py +62 -72
- modal/_vendor/cloudpickle.py +1 -1
- modal/_watcher.py +8 -7
- modal/app.py +29 -34
- modal/app.pyi +102 -97
- modal/call_graph.py +6 -6
- modal/cli/_download.py +3 -2
- modal/cli/_traceback.py +4 -4
- modal/cli/app.py +4 -4
- modal/cli/container.py +4 -4
- modal/cli/dict.py +1 -1
- modal/cli/environment.py +2 -3
- modal/cli/launch.py +2 -2
- modal/cli/network_file_system.py +1 -1
- modal/cli/profile.py +1 -1
- modal/cli/programs/run_jupyter.py +2 -2
- modal/cli/programs/vscode.py +3 -3
- modal/cli/queues.py +1 -1
- modal/cli/run.py +6 -6
- modal/cli/secret.py +3 -3
- modal/cli/utils.py +2 -1
- modal/cli/volume.py +3 -3
- modal/client.py +6 -11
- modal/client.pyi +18 -27
- modal/cloud_bucket_mount.py +3 -3
- modal/cloud_bucket_mount.pyi +2 -2
- modal/cls.py +16 -15
- modal/cls.pyi +23 -22
- modal/config.py +2 -2
- modal/dict.py +4 -3
- modal/dict.pyi +10 -9
- modal/environments.py +3 -3
- modal/environments.pyi +3 -3
- modal/exception.py +2 -3
- modal/functions.py +20 -27
- modal/functions.pyi +44 -47
- modal/image.py +45 -48
- modal/image.pyi +102 -101
- modal/io_streams.py +4 -7
- modal/io_streams.pyi +14 -13
- modal/mount.py +23 -22
- modal/mount.pyi +28 -29
- modal/network_file_system.py +7 -6
- modal/network_file_system.pyi +12 -11
- modal/object.py +9 -8
- modal/object.pyi +47 -34
- modal/output.py +2 -1
- modal/parallel_map.py +4 -4
- modal/partial_function.py +9 -13
- modal/partial_function.pyi +17 -18
- modal/queue.py +9 -8
- modal/queue.pyi +23 -22
- modal/runner.py +8 -7
- modal/runner.pyi +8 -14
- modal/running_app.py +3 -3
- modal/sandbox.py +14 -13
- modal/sandbox.pyi +67 -72
- modal/scheduler_placement.py +2 -1
- modal/secret.py +7 -7
- modal/secret.pyi +12 -12
- modal/serving.py +4 -3
- modal/serving.pyi +5 -4
- modal/token_flow.py +3 -2
- modal/token_flow.pyi +3 -3
- modal/volume.py +7 -12
- modal/volume.pyi +17 -16
- {modal-0.67.7.dist-info → modal-0.67.8.dist-info}/METADATA +1 -1
- modal-0.67.8.dist-info/RECORD +168 -0
- modal_docs/mdmd/signatures.py +1 -2
- modal_version/_version_generated.py +1 -1
- modal-0.67.7.dist-info/RECORD +0 -168
- {modal-0.67.7.dist-info → modal-0.67.8.dist-info}/LICENSE +0 -0
- {modal-0.67.7.dist-info → modal-0.67.8.dist-info}/WHEEL +0 -0
- {modal-0.67.7.dist-info → modal-0.67.8.dist-info}/entry_points.txt +0 -0
- {modal-0.67.7.dist-info → modal-0.67.8.dist-info}/top_level.txt +0 -0
modal/_clustered_functions.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
import os
|
3
3
|
import socket
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from typing import
|
5
|
+
from typing import Optional
|
6
6
|
|
7
7
|
from modal._utils.async_utils import synchronize_api
|
8
8
|
from modal._utils.grpc_utils import retry_transient_errors
|
@@ -14,7 +14,7 @@ from modal_proto import api_pb2
|
|
14
14
|
@dataclass
|
15
15
|
class ClusterInfo:
|
16
16
|
rank: int
|
17
|
-
container_ips:
|
17
|
+
container_ips: list[str]
|
18
18
|
|
19
19
|
|
20
20
|
cluster_info: Optional[ClusterInfo] = None
|
modal/_clustered_functions.pyi
CHANGED
@@ -4,9 +4,9 @@ import typing_extensions
|
|
4
4
|
|
5
5
|
class ClusterInfo:
|
6
6
|
rank: int
|
7
|
-
container_ips:
|
7
|
+
container_ips: list[str]
|
8
8
|
|
9
|
-
def __init__(self, rank: int, container_ips:
|
9
|
+
def __init__(self, rank: int, container_ips: list[str]) -> None: ...
|
10
10
|
def __repr__(self): ...
|
11
11
|
def __eq__(self, other): ...
|
12
12
|
|
modal/_container_entrypoint.py
CHANGED
@@ -19,7 +19,8 @@ import signal
|
|
19
19
|
import sys
|
20
20
|
import threading
|
21
21
|
import time
|
22
|
-
from
|
22
|
+
from collections.abc import Sequence
|
23
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional
|
23
24
|
|
24
25
|
from google.protobuf.message import Message
|
25
26
|
|
@@ -175,7 +176,7 @@ class UserCodeEventLoop:
|
|
175
176
|
def call_function(
|
176
177
|
user_code_event_loop: UserCodeEventLoop,
|
177
178
|
container_io_manager: "modal._runtime.container_io_manager.ContainerIOManager",
|
178
|
-
finalized_functions:
|
179
|
+
finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
|
179
180
|
batch_max_size: int,
|
180
181
|
batch_wait_ms: int,
|
181
182
|
):
|
@@ -473,7 +474,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
|
|
473
474
|
# 1. Enable lazy hydration for all objects
|
474
475
|
# 2. Fully deprecate .new() objects
|
475
476
|
if service.code_deps is not None: # this is not set for serialized or non-global scope functions
|
476
|
-
dep_object_ids:
|
477
|
+
dep_object_ids: list[str] = [dep.object_id for dep in function_def.object_dependencies]
|
477
478
|
if len(service.code_deps) != len(dep_object_ids):
|
478
479
|
raise ExecutionError(
|
479
480
|
f"Function has {len(service.code_deps)} dependencies"
|
@@ -595,7 +596,7 @@ if __name__ == "__main__":
|
|
595
596
|
# from shutting down. The sleep(0) here is needed for finished ThreadPoolExecutor resources to
|
596
597
|
# shut down without triggering this warning (e.g., `@wsgi_app()`).
|
597
598
|
time.sleep(0)
|
598
|
-
lingering_threads:
|
599
|
+
lingering_threads: list[threading.Thread] = []
|
599
600
|
for thread in threading.enumerate():
|
600
601
|
current_thread = threading.get_ident()
|
601
602
|
if thread.ident is not None and thread.ident != current_thread and not thread.daemon and thread.is_alive():
|
modal/_output.py
CHANGED
@@ -9,8 +9,9 @@ import platform
|
|
9
9
|
import re
|
10
10
|
import socket
|
11
11
|
import sys
|
12
|
+
from collections.abc import Generator
|
12
13
|
from datetime import timedelta
|
13
|
-
from typing import Callable, ClassVar
|
14
|
+
from typing import Callable, ClassVar
|
14
15
|
|
15
16
|
from grpclib.exceptions import GRPCError, StreamTerminatedError
|
16
17
|
from rich.console import Console, Group, RenderableType
|
@@ -121,25 +122,25 @@ class LineBufferedOutput(io.StringIO):
|
|
121
122
|
|
122
123
|
|
123
124
|
class OutputManager:
|
124
|
-
_instance: ClassVar[
|
125
|
+
_instance: ClassVar[OutputManager | None] = None
|
125
126
|
|
126
127
|
_console: Console
|
127
|
-
_task_states:
|
128
|
-
_task_progress_items:
|
129
|
-
_current_render_group:
|
130
|
-
_function_progress:
|
131
|
-
_function_queueing_progress:
|
132
|
-
_snapshot_progress:
|
133
|
-
_line_buffers:
|
128
|
+
_task_states: dict[str, int]
|
129
|
+
_task_progress_items: dict[tuple[str, int], TaskID]
|
130
|
+
_current_render_group: Group | None
|
131
|
+
_function_progress: Progress | None
|
132
|
+
_function_queueing_progress: Progress | None
|
133
|
+
_snapshot_progress: Progress | None
|
134
|
+
_line_buffers: dict[int, LineBufferedOutput]
|
134
135
|
_status_spinner: Spinner
|
135
|
-
_app_page_url:
|
136
|
+
_app_page_url: str | None
|
136
137
|
_show_image_logs: bool
|
137
|
-
_status_spinner_live:
|
138
|
+
_status_spinner_live: Live | None
|
138
139
|
|
139
140
|
def __init__(
|
140
141
|
self,
|
141
142
|
*,
|
142
|
-
stdout:
|
143
|
+
stdout: io.TextIOWrapper | None = None,
|
143
144
|
status_spinner_text: str = "Running app...",
|
144
145
|
):
|
145
146
|
self._stdout = stdout or sys.stdout
|
@@ -164,12 +165,12 @@ class OutputManager:
|
|
164
165
|
cls._instance = None
|
165
166
|
|
166
167
|
@classmethod
|
167
|
-
def get(cls) ->
|
168
|
+
def get(cls) -> OutputManager | None:
|
168
169
|
return cls._instance
|
169
170
|
|
170
171
|
@classmethod
|
171
172
|
@contextlib.contextmanager
|
172
|
-
def enable_output(cls, show_progress: bool = True) -> Generator[None
|
173
|
+
def enable_output(cls, show_progress: bool = True) -> Generator[None]:
|
173
174
|
if show_progress:
|
174
175
|
cls._instance = OutputManager()
|
175
176
|
try:
|
@@ -252,7 +253,7 @@ class OutputManager:
|
|
252
253
|
self._current_render_group.renderables.append(self._function_queueing_progress)
|
253
254
|
return self._function_queueing_progress
|
254
255
|
|
255
|
-
def function_progress_callback(self, tag: str, total:
|
256
|
+
def function_progress_callback(self, tag: str, total: int | None) -> Callable[[int, int], None]:
|
256
257
|
"""Adds a task to the current function_progress instance, and returns a callback
|
257
258
|
to update task progress with new completed and total counts."""
|
258
259
|
|
@@ -330,7 +331,7 @@ class OutputManager:
|
|
330
331
|
pass
|
331
332
|
|
332
333
|
def update_queueing_progress(
|
333
|
-
self, *, function_id: str, completed: int, total:
|
334
|
+
self, *, function_id: str, completed: int, total: int | None, description: str | None
|
334
335
|
) -> None:
|
335
336
|
"""Handle queueing updates, ignoring completion updates for functions that have no queue progress bar."""
|
336
337
|
task_key = (function_id, api_pb2.FUNCTION_QUEUED)
|
@@ -449,13 +450,13 @@ class ProgressHandler:
|
|
449
450
|
|
450
451
|
def progress(
|
451
452
|
self,
|
452
|
-
task_id:
|
453
|
-
advance:
|
454
|
-
name:
|
455
|
-
size:
|
456
|
-
reset:
|
457
|
-
complete:
|
458
|
-
) ->
|
453
|
+
task_id: TaskID | None = None,
|
454
|
+
advance: float | None = None,
|
455
|
+
name: str | None = None,
|
456
|
+
size: float | None = None,
|
457
|
+
reset: bool | None = False,
|
458
|
+
complete: bool | None = False,
|
459
|
+
) -> TaskID | None:
|
459
460
|
try:
|
460
461
|
if task_id is not None:
|
461
462
|
if reset:
|
@@ -527,15 +528,15 @@ async def put_pty_content(log: api_pb2.TaskLogs, stdout):
|
|
527
528
|
async def get_app_logs_loop(
|
528
529
|
client: _Client,
|
529
530
|
output_mgr: OutputManager,
|
530
|
-
app_id:
|
531
|
-
task_id:
|
532
|
-
app_logs_url:
|
531
|
+
app_id: str | None = None,
|
532
|
+
task_id: str | None = None,
|
533
|
+
app_logs_url: str | None = None,
|
533
534
|
):
|
534
535
|
last_log_batch_entry_id = ""
|
535
536
|
|
536
537
|
pty_shell_stdout = None
|
537
|
-
pty_shell_finish_event:
|
538
|
-
pty_shell_task_id:
|
538
|
+
pty_shell_finish_event: asyncio.Event | None = None
|
539
|
+
pty_shell_task_id: str | None = None
|
539
540
|
|
540
541
|
async def stop_pty_shell():
|
541
542
|
nonlocal pty_shell_finish_event
|
modal/_pty.py
CHANGED
@@ -2,12 +2,12 @@
|
|
2
2
|
import contextlib
|
3
3
|
import os
|
4
4
|
import sys
|
5
|
-
from typing import Optional
|
5
|
+
from typing import Optional
|
6
6
|
|
7
7
|
from modal_proto import api_pb2
|
8
8
|
|
9
9
|
|
10
|
-
def get_winsz(fd) ->
|
10
|
+
def get_winsz(fd) -> tuple[Optional[int], Optional[int]]:
|
11
11
|
try:
|
12
12
|
import fcntl
|
13
13
|
import struct
|
modal/_resolver.py
CHANGED
@@ -3,7 +3,8 @@ import asyncio
|
|
3
3
|
import contextlib
|
4
4
|
import typing
|
5
5
|
from asyncio import Future
|
6
|
-
from
|
6
|
+
from collections.abc import Hashable
|
7
|
+
from typing import TYPE_CHECKING, Optional
|
7
8
|
|
8
9
|
from grpclib import GRPCError, Status
|
9
10
|
|
@@ -40,10 +41,10 @@ class StatusRow:
|
|
40
41
|
|
41
42
|
|
42
43
|
class Resolver:
|
43
|
-
_local_uuid_to_future:
|
44
|
+
_local_uuid_to_future: dict[str, Future]
|
44
45
|
_environment_name: Optional[str]
|
45
46
|
_app_id: Optional[str]
|
46
|
-
_deduplication_cache:
|
47
|
+
_deduplication_cache: dict[Hashable, Future]
|
47
48
|
_client: _Client
|
48
49
|
|
49
50
|
def __init__(
|
@@ -153,8 +154,8 @@ class Resolver:
|
|
153
154
|
# TODO(elias): print original exception/trace rather than the Resolver-internal trace
|
154
155
|
return await cached_future
|
155
156
|
|
156
|
-
def objects(self) ->
|
157
|
-
unique_objects:
|
157
|
+
def objects(self) -> list["_Object"]:
|
158
|
+
unique_objects: dict[str, "_Object"] = {}
|
158
159
|
for fut in self._local_uuid_to_future.values():
|
159
160
|
if not fut.done():
|
160
161
|
# this will raise an exception if not all loads have been awaited, but that *should* never happen
|
modal/_resources.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
# Copyright Modal Labs 2024
|
2
|
-
from typing import Optional,
|
2
|
+
from typing import Optional, Union
|
3
3
|
|
4
4
|
from modal_proto import api_pb2
|
5
5
|
|
@@ -9,8 +9,8 @@ from .gpu import GPU_T, parse_gpu_config
|
|
9
9
|
|
10
10
|
def convert_fn_config_to_resources_config(
|
11
11
|
*,
|
12
|
-
cpu: Optional[Union[float,
|
13
|
-
memory: Optional[Union[int,
|
12
|
+
cpu: Optional[Union[float, tuple[float, float]]],
|
13
|
+
memory: Optional[Union[int, tuple[int, int]]],
|
14
14
|
gpu: GPU_T,
|
15
15
|
ephemeral_disk: Optional[int],
|
16
16
|
) -> api_pb2.Resources:
|
modal/_runtime/asgi.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
# Copyright Modal Labs 2022
|
2
2
|
import asyncio
|
3
|
-
from
|
3
|
+
from collections.abc import AsyncGenerator
|
4
|
+
from typing import Any, Callable, NoReturn, Optional, cast
|
4
5
|
|
5
6
|
import aiohttp
|
6
7
|
|
@@ -80,8 +81,8 @@ class LifespanManager:
|
|
80
81
|
await self.shutdown
|
81
82
|
|
82
83
|
|
83
|
-
def asgi_app_wrapper(asgi_app, container_io_manager) ->
|
84
|
-
state:
|
84
|
+
def asgi_app_wrapper(asgi_app, container_io_manager) -> tuple[Callable[..., AsyncGenerator], LifespanManager]:
|
85
|
+
state: dict[str, Any] = {} # used for lifespan state
|
85
86
|
|
86
87
|
async def fn(scope):
|
87
88
|
if "state" in scope:
|
@@ -92,8 +93,8 @@ def asgi_app_wrapper(asgi_app, container_io_manager) -> Tuple[Callable[..., Asyn
|
|
92
93
|
function_call_id = current_function_call_id()
|
93
94
|
assert function_call_id, "internal error: function_call_id not set in asgi_app() scope"
|
94
95
|
|
95
|
-
messages_from_app: asyncio.Queue[
|
96
|
-
messages_to_app: asyncio.Queue[
|
96
|
+
messages_from_app: asyncio.Queue[dict[str, Any]] = asyncio.Queue(1)
|
97
|
+
messages_to_app: asyncio.Queue[dict[str, Any]] = asyncio.Queue(1)
|
97
98
|
|
98
99
|
async def disconnect_app():
|
99
100
|
if scope["type"] == "http":
|
@@ -415,7 +416,7 @@ async def _proxy_websocket_request(session: aiohttp.ClientSession, scope, receiv
|
|
415
416
|
raise ExecutionError(f"Unexpected message type: {client_message['type']}")
|
416
417
|
|
417
418
|
async def upstream_to_client():
|
418
|
-
msg:
|
419
|
+
msg: dict[str, Any] = {
|
419
420
|
"type": "websocket.accept",
|
420
421
|
"subprotocol": upstream_ws.protocol,
|
421
422
|
}
|
@@ -9,19 +9,15 @@ import signal
|
|
9
9
|
import sys
|
10
10
|
import time
|
11
11
|
import traceback
|
12
|
+
from collections.abc import AsyncGenerator, AsyncIterator
|
12
13
|
from contextlib import AsyncExitStack
|
13
14
|
from pathlib import Path
|
14
15
|
from typing import (
|
15
16
|
TYPE_CHECKING,
|
16
17
|
Any,
|
17
|
-
AsyncGenerator,
|
18
|
-
AsyncIterator,
|
19
18
|
Callable,
|
20
19
|
ClassVar,
|
21
|
-
Dict,
|
22
|
-
List,
|
23
20
|
Optional,
|
24
|
-
Tuple,
|
25
21
|
)
|
26
22
|
|
27
23
|
from google.protobuf.empty_pb2 import Empty
|
@@ -68,8 +64,8 @@ class IOContext:
|
|
68
64
|
in a batched or single input context.
|
69
65
|
"""
|
70
66
|
|
71
|
-
input_ids:
|
72
|
-
function_call_ids:
|
67
|
+
input_ids: list[str]
|
68
|
+
function_call_ids: list[str]
|
73
69
|
finalized_function: "modal._runtime.user_code_imports.FinalizedFunction"
|
74
70
|
|
75
71
|
_cancel_issued: bool = False
|
@@ -77,10 +73,10 @@ class IOContext:
|
|
77
73
|
|
78
74
|
def __init__(
|
79
75
|
self,
|
80
|
-
input_ids:
|
81
|
-
function_call_ids:
|
76
|
+
input_ids: list[str],
|
77
|
+
function_call_ids: list[str],
|
82
78
|
finalized_function: "modal._runtime.user_code_imports.FinalizedFunction",
|
83
|
-
function_inputs:
|
79
|
+
function_inputs: list[api_pb2.FunctionInput],
|
84
80
|
is_batched: bool,
|
85
81
|
client: _Client,
|
86
82
|
):
|
@@ -95,8 +91,8 @@ class IOContext:
|
|
95
91
|
async def create(
|
96
92
|
cls,
|
97
93
|
client: _Client,
|
98
|
-
finalized_functions:
|
99
|
-
inputs:
|
94
|
+
finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
|
95
|
+
inputs: list[tuple[str, str, api_pb2.FunctionInput]],
|
100
96
|
is_batched: bool,
|
101
97
|
) -> "IOContext":
|
102
98
|
assert len(inputs) >= 1 if is_batched else len(inputs) == 1
|
@@ -136,7 +132,7 @@ class IOContext:
|
|
136
132
|
# between creating a new task for an input and attaching the cancellation callback
|
137
133
|
logger.warning("Unexpected: Could not cancel input")
|
138
134
|
|
139
|
-
def _args_and_kwargs(self) ->
|
135
|
+
def _args_and_kwargs(self) -> tuple[tuple[Any, ...], dict[str, list[Any]]]:
|
140
136
|
# deserializing here instead of the constructor
|
141
137
|
# to make sure we handle user exceptions properly
|
142
138
|
# and don't retry
|
@@ -153,7 +149,7 @@ class IOContext:
|
|
153
149
|
param_names.append(param.name)
|
154
150
|
|
155
151
|
# aggregate args and kwargs of all inputs into a kwarg dict
|
156
|
-
kwargs_by_inputs:
|
152
|
+
kwargs_by_inputs: list[dict[str, Any]] = [{} for _ in range(len(self.input_ids))]
|
157
153
|
|
158
154
|
for i, (args, kwargs) in enumerate(deserialized_args):
|
159
155
|
# check that all batched inputs should have the same number of args and kwargs
|
@@ -187,7 +183,7 @@ class IOContext:
|
|
187
183
|
logger.debug(f"Finished input {self.input_ids}")
|
188
184
|
return res
|
189
185
|
|
190
|
-
def validate_output_data(self, data: Any) ->
|
186
|
+
def validate_output_data(self, data: Any) -> list[Any]:
|
191
187
|
if not self._is_batched:
|
192
188
|
return [data]
|
193
189
|
|
@@ -263,7 +259,7 @@ class _ContainerIOManager:
|
|
263
259
|
calls_completed: int
|
264
260
|
total_user_time: float
|
265
261
|
current_input_id: Optional[str]
|
266
|
-
current_inputs:
|
262
|
+
current_inputs: dict[str, IOContext] # input_id -> IOContext
|
267
263
|
current_input_started_at: Optional[float]
|
268
264
|
|
269
265
|
_target_concurrency: int
|
@@ -472,7 +468,7 @@ class _ContainerIOManager:
|
|
472
468
|
client=self._client,
|
473
469
|
)
|
474
470
|
|
475
|
-
async def get_serialized_function(self) ->
|
471
|
+
async def get_serialized_function(self) -> tuple[Optional[Any], Optional[Callable[..., Any]]]:
|
476
472
|
# Fetch the serialized function definition
|
477
473
|
request = api_pb2.FunctionGetSerializedRequest(function_id=self.function_id)
|
478
474
|
response = await self._client.stub.FunctionGetSerialized(request)
|
@@ -498,7 +494,7 @@ class _ContainerIOManager:
|
|
498
494
|
def serialize_data_format(self, obj: Any, data_format: int) -> bytes:
|
499
495
|
return serialize_data_format(obj, data_format)
|
500
496
|
|
501
|
-
async def format_blob_data(self, data: bytes) ->
|
497
|
+
async def format_blob_data(self, data: bytes) -> dict[str, Any]:
|
502
498
|
return (
|
503
499
|
{"data_blob_id": await blob_upload(data, self._client.stub)}
|
504
500
|
if len(data) > MAX_OBJECT_SIZE_BYTES
|
@@ -515,7 +511,7 @@ class _ContainerIOManager:
|
|
515
511
|
function_call_id: str,
|
516
512
|
start_index: int,
|
517
513
|
data_format: int,
|
518
|
-
messages_bytes:
|
514
|
+
messages_bytes: list[Any],
|
519
515
|
) -> None:
|
520
516
|
"""Put data onto the `data_out` stream of a function call.
|
521
517
|
|
@@ -523,7 +519,7 @@ class _ContainerIOManager:
|
|
523
519
|
was introduced as a performance optimization in client version 0.57, so older clients will
|
524
520
|
still use the previous Postgres-backed system based on `FunctionPutOutputs()`.
|
525
521
|
"""
|
526
|
-
data_chunks:
|
522
|
+
data_chunks: list[api_pb2.DataChunk] = []
|
527
523
|
for i, message_bytes in enumerate(messages_bytes):
|
528
524
|
chunk = api_pb2.DataChunk(data_format=data_format, index=start_index + i) # type: ignore
|
529
525
|
if len(message_bytes) > MAX_OBJECT_SIZE_BYTES:
|
@@ -588,7 +584,7 @@ class _ContainerIOManager:
|
|
588
584
|
self,
|
589
585
|
batch_max_size: int,
|
590
586
|
batch_wait_ms: int,
|
591
|
-
) -> AsyncIterator[
|
587
|
+
) -> AsyncIterator[list[tuple[str, str, api_pb2.FunctionInput]]]:
|
592
588
|
request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id)
|
593
589
|
iteration = 0
|
594
590
|
while self._fetching_inputs:
|
@@ -645,7 +641,7 @@ class _ContainerIOManager:
|
|
645
641
|
@synchronizer.no_io_translation
|
646
642
|
async def run_inputs_outputs(
|
647
643
|
self,
|
648
|
-
finalized_functions:
|
644
|
+
finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
|
649
645
|
batch_max_size: int = 0,
|
650
646
|
batch_wait_ms: int = 0,
|
651
647
|
) -> AsyncIterator[IOContext]:
|
@@ -675,7 +671,7 @@ class _ContainerIOManager:
|
|
675
671
|
io_context: IOContext,
|
676
672
|
started_at: float,
|
677
673
|
data_format: "modal_proto.api_pb2.DataFormat.ValueType",
|
678
|
-
results:
|
674
|
+
results: list[api_pb2.GenericResult],
|
679
675
|
) -> None:
|
680
676
|
output_created_at = time.time()
|
681
677
|
outputs = [
|
@@ -704,7 +700,7 @@ class _ContainerIOManager:
|
|
704
700
|
logger.info(err)
|
705
701
|
return self.serialize(SerializationError(err))
|
706
702
|
|
707
|
-
def serialize_traceback(self, exc: BaseException) ->
|
703
|
+
def serialize_traceback(self, exc: BaseException) -> tuple[Optional[bytes], Optional[bytes]]:
|
708
704
|
serialized_tb, tb_line_cache = None, None
|
709
705
|
|
710
706
|
try:
|
@@ -831,7 +827,7 @@ class _ContainerIOManager:
|
|
831
827
|
)
|
832
828
|
self.exit_context(started_at, io_context.input_ids)
|
833
829
|
|
834
|
-
def exit_context(self, started_at, input_ids:
|
830
|
+
def exit_context(self, started_at, input_ids: list[str]):
|
835
831
|
self.total_user_time += time.time() - started_at
|
836
832
|
self.calls_completed += 1
|
837
833
|
|
@@ -934,7 +930,7 @@ class _ContainerIOManager:
|
|
934
930
|
self._waiting_for_memory_snapshot = False
|
935
931
|
self.heartbeat_condition.notify_all()
|
936
932
|
|
937
|
-
async def volume_commit(self, volume_ids:
|
933
|
+
async def volume_commit(self, volume_ids: list[str]) -> None:
|
938
934
|
"""
|
939
935
|
Perform volume commit for given `volume_ids`.
|
940
936
|
Only used on container exit to persist uncommitted changes on behalf of user.
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# Copyright Modal Labs 2024
|
2
2
|
from contextvars import ContextVar
|
3
|
-
from typing import Callable,
|
3
|
+
from typing import Callable, Optional
|
4
4
|
|
5
5
|
from modal._utils.async_utils import synchronize_api
|
6
6
|
from modal.exception import InvalidError
|
@@ -71,7 +71,7 @@ def current_function_call_id() -> Optional[str]:
|
|
71
71
|
return None
|
72
72
|
|
73
73
|
|
74
|
-
def _set_current_context_ids(input_ids:
|
74
|
+
def _set_current_context_ids(input_ids: list[str], function_call_ids: list[str]) -> Callable[[], None]:
|
75
75
|
assert len(input_ids) == len(function_call_ids) and len(input_ids) > 0
|
76
76
|
input_id = input_ids[0]
|
77
77
|
function_call_id = function_call_ids[0]
|
modal/_runtime/telemetry.py
CHANGED
@@ -7,7 +7,6 @@ import socket
|
|
7
7
|
import sys
|
8
8
|
import threading
|
9
9
|
import time
|
10
|
-
import typing
|
11
10
|
import uuid
|
12
11
|
from importlib.util import find_spec, module_from_spec
|
13
12
|
from struct import pack
|
@@ -65,7 +64,7 @@ class InterceptedModuleLoader(importlib.abc.Loader):
|
|
65
64
|
|
66
65
|
|
67
66
|
class ImportInterceptor(importlib.abc.MetaPathFinder):
|
68
|
-
loading:
|
67
|
+
loading: dict[str, tuple[str, float]]
|
69
68
|
tracing_socket: socket.socket
|
70
69
|
events: queue.Queue
|
71
70
|
|
@@ -3,7 +3,7 @@ import importlib
|
|
3
3
|
import typing
|
4
4
|
from abc import ABCMeta, abstractmethod
|
5
5
|
from dataclasses import dataclass
|
6
|
-
from typing import Any, Callable,
|
6
|
+
from typing import Any, Callable, Optional
|
7
7
|
|
8
8
|
import modal._runtime.container_io_manager
|
9
9
|
import modal.cls
|
@@ -49,12 +49,12 @@ class Service(metaclass=ABCMeta):
|
|
49
49
|
|
50
50
|
user_cls_instance: Any
|
51
51
|
app: Optional["modal.app._App"]
|
52
|
-
code_deps: Optional[
|
52
|
+
code_deps: Optional[list["modal.object._Object"]]
|
53
53
|
|
54
54
|
@abstractmethod
|
55
55
|
def get_finalized_functions(
|
56
56
|
self, fun_def: api_pb2.Function, container_io_manager: "modal._runtime.container_io_manager.ContainerIOManager"
|
57
|
-
) ->
|
57
|
+
) -> dict[str, "FinalizedFunction"]:
|
58
58
|
...
|
59
59
|
|
60
60
|
|
@@ -99,13 +99,13 @@ def construct_webhook_callable(
|
|
99
99
|
class ImportedFunction(Service):
|
100
100
|
user_cls_instance: Any
|
101
101
|
app: Optional["modal.app._App"]
|
102
|
-
code_deps: Optional[
|
102
|
+
code_deps: Optional[list["modal.object._Object"]]
|
103
103
|
|
104
104
|
_user_defined_callable: Callable[..., Any]
|
105
105
|
|
106
106
|
def get_finalized_functions(
|
107
107
|
self, fun_def: api_pb2.Function, container_io_manager: "modal._runtime.container_io_manager.ContainerIOManager"
|
108
|
-
) ->
|
108
|
+
) -> dict[str, "FinalizedFunction"]:
|
109
109
|
# Check this property before we turn it into a method (overriden by webhooks)
|
110
110
|
is_async = get_is_async(self._user_defined_callable)
|
111
111
|
# Use the function definition for whether this is a generator (overriden by webhooks)
|
@@ -142,13 +142,13 @@ class ImportedFunction(Service):
|
|
142
142
|
class ImportedClass(Service):
|
143
143
|
user_cls_instance: Any
|
144
144
|
app: Optional["modal.app._App"]
|
145
|
-
code_deps: Optional[
|
145
|
+
code_deps: Optional[list["modal.object._Object"]]
|
146
146
|
|
147
|
-
_partial_functions:
|
147
|
+
_partial_functions: dict[str, "modal.partial_function._PartialFunction"]
|
148
148
|
|
149
149
|
def get_finalized_functions(
|
150
150
|
self, fun_def: api_pb2.Function, container_io_manager: "modal._runtime.container_io_manager.ContainerIOManager"
|
151
|
-
) ->
|
151
|
+
) -> dict[str, "FinalizedFunction"]:
|
152
152
|
finalized_functions = {}
|
153
153
|
for method_name, partial in self._partial_functions.items():
|
154
154
|
partial = synchronizer._translate_in(partial) # ugly
|
@@ -184,9 +184,7 @@ class ImportedClass(Service):
|
|
184
184
|
return finalized_functions
|
185
185
|
|
186
186
|
|
187
|
-
def get_user_class_instance(
|
188
|
-
cls: typing.Union[type, modal.cls.Cls], args: typing.Tuple, kwargs: Dict[str, Any]
|
189
|
-
) -> typing.Any:
|
187
|
+
def get_user_class_instance(cls: typing.Union[type, modal.cls.Cls], args: tuple, kwargs: dict[str, Any]) -> typing.Any:
|
190
188
|
"""Returns instance of the underlying class to be used as the `self`
|
191
189
|
|
192
190
|
The input `cls` can either be the raw Python class the user has declared ("user class"),
|
@@ -236,7 +234,7 @@ def import_single_function_service(
|
|
236
234
|
"""
|
237
235
|
user_defined_callable: Callable
|
238
236
|
function: Optional[_Function] = None
|
239
|
-
code_deps: Optional[
|
237
|
+
code_deps: Optional[list["modal.object._Object"]] = None
|
240
238
|
active_app: Optional[modal.app._App] = None
|
241
239
|
|
242
240
|
if ser_fun is not None:
|
@@ -311,7 +309,7 @@ def import_class_service(
|
|
311
309
|
See import_function.
|
312
310
|
"""
|
313
311
|
active_app: Optional["modal.app._App"]
|
314
|
-
code_deps: Optional[
|
312
|
+
code_deps: Optional[list["modal.object._Object"]]
|
315
313
|
cls: typing.Union[type, modal.cls.Cls]
|
316
314
|
|
317
315
|
if function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED:
|
modal/_serialization.py
CHANGED
@@ -398,10 +398,8 @@ PARAM_TYPE_MAPPING = {
|
|
398
398
|
}
|
399
399
|
|
400
400
|
|
401
|
-
def serialize_proto_params(
|
402
|
-
|
403
|
-
) -> bytes:
|
404
|
-
proto_params: typing.List[api_pb2.ClassParameterValue] = []
|
401
|
+
def serialize_proto_params(python_params: dict[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]) -> bytes:
|
402
|
+
proto_params: list[api_pb2.ClassParameterValue] = []
|
405
403
|
for schema_param in schema:
|
406
404
|
type_info = PARAM_TYPE_MAPPING.get(schema_param.type)
|
407
405
|
if not type_info:
|
@@ -426,9 +424,7 @@ def serialize_proto_params(
|
|
426
424
|
return proto_bytes
|
427
425
|
|
428
426
|
|
429
|
-
def deserialize_proto_params(
|
430
|
-
serialized_params: bytes, schema: typing.List[api_pb2.ClassParameterSpec]
|
431
|
-
) -> typing.Dict[str, Any]:
|
427
|
+
def deserialize_proto_params(serialized_params: bytes, schema: list[api_pb2.ClassParameterSpec]) -> dict[str, Any]:
|
432
428
|
proto_struct = api_pb2.ClassParameterSet()
|
433
429
|
proto_struct.ParseFromString(serialized_params)
|
434
430
|
value_by_name = {p.name: p for p in proto_struct.parameters}
|
modal/_traceback.py
CHANGED
@@ -8,15 +8,15 @@ import re
|
|
8
8
|
import sys
|
9
9
|
import traceback
|
10
10
|
from types import TracebackType
|
11
|
-
from typing import Any,
|
11
|
+
from typing import Any, Optional
|
12
12
|
|
13
13
|
from ._vendor.tblib import Traceback as TBLibTraceback
|
14
14
|
|
15
|
-
TBDictType =
|
16
|
-
LineCacheType =
|
15
|
+
TBDictType = dict[str, Any]
|
16
|
+
LineCacheType = dict[tuple[str, str], str]
|
17
17
|
|
18
18
|
|
19
|
-
def extract_traceback(exc: BaseException, task_id: str) ->
|
19
|
+
def extract_traceback(exc: BaseException, task_id: str) -> tuple[TBDictType, LineCacheType]:
|
20
20
|
"""Given an exception, extract a serializable traceback (with task ID markers included),
|
21
21
|
and a line cache that maps (filename, lineno) to line contents. The latter is used to show
|
22
22
|
a helpful traceback to the user, even if they don't have packages installed locally that
|
@@ -103,7 +103,7 @@ def traceback_contains_remote_call(tb: Optional[TracebackType]) -> bool:
|
|
103
103
|
return False
|
104
104
|
|
105
105
|
|
106
|
-
def print_exception(exc: Optional[
|
106
|
+
def print_exception(exc: Optional[type[BaseException]], value: Optional[BaseException], tb: Optional[TracebackType]):
|
107
107
|
"""Add backwards compatibility for printing exceptions with "notes" for Python<3.11."""
|
108
108
|
traceback.print_exception(exc, value, tb)
|
109
109
|
if sys.version_info < (3, 11) and value is not None:
|