modal 0.67.43__py3-none-any.whl → 0.68.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. modal/__init__.py +2 -0
  2. modal/_container_entrypoint.py +4 -1
  3. modal/_ipython.py +3 -13
  4. modal/_runtime/asgi.py +4 -0
  5. modal/_runtime/container_io_manager.py +3 -0
  6. modal/_runtime/user_code_imports.py +17 -20
  7. modal/_traceback.py +16 -2
  8. modal/_utils/blob_utils.py +27 -92
  9. modal/_utils/bytes_io_segment_payload.py +97 -0
  10. modal/_utils/function_utils.py +5 -1
  11. modal/_utils/grpc_testing.py +6 -2
  12. modal/_utils/hash_utils.py +51 -10
  13. modal/_utils/http_utils.py +19 -10
  14. modal/_utils/{pattern_matcher.py → pattern_utils.py} +1 -70
  15. modal/_utils/shell_utils.py +11 -5
  16. modal/cli/_traceback.py +11 -4
  17. modal/cli/run.py +25 -12
  18. modal/client.py +6 -37
  19. modal/client.pyi +2 -6
  20. modal/cls.py +132 -62
  21. modal/cls.pyi +13 -7
  22. modal/exception.py +20 -0
  23. modal/file_io.py +380 -0
  24. modal/file_io.pyi +185 -0
  25. modal/file_pattern_matcher.py +121 -0
  26. modal/functions.py +33 -11
  27. modal/functions.pyi +11 -9
  28. modal/image.py +88 -8
  29. modal/image.pyi +20 -4
  30. modal/mount.py +49 -9
  31. modal/mount.pyi +19 -4
  32. modal/network_file_system.py +4 -1
  33. modal/object.py +4 -2
  34. modal/partial_function.py +22 -10
  35. modal/partial_function.pyi +10 -2
  36. modal/runner.py +5 -4
  37. modal/runner.pyi +2 -1
  38. modal/sandbox.py +40 -0
  39. modal/sandbox.pyi +18 -0
  40. modal/volume.py +5 -1
  41. {modal-0.67.43.dist-info → modal-0.68.24.dist-info}/METADATA +2 -2
  42. {modal-0.67.43.dist-info → modal-0.68.24.dist-info}/RECORD +52 -48
  43. modal_docs/gen_reference_docs.py +1 -0
  44. modal_proto/api.proto +33 -1
  45. modal_proto/api_pb2.py +813 -737
  46. modal_proto/api_pb2.pyi +160 -13
  47. modal_version/__init__.py +1 -1
  48. modal_version/_version_generated.py +1 -1
  49. {modal-0.67.43.dist-info → modal-0.68.24.dist-info}/LICENSE +0 -0
  50. {modal-0.67.43.dist-info → modal-0.68.24.dist-info}/WHEEL +0 -0
  51. {modal-0.67.43.dist-info → modal-0.68.24.dist-info}/entry_points.txt +0 -0
  52. {modal-0.67.43.dist-info → modal-0.68.24.dist-info}/top_level.txt +0 -0
modal/__init__.py CHANGED
@@ -17,6 +17,7 @@ try:
17
17
  from .cls import Cls, parameter
18
18
  from .dict import Dict
19
19
  from .exception import Error
20
+ from .file_pattern_matcher import FilePatternMatcher
20
21
  from .functions import Function
21
22
  from .image import Image
22
23
  from .mount import Mount
