modal 0.62.16__py3-none-any.whl → 0.72.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +17 -13
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +420 -937
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -59
  11. modal/_resources.py +51 -0
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1036 -0
  15. modal/_runtime/execution_context.py +89 -0
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +134 -9
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +52 -16
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +479 -100
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +460 -171
  29. modal/_utils/grpc_testing.py +47 -31
  30. modal/_utils/grpc_utils.py +62 -109
  31. modal/_utils/hash_utils.py +61 -19
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +5 -7
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +14 -12
  43. modal/app.py +1003 -314
  44. modal/app.pyi +540 -264
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +63 -53
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +205 -45
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +62 -14
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +64 -58
  55. modal/cli/launch.py +32 -18
  56. modal/cli/network_file_system.py +64 -83
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +35 -10
  59. modal/cli/programs/vscode.py +60 -10
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +234 -131
  62. modal/cli/secret.py +8 -7
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +79 -10
  65. modal/cli/volume.py +110 -109
  66. modal/client.py +250 -144
  67. modal/client.pyi +157 -118
  68. modal/cloud_bucket_mount.py +108 -34
  69. modal/cloud_bucket_mount.pyi +32 -38
  70. modal/cls.py +535 -148
  71. modal/cls.pyi +190 -146
  72. modal/config.py +41 -19
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +111 -65
  76. modal/dict.pyi +136 -131
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +34 -43
  80. modal/experimental.py +61 -2
  81. modal/extensions/ipython.py +5 -5
  82. modal/file_io.py +537 -0
  83. modal/file_io.pyi +235 -0
  84. modal/file_pattern_matcher.py +197 -0
  85. modal/functions.py +906 -911
  86. modal/functions.pyi +466 -430
  87. modal/gpu.py +57 -44
  88. modal/image.py +1089 -479
  89. modal/image.pyi +584 -228
  90. modal/io_streams.py +434 -0
  91. modal/io_streams.pyi +122 -0
  92. modal/mount.py +314 -101
  93. modal/mount.pyi +241 -235
  94. modal/network_file_system.py +92 -92
  95. modal/network_file_system.pyi +152 -110
  96. modal/object.py +67 -36
  97. modal/object.pyi +166 -143
  98. modal/output.py +63 -0
  99. modal/parallel_map.py +434 -0
  100. modal/parallel_map.pyi +75 -0
  101. modal/partial_function.py +282 -117
  102. modal/partial_function.pyi +222 -129
  103. modal/proxy.py +15 -12
  104. modal/proxy.pyi +3 -8
  105. modal/queue.py +182 -65
  106. modal/queue.pyi +218 -118
  107. modal/requirements/2024.04.txt +29 -0
  108. modal/requirements/2024.10.txt +16 -0
  109. modal/requirements/README.md +21 -0
  110. modal/requirements/base-images.json +22 -0
  111. modal/retries.py +48 -7
  112. modal/runner.py +459 -156
  113. modal/runner.pyi +135 -71
  114. modal/running_app.py +38 -0
  115. modal/sandbox.py +514 -236
  116. modal/sandbox.pyi +397 -169
  117. modal/schedule.py +4 -4
  118. modal/scheduler_placement.py +20 -3
  119. modal/secret.py +56 -31
  120. modal/secret.pyi +62 -42
  121. modal/serving.py +51 -56
  122. modal/serving.pyi +44 -36
  123. modal/stream_type.py +15 -0
  124. modal/token_flow.py +5 -3
  125. modal/token_flow.pyi +37 -32
  126. modal/volume.py +285 -157
  127. modal/volume.pyi +249 -184
  128. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/METADATA +7 -7
  129. modal-0.72.11.dist-info/RECORD +174 -0
  130. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
  131. modal_docs/gen_reference_docs.py +3 -1
  132. modal_docs/mdmd/mdmd.py +0 -1
  133. modal_docs/mdmd/signatures.py +5 -2
  134. modal_global_objects/images/base_images.py +28 -0
  135. modal_global_objects/mounts/python_standalone.py +2 -2
  136. modal_proto/__init__.py +1 -1
  137. modal_proto/api.proto +1288 -533
  138. modal_proto/api_grpc.py +856 -456
  139. modal_proto/api_pb2.py +2165 -1157
  140. modal_proto/api_pb2.pyi +8859 -0
  141. modal_proto/api_pb2_grpc.py +1674 -855
  142. modal_proto/api_pb2_grpc.pyi +1416 -0
  143. modal_proto/modal_api_grpc.py +149 -0
  144. modal_proto/modal_options_grpc.py +3 -0
  145. modal_proto/options_pb2.pyi +20 -0
  146. modal_proto/options_pb2_grpc.pyi +7 -0
  147. modal_proto/py.typed +0 -0
  148. modal_version/__init__.py +1 -1
  149. modal_version/_version_generated.py +2 -2
  150. modal/_asgi.py +0 -370
  151. modal/_container_entrypoint.pyi +0 -378
  152. modal/_container_exec.py +0 -128
  153. modal/_sandbox_shell.py +0 -49
  154. modal/shared_volume.py +0 -23
  155. modal/shared_volume.pyi +0 -24
  156. modal/stub.py +0 -783
  157. modal/stub.pyi +0 -332
  158. modal-0.62.16.dist-info/RECORD +0 -198
  159. modal_global_objects/images/conda.py +0 -15
  160. modal_global_objects/images/debian_slim.py +0 -15
  161. modal_global_objects/images/micromamba.py +0 -15
  162. test/__init__.py +0 -1
  163. test/aio_test.py +0 -12
  164. test/async_utils_test.py +0 -262
  165. test/blob_test.py +0 -67
  166. test/cli_imports_test.py +0 -149
  167. test/cli_test.py +0 -659
  168. test/client_test.py +0 -194
  169. test/cls_test.py +0 -630
  170. test/config_test.py +0 -137
  171. test/conftest.py +0 -1420
  172. test/container_app_test.py +0 -32
  173. test/container_test.py +0 -1389
  174. test/cpu_test.py +0 -23
  175. test/decorator_test.py +0 -85
  176. test/deprecation_test.py +0 -34
  177. test/dict_test.py +0 -33
  178. test/e2e_test.py +0 -68
  179. test/error_test.py +0 -7
  180. test/function_serialization_test.py +0 -32
  181. test/function_test.py +0 -653
  182. test/function_utils_test.py +0 -101
  183. test/gpu_test.py +0 -159
  184. test/grpc_utils_test.py +0 -141
  185. test/helpers.py +0 -42
  186. test/image_test.py +0 -669
  187. test/live_reload_test.py +0 -80
  188. test/lookup_test.py +0 -70
  189. test/mdmd_test.py +0 -329
  190. test/mount_test.py +0 -162
  191. test/mounted_files_test.py +0 -329
  192. test/network_file_system_test.py +0 -181
  193. test/notebook_test.py +0 -66
  194. test/object_test.py +0 -41
  195. test/package_utils_test.py +0 -25
  196. test/queue_test.py +0 -97
  197. test/resolver_test.py +0 -58
  198. test/retries_test.py +0 -67
  199. test/runner_test.py +0 -85
  200. test/sandbox_test.py +0 -191
  201. test/schedule_test.py +0 -15
  202. test/scheduler_placement_test.py +0 -29
  203. test/secret_test.py +0 -78
  204. test/serialization_test.py +0 -42
  205. test/stub_composition_test.py +0 -10
  206. test/stub_test.py +0 -360
  207. test/test_asgi_wrapper.py +0 -234
  208. test/token_flow_test.py +0 -18
  209. test/traceback_test.py +0 -135
  210. test/tunnel_test.py +0 -29
  211. test/utils_test.py +0 -88
  212. test/version_test.py +0 -14
  213. test/volume_test.py +0 -341
  214. test/watcher_test.py +0 -30
  215. test/webhook_test.py +0 -146
  216. /modal/{requirements.312.txt → requirements/2023.12.312.txt} +0 -0
  217. /modal/{requirements.txt → requirements/2023.12.txt} +0 -0
  218. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
  219. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
  220. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
