modal 1.2.1.dev19__py3-none-any.whl → 1.2.2.dev21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. modal/_clustered_functions.py +1 -3
  2. modal/_container_entrypoint.py +4 -1
  3. modal/_functions.py +33 -49
  4. modal/_grpc_client.py +148 -0
  5. modal/_output.py +3 -4
  6. modal/_runtime/container_io_manager.py +21 -22
  7. modal/_utils/async_utils.py +12 -3
  8. modal/_utils/auth_token_manager.py +1 -4
  9. modal/_utils/blob_utils.py +3 -4
  10. modal/_utils/grpc_utils.py +80 -51
  11. modal/_utils/mount_utils.py +26 -1
  12. modal/_utils/task_command_router_client.py +3 -4
  13. modal/app.py +3 -4
  14. modal/cli/config.py +3 -1
  15. modal/cli/container.py +1 -2
  16. modal/cli/entry_point.py +1 -0
  17. modal/cli/launch.py +1 -2
  18. modal/cli/network_file_system.py +1 -4
  19. modal/cli/queues.py +1 -2
  20. modal/cli/secret.py +1 -2
  21. modal/client.py +5 -115
  22. modal/client.pyi +2 -91
  23. modal/cls.py +1 -2
  24. modal/config.py +1 -1
  25. modal/container_process.py +4 -8
  26. modal/dict.py +12 -12
  27. modal/environments.py +1 -2
  28. modal/experimental/__init__.py +2 -3
  29. modal/experimental/flash.py +6 -10
  30. modal/file_io.py +13 -27
  31. modal/functions.pyi +6 -6
  32. modal/image.py +24 -3
  33. modal/image.pyi +4 -0
  34. modal/io_streams.py +61 -91
  35. modal/io_streams.pyi +33 -95
  36. modal/mount.py +4 -4
  37. modal/network_file_system.py +5 -6
  38. modal/parallel_map.py +29 -31
  39. modal/parallel_map.pyi +3 -9
  40. modal/queue.py +17 -18
  41. modal/runner.py +8 -8
  42. modal/sandbox.py +23 -36
  43. modal/secret.py +4 -5
  44. modal/snapshot.py +1 -4
  45. modal/token_flow.py +1 -1
  46. modal/volume.py +20 -22
  47. {modal-1.2.1.dev19.dist-info → modal-1.2.2.dev21.dist-info}/METADATA +1 -1
  48. {modal-1.2.1.dev19.dist-info → modal-1.2.2.dev21.dist-info}/RECORD +57 -56
  49. modal_proto/api.proto +3 -0
  50. modal_proto/api_pb2.py +1028 -1015
  51. modal_proto/api_pb2.pyi +29 -3
  52. modal_proto/modal_api_grpc.py +175 -175
  53. modal_version/__init__.py +1 -1
  54. {modal-1.2.1.dev19.dist-info → modal-1.2.2.dev21.dist-info}/WHEEL +0 -0
  55. {modal-1.2.1.dev19.dist-info → modal-1.2.2.dev21.dist-info}/entry_points.txt +0 -0
  56. {modal-1.2.1.dev19.dist-info → modal-1.2.2.dev21.dist-info}/licenses/LICENSE +0 -0
  57. {modal-1.2.1.dev19.dist-info → modal-1.2.2.dev21.dist-info}/top_level.txt +0 -0
@@ -8,12 +8,8 @@ import typing
8
8
  import urllib.parse
9
9
  import uuid
10
10
  from collections.abc import AsyncIterator
11
- from dataclasses import dataclass
12
- from typing import (
13
- Any,
14
- Optional,
15
- TypeVar,
16
- )
11
+ from dataclasses import dataclass, field
12
+ from typing import Any, Optional, TypeVar
17
13
 
18
14
  import grpclib.client
19
15
  import grpclib.config
@@ -28,6 +24,7 @@ from grpclib.protocol import H2Protocol
28
24
  from modal.exception import AuthError, ConnectionError
