modal 0.67.43__py3-none-any.whl → 0.68.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/__init__.py +2 -0
- modal/_container_entrypoint.py +4 -1
- modal/_ipython.py +3 -13
- modal/_runtime/asgi.py +4 -0
- modal/_runtime/container_io_manager.py +3 -0
- modal/_runtime/user_code_imports.py +17 -20
- modal/_traceback.py +16 -2
- modal/_utils/blob_utils.py +27 -92
- modal/_utils/bytes_io_segment_payload.py +97 -0
- modal/_utils/function_utils.py +5 -1
- modal/_utils/grpc_testing.py +6 -2
- modal/_utils/hash_utils.py +51 -10
- modal/_utils/http_utils.py +19 -10
- modal/_utils/{pattern_matcher.py → pattern_utils.py} +1 -70
- modal/_utils/shell_utils.py +11 -5
- modal/cli/_traceback.py +11 -4
- modal/cli/run.py +25 -12
- modal/client.py +6 -37
- modal/client.pyi +2 -6
- modal/cls.py +132 -62
- modal/cls.pyi +13 -7
- modal/exception.py +20 -0
- modal/file_io.py +380 -0
- modal/file_io.pyi +185 -0
- modal/file_pattern_matcher.py +121 -0
- modal/functions.py +33 -11
- modal/functions.pyi +11 -9
- modal/image.py +88 -8
- modal/image.pyi +20 -4
- modal/mount.py +49 -9
- modal/mount.pyi +19 -4
- modal/network_file_system.py +4 -1
- modal/object.py +4 -2
- modal/partial_function.py +22 -10
- modal/partial_function.pyi +10 -2
- modal/runner.py +5 -4
- modal/runner.pyi +2 -1
- modal/sandbox.py +40 -0
- modal/sandbox.pyi +18 -0
- modal/volume.py +5 -1
- {modal-0.67.43.dist-info → modal-0.68.24.dist-info}/METADATA +2 -2
- {modal-0.67.43.dist-info → modal-0.68.24.dist-info}/RECORD +52 -48
- modal_docs/gen_reference_docs.py +1 -0
- modal_proto/api.proto +33 -1
- modal_proto/api_pb2.py +813 -737
- modal_proto/api_pb2.pyi +160 -13
- modal_version/__init__.py +1 -1
- modal_version/_version_generated.py +1 -1
- {modal-0.67.43.dist-info → modal-0.68.24.dist-info}/LICENSE +0 -0
- {modal-0.67.43.dist-info → modal-0.68.24.dist-info}/WHEEL +0 -0
- {modal-0.67.43.dist-info → modal-0.68.24.dist-info}/entry_points.txt +0 -0
- {modal-0.67.43.dist-info → modal-0.68.24.dist-info}/top_level.txt +0 -0
modal/__init__.py
CHANGED
@@ -17,6 +17,7 @@ try:
|
|
17
17
|
from .cls import Cls, parameter
|
18
18
|
from .dict import Dict
|
19
19
|
from .exception import Error
|
20
|
+
from .file_pattern_matcher import FilePatternMatcher
|
20
21
|
from .functions import Function
|
21
22
|
from .image import Image
|
22
23
|
from .mount import Mount
|
@@ -48,6 +49,7 @@ __all__ = [
|
|
48
49
|
"Cron",
|
49
50
|
"Dict",
|
50
51
|
"Error",
|
52
|
+
"FilePatternMatcher",
|
51
53
|
"Function",
|
52
54
|
"Image",
|
53
55
|
"Mount",
|
modal/_container_entrypoint.py
CHANGED
@@ -6,7 +6,7 @@ from modal._runtime.user_code_imports import Service, import_class_service, impo
|
|
6
6
|
|
7
7
|
telemetry_socket = os.environ.get("MODAL_TELEMETRY_SOCKET")
|
8
8
|
if telemetry_socket:
|
9
|
-
from
|
9
|
+
from ._runtime.telemetry import instrument_imports
|
10
10
|
|
11
11
|
instrument_imports(telemetry_socket)
|
12
12
|
|
@@ -415,6 +415,9 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
|
|
415
415
|
|
416
416
|
_client: _Client = synchronizer._translate_in(client) # TODO(erikbern): ugly
|
417
417
|
|
418
|
+
# Call ContainerHello - currently a noop but might be used later for things
|
419
|
+
container_io_manager.hello()
|
420
|
+
|
418
421
|
with container_io_manager.heartbeats(is_snapshotting_function), UserCodeEventLoop() as event_loop:
|
419
422
|
# If this is a serialized function, fetch the definition from the server
|
420
423
|
if function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED:
|
modal/_ipython.py
CHANGED
@@ -1,21 +1,11 @@
|
|
1
1
|
# Copyright Modal Labs 2022
|
2
2
|
import sys
|
3
|
-
import warnings
|
4
|
-
|
5
|
-
ipy_outstream = None
|
6
|
-
try:
|
7
|
-
with warnings.catch_warnings():
|
8
|
-
warnings.simplefilter("ignore")
|
9
|
-
import ipykernel.iostream
|
10
|
-
|
11
|
-
ipy_outstream = ipykernel.iostream.OutStream
|
12
|
-
except ImportError:
|
13
|
-
pass
|
14
3
|
|
15
4
|
|
16
5
|
def is_notebook(stdout=None):
|
17
|
-
|
6
|
+
ipykernel_iostream = sys.modules.get("ipykernel.iostream")
|
7
|
+
if ipykernel_iostream is None:
|
18
8
|
return False
|
19
9
|
if stdout is None:
|
20
10
|
stdout = sys.stdout
|
21
|
-
return isinstance(stdout,
|
11
|
+
return isinstance(stdout, ipykernel_iostream.OutStream)
|
modal/_runtime/asgi.py
CHANGED
@@ -1,4 +1,8 @@
|
|
1
1
|
# Copyright Modal Labs 2022
|
2
|
+
|
3
|
+
# Note: this module isn't imported unless it's needed.
|
4
|
+
# This is because aiohttp is a pretty big dependency that adds significant latency when imported
|
5
|
+
|
2
6
|
import asyncio
|
3
7
|
from collections.abc import AsyncGenerator
|
4
8
|
from typing import Any, Callable, NoReturn, Optional, cast
|
@@ -335,6 +335,9 @@ class _ContainerIOManager:
|
|
335
335
|
"""Only used for tests."""
|
336
336
|
cls._singleton = None
|
337
337
|
|
338
|
+
async def hello(self):
|
339
|
+
await self._client.stub.ContainerHello(Empty())
|
340
|
+
|
338
341
|
async def _run_heartbeat_loop(self):
|
339
342
|
while 1:
|
340
343
|
t0 = time.monotonic()
|
@@ -9,15 +9,6 @@ import modal._runtime.container_io_manager
|
|
9
9
|
import modal.cls
|
10
10
|
import modal.object
|
11
11
|
from modal import Function
|
12
|
-
from modal._runtime.asgi import (
|
13
|
-
LifespanManager,
|
14
|
-
asgi_app_wrapper,
|
15
|
-
get_ip_address,
|
16
|
-
wait_for_web_server,
|
17
|
-
web_server_proxy,
|
18
|
-
webhook_asgi_app,
|
19
|
-
wsgi_app_wrapper,
|
20
|
-
)
|
21
12
|
from modal._utils.async_utils import synchronizer
|
22
13
|
from modal._utils.function_utils import LocalFunctionError, is_async as get_is_async, is_global_object
|
23
14
|
from modal.exception import ExecutionError, InvalidError
|
@@ -28,6 +19,7 @@ from modal_proto import api_pb2
|
|
28
19
|
if typing.TYPE_CHECKING:
|
29
20
|
import modal.app
|
30
21
|
import modal.partial_function
|
22
|
+
from modal._runtime.asgi import LifespanManager
|
31
23
|
|
32
24
|
|
33
25
|
@dataclass
|
@@ -36,7 +28,7 @@ class FinalizedFunction:
|
|
36
28
|
is_async: bool
|
37
29
|
is_generator: bool
|
38
30
|
data_format: int # api_pb2.DataFormat
|
39
|
-
lifespan_manager: Optional[LifespanManager] = None
|
31
|
+
lifespan_manager: Optional["LifespanManager"] = None
|
40
32
|
|
41
33
|
|
42
34
|
class Service(metaclass=ABCMeta):
|
@@ -63,19 +55,22 @@ def construct_webhook_callable(
|
|
63
55
|
webhook_config: api_pb2.WebhookConfig,
|
64
56
|
container_io_manager: "modal._runtime.container_io_manager.ContainerIOManager",
|
65
57
|
):
|
58
|
+
# Note: aiohttp is a significant dependency of the `asgi` module, so we import it locally
|
59
|
+
from modal._runtime import asgi
|
60
|
+
|
66
61
|
# For webhooks, the user function is used to construct an asgi app:
|
67
62
|
if webhook_config.type == api_pb2.WEBHOOK_TYPE_ASGI_APP:
|
68
63
|
# Function returns an asgi_app, which we can use as a callable.
|
69
|
-
return asgi_app_wrapper(user_defined_callable(), container_io_manager)
|
64
|
+
return asgi.asgi_app_wrapper(user_defined_callable(), container_io_manager)
|
70
65
|
|
71
66
|
elif webhook_config.type == api_pb2.WEBHOOK_TYPE_WSGI_APP:
|
72
|
-
# Function returns an wsgi_app, which we can use as a callable
|
73
|
-
return wsgi_app_wrapper(user_defined_callable(), container_io_manager)
|
67
|
+
# Function returns an wsgi_app, which we can use as a callable
|
68
|
+
return asgi.wsgi_app_wrapper(user_defined_callable(), container_io_manager)
|
74
69
|
|
75
70
|
elif webhook_config.type == api_pb2.WEBHOOK_TYPE_FUNCTION:
|
76
71
|
# Function is a webhook without an ASGI app. Create one for it.
|
77
|
-
return asgi_app_wrapper(
|
78
|
-
webhook_asgi_app(user_defined_callable, webhook_config.method, webhook_config.web_endpoint_docs),
|
72
|
+
return asgi.asgi_app_wrapper(
|
73
|
+
asgi.webhook_asgi_app(user_defined_callable, webhook_config.method, webhook_config.web_endpoint_docs),
|
79
74
|
container_io_manager,
|
80
75
|
)
|
81
76
|
|
@@ -86,11 +81,11 @@ def construct_webhook_callable(
|
|
86
81
|
# We intentionally try to connect to the external interface instead of the loopback
|
87
82
|
# interface here so users are forced to expose the server. This allows us to potentially
|
88
83
|
# change the implementation to use an external bridge in the future.
|
89
|
-
host = get_ip_address(b"eth0")
|
84
|
+
host = asgi.get_ip_address(b"eth0")
|
90
85
|
port = webhook_config.web_server_port
|
91
86
|
startup_timeout = webhook_config.web_server_startup_timeout
|
92
|
-
wait_for_web_server(host, port, timeout=startup_timeout)
|
93
|
-
return asgi_app_wrapper(web_server_proxy(host, port), container_io_manager)
|
87
|
+
asgi.wait_for_web_server(host, port, timeout=startup_timeout)
|
88
|
+
return asgi.asgi_app_wrapper(asgi.web_server_proxy(host, port), container_io_manager)
|
94
89
|
else:
|
95
90
|
raise InvalidError(f"Unrecognized web endpoint type {webhook_config.type}")
|
96
91
|
|
@@ -269,10 +264,12 @@ def import_single_function_service(
|
|
269
264
|
# The cls decorator is in global scope
|
270
265
|
_cls = synchronizer._translate_in(cls)
|
271
266
|
user_defined_callable = _cls._callables[fun_name]
|
272
|
-
function = _cls._method_functions.get(
|
267
|
+
function = _cls._method_functions.get(
|
268
|
+
fun_name
|
269
|
+
) # bound to the class service function - there is no instance
|
273
270
|
active_app = _cls._app
|
274
271
|
else:
|
275
|
-
# This is
|
272
|
+
# This is non-decorated class
|
276
273
|
user_defined_callable = getattr(cls, fun_name)
|
277
274
|
else:
|
278
275
|
raise InvalidError(f"Invalid function qualname {qual_name}")
|
modal/_traceback.py
CHANGED
@@ -1,16 +1,21 @@
|
|
1
1
|
# Copyright Modal Labs 2022
|
2
|
-
"""Helper functions related to operating on traceback objects.
|
2
|
+
"""Helper functions related to operating on exceptions, warnings, and traceback objects.
|
3
3
|
|
4
4
|
Functions related to *displaying* tracebacks should go in `modal/cli/_traceback.py`
|
5
5
|
so that Rich is not a dependency of the container Client.
|
6
6
|
"""
|
7
|
+
|
7
8
|
import re
|
8
9
|
import sys
|
9
10
|
import traceback
|
11
|
+
import warnings
|
10
12
|
from types import TracebackType
|
11
|
-
from typing import Any, Optional
|
13
|
+
from typing import Any, Iterable, Optional
|
14
|
+
|
15
|
+
from modal_proto import api_pb2
|
12
16
|
|
13
17
|
from ._vendor.tblib import Traceback as TBLibTraceback
|
18
|
+
from .exception import ServerWarning
|
14
19
|
|
15
20
|
TBDictType = dict[str, Any]
|
16
21
|
LineCacheType = dict[tuple[str, str], str]
|
@@ -109,3 +114,12 @@ def print_exception(exc: Optional[type[BaseException]], value: Optional[BaseExce
|
|
109
114
|
if sys.version_info < (3, 11) and value is not None:
|
110
115
|
notes = getattr(value, "__notes__", [])
|
111
116
|
print(*notes, sep="\n", file=sys.stderr)
|
117
|
+
|
118
|
+
|
119
|
+
def print_server_warnings(server_warnings: Iterable[api_pb2.Warning]):
|
120
|
+
"""Issue a warning originating from the server with empty metadata about local origin.
|
121
|
+
|
122
|
+
When using the Modal CLI, these warnings should get caught and coerced into Rich panels.
|
123
|
+
"""
|
124
|
+
for warning in server_warnings:
|
125
|
+
warnings.warn_explicit(warning.message, ServerWarning, "<modal-server>", 0)
|
modal/_utils/blob_utils.py
CHANGED
@@ -9,22 +9,22 @@ import time
|
|
9
9
|
from collections.abc import AsyncIterator
|
10
10
|
from contextlib import AbstractContextManager, contextmanager
|
11
11
|
from pathlib import Path, PurePosixPath
|
12
|
-
from typing import Any, BinaryIO, Callable, Optional, Union
|
12
|
+
from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Optional, Union
|
13
13
|
from urllib.parse import urlparse
|
14
14
|
|
15
|
-
from aiohttp import BytesIOPayload
|
16
|
-
from aiohttp.abc import AbstractStreamWriter
|
17
|
-
|
18
15
|
from modal_proto import api_pb2
|
19
16
|
from modal_proto.modal_api_grpc import ModalClientModal
|
20
17
|
|
21
18
|
from ..exception import ExecutionError
|
22
19
|
from .async_utils import TaskContext, retry
|
23
20
|
from .grpc_utils import retry_transient_errors
|
24
|
-
from .hash_utils import UploadHashes,
|
21
|
+
from .hash_utils import UploadHashes, get_upload_hashes
|
25
22
|
from .http_utils import ClientSessionRegistry
|
26
23
|
from .logger import logger
|
27
24
|
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from .bytes_io_segment_payload import BytesIOSegmentPayload
|
27
|
+
|
28
28
|
# Max size for function inputs and outputs.
|
29
29
|
MAX_OBJECT_SIZE_BYTES = 2 * 1024 * 1024 # 2 MiB
|
30
30
|
|
@@ -38,93 +38,16 @@ BLOB_MAX_PARALLELISM = 10
|
|
38
38
|
# read ~16MiB chunks by default
|
39
39
|
DEFAULT_SEGMENT_CHUNK_SIZE = 2**24
|
40
40
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
Adds:
|
46
|
-
* read limit using remaining_bytes, in order to split files across streams
|
47
|
-
* larger read chunk (to prevent excessive read contention between parts)
|
48
|
-
* calculates an md5 for the segment
|
49
|
-
|
50
|
-
Feels like this should be in some standard lib...
|
51
|
-
"""
|
52
|
-
|
53
|
-
def __init__(
|
54
|
-
self,
|
55
|
-
bytes_io: BinaryIO, # should *not* be shared as IO position modification is not locked
|
56
|
-
segment_start: int,
|
57
|
-
segment_length: int,
|
58
|
-
chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
|
59
|
-
progress_report_cb: Optional[Callable] = None,
|
60
|
-
):
|
61
|
-
# not thread safe constructor!
|
62
|
-
super().__init__(bytes_io)
|
63
|
-
self.initial_seek_pos = bytes_io.tell()
|
64
|
-
self.segment_start = segment_start
|
65
|
-
self.segment_length = segment_length
|
66
|
-
# seek to start of file segment we are interested in, in order to make .size() evaluate correctly
|
67
|
-
self._value.seek(self.initial_seek_pos + segment_start)
|
68
|
-
assert self.segment_length <= super().size
|
69
|
-
self.chunk_size = chunk_size
|
70
|
-
self.progress_report_cb = progress_report_cb or (lambda *_, **__: None)
|
71
|
-
self.reset_state()
|
72
|
-
|
73
|
-
def reset_state(self):
|
74
|
-
self._md5_checksum = hashlib.md5()
|
75
|
-
self.num_bytes_read = 0
|
76
|
-
self._value.seek(self.initial_seek_pos)
|
77
|
-
|
78
|
-
@contextmanager
|
79
|
-
def reset_on_error(self):
|
80
|
-
try:
|
81
|
-
yield
|
82
|
-
except Exception as exc:
|
83
|
-
try:
|
84
|
-
self.progress_report_cb(reset=True)
|
85
|
-
except Exception as cb_exc:
|
86
|
-
raise cb_exc from exc
|
87
|
-
raise exc
|
88
|
-
finally:
|
89
|
-
self.reset_state()
|
90
|
-
|
91
|
-
@property
|
92
|
-
def size(self) -> int:
|
93
|
-
return self.segment_length
|
94
|
-
|
95
|
-
def md5_checksum(self):
|
96
|
-
return self._md5_checksum
|
97
|
-
|
98
|
-
async def write(self, writer: AbstractStreamWriter):
|
99
|
-
loop = asyncio.get_event_loop()
|
100
|
-
|
101
|
-
async def safe_read():
|
102
|
-
read_start = self.initial_seek_pos + self.segment_start + self.num_bytes_read
|
103
|
-
self._value.seek(read_start)
|
104
|
-
num_bytes = min(self.chunk_size, self.remaining_bytes())
|
105
|
-
chunk = await loop.run_in_executor(None, self._value.read, num_bytes)
|
106
|
-
|
107
|
-
await loop.run_in_executor(None, self._md5_checksum.update, chunk)
|
108
|
-
self.num_bytes_read += len(chunk)
|
109
|
-
return chunk
|
110
|
-
|
111
|
-
chunk = await safe_read()
|
112
|
-
while chunk and self.remaining_bytes() > 0:
|
113
|
-
await writer.write(chunk)
|
114
|
-
self.progress_report_cb(advance=len(chunk))
|
115
|
-
chunk = await safe_read()
|
116
|
-
if chunk:
|
117
|
-
await writer.write(chunk)
|
118
|
-
self.progress_report_cb(advance=len(chunk))
|
119
|
-
|
120
|
-
def remaining_bytes(self):
|
121
|
-
return self.segment_length - self.num_bytes_read
|
41
|
+
# Files larger than this will be multipart uploaded. The server might request multipart upload for smaller files as
|
42
|
+
# well, but the limit will never be raised.
|
43
|
+
# TODO(dano): remove this once we stop requiring md5 for blobs
|
44
|
+
MULTIPART_UPLOAD_THRESHOLD = 1024**3
|
122
45
|
|
123
46
|
|
124
47
|
@retry(n_attempts=5, base_delay=0.5, timeout=None)
|
125
48
|
async def _upload_to_s3_url(
|
126
49
|
upload_url,
|
127
|
-
payload: BytesIOSegmentPayload,
|
50
|
+
payload: "BytesIOSegmentPayload",
|
128
51
|
content_md5_b64: Optional[str] = None,
|
129
52
|
content_type: Optional[str] = "application/octet-stream", # set to None to force omission of ContentType header
|
130
53
|
) -> str:
|
@@ -180,6 +103,8 @@ async def perform_multipart_upload(
|
|
180
103
|
upload_chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
|
181
104
|
progress_report_cb: Optional[Callable] = None,
|
182
105
|
) -> None:
|
106
|
+
from .bytes_io_segment_payload import BytesIOSegmentPayload
|
107
|
+
|
183
108
|
upload_coros = []
|
184
109
|
file_offset = 0
|
185
110
|
num_bytes_left = content_length
|
@@ -273,6 +198,8 @@ async def _blob_upload(
|
|
273
198
|
progress_report_cb=progress_report_cb,
|
274
199
|
)
|
275
200
|
else:
|
201
|
+
from .bytes_io_segment_payload import BytesIOSegmentPayload
|
202
|
+
|
276
203
|
payload = BytesIOSegmentPayload(
|
277
204
|
data, segment_start=0, segment_length=content_length, progress_report_cb=progress_report_cb
|
278
205
|
)
|
@@ -305,9 +232,13 @@ async def blob_upload(payload: bytes, stub: ModalClientModal) -> str:
|
|
305
232
|
|
306
233
|
|
307
234
|
async def blob_upload_file(
|
308
|
-
file_obj: BinaryIO,
|
235
|
+
file_obj: BinaryIO,
|
236
|
+
stub: ModalClientModal,
|
237
|
+
progress_report_cb: Optional[Callable] = None,
|
238
|
+
sha256_hex: Optional[str] = None,
|
239
|
+
md5_hex: Optional[str] = None,
|
309
240
|
) -> str:
|
310
|
-
upload_hashes = get_upload_hashes(file_obj)
|
241
|
+
upload_hashes = get_upload_hashes(file_obj, sha256_hex=sha256_hex, md5_hex=md5_hex)
|
311
242
|
return await _blob_upload(upload_hashes, file_obj, stub, progress_report_cb)
|
312
243
|
|
313
244
|
|
@@ -366,6 +297,7 @@ class FileUploadSpec:
|
|
366
297
|
use_blob: bool
|
367
298
|
content: Optional[bytes] # typically None if using blob, required otherwise
|
368
299
|
sha256_hex: str
|
300
|
+
md5_hex: str
|
369
301
|
mode: int # file permission bits (last 12 bits of st_mode)
|
370
302
|
size: int
|
371
303
|
|
@@ -383,13 +315,15 @@ def _get_file_upload_spec(
|
|
383
315
|
fp.seek(0)
|
384
316
|
|
385
317
|
if size >= LARGE_FILE_LIMIT:
|
318
|
+
# TODO(dano): remove the placeholder md5 once we stop requiring md5 for blobs
|
319
|
+
md5_hex = "baadbaadbaadbaadbaadbaadbaadbaad" if size > MULTIPART_UPLOAD_THRESHOLD else None
|
386
320
|
use_blob = True
|
387
321
|
content = None
|
388
|
-
|
322
|
+
hashes = get_upload_hashes(fp, md5_hex=md5_hex)
|
389
323
|
else:
|
390
324
|
use_blob = False
|
391
325
|
content = fp.read()
|
392
|
-
|
326
|
+
hashes = get_upload_hashes(content)
|
393
327
|
|
394
328
|
return FileUploadSpec(
|
395
329
|
source=source,
|
@@ -397,7 +331,8 @@ def _get_file_upload_spec(
|
|
397
331
|
mount_filename=mount_filename.as_posix(),
|
398
332
|
use_blob=use_blob,
|
399
333
|
content=content,
|
400
|
-
sha256_hex=sha256_hex,
|
334
|
+
sha256_hex=hashes.sha256_hex(),
|
335
|
+
md5_hex=hashes.md5_hex(),
|
401
336
|
mode=mode & 0o7777,
|
402
337
|
size=size,
|
403
338
|
)
|
@@ -0,0 +1,97 @@
|
|
1
|
+
# Copyright Modal Labs 2024
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import hashlib
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from typing import BinaryIO, Callable, Optional
|
7
|
+
|
8
|
+
# Note: this module needs to import aiohttp in global scope
|
9
|
+
# This takes about 50ms and isn't needed in many cases for Modal execution
|
10
|
+
# To avoid this, we import it in local scope when needed (blob_utils.py)
|
11
|
+
from aiohttp import BytesIOPayload
|
12
|
+
from aiohttp.abc import AbstractStreamWriter
|
13
|
+
|
14
|
+
# read ~16MiB chunks by default
|
15
|
+
DEFAULT_SEGMENT_CHUNK_SIZE = 2**24
|
16
|
+
|
17
|
+
|
18
|
+
class BytesIOSegmentPayload(BytesIOPayload):
|
19
|
+
"""Modified bytes payload for concurrent sends of chunks from the same file.
|
20
|
+
|
21
|
+
Adds:
|
22
|
+
* read limit using remaining_bytes, in order to split files across streams
|
23
|
+
* larger read chunk (to prevent excessive read contention between parts)
|
24
|
+
* calculates an md5 for the segment
|
25
|
+
|
26
|
+
Feels like this should be in some standard lib...
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
bytes_io: BinaryIO, # should *not* be shared as IO position modification is not locked
|
32
|
+
segment_start: int,
|
33
|
+
segment_length: int,
|
34
|
+
chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
|
35
|
+
progress_report_cb: Optional[Callable] = None,
|
36
|
+
):
|
37
|
+
# not thread safe constructor!
|
38
|
+
super().__init__(bytes_io)
|
39
|
+
self.initial_seek_pos = bytes_io.tell()
|
40
|
+
self.segment_start = segment_start
|
41
|
+
self.segment_length = segment_length
|
42
|
+
# seek to start of file segment we are interested in, in order to make .size() evaluate correctly
|
43
|
+
self._value.seek(self.initial_seek_pos + segment_start)
|
44
|
+
assert self.segment_length <= super().size
|
45
|
+
self.chunk_size = chunk_size
|
46
|
+
self.progress_report_cb = progress_report_cb or (lambda *_, **__: None)
|
47
|
+
self.reset_state()
|
48
|
+
|
49
|
+
def reset_state(self):
|
50
|
+
self._md5_checksum = hashlib.md5()
|
51
|
+
self.num_bytes_read = 0
|
52
|
+
self._value.seek(self.initial_seek_pos)
|
53
|
+
|
54
|
+
@contextmanager
|
55
|
+
def reset_on_error(self):
|
56
|
+
try:
|
57
|
+
yield
|
58
|
+
except Exception as exc:
|
59
|
+
try:
|
60
|
+
self.progress_report_cb(reset=True)
|
61
|
+
except Exception as cb_exc:
|
62
|
+
raise cb_exc from exc
|
63
|
+
raise exc
|
64
|
+
finally:
|
65
|
+
self.reset_state()
|
66
|
+
|
67
|
+
@property
|
68
|
+
def size(self) -> int:
|
69
|
+
return self.segment_length
|
70
|
+
|
71
|
+
def md5_checksum(self):
|
72
|
+
return self._md5_checksum
|
73
|
+
|
74
|
+
async def write(self, writer: "AbstractStreamWriter"):
|
75
|
+
loop = asyncio.get_event_loop()
|
76
|
+
|
77
|
+
async def safe_read():
|
78
|
+
read_start = self.initial_seek_pos + self.segment_start + self.num_bytes_read
|
79
|
+
self._value.seek(read_start)
|
80
|
+
num_bytes = min(self.chunk_size, self.remaining_bytes())
|
81
|
+
chunk = await loop.run_in_executor(None, self._value.read, num_bytes)
|
82
|
+
|
83
|
+
await loop.run_in_executor(None, self._md5_checksum.update, chunk)
|
84
|
+
self.num_bytes_read += len(chunk)
|
85
|
+
return chunk
|
86
|
+
|
87
|
+
chunk = await safe_read()
|
88
|
+
while chunk and self.remaining_bytes() > 0:
|
89
|
+
await writer.write(chunk)
|
90
|
+
self.progress_report_cb(advance=len(chunk))
|
91
|
+
chunk = await safe_read()
|
92
|
+
if chunk:
|
93
|
+
await writer.write(chunk)
|
94
|
+
self.progress_report_cb(advance=len(chunk))
|
95
|
+
|
96
|
+
def remaining_bytes(self):
|
97
|
+
return self.segment_length - self.num_bytes_read
|
modal/_utils/function_utils.py
CHANGED
@@ -99,7 +99,11 @@ def get_function_type(is_generator: Optional[bool]) -> "api_pb2.Function.Functio
|
|
99
99
|
|
100
100
|
|
101
101
|
class FunctionInfo:
|
102
|
-
"""Class that helps us extract a bunch of information about a function.
|
102
|
+
"""Class that helps us extract a bunch of information about a locally defined function.
|
103
|
+
|
104
|
+
Used for populating the definition of a remote function, and for making .local() calls
|
105
|
+
on a host with the local definition available.
|
106
|
+
"""
|
103
107
|
|
104
108
|
raw_f: Optional[Callable[..., Any]] # if None - this is a "class service function"
|
105
109
|
function_name: str
|
modal/_utils/grpc_testing.py
CHANGED
@@ -50,7 +50,7 @@ def patch_mock_servicer(cls):
|
|
50
50
|
|
51
51
|
@contextlib.contextmanager
|
52
52
|
def intercept(servicer):
|
53
|
-
ctx = InterceptionContext()
|
53
|
+
ctx = InterceptionContext(servicer)
|
54
54
|
servicer.interception_context = ctx
|
55
55
|
yield ctx
|
56
56
|
ctx._assert_responses_consumed()
|
@@ -101,7 +101,8 @@ class ResponseNotConsumed(Exception):
|
|
101
101
|
|
102
102
|
|
103
103
|
class InterceptionContext:
|
104
|
-
def __init__(self):
|
104
|
+
def __init__(self, servicer):
|
105
|
+
self._servicer = servicer
|
105
106
|
self.calls: list[tuple[str, Any]] = [] # List[Tuple[method_name, message]]
|
106
107
|
self.custom_responses: dict[str, list[tuple[Callable[[Any], bool], list[Any]]]] = defaultdict(list)
|
107
108
|
self.custom_defaults: dict[str, Callable[["MockClientServicer", grpclib.server.Stream], Awaitable[None]]] = {}
|
@@ -149,6 +150,9 @@ class InterceptionContext:
|
|
149
150
|
raise KeyError(f"No message of that type in call list: {self.calls}")
|
150
151
|
|
151
152
|
def get_requests(self, method_name: str) -> list[Any]:
|
153
|
+
if not hasattr(self._servicer, method_name):
|
154
|
+
# we check this to prevent things like `assert ctx.get_requests("ASdfFunctionCreate") == 0` passing
|
155
|
+
raise ValueError(f"{method_name} not in MockServicer - did you spell it right?")
|
152
156
|
return [msg for _method_name, msg in self.calls if _method_name == method_name]
|
153
157
|
|
154
158
|
def _add_recv(self, method_name: str, msg):
|
modal/_utils/hash_utils.py
CHANGED
@@ -2,12 +2,15 @@
|
|
2
2
|
import base64
|
3
3
|
import dataclasses
|
4
4
|
import hashlib
|
5
|
-
|
5
|
+
import time
|
6
|
+
from typing import BinaryIO, Callable, Optional, Sequence, Union
|
6
7
|
|
7
|
-
|
8
|
+
from modal.config import logger
|
8
9
|
|
10
|
+
HASH_CHUNK_SIZE = 65536
|
9
11
|
|
10
|
-
|
12
|
+
|
13
|
+
def _update(hashers: Sequence[Callable[[bytes], None]], data: Union[bytes, BinaryIO]) -> None:
|
11
14
|
if isinstance(data, bytes):
|
12
15
|
for hasher in hashers:
|
13
16
|
hasher(data)
|
@@ -26,20 +29,26 @@ def _update(hashers: list[Callable[[bytes], None]], data: Union[bytes, BinaryIO]
|
|
26
29
|
|
27
30
|
|
28
31
|
def get_sha256_hex(data: Union[bytes, BinaryIO]) -> str:
|
32
|
+
t0 = time.monotonic()
|
29
33
|
hasher = hashlib.sha256()
|
30
34
|
_update([hasher.update], data)
|
35
|
+
logger.debug("get_sha256_hex took %.3fs", time.monotonic() - t0)
|
31
36
|
return hasher.hexdigest()
|
32
37
|
|
33
38
|
|
34
39
|
def get_sha256_base64(data: Union[bytes, BinaryIO]) -> str:
|
40
|
+
t0 = time.monotonic()
|
35
41
|
hasher = hashlib.sha256()
|
36
42
|
_update([hasher.update], data)
|
43
|
+
logger.debug("get_sha256_base64 took %.3fs", time.monotonic() - t0)
|
37
44
|
return base64.b64encode(hasher.digest()).decode("ascii")
|
38
45
|
|
39
46
|
|
40
47
|
def get_md5_base64(data: Union[bytes, BinaryIO]) -> str:
|
48
|
+
t0 = time.monotonic()
|
41
49
|
hasher = hashlib.md5()
|
42
50
|
_update([hasher.update], data)
|
51
|
+
logger.debug("get_md5_base64 took %.3fs", time.monotonic() - t0)
|
43
52
|
return base64.b64encode(hasher.digest()).decode("utf-8")
|
44
53
|
|
45
54
|
|
@@ -48,12 +57,44 @@ class UploadHashes:
|
|
48
57
|
md5_base64: str
|
49
58
|
sha256_base64: str
|
50
59
|
|
60
|
+
def md5_hex(self) -> str:
|
61
|
+
return base64.b64decode(self.md5_base64).hex()
|
62
|
+
|
63
|
+
def sha256_hex(self) -> str:
|
64
|
+
return base64.b64decode(self.sha256_base64).hex()
|
65
|
+
|
66
|
+
|
67
|
+
def get_upload_hashes(
|
68
|
+
data: Union[bytes, BinaryIO], sha256_hex: Optional[str] = None, md5_hex: Optional[str] = None
|
69
|
+
) -> UploadHashes:
|
70
|
+
t0 = time.monotonic()
|
71
|
+
hashers = {}
|
72
|
+
|
73
|
+
if not sha256_hex:
|
74
|
+
sha256 = hashlib.sha256()
|
75
|
+
hashers["sha256"] = sha256
|
76
|
+
if not md5_hex:
|
77
|
+
md5 = hashlib.md5()
|
78
|
+
hashers["md5"] = md5
|
79
|
+
|
80
|
+
if hashers:
|
81
|
+
updaters = [h.update for h in hashers.values()]
|
82
|
+
_update(updaters, data)
|
51
83
|
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
84
|
+
if sha256_hex:
|
85
|
+
sha256_base64 = base64.b64encode(bytes.fromhex(sha256_hex)).decode("ascii")
|
86
|
+
else:
|
87
|
+
sha256_base64 = base64.b64encode(hashers["sha256"].digest()).decode("ascii")
|
88
|
+
|
89
|
+
if md5_hex:
|
90
|
+
md5_base64 = base64.b64encode(bytes.fromhex(md5_hex)).decode("ascii")
|
91
|
+
else:
|
92
|
+
md5_base64 = base64.b64encode(hashers["md5"].digest()).decode("ascii")
|
93
|
+
|
94
|
+
hashes = UploadHashes(
|
95
|
+
md5_base64=md5_base64,
|
96
|
+
sha256_base64=sha256_base64,
|
59
97
|
)
|
98
|
+
|
99
|
+
logger.debug("get_upload_hashes took %.3fs (%s)", time.monotonic() - t0, hashers.keys())
|
100
|
+
return hashes
|