@@ -5,23 +5,26 @@ import hashlib
5
5
  import io
6
6
  import os
7
7
  import platform
8
- from contextlib import contextmanager
8
+ import time
9
+ from collections.abc import AsyncIterator
10
+ from contextlib import AbstractContextManager, contextmanager
9
11
  from pathlib import Path, PurePosixPath
10
- from typing import Any, AsyncIterator, BinaryIO, Callable, List, Optional, Union
12
+ from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Optional, Union
11
13
  from urllib.parse import urlparse
12
14
 
13
- from aiohttp import BytesIOPayload
14
- from aiohttp.abc import AbstractStreamWriter
15
-
16
15
  from modal_proto import api_pb2
16
+ from modal_proto.modal_api_grpc import ModalClientModal
17
17
 
18
18
  from ..exception import ExecutionError
19
- from .async_utils import retry
19
+ from .async_utils import TaskContext, retry
20
20
  from .grpc_utils import retry_transient_errors
21
- from .hash_utils import UploadHashes, get_sha256_hex, get_upload_hashes
22
- from .http_utils import http_client_with_tls
21
+ from .hash_utils import UploadHashes, get_upload_hashes
22
+ from .http_utils import ClientSessionRegistry
23
23
  from .logger import logger
24
24
 
25
+ if TYPE_CHECKING:
26
+ from .bytes_io_segment_payload import BytesIOSegmentPayload
27
+
25
28
  # Max size for function inputs and outputs.
