modal 0.67.0__py3-none-any.whl → 0.67.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/_clustered_functions.py +2 -2
- modal/_clustered_functions.pyi +2 -2
- modal/_container_entrypoint.py +5 -4
- modal/_output.py +29 -28
- modal/_pty.py +2 -2
- modal/_resolver.py +6 -5
- modal/_resources.py +3 -3
- modal/_runtime/asgi.py +46 -6
- modal/_runtime/container_io_manager.py +22 -26
- modal/_runtime/execution_context.py +2 -2
- modal/_runtime/telemetry.py +1 -2
- modal/_runtime/user_code_imports.py +12 -14
- modal/_serialization.py +3 -7
- modal/_traceback.py +5 -5
- modal/_tunnel.py +5 -4
- modal/_tunnel.pyi +2 -2
- modal/_utils/async_utils.py +53 -17
- modal/_utils/blob_utils.py +22 -7
- modal/_utils/function_utils.py +14 -10
- modal/_utils/grpc_testing.py +7 -6
- modal/_utils/grpc_utils.py +2 -3
- modal/_utils/hash_utils.py +2 -2
- modal/_utils/mount_utils.py +5 -4
- modal/_utils/package_utils.py +2 -3
- modal/_utils/pattern_matcher.py +6 -6
- modal/_utils/rand_pb_testing.py +3 -3
- modal/_utils/shell_utils.py +2 -1
- modal/_vendor/a2wsgi_wsgi.py +62 -72
- modal/_vendor/cloudpickle.py +1 -1
- modal/_watcher.py +8 -7
- modal/app.py +81 -69
- modal/app.pyi +104 -99
- modal/call_graph.py +6 -6
- modal/cli/_download.py +3 -2
- modal/cli/_traceback.py +4 -4
- modal/cli/app.py +4 -4
- modal/cli/container.py +4 -4
- modal/cli/dict.py +1 -1
- modal/cli/environment.py +2 -3
- modal/cli/import_refs.py +1 -1
- modal/cli/launch.py +2 -2
- modal/cli/network_file_system.py +1 -1
- modal/cli/profile.py +1 -1
- modal/cli/programs/run_jupyter.py +2 -2
- modal/cli/programs/vscode.py +3 -3
- modal/cli/queues.py +1 -1
- modal/cli/run.py +6 -6
- modal/cli/secret.py +3 -3
- modal/cli/utils.py +2 -1
- modal/cli/volume.py +3 -3
- modal/client.py +6 -11
- modal/client.pyi +18 -27
- modal/cloud_bucket_mount.py +3 -3
- modal/cloud_bucket_mount.pyi +2 -2
- modal/cls.py +32 -32
- modal/cls.pyi +35 -34
- modal/config.py +3 -2
- modal/container_process.py +6 -2
- modal/dict.py +6 -3
- modal/dict.pyi +10 -9
- modal/environments.py +3 -3
- modal/environments.pyi +3 -3
- modal/exception.py +2 -3
- modal/functions.py +111 -40
- modal/functions.pyi +71 -48
- modal/image.py +46 -49
- modal/image.pyi +102 -101
- modal/io_streams.py +20 -12
- modal/io_streams.pyi +24 -14
- modal/mount.py +24 -24
- modal/mount.pyi +28 -29
- modal/network_file_system.py +14 -11
- modal/network_file_system.pyi +12 -11
- modal/object.py +9 -8
- modal/object.pyi +47 -34
- modal/output.py +2 -1
- modal/parallel_map.py +4 -4
- modal/partial_function.py +10 -14
- modal/partial_function.pyi +17 -18
- modal/queue.py +11 -8
- modal/queue.pyi +23 -22
- modal/retries.py +38 -0
- modal/runner.py +8 -7
- modal/runner.pyi +8 -14
- modal/running_app.py +3 -3
- modal/sandbox.py +20 -13
- modal/sandbox.pyi +73 -72
- modal/scheduler_placement.py +2 -1
- modal/secret.py +7 -7
- modal/secret.pyi +12 -12
- modal/serving.py +4 -3
- modal/serving.pyi +5 -4
- modal/token_flow.py +3 -2
- modal/token_flow.pyi +3 -3
- modal/volume.py +16 -23
- modal/volume.pyi +17 -16
- {modal-0.67.0.dist-info → modal-0.67.22.dist-info}/METADATA +2 -2
- modal-0.67.22.dist-info/RECORD +168 -0
- modal_docs/mdmd/signatures.py +1 -2
- modal_global_objects/mounts/python_standalone.py +1 -1
- modal_proto/api.proto +13 -0
- modal_proto/api_grpc.py +16 -0
- modal_proto/api_pb2.py +241 -221
- modal_proto/api_pb2.pyi +41 -0
- modal_proto/api_pb2_grpc.py +33 -0
- modal_proto/api_pb2_grpc.pyi +10 -0
- modal_proto/modal_api_grpc.py +1 -0
- modal_version/_version_generated.py +1 -1
- modal-0.67.0.dist-info/RECORD +0 -168
- {modal-0.67.0.dist-info → modal-0.67.22.dist-info}/LICENSE +0 -0
- {modal-0.67.0.dist-info → modal-0.67.22.dist-info}/WHEEL +0 -0
- {modal-0.67.0.dist-info → modal-0.67.22.dist-info}/entry_points.txt +0 -0
- {modal-0.67.0.dist-info → modal-0.67.22.dist-info}/top_level.txt +0 -0
modal/_serialization.py
CHANGED
@@ -398,10 +398,8 @@ PARAM_TYPE_MAPPING = {
|
|
398
398
|
}
|
399
399
|
|
400
400
|
|
401
|
-
def serialize_proto_params(
|
402
|
-
|
403
|
-
) -> bytes:
|
404
|
-
proto_params: typing.List[api_pb2.ClassParameterValue] = []
|
401
|
+
def serialize_proto_params(python_params: dict[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]) -> bytes:
|
402
|
+
proto_params: list[api_pb2.ClassParameterValue] = []
|
405
403
|
for schema_param in schema:
|
406
404
|
type_info = PARAM_TYPE_MAPPING.get(schema_param.type)
|
407
405
|
if not type_info:
|
@@ -426,9 +424,7 @@ def serialize_proto_params(
|
|
426
424
|
return proto_bytes
|
427
425
|
|
428
426
|
|
429
|
-
def deserialize_proto_params(
|
430
|
-
serialized_params: bytes, schema: typing.List[api_pb2.ClassParameterSpec]
|
431
|
-
) -> typing.Dict[str, Any]:
|
427
|
+
def deserialize_proto_params(serialized_params: bytes, schema: list[api_pb2.ClassParameterSpec]) -> dict[str, Any]:
|
432
428
|
proto_struct = api_pb2.ClassParameterSet()
|
433
429
|
proto_struct.ParseFromString(serialized_params)
|
434
430
|
value_by_name = {p.name: p for p in proto_struct.parameters}
|
modal/_traceback.py
CHANGED
@@ -8,15 +8,15 @@ import re
|
|
8
8
|
import sys
|
9
9
|
import traceback
|
10
10
|
from types import TracebackType
|
11
|
-
from typing import Any,
|
11
|
+
from typing import Any, Optional
|
12
12
|
|
13
13
|
from ._vendor.tblib import Traceback as TBLibTraceback
|
14
14
|
|
15
|
-
TBDictType =
|
16
|
-
LineCacheType =
|
15
|
+
TBDictType = dict[str, Any]
|
16
|
+
LineCacheType = dict[tuple[str, str], str]
|
17
17
|
|
18
18
|
|
19
|
-
def extract_traceback(exc: BaseException, task_id: str) ->
|
19
|
+
def extract_traceback(exc: BaseException, task_id: str) -> tuple[TBDictType, LineCacheType]:
|
20
20
|
"""Given an exception, extract a serializable traceback (with task ID markers included),
|
21
21
|
and a line cache that maps (filename, lineno) to line contents. The latter is used to show
|
22
22
|
a helpful traceback to the user, even if they don't have packages installed locally that
|
@@ -103,7 +103,7 @@ def traceback_contains_remote_call(tb: Optional[TracebackType]) -> bool:
|
|
103
103
|
return False
|
104
104
|
|
105
105
|
|
106
|
-
def print_exception(exc: Optional[
|
106
|
+
def print_exception(exc: Optional[type[BaseException]], value: Optional[BaseException], tb: Optional[TracebackType]):
|
107
107
|
"""Add backwards compatibility for printing exceptions with "notes" for Python<3.11."""
|
108
108
|
traceback.print_exception(exc, value, tb)
|
109
109
|
if sys.version_info < (3, 11) and value is not None:
|
modal/_tunnel.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
# Copyright Modal Labs 2023
|
2
2
|
"""Client for Modal relay servers, allowing users to expose TLS."""
|
3
3
|
|
4
|
+
from collections.abc import AsyncIterator
|
4
5
|
from dataclasses import dataclass
|
5
|
-
from typing import
|
6
|
+
from typing import Optional
|
6
7
|
|
7
8
|
from grpclib import GRPCError, Status
|
8
9
|
from synchronicity.async_wrap import asynccontextmanager
|
@@ -35,12 +36,12 @@ class Tunnel:
|
|
35
36
|
return value
|
36
37
|
|
37
38
|
@property
|
38
|
-
def tls_socket(self) ->
|
39
|
+
def tls_socket(self) -> tuple[str, int]:
|
39
40
|
"""Get the public TLS socket as a (host, port) tuple."""
|
40
41
|
return (self.host, self.port)
|
41
42
|
|
42
43
|
@property
|
43
|
-
def tcp_socket(self) ->
|
44
|
+
def tcp_socket(self) -> tuple[str, int]:
|
44
45
|
"""Get the public TCP socket as a (host, port) tuple."""
|
45
46
|
if not self.unencrypted_host:
|
46
47
|
raise InvalidError(
|
@@ -61,7 +62,7 @@ async def _forward(port: int, *, unencrypted: bool = False, client: Optional[_Cl
|
|
61
62
|
|
62
63
|
**Usage:**
|
63
64
|
|
64
|
-
```python
|
65
|
+
```python notest
|
65
66
|
import modal
|
66
67
|
from flask import Flask
|
67
68
|
|
modal/_tunnel.pyi
CHANGED
@@ -12,9 +12,9 @@ class Tunnel:
|
|
12
12
|
@property
|
13
13
|
def url(self) -> str: ...
|
14
14
|
@property
|
15
|
-
def tls_socket(self) ->
|
15
|
+
def tls_socket(self) -> tuple[str, int]: ...
|
16
16
|
@property
|
17
|
-
def tcp_socket(self) ->
|
17
|
+
def tcp_socket(self) -> tuple[str, int]: ...
|
18
18
|
def __init__(self, host: str, port: int, unencrypted_host: str, unencrypted_port: int) -> None: ...
|
19
19
|
def __repr__(self): ...
|
20
20
|
def __eq__(self, other): ...
|
modal/_utils/async_utils.py
CHANGED
@@ -6,19 +6,13 @@ import inspect
|
|
6
6
|
import itertools
|
7
7
|
import time
|
8
8
|
import typing
|
9
|
+
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Iterable, Iterator
|
9
10
|
from contextlib import asynccontextmanager
|
10
11
|
from dataclasses import dataclass
|
11
12
|
from typing import (
|
12
13
|
Any,
|
13
|
-
AsyncGenerator,
|
14
|
-
Awaitable,
|
15
14
|
Callable,
|
16
|
-
Iterable,
|
17
|
-
Iterator,
|
18
|
-
List,
|
19
15
|
Optional,
|
20
|
-
Set,
|
21
|
-
Tuple,
|
22
16
|
TypeVar,
|
23
17
|
Union,
|
24
18
|
cast,
|
@@ -118,7 +112,7 @@ class TaskContext:
|
|
118
112
|
```
|
119
113
|
"""
|
120
114
|
|
121
|
-
_loops:
|
115
|
+
_loops: set[asyncio.Task]
|
122
116
|
|
123
117
|
def __init__(self, grace: Optional[float] = None):
|
124
118
|
self._grace = grace
|
@@ -272,7 +266,7 @@ async def queue_batch_iterator(q: asyncio.Queue, max_batch_size=100, debounce_ti
|
|
272
266
|
|
273
267
|
Treats a None value as end of queue items
|
274
268
|
"""
|
275
|
-
item_list:
|
269
|
+
item_list: list[Any] = []
|
276
270
|
|
277
271
|
while True:
|
278
272
|
if q.empty() and len(item_list) > 0:
|
@@ -387,8 +381,7 @@ class AsyncOrSyncIterable:
|
|
387
381
|
def __iter__(self):
|
388
382
|
try:
|
389
383
|
with Runner() as runner:
|
390
|
-
|
391
|
-
yield output # type: ignore
|
384
|
+
yield from run_async_gen(runner, self._async_iterable)
|
392
385
|
except NestedEventLoops:
|
393
386
|
raise InvalidError(self.nested_async_message)
|
394
387
|
|
@@ -491,14 +484,17 @@ class aclosing(typing.Generic[T]): # noqa
|
|
491
484
|
await self.agen.aclose()
|
492
485
|
|
493
486
|
|
494
|
-
async def sync_or_async_iter(iter: Union[Iterable[T],
|
487
|
+
async def sync_or_async_iter(iter: Union[Iterable[T], AsyncIterable[T]]) -> AsyncGenerator[T, None]:
|
495
488
|
if hasattr(iter, "__aiter__"):
|
496
489
|
agen = typing.cast(AsyncGenerator[T, None], iter)
|
497
490
|
try:
|
498
491
|
async for item in agen:
|
499
492
|
yield item
|
500
493
|
finally:
|
501
|
-
|
494
|
+
if hasattr(agen, "aclose"):
|
495
|
+
# All AsyncGenerator's have an aclose method
|
496
|
+
# but some AsyncIterable's don't necessarily
|
497
|
+
await agen.aclose()
|
502
498
|
else:
|
503
499
|
assert hasattr(iter, "__iter__"), "sync_or_async_iter requires an Iterable or AsyncGenerator"
|
504
500
|
# This intentionally could block the event loop for the duration of calling __iter__ and __next__,
|
@@ -509,12 +505,12 @@ async def sync_or_async_iter(iter: Union[Iterable[T], AsyncGenerator[T, None]])
|
|
509
505
|
|
510
506
|
|
511
507
|
@typing.overload
|
512
|
-
def async_zip(g1: AsyncGenerator[T, None], g2: AsyncGenerator[V, None], /) -> AsyncGenerator[
|
508
|
+
def async_zip(g1: AsyncGenerator[T, None], g2: AsyncGenerator[V, None], /) -> AsyncGenerator[tuple[T, V], None]:
|
513
509
|
...
|
514
510
|
|
515
511
|
|
516
512
|
@typing.overload
|
517
|
-
def async_zip(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[
|
513
|
+
def async_zip(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[tuple[T, ...], None]:
|
518
514
|
...
|
519
515
|
|
520
516
|
|
@@ -573,6 +569,46 @@ STOP_SENTINEL = StopSentinelType()
|
|
573
569
|
|
574
570
|
|
575
571
|
async def async_merge(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
572
|
+
"""
|
573
|
+
Asynchronously merges multiple async generators into a single async generator.
|
574
|
+
|
575
|
+
This function takes multiple async generators and yields their values in the order
|
576
|
+
they are produced. If any generator raises an exception, the exception is propagated.
|
577
|
+
|
578
|
+
Args:
|
579
|
+
*generators: One or more async generators to be merged.
|
580
|
+
|
581
|
+
Yields:
|
582
|
+
The values produced by the input async generators.
|
583
|
+
|
584
|
+
Raises:
|
585
|
+
Exception: If any of the input generators raises an exception, it is propagated.
|
586
|
+
|
587
|
+
Usage:
|
588
|
+
```python
|
589
|
+
import asyncio
|
590
|
+
from modal._utils.async_utils import async_merge
|
591
|
+
|
592
|
+
async def gen1():
|
593
|
+
yield 1
|
594
|
+
yield 2
|
595
|
+
|
596
|
+
async def gen2():
|
597
|
+
yield "a"
|
598
|
+
yield "b"
|
599
|
+
|
600
|
+
async def example():
|
601
|
+
values = set()
|
602
|
+
async for value in async_merge(gen1(), gen2()):
|
603
|
+
values.add(value)
|
604
|
+
|
605
|
+
return values
|
606
|
+
|
607
|
+
# Output could be: {1, "a", 2, "b"} (order may vary)
|
608
|
+
values = asyncio.run(example())
|
609
|
+
assert values == {1, "a", 2, "b"}
|
610
|
+
```
|
611
|
+
"""
|
576
612
|
queue: asyncio.Queue[Union[ValueWrapper[T], ExceptionWrapper]] = asyncio.Queue(maxsize=len(generators) * 10)
|
577
613
|
|
578
614
|
async def producer(generator: AsyncGenerator[T, None]):
|
@@ -582,7 +618,7 @@ async def async_merge(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T,
|
|
582
618
|
except Exception as e:
|
583
619
|
await queue.put(ExceptionWrapper(e))
|
584
620
|
|
585
|
-
tasks =
|
621
|
+
tasks = {asyncio.create_task(producer(gen)) for gen in generators}
|
586
622
|
new_output_task = asyncio.create_task(queue.get())
|
587
623
|
|
588
624
|
try:
|
@@ -689,7 +725,7 @@ async def async_map_ordered(
|
|
689
725
|
) -> AsyncGenerator[V, None]:
|
690
726
|
semaphore = asyncio.Semaphore(buffer_size or concurrency)
|
691
727
|
|
692
|
-
async def mapper_func_wrapper(tup:
|
728
|
+
async def mapper_func_wrapper(tup: tuple[int, T]) -> tuple[int, V]:
|
693
729
|
return (tup[0], await async_mapper_func(tup[1]))
|
694
730
|
|
695
731
|
async def counter() -> AsyncGenerator[int, None]:
|
modal/_utils/blob_utils.py
CHANGED
@@ -5,9 +5,11 @@ import hashlib
|
|
5
5
|
import io
|
6
6
|
import os
|
7
7
|
import platform
|
8
|
+
import time
|
9
|
+
from collections.abc import AsyncIterator
|
8
10
|
from contextlib import AbstractContextManager, contextmanager
|
9
11
|
from pathlib import Path, PurePosixPath
|
10
|
-
from typing import Any,
|
12
|
+
from typing import Any, BinaryIO, Callable, Optional, Union
|
11
13
|
from urllib.parse import urlparse
|
12
14
|
|
13
15
|
from aiohttp import BytesIOPayload
|
@@ -173,7 +175,7 @@ async def perform_multipart_upload(
|
|
173
175
|
*,
|
174
176
|
content_length: int,
|
175
177
|
max_part_size: int,
|
176
|
-
part_urls:
|
178
|
+
part_urls: list[str],
|
177
179
|
completion_url: str,
|
178
180
|
upload_chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
|
179
181
|
progress_report_cb: Optional[Callable] = None,
|
@@ -184,7 +186,7 @@ async def perform_multipart_upload(
|
|
184
186
|
|
185
187
|
# Give each part its own IO reader object to avoid needing to
|
186
188
|
# lock access to the reader's position pointer.
|
187
|
-
data_file_readers:
|
189
|
+
data_file_readers: list[BinaryIO]
|
188
190
|
if isinstance(data_file, io.BytesIO):
|
189
191
|
view = data_file.getbuffer() # does not copy data
|
190
192
|
data_file_readers = [io.BytesIO(view) for _ in range(len(part_urls))]
|
@@ -288,11 +290,18 @@ async def _blob_upload(
|
|
288
290
|
|
289
291
|
|
290
292
|
async def blob_upload(payload: bytes, stub: ModalClientModal) -> str:
|
293
|
+
size_mib = len(payload) / 1024 / 1024
|
294
|
+
logger.debug(f"Uploading large blob of size {size_mib:.2f} MiB")
|
295
|
+
t0 = time.time()
|
291
296
|
if isinstance(payload, str):
|
292
297
|
logger.warning("Blob uploading string, not bytes - auto-encoding as utf8")
|
293
298
|
payload = payload.encode("utf8")
|
294
299
|
upload_hashes = get_upload_hashes(payload)
|
295
|
-
|
300
|
+
blob_id = await _blob_upload(upload_hashes, payload, stub)
|
301
|
+
dur_s = max(time.time() - t0, 0.001) # avoid division by zero
|
302
|
+
throughput_mib_s = (size_mib) / dur_s
|
303
|
+
logger.debug(f"Uploaded large blob of size {size_mib:.2f} MiB ({throughput_mib_s:.2f} MiB/s)." f" {blob_id}")
|
304
|
+
return blob_id
|
296
305
|
|
297
306
|
|
298
307
|
async def blob_upload_file(
|
@@ -317,11 +326,17 @@ async def _download_from_url(download_url: str) -> bytes:
|
|
317
326
|
|
318
327
|
|
319
328
|
async def blob_download(blob_id: str, stub: ModalClientModal) -> bytes:
|
320
|
-
|
329
|
+
"""Convenience function for reading all of the downloaded file into memory."""
|
330
|
+
logger.debug(f"Downloading large blob {blob_id}")
|
331
|
+
t0 = time.time()
|
321
332
|
req = api_pb2.BlobGetRequest(blob_id=blob_id)
|
322
333
|
resp = await retry_transient_errors(stub.BlobGet, req)
|
323
|
-
|
324
|
-
|
334
|
+
data = await _download_from_url(resp.download_url)
|
335
|
+
size_mib = len(data) / 1024 / 1024
|
336
|
+
dur_s = max(time.time() - t0, 0.001) # avoid division by zero
|
337
|
+
throughput_mib_s = size_mib / dur_s
|
338
|
+
logger.debug(f"Downloaded large blob {blob_id} of size {size_mib:.2f} MiB ({throughput_mib_s:.2f} MiB/s)")
|
339
|
+
return data
|
325
340
|
|
326
341
|
|
327
342
|
async def blob_iter(blob_id: str, stub: ModalClientModal) -> AsyncIterator[bytes]:
|
modal/_utils/function_utils.py
CHANGED
@@ -2,9 +2,10 @@
|
|
2
2
|
import asyncio
|
3
3
|
import inspect
|
4
4
|
import os
|
5
|
+
from collections.abc import AsyncGenerator
|
5
6
|
from enum import Enum
|
6
7
|
from pathlib import Path, PurePosixPath
|
7
|
-
from typing import Any,
|
8
|
+
from typing import Any, Callable, Literal, Optional
|
8
9
|
|
9
10
|
from grpclib import GRPCError
|
10
11
|
from grpclib.exceptions import StreamTerminatedError
|
@@ -30,7 +31,7 @@ class FunctionInfoType(Enum):
|
|
30
31
|
|
31
32
|
|
32
33
|
# TODO(elias): Add support for quoted/str annotations
|
33
|
-
CLASS_PARAM_TYPE_MAP:
|
34
|
+
CLASS_PARAM_TYPE_MAP: dict[type, tuple["api_pb2.ParameterType.ValueType", str]] = {
|
34
35
|
str: (api_pb2.PARAM_TYPE_STRING, "string_default"),
|
35
36
|
int: (api_pb2.PARAM_TYPE_INT, "int_default"),
|
36
37
|
}
|
@@ -102,7 +103,7 @@ class FunctionInfo:
|
|
102
103
|
|
103
104
|
raw_f: Optional[Callable[..., Any]] # if None - this is a "class service function"
|
104
105
|
function_name: str
|
105
|
-
user_cls: Optional[
|
106
|
+
user_cls: Optional[type[Any]]
|
106
107
|
definition_type: "modal_proto.api_pb2.Function.DefinitionType.ValueType"
|
107
108
|
module_name: Optional[str]
|
108
109
|
|
@@ -123,7 +124,7 @@ class FunctionInfo:
|
|
123
124
|
f: Optional[Callable[..., Any]],
|
124
125
|
serialized=False,
|
125
126
|
name_override: Optional[str] = None,
|
126
|
-
user_cls: Optional[
|
127
|
+
user_cls: Optional[type] = None,
|
127
128
|
):
|
128
129
|
self.raw_f = f
|
129
130
|
self.user_cls = user_cls
|
@@ -133,6 +134,9 @@ class FunctionInfo:
|
|
133
134
|
elif f is None and user_cls:
|
134
135
|
# "service function" for running all methods of a class
|
135
136
|
self.function_name = f"{user_cls.__name__}.*"
|
137
|
+
elif f and user_cls:
|
138
|
+
# Method may be defined on superclass of the wrapped class
|
139
|
+
self.function_name = f"{user_cls.__name__}.{f.__name__}"
|
136
140
|
else:
|
137
141
|
self.function_name = f.__qualname__
|
138
142
|
|
@@ -147,7 +151,7 @@ class FunctionInfo:
|
|
147
151
|
# Get the package path
|
148
152
|
# Note: __import__ always returns the top-level package.
|
149
153
|
self._file = os.path.abspath(module.__file__)
|
150
|
-
package_paths =
|
154
|
+
package_paths = {os.path.abspath(p) for p in __import__(module.__package__).__path__}
|
151
155
|
# There might be multiple package paths in some weird cases
|
152
156
|
base_dirs = [
|
153
157
|
base_dir for base_dir in package_paths if os.path.commonpath((base_dir, self._file)) == base_dir
|
@@ -210,7 +214,7 @@ class FunctionInfo:
|
|
210
214
|
logger.debug(f"Serializing function for class service function {self.user_cls.__qualname__} as empty")
|
211
215
|
return b""
|
212
216
|
|
213
|
-
def get_cls_vars(self) ->
|
217
|
+
def get_cls_vars(self) -> dict[str, Any]:
|
214
218
|
if self.user_cls is not None:
|
215
219
|
cls_vars = {
|
216
220
|
attr: getattr(self.user_cls, attr)
|
@@ -220,7 +224,7 @@ class FunctionInfo:
|
|
220
224
|
return cls_vars
|
221
225
|
return {}
|
222
226
|
|
223
|
-
def get_cls_var_attrs(self) ->
|
227
|
+
def get_cls_var_attrs(self) -> dict[str, Any]:
|
224
228
|
import dis
|
225
229
|
|
226
230
|
import opcode
|
@@ -241,7 +245,7 @@ class FunctionInfo:
|
|
241
245
|
f_attrs = {k: cls_vars[k] for k in cls_vars if k in f_attr_ops}
|
242
246
|
return f_attrs
|
243
247
|
|
244
|
-
def get_globals(self) ->
|
248
|
+
def get_globals(self) -> dict[str, Any]:
|
245
249
|
from .._vendor.cloudpickle import _extract_code_globals
|
246
250
|
|
247
251
|
func = self.raw_f
|
@@ -262,7 +266,7 @@ class FunctionInfo:
|
|
262
266
|
# annotation parameters trigger strictly typed parameterization
|
263
267
|
# which enables web endpoint for parameterized classes
|
264
268
|
|
265
|
-
modal_parameters:
|
269
|
+
modal_parameters: list[api_pb2.ClassParameterSpec] = []
|
266
270
|
signature = _get_class_constructor_signature(self.user_cls)
|
267
271
|
for param in signature.parameters.values():
|
268
272
|
has_default = param.default is not param.empty
|
@@ -278,7 +282,7 @@ class FunctionInfo:
|
|
278
282
|
format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO, schema=modal_parameters
|
279
283
|
)
|
280
284
|
|
281
|
-
def get_entrypoint_mount(self) ->
|
285
|
+
def get_entrypoint_mount(self) -> list[_Mount]:
|
282
286
|
"""
|
283
287
|
Includes:
|
284
288
|
* Implicit mount of the function itself (the module or package that the function is part of)
|
modal/_utils/grpc_testing.py
CHANGED
@@ -4,7 +4,8 @@ import inspect
|
|
4
4
|
import logging
|
5
5
|
import typing
|
6
6
|
from collections import Counter, defaultdict
|
7
|
-
from
|
7
|
+
from collections.abc import Awaitable
|
8
|
+
from typing import Any, Callable
|
8
9
|
|
9
10
|
import grpclib.server
|
10
11
|
from grpclib import GRPCError, Status
|
@@ -93,7 +94,7 @@ def patch_mock_servicer(cls):
|
|
93
94
|
|
94
95
|
|
95
96
|
class ResponseNotConsumed(Exception):
|
96
|
-
def __init__(self, unconsumed_requests:
|
97
|
+
def __init__(self, unconsumed_requests: list[str]):
|
97
98
|
self.unconsumed_requests = unconsumed_requests
|
98
99
|
request_count = Counter(unconsumed_requests)
|
99
100
|
super().__init__(f"Expected but did not receive the following requests: {request_count}")
|
@@ -101,9 +102,9 @@ class ResponseNotConsumed(Exception):
|
|
101
102
|
|
102
103
|
class InterceptionContext:
|
103
104
|
def __init__(self):
|
104
|
-
self.calls:
|
105
|
-
self.custom_responses:
|
106
|
-
self.custom_defaults:
|
105
|
+
self.calls: list[tuple[str, Any]] = [] # List[Tuple[method_name, message]]
|
106
|
+
self.custom_responses: dict[str, list[tuple[Callable[[Any], bool], list[Any]]]] = defaultdict(list)
|
107
|
+
self.custom_defaults: dict[str, Callable[["MockClientServicer", grpclib.server.Stream], Awaitable[None]]] = {}
|
107
108
|
|
108
109
|
def add_response(
|
109
110
|
self, method_name: str, first_payload, *, request_filter: Callable[[Any], bool] = lambda req: True
|
@@ -147,7 +148,7 @@ class InterceptionContext:
|
|
147
148
|
|
148
149
|
raise KeyError(f"No message of that type in call list: {self.calls}")
|
149
150
|
|
150
|
-
def get_requests(self, method_name: str) ->
|
151
|
+
def get_requests(self, method_name: str) -> list[Any]:
|
151
152
|
return [msg for _method_name, msg in self.calls if _method_name == method_name]
|
152
153
|
|
153
154
|
def _add_recv(self, method_name: str, msg):
|
modal/_utils/grpc_utils.py
CHANGED
@@ -7,10 +7,9 @@ import time
|
|
7
7
|
import typing
|
8
8
|
import urllib.parse
|
9
9
|
import uuid
|
10
|
+
from collections.abc import AsyncIterator
|
10
11
|
from typing import (
|
11
12
|
Any,
|
12
|
-
AsyncIterator,
|
13
|
-
Dict,
|
14
13
|
Optional,
|
15
14
|
TypeVar,
|
16
15
|
)
|
@@ -72,7 +71,7 @@ RETRYABLE_GRPC_STATUS_CODES = [
|
|
72
71
|
|
73
72
|
def create_channel(
|
74
73
|
server_url: str,
|
75
|
-
metadata:
|
74
|
+
metadata: dict[str, str] = {},
|
76
75
|
) -> grpclib.client.Channel:
|
77
76
|
"""Creates a grpclib.Channel.
|
78
77
|
|
modal/_utils/hash_utils.py
CHANGED
@@ -2,12 +2,12 @@
|
|
2
2
|
import base64
|
3
3
|
import dataclasses
|
4
4
|
import hashlib
|
5
|
-
from typing import BinaryIO, Callable,
|
5
|
+
from typing import BinaryIO, Callable, Union
|
6
6
|
|
7
7
|
HASH_CHUNK_SIZE = 4096
|
8
8
|
|
9
9
|
|
10
|
-
def _update(hashers:
|
10
|
+
def _update(hashers: list[Callable[[bytes], None]], data: Union[bytes, BinaryIO]) -> None:
|
11
11
|
if isinstance(data, bytes):
|
12
12
|
for hasher in hashers:
|
13
13
|
hasher(data)
|
modal/_utils/mount_utils.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
# Copyright Modal Labs 2022
|
2
2
|
import posixpath
|
3
3
|
import typing
|
4
|
+
from collections.abc import Mapping, Sequence
|
4
5
|
from pathlib import PurePath, PurePosixPath
|
5
|
-
from typing import
|
6
|
+
from typing import Union
|
6
7
|
|
7
8
|
from ..cloud_bucket_mount import _CloudBucketMount
|
8
9
|
from ..exception import InvalidError
|
@@ -15,7 +16,7 @@ T = typing.TypeVar("T", bound=Union[_Volume, _NetworkFileSystem, _CloudBucketMou
|
|
15
16
|
def validate_mount_points(
|
16
17
|
display_name: str,
|
17
18
|
volume_likes: Mapping[Union[str, PurePosixPath], T],
|
18
|
-
) ->
|
19
|
+
) -> list[tuple[str, T]]:
|
19
20
|
"""Mount point path validation for volumes and network file systems."""
|
20
21
|
|
21
22
|
if not isinstance(volume_likes, dict):
|
@@ -57,11 +58,11 @@ def validate_network_file_systems(
|
|
57
58
|
|
58
59
|
def validate_volumes(
|
59
60
|
volumes: Mapping[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]],
|
60
|
-
) -> Sequence[
|
61
|
+
) -> Sequence[tuple[str, Union[_Volume, _CloudBucketMount]]]:
|
61
62
|
validated_volumes = validate_mount_points("Volume", volumes)
|
62
63
|
# We don't support mounting a modal.Volume in more than one location,
|
63
64
|
# but the same CloudBucketMount object can be used in more than one location.
|
64
|
-
volume_to_paths:
|
65
|
+
volume_to_paths: dict[_Volume, list[str]] = {}
|
65
66
|
for path, volume in validated_volumes:
|
66
67
|
if not isinstance(volume, (_Volume, _CloudBucketMount)):
|
67
68
|
raise InvalidError(f"Object of type {type(volume)} mounted at '{path}' is not useable as a volume.")
|
modal/_utils/package_utils.py
CHANGED
@@ -4,7 +4,6 @@ import importlib.util
|
|
4
4
|
import typing
|
5
5
|
from importlib.metadata import PackageNotFoundError, files
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Tuple
|
8
7
|
|
9
8
|
from ..exception import ModuleNotMountable
|
10
9
|
|
@@ -24,7 +23,7 @@ def get_file_formats(module):
|
|
24
23
|
BINARY_FORMATS = ["so", "S", "s", "asm"] # TODO
|
25
24
|
|
26
25
|
|
27
|
-
def get_module_mount_info(module_name: str) -> typing.Sequence[
|
26
|
+
def get_module_mount_info(module_name: str) -> typing.Sequence[tuple[bool, Path]]:
|
28
27
|
"""Returns a list of tuples [(is_dir, path)] describing how to mount a given module."""
|
29
28
|
file_formats = get_file_formats(module_name)
|
30
29
|
if set(BINARY_FORMATS) & set(file_formats):
|
@@ -49,7 +48,7 @@ def get_module_mount_info(module_name: str) -> typing.Sequence[typing.Tuple[bool
|
|
49
48
|
return entries
|
50
49
|
|
51
50
|
|
52
|
-
def parse_major_minor_version(version_string: str) ->
|
51
|
+
def parse_major_minor_version(version_string: str) -> tuple[int, int]:
|
53
52
|
parts = version_string.split(".")
|
54
53
|
if len(parts) < 2:
|
55
54
|
raise ValueError("version_string must have at least an 'X.Y' format")
|
modal/_utils/pattern_matcher.py
CHANGED
@@ -12,7 +12,7 @@ then asking it whether file paths match any of its patterns.
|
|
12
12
|
import enum
|
13
13
|
import os
|
14
14
|
import re
|
15
|
-
from typing import
|
15
|
+
from typing import Optional, TextIO
|
16
16
|
|
17
17
|
escape_chars = frozenset(".+()|{}$")
|
18
18
|
|
@@ -32,7 +32,7 @@ class Pattern:
|
|
32
32
|
"""Initialize a new Pattern instance."""
|
33
33
|
self.match_type = MatchType.UNKNOWN
|
34
34
|
self.cleaned_pattern = ""
|
35
|
-
self.dirs:
|
35
|
+
self.dirs: list[str] = []
|
36
36
|
self.regexp: Optional[re.Pattern] = None
|
37
37
|
self.exclusion = False
|
38
38
|
|
@@ -151,7 +151,7 @@ class Pattern:
|
|
151
151
|
class PatternMatcher:
|
152
152
|
"""Allows checking paths against a list of patterns."""
|
153
153
|
|
154
|
-
def __init__(self, patterns:
|
154
|
+
def __init__(self, patterns: list[str]) -> None:
|
155
155
|
"""Initialize a new PatternMatcher instance.
|
156
156
|
|
157
157
|
Args:
|
@@ -160,7 +160,7 @@ class PatternMatcher:
|
|
160
160
|
Raises:
|
161
161
|
ValueError: If an illegal exclusion pattern is provided.
|
162
162
|
"""
|
163
|
-
self.patterns:
|
163
|
+
self.patterns: list[Pattern] = []
|
164
164
|
self.exclusions = False
|
165
165
|
for pattern in patterns:
|
166
166
|
pattern = pattern.strip()
|
@@ -217,7 +217,7 @@ class PatternMatcher:
|
|
217
217
|
return matched
|
218
218
|
|
219
219
|
|
220
|
-
def read_ignorefile(reader: TextIO) ->
|
220
|
+
def read_ignorefile(reader: TextIO) -> list[str]:
|
221
221
|
"""Read an ignore file from a reader and return the list of file patterns to
|
222
222
|
ignore, applying the following rules:
|
223
223
|
|
@@ -241,7 +241,7 @@ def read_ignorefile(reader: TextIO) -> List[str]:
|
|
241
241
|
if reader is None:
|
242
242
|
return []
|
243
243
|
|
244
|
-
excludes:
|
244
|
+
excludes: list[str] = []
|
245
245
|
|
246
246
|
for line in reader:
|
247
247
|
pattern = line.rstrip("\n\r")
|
modal/_utils/rand_pb_testing.py
CHANGED
@@ -7,13 +7,13 @@ Modal, with random seeds, and it supports oneofs, and Protobuf v4.
|
|
7
7
|
|
8
8
|
import string
|
9
9
|
from random import Random
|
10
|
-
from typing import Any, Callable,
|
10
|
+
from typing import Any, Callable, Optional, TypeVar
|
11
11
|
|
12
12
|
from google.protobuf.descriptor import Descriptor, FieldDescriptor
|
13
13
|
|
14
14
|
T = TypeVar("T")
|
15
15
|
|
16
|
-
_FIELD_RANDOM_GENERATOR:
|
16
|
+
_FIELD_RANDOM_GENERATOR: dict[int, Callable[[Random], Any]] = {
|
17
17
|
FieldDescriptor.TYPE_DOUBLE: lambda rand: rand.normalvariate(0, 1),
|
18
18
|
FieldDescriptor.TYPE_FLOAT: lambda rand: rand.normalvariate(0, 1),
|
19
19
|
FieldDescriptor.TYPE_INT32: lambda rand: int.from_bytes(rand.randbytes(4), "little", signed=True),
|
@@ -71,7 +71,7 @@ def _fill(msg, desc: Descriptor, rand: Random) -> None:
|
|
71
71
|
setattr(msg, field.name, generator(rand))
|
72
72
|
|
73
73
|
|
74
|
-
def rand_pb(proto:
|
74
|
+
def rand_pb(proto: type[T], rand: Optional[Random] = None) -> T:
|
75
75
|
"""Generate a pseudorandom protobuf message.
|
76
76
|
|
77
77
|
```python notest
|
modal/_utils/shell_utils.py
CHANGED