modal 0.62.115__py3-none-any.whl → 0.72.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +13 -9
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +402 -398
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -60
  11. modal/_resources.py +26 -7
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1025 -0
  15. modal/{execution_context.py → _runtime/execution_context.py} +11 -2
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +123 -6
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +50 -14
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +386 -104
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +299 -98
  29. modal/_utils/grpc_testing.py +47 -34
  30. modal/_utils/grpc_utils.py +54 -21
  31. modal/_utils/hash_utils.py +51 -10
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +3 -3
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +12 -10
  43. modal/app.py +561 -323
  44. modal/app.pyi +474 -262
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +22 -6
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +203 -42
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +61 -13
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +21 -48
  55. modal/cli/launch.py +28 -14
  56. modal/cli/network_file_system.py +57 -21
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +34 -9
  59. modal/cli/programs/vscode.py +58 -8
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +199 -96
  62. modal/cli/secret.py +5 -4
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +74 -8
  65. modal/cli/volume.py +97 -56
  66. modal/client.py +248 -144
  67. modal/client.pyi +156 -124
  68. modal/cloud_bucket_mount.py +43 -30
  69. modal/cloud_bucket_mount.pyi +32 -25
  70. modal/cls.py +528 -141
  71. modal/cls.pyi +189 -145
  72. modal/config.py +32 -15
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +50 -54
  76. modal/dict.pyi +120 -164
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +30 -43
  80. modal/experimental.py +62 -2
  81. modal/file_io.py +537 -0
  82. modal/file_io.pyi +235 -0
  83. modal/file_pattern_matcher.py +196 -0
  84. modal/functions.py +846 -428
  85. modal/functions.pyi +446 -387
  86. modal/gpu.py +57 -44
  87. modal/image.py +943 -417
  88. modal/image.pyi +584 -245
  89. modal/io_streams.py +434 -0
  90. modal/io_streams.pyi +122 -0
  91. modal/mount.py +223 -90
  92. modal/mount.pyi +241 -243
  93. modal/network_file_system.py +85 -86
  94. modal/network_file_system.pyi +151 -110
  95. modal/object.py +66 -36
  96. modal/object.pyi +166 -143
  97. modal/output.py +63 -0
  98. modal/parallel_map.py +73 -47
  99. modal/parallel_map.pyi +51 -63
  100. modal/partial_function.py +272 -107
  101. modal/partial_function.pyi +219 -120
  102. modal/proxy.py +15 -12
  103. modal/proxy.pyi +3 -8
  104. modal/queue.py +96 -72
  105. modal/queue.pyi +210 -135
  106. modal/requirements/2024.04.txt +2 -1
  107. modal/requirements/2024.10.txt +16 -0
  108. modal/requirements/README.md +21 -0
  109. modal/requirements/base-images.json +22 -0
  110. modal/retries.py +45 -4
  111. modal/runner.py +325 -203
  112. modal/runner.pyi +124 -110
  113. modal/running_app.py +27 -4
  114. modal/sandbox.py +509 -231
  115. modal/sandbox.pyi +396 -169
  116. modal/schedule.py +2 -2
  117. modal/scheduler_placement.py +20 -3
  118. modal/secret.py +41 -25
  119. modal/secret.pyi +62 -42
  120. modal/serving.py +39 -49
  121. modal/serving.pyi +37 -43
  122. modal/stream_type.py +15 -0
  123. modal/token_flow.py +5 -3
  124. modal/token_flow.pyi +37 -32
  125. modal/volume.py +123 -137
  126. modal/volume.pyi +228 -221
  127. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/METADATA +5 -5
  128. modal-0.72.13.dist-info/RECORD +174 -0
  129. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/top_level.txt +0 -1
  130. modal_docs/gen_reference_docs.py +3 -1
  131. modal_docs/mdmd/mdmd.py +0 -1
  132. modal_docs/mdmd/signatures.py +1 -2
  133. modal_global_objects/images/base_images.py +28 -0
  134. modal_global_objects/mounts/python_standalone.py +2 -2
  135. modal_proto/__init__.py +1 -1
  136. modal_proto/api.proto +1231 -531
  137. modal_proto/api_grpc.py +750 -430
  138. modal_proto/api_pb2.py +2102 -1176
  139. modal_proto/api_pb2.pyi +8859 -0
  140. modal_proto/api_pb2_grpc.py +1329 -675
  141. modal_proto/api_pb2_grpc.pyi +1416 -0
  142. modal_proto/modal_api_grpc.py +149 -0
  143. modal_proto/modal_options_grpc.py +3 -0
  144. modal_proto/options_pb2.pyi +20 -0
  145. modal_proto/options_pb2_grpc.pyi +7 -0
  146. modal_proto/py.typed +0 -0
  147. modal_version/__init__.py +1 -1
  148. modal_version/_version_generated.py +2 -2
  149. modal/_asgi.py +0 -370
  150. modal/_container_exec.py +0 -128
  151. modal/_container_io_manager.py +0 -646
  152. modal/_container_io_manager.pyi +0 -412
  153. modal/_sandbox_shell.py +0 -49
  154. modal/app_utils.py +0 -20
  155. modal/app_utils.pyi +0 -17
  156. modal/execution_context.pyi +0 -37
  157. modal/shared_volume.py +0 -23
  158. modal/shared_volume.pyi +0 -24
  159. modal-0.62.115.dist-info/RECORD +0 -207
  160. modal_global_objects/images/conda.py +0 -15
  161. modal_global_objects/images/debian_slim.py +0 -15
  162. modal_global_objects/images/micromamba.py +0 -15
  163. test/__init__.py +0 -1
  164. test/aio_test.py +0 -12
  165. test/async_utils_test.py +0 -279
  166. test/blob_test.py +0 -67
  167. test/cli_imports_test.py +0 -149
  168. test/cli_test.py +0 -674
  169. test/client_test.py +0 -203
  170. test/cloud_bucket_mount_test.py +0 -22
  171. test/cls_test.py +0 -636
  172. test/config_test.py +0 -149
  173. test/conftest.py +0 -1485
  174. test/container_app_test.py +0 -50
  175. test/container_test.py +0 -1405
  176. test/cpu_test.py +0 -23
  177. test/decorator_test.py +0 -85
  178. test/deprecation_test.py +0 -34
  179. test/dict_test.py +0 -51
  180. test/e2e_test.py +0 -68
  181. test/error_test.py +0 -7
  182. test/function_serialization_test.py +0 -32
  183. test/function_test.py +0 -791
  184. test/function_utils_test.py +0 -101
  185. test/gpu_test.py +0 -159
  186. test/grpc_utils_test.py +0 -82
  187. test/helpers.py +0 -47
  188. test/image_test.py +0 -814
  189. test/live_reload_test.py +0 -80
  190. test/lookup_test.py +0 -70
  191. test/mdmd_test.py +0 -329
  192. test/mount_test.py +0 -162
  193. test/mounted_files_test.py +0 -327
  194. test/network_file_system_test.py +0 -188
  195. test/notebook_test.py +0 -66
  196. test/object_test.py +0 -41
  197. test/package_utils_test.py +0 -25
  198. test/queue_test.py +0 -115
  199. test/resolver_test.py +0 -59
  200. test/retries_test.py +0 -67
  201. test/runner_test.py +0 -85
  202. test/sandbox_test.py +0 -191
  203. test/schedule_test.py +0 -15
  204. test/scheduler_placement_test.py +0 -57
  205. test/secret_test.py +0 -89
  206. test/serialization_test.py +0 -50
  207. test/stub_composition_test.py +0 -10
  208. test/stub_test.py +0 -361
  209. test/test_asgi_wrapper.py +0 -234
  210. test/token_flow_test.py +0 -18
  211. test/traceback_test.py +0 -135
  212. test/tunnel_test.py +0 -29
  213. test/utils_test.py +0 -88
  214. test/version_test.py +0 -14
  215. test/volume_test.py +0 -397
  216. test/watcher_test.py +0 -58
  217. test/webhook_test.py +0 -145
  218. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/LICENSE +0 -0
  219. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/WHEEL +0 -0
  220. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/entry_points.txt +0 -0
