modal 0.62.115__py3-none-any.whl → 0.72.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/__init__.py +13 -9
- modal/__main__.py +41 -3
- modal/_clustered_functions.py +80 -0
- modal/_clustered_functions.pyi +22 -0
- modal/_container_entrypoint.py +402 -398
- modal/_ipython.py +3 -13
- modal/_location.py +17 -10
- modal/_output.py +243 -99
- modal/_pty.py +2 -2
- modal/_resolver.py +55 -60
- modal/_resources.py +26 -7
- modal/_runtime/__init__.py +1 -0
- modal/_runtime/asgi.py +519 -0
- modal/_runtime/container_io_manager.py +1025 -0
- modal/{execution_context.py → _runtime/execution_context.py} +11 -2
- modal/_runtime/telemetry.py +169 -0
- modal/_runtime/user_code_imports.py +356 -0
- modal/_serialization.py +123 -6
- modal/_traceback.py +47 -187
- modal/_tunnel.py +50 -14
- modal/_tunnel.pyi +19 -36
- modal/_utils/app_utils.py +3 -17
- modal/_utils/async_utils.py +386 -104
- modal/_utils/blob_utils.py +157 -186
- modal/_utils/bytes_io_segment_payload.py +97 -0
- modal/_utils/deprecation.py +89 -0
- modal/_utils/docker_utils.py +98 -0
- modal/_utils/function_utils.py +299 -98
- modal/_utils/grpc_testing.py +47 -34
- modal/_utils/grpc_utils.py +54 -21
- modal/_utils/hash_utils.py +51 -10
- modal/_utils/http_utils.py +39 -9
- modal/_utils/logger.py +2 -1
- modal/_utils/mount_utils.py +34 -16
- modal/_utils/name_utils.py +58 -0
- modal/_utils/package_utils.py +14 -1
- modal/_utils/pattern_utils.py +205 -0
- modal/_utils/rand_pb_testing.py +3 -3
- modal/_utils/shell_utils.py +15 -49
- modal/_vendor/a2wsgi_wsgi.py +62 -72
- modal/_vendor/cloudpickle.py +1 -1
- modal/_watcher.py +12 -10
- modal/app.py +561 -323
- modal/app.pyi +474 -262
- modal/call_graph.py +7 -6
- modal/cli/_download.py +22 -6
- modal/cli/_traceback.py +200 -0
- modal/cli/app.py +203 -42
- modal/cli/config.py +12 -5
- modal/cli/container.py +61 -13
- modal/cli/dict.py +128 -0
- modal/cli/entry_point.py +26 -13
- modal/cli/environment.py +40 -9
- modal/cli/import_refs.py +21 -48
- modal/cli/launch.py +28 -14
- modal/cli/network_file_system.py +57 -21
- modal/cli/profile.py +1 -1
- modal/cli/programs/run_jupyter.py +34 -9
- modal/cli/programs/vscode.py +58 -8
- modal/cli/queues.py +131 -0
- modal/cli/run.py +199 -96
- modal/cli/secret.py +5 -4
- modal/cli/token.py +7 -2
- modal/cli/utils.py +74 -8
- modal/cli/volume.py +97 -56
- modal/client.py +248 -144
- modal/client.pyi +156 -124
- modal/cloud_bucket_mount.py +43 -30
- modal/cloud_bucket_mount.pyi +32 -25
- modal/cls.py +528 -141
- modal/cls.pyi +189 -145
- modal/config.py +32 -15
- modal/container_process.py +177 -0
- modal/container_process.pyi +82 -0
- modal/dict.py +50 -54
- modal/dict.pyi +120 -164
- modal/environments.py +106 -5
- modal/environments.pyi +77 -25
- modal/exception.py +30 -43
- modal/experimental.py +62 -2
- modal/file_io.py +537 -0
- modal/file_io.pyi +235 -0
- modal/file_pattern_matcher.py +196 -0
- modal/functions.py +846 -428
- modal/functions.pyi +446 -387
- modal/gpu.py +57 -44
- modal/image.py +943 -417
- modal/image.pyi +584 -245
- modal/io_streams.py +434 -0
- modal/io_streams.pyi +122 -0
- modal/mount.py +223 -90
- modal/mount.pyi +241 -243
- modal/network_file_system.py +85 -86
- modal/network_file_system.pyi +151 -110
- modal/object.py +66 -36
- modal/object.pyi +166 -143
- modal/output.py +63 -0
- modal/parallel_map.py +73 -47
- modal/parallel_map.pyi +51 -63
- modal/partial_function.py +272 -107
- modal/partial_function.pyi +219 -120
- modal/proxy.py +15 -12
- modal/proxy.pyi +3 -8
- modal/queue.py +96 -72
- modal/queue.pyi +210 -135
- modal/requirements/2024.04.txt +2 -1
- modal/requirements/2024.10.txt +16 -0
- modal/requirements/README.md +21 -0
- modal/requirements/base-images.json +22 -0
- modal/retries.py +45 -4
- modal/runner.py +325 -203
- modal/runner.pyi +124 -110
- modal/running_app.py +27 -4
- modal/sandbox.py +509 -231
- modal/sandbox.pyi +396 -169
- modal/schedule.py +2 -2
- modal/scheduler_placement.py +20 -3
- modal/secret.py +41 -25
- modal/secret.pyi +62 -42
- modal/serving.py +39 -49
- modal/serving.pyi +37 -43
- modal/stream_type.py +15 -0
- modal/token_flow.py +5 -3
- modal/token_flow.pyi +37 -32
- modal/volume.py +123 -137
- modal/volume.pyi +228 -221
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/METADATA +5 -5
- modal-0.72.13.dist-info/RECORD +174 -0
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/top_level.txt +0 -1
- modal_docs/gen_reference_docs.py +3 -1
- modal_docs/mdmd/mdmd.py +0 -1
- modal_docs/mdmd/signatures.py +1 -2
- modal_global_objects/images/base_images.py +28 -0
- modal_global_objects/mounts/python_standalone.py +2 -2
- modal_proto/__init__.py +1 -1
- modal_proto/api.proto +1231 -531
- modal_proto/api_grpc.py +750 -430
- modal_proto/api_pb2.py +2102 -1176
- modal_proto/api_pb2.pyi +8859 -0
- modal_proto/api_pb2_grpc.py +1329 -675
- modal_proto/api_pb2_grpc.pyi +1416 -0
- modal_proto/modal_api_grpc.py +149 -0
- modal_proto/modal_options_grpc.py +3 -0
- modal_proto/options_pb2.pyi +20 -0
- modal_proto/options_pb2_grpc.pyi +7 -0
- modal_proto/py.typed +0 -0
- modal_version/__init__.py +1 -1
- modal_version/_version_generated.py +2 -2
- modal/_asgi.py +0 -370
- modal/_container_exec.py +0 -128
- modal/_container_io_manager.py +0 -646
- modal/_container_io_manager.pyi +0 -412
- modal/_sandbox_shell.py +0 -49
- modal/app_utils.py +0 -20
- modal/app_utils.pyi +0 -17
- modal/execution_context.pyi +0 -37
- modal/shared_volume.py +0 -23
- modal/shared_volume.pyi +0 -24
- modal-0.62.115.dist-info/RECORD +0 -207
- modal_global_objects/images/conda.py +0 -15
- modal_global_objects/images/debian_slim.py +0 -15
- modal_global_objects/images/micromamba.py +0 -15
- test/__init__.py +0 -1
- test/aio_test.py +0 -12
- test/async_utils_test.py +0 -279
- test/blob_test.py +0 -67
- test/cli_imports_test.py +0 -149
- test/cli_test.py +0 -674
- test/client_test.py +0 -203
- test/cloud_bucket_mount_test.py +0 -22
- test/cls_test.py +0 -636
- test/config_test.py +0 -149
- test/conftest.py +0 -1485
- test/container_app_test.py +0 -50
- test/container_test.py +0 -1405
- test/cpu_test.py +0 -23
- test/decorator_test.py +0 -85
- test/deprecation_test.py +0 -34
- test/dict_test.py +0 -51
- test/e2e_test.py +0 -68
- test/error_test.py +0 -7
- test/function_serialization_test.py +0 -32
- test/function_test.py +0 -791
- test/function_utils_test.py +0 -101
- test/gpu_test.py +0 -159
- test/grpc_utils_test.py +0 -82
- test/helpers.py +0 -47
- test/image_test.py +0 -814
- test/live_reload_test.py +0 -80
- test/lookup_test.py +0 -70
- test/mdmd_test.py +0 -329
- test/mount_test.py +0 -162
- test/mounted_files_test.py +0 -327
- test/network_file_system_test.py +0 -188
- test/notebook_test.py +0 -66
- test/object_test.py +0 -41
- test/package_utils_test.py +0 -25
- test/queue_test.py +0 -115
- test/resolver_test.py +0 -59
- test/retries_test.py +0 -67
- test/runner_test.py +0 -85
- test/sandbox_test.py +0 -191
- test/schedule_test.py +0 -15
- test/scheduler_placement_test.py +0 -57
- test/secret_test.py +0 -89
- test/serialization_test.py +0 -50
- test/stub_composition_test.py +0 -10
- test/stub_test.py +0 -361
- test/test_asgi_wrapper.py +0 -234
- test/token_flow_test.py +0 -18
- test/traceback_test.py +0 -135
- test/tunnel_test.py +0 -29
- test/utils_test.py +0 -88
- test/version_test.py +0 -14
- test/volume_test.py +0 -397
- test/watcher_test.py +0 -58
- test/webhook_test.py +0 -145
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/LICENSE +0 -0
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/WHEEL +0 -0
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/entry_points.txt +0 -0
modal/_utils/blob_utils.py
CHANGED
@@ -5,23 +5,26 @@ import hashlib
|
|
5
5
|
import io
|
6
6
|
import os
|
7
7
|
import platform
|
8
|
-
|
8
|
+
import time
|
9
|
+
from collections.abc import AsyncIterator
|
10
|
+
from contextlib import AbstractContextManager, contextmanager
|
9
11
|
from pathlib import Path, PurePosixPath
|
10
|
-
from typing import
|
12
|
+
from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Optional, Union
|
11
13
|
from urllib.parse import urlparse
|
12
14
|
|
13
|
-
from aiohttp import BytesIOPayload
|
14
|
-
from aiohttp.abc import AbstractStreamWriter
|
15
|
-
|
16
15
|
from modal_proto import api_pb2
|
16
|
+
from modal_proto.modal_api_grpc import ModalClientModal
|
17
17
|
|
18
18
|
from ..exception import ExecutionError
|
19
|
-
from .async_utils import retry
|
19
|
+
from .async_utils import TaskContext, retry
|
20
20
|
from .grpc_utils import retry_transient_errors
|
21
|
-
from .hash_utils import UploadHashes,
|
22
|
-
from .http_utils import
|
21
|
+
from .hash_utils import UploadHashes, get_upload_hashes
|
22
|
+
from .http_utils import ClientSessionRegistry
|
23
23
|
from .logger import logger
|
24
24
|
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from .bytes_io_segment_payload import BytesIOSegmentPayload
|
27
|
+
|
25
28
|
# Max size for function inputs and outputs.
|
26
29
|
MAX_OBJECT_SIZE_BYTES = 2 * 1024 * 1024 # 2 MiB
|
27
30
|
|
@@ -35,129 +38,59 @@ BLOB_MAX_PARALLELISM = 10
|
|
35
38
|
# read ~16MiB chunks by default
|
36
39
|
DEFAULT_SEGMENT_CHUNK_SIZE = 2**24
|
37
40
|
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
Adds:
|
43
|
-
* read limit using remaining_bytes, in order to split files across streams
|
44
|
-
* larger read chunk (to prevent excessive read contention between parts)
|
45
|
-
* calculates an md5 for the segment
|
46
|
-
|
47
|
-
Feels like this should be in some standard lib...
|
48
|
-
"""
|
49
|
-
|
50
|
-
def __init__(
|
51
|
-
self,
|
52
|
-
bytes_io: BinaryIO, # should *not* be shared as IO position modification is not locked
|
53
|
-
segment_start: int,
|
54
|
-
segment_length: int,
|
55
|
-
chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
|
56
|
-
):
|
57
|
-
# not thread safe constructor!
|
58
|
-
super().__init__(bytes_io)
|
59
|
-
self.initial_seek_pos = bytes_io.tell()
|
60
|
-
self.segment_start = segment_start
|
61
|
-
self.segment_length = segment_length
|
62
|
-
# seek to start of file segment we are interested in, in order to make .size() evaluate correctly
|
63
|
-
self._value.seek(self.initial_seek_pos + segment_start)
|
64
|
-
assert self.segment_length <= super().size
|
65
|
-
self.chunk_size = chunk_size
|
66
|
-
self.reset_state()
|
67
|
-
|
68
|
-
def reset_state(self):
|
69
|
-
self._md5_checksum = hashlib.md5()
|
70
|
-
self.num_bytes_read = 0
|
71
|
-
self._value.seek(self.initial_seek_pos)
|
72
|
-
|
73
|
-
@contextmanager
|
74
|
-
def reset_on_error(self):
|
75
|
-
try:
|
76
|
-
yield
|
77
|
-
finally:
|
78
|
-
self.reset_state()
|
79
|
-
|
80
|
-
@property
|
81
|
-
def size(self) -> int:
|
82
|
-
return self.segment_length
|
83
|
-
|
84
|
-
def md5_checksum(self):
|
85
|
-
return self._md5_checksum
|
86
|
-
|
87
|
-
async def write(self, writer: AbstractStreamWriter):
|
88
|
-
loop = asyncio.get_event_loop()
|
89
|
-
|
90
|
-
async def safe_read():
|
91
|
-
read_start = self.initial_seek_pos + self.segment_start + self.num_bytes_read
|
92
|
-
self._value.seek(read_start)
|
93
|
-
num_bytes = min(self.chunk_size, self.remaining_bytes())
|
94
|
-
chunk = await loop.run_in_executor(None, self._value.read, num_bytes)
|
95
|
-
|
96
|
-
await loop.run_in_executor(None, self._md5_checksum.update, chunk)
|
97
|
-
self.num_bytes_read += len(chunk)
|
98
|
-
return chunk
|
99
|
-
|
100
|
-
chunk = await safe_read()
|
101
|
-
while chunk and self.remaining_bytes() > 0:
|
102
|
-
await writer.write(chunk)
|
103
|
-
chunk = await safe_read()
|
104
|
-
if chunk:
|
105
|
-
await writer.write(chunk)
|
106
|
-
|
107
|
-
def remaining_bytes(self):
|
108
|
-
return self.segment_length - self.num_bytes_read
|
41
|
+
# Files larger than this will be multipart uploaded. The server might request multipart upload for smaller files as
|
42
|
+
# well, but the limit will never be raised.
|
43
|
+
# TODO(dano): remove this once we stop requiring md5 for blobs
|
44
|
+
MULTIPART_UPLOAD_THRESHOLD = 1024**3
|
109
45
|
|
110
46
|
|
111
47
|
@retry(n_attempts=5, base_delay=0.5, timeout=None)
|
112
48
|
async def _upload_to_s3_url(
|
113
49
|
upload_url,
|
114
|
-
payload: BytesIOSegmentPayload,
|
50
|
+
payload: "BytesIOSegmentPayload",
|
115
51
|
content_md5_b64: Optional[str] = None,
|
116
52
|
content_type: Optional[str] = "application/octet-stream", # set to None to force omission of ContentType header
|
117
53
|
) -> str:
|
118
54
|
"""Returns etag of s3 object which is a md5 hex checksum of the uploaded content"""
|
119
55
|
with payload.reset_on_error(): # ensure retries read the same data
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
)
|
159
|
-
|
160
|
-
return remote_md5
|
56
|
+
headers = {}
|
57
|
+
if content_md5_b64 and use_md5(upload_url):
|
58
|
+
headers["Content-MD5"] = content_md5_b64
|
59
|
+
if content_type:
|
60
|
+
headers["Content-Type"] = content_type
|
61
|
+
|
62
|
+
async with ClientSessionRegistry.get_session().put(
|
63
|
+
upload_url,
|
64
|
+
data=payload,
|
65
|
+
headers=headers,
|
66
|
+
skip_auto_headers=["content-type"] if content_type is None else [],
|
67
|
+
) as resp:
|
68
|
+
# S3 signal to slow down request rate.
|
69
|
+
if resp.status == 503:
|
70
|
+
logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
|
71
|
+
await asyncio.sleep(1)
|
72
|
+
|
73
|
+
if resp.status != 200:
|
74
|
+
try:
|
75
|
+
text = await resp.text()
|
76
|
+
except Exception:
|
77
|
+
text = "<no body>"
|
78
|
+
raise ExecutionError(f"Put to url {upload_url} failed with status {resp.status}: {text}")
|
79
|
+
|
80
|
+
# client side ETag checksum verification
|
81
|
+
# the s3 ETag of a single part upload is a quoted md5 hex of the uploaded content
|
82
|
+
etag = resp.headers["ETag"].strip()
|
83
|
+
if etag.startswith(("W/", "w/")): # see https://www.rfc-editor.org/rfc/rfc7232#section-2.3
|
84
|
+
etag = etag[2:]
|
85
|
+
if etag[0] == '"' and etag[-1] == '"':
|
86
|
+
etag = etag[1:-1]
|
87
|
+
remote_md5 = etag
|
88
|
+
|
89
|
+
local_md5_hex = payload.md5_checksum().hexdigest()
|
90
|
+
if local_md5_hex != remote_md5:
|
91
|
+
raise ExecutionError(f"Local data and remote data checksum mismatch ({local_md5_hex} vs {remote_md5})")
|
92
|
+
|
93
|
+
return remote_md5
|
161
94
|
|
162
95
|
|
163
96
|
async def perform_multipart_upload(
|
@@ -165,17 +98,20 @@ async def perform_multipart_upload(
|
|
165
98
|
*,
|
166
99
|
content_length: int,
|
167
100
|
max_part_size: int,
|
168
|
-
part_urls:
|
101
|
+
part_urls: list[str],
|
169
102
|
completion_url: str,
|
170
103
|
upload_chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
|
171
|
-
|
104
|
+
progress_report_cb: Optional[Callable] = None,
|
105
|
+
) -> None:
|
106
|
+
from .bytes_io_segment_payload import BytesIOSegmentPayload
|
107
|
+
|
172
108
|
upload_coros = []
|
173
109
|
file_offset = 0
|
174
110
|
num_bytes_left = content_length
|
175
111
|
|
176
112
|
# Give each part its own IO reader object to avoid needing to
|
177
113
|
# lock access to the reader's position pointer.
|
178
|
-
data_file_readers:
|
114
|
+
data_file_readers: list[BinaryIO]
|
179
115
|
if isinstance(data_file, io.BytesIO):
|
180
116
|
view = data_file.getbuffer() # does not copy data
|
181
117
|
data_file_readers = [io.BytesIO(view) for _ in range(len(part_urls))]
|
@@ -190,12 +126,13 @@ async def perform_multipart_upload(
|
|
190
126
|
segment_start=file_offset,
|
191
127
|
segment_length=part_length_bytes,
|
192
128
|
chunk_size=upload_chunk_size,
|
129
|
+
progress_report_cb=progress_report_cb,
|
193
130
|
)
|
194
131
|
upload_coros.append(_upload_to_s3_url(part_url, payload=part_payload, content_type=None))
|
195
132
|
num_bytes_left -= part_length_bytes
|
196
133
|
file_offset += part_length_bytes
|
197
134
|
|
198
|
-
part_etags = await
|
135
|
+
part_etags = await TaskContext.gather(*upload_coros)
|
199
136
|
|
200
137
|
# The body of the complete_multipart_upload command needs some data in xml format:
|
201
138
|
completion_body = "<CompleteMultipartUpload>\n"
|
@@ -207,25 +144,24 @@ async def perform_multipart_upload(
|
|
207
144
|
bin_hash_parts = [bytes.fromhex(etag) for etag in part_etags]
|
208
145
|
|
209
146
|
expected_multipart_etag = hashlib.md5(b"".join(bin_hash_parts)).hexdigest() + f"-{len(part_etags)}"
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
)
|
147
|
+
resp = await ClientSessionRegistry.get_session().post(
|
148
|
+
completion_url, data=completion_body.encode("ascii"), skip_auto_headers=["content-type"]
|
149
|
+
)
|
150
|
+
if resp.status != 200:
|
151
|
+
try:
|
152
|
+
msg = await resp.text()
|
153
|
+
except Exception:
|
154
|
+
msg = "<no body>"
|
155
|
+
raise ExecutionError(f"Error when completing multipart upload: {resp.status}\n{msg}")
|
156
|
+
else:
|
157
|
+
response_body = await resp.text()
|
158
|
+
if expected_multipart_etag not in response_body:
|
159
|
+
raise ExecutionError(
|
160
|
+
f"Hash mismatch on multipart upload assembly: {expected_multipart_etag} not in {response_body}"
|
161
|
+
)
|
226
162
|
|
227
163
|
|
228
|
-
def get_content_length(data: BinaryIO):
|
164
|
+
def get_content_length(data: BinaryIO) -> int:
|
229
165
|
# *Remaining* length of file from current seek position
|
230
166
|
pos = data.tell()
|
231
167
|
data.seek(0, os.SEEK_END)
|
@@ -234,7 +170,9 @@ def get_content_length(data: BinaryIO):
|
|
234
170
|
return content_length - pos
|
235
171
|
|
236
172
|
|
237
|
-
async def _blob_upload(
|
173
|
+
async def _blob_upload(
|
174
|
+
upload_hashes: UploadHashes, data: Union[bytes, BinaryIO], stub, progress_report_cb: Optional[Callable] = None
|
175
|
+
) -> str:
|
238
176
|
if isinstance(data, bytes):
|
239
177
|
data = io.BytesIO(data)
|
240
178
|
|
@@ -257,9 +195,14 @@ async def _blob_upload(upload_hashes: UploadHashes, data: Union[bytes, BinaryIO]
|
|
257
195
|
part_urls=resp.multipart.upload_urls,
|
258
196
|
completion_url=resp.multipart.completion_url,
|
259
197
|
upload_chunk_size=DEFAULT_SEGMENT_CHUNK_SIZE,
|
198
|
+
progress_report_cb=progress_report_cb,
|
260
199
|
)
|
261
200
|
else:
|
262
|
-
|
201
|
+
from .bytes_io_segment_payload import BytesIOSegmentPayload
|
202
|
+
|
203
|
+
payload = BytesIOSegmentPayload(
|
204
|
+
data, segment_start=0, segment_length=content_length, progress_report_cb=progress_report_cb
|
205
|
+
)
|
263
206
|
await _upload_to_s3_url(
|
264
207
|
resp.upload_url,
|
265
208
|
payload,
|
@@ -267,79 +210,103 @@ async def _blob_upload(upload_hashes: UploadHashes, data: Union[bytes, BinaryIO]
|
|
267
210
|
content_md5_b64=upload_hashes.md5_base64,
|
268
211
|
)
|
269
212
|
|
213
|
+
if progress_report_cb:
|
214
|
+
progress_report_cb(complete=True)
|
215
|
+
|
270
216
|
return blob_id
|
271
217
|
|
272
218
|
|
273
|
-
async def blob_upload(payload: bytes, stub) -> str:
|
219
|
+
async def blob_upload(payload: bytes, stub: ModalClientModal) -> str:
|
220
|
+
size_mib = len(payload) / 1024 / 1024
|
221
|
+
logger.debug(f"Uploading large blob of size {size_mib:.2f} MiB")
|
222
|
+
t0 = time.time()
|
274
223
|
if isinstance(payload, str):
|
275
224
|
logger.warning("Blob uploading string, not bytes - auto-encoding as utf8")
|
276
225
|
payload = payload.encode("utf8")
|
277
226
|
upload_hashes = get_upload_hashes(payload)
|
278
|
-
|
227
|
+
blob_id = await _blob_upload(upload_hashes, payload, stub)
|
228
|
+
dur_s = max(time.time() - t0, 0.001) # avoid division by zero
|
229
|
+
throughput_mib_s = (size_mib) / dur_s
|
230
|
+
logger.debug(f"Uploaded large blob of size {size_mib:.2f} MiB ({throughput_mib_s:.2f} MiB/s)." f" {blob_id}")
|
231
|
+
return blob_id
|
279
232
|
|
280
233
|
|
281
|
-
async def blob_upload_file(
|
282
|
-
|
283
|
-
|
234
|
+
async def blob_upload_file(
|
235
|
+
file_obj: BinaryIO,
|
236
|
+
stub: ModalClientModal,
|
237
|
+
progress_report_cb: Optional[Callable] = None,
|
238
|
+
sha256_hex: Optional[str] = None,
|
239
|
+
md5_hex: Optional[str] = None,
|
240
|
+
) -> str:
|
241
|
+
upload_hashes = get_upload_hashes(file_obj, sha256_hex=sha256_hex, md5_hex=md5_hex)
|
242
|
+
return await _blob_upload(upload_hashes, file_obj, stub, progress_report_cb)
|
284
243
|
|
285
244
|
|
286
245
|
@retry(n_attempts=5, base_delay=0.1, timeout=None)
|
287
|
-
async def _download_from_url(download_url) -> bytes:
|
288
|
-
async with
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
246
|
+
async def _download_from_url(download_url: str) -> bytes:
|
247
|
+
async with ClientSessionRegistry.get_session().get(download_url) as s3_resp:
|
248
|
+
# S3 signal to slow down request rate.
|
249
|
+
if s3_resp.status == 503:
|
250
|
+
logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
|
251
|
+
await asyncio.sleep(1)
|
252
|
+
|
253
|
+
if s3_resp.status != 200:
|
254
|
+
text = await s3_resp.text()
|
255
|
+
raise ExecutionError(f"Get from url failed with status {s3_resp.status}: {text}")
|
256
|
+
return await s3_resp.read()
|
257
|
+
|
258
|
+
|
259
|
+
async def blob_download(blob_id: str, stub: ModalClientModal) -> bytes:
|
260
|
+
"""Convenience function for reading all of the downloaded file into memory."""
|
261
|
+
logger.debug(f"Downloading large blob {blob_id}")
|
262
|
+
t0 = time.time()
|
303
263
|
req = api_pb2.BlobGetRequest(blob_id=blob_id)
|
304
264
|
resp = await retry_transient_errors(stub.BlobGet, req)
|
305
|
-
|
306
|
-
|
265
|
+
data = await _download_from_url(resp.download_url)
|
266
|
+
size_mib = len(data) / 1024 / 1024
|
267
|
+
dur_s = max(time.time() - t0, 0.001) # avoid division by zero
|
268
|
+
throughput_mib_s = size_mib / dur_s
|
269
|
+
logger.debug(f"Downloaded large blob {blob_id} of size {size_mib:.2f} MiB ({throughput_mib_s:.2f} MiB/s)")
|
270
|
+
return data
|
307
271
|
|
308
272
|
|
309
|
-
async def blob_iter(blob_id, stub) -> AsyncIterator[bytes]:
|
273
|
+
async def blob_iter(blob_id: str, stub: ModalClientModal) -> AsyncIterator[bytes]:
|
310
274
|
req = api_pb2.BlobGetRequest(blob_id=blob_id)
|
311
275
|
resp = await retry_transient_errors(stub.BlobGet, req)
|
312
276
|
download_url = resp.download_url
|
313
|
-
async with
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
await asyncio.sleep(1)
|
277
|
+
async with ClientSessionRegistry.get_session().get(download_url) as s3_resp:
|
278
|
+
# S3 signal to slow down request rate.
|
279
|
+
if s3_resp.status == 503:
|
280
|
+
logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
|
281
|
+
await asyncio.sleep(1)
|
319
282
|
|
320
|
-
|
321
|
-
|
322
|
-
|
283
|
+
if s3_resp.status != 200:
|
284
|
+
text = await s3_resp.text()
|
285
|
+
raise ExecutionError(f"Get from url failed with status {s3_resp.status}: {text}")
|
323
286
|
|
324
|
-
|
325
|
-
|
287
|
+
async for chunk in s3_resp.content.iter_any():
|
288
|
+
yield chunk
|
326
289
|
|
327
290
|
|
328
291
|
@dataclasses.dataclass
|
329
292
|
class FileUploadSpec:
|
330
|
-
source: Callable[[], BinaryIO]
|
293
|
+
source: Callable[[], Union[AbstractContextManager, BinaryIO]]
|
331
294
|
source_description: Any
|
332
295
|
mount_filename: str
|
333
296
|
|
334
297
|
use_blob: bool
|
335
298
|
content: Optional[bytes] # typically None if using blob, required otherwise
|
336
299
|
sha256_hex: str
|
300
|
+
md5_hex: str
|
337
301
|
mode: int # file permission bits (last 12 bits of st_mode)
|
338
302
|
size: int
|
339
303
|
|
340
304
|
|
341
305
|
def _get_file_upload_spec(
|
342
|
-
source: Callable[[], BinaryIO],
|
306
|
+
source: Callable[[], Union[AbstractContextManager, BinaryIO]],
|
307
|
+
source_description: Any,
|
308
|
+
mount_filename: PurePosixPath,
|
309
|
+
mode: int,
|
343
310
|
) -> FileUploadSpec:
|
344
311
|
with source() as fp:
|
345
312
|
# Current position is ignored - we always upload from position 0
|
@@ -348,13 +315,15 @@ def _get_file_upload_spec(
|
|
348
315
|
fp.seek(0)
|
349
316
|
|
350
317
|
if size >= LARGE_FILE_LIMIT:
|
318
|
+
# TODO(dano): remove the placeholder md5 once we stop requiring md5 for blobs
|
319
|
+
md5_hex = "baadbaadbaadbaadbaadbaadbaadbaad" if size > MULTIPART_UPLOAD_THRESHOLD else None
|
351
320
|
use_blob = True
|
352
321
|
content = None
|
353
|
-
|
322
|
+
hashes = get_upload_hashes(fp, md5_hex=md5_hex)
|
354
323
|
else:
|
355
324
|
use_blob = False
|
356
325
|
content = fp.read()
|
357
|
-
|
326
|
+
hashes = get_upload_hashes(content)
|
358
327
|
|
359
328
|
return FileUploadSpec(
|
360
329
|
source=source,
|
@@ -362,7 +331,8 @@ def _get_file_upload_spec(
|
|
362
331
|
mount_filename=mount_filename.as_posix(),
|
363
332
|
use_blob=use_blob,
|
364
333
|
content=content,
|
365
|
-
sha256_hex=sha256_hex,
|
334
|
+
sha256_hex=hashes.sha256_hex(),
|
335
|
+
md5_hex=hashes.md5_hex(),
|
366
336
|
mode=mode & 0o7777,
|
367
337
|
size=size,
|
368
338
|
)
|
@@ -383,10 +353,11 @@ def get_file_upload_spec_from_path(
|
|
383
353
|
|
384
354
|
|
385
355
|
def get_file_upload_spec_from_fileobj(fp: BinaryIO, mount_filename: PurePosixPath, mode: int) -> FileUploadSpec:
|
356
|
+
@contextmanager
|
386
357
|
def source():
|
387
358
|
# We ignore position in stream and always upload from position 0
|
388
359
|
fp.seek(0)
|
389
|
-
|
360
|
+
yield fp
|
390
361
|
|
391
362
|
return _get_file_upload_spec(
|
392
363
|
source,
|
@@ -403,7 +374,7 @@ def use_md5(url: str) -> bool:
|
|
403
374
|
https://github.com/spulec/moto/issues/816
|
404
375
|
"""
|
405
376
|
host = urlparse(url).netloc.split(":")[0]
|
406
|
-
if host.endswith(".amazonaws.com"):
|
377
|
+
if host.endswith(".amazonaws.com") or host.endswith(".r2.cloudflarestorage.com"):
|
407
378
|
return True
|
408
379
|
elif host in ["127.0.0.1", "localhost", "172.21.0.1"]:
|
409
380
|
return False
|
@@ -0,0 +1,97 @@
|
|
1
|
+
# Copyright Modal Labs 2024
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import hashlib
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from typing import BinaryIO, Callable, Optional
|
7
|
+
|
8
|
+
# Note: this module needs to import aiohttp in global scope
|
9
|
+
# This takes about 50ms and isn't needed in many cases for Modal execution
|
10
|
+
# To avoid this, we import it in local scope when needed (blob_utils.py)
|
11
|
+
from aiohttp import BytesIOPayload
|
12
|
+
from aiohttp.abc import AbstractStreamWriter
|
13
|
+
|
14
|
+
# read ~16MiB chunks by default
|
15
|
+
DEFAULT_SEGMENT_CHUNK_SIZE = 2**24
|
16
|
+
|
17
|
+
|
18
|
+
class BytesIOSegmentPayload(BytesIOPayload):
|
19
|
+
"""Modified bytes payload for concurrent sends of chunks from the same file.
|
20
|
+
|
21
|
+
Adds:
|
22
|
+
* read limit using remaining_bytes, in order to split files across streams
|
23
|
+
* larger read chunk (to prevent excessive read contention between parts)
|
24
|
+
* calculates an md5 for the segment
|
25
|
+
|
26
|
+
Feels like this should be in some standard lib...
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
bytes_io: BinaryIO, # should *not* be shared as IO position modification is not locked
|
32
|
+
segment_start: int,
|
33
|
+
segment_length: int,
|
34
|
+
chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
|
35
|
+
progress_report_cb: Optional[Callable] = None,
|
36
|
+
):
|
37
|
+
# not thread safe constructor!
|
38
|
+
super().__init__(bytes_io)
|
39
|
+
self.initial_seek_pos = bytes_io.tell()
|
40
|
+
self.segment_start = segment_start
|
41
|
+
self.segment_length = segment_length
|
42
|
+
# seek to start of file segment we are interested in, in order to make .size() evaluate correctly
|
43
|
+
self._value.seek(self.initial_seek_pos + segment_start)
|
44
|
+
assert self.segment_length <= super().size
|
45
|
+
self.chunk_size = chunk_size
|
46
|
+
self.progress_report_cb = progress_report_cb or (lambda *_, **__: None)
|
47
|
+
self.reset_state()
|
48
|
+
|
49
|
+
def reset_state(self):
|
50
|
+
self._md5_checksum = hashlib.md5()
|
51
|
+
self.num_bytes_read = 0
|
52
|
+
self._value.seek(self.initial_seek_pos)
|
53
|
+
|
54
|
+
@contextmanager
|
55
|
+
def reset_on_error(self):
|
56
|
+
try:
|
57
|
+
yield
|
58
|
+
except Exception as exc:
|
59
|
+
try:
|
60
|
+
self.progress_report_cb(reset=True)
|
61
|
+
except Exception as cb_exc:
|
62
|
+
raise cb_exc from exc
|
63
|
+
raise exc
|
64
|
+
finally:
|
65
|
+
self.reset_state()
|
66
|
+
|
67
|
+
@property
|
68
|
+
def size(self) -> int:
|
69
|
+
return self.segment_length
|
70
|
+
|
71
|
+
def md5_checksum(self):
|
72
|
+
return self._md5_checksum
|
73
|
+
|
74
|
+
async def write(self, writer: "AbstractStreamWriter"):
|
75
|
+
loop = asyncio.get_event_loop()
|
76
|
+
|
77
|
+
async def safe_read():
|
78
|
+
read_start = self.initial_seek_pos + self.segment_start + self.num_bytes_read
|
79
|
+
self._value.seek(read_start)
|
80
|
+
num_bytes = min(self.chunk_size, self.remaining_bytes())
|
81
|
+
chunk = await loop.run_in_executor(None, self._value.read, num_bytes)
|
82
|
+
|
83
|
+
await loop.run_in_executor(None, self._md5_checksum.update, chunk)
|
84
|
+
self.num_bytes_read += len(chunk)
|
85
|
+
return chunk
|
86
|
+
|
87
|
+
chunk = await safe_read()
|
88
|
+
while chunk and self.remaining_bytes() > 0:
|
89
|
+
await writer.write(chunk)
|
90
|
+
self.progress_report_cb(advance=len(chunk))
|
91
|
+
chunk = await safe_read()
|
92
|
+
if chunk:
|
93
|
+
await writer.write(chunk)
|
94
|
+
self.progress_report_cb(advance=len(chunk))
|
95
|
+
|
96
|
+
def remaining_bytes(self):
|
97
|
+
return self.segment_length - self.num_bytes_read
|