modal 0.62.16__py3-none-any.whl → 0.72.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +17 -13
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +420 -937
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -59
  11. modal/_resources.py +51 -0
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1036 -0
  15. modal/_runtime/execution_context.py +89 -0
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +134 -9
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +52 -16
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +479 -100
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +460 -171
  29. modal/_utils/grpc_testing.py +47 -31
  30. modal/_utils/grpc_utils.py +62 -109
  31. modal/_utils/hash_utils.py +61 -19
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +5 -7
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +14 -12
  43. modal/app.py +1003 -314
  44. modal/app.pyi +540 -264
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +63 -53
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +205 -45
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +62 -14
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +64 -58
  55. modal/cli/launch.py +32 -18
  56. modal/cli/network_file_system.py +64 -83
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +35 -10
  59. modal/cli/programs/vscode.py +60 -10
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +234 -131
  62. modal/cli/secret.py +8 -7
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +79 -10
  65. modal/cli/volume.py +110 -109
  66. modal/client.py +250 -144
  67. modal/client.pyi +157 -118
  68. modal/cloud_bucket_mount.py +108 -34
  69. modal/cloud_bucket_mount.pyi +32 -38
  70. modal/cls.py +535 -148
  71. modal/cls.pyi +190 -146
  72. modal/config.py +41 -19
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +111 -65
  76. modal/dict.pyi +136 -131
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +34 -43
  80. modal/experimental.py +61 -2
  81. modal/extensions/ipython.py +5 -5
  82. modal/file_io.py +537 -0
  83. modal/file_io.pyi +235 -0
  84. modal/file_pattern_matcher.py +197 -0
  85. modal/functions.py +906 -911
  86. modal/functions.pyi +466 -430
  87. modal/gpu.py +57 -44
  88. modal/image.py +1089 -479
  89. modal/image.pyi +584 -228
  90. modal/io_streams.py +434 -0
  91. modal/io_streams.pyi +122 -0
  92. modal/mount.py +314 -101
  93. modal/mount.pyi +241 -235
  94. modal/network_file_system.py +92 -92
  95. modal/network_file_system.pyi +152 -110
  96. modal/object.py +67 -36
  97. modal/object.pyi +166 -143
  98. modal/output.py +63 -0
  99. modal/parallel_map.py +434 -0
  100. modal/parallel_map.pyi +75 -0
  101. modal/partial_function.py +282 -117
  102. modal/partial_function.pyi +222 -129
  103. modal/proxy.py +15 -12
  104. modal/proxy.pyi +3 -8
  105. modal/queue.py +182 -65
  106. modal/queue.pyi +218 -118
  107. modal/requirements/2024.04.txt +29 -0
  108. modal/requirements/2024.10.txt +16 -0
  109. modal/requirements/README.md +21 -0
  110. modal/requirements/base-images.json +22 -0
  111. modal/retries.py +48 -7
  112. modal/runner.py +459 -156
  113. modal/runner.pyi +135 -71
  114. modal/running_app.py +38 -0
  115. modal/sandbox.py +514 -236
  116. modal/sandbox.pyi +397 -169
  117. modal/schedule.py +4 -4
  118. modal/scheduler_placement.py +20 -3
  119. modal/secret.py +56 -31
  120. modal/secret.pyi +62 -42
  121. modal/serving.py +51 -56
  122. modal/serving.pyi +44 -36
  123. modal/stream_type.py +15 -0
  124. modal/token_flow.py +5 -3
  125. modal/token_flow.pyi +37 -32
  126. modal/volume.py +285 -157
  127. modal/volume.pyi +249 -184
  128. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/METADATA +7 -7
  129. modal-0.72.11.dist-info/RECORD +174 -0
  130. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
  131. modal_docs/gen_reference_docs.py +3 -1
  132. modal_docs/mdmd/mdmd.py +0 -1
  133. modal_docs/mdmd/signatures.py +5 -2
  134. modal_global_objects/images/base_images.py +28 -0
  135. modal_global_objects/mounts/python_standalone.py +2 -2
  136. modal_proto/__init__.py +1 -1
  137. modal_proto/api.proto +1288 -533
  138. modal_proto/api_grpc.py +856 -456
  139. modal_proto/api_pb2.py +2165 -1157
  140. modal_proto/api_pb2.pyi +8859 -0
  141. modal_proto/api_pb2_grpc.py +1674 -855
  142. modal_proto/api_pb2_grpc.pyi +1416 -0
  143. modal_proto/modal_api_grpc.py +149 -0
  144. modal_proto/modal_options_grpc.py +3 -0
  145. modal_proto/options_pb2.pyi +20 -0
  146. modal_proto/options_pb2_grpc.pyi +7 -0
  147. modal_proto/py.typed +0 -0
  148. modal_version/__init__.py +1 -1
  149. modal_version/_version_generated.py +2 -2
  150. modal/_asgi.py +0 -370
  151. modal/_container_entrypoint.pyi +0 -378
  152. modal/_container_exec.py +0 -128
  153. modal/_sandbox_shell.py +0 -49
  154. modal/shared_volume.py +0 -23
  155. modal/shared_volume.pyi +0 -24
  156. modal/stub.py +0 -783
  157. modal/stub.pyi +0 -332
  158. modal-0.62.16.dist-info/RECORD +0 -198
  159. modal_global_objects/images/conda.py +0 -15
  160. modal_global_objects/images/debian_slim.py +0 -15
  161. modal_global_objects/images/micromamba.py +0 -15
  162. test/__init__.py +0 -1
  163. test/aio_test.py +0 -12
  164. test/async_utils_test.py +0 -262
  165. test/blob_test.py +0 -67
  166. test/cli_imports_test.py +0 -149
  167. test/cli_test.py +0 -659
  168. test/client_test.py +0 -194
  169. test/cls_test.py +0 -630
  170. test/config_test.py +0 -137
  171. test/conftest.py +0 -1420
  172. test/container_app_test.py +0 -32
  173. test/container_test.py +0 -1389
  174. test/cpu_test.py +0 -23
  175. test/decorator_test.py +0 -85
  176. test/deprecation_test.py +0 -34
  177. test/dict_test.py +0 -33
  178. test/e2e_test.py +0 -68
  179. test/error_test.py +0 -7
  180. test/function_serialization_test.py +0 -32
  181. test/function_test.py +0 -653
  182. test/function_utils_test.py +0 -101
  183. test/gpu_test.py +0 -159
  184. test/grpc_utils_test.py +0 -141
  185. test/helpers.py +0 -42
  186. test/image_test.py +0 -669
  187. test/live_reload_test.py +0 -80
  188. test/lookup_test.py +0 -70
  189. test/mdmd_test.py +0 -329
  190. test/mount_test.py +0 -162
  191. test/mounted_files_test.py +0 -329
  192. test/network_file_system_test.py +0 -181
  193. test/notebook_test.py +0 -66
  194. test/object_test.py +0 -41
  195. test/package_utils_test.py +0 -25
  196. test/queue_test.py +0 -97
  197. test/resolver_test.py +0 -58
  198. test/retries_test.py +0 -67
  199. test/runner_test.py +0 -85
  200. test/sandbox_test.py +0 -191
  201. test/schedule_test.py +0 -15
  202. test/scheduler_placement_test.py +0 -29
  203. test/secret_test.py +0 -78
  204. test/serialization_test.py +0 -42
  205. test/stub_composition_test.py +0 -10
  206. test/stub_test.py +0 -360
  207. test/test_asgi_wrapper.py +0 -234
  208. test/token_flow_test.py +0 -18
  209. test/traceback_test.py +0 -135
  210. test/tunnel_test.py +0 -29
  211. test/utils_test.py +0 -88
  212. test/version_test.py +0 -14
  213. test/volume_test.py +0 -341
  214. test/watcher_test.py +0 -30
  215. test/webhook_test.py +0 -146
  216. /modal/{requirements.312.txt → requirements/2023.12.312.txt} +0 -0
  217. /modal/{requirements.txt → requirements/2023.12.txt} +0 -0
  218. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
  219. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
  220. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
