modal 0.62.115__py3-none-any.whl → 0.72.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +13 -9
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +402 -398
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -60
  11. modal/_resources.py +26 -7
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1025 -0
  15. modal/{execution_context.py → _runtime/execution_context.py} +11 -2
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +123 -6
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +50 -14
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +386 -104
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +299 -98
  29. modal/_utils/grpc_testing.py +47 -34
  30. modal/_utils/grpc_utils.py +54 -21
  31. modal/_utils/hash_utils.py +51 -10
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +3 -3
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +12 -10
  43. modal/app.py +561 -323
  44. modal/app.pyi +474 -262
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +22 -6
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +203 -42
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +61 -13
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +21 -48
  55. modal/cli/launch.py +28 -14
  56. modal/cli/network_file_system.py +57 -21
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +34 -9
  59. modal/cli/programs/vscode.py +58 -8
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +199 -96
  62. modal/cli/secret.py +5 -4
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +74 -8
  65. modal/cli/volume.py +97 -56
  66. modal/client.py +248 -144
  67. modal/client.pyi +156 -124
  68. modal/cloud_bucket_mount.py +43 -30
  69. modal/cloud_bucket_mount.pyi +32 -25
  70. modal/cls.py +528 -141
  71. modal/cls.pyi +189 -145
  72. modal/config.py +32 -15
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +50 -54
  76. modal/dict.pyi +120 -164
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +30 -43
  80. modal/experimental.py +62 -2
  81. modal/file_io.py +537 -0
  82. modal/file_io.pyi +235 -0
  83. modal/file_pattern_matcher.py +196 -0
  84. modal/functions.py +846 -428
  85. modal/functions.pyi +446 -387
  86. modal/gpu.py +57 -44
  87. modal/image.py +943 -417
  88. modal/image.pyi +584 -245
  89. modal/io_streams.py +434 -0
  90. modal/io_streams.pyi +122 -0
  91. modal/mount.py +223 -90
  92. modal/mount.pyi +241 -243
  93. modal/network_file_system.py +85 -86
  94. modal/network_file_system.pyi +151 -110
  95. modal/object.py +66 -36
  96. modal/object.pyi +166 -143
  97. modal/output.py +63 -0
  98. modal/parallel_map.py +73 -47
  99. modal/parallel_map.pyi +51 -63
  100. modal/partial_function.py +272 -107
  101. modal/partial_function.pyi +219 -120
  102. modal/proxy.py +15 -12
  103. modal/proxy.pyi +3 -8
  104. modal/queue.py +96 -72
  105. modal/queue.pyi +210 -135
  106. modal/requirements/2024.04.txt +2 -1
  107. modal/requirements/2024.10.txt +16 -0
  108. modal/requirements/README.md +21 -0
  109. modal/requirements/base-images.json +22 -0
  110. modal/retries.py +45 -4
  111. modal/runner.py +325 -203
  112. modal/runner.pyi +124 -110
  113. modal/running_app.py +27 -4
  114. modal/sandbox.py +509 -231
  115. modal/sandbox.pyi +396 -169
  116. modal/schedule.py +2 -2
  117. modal/scheduler_placement.py +20 -3
  118. modal/secret.py +41 -25
  119. modal/secret.pyi +62 -42
  120. modal/serving.py +39 -49
  121. modal/serving.pyi +37 -43
  122. modal/stream_type.py +15 -0
  123. modal/token_flow.py +5 -3
  124. modal/token_flow.pyi +37 -32
  125. modal/volume.py +123 -137
  126. modal/volume.pyi +228 -221
  127. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/METADATA +5 -5
  128. modal-0.72.13.dist-info/RECORD +174 -0
  129. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/top_level.txt +0 -1
  130. modal_docs/gen_reference_docs.py +3 -1
  131. modal_docs/mdmd/mdmd.py +0 -1
  132. modal_docs/mdmd/signatures.py +1 -2
  133. modal_global_objects/images/base_images.py +28 -0
  134. modal_global_objects/mounts/python_standalone.py +2 -2
  135. modal_proto/__init__.py +1 -1
  136. modal_proto/api.proto +1231 -531
  137. modal_proto/api_grpc.py +750 -430
  138. modal_proto/api_pb2.py +2102 -1176
  139. modal_proto/api_pb2.pyi +8859 -0
  140. modal_proto/api_pb2_grpc.py +1329 -675
  141. modal_proto/api_pb2_grpc.pyi +1416 -0
  142. modal_proto/modal_api_grpc.py +149 -0
  143. modal_proto/modal_options_grpc.py +3 -0
  144. modal_proto/options_pb2.pyi +20 -0
  145. modal_proto/options_pb2_grpc.pyi +7 -0
  146. modal_proto/py.typed +0 -0
  147. modal_version/__init__.py +1 -1
  148. modal_version/_version_generated.py +2 -2
  149. modal/_asgi.py +0 -370
  150. modal/_container_exec.py +0 -128
  151. modal/_container_io_manager.py +0 -646
  152. modal/_container_io_manager.pyi +0 -412
  153. modal/_sandbox_shell.py +0 -49
  154. modal/app_utils.py +0 -20
  155. modal/app_utils.pyi +0 -17
  156. modal/execution_context.pyi +0 -37
  157. modal/shared_volume.py +0 -23
  158. modal/shared_volume.pyi +0 -24
  159. modal-0.62.115.dist-info/RECORD +0 -207
  160. modal_global_objects/images/conda.py +0 -15
  161. modal_global_objects/images/debian_slim.py +0 -15
  162. modal_global_objects/images/micromamba.py +0 -15
  163. test/__init__.py +0 -1
  164. test/aio_test.py +0 -12
  165. test/async_utils_test.py +0 -279
  166. test/blob_test.py +0 -67
  167. test/cli_imports_test.py +0 -149
  168. test/cli_test.py +0 -674
  169. test/client_test.py +0 -203
  170. test/cloud_bucket_mount_test.py +0 -22
  171. test/cls_test.py +0 -636
  172. test/config_test.py +0 -149
  173. test/conftest.py +0 -1485
  174. test/container_app_test.py +0 -50
  175. test/container_test.py +0 -1405
  176. test/cpu_test.py +0 -23
  177. test/decorator_test.py +0 -85
  178. test/deprecation_test.py +0 -34
  179. test/dict_test.py +0 -51
  180. test/e2e_test.py +0 -68
  181. test/error_test.py +0 -7
  182. test/function_serialization_test.py +0 -32
  183. test/function_test.py +0 -791
  184. test/function_utils_test.py +0 -101
  185. test/gpu_test.py +0 -159
  186. test/grpc_utils_test.py +0 -82
  187. test/helpers.py +0 -47
  188. test/image_test.py +0 -814
  189. test/live_reload_test.py +0 -80
  190. test/lookup_test.py +0 -70
  191. test/mdmd_test.py +0 -329
  192. test/mount_test.py +0 -162
  193. test/mounted_files_test.py +0 -327
  194. test/network_file_system_test.py +0 -188
  195. test/notebook_test.py +0 -66
  196. test/object_test.py +0 -41
  197. test/package_utils_test.py +0 -25
  198. test/queue_test.py +0 -115
  199. test/resolver_test.py +0 -59
  200. test/retries_test.py +0 -67
  201. test/runner_test.py +0 -85
  202. test/sandbox_test.py +0 -191
  203. test/schedule_test.py +0 -15
  204. test/scheduler_placement_test.py +0 -57
  205. test/secret_test.py +0 -89
  206. test/serialization_test.py +0 -50
  207. test/stub_composition_test.py +0 -10
  208. test/stub_test.py +0 -361
  209. test/test_asgi_wrapper.py +0 -234
  210. test/token_flow_test.py +0 -18
  211. test/traceback_test.py +0 -135
  212. test/tunnel_test.py +0 -29
  213. test/utils_test.py +0 -88
  214. test/version_test.py +0 -14
  215. test/volume_test.py +0 -397
  216. test/watcher_test.py +0 -58
  217. test/webhook_test.py +0 -145
  218. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/LICENSE +0 -0
  219. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/WHEEL +0 -0
  220. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/entry_points.txt +0 -0