26
29
  MAX_OBJECT_SIZE_BYTES = 2 * 1024 * 1024 # 2 MiB
27
30
 
@@ -35,129 +38,59 @@ BLOB_MAX_PARALLELISM = 10
35
38
  # read ~16MiB chunks by default
36
39
  DEFAULT_SEGMENT_CHUNK_SIZE = 2**24
37
40
 
38
-
39
- class BytesIOSegmentPayload(BytesIOPayload):
40
- """Modified bytes payload for concurrent sends of chunks from the same file.
41
-
42
- Adds:
43
- * read limit using remaining_bytes, in order to split files across streams
44
- * larger read chunk (to prevent excessive read contention between parts)
45
- * calculates an md5 for the segment
46
-
47
- Feels like this should be in some standard lib...
48
- """
49
-
50
- def __init__(
51
- self,
52
- bytes_io: BinaryIO, # should *not* be shared as IO position modification is not locked
53
- segment_start: int,
54
- segment_length: int,
55
- chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
56
- ):
57
- # not thread safe constructor!
58
- super().__init__(bytes_io)
59
- self.initial_seek_pos = bytes_io.tell()
60
- self.segment_start = segment_start
61
- self.segment_length = segment_length
62
- # seek to start of file segment we are interested in, in order to make .size() evaluate correctly
63
- self._value.seek(self.initial_seek_pos + segment_start)
64
- assert self.segment_length <= super().size
65
- self.chunk_size = chunk_size
66
- self.reset_state()
67
-
68
- def reset_state(self):
69
- self._md5_checksum = hashlib.md5()
70
- self.num_bytes_read = 0
71
- self._value.seek(self.initial_seek_pos)
72
-
73
- @contextmanager
74
- def reset_on_error(self):
75
- try:
76
- yield
77
- finally:
78
- self.reset_state()
79
-
80
- @property
81
- def size(self) -> int:
82
- return self.segment_length
83
-
84
- def md5_checksum(self):
85
- return self._md5_checksum
86
-
87
- async def write(self, writer: AbstractStreamWriter):
88
- loop = asyncio.get_event_loop()
89
-
90
- async def safe_read():
91
- read_start = self.initial_seek_pos + self.segment_start + self.num_bytes_read
92
- self._value.seek(read_start)
93
- num_bytes = min(self.chunk_size, self.remaining_bytes())
94
- chunk = await loop.run_in_executor(None, self._value.read, num_bytes)
95
-
96
- await loop.run_in_executor(None, self._md5_checksum.update, chunk)
97
- self.num_bytes_read += len(chunk)
98
- return chunk
99
-
100
- chunk = await safe_read()
101
- while chunk and self.remaining_bytes() > 0:
102
- await writer.write(chunk)
103
- chunk = await safe_read()
104
- if chunk:
105
- await writer.write(chunk)
106
-
107
- def remaining_bytes(self):
108
- return self.segment_length - self.num_bytes_read
41
+ # Files larger than this will be multipart uploaded. The server might request multipart upload for smaller files as
42
+ # well, but the limit will never be raised.
43
+ # TODO(dano): remove this once we stop requiring md5 for blobs
44
+ MULTIPART_UPLOAD_THRESHOLD = 1024**3
109
45
 
110
46
 
111
47
  @retry(n_attempts=5, base_delay=0.5, timeout=None)
112
48
  async def _upload_to_s3_url(
113
49
  upload_url,
114
- payload: BytesIOSegmentPayload,
50
+ payload: "BytesIOSegmentPayload",
115
51
  content_md5_b64: Optional[str] = None,
116
52
  content_type: Optional[str] = "application/octet-stream", # set to None to force omission of ContentType header
117
53
  ) -> str:
118
54
  """Returns etag of s3 object which is a md5 hex checksum of the uploaded content"""
119
55
  with payload.reset_on_error(): # ensure retries read the same data
120
- async with http_client_with_tls(timeout=None) as session:
121
- headers = {}
122
- if content_md5_b64 and use_md5(upload_url):
123
- headers["Content-MD5"] = content_md5_b64
124
- if content_type:
125
- headers["Content-Type"] = content_type
126
-
127
- async with session.put(
128
- upload_url,
129
- data=payload,
130
- headers=headers,
131
- skip_auto_headers=["content-type"] if content_type is None else [],
132
- ) as resp:
133
- # S3 signal to slow down request rate.
134
- if resp.status == 503:
135
- logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
136
- await asyncio.sleep(1)
137
-
138
- if resp.status != 200:
139
- try:
140
- text = await resp.text()
141
- except Exception:
142
- text = "<no body>"
143
- raise ExecutionError(f"Put to url {upload_url} failed with status {resp.status}: {text}")
144
-
145
- # client side ETag checksum verification
146
- # the s3 ETag of a single part upload is a quoted md5 hex of the uploaded content
147
- etag = resp.headers["ETag"].strip()
148
- if etag.startswith(("W/", "w/")): # see https://www.rfc-editor.org/rfc/rfc7232#section-2.3
149
- etag = etag[2:]
150
- if etag[0] == '"' and etag[-1] == '"':
151
- etag = etag[1:-1]
152
- remote_md5 = etag
153
-
154
- local_md5_hex = payload.md5_checksum().hexdigest()
155
- if local_md5_hex != remote_md5:
156
- raise ExecutionError(
157
- f"Local data and remote data checksum mismatch ({local_md5_hex} vs {remote_md5})"
158
- )
159
-
160
- return remote_md5
56
+ headers = {}
57
+ if content_md5_b64 and use_md5(upload_url):
58
+ headers["Content-MD5"] = content_md5_b64
59
+ if content_type:
60
+ headers["Content-Type"] = content_type
61
+
62
+ async with ClientSessionRegistry.get_session().put(
63
+ upload_url,
64
+ data=payload,
65
+ headers=headers,
66
+ skip_auto_headers=["content-type"] if content_type is None else [],
67
+ ) as resp:
68
+ # S3 signal to slow down request rate.
69
+ if resp.status == 503:
70
+ logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
71
+ await asyncio.sleep(1)
72
+
73
+ if resp.status != 200:
74
+ try:
75
+ text = await resp.text()
76
+ except Exception:
77
+ text = "<no body>"
78
+ raise ExecutionError(f"Put to url {upload_url} failed with status {resp.status}: {text}")
79
+
80
+ # client side ETag checksum verification
81
+ # the s3 ETag of a single part upload is a quoted md5 hex of the uploaded content
82
+ etag = resp.headers["ETag"].strip()
83
+ if etag.startswith(("W/", "w/")): # see https://www.rfc-editor.org/rfc/rfc7232#section-2.3
84
+ etag = etag[2:]
85
+ if etag[0] == '"' and etag[-1] == '"':
86
+ etag = etag[1:-1]
87
+ remote_md5 = etag
88
+
89
+ local_md5_hex = payload.md5_checksum().hexdigest()
90
+ if local_md5_hex != remote_md5:
91
+ raise ExecutionError(f"Local data and remote data checksum mismatch ({local_md5_hex} vs {remote_md5})")
92
+
93
+ return remote_md5
161
94
 
162
95
 
163
96
  async def perform_multipart_upload(
@@ -165,17 +98,20 @@ async def perform_multipart_upload(
165
98
  *,
166
99
  content_length: int,
167
100
  max_part_size: int,
168
- part_urls: List[str],
101
+ part_urls: list[str],
169
102
  completion_url: str,
170
103
  upload_chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
171
- ):
104
+ progress_report_cb: Optional[Callable] = None,
105
+ ) -> None:
106
+ from .bytes_io_segment_payload import BytesIOSegmentPayload
107
+
172
108
  upload_coros = []
173
109
  file_offset = 0
174
110
  num_bytes_left = content_length
175
111
 
176
112
  # Give each part its own IO reader object to avoid needing to
177
113
  # lock access to the reader's position pointer.
178
- data_file_readers: List[BinaryIO]
114
+ data_file_readers: list[BinaryIO]
179
115
  if isinstance(data_file, io.BytesIO):