@@ -4,7 +4,8 @@ import inspect
4
4
  import logging
5
5
  import typing
6
6
  from collections import Counter, defaultdict
7
- from typing import Any, Awaitable, Callable, Dict, List, Tuple
7
+ from collections.abc import Awaitable
8
+ from typing import Any, Callable
8
9
 
9
10
  import grpclib.server
10
11
  from grpclib import GRPCError, Status
@@ -26,7 +27,8 @@ def patch_mock_servicer(cls):
26
27
  await some_complex_method()
27
28
  assert ctx.calls == [("SomeMethod", MyMessage(foo="bar"))]
28
29
  ```
29
- Also allows to set a predefined queue of responses, temporarily replacing a mock servicer's default responses for a method:
30
+ Also allows to set a predefined queue of responses, temporarily replacing
31
+ a mock servicer's default responses for a method:
30
32
 
31
33
  ```python notest
32
34
  with servicer.intercept() as ctx:
@@ -48,10 +50,10 @@ def patch_mock_servicer(cls):
48
50
 
49
51
  @contextlib.contextmanager
50
52
  def intercept(servicer):
51
- ctx = InterceptionContext()
53
+ ctx = InterceptionContext(servicer)
52
54
  servicer.interception_context = ctx
53
55
  yield ctx
54
- ctx.assert_responses_consumed()
56
+ ctx._assert_responses_consumed()
55
57
  servicer.interception_context = None