@@ -4,7 +4,8 @@ import inspect
4
4
  import logging
5
5
  import typing
6
6
  from collections import Counter, defaultdict
7
- from typing import Any, Awaitable, Callable, Dict, List, Tuple
7
+ from collections.abc import Awaitable
8
+ from typing import Any, Callable
8
9
 
9
10
  import grpclib.server
10
11
  from grpclib import GRPCError, Status
@@ -26,7 +27,8 @@ def patch_mock_servicer(cls):
26
27
  await some_complex_method()
27
28
  assert ctx.calls == [("SomeMethod", MyMessage(foo="bar"))]
28
29
  ```
29
- Also allows to set a predefined queue of responses, temporarily replacing a mock servicer's default responses for a method:
30
+ Also allows to set a predefined queue of responses, temporarily replacing
31
+ a mock servicer's default responses for a method:
30
32
 
31
33
  ```python notest
32
34
  with servicer.intercept() as ctx:
@@ -48,10 +50,10 @@ def patch_mock_servicer(cls):
48
50
 
49
51
  @contextlib.contextmanager
50
52
  def intercept(servicer):
51
- ctx = InterceptionContext()
53
+ ctx = InterceptionContext(servicer)
52
54
  servicer.interception_context = ctx
53
55
  yield ctx
54
- ctx.assert_responses_consumed()
56
+ ctx._assert_responses_consumed()
55
57
  servicer.interception_context = None
56
58
 
57
59
  cls.intercept = intercept
@@ -63,7 +65,7 @@ def patch_mock_servicer(cls):
63
65
  ctx = servicer_self.interception_context