@@ -5,23 +5,26 @@ import hashlib
5
5
  import io
6
6
  import os
7
7
  import platform
8
- from contextlib import contextmanager
8
+ import time
9
+ from collections.abc import AsyncIterator
10
+ from contextlib import AbstractContextManager, contextmanager
9
11
  from pathlib import Path, PurePosixPath
10
- from typing import Any, AsyncIterator, BinaryIO, Callable, List, Optional, Union
12
+ from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Optional, Union
11
13
  from urllib.parse import urlparse
12
14
 
13
- from aiohttp import BytesIOPayload
14
- from aiohttp.abc import AbstractStreamWriter
15
-
16
15
  from modal_proto import api_pb2
16
+ from modal_proto.modal_api_grpc import ModalClientModal
17
17
 
18
18
  from ..exception import ExecutionError
19
- from .async_utils import retry
19
+ from .async_utils import TaskContext, retry
20
20
  from .grpc_utils import retry_transient_errors
21
- from .hash_utils import UploadHashes, get_sha256_hex, get_upload_hashes
22
- from .http_utils import http_client_with_tls
21
+ from .hash_utils import UploadHashes, get_upload_hashes
22
+ from .http_utils import ClientSessionRegistry
23
23
  from .logger import logger
24
24
 
25
+ if TYPE_CHECKING:
26
+ from .bytes_io_segment_payload import BytesIOSegmentPayload
27
+
25
28
  # Max size for function inputs and outputs.