@@ -48,6 +49,7 @@ __all__ = [
48
49
  "Cron",
49
50
  "Dict",
50
51
  "Error",
52
+ "FilePatternMatcher",
51
53
  "Function",
52
54
  "Image",
53
55
  "Mount",
@@ -6,7 +6,7 @@ from modal._runtime.user_code_imports import Service, import_class_service, impo
6
6
 
7
7
  telemetry_socket = os.environ.get("MODAL_TELEMETRY_SOCKET")
8
8
  if telemetry_socket:
9
- from runtime._telemetry import instrument_imports
9
+ from ._runtime.telemetry import instrument_imports
10
10
 
11
11
  instrument_imports(telemetry_socket)
12
12
 
@@ -415,6 +415,9 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
415
415
 
416
416
  _client: _Client = synchronizer._translate_in(client) # TODO(erikbern): ugly
417
417
 
418
+ # Call ContainerHello - currently a noop but might be used later for things
419
+ container_io_manager.hello()
420
+
418
421
  with container_io_manager.heartbeats(is_snapshotting_function), UserCodeEventLoop() as event_loop:
419
422
  # If this is a serialized function, fetch the definition from the server
420
423
  if function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED:
modal/_ipython.py CHANGED
@@ -1,21 +1,11 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import sys
3
- import warnings
4
-
5
- ipy_outstream = None
6
- try:
7
- with warnings.catch_warnings():
8
- warnings.simplefilter("ignore")
9
- import ipykernel.iostream
10
-
11
- ipy_outstream = ipykernel.iostream.OutStream
12
- except ImportError:
13
- pass
14
3
 
15
4
 
16
5
  def is_notebook(stdout=None):
17
- if ipy_outstream is None:
6
+ ipykernel_iostream = sys.modules.get("ipykernel.iostream")
7
+ if ipykernel_iostream is None:
18
8
  return False
19
9
  if stdout is None:
20
10
  stdout = sys.stdout
21
- return isinstance(stdout, ipy_outstream)
11
+ return isinstance(stdout, ipykernel_iostream.OutStream)
modal/_runtime/asgi.py CHANGED
@@ -1,4 +1,8 @@
1
1
  # Copyright Modal Labs 2022
2
+
3
+ # Note: this module isn't imported unless it's needed.
4
+ # This is because aiohttp is a pretty big dependency that adds significant latency when imported
5
+
2
6
  import asyncio
3
7
  from collections.abc import AsyncGenerator
4
8
  from typing import Any, Callable, NoReturn, Optional, cast
@@ -335,6 +335,9 @@ class _ContainerIOManager:
335
335
  """Only used for tests."""
336
336
  cls._singleton = None
337
337
 
338
+ async def hello(self):
339
+ await self._client.stub.ContainerHello(Empty())
340
+
338
341
  async def _run_heartbeat_loop(self):
339
342
  while 1:
340
343
  t0 = time.monotonic()
@@ -9,15 +9,6 @@ import modal._runtime.container_io_manager
9
9
  import modal.cls
10
10
  import modal.object
11
11
  from modal import Function
12
- from modal._runtime.asgi import (
13
- LifespanManager,
14
- asgi_app_wrapper,
15
- get_ip_address,
16
- wait_for_web_server,
17
- web_server_proxy,
18
- webhook_asgi_app,
19
- wsgi_app_wrapper,
20
- )
21
12
  from modal._utils.async_utils import synchronizer
22
13
  from modal._utils.function_utils import LocalFunctionError, is_async as get_is_async, is_global_object
23
14
  from modal.exception import ExecutionError, InvalidError
@@ -28,6 +19,7 @@ from modal_proto import api_pb2
28
19
  if typing.TYPE_CHECKING:
29
20
  import modal.app
30
21
  import modal.partial_function
22
+ from modal._runtime.asgi import LifespanManager
31
23
 
32
24
 
33
25
  @dataclass
@@ -36,7 +28,7 @@ class FinalizedFunction:
36
28
  is_async: bool
37
29
  is_generator: bool
38
30
  data_format: int # api_pb2.DataFormat
39
- lifespan_manager: Optional[LifespanManager] = None
31
+ lifespan_manager: Optional["LifespanManager"] = None
40
32
 
41
33
 
42
34
  class Service(metaclass=ABCMeta):
@@ -63,19 +55,22 @@ def construct_webhook_callable(
63
55
  webhook_config: api_pb2.WebhookConfig,
64
56
  container_io_manager: "modal._runtime.container_io_manager.ContainerIOManager",
65
57
  ):
58
+ # Note: aiohttp is a significant dependency of the `asgi` module, so we import it locally
59
+ from modal._runtime import asgi
60
+
66
61
  # For webhooks, the user function is used to construct an asgi app:
67
62
  if webhook_config.type == api_pb2.WEBHOOK_TYPE_ASGI_APP:
68
63
  # Function returns an asgi_app, which we can use as a callable.
69
- return asgi_app_wrapper(user_defined_callable(), container_io_manager)
64
+ return asgi.asgi_app_wrapper(user_defined_callable(), container_io_manager)
70
65
 
71
66
  elif webhook_config.type == api_pb2.WEBHOOK_TYPE_WSGI_APP:
72
- # Function returns an wsgi_app, which we can use as a callable.
73
- return wsgi_app_wrapper(user_defined_callable(), container_io_manager)
67
+ # Function returns an wsgi_app, which we can use as a callable
68
+ return asgi.wsgi_app_wrapper(user_defined_callable(), container_io_manager)
74
69
 
75
70
  elif webhook_config.type == api_pb2.WEBHOOK_TYPE_FUNCTION:
76
71
  # Function is a webhook without an ASGI app. Create one for it.
77
- return asgi_app_wrapper(
78
- webhook_asgi_app(user_defined_callable, webhook_config.method, webhook_config.web_endpoint_docs),
72
+ return asgi.asgi_app_wrapper(
73
+ asgi.webhook_asgi_app(user_defined_callable, webhook_config.method, webhook_config.web_endpoint_docs),
79
74
  container_io_manager,
80
75
  )
81
76
 
@@ -86,11 +81,11 @@ def construct_webhook_callable(
86
81
  # We intentionally try to connect to the external interface instead of the loopback
87
82
  # interface here so users are forced to expose the server. This allows us to potentially
88
83
  # change the implementation to use an external bridge in the future.
89
- host = get_ip_address(b"eth0")
84
+ host = asgi.get_ip_address(b"eth0")
90
85
  port = webhook_config.web_server_port
91
86
  startup_timeout = webhook_config.web_server_startup_timeout
92
- wait_for_web_server(host, port, timeout=startup_timeout)
93
- return asgi_app_wrapper(web_server_proxy(host, port), container_io_manager)
87
+ asgi.wait_for_web_server(host, port, timeout=startup_timeout)
88
+ return asgi.asgi_app_wrapper(asgi.web_server_proxy(host, port), container_io_manager)
94
89
  else:
95
90
  raise InvalidError(f"Unrecognized web endpoint type {webhook_config.type}")
96
91
 
@@ -269,10 +264,12 @@ def import_single_function_service(
269
264
  # The cls decorator is in global scope
270
265
  _cls = synchronizer._translate_in(cls)
271
266
  user_defined_callable = _cls._callables[fun_name]
272
- function = _cls._method_functions.get(fun_name)
267
+ function = _cls._method_functions.get(
268
+ fun_name
269
+ ) # bound to the class service function - there is no instance
273
270
  active_app = _cls._app
274
271
  else:
275
- # This is a raw class
272
+ # This is non-decorated class
276
273
  user_defined_callable = getattr(cls, fun_name)
277
274
  else:
278
275
  raise InvalidError(f"Invalid function qualname {qual_name}")
modal/_traceback.py CHANGED
@@ -1,16 +1,21 @@
1
1
  # Copyright Modal Labs 2022
2
- """Helper functions related to operating on traceback objects.
2
+ """Helper functions related to operating on exceptions, warnings, and traceback objects.
3
3
 
4
4
  Functions related to *displaying* tracebacks should go in `modal/cli/_traceback.py`
5
5
  so that Rich is not a dependency of the container Client.
6
6
  """
7
+
7
8
  import re
8
9
  import sys
9
10
  import traceback
11
+ import warnings
10
12
  from types import TracebackType
11
- from typing import Any, Optional
13
+ from typing import Any, Iterable, Optional
14
+
15
+ from modal_proto import api_pb2
12
16
 
13
17
  from ._vendor.tblib import Traceback as TBLibTraceback
18
+ from .exception import ServerWarning
14
19
 
15
20
  TBDictType = dict[str, Any]
16
21
  LineCacheType = dict[tuple[str, str], str]
@@ -109,3 +114,12 @@ def print_exception(exc: Optional[type[BaseException]], value: Optional[BaseExce
109
114
  if sys.version_info < (3, 11) and value is not None:
110
115
  notes = getattr(value, "__notes__", [])
111
116
  print(*notes, sep="\n", file=sys.stderr)
117
+
118
+
119
+ def print_server_warnings(server_warnings: Iterable[api_pb2.Warning]):
120
+ """Issue a warning originating from the server with empty metadata about local origin.
121
+
122
+ When using the Modal CLI, these warnings should get caught and coerced into Rich panels.
123
+ """
124
+ for warning in server_warnings:
125
+ warnings.warn_explicit(warning.message, ServerWarning, "<modal-server>", 0)
@@ -9,22 +9,22 @@ import time
9
9
  from collections.abc import AsyncIterator
10
10
  from contextlib import AbstractContextManager, contextmanager
11
11
  from pathlib import Path, PurePosixPath
12
- from typing import Any, BinaryIO, Callable, Optional, Union
12
+ from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Optional, Union
13
13
  from urllib.parse import urlparse
14
14
 
15
- from aiohttp import BytesIOPayload
16
- from aiohttp.abc import AbstractStreamWriter
17
-
18
15
  from modal_proto import api_pb2
19
16
  from modal_proto.modal_api_grpc import ModalClientModal
20
17
 
21
18
  from ..exception import ExecutionError
22
19
  from .async_utils import TaskContext, retry
23
20
  from .grpc_utils import retry_transient_errors
24
- from .hash_utils import UploadHashes, get_sha256_hex, get_upload_hashes
21
+ from .hash_utils import UploadHashes, get_upload_hashes
25
22
  from .http_utils import ClientSessionRegistry
26
23
  from .logger import logger
27
24
 
25
+ if TYPE_CHECKING:
26
+ from .bytes_io_segment_payload import BytesIOSegmentPayload
27
+
28
28
  # Max size for function inputs and outputs.
29
29
  MAX_OBJECT_SIZE_BYTES = 2 * 1024 * 1024 # 2 MiB
30
30
 
@@ -38,93 +38,16 @@ BLOB_MAX_PARALLELISM = 10
38
38
  # read ~16MiB chunks by default
39
39
  DEFAULT_SEGMENT_CHUNK_SIZE = 2**24
40
40
 
41
-
42
- class BytesIOSegmentPayload(BytesIOPayload):
43
- """Modified bytes payload for concurrent sends of chunks from the same file.
44
-
45
- Adds:
46
- * read limit using remaining_bytes, in order to split files across streams
47
- * larger read chunk (to prevent excessive read contention between parts)
48
- * calculates an md5 for the segment
49
-
50
- Feels like this should be in some standard lib...
51
- """
52
-
53
- def __init__(
54
- self,
55
- bytes_io: BinaryIO, # should *not* be shared as IO position modification is not locked
56
- segment_start: int,
57
- segment_length: int,
58
- chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
59
- progress_report_cb: Optional[Callable] = None,
60
- ):
61
- # not thread safe constructor!
62
- super().__init__(bytes_io)
63
- self.initial_seek_pos = bytes_io.tell()
64
- self.segment_start = segment_start
65
- self.segment_length = segment_length
66
- # seek to start of file segment we are interested in, in order to make .size() evaluate correctly
67
- self._value.seek(self.initial_seek_pos + segment_start)
68
- assert self.segment_length <= super().size
69
- self.chunk_size = chunk_size
70
- self.progress_report_cb = progress_report_cb or (lambda *_, **__: None)
71
- self.reset_state()
72
-
73
- def reset_state(self):
74
- self._md5_checksum = hashlib.md5()
75
- self.num_bytes_read = 0
76
- self._value.seek(self.initial_seek_pos)
77
-
78
- @contextmanager
79
- def reset_on_error(self):
80
- try:
81
- yield
82
- except Exception as exc:
83
- try:
84
- self.progress_report_cb(reset=True)
85
- except Exception as cb_exc:
86
- raise cb_exc from exc
87
- raise exc
88
- finally:
89
- self.reset_state()
90
-
91
- @property
92
- def size(self) -> int:
93
- return self.segment_length
94
-
95
- def md5_checksum(self):
96
- return self._md5_checksum
97
-
98
- async def write(self, writer: AbstractStreamWriter):
99
- loop = asyncio.get_event_loop()
100
-
101
- async def safe_read():
102
- read_start = self.initial_seek_pos + self.segment_start + self.num_bytes_read
103
- self._value.seek(read_start)
104
- num_bytes = min(self.chunk_size, self.remaining_bytes())
105
- chunk = await loop.run_in_executor(None, self._value.read, num_bytes)
106
-
107
- await loop.run_in_executor(None, self._md5_checksum.update, chunk)
108
- self.num_bytes_read += len(chunk)
109
- return chunk
110
-
111
- chunk = await safe_read()
112
- while chunk and self.remaining_bytes() > 0:
113
- await writer.write(chunk)
114
- self.progress_report_cb(advance=len(chunk))
115
- chunk = await safe_read()
116
- if chunk:
117
- await writer.write(chunk)
118
- self.progress_report_cb(advance=len(chunk))
119
-
120
- def remaining_bytes(self):
121
- return self.segment_length - self.num_bytes_read
41
+ # Files larger than this will be multipart uploaded. The server might request multipart upload for smaller files as
42
+ # well, but the limit will never be raised.
43
+ # TODO(dano): remove this once we stop requiring md5 for blobs
44
+ MULTIPART_UPLOAD_THRESHOLD = 1024**3
122
45
 
123
46
 
124
47
  @retry(n_attempts=5, base_delay=0.5, timeout=None)
125
48
  async def _upload_to_s3_url(
126
49
  upload_url,
127
- payload: BytesIOSegmentPayload,
50
+ payload: "BytesIOSegmentPayload",
128
51
  content_md5_b64: Optional[str] = None,
129
52
  content_type: Optional[str] = "application/octet-stream", # set to None to force omission of ContentType header
130
53
  ) -> str:
@@ -180,6 +103,8 @@ async def perform_multipart_upload(
180
103
  upload_chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
181
104
  progress_report_cb: Optional[Callable] = None,
182
105
  ) -> None:
106
+ from .bytes_io_segment_payload import BytesIOSegmentPayload
107
+
183
108
  upload_coros = []
184
109
  file_offset = 0
185
110
  num_bytes_left = content_length
@@ -273,6 +198,8 @@ async def _blob_upload(
273
198
  progress_report_cb=progress_report_cb,
274
199
  )
275
200
  else:
201
+ from .bytes_io_segment_payload import BytesIOSegmentPayload
202
+
276
203
  payload = BytesIOSegmentPayload(
277
204
  data, segment_start=0, segment_length=content_length, progress_report_cb=progress_report_cb
278
205
  )
@@ -305,9 +232,13 @@ async def blob_upload(payload: bytes, stub: ModalClientModal) -> str:
305
232
 
306
233
 
307
234
  async def blob_upload_file(
308
- file_obj: BinaryIO, stub: ModalClientModal, progress_report_cb: Optional[Callable] = None
235
+ file_obj: BinaryIO,
236
+ stub: ModalClientModal,
237
+ progress_report_cb: Optional[Callable] = None,
238
+ sha256_hex: Optional[str] = None,
239
+ md5_hex: Optional[str] = None,
309
240
  ) -> str:
310
- upload_hashes = get_upload_hashes(file_obj)
241
+ upload_hashes = get_upload_hashes(file_obj, sha256_hex=sha256_hex, md5_hex=md5_hex)
311
242
  return await _blob_upload(upload_hashes, file_obj, stub, progress_report_cb)
312
243
 
313
244
 
@@ -366,6 +297,7 @@ class FileUploadSpec:
366
297
  use_blob: bool
367
298
  content: Optional[bytes] # typically None if using blob, required otherwise
368
299
  sha256_hex: str
300
+ md5_hex: str
369
301
  mode: int # file permission bits (last 12 bits of st_mode)
370
302
  size: int
371
303
 
@@ -383,13 +315,15 @@ def _get_file_upload_spec(
383
315
  fp.seek(0)
384
316
 
385
317
  if size >= LARGE_FILE_LIMIT:
318
+ # TODO(dano): remove the placeholder md5 once we stop requiring md5 for blobs
319
+ md5_hex = "baadbaadbaadbaadbaadbaadbaadbaad" if size > MULTIPART_UPLOAD_THRESHOLD else None
386
320
  use_blob = True
387
321
  content = None
388
- sha256_hex = get_sha256_hex(fp)
322
+ hashes = get_upload_hashes(fp, md5_hex=md5_hex)
389
323
  else:
390
324
  use_blob = False
391
325
  content = fp.read()
392
- sha256_hex = get_sha256_hex(content)
326
+ hashes = get_upload_hashes(content)
393
327
 
394
328
  return FileUploadSpec(
395
329
  source=source,
@@ -397,7 +331,8 @@ def _get_file_upload_spec(
397
331
  mount_filename=mount_filename.as_posix(),
398
332
  use_blob=use_blob,
399
333
  content=content,
400
- sha256_hex=sha256_hex,
334
+ sha256_hex=hashes.sha256_hex(),
335
+ md5_hex=hashes.md5_hex(),
401
336
  mode=mode & 0o7777,
402
337
  size=size,
403
338
  )
@@ -0,0 +1,97 @@
1
+ # Copyright Modal Labs 2024
2
+
3
+ import asyncio
4
+ import hashlib
5
+ from contextlib import contextmanager
6
+ from typing import BinaryIO, Callable, Optional
7
+
8
+ # Note: this module needs to import aiohttp in global scope
9
+ # This takes about 50ms and isn't needed in many cases for Modal execution
10
+ # To avoid this, we import it in local scope when needed (blob_utils.py)
11
+ from aiohttp import BytesIOPayload
12
+ from aiohttp.abc import AbstractStreamWriter
13
+
14
+ # read ~16MiB chunks by default
15
+ DEFAULT_SEGMENT_CHUNK_SIZE = 2**24
16
+
17
+
18
+ class BytesIOSegmentPayload(BytesIOPayload):
19
+ """Modified bytes payload for concurrent sends of chunks from the same file.
20
+
21
+ Adds:
22
+ * read limit using remaining_bytes, in order to split files across streams
23
+ * larger read chunk (to prevent excessive read contention between parts)
24
+ * calculates an md5 for the segment
25
+
26
+ Feels like this should be in some standard lib...
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ bytes_io: BinaryIO, # should *not* be shared as IO position modification is not locked
32
+ segment_start: int,
33
+ segment_length: int,
34
+ chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
35
+ progress_report_cb: Optional[Callable] = None,
36
+ ):
37
+ # not thread safe constructor!
38
+ super().__init__(bytes_io)
39
+ self.initial_seek_pos = bytes_io.tell()
40
+ self.segment_start = segment_start
41
+ self.segment_length = segment_length
42
+ # seek to start of file segment we are interested in, in order to make .size() evaluate correctly
43
+ self._value.seek(self.initial_seek_pos + segment_start)
44
+ assert self.segment_length <= super().size
45
+ self.chunk_size = chunk_size
46
+ self.progress_report_cb = progress_report_cb or (lambda *_, **__: None)
47
+ self.reset_state()
48
+
49
+ def reset_state(self):
50
+ self._md5_checksum = hashlib.md5()
51
+ self.num_bytes_read = 0
52
+ self._value.seek(self.initial_seek_pos)
53
+
54
+ @contextmanager
55
+ def reset_on_error(self):
56
+ try:
57
+ yield
58
+ except Exception as exc:
59
+ try:
60
+ self.progress_report_cb(reset=True)
61
+ except Exception as cb_exc:
62
+ raise cb_exc from exc
63
+ raise exc
64
+ finally:
65
+ self.reset_state()
66
+
67
+ @property
68
+ def size(self) -> int:
69
+ return self.segment_length
70
+
71
+ def md5_checksum(self):
72
+ return self._md5_checksum
73
+
74
+ async def write(self, writer: "AbstractStreamWriter"):
75
+ loop = asyncio.get_event_loop()
76
+
77
+ async def safe_read():
78
+ read_start = self.initial_seek_pos + self.segment_start + self.num_bytes_read
79
+ self._value.seek(read_start)
80
+ num_bytes = min(self.chunk_size, self.remaining_bytes())
81
+ chunk = await loop.run_in_executor(None, self._value.read, num_bytes)
82
+
83
+ await loop.run_in_executor(None, self._md5_checksum.update, chunk)
84
+ self.num_bytes_read += len(chunk)
85
+ return chunk
86
+
87
+ chunk = await safe_read()
88
+ while chunk and self.remaining_bytes() > 0:
89
+ await writer.write(chunk)
90
+ self.progress_report_cb(advance=len(chunk))
91
+ chunk = await safe_read()
92
+ if chunk:
93
+ await writer.write(chunk)
94
+ self.progress_report_cb(advance=len(chunk))
95
+
96
+ def remaining_bytes(self):
97
+ return self.segment_length - self.num_bytes_read
@@ -99,7 +99,11 @@ def get_function_type(is_generator: Optional[bool]) -> "api_pb2.Function.Functio
99
99
 
100
100
 
101
101
  class FunctionInfo:
102
- """Class that helps us extract a bunch of information about a function."""
102
+ """Class that helps us extract a bunch of information about a locally defined function.
103
+
104
+ Used for populating the definition of a remote function, and for making .local() calls
105
+ on a host with the local definition available.
106
+ """
103
107
 
104
108
  raw_f: Optional[Callable[..., Any]] # if None - this is a "class service function"
105
109
  function_name: str
@@ -50,7 +50,7 @@ def patch_mock_servicer(cls):
50
50
 
51
51
  @contextlib.contextmanager
52
52
  def intercept(servicer):
53
- ctx = InterceptionContext()
53
+ ctx = InterceptionContext(servicer)
54
54
  servicer.interception_context = ctx
55
55
  yield ctx
56
56
  ctx._assert_responses_consumed()
@@ -101,7 +101,8 @@ class ResponseNotConsumed(Exception):
101
101
 
102
102
 
103
103
  class InterceptionContext:
104
- def __init__(self):
104
+ def __init__(self, servicer):
105
+ self._servicer = servicer
105
106
  self.calls: list[tuple[str, Any]] = [] # List[Tuple[method_name, message]]
106
107
  self.custom_responses: dict[str, list[tuple[Callable[[Any], bool], list[Any]]]] = defaultdict(list)
107
108
  self.custom_defaults: dict[str, Callable[["MockClientServicer", grpclib.server.Stream], Awaitable[None]]] = {}
@@ -149,6 +150,9 @@ class InterceptionContext:
149
150
  raise KeyError(f"No message of that type in call list: {self.calls}")
150
151
 
151
152
  def get_requests(self, method_name: str) -> list[Any]:
153
+ if not hasattr(self._servicer, method_name):
154
+ # we check this to prevent things like `assert ctx.get_requests("ASdfFunctionCreate") == 0` passing
155
+ raise ValueError(f"{method_name} not in MockServicer - did you spell it right?")
152
156
  return [msg for _method_name, msg in self.calls if _method_name == method_name]
153
157
 
154
158
  def _add_recv(self, method_name: str, msg):
@@ -2,12 +2,15 @@
2
2
  import base64
3
3
  import dataclasses
4
4
  import hashlib
5
- from typing import BinaryIO, Callable, Union
5
+ import time
6
+ from typing import BinaryIO, Callable, Optional, Sequence, Union
6
7
 
7
- HASH_CHUNK_SIZE = 4096
8
+ from modal.config import logger
8
9
 
10
+ HASH_CHUNK_SIZE = 65536
9
11
 
10
- def _update(hashers: list[Callable[[bytes], None]], data: Union[bytes, BinaryIO]) -> None:
12
+
13
+ def _update(hashers: Sequence[Callable[[bytes], None]], data: Union[bytes, BinaryIO]) -> None:
11
14
  if isinstance(data, bytes):
12
15
  for hasher in hashers:
13
16
  hasher(data)
@@ -26,20 +29,26 @@ def _update(hashers: list[Callable[[bytes], None]], data: Union[bytes, BinaryIO]
26
29
 
27
30
 
28
31
  def get_sha256_hex(data: Union[bytes, BinaryIO]) -> str:
32
+ t0 = time.monotonic()
29
33
  hasher = hashlib.sha256()
30
34
  _update([hasher.update], data)
35
+ logger.debug("get_sha256_hex took %.3fs", time.monotonic() - t0)
31
36
  return hasher.hexdigest()
32
37
 
33
38
 
34
39
  def get_sha256_base64(data: Union[bytes, BinaryIO]) -> str:
40
+ t0 = time.monotonic()
35
41
  hasher = hashlib.sha256()
36
42
  _update([hasher.update], data)
43
+ logger.debug("get_sha256_base64 took %.3fs", time.monotonic() - t0)
37
44
  return base64.b64encode(hasher.digest()).decode("ascii")
38
45
 
39
46
 
40
47
  def get_md5_base64(data: Union[bytes, BinaryIO]) -> str:
48
+ t0 = time.monotonic()
41
49
  hasher = hashlib.md5()
42
50
  _update([hasher.update], data)
51
+ logger.debug("get_md5_base64 took %.3fs", time.monotonic() - t0)
43
52
  return base64.b64encode(hasher.digest()).decode("utf-8")
44
53
 
45
54
 
@@ -48,12 +57,44 @@ class UploadHashes:
48
57
  md5_base64: str
49
58
  sha256_base64: str
50
59
 
60
+ def md5_hex(self) -> str:
61
+ return base64.b64decode(self.md5_base64).hex()
62
+
63
+ def sha256_hex(self) -> str:
64
+ return base64.b64decode(self.sha256_base64).hex()
65
+
66
+
67
+ def get_upload_hashes(
68
+ data: Union[bytes, BinaryIO], sha256_hex: Optional[str] = None, md5_hex: Optional[str] = None
69
+ ) -> UploadHashes:
70
+ t0 = time.monotonic()
71
+ hashers = {}
72
+
73
+ if not sha256_hex:
74
+ sha256 = hashlib.sha256()
75
+ hashers["sha256"] = sha256
76
+ if not md5_hex:
77
+ md5 = hashlib.md5()
78
+ hashers["md5"] = md5
79
+
80
+ if hashers:
81
+ updaters = [h.update for h in hashers.values()]
82
+ _update(updaters, data)
51
83
 
52
- def get_upload_hashes(data: Union[bytes, BinaryIO]) -> UploadHashes:
53
- md5 = hashlib.md5()
54
- sha256 = hashlib.sha256()
55
- _update([md5.update, sha256.update], data)
56
- return UploadHashes(
57
- md5_base64=base64.b64encode(md5.digest()).decode("ascii"),
58
- sha256_base64=base64.b64encode(sha256.digest()).decode("ascii"),
84
+ if sha256_hex:
85
+ sha256_base64 = base64.b64encode(bytes.fromhex(sha256_hex)).decode("ascii")
86
+ else:
87
+ sha256_base64 = base64.b64encode(hashers["sha256"].digest()).decode("ascii")
88
+
89
+ if md5_hex:
90
+ md5_base64 = base64.b64encode(bytes.fromhex(md5_hex)).decode("ascii")
91
+ else:
92
+ md5_base64 = base64.b64encode(hashers["md5"].digest()).decode("ascii")
93
+
94
+ hashes = UploadHashes(
95
+ md5_base64=md5_base64,
96
+ sha256_base64=sha256_base64,
59
97
  )
98
+
99
+ logger.debug("get_upload_hashes took %.3fs (%s)", time.monotonic() - t0, hashers.keys())
100
+ return hashes