64
66
  if ctx:
65
67
  intercepted_stream = await InterceptedStream(ctx, method_name, stream).initialize()
66
- custom_responder = ctx.next_custom_responder(method_name, intercepted_stream.request_message)
68
+ custom_responder = ctx._next_custom_responder(method_name, intercepted_stream.request_message)
67
69
  if custom_responder:
68
70
  return await custom_responder(servicer_self, intercepted_stream)
69
71
  else:
@@ -92,31 +94,36 @@ def patch_mock_servicer(cls):
92
94
 
93
95
 
94
96
  class ResponseNotConsumed(Exception):
95
- def __init__(self, unconsumed_requests: List[str]):
97
+ def __init__(self, unconsumed_requests: list[str]):
96
98
  self.unconsumed_requests = unconsumed_requests
97
99
  request_count = Counter(unconsumed_requests)
98
100
  super().__init__(f"Expected but did not receive the following requests: {request_count}")
99
101
 
100
102
 
101
103
  class InterceptionContext:
102
- def __init__(self):
103
- self.calls: List[Tuple[str, Any]] = [] # List[Tuple[method_name, message]]
104
- self.custom_responses: Dict[str, List[Tuple[Callable[[Any], bool], List[Any]]]] = defaultdict(list)
105
- self.custom_defaults: Dict[str, Callable[["MockClientServicer", grpclib.server.Stream], Awaitable[None]]] = {}
106
-
107
- def add_recv(self, method_name: str, msg):
108
- self.calls.append((method_name, msg))
104
+ def __init__(self, servicer):
105
+ self._servicer = servicer
106
+ self.calls: list[tuple[str, Any]] = [] # List[Tuple[method_name, message]]
107
+ self.custom_responses: dict[str, list[tuple[Callable[[Any], bool], list[Any]]]] = defaultdict(list)
108
+ self.custom_defaults: dict[str, Callable[["MockClientServicer", grpclib.server.Stream], Awaitable[None]]] = {}
109
109
 
110
110
  def add_response(
111
111
  self, method_name: str, first_payload, *, request_filter: Callable[[Any], bool] = lambda req: True
112
112
  ):
113
- # adds one response to a queue of responses for requests of the specified type
113
+ """Adds one response payload to an expected queue of responses for a method.
114
+
115
+ These responses will be used once each instead of calling the MockServicer's
116
+ implementation of the method.
117
+
118
+ The interception context will throw an exception on exit if not all of the added
119
+ responses have been consumed.
120
+ """
114
121
  self.custom_responses[method_name].append((request_filter, [first_payload]))
115
122
 
116
123
  def set_responder(
117
124
  self, method_name: str, responder: Callable[["MockClientServicer", grpclib.server.Stream], Awaitable[None]]
118
125
  ):
119
- """Replace the default responder method. E.g.
126
+ """Replace the default responder from the MockClientServicer with a custom implementation
120
127
 
121
128
  ```python notest
122
129
  def custom_responder(servicer, stream):
@@ -127,11 +134,31 @@ class InterceptionContext:
127
134
  ctx.set_responder("SomeMethod", custom_responder)
128
135
  ```
129
136
 