29
25
  from modal_version import __version__
30
26
 
27
+ from .._traceback import suppress_tb_frames
31
28
  from .async_utils import retry
32
29
  from .logger import logger
33
30
 
@@ -35,6 +32,7 @@ RequestType = TypeVar("RequestType", bound=Message)
35
32
  ResponseType = TypeVar("ResponseType", bound=Message)
36
33
 
37
34
  if typing.TYPE_CHECKING:
35
+ import modal._grpc_client
38
36
  import modal.client
39
37
 
40
38
  # Monkey patches grpclib to have a Modal User Agent header.
@@ -165,7 +163,7 @@ if typing.TYPE_CHECKING:
165
163
 
166
164
 
167
165
  async def unary_stream(
168
- method: "modal.client.UnaryStreamWrapper[RequestType, ResponseType]",
166
+ method: "modal._grpc_client.UnaryStreamWrapper[RequestType, ResponseType]",
169
167
  request: RequestType,
170
168
  metadata: Optional[Any] = None,
171
169
  ) -> AsyncIterator[ResponseType]:
@@ -174,37 +172,66 @@ async def unary_stream(
174
172
  yield item
175
173
 
176
174
 
175
+ @dataclass(frozen=True)
176
+ class Retry:
177
+ base_delay: float = 0.1
178
+ max_delay: float = 1
179
+ delay_factor: float = 2
180
+ max_retries: Optional[int] = 3
181
+ additional_status_codes: list = field(default_factory=list)
182
+ attempt_timeout: Optional[float] = None # timeout for each attempt
183
+ total_timeout: Optional[float] = None # timeout for the entire function call
184
+ attempt_timeout_floor: float = 2.0 # always have at least this much timeout (only for total_timeout)
185
+ warning_message: Optional[RetryWarningMessage] = None
186
+
187
+
177
188
  async def retry_transient_errors(
178
- fn: "modal.client.UnaryUnaryWrapper[RequestType, ResponseType]",
179
- *args,
180
- base_delay: float = 0.1,
181
- max_delay: float = 1,
182
- delay_factor: float = 2,
189
+ fn: "grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]",
190
+ req: RequestType,
183
191
  max_retries: Optional[int] = 3,
184
- additional_status_codes: list = [],
185
- attempt_timeout: Optional[float] = None, # timeout for each attempt
186
- total_timeout: Optional[float] = None, # timeout for the entire function call
187
- attempt_timeout_floor=2.0, # always have at least this much timeout (only for total_timeout)
188
- retry_warning_message: Optional[RetryWarningMessage] = None,
189
- metadata: list[tuple[str, str]] = [],
192
+ ) -> ResponseType:
193
+ """Minimum API version of _retry_transient_errors that works with grpclib.client.UnaryUnaryMethod.
194
+
195
+ Used by modal server.
196
+ """
197
+ return await _retry_transient_errors(fn, req, retry=Retry(max_retries=max_retries))
198
+
199
+
200
+ async def _retry_transient_errors(
201
+ fn: typing.Union[
202
+ "modal._grpc_client.UnaryUnaryWrapper[RequestType, ResponseType]",
203
+ "grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]",
204
+ ],
205
+ req: RequestType,
206
+ retry: Retry,
207
+ metadata: Optional[list[tuple[str, str]]] = None,
190
208
  ) -> ResponseType:
191
209
  """Retry on transient gRPC failures with back-off until max_retries is reached.
192
210
  If max_retries is None, retry forever."""
211
+ import modal._grpc_client
212
+
213
+ if isinstance(fn, modal._grpc_client.UnaryUnaryWrapper):
214
+ fn_callable = fn.direct
215
+ elif isinstance(fn, grpclib.client.UnaryUnaryMethod):
216
+ fn_callable = fn # type: ignore
217
+ else:
218
+ raise ValueError("Only modal._grpc_client.UnaryUnaryWrapper and grpclib.client.UnaryUnaryMethod are supported")
193
219
 
194
- delay = base_delay
220
+ delay = retry.base_delay
195
221
  n_retries = 0
196
222
 
197
- status_codes = [*RETRYABLE_GRPC_STATUS_CODES, *additional_status_codes]
223
+ status_codes = [*RETRYABLE_GRPC_STATUS_CODES, *retry.additional_status_codes]
198
224
 
199
225
  idempotency_key = str(uuid.uuid4())
200
226
 
201
227
  t0 = time.time()
202
- if total_timeout is not None:
203
- total_deadline = t0 + total_timeout
228
+ if retry.total_timeout is not None:
229
+ total_deadline = t0 + retry.total_timeout
204
230
  else:
205
231
  total_deadline = None
206
232
 
207
- metadata = metadata + [("x-modal-timestamp", str(time.time()))]
233
+ metadata = (metadata or []) + [("x-modal-timestamp", str(time.time()))]
234
+
208
235
  while True:
209
236
  attempt_metadata = [
210
237
  ("x-idempotency-key", idempotency_key),
@@ -214,16 +241,17 @@ async def retry_transient_errors(
214
241
  if n_retries > 0:
215
242
  attempt_metadata.append(("x-retry-delay", str(time.time() - t0)))
216
243
  timeouts = []
217
- if attempt_timeout is not None:
218
- timeouts.append(attempt_timeout)
219
- if total_timeout is not None:
220
- timeouts.append(max(total_deadline - time.time(), attempt_timeout_floor))
244
+ if retry.attempt_timeout is not None:
245
+ timeouts.append(retry.attempt_timeout)
246
+ if retry.total_timeout is not None and total_deadline is not None:
247
+ timeouts.append(max(total_deadline - time.time(), retry.attempt_timeout_floor))
221
248
  if timeouts:
222
249
  timeout = min(timeouts) # In case the function provided both types of timeouts
223
250
  else:
224
251
  timeout = None
225
252
  try:
226
- return await fn(*args, metadata=attempt_metadata, timeout=timeout)
253
+ with suppress_tb_frames(1):
254
+ return await fn_callable(req, metadata=attempt_metadata, timeout=timeout)
227
255
  except (StreamTerminatedError, GRPCError, OSError, asyncio.TimeoutError, AttributeError) as exc:
228
256
  if isinstance(exc, GRPCError) and exc.status not in status_codes:
229
257
  if exc.status == Status.UNAUTHENTICATED:
@@ -231,45 +259,46 @@ async def retry_transient_errors(
231
259
  else:
232
260
  raise exc
233
261
 
234
- if max_retries is not None and n_retries >= max_retries:
262
+ if retry.max_retries is not None and n_retries >= retry.max_retries:
235
263
  final_attempt = True
236
- elif total_deadline is not None and time.time() + delay + attempt_timeout_floor >= total_deadline:
264
+ elif total_deadline is not None and time.time() + delay + retry.attempt_timeout_floor >= total_deadline:
237
265
  final_attempt = True
238
266
  else:
239
267
  final_attempt = False
240
268
 
241
- if final_attempt:
242
- logger.debug(
243
- f"Final attempt failed with {repr(exc)} {n_retries=} {delay=} "
244
- f"{total_deadline=} for {fn.name} ({idempotency_key[:8]})"
245
- )
246
- if isinstance(exc, OSError):
247
- raise ConnectionError(str(exc))
248
- elif isinstance(exc, asyncio.TimeoutError):
249
- raise ConnectionError(str(exc))
250
- else:
269
+ with suppress_tb_frames(1):
270
+ if final_attempt:
271
+ logger.debug(
272
+ f"Final attempt failed with {repr(exc)} {n_retries=} {delay=} "
273
+ f"{total_deadline=} for {fn.name} ({idempotency_key[:8]})"
274
+ )
275
+ if isinstance(exc, OSError):
276
+ raise ConnectionError(str(exc))
277
+ elif isinstance(exc, asyncio.TimeoutError):
278
+ raise ConnectionError(str(exc))
279
+ else:
280
+ raise exc
281
+
282
+ if isinstance(exc, AttributeError) and "_write_appdata" not in str(exc):
283
+ # StreamTerminatedError are not properly raised in grpclib<=0.4.7
284
+ # fixed in https://github.com/vmagamedov/grpclib/issues/185
285
+ # TODO: update to newer version (>=0.4.8) once stable
251
286
  raise exc
252
287
 
253
- if isinstance(exc, AttributeError) and "_write_appdata" not in str(exc):
254
- # StreamTerminatedError are not properly raised in grpclib<=0.4.7
255
- # fixed in https://github.com/vmagamedov/grpclib/issues/185
256
- # TODO: update to newer version (>=0.4.8) once stable
257
- raise exc
258
-
259
288
  logger.debug(f"Retryable failure {repr(exc)} {n_retries=} {delay=} for {fn.name} ({idempotency_key[:8]})")
260
289
 
261
290
  n_retries += 1
262
291
 
263
292
  if (
264
- retry_warning_message
265
- and n_retries % retry_warning_message.warning_interval == 0
293
+ retry.warning_message
294
+ and n_retries % retry.warning_message.warning_interval == 0
266
295
  and isinstance(exc, GRPCError)
267
- and exc.status in retry_warning_message.errors_to_warn_for
296
+ and exc.status in retry.warning_message.errors_to_warn_for
268
297
  ):
269
- logger.warning(retry_warning_message.message)
298
+ logger.warning(retry.warning_message.message)
270
299
 
271
300
  await asyncio.sleep(delay)
272
- delay = min(delay * delay_factor, max_delay)
301
+ delay = min(delay * retry.delay_factor, retry.max_delay)
273
302
 
274
303
 
275
304
  def find_free_port() -> int:
@@ -3,7 +3,9 @@ import posixpath
3
3
  import typing
4
4
  from collections.abc import Mapping, Sequence
5
5
  from pathlib import PurePath, PurePosixPath
6
- from typing import Union
6
+ from typing import Optional, Union
7
+
8
+ from typing_extensions import TypeGuard
7
9
 
8
10
  from ..cloud_bucket_mount import _CloudBucketMount
9
11
  from ..exception import InvalidError
@@ -76,3 +78,26 @@ def validate_volumes(
76
78
  )
77
79
 
78
80
  return validated_volumes
81
+
82
+
83
+ def validate_only_modal_volumes(
84
+ volumes: Optional[Optional[dict[Union[str, PurePosixPath], _Volume]]],
85
+ caller_name: str,
86
+ ) -> Sequence[tuple[str, _Volume]]:
87
+ """Validate all volumes are `modal.Volume`."""
88
+ if volumes is None:
89
+ return []
90
+
91
+ validated_volumes = validate_volumes(volumes)
92
+
93
+ # Although the typing forbids `_CloudBucketMount` for type checking, one can still pass a `_CloudBucketMount`
94
+ # during runtime, so we'll check the type here.
95
+ def all_modal_volumes(
96
+ vols: Sequence[tuple[str, Union[_Volume, _CloudBucketMount]]],
97
+ ) -> TypeGuard[Sequence[tuple[str, _Volume]]]:
98
+ return all(isinstance(v, _Volume) for _, v in vols)
99
+
100
+ if not all_modal_volumes(validated_volumes):
101
+ raise InvalidError(f"{caller_name} only supports volumes that are modal.Volume")
102
+
103
+ return validated_volumes
@@ -18,7 +18,7 @@ from modal.exception import ExecTimeoutError
18
18
  from modal_proto import api_pb2, task_command_router_pb2 as sr_pb2
19
19
  from modal_proto.task_command_router_grpc import TaskCommandRouterStub
20
20
 
21
- from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES, connect_channel, retry_transient_errors
21
+ from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES, connect_channel
22
22
 
23
23
 
24
24
  def _b64url_decode(data: str) -> bytes:
@@ -99,8 +99,7 @@ async def call_with_retries_on_transient_errors(
99
99
 
100
100
  async def fetch_command_router_access(server_client, task_id: str) -> api_pb2.TaskGetCommandRouterAccessResponse:
101
101
  """Fetch direct command router access info from Modal server."""
102
- return await retry_transient_errors(
103
- server_client.stub.TaskGetCommandRouterAccess,
102
+ return await server_client.stub.TaskGetCommandRouterAccess(
104
103
  api_pb2.TaskGetCommandRouterAccessRequest(task_id=task_id),
105
104
  )
106
105
 
@@ -444,7 +443,7 @@ class TaskCommandRouterClient:
444
443
  except Exception as e:
445
444
  # Exceptions here can stem from non-transient errors against the server sending
446
445
  # the TaskGetCommandRouterAccess RPC, for instance, if the task has finished.
447
- logger.warning(f"Background JWT refresh failed for exec with task ID {self._task_id}: {e}")
446
+ logger.debug(f"Background JWT refresh failed for exec with task ID {self._task_id}: {e}")
448
447
  break
449
448
 
450
449
  async def _stream_stdio(
modal/app.py CHANGED
@@ -35,7 +35,6 @@ from ._utils.deprecation import (
35
35
  warn_on_renamed_autoscaler_settings,
36
36
  )
37
37
  from ._utils.function_utils import FunctionInfo, is_flash_object, is_global_object, is_method_fn
38
- from ._utils.grpc_utils import retry_transient_errors
39
38
  from ._utils.mount_utils import validate_volumes
40
39
  from ._utils.name_utils import check_object_name, check_tag_dict
41
40
  from .client import _Client
@@ -303,7 +302,7 @@ class _App:
303
302
  object_creation_type=(api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING if create_if_missing else None),
304
303
  )
305
304
 
306
- response = await retry_transient_errors(client.stub.AppGetOrCreate, request)
305
+ response = await client.stub.AppGetOrCreate(request)
307
306
 
308
307
  app = _App(name) # TODO: this should probably be a distinct constructor, possibly even a distinct type
309
308
  app._local_state_attr = None # this is not a locally defined App, so no local state
@@ -1183,7 +1182,7 @@ class _App:
1183
1182
  req = api_pb2.AppSetTagsRequest(app_id=self._app_id, tags=tags)
1184
1183
 
1185
1184
  client = client or self._client or await _Client.from_env()
1186
- await retry_transient_errors(client.stub.AppSetTags, req)
1185
+ await client.stub.AppSetTags(req)
1187
1186
 
1188
1187
  async def get_tags(self, *, client: Optional[_Client] = None) -> dict[str, str]:
1189
1188
  """Get the tags that are currently attached to the App."""
@@ -1191,7 +1190,7 @@ class _App:
1191
1190
  raise InvalidError("`App.get_tags` cannot be called before the App is running.")
1192
1191
  req = api_pb2.AppGetTagsRequest(app_id=self._app_id)
1193
1192
  client = client or self._client or await _Client.from_env()
1194
- resp = await retry_transient_errors(client.stub.AppGetTags, req)
1193
+ resp = await client.stub.AppGetTags(req)
1195
1194
  return dict(resp.tags)
1196
1195
 
1197
1196
  async def _logs(self, client: Optional[_Client] = None) -> AsyncGenerator[str, None]:
modal/cli/config.py CHANGED
@@ -1,4 +1,6 @@
1
1
  # Copyright Modal Labs 2022
2
+ import json
3
+
2
4
  import typer
3
5
 
4
6
  from modal._output import make_console
@@ -25,7 +27,7 @@ def show(redact: bool = typer.Option(True, help="Redact the `token_secret` value
25
27
  config_dict["token_secret"] = "***"
26
28
 
27
29
  console = make_console()
28
- console.print(config_dict)
30
+ console.print_json(json.dumps(config_dict))
29
31
 
30
32
 
31
33
  SET_DEFAULT_ENV_HELP = """Set the default Modal environment for the active profile
modal/cli/container.py CHANGED
@@ -7,7 +7,6 @@ from rich.text import Text
7
7
  from modal._object import _get_environment_name
8
8
  from modal._pty import get_pty_info
9
9
  from modal._utils.async_utils import synchronizer
10
- from modal._utils.grpc_utils import retry_transient_errors
11
10
  from modal._utils.time_utils import timestamp_to_localized_str
12
11
  from modal.cli.utils import ENV_OPTION, display_table, is_tty, stream_app_logs
13
12
  from modal.client import _Client
@@ -97,4 +96,4 @@ async def stop(container_id: str = typer.Argument(help="Container ID")):
97
96
  """
98
97
  client = await _Client.from_env()
99
98
  request = api_pb2.ContainerStopRequest(task_id=container_id)
100
- await retry_transient_errors(client.stub.ContainerStop, request)
99
+ await client.stub.ContainerStop(request)
modal/cli/entry_point.py CHANGED
@@ -36,6 +36,7 @@ entrypoint_cli_typer = typer.Typer(
36
36
  no_args_is_help=False,
37
37
  add_completion=False,
38
38
  rich_markup_mode="markdown",
39
+ context_settings={"help_option_names": ["-h", "--help"]},
39
40
  help="""
40
41
  Modal is the fastest way to run code in the cloud.
41
42
 
modal/cli/launch.py CHANGED
@@ -23,8 +23,7 @@ launch_cli = Typer(
23
23
  no_args_is_help=True,
24
24
  rich_markup_mode="markdown",
25
25
  help="""
26
- Open a serverless app instance on Modal.
27
- >⚠️ `modal launch` is **experimental** and may change in the future.
26
+ [Experimental] Open a serverless app instance on Modal.
28
27
  """,
29
28
  )
30
29
 
@@ -15,7 +15,6 @@ import modal
15
15
  from modal._location import display_location
16
16
  from modal._output import OutputManager, ProgressHandler, make_console
17
17
  from modal._utils.async_utils import synchronizer
18
- from modal._utils.grpc_utils import retry_transient_errors
19
18
  from modal._utils.time_utils import timestamp_to_localized_str
20
19
  from modal.cli._download import _volume_download
21
20
  from modal.cli.utils import ENV_OPTION, YES_OPTION, display_table
@@ -33,9 +32,7 @@ async def list_(env: Optional[str] = ENV_OPTION, json: Optional[bool] = False):
33
32
  env = ensure_env(env)
34
33
 
35
34
  client = await _Client.from_env()
36
- response = await retry_transient_errors(
37
- client.stub.SharedVolumeList, api_pb2.SharedVolumeListRequest(environment_name=env)
38
- )
35
+ response = await client.stub.SharedVolumeList(api_pb2.SharedVolumeListRequest(environment_name=env))
39
36
  env_part = f" in environment '{env}'" if env else ""
40
37
  column_names = ["Name", "Location", "Created at"]
41
38
  rows = []
modal/cli/queues.py CHANGED
@@ -8,7 +8,6 @@ from typer import Argument, Option, Typer
8
8
  from modal._output import make_console
9
9
  from modal._resolver import Resolver
10
10
  from modal._utils.async_utils import synchronizer
11
- from modal._utils.grpc_utils import retry_transient_errors
12
11
  from modal._utils.time_utils import timestamp_to_localized_str
13
12
  from modal.cli.utils import ENV_OPTION, YES_OPTION, display_table
14
13
  from modal.client import _Client
@@ -83,7 +82,7 @@ async def list_(*, json: bool = False, env: Optional[str] = ENV_OPTION):
83
82
  max_page_size = 100
84
83
  pagination = api_pb2.ListPagination(max_objects=max_page_size, created_before=created_before)
85
84
  req = api_pb2.QueueListRequest(environment_name=env, pagination=pagination, total_size_limit=max_total_size)
86
- resp = await retry_transient_errors(client.stub.QueueList, req)
85
+ resp = await client.stub.QueueList(req)
87
86
  items.extend(resp.queues)
88
87
  return len(resp.queues) < max_page_size
89
88
 
modal/cli/secret.py CHANGED
@@ -15,7 +15,6 @@ from typer import Argument, Option
15
15
 
16
16
  from modal._output import make_console
17
17
  from modal._utils.async_utils import synchronizer
18
- from modal._utils.grpc_utils import retry_transient_errors
19
18
  from modal._utils.time_utils import timestamp_to_localized_str
20
19
  from modal.cli.utils import ENV_OPTION, YES_OPTION, display_table
21
20
  from modal.client import _Client
@@ -44,7 +43,7 @@ async def list_(env: Optional[str] = ENV_OPTION, json: bool = False):
44
43
  max_page_size = 100
45
44
  pagination = api_pb2.ListPagination(max_objects=max_page_size, created_before=created_before)
46
45
  req = api_pb2.SecretListRequest(environment_name=env, pagination=pagination)
47
- resp = await retry_transient_errors(client.stub.SecretList, req)
46
+ resp = await client.stub.SecretList(req)
48
47
  items.extend(resp.items)
49
48
  return len(resp.items) < max_page_size
50
49
 
modal/client.py CHANGED
@@ -6,32 +6,24 @@ import sys
6
6
  import urllib.parse
7
7
  import warnings
8
8
  from collections.abc import AsyncGenerator, AsyncIterator, Collection, Mapping
9
- from typing import (
10
- Any,
11
- ClassVar,
12
- Generic,
13
- Optional,
14
- TypeVar,
15
- Union,
16
- )
9
+ from typing import Any, ClassVar, Optional, TypeVar, Union
17
10
 
18
11
  import grpclib.client
19
12
  from google.protobuf import empty_pb2
20
13
  from google.protobuf.message import Message
21
- from grpclib import GRPCError, Status
22
14
  from synchronicity.async_wrap import asynccontextmanager
23
15
 
24
16
  from modal._utils.async_utils import synchronizer
25
17
  from modal_proto import api_grpc, api_pb2, modal_api_grpc
26
18
  from modal_version import __version__
27
19
 
28
- from ._traceback import print_server_warnings, suppress_tb_frames
20
+ from ._traceback import print_server_warnings
29
21
  from ._utils import async_utils
30
22
  from ._utils.async_utils import TaskContext, synchronize_api
31
23
  from ._utils.auth_token_manager import _AuthTokenManager
32
- from ._utils.grpc_utils import ConnectionManager, retry_transient_errors
24
+ from ._utils.grpc_utils import ConnectionManager
33
25
  from .config import _check_config, _is_remote, config, logger
34
- from .exception import AuthError, ClientClosed, NotFoundError
26
+ from .exception import AuthError, ClientClosed
35
27
 
36
28
  HEARTBEAT_INTERVAL: float = config.get("heartbeat_interval")
37
29
  HEARTBEAT_TIMEOUT: float = HEARTBEAT_INTERVAL + 0.1
@@ -159,7 +151,7 @@ class _Client:
159
151
  async def hello(self):
160
152
  """Connect to server and retrieve version information; raise appropriate error for various failures."""
161
153
  logger.debug(f"Client ({id(self)}): Starting")
162
- resp = await retry_transient_errors(self.stub.ClientHello, empty_pb2.Empty())
154
+ resp = await self.stub.ClientHello(empty_pb2.Empty())
163
155
  print_server_warnings(resp.server_warnings)
164
156
 
165
157
  async def __aenter__(self):
@@ -362,105 +354,3 @@ class _Client:
362
354
 
363
355
 
364
356
  Client = synchronize_api(_Client)
365
-
366
-
367
- class grpc_error_converter:
368
- def __enter__(self):
369
- pass
370
-
371
- def __exit__(self, exc_type, exc, traceback) -> bool:
372
- # skip all internal frames from grpclib
373
- use_full_traceback = config.get("traceback")
374
- with suppress_tb_frames(1):
375
- if isinstance(exc, GRPCError):
376
- if exc.status == Status.NOT_FOUND:
377
- if use_full_traceback:
378
- raise NotFoundError(exc.message)
379
- else:
380
- raise NotFoundError(exc.message) from None # from None to skip the grpc-internal cause
381
-
382
- if not use_full_traceback:
383
- # just include the frame in grpclib that actually raises the GRPCError
384
- tb = exc.__traceback__
385
- while tb.tb_next:
386
- tb = tb.tb_next
387
- exc.with_traceback(tb)
388
- raise exc from None # from None to skip the grpc-internal cause
389
- raise exc
390
-
391
- return False
392
-
393
-
394
- class UnaryUnaryWrapper(Generic[RequestType, ResponseType]):
395
- # Calls a grpclib.UnaryUnaryMethod using a specific Client instance, respecting
396
- # if that client is closed etc. and possibly introducing Modal-specific retry logic
397
- wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]
398
- client: _Client
399
-
400
- def __init__(
401
- self,
402
- wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType],
403
- client: _Client,
404
- server_url: str,
405
- ):
406
- self.wrapped_method = wrapped_method
407
- self.client = client
408
- self.server_url = server_url
409
-
410
- @property
411
- def name(self) -> str:
412
- return self.wrapped_method.name
413
-
414
- async def __call__(
415
- self,
416
- req: RequestType,
417
- *,
418
- timeout: Optional[float] = None,
419
- metadata: Optional[_MetadataLike] = None,
420
- ) -> ResponseType:
421
- if self.client._snapshotted:
422
- logger.debug(f"refreshing client after snapshot for {self.name.rsplit('/', 1)[1]}")
423
- self.client = await _Client.from_env()
424
-
425
- # Note: We override the grpclib method's channel (see grpclib's code [1]). I think this is fine
426
- # since grpclib's code doesn't seem to change very much, but we could also recreate the
427
- # grpclib stub if we aren't comfortable with this. The downside is then we need to cache
428
- # the grpclib stub so the rest of our code becomes a bit more complicated.
429
- #
430
- # We need to override the channel because after the process is forked or the client is
431
- # snapshotted, the existing channel may be stale / unusable.
432
- #
433
- # [1]: https://github.com/vmagamedov/grpclib/blob/62f968a4c84e3f64e6966097574ff0a59969ea9b/grpclib/client.py#L844
434
- self.wrapped_method.channel = await self.client._get_channel(self.server_url)
435
- with suppress_tb_frames(1), grpc_error_converter():
436
- return await self.client._call_unary(self.wrapped_method, req, timeout=timeout, metadata=metadata)
437
-
438
-
439
- class UnaryStreamWrapper(Generic[RequestType, ResponseType]):
440
- wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]
441
-
442
- def __init__(
443
- self,
444
- wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType],
445
- client: _Client,
446
- server_url: str,
447
- ):
448
- self.wrapped_method = wrapped_method
449
- self.client = client
450
- self.server_url = server_url
451
-
452
- @property
453
- def name(self) -> str:
454
- return self.wrapped_method.name
455
-
456
- async def unary_stream(
457
- self,
458
- request,
459
- metadata: Optional[Any] = None,
460
- ):
461
- if self.client._snapshotted:
462
- logger.debug(f"refreshing client after snapshot for {self.name.rsplit('/', 1)[1]}")
463
- self.client = await _Client.from_env()
464
- self.wrapped_method.channel = await self.client._get_channel(self.server_url)
465
- async for response in self.client._call_stream(self.wrapped_method, request, metadata=metadata):
466
- yield response