56
58
 
57
59
  cls.intercept = intercept
@@ -63,7 +65,7 @@ def patch_mock_servicer(cls):
63
65
  ctx = servicer_self.interception_context
64
66
  if ctx:
65
67
  intercepted_stream = await InterceptedStream(ctx, method_name, stream).initialize()
66
- custom_responder = ctx.next_custom_responder(method_name, intercepted_stream.request_message)
68
+ custom_responder = ctx._next_custom_responder(method_name, intercepted_stream.request_message)
67
69
  if custom_responder:
68
70
  return await custom_responder(servicer_self, intercepted_stream)
69
71
  else:
@@ -92,31 +94,36 @@ def patch_mock_servicer(cls):
92
94
 
93
95
 
94
96
  class ResponseNotConsumed(Exception):
95
- def __init__(self, unconsumed_requests: List[str]):
97
+ def __init__(self, unconsumed_requests: list[str]):
96
98
  self.unconsumed_requests = unconsumed_requests
97
99
  request_count = Counter(unconsumed_requests)
98
100
  super().__init__(f"Expected but did not receive the following requests: {request_count}")
99
101
 
100
102
 
101
103
  class InterceptionContext:
102
- def __init__(self):
103
- self.calls: List[Tuple[str, Any]] = [] # List[Tuple[method_name, message]]
104
- self.custom_responses: Dict[str, List[Tuple[Callable[[Any], bool], List[Any]]]] = defaultdict(list)
105
- self.custom_defaults: Dict[str, Callable[["MockClientServicer", grpclib.server.Stream], Awaitable[None]]] = {}
106
-
107
- def add_recv(self, method_name: str, msg):
108
- self.calls.append((method_name, msg))
104
+ def __init__(self, servicer):
105
+ self._servicer = servicer
106
+ self.calls: list[tuple[str, Any]] = [] # List[Tuple[method_name, message]]
107
+ self.custom_responses: dict[str, list[tuple[Callable[[Any], bool], list[Any]]]] = defaultdict(list)
108
+ self.custom_defaults: dict[str, Callable[["MockClientServicer", grpclib.server.Stream], Awaitable[None]]] = {}
109
109
 
110
110
  def add_response(
111
111
  self, method_name: str, first_payload, *, request_filter: Callable[[Any], bool] = lambda req: True
112
112
  ):
113
- # adds one response to a queue of responses for requests of the specified type
113
+ """Adds one response payload to an expected queue of responses for a method.
114
+
115
+ These responses will be used once each instead of calling the MockServicer's
116
+ implementation of the method.
117
+
118
+ The interception context will throw an exception on exit if not all of the added
119
+ responses have been consumed.
120
+ """
114
121
  self.custom_responses[method_name].append((request_filter, [first_payload]))
115
122
 
116
123
  def set_responder(
117
124
  self, method_name: str, responder: Callable[["MockClientServicer", grpclib.server.Stream], Awaitable[None]]
118
125
  ):
119
- """Replace the default responder method. E.g.
126
+ """Replace the default responder from the MockClientServicer with a custom implementation
120
127
 
121
128
  ```python notest
122
129
  def custom_responder(servicer, stream):
@@ -127,11 +134,31 @@ class InterceptionContext:
127
134
  ctx.set_responder("SomeMethod", custom_responder)
128
135
  ```
129
136
 