130
- Responses added via `.add_response()` take precedence.
137
+ Responses added via `.add_response()` take precedence over the use of this replacement
131
138
  """
132
139
  self.custom_defaults[method_name] = responder
133
140
 
134
- def next_custom_responder(self, method_name, request):
141
+ def pop_request(self, method_name):
142
+ # fast forward to the next request of type method_name
143
+ # dropping any preceding requests if there is a match
144
+ # returns the payload of the request
145
+ for i, (_method_name, msg) in enumerate(self.calls):
146
+ if _method_name == method_name:
147
+ self.calls = self.calls[i + 1 :]
148
+ return msg
149
+
150
+ raise KeyError(f"No message of that type in call list: {self.calls}")
151
+
152
+ def get_requests(self, method_name: str) -> list[Any]:
153
+ if not hasattr(self._servicer, method_name):
154
+ # we check this to prevent things like `assert ctx.get_requests("ASdfFunctionCreate") == 0` passing
155
+ raise ValueError(f"{method_name} not in MockServicer - did you spell it right?")
156
+ return [msg for _method_name, msg in self.calls if _method_name == method_name]
157
+
158
+ def _add_recv(self, method_name: str, msg):
159
+ self.calls.append((method_name, msg))
160
+
161
+ def _next_custom_responder(self, method_name, request):
135
162
  method_responses = self.custom_responses[method_name]
136
163
  for i, (request_filter, response_messages) in enumerate(method_responses):
137
164
  try:
@@ -158,7 +185,7 @@ class InterceptionContext:
158
185
 
159
186
  return responder
160
187
 
161
- def assert_responses_consumed(self):
188
+ def _assert_responses_consumed(self):
162
189
  unconsumed = []
163
190
  for method_name, queued_responses in self.custom_responses.items():
164
191
  unconsumed += [method_name] * len(queued_responses)
@@ -166,23 +193,9 @@ class InterceptionContext:
166
193
  if unconsumed:
167
194
  raise ResponseNotConsumed(unconsumed)
168
195
 
169
- def pop_request(self, method_name):
170
- # fast forward to the next request of type method_name
171
- # dropping any preceding requests if there is a match
172
- # returns the payload of the request
173
- for i, (_method_name, msg) in enumerate(self.calls):
174
- if _method_name == method_name:
175
- self.calls = self.calls[i + 1 :]
176
- return msg
177
-
178
- raise KeyError(f"No message of that type in call list: {self.calls}")
179
-
180
- def get_requests(self, method_name: str) -> List[Any]:
181
- return [msg for _method_name, msg in self.calls if _method_name == method_name]
182
-
183
196
 
184
197
  class InterceptedStream:
185
- def __init__(self, interception_context, method_name, stream):
198
+ def __init__(self, interception_context: InterceptionContext, method_name: str, stream):
186
199
  self.interception_context = interception_context
187
200
  self.method_name = method_name
188
201
  self.stream = stream
@@ -199,7 +212,7 @@ class InterceptedStream:
199
212
  return ret
200
213
 
201
214
  msg = await self.stream.recv_message()
202
- self.interception_context.add_recv(self.method_name, msg)
215
+ self.interception_context._add_recv(self.method_name, msg)
203
216
  return msg
204
217
 
205
218
  async def send_message(self, msg):
@@ -4,12 +4,12 @@ import contextlib
4
4
  import platform
5
5
  import socket
6
6
  import time
7
+ import typing
7
8
  import urllib.parse
8
9
  import uuid
10
+ from collections.abc import AsyncIterator
9
11
  from typing import (
10
12
  Any,
11
- AsyncIterator,
12
- Dict,
13
13
  Optional,
14
14
  TypeVar,
15
15
  )
@@ -17,15 +17,24 @@ from typing import (
17
17
  import grpclib.client
18
18
  import grpclib.config
19
19
  import grpclib.events
20
+ import grpclib.protocol
21
+ import grpclib.stream
20
22
  from google.protobuf.message import Message
21
23
  from grpclib import GRPCError, Status
22
24
  from grpclib.exceptions import StreamTerminatedError
23
25
  from grpclib.protocol import H2Protocol
24
26
 
27
+ from modal.exception import AuthError, ConnectionError
25
28
  from modal_version import __version__
26
29
 
27
30
  from .logger import logger
28
31
 
32
+ RequestType = TypeVar("RequestType", bound=Message)
33
+ ResponseType = TypeVar("ResponseType", bound=Message)
34
+
35
+ if typing.TYPE_CHECKING:
36
+ import modal.client
37
+
29
38
  # Monkey patches grpclib to have a Modal User Agent header.
30
39
  grpclib.client.USER_AGENT = "modal-client/{version} ({sys}; {py}/{py_ver})'".format(
31
40
  version=__version__,
@@ -52,9 +61,6 @@ class Subchannel:
52
61
  return True
53
62
 
54
63
 
55
- _SendType = TypeVar("_SendType")
56
- _RecvType = TypeVar("_RecvType")
57
-
58
64
  RETRYABLE_GRPC_STATUS_CODES = [
59
65
  Status.DEADLINE_EXCEEDED,
60
66
  Status.UNAVAILABLE,
@@ -65,7 +71,7 @@ RETRYABLE_GRPC_STATUS_CODES = [
65
71
 
66
72
  def create_channel(
67
73
  server_url: str,
68
- metadata: Dict[str, str] = {},
74
+ metadata: dict[str, str] = {},
69
75
  ) -> grpclib.client.Channel:
70
76
  """Creates a grpclib.Channel.
71
77
 
@@ -104,23 +110,31 @@ def create_channel(
104
110
  logger.debug(f"Sending request to {event.method_name}")
105
111
 
106
112
  grpclib.events.listen(channel, grpclib.events.SendRequest, send_request)
113
+
107
114
  return channel
108
115
 
109
116
 
117
+ async def connect_channel(channel: grpclib.client.Channel):
118
+ """Connects socket (potentially raising errors raising to connectivity."""
119
+ await channel.__connect__()
120
+
121
+
122
+ if typing.TYPE_CHECKING:
123
+ import modal.client
124
+
125
+
110
126
  async def unary_stream(
111
- method: grpclib.client.UnaryStreamMethod[_SendType, _RecvType],
112
- request: _SendType,
127
+ method: "modal.client.UnaryStreamWrapper[RequestType, ResponseType]",
128
+ request: RequestType,
113
129
  metadata: Optional[Any] = None,
114
- ) -> AsyncIterator[_RecvType]:
115
- """Helper for making a unary-streaming gRPC request."""
116
- async with method.open(metadata=metadata) as stream:
117
- await stream.send_message(request, end=True)
118
- async for item in stream:
119
- yield item
130
+ ) -> AsyncIterator[ResponseType]:
131
+ # TODO: remove this, since we have a method now
132
+ async for item in method.unary_stream(request, metadata):
133
+ yield item
120
134
 