180
116
  view = data_file.getbuffer() # does not copy data
181
117
  data_file_readers = [io.BytesIO(view) for _ in range(len(part_urls))]
@@ -190,12 +126,13 @@ async def perform_multipart_upload(
190
126
  segment_start=file_offset,
191
127
  segment_length=part_length_bytes,
192
128
  chunk_size=upload_chunk_size,
129
+ progress_report_cb=progress_report_cb,
193
130
  )
194
131
  upload_coros.append(_upload_to_s3_url(part_url, payload=part_payload, content_type=None))
195
132
  num_bytes_left -= part_length_bytes
196
133
  file_offset += part_length_bytes
197
134
 
198
- part_etags = await asyncio.gather(*upload_coros)
135
+ part_etags = await TaskContext.gather(*upload_coros)
199
136
 
200
137
  # The body of the complete_multipart_upload command needs some data in xml format:
201
138
  completion_body = "<CompleteMultipartUpload>\n"
@@ -207,25 +144,24 @@ async def perform_multipart_upload(
207
144
  bin_hash_parts = [bytes.fromhex(etag) for etag in part_etags]
208
145
 
209
146
  expected_multipart_etag = hashlib.md5(b"".join(bin_hash_parts)).hexdigest() + f"-{len(part_etags)}"
210
- async with http_client_with_tls(timeout=None) as session:
211
- resp = await session.post(
212
- completion_url, data=completion_body.encode("ascii"), skip_auto_headers=["content-type"]
213
- )
214
- if resp.status != 200:
215
- try:
216
- msg = await resp.text()
217
- except Exception:
218
- msg = "<no body>"
219
- raise ExecutionError(f"Error when completing multipart upload: {resp.status}\n{msg}")
220
- else:
221
- response_body = await resp.text()
222
- if expected_multipart_etag not in response_body:
223
- raise ExecutionError(
224
- f"Hash mismatch on multipart upload assembly: {expected_multipart_etag} not in {response_body}"
225
- )
147
+ resp = await ClientSessionRegistry.get_session().post(
148
+ completion_url, data=completion_body.encode("ascii"), skip_auto_headers=["content-type"]
149
+ )
150
+ if resp.status != 200:
151
+ try:
152
+ msg = await resp.text()
153
+ except Exception:
154
+ msg = "<no body>"
155
+ raise ExecutionError(f"Error when completing multipart upload: {resp.status}\n{msg}")
156
+ else:
157
+ response_body = await resp.text()
158
+ if expected_multipart_etag not in response_body:
159
+ raise ExecutionError(
160
+ f"Hash mismatch on multipart upload assembly: {expected_multipart_etag} not in {response_body}"
161
+ )
226
162
 
227
163
 
228
- def get_content_length(data: BinaryIO):
164
+ def get_content_length(data: BinaryIO) -> int:
229
165
  # *Remaining* length of file from current seek position
230
166
  pos = data.tell()
231
167
  data.seek(0, os.SEEK_END)
@@ -234,7 +170,9 @@ def get_content_length(data: BinaryIO):
234
170
  return content_length - pos
235
171
 
236
172
 
237
- async def _blob_upload(upload_hashes: UploadHashes, data: Union[bytes, BinaryIO], stub) -> str:
173
+ async def _blob_upload(
174
+ upload_hashes: UploadHashes, data: Union[bytes, BinaryIO], stub, progress_report_cb: Optional[Callable] = None
175
+ ) -> str:
238
176
  if isinstance(data, bytes):
239
177
  data = io.BytesIO(data)
240
178
 
@@ -257,9 +195,14 @@ async def _blob_upload(upload_hashes: UploadHashes, data: Union[bytes, BinaryIO]
257
195
  part_urls=resp.multipart.upload_urls,
258
196
  completion_url=resp.multipart.completion_url,
259
197
  upload_chunk_size=DEFAULT_SEGMENT_CHUNK_SIZE,
198
+ progress_report_cb=progress_report_cb,
260
199
  )
261
200
  else:
262
- payload = BytesIOSegmentPayload(data, segment_start=0, segment_length=content_length)
201
+ from .bytes_io_segment_payload import BytesIOSegmentPayload
202
+
203
+ payload = BytesIOSegmentPayload(
204
+ data, segment_start=0, segment_length=content_length, progress_report_cb=progress_report_cb
205
+ )
263
206
  await _upload_to_s3_url(
264
207
  resp.upload_url,
265
208
  payload,
@@ -267,79 +210,103 @@ async def _blob_upload(upload_hashes: UploadHashes, data: Union[bytes, BinaryIO]
267
210
  content_md5_b64=upload_hashes.md5_base64,
268
211
  )
269
212
 
213
+ if progress_report_cb:
214
+ progress_report_cb(complete=True)
215
+
270
216
  return blob_id
271
217
 
272
218
 
273
- async def blob_upload(payload: bytes, stub) -> str:
219
+ async def blob_upload(payload: bytes, stub: ModalClientModal) -> str:
220
+ size_mib = len(payload) / 1024 / 1024
221
+ logger.debug(f"Uploading large blob of size {size_mib:.2f} MiB")
222
+ t0 = time.time()
274
223
  if isinstance(payload, str):
275
224
  logger.warning("Blob uploading string, not bytes - auto-encoding as utf8")
276
225
  payload = payload.encode("utf8")
277
226
  upload_hashes = get_upload_hashes(payload)
278
- return await _blob_upload(upload_hashes, payload, stub)
227
+ blob_id = await _blob_upload(upload_hashes, payload, stub)
228
+ dur_s = max(time.time() - t0, 0.001) # avoid division by zero
229
+ throughput_mib_s = (size_mib) / dur_s
230
+ logger.debug(f"Uploaded large blob of size {size_mib:.2f} MiB ({throughput_mib_s:.2f} MiB/s)." f" {blob_id}")
231
+ return blob_id
279
232
 
280
233
 
281
- async def blob_upload_file(file_obj: BinaryIO, stub) -> str:
282
- upload_hashes = get_upload_hashes(file_obj)
283
- return await _blob_upload(upload_hashes, file_obj, stub)
234
+ async def blob_upload_file(
235
+ file_obj: BinaryIO,
236
+ stub: ModalClientModal,
237
+ progress_report_cb: Optional[Callable] = None,
238
+ sha256_hex: Optional[str] = None,
239
+ md5_hex: Optional[str] = None,
240
+ ) -> str:
241
+ upload_hashes = get_upload_hashes(file_obj, sha256_hex=sha256_hex, md5_hex=md5_hex)
242
+ return await _blob_upload(upload_hashes, file_obj, stub, progress_report_cb)
284
243
 
285
244
 
286
245
  @retry(n_attempts=5, base_delay=0.1, timeout=None)
287
- async def _download_from_url(download_url) -> bytes:
288
- async with http_client_with_tls(timeout=None) as session:
289
- async with session.get(download_url) as resp:
290
- # S3 signal to slow down request rate.
291
- if resp.status == 503:
292
- logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
293
- await asyncio.sleep(1)
294
-
295
- if resp.status != 200:
296
- text = await resp.text()
297
- raise ExecutionError(f"Get from url failed with status {resp.status}: {text}")
298
- return await resp.read()
299
-
300
-
301
- async def blob_download(blob_id, stub) -> bytes:
302
- # convenience function reading all of the downloaded file into memory
246
+ async def _download_from_url(download_url: str) -> bytes:
247
+ async with ClientSessionRegistry.get_session().get(download_url) as s3_resp:
248
+ # S3 signal to slow down request rate.
249
+ if s3_resp.status == 503:
250
+ logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
251
+ await asyncio.sleep(1)
252
+
253
+ if s3_resp.status != 200:
254
+ text = await s3_resp.text()
255
+ raise ExecutionError(f"Get from url failed with status {s3_resp.status}: {text}")
256
+ return await s3_resp.read()
257
+
258
+
259
+ async def blob_download(blob_id: str, stub: ModalClientModal) -> bytes:
260
+ """Convenience function for reading all of the downloaded file into memory."""
261
+ logger.debug(f"Downloading large blob {blob_id}")
262
+ t0 = time.time()
303
263
  req = api_pb2.BlobGetRequest(blob_id=blob_id)
304
264
  resp = await retry_transient_errors(stub.BlobGet, req)
305
-
306
- return await _download_from_url(resp.download_url)
265
+ data = await _download_from_url(resp.download_url)
266
+ size_mib = len(data) / 1024 / 1024
267
+ dur_s = max(time.time() - t0, 0.001) # avoid division by zero
268
+ throughput_mib_s = size_mib / dur_s
269
+ logger.debug(f"Downloaded large blob {blob_id} of size {size_mib:.2f} MiB ({throughput_mib_s:.2f} MiB/s)")
270
+ return data
307
271
 
308
272
 
309
- async def blob_iter(blob_id, stub) -> AsyncIterator[bytes]:
273
+ async def blob_iter(blob_id: str, stub: ModalClientModal) -> AsyncIterator[bytes]:
310
274
  req = api_pb2.BlobGetRequest(blob_id=blob_id)
311
275
  resp = await retry_transient_errors(stub.BlobGet, req)
312
276
  download_url = resp.download_url
313
- async with http_client_with_tls(timeout=None) as session:
314
- async with session.get(download_url) as resp:
315
- # S3 signal to slow down request rate.
316
- if resp.status == 503:
317
- logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
318
- await asyncio.sleep(1)
277
+ async with ClientSessionRegistry.get_session().get(download_url) as s3_resp:
278
+ # S3 signal to slow down request rate.
279
+ if s3_resp.status == 503:
280
+ logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
281
+ await asyncio.sleep(1)
319
282
 
320
- if resp.status != 200:
321
- text = await resp.text()
322
- raise ExecutionError(f"Get from url failed with status {resp.status}: {text}")
283
+ if s3_resp.status != 200:
284
+ text = await s3_resp.text()
285
+ raise ExecutionError(f"Get from url failed with status {s3_resp.status}: {text}")
323
286
 
324
- async for chunk in resp.content.iter_any():
325
- yield chunk
287
+ async for chunk in s3_resp.content.iter_any():
288
+ yield chunk
326
289
 
327
290
 
328
291
  @dataclasses.dataclass
329
292
  class FileUploadSpec:
330
- source: Callable[[], BinaryIO]
293
+ source: Callable[[], Union[AbstractContextManager, BinaryIO]]
331
294
  source_description: Any
332
295
  mount_filename: str
333
296
 
334
297
  use_blob: bool
335
298
  content: Optional[bytes] # typically None if using blob, required otherwise
336
299
  sha256_hex: str
300
+ md5_hex: str
337
301
  mode: int # file permission bits (last 12 bits of st_mode)
338
302
  size: int
339
303
 
340
304
 
341
305
  def _get_file_upload_spec(
342
- source: Callable[[], BinaryIO], source_description: Any, mount_filename: PurePosixPath, mode: int
306
+ source: Callable[[], Union[AbstractContextManager, BinaryIO]],
307
+ source_description: Any,
308
+ mount_filename: PurePosixPath,
309
+ mode: int,
343
310
  ) -> FileUploadSpec:
344
311
  with source() as fp:
345
312
  # Current position is ignored - we always upload from position 0
@@ -348,13 +315,15 @@ def _get_file_upload_spec(
348
315
  fp.seek(0)
349
316
 
350
317
  if size >= LARGE_FILE_LIMIT:
318
+ # TODO(dano): remove the placeholder md5 once we stop requiring md5 for blobs
319
+ md5_hex = "baadbaadbaadbaadbaadbaadbaadbaad" if size > MULTIPART_UPLOAD_THRESHOLD else None
351
320
  use_blob = True
352
321
  content = None
353
- sha256_hex = get_sha256_hex(fp)
322
+ hashes = get_upload_hashes(fp, md5_hex=md5_hex)
354
323
  else:
355
324
  use_blob = False
356
325
  content = fp.read()
357
- sha256_hex = get_sha256_hex(content)
326
+ hashes = get_upload_hashes(content)
358
327
 
359
328
  return FileUploadSpec(
360
329
  source=source,
@@ -362,7 +331,8 @@ def _get_file_upload_spec(
362
331
  mount_filename=mount_filename.as_posix(),
363
332
  use_blob=use_blob,
364
333
  content=content,
365
- sha256_hex=sha256_hex,
334
+ sha256_hex=hashes.sha256_hex(),
335
+ md5_hex=hashes.md5_hex(),
366
336
  mode=mode & 0o7777,
367
337
  size=size,
368
338
  )
@@ -383,10 +353,11 @@ def get_file_upload_spec_from_path(
383
353
 
384
354
 
385
355
  def get_file_upload_spec_from_fileobj(fp: BinaryIO, mount_filename: PurePosixPath, mode: int) -> FileUploadSpec:
356
+ @contextmanager
386
357
  def source():
387
358
  # We ignore position in stream and always upload from position 0
388
359
  fp.seek(0)
389
- return fp
360
+ yield fp
390
361
 
391
362
  return _get_file_upload_spec(
392
363
  source,
@@ -403,7 +374,7 @@ def use_md5(url: str) -> bool:
403
374
  https://github.com/spulec/moto/issues/816
404
375
  """
405
376
  host = urlparse(url).netloc.split(":")[0]
406
- if host.endswith(".amazonaws.com"):
377
+ if host.endswith(".amazonaws.com") or host.endswith(".r2.cloudflarestorage.com"):
407
378
  return True
408
379
  elif host in ["127.0.0.1", "localhost", "172.21.0.1"]:
409
380
  return False
@@ -0,0 +1,97 @@
1
+ # Copyright Modal Labs 2024
2
+
3
+ import asyncio
4
+ import hashlib
5
+ from contextlib import contextmanager
6
+ from typing import BinaryIO, Callable, Optional
7
+
8
+ # Note: this module needs to import aiohttp in global scope
9
+ # This takes about 50ms and isn't needed in many cases for Modal execution
10
+ # To avoid this, we import it in local scope when needed (blob_utils.py)
11
+ from aiohttp import BytesIOPayload
12
+ from aiohttp.abc import AbstractStreamWriter
13
+
14
+ # read ~16MiB chunks by default
15
+ DEFAULT_SEGMENT_CHUNK_SIZE = 2**24
16
+
17
+
18
+ class BytesIOSegmentPayload(BytesIOPayload):
19
+ """Modified bytes payload for concurrent sends of chunks from the same file.
20
+
21
+ Adds:
22
+ * read limit using remaining_bytes, in order to split files across streams
23
+ * larger read chunk (to prevent excessive read contention between parts)
24
+ * calculates an md5 for the segment
25
+
26
+ Feels like this should be in some standard lib...
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ bytes_io: BinaryIO, # should *not* be shared as IO position modification is not locked
32
+ segment_start: int,
33
+ segment_length: int,
34
+ chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
35
+ progress_report_cb: Optional[Callable] = None,
36
+ ):
37
+ # not thread safe constructor!
38
+ super().__init__(bytes_io)
39
+ self.initial_seek_pos = bytes_io.tell()
40
+ self.segment_start = segment_start
41
+ self.segment_length = segment_length
42
+ # seek to start of file segment we are interested in, in order to make .size() evaluate correctly
43
+ self._value.seek(self.initial_seek_pos + segment_start)
44
+ assert self.segment_length <= super().size
45
+ self.chunk_size = chunk_size
46
+ self.progress_report_cb = progress_report_cb or (lambda *_, **__: None)
47
+ self.reset_state()
48
+
49
+ def reset_state(self):
50
+ self._md5_checksum = hashlib.md5()
51
+ self.num_bytes_read = 0
52
+ self._value.seek(self.initial_seek_pos)
53
+
54
+ @contextmanager
55
+ def reset_on_error(self):
56
+ try:
57
+ yield
58
+ except Exception as exc:
59
+ try:
60
+ self.progress_report_cb(reset=True)
61
+ except Exception as cb_exc:
62
+ raise cb_exc from exc
63
+ raise exc
64
+ finally:
65
+ self.reset_state()
66
+
67
+ @property
68
+ def size(self) -> int:
69
+ return self.segment_length
70
+
71
+ def md5_checksum(self):
72
+ return self._md5_checksum
73
+
74
+ async def write(self, writer: "AbstractStreamWriter"):
75
+ loop = asyncio.get_event_loop()
76
+
77
+ async def safe_read():
78
+ read_start = self.initial_seek_pos + self.segment_start + self.num_bytes_read
79
+ self._value.seek(read_start)
80
+ num_bytes = min(self.chunk_size, self.remaining_bytes())
81
+ chunk = await loop.run_in_executor(None, self._value.read, num_bytes)
82
+
83
+ await loop.run_in_executor(None, self._md5_checksum.update, chunk)
84
+ self.num_bytes_read += len(chunk)
85
+ return chunk
86
+
87
+ chunk = await safe_read()
88
+ while chunk and self.remaining_bytes() > 0:
89
+ await writer.write(chunk)
90
+ self.progress_report_cb(advance=len(chunk))
91
+ chunk = await safe_read()
92
+ if chunk:
93
+ await writer.write(chunk)
94
+ self.progress_report_cb(advance=len(chunk))
95
+
96
+ def remaining_bytes(self):
97
+ return self.segment_length - self.num_bytes_read