130
- Responses added via `.add_response()` take precedence.
137
+ Responses added via `.add_response()` take precedence over the use of this replacement
131
138
  """
132
139
  self.custom_defaults[method_name] = responder
133
140
 
134
- def next_custom_responder(self, method_name, request):
141
+ def pop_request(self, method_name):
142
+ # fast forward to the next request of type method_name
143
+ # dropping any preceding requests if there is a match
144
+ # returns the payload of the request
145
+ for i, (_method_name, msg) in enumerate(self.calls):
146
+ if _method_name == method_name:
147
+ self.calls = self.calls[i + 1 :]
148
+ return msg
149
+
150
+ raise KeyError(f"No message of that type in call list: {self.calls}")
151
+
152
+ def get_requests(self, method_name: str) -> list[Any]:
153
+ if not hasattr(self._servicer, method_name):
154
+ # we check this to prevent things like `assert ctx.get_requests("ASdfFunctionCreate") == 0` passing
155
+ raise ValueError(f"{method_name} not in MockServicer - did you spell it right?")
156
+ return [msg for _method_name, msg in self.calls if _method_name == method_name]
157
+
158
+ def _add_recv(self, method_name: str, msg):
159
+ self.calls.append((method_name, msg))
160
+
161
+ def _next_custom_responder(self, method_name, request):
135
162
  method_responses = self.custom_responses[method_name]
136
163
  for i, (request_filter, response_messages) in enumerate(method_responses):
137
164
  try:
@@ -158,7 +185,7 @@ class InterceptionContext:
158
185
 
159
186
  return responder
160
187
 
161
- def assert_responses_consumed(self):
188
+ def _assert_responses_consumed(self):
162
189
  unconsumed = []
163
190
  for method_name, queued_responses in self.custom_responses.items():
164
191
  unconsumed += [method_name] * len(queued_responses)
@@ -166,20 +193,9 @@ class InterceptionContext:
166
193
  if unconsumed:
167
194
  raise ResponseNotConsumed(unconsumed)
168
195
 
169
- def pop_request(self, method_name):
170
- # fast forward to the next request of type method_name
171
- # dropping any preceding requests if there is a match
172
- # returns the payload of the request
173
- for i, (_method_name, msg) in enumerate(self.calls):
174
- if _method_name == method_name:
175
- self.calls = self.calls[i + 1 :]
176
- return msg
177
-
178
- raise Exception(f"No message of that type in call list: {self.calls}")
179
-
180
196
 
181
197
  class InterceptedStream:
182
- def __init__(self, interception_context, method_name, stream):
198
+ def __init__(self, interception_context: InterceptionContext, method_name: str, stream):
183
199
  self.interception_context = interception_context
184
200
  self.method_name = method_name
185
201
  self.stream = stream
@@ -196,7 +212,7 @@ class InterceptedStream:
196
212
  return ret
197
213
 
198
214
  msg = await self.stream.recv_message()
199
- self.interception_context.add_recv(self.method_name, msg)
215
+ self.interception_context._add_recv(self.method_name, msg)
200
216
  return msg
201
217
 
202
218
  async def send_message(self, msg):
@@ -4,30 +4,37 @@ import contextlib
4
4
  import platform
5
5
  import socket
6
6
  import time
7
+ import typing
7
8
  import urllib.parse
8
9
  import uuid
10
+ from collections.abc import AsyncIterator
9
11
  from typing import (
10
12
  Any,
11
- AsyncIterator,
12
- Dict,
13
- List,
14
13
  Optional,
15
- Type,
16
14
  TypeVar,
17
15
  )
18
16
 
19
17
  import grpclib.client
20
18
  import grpclib.config
21
19
  import grpclib.events
20
+ import grpclib.protocol
21
+ import grpclib.stream
22
22
  from google.protobuf.message import Message
23
23
  from grpclib import GRPCError, Status
24
24
  from grpclib.exceptions import StreamTerminatedError
25
25
  from grpclib.protocol import H2Protocol
26
26
 
27
+ from modal.exception import AuthError, ConnectionError
27
28
  from modal_version import __version__
28
29
 
29
30
  from .logger import logger
30
31
 
32
+ RequestType = TypeVar("RequestType", bound=Message)
33
+ ResponseType = TypeVar("ResponseType", bound=Message)
34
+
35
+ if typing.TYPE_CHECKING:
36
+ import modal.client
37
+
31
38
  # Monkey patches grpclib to have a Modal User Agent header.
32
39
  grpclib.client.USER_AGENT = "modal-client/{version} ({sys}; {py}/{py_ver})'".format(
33
40
  version=__version__,
@@ -54,81 +61,6 @@ class Subchannel:
54
61
  return True
55
62
 
56
63
 
57
- class ChannelPool(grpclib.client.Channel):
58
- """Use multiple channels under the hood. A drop-in replacement for the grpclib Channel.
59
-
60
- The main reason is to get around limitations with TCP connections over the internet,
61
- in particular idle timeouts.
62
-
63
- The algorithm is very simple. It reuses the last subchannel as long as it has had less
64
- than 64 requests or if it was created less than 30s ago. It closes any subchannel that
65
- hits 90s age. This means requests using the ChannelPool can't be longer than 60s.
66
- """
67
-
68
- _max_requests: int
69
- _max_lifetime: float
70
- _max_active: float
71
- _subchannels: List[Subchannel]
72
-
73
- def __init__(
74
- self,
75
- *args,
76
- max_requests=64, # Maximum number of total requests per subchannel
77
- max_active=30, # Don't accept more connections on the subchannel after this many seconds
78
- max_lifetime=90, # Close subchannel after this many seconds
79
- **kwargs,
80
- ):
81
- self._subchannels = []
82
- self._max_requests = max_requests
83
- self._max_active = max_active
84
- self._max_lifetime = max_lifetime
85
- super().__init__(*args, **kwargs)
86
-
87
- async def __connect__(self):
88
- now = time.time()
89
- # Remove any closed subchannels
90
- while len(self._subchannels) > 0 and not self._subchannels[-1].connected():
91
- self._subchannels.pop()
92
-
93
- # Close and delete any subchannels that are past their lifetime
94
- while len(self._subchannels) > 0 and now - self._subchannels[0].created_at > self._max_lifetime:
95
- self._subchannels.pop(0).protocol.processor.close()
96
-
97
- # See if we can reuse the last subchannel
98
- create_subchannel = None
99
- if len(self._subchannels) > 0:
100
- if self._subchannels[-1].created_at < now - self._max_active:
101
- # Don't reuse subchannel that's too old
102
- create_subchannel = True
103
- elif self._subchannels[-1].requests > self._max_requests:
104
- create_subchannel = True
105
- else:
106
- create_subchannel = False
107
- else:
108
- create_subchannel = True
109
-
110
- # Create new if needed
111
- # There's a theoretical race condition here.
112
- # This is harmless but may lead to superfluous protocols.
113
- if create_subchannel:
114
- protocol = await self._create_connection()
115
- self._subchannels.append(Subchannel(protocol))
116
-
117
- self._subchannels[-1].requests += 1
118
- return self._subchannels[-1].protocol
119
-
120
- def close(self) -> None:
121
- while len(self._subchannels) > 0:
122
- self._subchannels.pop(0).protocol.processor.close()
123
-
124
- def __del__(self) -> None:
125
- if len(self._subchannels) > 0:
126
- logger.warning("Channel pool not properly closed")
127
-
128
-
129
- _SendType = TypeVar("_SendType")
130
- _RecvType = TypeVar("_RecvType")
131
-
132
64
  RETRYABLE_GRPC_STATUS_CODES = [
133
65
  Status.DEADLINE_EXCEEDED,
134
66
  Status.UNAVAILABLE,
@@ -139,9 +71,7 @@ RETRYABLE_GRPC_STATUS_CODES = [
139
71
 
140
72
  def create_channel(
141
73
  server_url: str,
142
- metadata: Dict[str, str] = {},
143
- *,
144
- use_pool: Optional[bool] = None, # If None, inferred from the scheme
74
+ metadata: dict[str, str] = {},
145
75
  ) -> grpclib.client.Channel:
146
76
  """Creates a grpclib.Channel.