26
29
  MAX_OBJECT_SIZE_BYTES = 2 * 1024 * 1024 # 2 MiB
27
30
 
@@ -35,129 +38,59 @@ BLOB_MAX_PARALLELISM = 10
35
38
  # read ~16MiB chunks by default
36
39
  DEFAULT_SEGMENT_CHUNK_SIZE = 2**24
37
40
 
38
-
39
- class BytesIOSegmentPayload(BytesIOPayload):
40
- """Modified bytes payload for concurrent sends of chunks from the same file.
41
-
42
- Adds:
43
- * read limit using remaining_bytes, in order to split files across streams
44
- * larger read chunk (to prevent excessive read contention between parts)
45
- * calculates an md5 for the segment
46
-
47
- Feels like this should be in some standard lib...
48
- """
49
-
50
- def __init__(
51
- self,
52
- bytes_io: BinaryIO, # should *not* be shared as IO position modification is not locked
53
- segment_start: int,
54
- segment_length: int,
55
- chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
56
- ):
57
- # not thread safe constructor!
58
- super().__init__(bytes_io)
59
- self.initial_seek_pos = bytes_io.tell()
60
- self.segment_start = segment_start
61
- self.segment_length = segment_length
62
- # seek to start of file segment we are interested in, in order to make .size() evaluate correctly
63
- self._value.seek(self.initial_seek_pos + segment_start)
64
- assert self.segment_length <= super().size
65
- self.chunk_size = chunk_size
66
- self.reset_state()
67
-
68
- def reset_state(self):
69
- self._md5_checksum = hashlib.md5()
70
- self.num_bytes_read = 0
71
- self._value.seek(self.initial_seek_pos)
72
-
73
- @contextmanager
74
- def reset_on_error(self):
75
- try:
76
- yield
77
- finally:
78
- self.reset_state()
79
-
80
- @property
81
- def size(self) -> int:
82
- return self.segment_length
83
-
84
- def md5_checksum(self):
85
- return self._md5_checksum
86
-
87
- async def write(self, writer: AbstractStreamWriter):
88
- loop = asyncio.get_event_loop()
89
-
90
- async def safe_read():
91
- read_start = self.initial_seek_pos + self.segment_start + self.num_bytes_read
92
- self._value.seek(read_start)
93
- num_bytes = min(self.chunk_size, self.remaining_bytes())
94
- chunk = await loop.run_in_executor(None, self._value.read, num_bytes)
95
-
96
- await loop.run_in_executor(None, self._md5_checksum.update, chunk)
97
- self.num_bytes_read += len(chunk)
98
- return chunk
99
-
100
- chunk = await safe_read()
101
- while chunk and self.remaining_bytes() > 0:
102
- await writer.write(chunk)
103
- chunk = await safe_read()
104
- if chunk:
105
- await writer.write(chunk)
106
-
107
- def remaining_bytes(self):
108
- return self.segment_length - self.num_bytes_read
41
+ # Files larger than this will be multipart uploaded. The server might request multipart upload for smaller files as
42
+ # well, but the limit will never be raised.
43
+ # TODO(dano): remove this once we stop requiring md5 for blobs
44
+ MULTIPART_UPLOAD_THRESHOLD = 1024**3
109
45
 
110
46
 
111
47
  @retry(n_attempts=5, base_delay=0.5, timeout=None)
112
48
  async def _upload_to_s3_url(
113
49
  upload_url,
114
- payload: BytesIOSegmentPayload,
50
+ payload: "BytesIOSegmentPayload",
115
51
  content_md5_b64: Optional[str] = None,
116
52
  content_type: Optional[str] = "application/octet-stream", # set to None to force omission of ContentType header
117
53
  ) -> str:
118
54
  """Returns etag of s3 object which is a md5 hex checksum of the uploaded content"""
119
55
  with payload.reset_on_error(): # ensure retries read the same data
120
- async with http_client_with_tls(timeout=None) as session:
121
- headers = {}
122
- if content_md5_b64 and use_md5(upload_url):
123
- headers["Content-MD5"] = content_md5_b64
124
- if content_type:
125
- headers["Content-Type"] = content_type
126
-
127
- async with session.put(
128
- upload_url,
129
- data=payload,
130
- headers=headers,
131
- skip_auto_headers=["content-type"] if content_type is None else [],
132
- ) as resp:
133
- # S3 signal to slow down request rate.
134
- if resp.status == 503:
135
- logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
136
- await asyncio.sleep(1)
137
-
138
- if resp.status != 200:
139
- try:
140
- text = await resp.text()
141
- except Exception:
142
- text = "<no body>"
143
- raise ExecutionError(f"Put to url {upload_url} failed with status {resp.status}: {text}")
144
-
145
- # client side ETag checksum verification
146
- # the s3 ETag of a single part upload is a quoted md5 hex of the uploaded content
147
- etag = resp.headers["ETag"].strip()
148
- if etag.startswith(("W/", "w/")): # see https://www.rfc-editor.org/rfc/rfc7232#section-2.3
149
- etag = etag[2:]
150
- if etag[0] == '"' and etag[-1] == '"':
151
- etag = etag[1:-1]
152
- remote_md5 = etag
153
-
154
- local_md5_hex = payload.md5_checksum().hexdigest()
155
- if local_md5_hex != remote_md5:
156
- raise ExecutionError(
157
- f"Local data and remote data checksum mismatch ({local_md5_hex} vs {remote_md5})"
158
- )
159
-
160
- return remote_md5
56
+ headers = {}
57
+ if content_md5_b64 and use_md5(upload_url):
58
+ headers["Content-MD5"] = content_md5_b64
59
+ if content_type:
60
+ headers["Content-Type"] = content_type
61
+
62
+ async with ClientSessionRegistry.get_session().put(
63
+ upload_url,
64
+ data=payload,
65
+ headers=headers,
66
+ skip_auto_headers=["content-type"] if content_type is None else [],
67
+ ) as resp:
68
+ # S3 signal to slow down request rate.
69
+ if resp.status == 503:
70
+ logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
71
+ await asyncio.sleep(1)
72
+
73
+ if resp.status != 200:
74
+ try:
75
+ text = await resp.text()
76
+ except Exception:
77
+ text = "<no body>"
78
+ raise ExecutionError(f"Put to url {upload_url} failed with status {resp.status}: {text}")
79
+
80
+ # client side ETag checksum verification
81
+ # the s3 ETag of a single part upload is a quoted md5 hex of the uploaded content
82
+ etag = resp.headers["ETag"].strip()
83
+ if etag.startswith(("W/", "w/")): # see https://www.rfc-editor.org/rfc/rfc7232#section-2.3
84
+ etag = etag[2:]
85
+ if etag[0] == '"' and etag[-1] == '"':
86
+ etag = etag[1:-1]
87
+ remote_md5 = etag
88
+
89
+ local_md5_hex = payload.md5_checksum().hexdigest()
90
+ if local_md5_hex != remote_md5:
91
+ raise ExecutionError(f"Local data and remote data checksum mismatch ({local_md5_hex} vs {remote_md5})")
92
+
93
+ return remote_md5
161
94
 
162
95
 
163
96
  async def perform_multipart_upload(
@@ -165,17 +98,20 @@ async def perform_multipart_upload(
165
98
  *,
166
99
  content_length: int,
167
100
  max_part_size: int,
168
- part_urls: List[str],
101
+ part_urls: list[str],
169
102
  completion_url: str,
170
103
  upload_chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
171
- ):
104
+ progress_report_cb: Optional[Callable] = None,
105
+ ) -> None:
106
+ from .bytes_io_segment_payload import BytesIOSegmentPayload
107
+
172
108
  upload_coros = []
173
109
  file_offset = 0
174
110
  num_bytes_left = content_length
175
111
 
176
112
  # Give each part its own IO reader object to avoid needing to
177
113
  # lock access to the reader's position pointer.
