modal 0.62.16__py3-none-any.whl → 0.72.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/__init__.py +17 -13
- modal/__main__.py +41 -3
- modal/_clustered_functions.py +80 -0
- modal/_clustered_functions.pyi +22 -0
- modal/_container_entrypoint.py +420 -937
- modal/_ipython.py +3 -13
- modal/_location.py +17 -10
- modal/_output.py +243 -99
- modal/_pty.py +2 -2
- modal/_resolver.py +55 -59
- modal/_resources.py +51 -0
- modal/_runtime/__init__.py +1 -0
- modal/_runtime/asgi.py +519 -0
- modal/_runtime/container_io_manager.py +1036 -0
- modal/_runtime/execution_context.py +89 -0
- modal/_runtime/telemetry.py +169 -0
- modal/_runtime/user_code_imports.py +356 -0
- modal/_serialization.py +134 -9
- modal/_traceback.py +47 -187
- modal/_tunnel.py +52 -16
- modal/_tunnel.pyi +19 -36
- modal/_utils/app_utils.py +3 -17
- modal/_utils/async_utils.py +479 -100
- modal/_utils/blob_utils.py +157 -186
- modal/_utils/bytes_io_segment_payload.py +97 -0
- modal/_utils/deprecation.py +89 -0
- modal/_utils/docker_utils.py +98 -0
- modal/_utils/function_utils.py +460 -171
- modal/_utils/grpc_testing.py +47 -31
- modal/_utils/grpc_utils.py +62 -109
- modal/_utils/hash_utils.py +61 -19
- modal/_utils/http_utils.py +39 -9
- modal/_utils/logger.py +2 -1
- modal/_utils/mount_utils.py +34 -16
- modal/_utils/name_utils.py +58 -0
- modal/_utils/package_utils.py +14 -1
- modal/_utils/pattern_utils.py +205 -0
- modal/_utils/rand_pb_testing.py +5 -7
- modal/_utils/shell_utils.py +15 -49
- modal/_vendor/a2wsgi_wsgi.py +62 -72
- modal/_vendor/cloudpickle.py +1 -1
- modal/_watcher.py +14 -12
- modal/app.py +1003 -314
- modal/app.pyi +540 -264
- modal/call_graph.py +7 -6
- modal/cli/_download.py +63 -53
- modal/cli/_traceback.py +200 -0
- modal/cli/app.py +205 -45
- modal/cli/config.py +12 -5
- modal/cli/container.py +62 -14
- modal/cli/dict.py +128 -0
- modal/cli/entry_point.py +26 -13
- modal/cli/environment.py +40 -9
- modal/cli/import_refs.py +64 -58
- modal/cli/launch.py +32 -18
- modal/cli/network_file_system.py +64 -83
- modal/cli/profile.py +1 -1
- modal/cli/programs/run_jupyter.py +35 -10
- modal/cli/programs/vscode.py +60 -10
- modal/cli/queues.py +131 -0
- modal/cli/run.py +234 -131
- modal/cli/secret.py +8 -7
- modal/cli/token.py +7 -2
- modal/cli/utils.py +79 -10
- modal/cli/volume.py +110 -109
- modal/client.py +250 -144
- modal/client.pyi +157 -118
- modal/cloud_bucket_mount.py +108 -34
- modal/cloud_bucket_mount.pyi +32 -38
- modal/cls.py +535 -148
- modal/cls.pyi +190 -146
- modal/config.py +41 -19
- modal/container_process.py +177 -0
- modal/container_process.pyi +82 -0
- modal/dict.py +111 -65
- modal/dict.pyi +136 -131
- modal/environments.py +106 -5
- modal/environments.pyi +77 -25
- modal/exception.py +34 -43
- modal/experimental.py +61 -2
- modal/extensions/ipython.py +5 -5
- modal/file_io.py +537 -0
- modal/file_io.pyi +235 -0
- modal/file_pattern_matcher.py +197 -0
- modal/functions.py +906 -911
- modal/functions.pyi +466 -430
- modal/gpu.py +57 -44
- modal/image.py +1089 -479
- modal/image.pyi +584 -228
- modal/io_streams.py +434 -0
- modal/io_streams.pyi +122 -0
- modal/mount.py +314 -101
- modal/mount.pyi +241 -235
- modal/network_file_system.py +92 -92
- modal/network_file_system.pyi +152 -110
- modal/object.py +67 -36
- modal/object.pyi +166 -143
- modal/output.py +63 -0
- modal/parallel_map.py +434 -0
- modal/parallel_map.pyi +75 -0
- modal/partial_function.py +282 -117
- modal/partial_function.pyi +222 -129
- modal/proxy.py +15 -12
- modal/proxy.pyi +3 -8
- modal/queue.py +182 -65
- modal/queue.pyi +218 -118
- modal/requirements/2024.04.txt +29 -0
- modal/requirements/2024.10.txt +16 -0
- modal/requirements/README.md +21 -0
- modal/requirements/base-images.json +22 -0
- modal/retries.py +48 -7
- modal/runner.py +459 -156
- modal/runner.pyi +135 -71
- modal/running_app.py +38 -0
- modal/sandbox.py +514 -236
- modal/sandbox.pyi +397 -169
- modal/schedule.py +4 -4
- modal/scheduler_placement.py +20 -3
- modal/secret.py +56 -31
- modal/secret.pyi +62 -42
- modal/serving.py +51 -56
- modal/serving.pyi +44 -36
- modal/stream_type.py +15 -0
- modal/token_flow.py +5 -3
- modal/token_flow.pyi +37 -32
- modal/volume.py +285 -157
- modal/volume.pyi +249 -184
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/METADATA +7 -7
- modal-0.72.11.dist-info/RECORD +174 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
- modal_docs/gen_reference_docs.py +3 -1
- modal_docs/mdmd/mdmd.py +0 -1
- modal_docs/mdmd/signatures.py +5 -2
- modal_global_objects/images/base_images.py +28 -0
- modal_global_objects/mounts/python_standalone.py +2 -2
- modal_proto/__init__.py +1 -1
- modal_proto/api.proto +1288 -533
- modal_proto/api_grpc.py +856 -456
- modal_proto/api_pb2.py +2165 -1157
- modal_proto/api_pb2.pyi +8859 -0
- modal_proto/api_pb2_grpc.py +1674 -855
- modal_proto/api_pb2_grpc.pyi +1416 -0
- modal_proto/modal_api_grpc.py +149 -0
- modal_proto/modal_options_grpc.py +3 -0
- modal_proto/options_pb2.pyi +20 -0
- modal_proto/options_pb2_grpc.pyi +7 -0
- modal_proto/py.typed +0 -0
- modal_version/__init__.py +1 -1
- modal_version/_version_generated.py +2 -2
- modal/_asgi.py +0 -370
- modal/_container_entrypoint.pyi +0 -378
- modal/_container_exec.py +0 -128
- modal/_sandbox_shell.py +0 -49
- modal/shared_volume.py +0 -23
- modal/shared_volume.pyi +0 -24
- modal/stub.py +0 -783
- modal/stub.pyi +0 -332
- modal-0.62.16.dist-info/RECORD +0 -198
- modal_global_objects/images/conda.py +0 -15
- modal_global_objects/images/debian_slim.py +0 -15
- modal_global_objects/images/micromamba.py +0 -15
- test/__init__.py +0 -1
- test/aio_test.py +0 -12
- test/async_utils_test.py +0 -262
- test/blob_test.py +0 -67
- test/cli_imports_test.py +0 -149
- test/cli_test.py +0 -659
- test/client_test.py +0 -194
- test/cls_test.py +0 -630
- test/config_test.py +0 -137
- test/conftest.py +0 -1420
- test/container_app_test.py +0 -32
- test/container_test.py +0 -1389
- test/cpu_test.py +0 -23
- test/decorator_test.py +0 -85
- test/deprecation_test.py +0 -34
- test/dict_test.py +0 -33
- test/e2e_test.py +0 -68
- test/error_test.py +0 -7
- test/function_serialization_test.py +0 -32
- test/function_test.py +0 -653
- test/function_utils_test.py +0 -101
- test/gpu_test.py +0 -159
- test/grpc_utils_test.py +0 -141
- test/helpers.py +0 -42
- test/image_test.py +0 -669
- test/live_reload_test.py +0 -80
- test/lookup_test.py +0 -70
- test/mdmd_test.py +0 -329
- test/mount_test.py +0 -162
- test/mounted_files_test.py +0 -329
- test/network_file_system_test.py +0 -181
- test/notebook_test.py +0 -66
- test/object_test.py +0 -41
- test/package_utils_test.py +0 -25
- test/queue_test.py +0 -97
- test/resolver_test.py +0 -58
- test/retries_test.py +0 -67
- test/runner_test.py +0 -85
- test/sandbox_test.py +0 -191
- test/schedule_test.py +0 -15
- test/scheduler_placement_test.py +0 -29
- test/secret_test.py +0 -78
- test/serialization_test.py +0 -42
- test/stub_composition_test.py +0 -10
- test/stub_test.py +0 -360
- test/test_asgi_wrapper.py +0 -234
- test/token_flow_test.py +0 -18
- test/traceback_test.py +0 -135
- test/tunnel_test.py +0 -29
- test/utils_test.py +0 -88
- test/version_test.py +0 -14
- test/volume_test.py +0 -341
- test/watcher_test.py +0 -30
- test/webhook_test.py +0 -146
- /modal/{requirements.312.txt → requirements/2023.12.312.txt} +0 -0
- /modal/{requirements.txt → requirements/2023.12.txt} +0 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
modal/_utils/grpc_testing.py
CHANGED
@@ -4,7 +4,8 @@ import inspect
|
|
4
4
|
import logging
|
5
5
|
import typing
|
6
6
|
from collections import Counter, defaultdict
|
7
|
-
from
|
7
|
+
from collections.abc import Awaitable
|
8
|
+
from typing import Any, Callable
|
8
9
|
|
9
10
|
import grpclib.server
|
10
11
|
from grpclib import GRPCError, Status
|
@@ -26,7 +27,8 @@ def patch_mock_servicer(cls):
|
|
26
27
|
await some_complex_method()
|
27
28
|
assert ctx.calls == [("SomeMethod", MyMessage(foo="bar"))]
|
28
29
|
```
|
29
|
-
Also allows to set a predefined queue of responses, temporarily replacing
|
30
|
+
Also allows to set a predefined queue of responses, temporarily replacing
|
31
|
+
a mock servicer's default responses for a method:
|
30
32
|
|
31
33
|
```python notest
|
32
34
|
with servicer.intercept() as ctx:
|
@@ -48,10 +50,10 @@ def patch_mock_servicer(cls):
|
|
48
50
|
|
49
51
|
@contextlib.contextmanager
|
50
52
|
def intercept(servicer):
|
51
|
-
ctx = InterceptionContext()
|
53
|
+
ctx = InterceptionContext(servicer)
|
52
54
|
servicer.interception_context = ctx
|
53
55
|
yield ctx
|
54
|
-
ctx.
|
56
|
+
ctx._assert_responses_consumed()
|
55
57
|
servicer.interception_context = None
|
56
58
|
|
57
59
|
cls.intercept = intercept
|
@@ -63,7 +65,7 @@ def patch_mock_servicer(cls):
|
|
63
65
|
ctx = servicer_self.interception_context
|
64
66
|
if ctx:
|
65
67
|
intercepted_stream = await InterceptedStream(ctx, method_name, stream).initialize()
|
66
|
-
custom_responder = ctx.
|
68
|
+
custom_responder = ctx._next_custom_responder(method_name, intercepted_stream.request_message)
|
67
69
|
if custom_responder:
|
68
70
|
return await custom_responder(servicer_self, intercepted_stream)
|
69
71
|
else:
|
@@ -92,31 +94,36 @@ def patch_mock_servicer(cls):
|
|
92
94
|
|
93
95
|
|
94
96
|
class ResponseNotConsumed(Exception):
|
95
|
-
def __init__(self, unconsumed_requests:
|
97
|
+
def __init__(self, unconsumed_requests: list[str]):
|
96
98
|
self.unconsumed_requests = unconsumed_requests
|
97
99
|
request_count = Counter(unconsumed_requests)
|
98
100
|
super().__init__(f"Expected but did not receive the following requests: {request_count}")
|
99
101
|
|
100
102
|
|
101
103
|
class InterceptionContext:
|
102
|
-
def __init__(self):
|
103
|
-
self.
|
104
|
-
self.
|
105
|
-
self.
|
106
|
-
|
107
|
-
def add_recv(self, method_name: str, msg):
|
108
|
-
self.calls.append((method_name, msg))
|
104
|
+
def __init__(self, servicer):
|
105
|
+
self._servicer = servicer
|
106
|
+
self.calls: list[tuple[str, Any]] = [] # List[Tuple[method_name, message]]
|
107
|
+
self.custom_responses: dict[str, list[tuple[Callable[[Any], bool], list[Any]]]] = defaultdict(list)
|
108
|
+
self.custom_defaults: dict[str, Callable[["MockClientServicer", grpclib.server.Stream], Awaitable[None]]] = {}
|
109
109
|
|
110
110
|
def add_response(
|
111
111
|
self, method_name: str, first_payload, *, request_filter: Callable[[Any], bool] = lambda req: True
|
112
112
|
):
|
113
|
-
|
113
|
+
"""Adds one response payload to an expected queue of responses for a method.
|
114
|
+
|
115
|
+
These responses will be used once each instead of calling the MockServicer's
|
116
|
+
implementation of the method.
|
117
|
+
|
118
|
+
The interception context will throw an exception on exit if not all of the added
|
119
|
+
responses have been consumed.
|
120
|
+
"""
|
114
121
|
self.custom_responses[method_name].append((request_filter, [first_payload]))
|
115
122
|
|
116
123
|
def set_responder(
|
117
124
|
self, method_name: str, responder: Callable[["MockClientServicer", grpclib.server.Stream], Awaitable[None]]
|
118
125
|
):
|
119
|
-
"""Replace the default responder
|
126
|
+
"""Replace the default responder from the MockClientServicer with a custom implementation
|
120
127
|
|
121
128
|
```python notest
|
122
129
|
def custom_responder(servicer, stream):
|
@@ -127,11 +134,31 @@ class InterceptionContext:
|
|
127
134
|
ctx.set_responder("SomeMethod", custom_responder)
|
128
135
|
```
|
129
136
|
|
130
|
-
Responses added via `.add_response()` take precedence
|
137
|
+
Responses added via `.add_response()` take precedence over the use of this replacement
|
131
138
|
"""
|
132
139
|
self.custom_defaults[method_name] = responder
|
133
140
|
|
134
|
-
def
|
141
|
+
def pop_request(self, method_name):
|
142
|
+
# fast forward to the next request of type method_name
|
143
|
+
# dropping any preceding requests if there is a match
|
144
|
+
# returns the payload of the request
|
145
|
+
for i, (_method_name, msg) in enumerate(self.calls):
|
146
|
+
if _method_name == method_name:
|
147
|
+
self.calls = self.calls[i + 1 :]
|
148
|
+
return msg
|
149
|
+
|
150
|
+
raise KeyError(f"No message of that type in call list: {self.calls}")
|
151
|
+
|
152
|
+
def get_requests(self, method_name: str) -> list[Any]:
|
153
|
+
if not hasattr(self._servicer, method_name):
|
154
|
+
# we check this to prevent things like `assert ctx.get_requests("ASdfFunctionCreate") == 0` passing
|
155
|
+
raise ValueError(f"{method_name} not in MockServicer - did you spell it right?")
|
156
|
+
return [msg for _method_name, msg in self.calls if _method_name == method_name]
|
157
|
+
|
158
|
+
def _add_recv(self, method_name: str, msg):
|
159
|
+
self.calls.append((method_name, msg))
|
160
|
+
|
161
|
+
def _next_custom_responder(self, method_name, request):
|
135
162
|
method_responses = self.custom_responses[method_name]
|
136
163
|
for i, (request_filter, response_messages) in enumerate(method_responses):
|
137
164
|
try:
|
@@ -158,7 +185,7 @@ class InterceptionContext:
|
|
158
185
|
|
159
186
|
return responder
|
160
187
|
|
161
|
-
def
|
188
|
+
def _assert_responses_consumed(self):
|
162
189
|
unconsumed = []
|
163
190
|
for method_name, queued_responses in self.custom_responses.items():
|
164
191
|
unconsumed += [method_name] * len(queued_responses)
|
@@ -166,20 +193,9 @@ class InterceptionContext:
|
|
166
193
|
if unconsumed:
|
167
194
|
raise ResponseNotConsumed(unconsumed)
|
168
195
|
|
169
|
-
def pop_request(self, method_name):
|
170
|
-
# fast forward to the next request of type method_name
|
171
|
-
# dropping any preceding requests if there is a match
|
172
|
-
# returns the payload of the request
|
173
|
-
for i, (_method_name, msg) in enumerate(self.calls):
|
174
|
-
if _method_name == method_name:
|
175
|
-
self.calls = self.calls[i + 1 :]
|
176
|
-
return msg
|
177
|
-
|
178
|
-
raise Exception(f"No message of that type in call list: {self.calls}")
|
179
|
-
|
180
196
|
|
181
197
|
class InterceptedStream:
|
182
|
-
def __init__(self, interception_context, method_name, stream):
|
198
|
+
def __init__(self, interception_context: InterceptionContext, method_name: str, stream):
|
183
199
|
self.interception_context = interception_context
|
184
200
|
self.method_name = method_name
|
185
201
|
self.stream = stream
|
@@ -196,7 +212,7 @@ class InterceptedStream:
|
|
196
212
|
return ret
|
197
213
|
|
198
214
|
msg = await self.stream.recv_message()
|
199
|
-
self.interception_context.
|
215
|
+
self.interception_context._add_recv(self.method_name, msg)
|
200
216
|
return msg
|
201
217
|
|
202
218
|
async def send_message(self, msg):
|
modal/_utils/grpc_utils.py
CHANGED
@@ -4,30 +4,37 @@ import contextlib
|
|
4
4
|
import platform
|
5
5
|
import socket
|
6
6
|
import time
|
7
|
+
import typing
|
7
8
|
import urllib.parse
|
8
9
|
import uuid
|
10
|
+
from collections.abc import AsyncIterator
|
9
11
|
from typing import (
|
10
12
|
Any,
|
11
|
-
AsyncIterator,
|
12
|
-
Dict,
|
13
|
-
List,
|
14
13
|
Optional,
|
15
|
-
Type,
|
16
14
|
TypeVar,
|
17
15
|
)
|
18
16
|
|
19
17
|
import grpclib.client
|
20
18
|
import grpclib.config
|
21
19
|
import grpclib.events
|
20
|
+
import grpclib.protocol
|
21
|
+
import grpclib.stream
|
22
22
|
from google.protobuf.message import Message
|
23
23
|
from grpclib import GRPCError, Status
|
24
24
|
from grpclib.exceptions import StreamTerminatedError
|
25
25
|
from grpclib.protocol import H2Protocol
|
26
26
|
|
27
|
+
from modal.exception import AuthError, ConnectionError
|
27
28
|
from modal_version import __version__
|
28
29
|
|
29
30
|
from .logger import logger
|
30
31
|
|
32
|
+
RequestType = TypeVar("RequestType", bound=Message)
|
33
|
+
ResponseType = TypeVar("ResponseType", bound=Message)
|
34
|
+
|
35
|
+
if typing.TYPE_CHECKING:
|
36
|
+
import modal.client
|
37
|
+
|
31
38
|
# Monkey patches grpclib to have a Modal User Agent header.
|
32
39
|
grpclib.client.USER_AGENT = "modal-client/{version} ({sys}; {py}/{py_ver})'".format(
|
33
40
|
version=__version__,
|
@@ -54,81 +61,6 @@ class Subchannel:
|
|
54
61
|
return True
|
55
62
|
|
56
63
|
|
57
|
-
class ChannelPool(grpclib.client.Channel):
|
58
|
-
"""Use multiple channels under the hood. A drop-in replacement for the grpclib Channel.
|
59
|
-
|
60
|
-
The main reason is to get around limitations with TCP connections over the internet,
|
61
|
-
in particular idle timeouts.
|
62
|
-
|
63
|
-
The algorithm is very simple. It reuses the last subchannel as long as it has had less
|
64
|
-
than 64 requests or if it was created less than 30s ago. It closes any subchannel that
|
65
|
-
hits 90s age. This means requests using the ChannelPool can't be longer than 60s.
|
66
|
-
"""
|
67
|
-
|
68
|
-
_max_requests: int
|
69
|
-
_max_lifetime: float
|
70
|
-
_max_active: float
|
71
|
-
_subchannels: List[Subchannel]
|
72
|
-
|
73
|
-
def __init__(
|
74
|
-
self,
|
75
|
-
*args,
|
76
|
-
max_requests=64, # Maximum number of total requests per subchannel
|
77
|
-
max_active=30, # Don't accept more connections on the subchannel after this many seconds
|
78
|
-
max_lifetime=90, # Close subchannel after this many seconds
|
79
|
-
**kwargs,
|
80
|
-
):
|
81
|
-
self._subchannels = []
|
82
|
-
self._max_requests = max_requests
|
83
|
-
self._max_active = max_active
|
84
|
-
self._max_lifetime = max_lifetime
|
85
|
-
super().__init__(*args, **kwargs)
|
86
|
-
|
87
|
-
async def __connect__(self):
|
88
|
-
now = time.time()
|
89
|
-
# Remove any closed subchannels
|
90
|
-
while len(self._subchannels) > 0 and not self._subchannels[-1].connected():
|
91
|
-
self._subchannels.pop()
|
92
|
-
|
93
|
-
# Close and delete any subchannels that are past their lifetime
|
94
|
-
while len(self._subchannels) > 0 and now - self._subchannels[0].created_at > self._max_lifetime:
|
95
|
-
self._subchannels.pop(0).protocol.processor.close()
|
96
|
-
|
97
|
-
# See if we can reuse the last subchannel
|
98
|
-
create_subchannel = None
|
99
|
-
if len(self._subchannels) > 0:
|
100
|
-
if self._subchannels[-1].created_at < now - self._max_active:
|
101
|
-
# Don't reuse subchannel that's too old
|
102
|
-
create_subchannel = True
|
103
|
-
elif self._subchannels[-1].requests > self._max_requests:
|
104
|
-
create_subchannel = True
|
105
|
-
else:
|
106
|
-
create_subchannel = False
|
107
|
-
else:
|
108
|
-
create_subchannel = True
|
109
|
-
|
110
|
-
# Create new if needed
|
111
|
-
# There's a theoretical race condition here.
|
112
|
-
# This is harmless but may lead to superfluous protocols.
|
113
|
-
if create_subchannel:
|
114
|
-
protocol = await self._create_connection()
|
115
|
-
self._subchannels.append(Subchannel(protocol))
|
116
|
-
|
117
|
-
self._subchannels[-1].requests += 1
|
118
|
-
return self._subchannels[-1].protocol
|
119
|
-
|
120
|
-
def close(self) -> None:
|
121
|
-
while len(self._subchannels) > 0:
|
122
|
-
self._subchannels.pop(0).protocol.processor.close()
|
123
|
-
|
124
|
-
def __del__(self) -> None:
|
125
|
-
if len(self._subchannels) > 0:
|
126
|
-
logger.warning("Channel pool not properly closed")
|
127
|
-
|
128
|
-
|
129
|
-
_SendType = TypeVar("_SendType")
|
130
|
-
_RecvType = TypeVar("_RecvType")
|
131
|
-
|
132
64
|
RETRYABLE_GRPC_STATUS_CODES = [
|
133
65
|
Status.DEADLINE_EXCEEDED,
|
134
66
|
Status.UNAVAILABLE,
|
@@ -139,9 +71,7 @@ RETRYABLE_GRPC_STATUS_CODES = [
|
|
139
71
|
|
140
72
|
def create_channel(
|
141
73
|
server_url: str,
|
142
|
-
metadata:
|
143
|
-
*,
|
144
|
-
use_pool: Optional[bool] = None, # If None, inferred from the scheme
|
74
|
+
metadata: dict[str, str] = {},
|
145
75
|
) -> grpclib.client.Channel:
|
146
76
|
"""Creates a grpclib.Channel.
|
147
77
|
|
@@ -150,15 +80,6 @@ def create_channel(
|
|
150
80
|
"""
|
151
81
|
o = urllib.parse.urlparse(server_url)
|
152
82
|
|
153
|
-
if use_pool is None:
|
154
|
-
use_pool = o.scheme in ("http", "https")
|
155
|
-
|
156
|
-
channel_cls: Type[grpclib.client.Channel]
|
157
|
-
if use_pool:
|
158
|
-
channel_cls = ChannelPool
|
159
|
-
else:
|
160
|
-
channel_cls = grpclib.client.Channel
|
161
|
-
|
162
83
|
channel: grpclib.client.Channel
|
163
84
|
config = grpclib.config.Configuration(
|
164
85
|
http2_connection_window_size=64 * 1024 * 1024, # 64 MiB
|
@@ -166,7 +87,7 @@ def create_channel(
|
|
166
87
|
)
|
167
88
|
|
168
89
|
if o.scheme == "unix":
|
169
|
-
channel =
|
90
|
+
channel = grpclib.client.Channel(path=o.path, config=config) # probably pointless to use a pool ever
|
170
91
|
elif o.scheme in ("http", "https"):
|
171
92
|
target = o.netloc
|
172
93
|
parts = target.split(":")
|
@@ -174,7 +95,7 @@ def create_channel(
|
|
174
95
|
ssl = o.scheme.endswith("s")
|
175
96
|
host = parts[0]
|
176
97
|
port = int(parts[1]) if len(parts) == 2 else 443 if ssl else 80
|
177
|
-
channel =
|
98
|
+
channel = grpclib.client.Channel(host, port, ssl=ssl, config=config)
|
178
99
|
else:
|
179
100
|
raise Exception(f"Unknown scheme: {o.scheme}")
|
180
101
|
|
@@ -189,23 +110,31 @@ def create_channel(
|
|
189
110
|
logger.debug(f"Sending request to {event.method_name}")
|
190
111
|
|
191
112
|
grpclib.events.listen(channel, grpclib.events.SendRequest, send_request)
|
113
|
+
|
192
114
|
return channel
|
193
115
|
|
194
116
|
|
117
|
+
async def connect_channel(channel: grpclib.client.Channel):
|
118
|
+
"""Connects socket (potentially raising errors raising to connectivity."""
|
119
|
+
await channel.__connect__()
|
120
|
+
|
121
|
+
|
122
|
+
if typing.TYPE_CHECKING:
|
123
|
+
import modal.client
|
124
|
+
|
125
|
+
|
195
126
|
async def unary_stream(
|
196
|
-
method:
|
197
|
-
request:
|
127
|
+
method: "modal.client.UnaryStreamWrapper[RequestType, ResponseType]",
|
128
|
+
request: RequestType,
|
198
129
|
metadata: Optional[Any] = None,
|
199
|
-
) -> AsyncIterator[
|
200
|
-
|
201
|
-
async
|
202
|
-
|
203
|
-
async for item in stream:
|
204
|
-
yield item
|
130
|
+
) -> AsyncIterator[ResponseType]:
|
131
|
+
# TODO: remove this, since we have a method now
|
132
|
+
async for item in method.unary_stream(request, metadata):
|
133
|
+
yield item
|
205
134
|
|
206
135
|
|
207
136
|
async def retry_transient_errors(
|
208
|
-
fn,
|
137
|
+
fn: "modal.client.UnaryUnaryWrapper[RequestType, ResponseType]",
|
209
138
|
*args,
|
210
139
|
base_delay: float = 0.1,
|
211
140
|
max_delay: float = 1,
|
@@ -215,7 +144,7 @@ async def retry_transient_errors(
|
|
215
144
|
attempt_timeout: Optional[float] = None, # timeout for each attempt
|
216
145
|
total_timeout: Optional[float] = None, # timeout for the entire function call
|
217
146
|
attempt_timeout_floor=2.0, # always have at least this much timeout (only for total_timeout)
|
218
|
-
):
|
147
|
+
) -> ResponseType:
|
219
148
|
"""Retry on transient gRPC failures with back-off until max_retries is reached.
|
220
149
|
If max_retries is None, retry forever."""
|
221
150
|
|
@@ -247,16 +176,35 @@ async def retry_transient_errors(
|
|
247
176
|
timeout = None
|
248
177
|
try:
|
249
178
|
return await fn(*args, metadata=metadata, timeout=timeout)
|
250
|
-
except (StreamTerminatedError, GRPCError,
|
179
|
+
except (StreamTerminatedError, GRPCError, OSError, asyncio.TimeoutError, AttributeError) as exc:
|
251
180
|
if isinstance(exc, GRPCError) and exc.status not in status_codes:
|
252
|
-
|
181
|
+
if exc.status == Status.UNAUTHENTICATED:
|
182
|
+
raise AuthError(exc.message)
|
183
|
+
else:
|
184
|
+
raise exc
|
253
185
|
|
254
186
|
if max_retries is not None and n_retries >= max_retries:
|
187
|
+
final_attempt = True
|
188
|
+
elif total_deadline is not None and time.time() + delay + attempt_timeout_floor >= total_deadline:
|
189
|
+
final_attempt = True
|
190
|
+
else:
|
191
|
+
final_attempt = False
|
192
|
+
|
193
|
+
if final_attempt:
|
194
|
+
if isinstance(exc, OSError):
|
195
|
+
raise ConnectionError(str(exc))
|
196
|
+
elif isinstance(exc, asyncio.TimeoutError):
|
197
|
+
raise ConnectionError(str(exc))
|
198
|
+
else:
|
199
|
+
raise exc
|
200
|
+
|
201
|
+
if isinstance(exc, AttributeError) and "_write_appdata" not in str(exc):
|
202
|
+
# StreamTerminatedError are not properly raised in grpclib<=0.4.7
|
203
|
+
# fixed in https://github.com/vmagamedov/grpclib/issues/185
|
204
|
+
# TODO: update to newer version (>=0.4.8) once stable
|
255
205
|
raise exc
|
256
206
|
|
257
|
-
|
258
|
-
# no point sleeping if that's going to push us past the deadline
|
259
|
-
raise exc
|
207
|
+
logger.debug(f"Retryable failure {repr(exc)} {n_retries=} {delay=} for {fn.name}")
|
260
208
|
|
261
209
|
n_retries += 1
|
262
210
|
|
@@ -265,7 +213,12 @@ async def retry_transient_errors(
|
|
265
213
|
|
266
214
|
|
267
215
|
def find_free_port() -> int:
|
268
|
-
"""
|
216
|
+
"""
|
217
|
+
Find a free TCP port, useful for testing.
|
218
|
+
|
219
|
+
WARN: if a returned free port is not bound immediately by the caller, that same port
|
220
|
+
may be returned in subsequent calls to this function, potentially creating port collisions.
|
221
|
+
"""
|
269
222
|
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
270
223
|
s.bind(("", 0))
|
271
224
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
modal/_utils/hash_utils.py
CHANGED
@@ -2,43 +2,53 @@
|
|
2
2
|
import base64
|
3
3
|
import dataclasses
|
4
4
|
import hashlib
|
5
|
-
|
5
|
+
import time
|
6
|
+
from typing import BinaryIO, Callable, Optional, Sequence, Union
|
6
7
|
|
7
|
-
|
8
|
+
from modal.config import logger
|
8
9
|
|
10
|
+
HASH_CHUNK_SIZE = 65536
|
9
11
|
|
10
|
-
|
12
|
+
|
13
|
+
def _update(hashers: Sequence[Callable[[bytes], None]], data: Union[bytes, BinaryIO]) -> None:
|
11
14
|
if isinstance(data, bytes):
|
12
15
|
for hasher in hashers:
|
13
|
-
hasher
|
16
|
+
hasher(data)
|
14
17
|
else:
|
18
|
+
assert not isinstance(data, (bytearray, memoryview)) # https://github.com/microsoft/pyright/issues/5697
|
15
19
|
pos = data.tell()
|
16
|
-
while
|
20
|
+
while True:
|
17
21
|
chunk = data.read(HASH_CHUNK_SIZE)
|
18
22
|
if not isinstance(chunk, bytes):
|
19
23
|
raise ValueError(f"Only accepts bytes or byte buffer objects, not {type(chunk)} buffers")
|
20
24
|
if not chunk:
|
21
25
|
break
|
22
26
|
for hasher in hashers:
|
23
|
-
hasher
|
27
|
+
hasher(chunk)
|
24
28
|
data.seek(pos)
|
25
29
|
|
26
30
|
|
27
|
-
def get_sha256_hex(data: Union[bytes,
|
31
|
+
def get_sha256_hex(data: Union[bytes, BinaryIO]) -> str:
|
32
|
+
t0 = time.monotonic()
|
28
33
|
hasher = hashlib.sha256()
|
29
|
-
_update([hasher], data)
|
34
|
+
_update([hasher.update], data)
|
35
|
+
logger.debug("get_sha256_hex took %.3fs", time.monotonic() - t0)
|
30
36
|
return hasher.hexdigest()
|
31
37
|
|
32
38
|
|
33
|
-
def get_sha256_base64(data: Union[bytes,
|
39
|
+
def get_sha256_base64(data: Union[bytes, BinaryIO]) -> str:
|
40
|
+
t0 = time.monotonic()
|
34
41
|
hasher = hashlib.sha256()
|
35
|
-
_update([hasher], data)
|
42
|
+
_update([hasher.update], data)
|
43
|
+
logger.debug("get_sha256_base64 took %.3fs", time.monotonic() - t0)
|
36
44
|
return base64.b64encode(hasher.digest()).decode("ascii")
|
37
45
|
|
38
46
|
|
39
|
-
def get_md5_base64(data: Union[bytes,
|
47
|
+
def get_md5_base64(data: Union[bytes, BinaryIO]) -> str:
|
48
|
+
t0 = time.monotonic()
|
40
49
|
hasher = hashlib.md5()
|
41
|
-
_update([hasher], data)
|
50
|
+
_update([hasher.update], data)
|
51
|
+
logger.debug("get_md5_base64 took %.3fs", time.monotonic() - t0)
|
42
52
|
return base64.b64encode(hasher.digest()).decode("utf-8")
|
43
53
|
|
44
54
|
|
@@ -47,12 +57,44 @@ class UploadHashes:
|
|
47
57
|
md5_base64: str
|
48
58
|
sha256_base64: str
|
49
59
|
|
60
|
+
def md5_hex(self) -> str:
|
61
|
+
return base64.b64decode(self.md5_base64).hex()
|
62
|
+
|
63
|
+
def sha256_hex(self) -> str:
|
64
|
+
return base64.b64decode(self.sha256_base64).hex()
|
65
|
+
|
66
|
+
|
67
|
+
def get_upload_hashes(
|
68
|
+
data: Union[bytes, BinaryIO], sha256_hex: Optional[str] = None, md5_hex: Optional[str] = None
|
69
|
+
) -> UploadHashes:
|
70
|
+
t0 = time.monotonic()
|
71
|
+
hashers = {}
|
72
|
+
|
73
|
+
if not sha256_hex:
|
74
|
+
sha256 = hashlib.sha256()
|
75
|
+
hashers["sha256"] = sha256
|
76
|
+
if not md5_hex:
|
77
|
+
md5 = hashlib.md5()
|
78
|
+
hashers["md5"] = md5
|
79
|
+
|
80
|
+
if hashers:
|
81
|
+
updaters = [h.update for h in hashers.values()]
|
82
|
+
_update(updaters, data)
|
50
83
|
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
84
|
+
if sha256_hex:
|
85
|
+
sha256_base64 = base64.b64encode(bytes.fromhex(sha256_hex)).decode("ascii")
|
86
|
+
else:
|
87
|
+
sha256_base64 = base64.b64encode(hashers["sha256"].digest()).decode("ascii")
|
88
|
+
|
89
|
+
if md5_hex:
|
90
|
+
md5_base64 = base64.b64encode(bytes.fromhex(md5_hex)).decode("ascii")
|
91
|
+
else:
|
92
|
+
md5_base64 = base64.b64encode(hashers["md5"].digest()).decode("ascii")
|
93
|
+
|
94
|
+
hashes = UploadHashes(
|
95
|
+
md5_base64=md5_base64,
|
96
|
+
sha256_base64=sha256_base64,
|
58
97
|
)
|
98
|
+
|
99
|
+
logger.debug("get_upload_hashes took %.3fs (%s)", time.monotonic() - t0, hashers.keys())
|
100
|
+
return hashes
|
modal/_utils/http_utils.py
CHANGED
@@ -1,16 +1,18 @@
|
|
1
1
|
# Copyright Modal Labs 2022
|
2
2
|
import contextlib
|
3
|
-
import
|
4
|
-
import ssl
|
5
|
-
from typing import Optional
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
6
4
|
|
7
|
-
|
8
|
-
|
9
|
-
from aiohttp.web import Application
|
10
|
-
from aiohttp.web_runner import AppRunner, SockSite
|
5
|
+
# Note: importing aiohttp seems to take about 100ms, and it's not really necessarily,
|
6
|
+
# unless we need to work with blobs. So that's why we import it lazily instead.
|
11
7
|
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from aiohttp import ClientSession
|
10
|
+
from aiohttp.web import Application
|
12
11
|
|
13
|
-
|
12
|
+
from .async_utils import on_shutdown
|
13
|
+
|
14
|
+
|
15
|
+
def _http_client_with_tls(timeout: Optional[float]) -> "ClientSession":
|
14
16
|
"""Create a new HTTP client session with standard, bundled TLS certificates.
|
15
17
|
|
16
18
|
This is necessary to prevent client issues on some system where Python does
|
@@ -20,15 +22,43 @@ def http_client_with_tls(timeout: Optional[float]) -> ClientSession:
|
|
20
22
|
Specifically: the error "unable to get local issuer certificate" when making
|
21
23
|
an aiohttp request.
|
22
24
|
"""
|
25
|
+
import ssl
|
26
|
+
|
27
|
+
import certifi
|
28
|
+
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
29
|
+
|
23
30
|
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
24
31
|
connector = TCPConnector(ssl=ssl_context)
|
25
32
|
return ClientSession(connector=connector, timeout=ClientTimeout(total=timeout))
|
26
33
|
|
27
34
|
|
35
|
+
class ClientSessionRegistry:
|
36
|
+
_client_session: "ClientSession"
|
37
|
+
_client_session_active: bool = False
|
38
|
+
|
39
|
+
@staticmethod
|
40
|
+
def get_session():
|
41
|
+
if not ClientSessionRegistry._client_session_active:
|
42
|
+
ClientSessionRegistry._client_session = _http_client_with_tls(timeout=None)
|
43
|
+
ClientSessionRegistry._client_session_active = True
|
44
|
+
on_shutdown(ClientSessionRegistry.close_session())
|
45
|
+
return ClientSessionRegistry._client_session
|
46
|
+
|
47
|
+
@staticmethod
|
48
|
+
async def close_session():
|
49
|
+
if ClientSessionRegistry._client_session_active:
|
50
|
+
await ClientSessionRegistry._client_session.close()
|
51
|
+
ClientSessionRegistry._client_session_active = False
|
52
|
+
|
53
|
+
|
28
54
|
@contextlib.asynccontextmanager
|
29
|
-
async def run_temporary_http_server(app: Application):
|
55
|
+
async def run_temporary_http_server(app: "Application"):
|
30
56
|
# Allocates a random port, runs a server in a context manager
|
31
57
|
# This is used in various tests
|
58
|
+
import socket
|
59
|
+
|
60
|
+
from aiohttp.web_runner import AppRunner, SockSite
|
61
|
+
|
32
62
|
sock = socket.socket()
|
33
63
|
sock.bind(("", 0))
|
34
64
|
port = sock.getsockname()[1]
|
modal/_utils/logger.py
CHANGED
@@ -17,7 +17,8 @@ def configure_logger(logger: logging.Logger, log_level: str, log_format: str):
|
|
17
17
|
json_formatter = jsonlogger.JsonFormatter(
|
18
18
|
fmt=(
|
19
19
|
"%(asctime)s %(levelname)s [%(name)s] [%(filename)s:%(lineno)d] "
|
20
|
-
"[dd.service=%(dd.service)s dd.env=%(dd.env)s dd.version=%(dd.version)s dd.trace_id=%(dd.trace_id)s
|
20
|
+
"[dd.service=%(dd.service)s dd.env=%(dd.env)s dd.version=%(dd.version)s dd.trace_id=%(dd.trace_id)s "
|
21
|
+
"dd.span_id=%(dd.span_id)s] "
|
21
22
|
"- %(message)s"
|
22
23
|
),
|
23
24
|
datefmt="%Y-%m-%dT%H:%M:%S%z",
|