147
77
 
@@ -150,15 +80,6 @@ def create_channel(
150
80
  """
151
81
  o = urllib.parse.urlparse(server_url)
152
82
 
153
- if use_pool is None:
154
- use_pool = o.scheme in ("http", "https")
155
-
156
- channel_cls: Type[grpclib.client.Channel]
157
- if use_pool:
158
- channel_cls = ChannelPool
159
- else:
160
- channel_cls = grpclib.client.Channel
161
-
162
83
  channel: grpclib.client.Channel
163
84
  config = grpclib.config.Configuration(
164
85
  http2_connection_window_size=64 * 1024 * 1024, # 64 MiB
@@ -166,7 +87,7 @@ def create_channel(
166
87
  )
167
88
 
168
89
  if o.scheme == "unix":
169
- channel = channel_cls(path=o.path, config=config) # probably pointless to use a pool ever
90
+ channel = grpclib.client.Channel(path=o.path, config=config) # probably pointless to use a pool ever
170
91
  elif o.scheme in ("http", "https"):
171
92
  target = o.netloc
172
93
  parts = target.split(":")
@@ -174,7 +95,7 @@ def create_channel(
174
95
  ssl = o.scheme.endswith("s")
175
96
  host = parts[0]
176
97
  port = int(parts[1]) if len(parts) == 2 else 443 if ssl else 80
177
- channel = channel_cls(host, port, ssl=ssl, config=config)
98
+ channel = grpclib.client.Channel(host, port, ssl=ssl, config=config)
178
99
  else:
179
100
  raise Exception(f"Unknown scheme: {o.scheme}")
180
101
 
@@ -189,23 +110,31 @@ def create_channel(
189
110
  logger.debug(f"Sending request to {event.method_name}")
190
111
 
191
112
  grpclib.events.listen(channel, grpclib.events.SendRequest, send_request)
113
+
192
114
  return channel
193
115
 
194
116
 
117
+ async def connect_channel(channel: grpclib.client.Channel):
118
+ """Connects socket (potentially raising errors raising to connectivity."""
119
+ await channel.__connect__()
120
+
121
+
122
+ if typing.TYPE_CHECKING:
123
+ import modal.client
124
+
125
+
195
126
  async def unary_stream(
196
- method: grpclib.client.UnaryStreamMethod[_SendType, _RecvType],
197
- request: _SendType,
127
+ method: "modal.client.UnaryStreamWrapper[RequestType, ResponseType]",
128
+ request: RequestType,
198
129
  metadata: Optional[Any] = None,
199
- ) -> AsyncIterator[_RecvType]:
200
- """Helper for making a unary-streaming gRPC request."""
201
- async with method.open(metadata=metadata) as stream:
202
- await stream.send_message(request, end=True)
203
- async for item in stream:
204
- yield item
130
+ ) -> AsyncIterator[ResponseType]:
131
+ # TODO: remove this, since we have a method now
132
+ async for item in method.unary_stream(request, metadata):
133
+ yield item
205
134
 
206
135
 
207
136
  async def retry_transient_errors(
208
- fn,
137
+ fn: "modal.client.UnaryUnaryWrapper[RequestType, ResponseType]",
209
138
  *args,
210
139
  base_delay: float = 0.1,
211
140
  max_delay: float = 1,
@@ -215,7 +144,7 @@ async def retry_transient_errors(
215
144
  attempt_timeout: Optional[float] = None, # timeout for each attempt
216
145
  total_timeout: Optional[float] = None, # timeout for the entire function call
217
146
  attempt_timeout_floor=2.0, # always have at least this much timeout (only for total_timeout)
218
- ):
147
+ ) -> ResponseType:
219
148
  """Retry on transient gRPC failures with back-off until max_retries is reached.
220
149
  If max_retries is None, retry forever."""
221
150
 
@@ -247,16 +176,35 @@ async def retry_transient_errors(
247
176
  timeout = None
248
177
  try:
249
178
  return await fn(*args, metadata=metadata, timeout=timeout)
250
- except (StreamTerminatedError, GRPCError, socket.gaierror, asyncio.TimeoutError) as exc:
179
+ except (StreamTerminatedError, GRPCError, OSError, asyncio.TimeoutError, AttributeError) as exc:
251
180
  if isinstance(exc, GRPCError) and exc.status not in status_codes:
252
- raise exc
181
+ if exc.status == Status.UNAUTHENTICATED:
182
+ raise AuthError(exc.message)
183
+ else:
184
+ raise exc
253
185
 
254
186
  if max_retries is not None and n_retries >= max_retries:
187
+ final_attempt = True
188
+ elif total_deadline is not None and time.time() + delay + attempt_timeout_floor >= total_deadline:
189
+ final_attempt = True
190
+ else:
191
+ final_attempt = False
192
+
193
+ if final_attempt:
194
+ if isinstance(exc, OSError):
195
+ raise ConnectionError(str(exc))
196
+ elif isinstance(exc, asyncio.TimeoutError):
197
+ raise ConnectionError(str(exc))
198
+ else:
199
+ raise exc
200
+
201
+ if isinstance(exc, AttributeError) and "_write_appdata" not in str(exc):
202
+ # StreamTerminatedError are not properly raised in grpclib<=0.4.7
203
+ # fixed in https://github.com/vmagamedov/grpclib/issues/185
204
+ # TODO: update to newer version (>=0.4.8) once stable
255
205
  raise exc
256
206
 
257
- if total_deadline and time.time() + delay + attempt_timeout_floor >= total_deadline:
258
- # no point sleeping if that's going to push us past the deadline
259
- raise exc
207
+ logger.debug(f"Retryable failure {repr(exc)} {n_retries=} {delay=} for {fn.name}")
260
208
 
261
209
  n_retries += 1
262
210
 
@@ -265,7 +213,12 @@ async def retry_transient_errors(
265
213
 
266
214
 
267
215
  def find_free_port() -> int:
268
- """Find a free TCP port, useful for testing."""
216
+ """
217
+ Find a free TCP port, useful for testing.
218
+
219
+ WARN: if a returned free port is not bound immediately by the caller, that same port
220
+ may be returned in subsequent calls to this function, potentially creating port collisions.
221
+ """
269
222
  with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
270
223
  s.bind(("", 0))
271
224
  s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@@ -2,43 +2,53 @@
2
2
  import base64
3
3
  import dataclasses
4
4
  import hashlib
5
- from typing import IO, Union
5
+ import time
6
+ from typing import BinaryIO, Callable, Optional, Sequence, Union
6
7
 
7
- HASH_CHUNK_SIZE = 4096
8
+ from modal.config import logger
8
9
 
10
+ HASH_CHUNK_SIZE = 65536
9
11
 
10
- def _update(hashers, data: Union[bytes, IO[bytes]]):
12
+
13
+ def _update(hashers: Sequence[Callable[[bytes], None]], data: Union[bytes, BinaryIO]) -> None:
11
14
  if isinstance(data, bytes):
12
15
  for hasher in hashers:
13
- hasher.update(data)
16
+ hasher(data)
14
17
  else:
18
+ assert not isinstance(data, (bytearray, memoryview)) # https://github.com/microsoft/pyright/issues/5697
15
19
  pos = data.tell()
16
- while 1:
20
+ while True:
17
21
  chunk = data.read(HASH_CHUNK_SIZE)
18
22
  if not isinstance(chunk, bytes):
19
23
  raise ValueError(f"Only accepts bytes or byte buffer objects, not {type(chunk)} buffers")
20
24
  if not chunk:
21
25
  break
22
26
  for hasher in hashers:
23
- hasher.update(chunk)
27
+ hasher(chunk)
24
28
  data.seek(pos)
25
29
 
26
30
 
27
- def get_sha256_hex(data: Union[bytes, IO[bytes]]) -> str:
31
+ def get_sha256_hex(data: Union[bytes, BinaryIO]) -> str:
32
+ t0 = time.monotonic()
28
33
  hasher = hashlib.sha256()
29
- _update([hasher], data)
34
+ _update([hasher.update], data)
35
+ logger.debug("get_sha256_hex took %.3fs", time.monotonic() - t0)
30
36
  return hasher.hexdigest()
31
37
 
32
38
 
33
- def get_sha256_base64(data: Union[bytes, IO[bytes]]) -> str:
39
+ def get_sha256_base64(data: Union[bytes, BinaryIO]) -> str:
40
+ t0 = time.monotonic()
34
41
  hasher = hashlib.sha256()
35
- _update([hasher], data)
42
+ _update([hasher.update], data)
43
+ logger.debug("get_sha256_base64 took %.3fs", time.monotonic() - t0)
36
44
  return base64.b64encode(hasher.digest()).decode("ascii")
37
45
 
38
46
 
39
- def get_md5_base64(data: Union[bytes, IO[bytes]]) -> str:
47
+ def get_md5_base64(data: Union[bytes, BinaryIO]) -> str:
48
+ t0 = time.monotonic()
40
49
  hasher = hashlib.md5()
41
- _update([hasher], data)
50
+ _update([hasher.update], data)
51
+ logger.debug("get_md5_base64 took %.3fs", time.monotonic() - t0)
42
52
  return base64.b64encode(hasher.digest()).decode("utf-8")
43
53
 
44
54
 
@@ -47,12 +57,44 @@ class UploadHashes:
47
57
  md5_base64: str
48
58
  sha256_base64: str
49
59
 
60
+ def md5_hex(self) -> str:
61
+ return base64.b64decode(self.md5_base64).hex()
62
+
63
+ def sha256_hex(self) -> str:
64
+ return base64.b64decode(self.sha256_base64).hex()
65
+
66
+
67
+ def get_upload_hashes(
68
+ data: Union[bytes, BinaryIO], sha256_hex: Optional[str] = None, md5_hex: Optional[str] = None
69
+ ) -> UploadHashes:
70
+ t0 = time.monotonic()
71
+ hashers = {}
72
+
73
+ if not sha256_hex:
74
+ sha256 = hashlib.sha256()
75
+ hashers["sha256"] = sha256
76
+ if not md5_hex:
77
+ md5 = hashlib.md5()
78
+ hashers["md5"] = md5
79
+
80
+ if hashers:
81
+ updaters = [h.update for h in hashers.values()]
82
+ _update(updaters, data)
50
83
 
51
- def get_upload_hashes(data: Union[bytes, IO[bytes]]) -> UploadHashes:
52
- md5 = hashlib.md5()
53
- sha256 = hashlib.sha256()
54
- _update([md5, sha256], data)
55
- return UploadHashes(
56
- md5_base64=base64.b64encode(md5.digest()).decode("ascii"),
57
- sha256_base64=base64.b64encode(sha256.digest()).decode("ascii"),
84
+ if sha256_hex:
85
+ sha256_base64 = base64.b64encode(bytes.fromhex(sha256_hex)).decode("ascii")
86
+ else:
87
+ sha256_base64 = base64.b64encode(hashers["sha256"].digest()).decode("ascii")
88
+
89
+ if md5_hex:
90
+ md5_base64 = base64.b64encode(bytes.fromhex(md5_hex)).decode("ascii")
91
+ else:
92
+ md5_base64 = base64.b64encode(hashers["md5"].digest()).decode("ascii")
93
+
94
+ hashes = UploadHashes(
95
+ md5_base64=md5_base64,
96
+ sha256_base64=sha256_base64,
58
97
  )
98
+
99
+ logger.debug("get_upload_hashes took %.3fs (%s)", time.monotonic() - t0, hashers.keys())
100
+ return hashes
@@ -1,16 +1,18 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import contextlib
3
- import socket
4
- import ssl
5
- from typing import Optional
3
+ from typing import TYPE_CHECKING, Optional
6
4
 
7
- import certifi
8
- from aiohttp import ClientSession, ClientTimeout, TCPConnector
9
- from aiohttp.web import Application
10
- from aiohttp.web_runner import AppRunner, SockSite
5
+ # Note: importing aiohttp seems to take about 100ms, and it's not really necessarily,
6
+ # unless we need to work with blobs. So that's why we import it lazily instead.
11
7
 
8
+ if TYPE_CHECKING:
9
+ from aiohttp import ClientSession
10
+ from aiohttp.web import Application
12
11
 
13
- def http_client_with_tls(timeout: Optional[float]) -> ClientSession:
12
+ from .async_utils import on_shutdown
13
+
14
+
15
+ def _http_client_with_tls(timeout: Optional[float]) -> "ClientSession":
14
16
  """Create a new HTTP client session with standard, bundled TLS certificates.
15
17
 
16
18
  This is necessary to prevent client issues on some system where Python does
@@ -20,15 +22,43 @@ def http_client_with_tls(timeout: Optional[float]) -> ClientSession:
20
22
  Specifically: the error "unable to get local issuer certificate" when making
21
23
  an aiohttp request.
22
24
  """
25
+ import ssl
26
+
27
+ import certifi
28
+ from aiohttp import ClientSession, ClientTimeout, TCPConnector
29
+
23
30
  ssl_context = ssl.create_default_context(cafile=certifi.where())
24
31
  connector = TCPConnector(ssl=ssl_context)
25
32
  return ClientSession(connector=connector, timeout=ClientTimeout(total=timeout))
26
33
 
27
34
 
35
+ class ClientSessionRegistry:
36
+ _client_session: "ClientSession"
37
+ _client_session_active: bool = False
38
+
39
+ @staticmethod
40
+ def get_session():
41
+ if not ClientSessionRegistry._client_session_active:
42
+ ClientSessionRegistry._client_session = _http_client_with_tls(timeout=None)
43
+ ClientSessionRegistry._client_session_active = True
44
+ on_shutdown(ClientSessionRegistry.close_session())
45
+ return ClientSessionRegistry._client_session
46
+
47
+ @staticmethod
48
+ async def close_session():
49
+ if ClientSessionRegistry._client_session_active:
50
+ await ClientSessionRegistry._client_session.close()
51
+ ClientSessionRegistry._client_session_active = False
52
+
53
+
28
54
  @contextlib.asynccontextmanager
29
- async def run_temporary_http_server(app: Application):
55
+ async def run_temporary_http_server(app: "Application"):
30
56
  # Allocates a random port, runs a server in a context manager
31
57
  # This is used in various tests
58
+ import socket
59
+
60
+ from aiohttp.web_runner import AppRunner, SockSite
61
+
32
62
  sock = socket.socket()
33
63
  sock.bind(("", 0))
34
64
  port = sock.getsockname()[1]
modal/_utils/logger.py CHANGED
@@ -17,7 +17,8 @@ def configure_logger(logger: logging.Logger, log_level: str, log_format: str):
17
17
  json_formatter = jsonlogger.JsonFormatter(
18
18
  fmt=(
19
19
  "%(asctime)s %(levelname)s [%(name)s] [%(filename)s:%(lineno)d] "
20
- "[dd.service=%(dd.service)s dd.env=%(dd.env)s dd.version=%(dd.version)s dd.trace_id=%(dd.trace_id)s dd.span_id=%(dd.span_id)s] "
20
+ "[dd.service=%(dd.service)s dd.env=%(dd.env)s dd.version=%(dd.version)s dd.trace_id=%(dd.trace_id)s "
21
+ "dd.span_id=%(dd.span_id)s] "
21
22
  "- %(message)s"
22
23
  ),
23
24
  datefmt="%Y-%m-%dT%H:%M:%S%z",