178
- data_file_readers: List[BinaryIO]
114
+ data_file_readers: list[BinaryIO]
179
115
  if isinstance(data_file, io.BytesIO):
180
116
  view = data_file.getbuffer() # does not copy data
181
117
  data_file_readers = [io.BytesIO(view) for _ in range(len(part_urls))]
@@ -190,12 +126,13 @@ async def perform_multipart_upload(
190
126
  segment_start=file_offset,
191
127
  segment_length=part_length_bytes,
192
128
  chunk_size=upload_chunk_size,
129
+ progress_report_cb=progress_report_cb,
193
130
  )
194
131
  upload_coros.append(_upload_to_s3_url(part_url, payload=part_payload, content_type=None))
195
132
  num_bytes_left -= part_length_bytes
196
133
  file_offset += part_length_bytes
197
134
 
198
- part_etags = await asyncio.gather(*upload_coros)
135
+ part_etags = await TaskContext.gather(*upload_coros)
199
136
 
200
137
  # The body of the complete_multipart_upload command needs some data in xml format:
201
138
  completion_body = "<CompleteMultipartUpload>\n"
@@ -207,25 +144,24 @@ async def perform_multipart_upload(
207
144
  bin_hash_parts = [bytes.fromhex(etag) for etag in part_etags]
208
145
 
209
146
  expected_multipart_etag = hashlib.md5(b"".join(bin_hash_parts)).hexdigest() + f"-{len(part_etags)}"
210
- async with http_client_with_tls(timeout=None) as session:
211
- resp = await session.post(
212
- completion_url, data=completion_body.encode("ascii"), skip_auto_headers=["content-type"]
213
- )
214
- if resp.status != 200:
215
- try:
216
- msg = await resp.text()
217
- except Exception:
218
- msg = "<no body>"
219
- raise ExecutionError(f"Error when completing multipart upload: {resp.status}\n{msg}")
220
- else:
221
- response_body = await resp.text()
222
- if expected_multipart_etag not in response_body:
223
- raise ExecutionError(
224
- f"Hash mismatch on multipart upload assembly: {expected_multipart_etag} not in {response_body}"
225
- )
147
+ resp = await ClientSessionRegistry.get_session().post(
148
+ completion_url, data=completion_body.encode("ascii"), skip_auto_headers=["content-type"]
149
+ )
150
+ if resp.status != 200:
151
+ try:
152
+ msg = await resp.text()
153
+ except Exception:
154
+ msg = "<no body>"
155
+ raise ExecutionError(f"Error when completing multipart upload: {resp.status}\n{msg}")
156
+ else:
157
+ response_body = await resp.text()
158
+ if expected_multipart_etag not in response_body:
159
+ raise ExecutionError(
160
+ f"Hash mismatch on multipart upload assembly: {expected_multipart_etag} not in {response_body}"
161
+ )
226
162
 
227
163
 
228
- def get_content_length(data: BinaryIO):
164
+ def get_content_length(data: BinaryIO) -> int:
229
165
  # *Remaining* length of file from current seek position
230
166
  pos = data.tell()
231
167
  data.seek(0, os.SEEK_END)
@@ -234,7 +170,9 @@ def get_content_length(data: BinaryIO):
234
170
  return content_length - pos
235
171
 
236
172
 
237
- async def _blob_upload(upload_hashes: UploadHashes, data: Union[bytes, BinaryIO], stub) -> str:
173
+ async def _blob_upload(
174
+ upload_hashes: UploadHashes, data: Union[bytes, BinaryIO], stub, progress_report_cb: Optional[Callable] = None
175
+ ) -> str:
238
176
  if isinstance(data, bytes):
239
177
  data = io.BytesIO(data)
240
178
 
@@ -257,9 +195,14 @@ async def _blob_upload(upload_hashes: UploadHashes, data: Union[bytes, BinaryIO]
257
195
  part_urls=resp.multipart.upload_urls,
258
196
  completion_url=resp.multipart.completion_url,
259
197
  upload_chunk_size=DEFAULT_SEGMENT_CHUNK_SIZE,
198
+ progress_report_cb=progress_report_cb,
260
199
  )
261
200
  else:
262
- payload = BytesIOSegmentPayload(data, segment_start=0, segment_length=content_length)
201
+ from .bytes_io_segment_payload import BytesIOSegmentPayload
202
+
203
+ payload = BytesIOSegmentPayload(
204
+ data, segment_start=0, segment_length=content_length, progress_report_cb=progress_report_cb
205
+ )
263
206
  await _upload_to_s3_url(
264
207
  resp.upload_url,
265
208
  payload,
@@ -267,79 +210,103 @@ async def _blob_upload(upload_hashes: UploadHashes, data: Union[bytes, BinaryIO]
267
210
  content_md5_b64=upload_hashes.md5_base64,
268
211
  )
269
212
 
213
+ if progress_report_cb:
214
+ progress_report_cb(complete=True)
215
+
270
216
  return blob_id
271
217
 
272
218
 
273
- async def blob_upload(payload: bytes, stub) -> str:
219
+ async def blob_upload(payload: bytes, stub: ModalClientModal) -> str:
220
+ size_mib = len(payload) / 1024 / 1024
221
+ logger.debug(f"Uploading large blob of size {size_mib:.2f} MiB")
222
+ t0 = time.time()
274
223
  if isinstance(payload, str):
275
224
  logger.warning("Blob uploading string, not bytes - auto-encoding as utf8")
276
225
  payload = payload.encode("utf8")
277
226
  upload_hashes = get_upload_hashes(payload)
278
- return await _blob_upload(upload_hashes, payload, stub)
227
+ blob_id = await _blob_upload(upload_hashes, payload, stub)
228
+ dur_s = max(time.time() - t0, 0.001) # avoid division by zero
229
+ throughput_mib_s = (size_mib) / dur_s
230
+ logger.debug(f"Uploaded large blob of size {size_mib:.2f} MiB ({throughput_mib_s:.2f} MiB/s)." f" {blob_id}")
231
+ return blob_id
279
232
 
280
233
 
281
- async def blob_upload_file(file_obj: BinaryIO, stub) -> str:
282
- upload_hashes = get_upload_hashes(file_obj)
283
- return await _blob_upload(upload_hashes, file_obj, stub)
234
+ async def blob_upload_file(
235
+ file_obj: BinaryIO,
236
+ stub: ModalClientModal,
237
+ progress_report_cb: Optional[Callable] = None,
238
+ sha256_hex: Optional[str] = None,
239
+ md5_hex: Optional[str] = None,
240
+ ) -> str:
241
+ upload_hashes = get_upload_hashes(file_obj, sha256_hex=sha256_hex, md5_hex=md5_hex)
242
+ return await _blob_upload(upload_hashes, file_obj, stub, progress_report_cb)
284
243
 
285
244
 
286
245
  @retry(n_attempts=5, base_delay=0.1, timeout=None)
287
- async def _download_from_url(download_url) -> bytes:
288
- async with http_client_with_tls(timeout=None) as session:
289
- async with session.get(download_url) as resp:
290
- # S3 signal to slow down request rate.
291
- if resp.status == 503:
292
- logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
293
- await asyncio.sleep(1)
294
-
295
- if resp.status != 200:
296
- text = await resp.text()
297
- raise ExecutionError(f"Get from url failed with status {resp.status}: {text}")
298
- return await resp.read()
299
-
300
-
301
- async def blob_download(blob_id, stub) -> bytes:
302
- # convenience function reading all of the downloaded file into memory
246
+ async def _download_from_url(download_url: str) -> bytes:
247
+ async with ClientSessionRegistry.get_session().get(download_url) as s3_resp:
248
+ # S3 signal to slow down request rate.
249
+ if s3_resp.status == 503:
250
+ logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
251
+ await asyncio.sleep(1)
252
+
253
+ if s3_resp.status != 200:
254
+ text = await s3_resp.text()
255
+ raise ExecutionError(f"Get from url failed with status {s3_resp.status}: {text}")
256
+ return await s3_resp.read()
257
+
258
+
259
+ async def blob_download(blob_id: str, stub: ModalClientModal) -> bytes:
260
+ """Convenience function for reading all of the downloaded file into memory."""
261
+ logger.debug(f"Downloading large blob {blob_id}")
262
+ t0 = time.time()
303
263
  req = api_pb2.BlobGetRequest(blob_id=blob_id)
304
264
  resp = await retry_transient_errors(stub.BlobGet, req)
305
-
306
- return await _download_from_url(resp.download_url)
265
+ data = await _download_from_url(resp.download_url)
266
+ size_mib = len(data) / 1024 / 1024
267
+ dur_s = max(time.time() - t0, 0.001) # avoid division by zero
268
+ throughput_mib_s = size_mib / dur_s
269
+ logger.debug(f"Downloaded large blob {blob_id} of size {size_mib:.2f} MiB ({throughput_mib_s:.2f} MiB/s)")
270
+ return data
307
271
 
308
272
 
309
- async def blob_iter(blob_id, stub) -> AsyncIterator[bytes]:
273
+ async def blob_iter(blob_id: str, stub: ModalClientModal) -> AsyncIterator[bytes]:
310
274
  req = api_pb2.BlobGetRequest(blob_id=blob_id)
311
275
  resp = await retry_transient_errors(stub.BlobGet, req)
312
276
  download_url = resp.download_url
313
- async with http_client_with_tls(timeout=None) as session:
314
- async with session.get(download_url) as resp:
315
- # S3 signal to slow down request rate.
316
- if resp.status == 503:
317
- logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
318
- await asyncio.sleep(1)
277
+ async with ClientSessionRegistry.get_session().get(download_url) as s3_resp:
278
+ # S3 signal to slow down request rate.
279
+ if s3_resp.status == 503:
280
+ logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
281
+ await asyncio.sleep(1)
319
282
 
320
- if resp.status != 200:
321
- text = await resp.text()
322
- raise ExecutionError(f"Get from url failed with status {resp.status}: {text}")
283
+ if s3_resp.status != 200:
284
+ text = await s3_resp.text()
285
+ raise ExecutionError(f"Get from url failed with status {s3_resp.status}: {text}")
323
286
 
324
- async for chunk in resp.content.iter_any():
325
- yield chunk
287
+ async for chunk in s3_resp.content.iter_any():
288
+ yield chunk
326
289
 
327
290
 
328
291
  @dataclasses.dataclass
329
292
  class FileUploadSpec:
330
- source: Callable[[], BinaryIO]
293
+ source: Callable[[], Union[AbstractContextManager, BinaryIO]]
331
294
  source_description: Any
332
295
  mount_filename: str
333
296
 
334
297
  use_blob: bool
335
298
  content: Optional[bytes] # typically None if using blob, required otherwise
336
299
  sha256_hex: str
300
+ md5_hex: str
337
301
  mode: int # file permission bits (last 12 bits of st_mode)
338
302
  size: int
339
303
 
340
304
 
341
305
  def _get_file_upload_spec(
342
- source: Callable[[], BinaryIO], source_description: Any, mount_filename: PurePosixPath, mode: int
306
+ source: Callable[[], Union[AbstractContextManager, BinaryIO]],
307
+ source_description: Any,
308
+ mount_filename: PurePosixPath,
309
+ mode: int,
343
310
  ) -> FileUploadSpec:
344
311
  with source() as fp:
345
312
  # Current position is ignored - we always upload from position 0
@@ -348,13 +315,15 @@ def _get_file_upload_spec(
348
315
  fp.seek(0)
349
316
 
350
317
  if size >= LARGE_FILE_LIMIT:
318
+ # TODO(dano): remove the placeholder md5 once we stop requiring md5 for blobs
319
+ md5_hex = "baadbaadbaadbaadbaadbaadbaadbaad" if size > MULTIPART_UPLOAD_THRESHOLD else None
351
320
  use_blob = True
352
321
  content = None
353
- sha256_hex = get_sha256_hex(fp)
322
+ hashes = get_upload_hashes(fp, md5_hex=md5_hex)
354
323
  else:
355
324
  use_blob = False
356
325
  content = fp.read()
357
- sha256_hex = get_sha256_hex(content)
326
+ hashes = get_upload_hashes(content)
358
327
 
359
328
  return FileUploadSpec(
360
329
  source=source,
@@ -362,7 +331,8 @@ def _get_file_upload_spec(
362
331
  mount_filename=mount_filename.as_posix(),
363
332
  use_blob=use_blob,
364
333
  content=content,
365
- sha256_hex=sha256_hex,
334
+ sha256_hex=hashes.sha256_hex(),
335
+ md5_hex=hashes.md5_hex(),
366
336
  mode=mode & 0o7777,
367
337
  size=size,
368
338
  )
@@ -383,10 +353,11 @@ def get_file_upload_spec_from_path(
383
353
 
384
354
 
385
355
  def get_file_upload_spec_from_fileobj(fp: BinaryIO, mount_filename: PurePosixPath, mode: int) -> FileUploadSpec:
356
+ @contextmanager
386
357
  def source():
387
358
  # We ignore position in stream and always upload from position 0
388
359
  fp.seek(0)
389
- return fp
360
+ yield fp
390
361
 
391
362
  return _get_file_upload_spec(
392
363
  source,
@@ -403,7 +374,7 @@ def use_md5(url: str) -> bool:
403
374
  https://github.com/spulec/moto/issues/816
404
375
  """
405
376
  host = urlparse(url).netloc.split(":")[0]
406
- if host.endswith(".amazonaws.com"):
377
+ if host.endswith(".amazonaws.com") or host.endswith(".r2.cloudflarestorage.com"):
407
378
  return True
408
379
  elif host in ["127.0.0.1", "localhost", "172.21.0.1"]:
409
380
  return False
@@ -0,0 +1,97 @@
1
+ # Copyright Modal Labs 2024
2
+
3
+ import asyncio
4
+ import hashlib
5
+ from contextlib import contextmanager
6
+ from typing import BinaryIO, Callable, Optional
7
+
8
+ # Note: this module needs to import aiohttp in global scope
9
+ # This takes about 50ms and isn't needed in many cases for Modal execution
10
+ # To avoid this, we import it in local scope when needed (blob_utils.py)
11
+ from aiohttp import BytesIOPayload
12
+ from aiohttp.abc import AbstractStreamWriter
13
+
14
+ # read ~16MiB chunks by default
15
+ DEFAULT_SEGMENT_CHUNK_SIZE = 2**24
16
+
17
+
18
+ class BytesIOSegmentPayload(BytesIOPayload):
19
+ """Modified bytes payload for concurrent sends of chunks from the same file.
20
+
21
+ Adds:
22
+ * read limit using remaining_bytes, in order to split files across streams
23
+ * larger read chunk (to prevent excessive read contention between parts)
24
+ * calculates an md5 for the segment
25
+
26
+ Feels like this should be in some standard lib...
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ bytes_io: BinaryIO, # should *not* be shared as IO position modification is not locked
32
+ segment_start: int,
33
+ segment_length: int,
34
+ chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
35
+ progress_report_cb: Optional[Callable] = None,
36
+ ):
37
+ # not thread safe constructor!
38
+ super().__init__(bytes_io)
39
+ self.initial_seek_pos = bytes_io.tell()
40
+ self.segment_start = segment_start
41
+ self.segment_length = segment_length
42
+ # seek to start of file segment we are interested in, in order to make .size() evaluate correctly
43
+ self._value.seek(self.initial_seek_pos + segment_start)
44
+ assert self.segment_length <= super().size
45
+ self.chunk_size = chunk_size
46
+ self.progress_report_cb = progress_report_cb or (lambda *_, **__: None)
47
+ self.reset_state()
48
+
49
+ def reset_state(self):
50
+ self._md5_checksum = hashlib.md5()
51
+ self.num_bytes_read = 0
52
+ self._value.seek(self.initial_seek_pos)
53
+
54
+ @contextmanager
55
+ def reset_on_error(self):
56
+ try:
57
+ yield
58
+ except Exception as exc:
59
+ try:
60
+ self.progress_report_cb(reset=True)
61
+ except Exception as cb_exc:
62
+ raise cb_exc from exc
63
+ raise exc
64
+ finally:
65
+ self.reset_state()
66
+
67
+ @property
68
+ def size(self) -> int:
69
+ return self.segment_length
70
+
71
+ def md5_checksum(self):
72
+ return self._md5_checksum
73
+
74
+ async def write(self, writer: "AbstractStreamWriter"):
75
+ loop = asyncio.get_event_loop()
76
+
77
+ async def safe_read():
78
+ read_start = self.initial_seek_pos + self.segment_start + self.num_bytes_read
79
+ self._value.seek(read_start)
80
+ num_bytes = min(self.chunk_size, self.remaining_bytes())
81
+ chunk = await loop.run_in_executor(None, self._value.read, num_bytes)
82
+
83
+ await loop.run_in_executor(None, self._md5_checksum.update, chunk)
84
+ self.num_bytes_read += len(chunk)
85
+ return chunk
86
+
87
+ chunk = await safe_read()
88
+ while chunk and self.remaining_bytes() > 0:
89
+ await writer.write(chunk)
90
+ self.progress_report_cb(advance=len(chunk))
91
+ chunk = await safe_read()
92
+ if chunk:
93
+ await writer.write(chunk)
94
+ self.progress_report_cb(advance=len(chunk))
95
+
96
+ def remaining_bytes(self):
97
+ return self.segment_length - self.num_bytes_read