121
135
 
122
136
  async def retry_transient_errors(
123
- fn,
137
+ fn: "modal.client.UnaryUnaryWrapper[RequestType, ResponseType]",
124
138
  *args,
125
139
  base_delay: float = 0.1,
126
140
  max_delay: float = 1,
@@ -130,7 +144,7 @@ async def retry_transient_errors(
130
144
  attempt_timeout: Optional[float] = None, # timeout for each attempt
131
145
  total_timeout: Optional[float] = None, # timeout for the entire function call
132
146
  attempt_timeout_floor=2.0, # always have at least this much timeout (only for total_timeout)
133
- ):
147
+ ) -> ResponseType:
134
148
  """Retry on transient gRPC failures with back-off until max_retries is reached.
135
149
  If max_retries is None, retry forever."""
136
150
 
@@ -162,16 +176,35 @@ async def retry_transient_errors(
162
176
  timeout = None
163
177
  try:
164
178
  return await fn(*args, metadata=metadata, timeout=timeout)
165
- except (StreamTerminatedError, GRPCError, socket.gaierror, asyncio.TimeoutError) as exc:
179
+ except (StreamTerminatedError, GRPCError, OSError, asyncio.TimeoutError, AttributeError) as exc:
166
180
  if isinstance(exc, GRPCError) and exc.status not in status_codes:
167
- raise exc
181
+ if exc.status == Status.UNAUTHENTICATED:
182
+ raise AuthError(exc.message)
183
+ else:
184
+ raise exc
168
185
 
169
186
  if max_retries is not None and n_retries >= max_retries:
187
+ final_attempt = True
188
+ elif total_deadline is not None and time.time() + delay + attempt_timeout_floor >= total_deadline:
189
+ final_attempt = True
190
+ else:
191
+ final_attempt = False
192
+
193
+ if final_attempt:
194
+ if isinstance(exc, OSError):
195
+ raise ConnectionError(str(exc))
196
+ elif isinstance(exc, asyncio.TimeoutError):
197
+ raise ConnectionError(str(exc))
198
+ else:
199
+ raise exc
200
+
201
+ if isinstance(exc, AttributeError) and "_write_appdata" not in str(exc):
202
+ # StreamTerminatedError are not properly raised in grpclib<=0.4.7
203
+ # fixed in https://github.com/vmagamedov/grpclib/issues/185
204
+ # TODO: update to newer version (>=0.4.8) once stable
170
205
  raise exc
171
206
 
172
- if total_deadline and time.time() + delay + attempt_timeout_floor >= total_deadline:
173
- # no point sleeping if that's going to push us past the deadline
174
- raise exc
207
+ logger.debug(f"Retryable failure {repr(exc)} {n_retries=} {delay=} for {fn.name}")
175
208
 
176
209
  n_retries += 1
177
210
 
@@ -2,12 +2,15 @@
2
2
  import base64
3
3
  import dataclasses
4
4
  import hashlib
5
- from typing import BinaryIO, Callable, List, Union
5
+ import time
6
+ from typing import BinaryIO, Callable, Optional, Sequence, Union
6
7
 
7
- HASH_CHUNK_SIZE = 4096
8
+ from modal.config import logger
8
9
 
10
+ HASH_CHUNK_SIZE = 65536
9
11
 
10
- def _update(hashers: List[Callable[[bytes], None]], data: Union[bytes, BinaryIO]) -> None:
12
+
13
+ def _update(hashers: Sequence[Callable[[bytes], None]], data: Union[bytes, BinaryIO]) -> None:
11
14
  if isinstance(data, bytes):
12
15
  for hasher in hashers:
13
16
  hasher(data)
@@ -26,20 +29,26 @@ def _update(hashers: List[Callable[[bytes], None]], data: Union[bytes, BinaryIO]
26
29
 
27
30
 
28
31
  def get_sha256_hex(data: Union[bytes, BinaryIO]) -> str:
32
+ t0 = time.monotonic()
29
33
  hasher = hashlib.sha256()
30
34
  _update([hasher.update], data)
35
+ logger.debug("get_sha256_hex took %.3fs", time.monotonic() - t0)
31
36
  return hasher.hexdigest()
32
37
 
33
38
 
34
39
  def get_sha256_base64(data: Union[bytes, BinaryIO]) -> str:
40
+ t0 = time.monotonic()
35
41
  hasher = hashlib.sha256()
36
42
  _update([hasher.update], data)
43
+ logger.debug("get_sha256_base64 took %.3fs", time.monotonic() - t0)
37
44
  return base64.b64encode(hasher.digest()).decode("ascii")
38
45
 
39
46
 
40
47
  def get_md5_base64(data: Union[bytes, BinaryIO]) -> str:
48
+ t0 = time.monotonic()
41
49
  hasher = hashlib.md5()
42
50
  _update([hasher.update], data)
51
+ logger.debug("get_md5_base64 took %.3fs", time.monotonic() - t0)
43
52
  return base64.b64encode(hasher.digest()).decode("utf-8")
44
53
 
45
54
 
@@ -48,12 +57,44 @@ class UploadHashes:
48
57
  md5_base64: str
49
58
  sha256_base64: str
50
59
 
60
+ def md5_hex(self) -> str:
61
+ return base64.b64decode(self.md5_base64).hex()
62
+
63
+ def sha256_hex(self) -> str:
64
+ return base64.b64decode(self.sha256_base64).hex()
65
+
66
+
67
+ def get_upload_hashes(
68
+ data: Union[bytes, BinaryIO], sha256_hex: Optional[str] = None, md5_hex: Optional[str] = None
69
+ ) -> UploadHashes:
70
+ t0 = time.monotonic()
71
+ hashers = {}
72
+
73
+ if not sha256_hex:
74
+ sha256 = hashlib.sha256()
75
+ hashers["sha256"] = sha256
76
+ if not md5_hex:
77
+ md5 = hashlib.md5()
78
+ hashers["md5"] = md5
79
+
80
+ if hashers:
81
+ updaters = [h.update for h in hashers.values()]
82
+ _update(updaters, data)
51
83
 
52
- def get_upload_hashes(data: Union[bytes, BinaryIO]) -> UploadHashes:
53
- md5 = hashlib.md5()
54
- sha256 = hashlib.sha256()
55
- _update([md5.update, sha256.update], data)
56
- return UploadHashes(
57
- md5_base64=base64.b64encode(md5.digest()).decode("ascii"),
58
- sha256_base64=base64.b64encode(sha256.digest()).decode("ascii"),
84
+ if sha256_hex:
85
+ sha256_base64 = base64.b64encode(bytes.fromhex(sha256_hex)).decode("ascii")
86
+ else:
87
+ sha256_base64 = base64.b64encode(hashers["sha256"].digest()).decode("ascii")
88
+
89
+ if md5_hex:
90
+ md5_base64 = base64.b64encode(bytes.fromhex(md5_hex)).decode("ascii")
91
+ else:
92
+ md5_base64 = base64.b64encode(hashers["md5"].digest()).decode("ascii")
93
+
94
+ hashes = UploadHashes(
95
+ md5_base64=md5_base64,
96
+ sha256_base64=sha256_base64,
59
97
  )
98
+
99
+ logger.debug("get_upload_hashes took %.3fs (%s)", time.monotonic() - t0, hashers.keys())
100
+ return hashes
@@ -1,16 +1,18 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import contextlib
3
- import socket
4
- import ssl
5
- from typing import Optional
3
+ from typing import TYPE_CHECKING, Optional
6
4
 
7
- import certifi
8
- from aiohttp import ClientSession, ClientTimeout, TCPConnector
9
- from aiohttp.web import Application
10
- from aiohttp.web_runner import AppRunner, SockSite
5
+ # Note: importing aiohttp seems to take about 100ms, and it's not really necessarily,
6
+ # unless we need to work with blobs. So that's why we import it lazily instead.
11
7
 
8
+ if TYPE_CHECKING:
9
+ from aiohttp import ClientSession
10
+ from aiohttp.web import Application
12
11
 
13
- def http_client_with_tls(timeout: Optional[float]) -> ClientSession:
12
+ from .async_utils import on_shutdown
13
+
14
+
15
+ def _http_client_with_tls(timeout: Optional[float]) -> "ClientSession":
14
16
  """Create a new HTTP client session with standard, bundled TLS certificates.
15
17
 
16
18
  This is necessary to prevent client issues on some system where Python does
@@ -20,15 +22,43 @@ def http_client_with_tls(timeout: Optional[float]) -> ClientSession:
20
22
  Specifically: the error "unable to get local issuer certificate" when making
21
23
  an aiohttp request.
22
24
  """
25
+ import ssl
26
+
27
+ import certifi
28
+ from aiohttp import ClientSession, ClientTimeout, TCPConnector
29
+
23
30
  ssl_context = ssl.create_default_context(cafile=certifi.where())
24
31
  connector = TCPConnector(ssl=ssl_context)
25
32
  return ClientSession(connector=connector, timeout=ClientTimeout(total=timeout))
26
33
 
27
34
 
35
+ class ClientSessionRegistry:
36
+ _client_session: "ClientSession"
37
+ _client_session_active: bool = False
38
+
39
+ @staticmethod
40
+ def get_session():
41
+ if not ClientSessionRegistry._client_session_active:
42
+ ClientSessionRegistry._client_session = _http_client_with_tls(timeout=None)
43
+ ClientSessionRegistry._client_session_active = True
44
+ on_shutdown(ClientSessionRegistry.close_session())
45
+ return ClientSessionRegistry._client_session
46
+
47
+ @staticmethod
48
+ async def close_session():
49
+ if ClientSessionRegistry._client_session_active:
50
+ await ClientSessionRegistry._client_session.close()
51
+ ClientSessionRegistry._client_session_active = False
52
+
53
+
28
54
  @contextlib.asynccontextmanager
29
- async def run_temporary_http_server(app: Application):
55
+ async def run_temporary_http_server(app: "Application"):
30
56
  # Allocates a random port, runs a server in a context manager
31
57
  # This is used in various tests
58
+ import socket
59
+
60
+ from aiohttp.web_runner import AppRunner, SockSite
61
+
32
62
  sock = socket.socket()
33
63
  sock.bind(("", 0))
34
64
  port = sock.getsockname()[1]
modal/_utils/logger.py CHANGED
@@ -17,7 +17,8 @@ def configure_logger(logger: logging.Logger, log_level: str, log_format: str):
17
17
  json_formatter = jsonlogger.JsonFormatter(
18
18
  fmt=(
19
19
  "%(asctime)s %(levelname)s [%(name)s] [%(filename)s:%(lineno)d] "
20
- "[dd.service=%(dd.service)s dd.env=%(dd.env)s dd.version=%(dd.version)s dd.trace_id=%(dd.trace_id)s dd.span_id=%(dd.span_id)s] "
20
+ "[dd.service=%(dd.service)s dd.env=%(dd.env)s dd.version=%(dd.version)s dd.trace_id=%(dd.trace_id)s "
21
+ "dd.span_id=%(dd.span_id)s] "
21
22
  "- %(message)s"
22
23
  ),
23
24
  datefmt="%Y-%m-%dT%H:%M:%S%z",
@@ -1,26 +1,29 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import posixpath
3
3
  import typing
4
+ from collections.abc import Mapping, Sequence
4
5
  from pathlib import PurePath, PurePosixPath
5
- from typing import TYPE_CHECKING, Dict, List, Mapping, Sequence, Tuple, Union
6
+ from typing import Union
6
7
 
8
+ from ..cloud_bucket_mount import _CloudBucketMount
7
9
  from ..exception import InvalidError
10
+ from ..network_file_system import _NetworkFileSystem
8
11
  from ..volume import _Volume
9
12
 
10
- if TYPE_CHECKING:
11
- from ..cloud_bucket_mount import _CloudBucketMount
12
- from ..network_file_system import _NetworkFileSystem
13
-
14
-
15
- T = typing.TypeVar("T", bound=Union["_Volume", "_NetworkFileSystem", "_CloudBucketMount"])
13
+ T = typing.TypeVar("T", bound=Union[_Volume, _NetworkFileSystem, _CloudBucketMount])
16
14
 
17
15
 
18
16
  def validate_mount_points(
19
17
  display_name: str,
20
18
  volume_likes: Mapping[Union[str, PurePosixPath], T],
21
- ) -> List[Tuple[str, T]]:
19
+ ) -> list[tuple[str, T]]:
22
20
  """Mount point path validation for volumes and network file systems."""
23
21
 
22
+ if not isinstance(volume_likes, dict):
23
+ raise InvalidError(
24
+ f"`volume_likes` should be a dict[str | PurePosixPath, {display_name}], got {type(volume_likes)} instead"
25
+ )
26
+
24
27
  validated = []
25
28
  for path, vol in volume_likes.items():
26
29
  path = PurePath(path).as_posix()
@@ -38,17 +41,32 @@ def validate_mount_points(
38
41
  return validated
39
42
 
40
43
 
41
- def validate_volumes(
42
- volumes: Mapping[Union[str, PurePosixPath], Union["_Volume", "_CloudBucketMount"]],
43
- ) -> Sequence[Tuple[str, Union["_Volume", "_NetworkFileSystem", "_CloudBucketMount"]]]:
44
- if not isinstance(volumes, dict):
45
- raise InvalidError("volumes must be a dict[str, Volume] where the keys are paths")
44
+ def validate_network_file_systems(
45
+ network_file_systems: Mapping[Union[str, PurePosixPath], _NetworkFileSystem],
46
+ ):
47
+ validated_network_file_systems = validate_mount_points("NetworkFileSystem", network_file_systems)
48
+
49
+ for path, network_file_system in validated_network_file_systems:
50
+ if not isinstance(network_file_system, (_NetworkFileSystem)):
51
+ raise InvalidError(
52
+ f"Object of type {type(network_file_system)} mounted at '{path}' "
53
+ + "is not useable as a network file system."
54
+ )
55
+
56
+ return validated_network_file_systems
46
57
 
58
+
59
+ def validate_volumes(
60
+ volumes: Mapping[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]],
61
+ ) -> Sequence[tuple[str, Union[_Volume, _CloudBucketMount]]]:
47
62
  validated_volumes = validate_mount_points("Volume", volumes)
48
- # We don't support mounting a volume in more than one location
49
- volume_to_paths: Dict["_Volume", List[str]] = {}
63
+ # We don't support mounting a modal.Volume in more than one location,
64
+ # but the same CloudBucketMount object can be used in more than one location.
65
+ volume_to_paths: dict[_Volume, list[str]] = {}
50
66
  for path, volume in validated_volumes:
51
- if isinstance(volume, _Volume):
67
+ if not isinstance(volume, (_Volume, _CloudBucketMount)):
68
+ raise InvalidError(f"Object of type {type(volume)} mounted at '{path}' is not useable as a volume.")
69
+ elif isinstance(volume, (_Volume)):
52
70
  volume_to_paths.setdefault(volume, []).append(path)
53
71
  for paths in volume_to_paths.values():
54
72
  if len(paths) > 1:
@@ -0,0 +1,58 @@
1
+ # Copyright Modal Labs 2022
2
+ import re
3
+
4
+ from ..exception import InvalidError
5
+
6
+ # https://www.rfc-editor.org/rfc/rfc1035
7
+ subdomain_regex = re.compile("^(?![0-9]+$)(?!-)[a-z0-9-]{,63}(?<!-)$")
8
+
9
+
10
+ def is_valid_subdomain_label(label: str) -> bool:
11
+ return subdomain_regex.match(label) is not None
12
+
13
+
14
+ def replace_invalid_subdomain_chars(label: str) -> str:
15
+ return re.sub("[^a-z0-9-]", "-", label.lower())
16
+
17
+
18
+ def is_valid_object_name(name: str) -> bool:
19
+ return (
20
+ # Limit object name length
21
+ len(name) <= 64
22
+ # Limit character set
23
+ and re.match("^[a-zA-Z0-9-_.]+$", name) is not None
24
+ # Avoid collisions with App IDs
25
+ and re.match("^ap-[a-zA-Z0-9]{22}$", name) is None
26
+ )
27
+
28
+
29
+ def is_valid_environment_name(name: str) -> bool:
30
+ # first char is alnum, the rest allows other chars
31
+ return len(name) <= 64 and re.match(r"^[a-zA-Z0-9][a-zA-Z0-9-_.]+$", name) is not None
32
+
33
+
34
+ def is_valid_tag(tag: str) -> bool:
35
+ """Tags are alphanumeric, dashes, periods, and underscores, and must be 50 characters or less"""
36
+ pattern = r"^[a-zA-Z0-9._-]{1,50}$"
37
+ return bool(re.match(pattern, tag))
38
+
39
+
40
+ def check_object_name(name: str, object_type: str) -> None:
41
+ message = (
42
+ f"Invalid {object_type} name: '{name}'."
43
+ "\n\nNames may contain only alphanumeric characters, dashes, periods, and underscores,"
44
+ " must be shorter than 64 characters, and cannot conflict with App ID strings."
45
+ )
46
+ if not is_valid_object_name(name):
47
+ raise InvalidError(message)
48
+
49
+
50
+ def check_environment_name(name: str) -> None:
51
+ message = (
52
+ f"Invalid environment name: '{name}'."
53
+ "\n\nEnvironment names can only start with alphanumeric characters,"
54
+ " may contain only alphanumeric characters, dashes, periods, and underscores,"
55
+ " and must be shorter than 64 characters."
56
+ )
57
+ if not is_valid_environment_name(name):
58
+ raise InvalidError(message)
@@ -23,7 +23,7 @@ def get_file_formats(module):
23
23
  BINARY_FORMATS = ["so", "S", "s", "asm"] # TODO
24
24
 
25
25
 
26
- def get_module_mount_info(module_name: str) -> typing.Sequence[typing.Tuple[bool, Path]]:
26
+ def get_module_mount_info(module_name: str) -> typing.Sequence[tuple[bool, Path]]:
27
27
  """Returns a list of tuples [(is_dir, path)] describing how to mount a given module."""
28
28
  file_formats = get_file_formats(module_name)
29
29
  if set(BINARY_FORMATS) & set(file_formats):
@@ -46,3 +46,16 @@ def get_module_mount_info(module_name: str) -> typing.Sequence[typing.Tuple[bool
46
46
  if not entries:
47
47
  raise ModuleNotMountable(f"{module_name} has no mountable paths")
48
48
  return entries
49
+
50
+
51
+ def parse_major_minor_version(version_string: str) -> tuple[int, int]:
52
+ parts = version_string.split(".")
53
+ if len(parts) < 2:
54
+ raise ValueError("version_string must have at least an 'X.Y' format")
55
+ try:
56
+ major = int(parts[0])
57
+ minor = int(parts[1])
58
+ except ValueError:
59
+ raise ValueError("version_string must have at least an 'X.Y' format with integral major/minor values")
60
+
61
+ return major, minor