modal 0.62.115__py3-none-any.whl → 0.72.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +13 -9
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +402 -398
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -60
  11. modal/_resources.py +26 -7
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1025 -0
  15. modal/{execution_context.py → _runtime/execution_context.py} +11 -2
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +123 -6
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +50 -14
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +386 -104
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +299 -98
  29. modal/_utils/grpc_testing.py +47 -34
  30. modal/_utils/grpc_utils.py +54 -21
  31. modal/_utils/hash_utils.py +51 -10
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +3 -3
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +12 -10
  43. modal/app.py +561 -323
  44. modal/app.pyi +474 -262
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +22 -6
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +203 -42
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +61 -13
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +21 -48
  55. modal/cli/launch.py +28 -14
  56. modal/cli/network_file_system.py +57 -21
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +34 -9
  59. modal/cli/programs/vscode.py +58 -8
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +199 -96
  62. modal/cli/secret.py +5 -4
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +74 -8
  65. modal/cli/volume.py +97 -56
  66. modal/client.py +248 -144
  67. modal/client.pyi +156 -124
  68. modal/cloud_bucket_mount.py +43 -30
  69. modal/cloud_bucket_mount.pyi +32 -25
  70. modal/cls.py +528 -141
  71. modal/cls.pyi +189 -145
  72. modal/config.py +32 -15
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +50 -54
  76. modal/dict.pyi +120 -164
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +30 -43
  80. modal/experimental.py +62 -2
  81. modal/file_io.py +537 -0
  82. modal/file_io.pyi +235 -0
  83. modal/file_pattern_matcher.py +196 -0
  84. modal/functions.py +846 -428
  85. modal/functions.pyi +446 -387
  86. modal/gpu.py +57 -44
  87. modal/image.py +943 -417
  88. modal/image.pyi +584 -245
  89. modal/io_streams.py +434 -0
  90. modal/io_streams.pyi +122 -0
  91. modal/mount.py +223 -90
  92. modal/mount.pyi +241 -243
  93. modal/network_file_system.py +85 -86
  94. modal/network_file_system.pyi +151 -110
  95. modal/object.py +66 -36
  96. modal/object.pyi +166 -143
  97. modal/output.py +63 -0
  98. modal/parallel_map.py +73 -47
  99. modal/parallel_map.pyi +51 -63
  100. modal/partial_function.py +272 -107
  101. modal/partial_function.pyi +219 -120
  102. modal/proxy.py +15 -12
  103. modal/proxy.pyi +3 -8
  104. modal/queue.py +96 -72
  105. modal/queue.pyi +210 -135
  106. modal/requirements/2024.04.txt +2 -1
  107. modal/requirements/2024.10.txt +16 -0
  108. modal/requirements/README.md +21 -0
  109. modal/requirements/base-images.json +22 -0
  110. modal/retries.py +45 -4
  111. modal/runner.py +325 -203
  112. modal/runner.pyi +124 -110
  113. modal/running_app.py +27 -4
  114. modal/sandbox.py +509 -231
  115. modal/sandbox.pyi +396 -169
  116. modal/schedule.py +2 -2
  117. modal/scheduler_placement.py +20 -3
  118. modal/secret.py +41 -25
  119. modal/secret.pyi +62 -42
  120. modal/serving.py +39 -49
  121. modal/serving.pyi +37 -43
  122. modal/stream_type.py +15 -0
  123. modal/token_flow.py +5 -3
  124. modal/token_flow.pyi +37 -32
  125. modal/volume.py +123 -137
  126. modal/volume.pyi +228 -221
  127. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/METADATA +5 -5
  128. modal-0.72.13.dist-info/RECORD +174 -0
  129. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/top_level.txt +0 -1
  130. modal_docs/gen_reference_docs.py +3 -1
  131. modal_docs/mdmd/mdmd.py +0 -1
  132. modal_docs/mdmd/signatures.py +1 -2
  133. modal_global_objects/images/base_images.py +28 -0
  134. modal_global_objects/mounts/python_standalone.py +2 -2
  135. modal_proto/__init__.py +1 -1
  136. modal_proto/api.proto +1231 -531
  137. modal_proto/api_grpc.py +750 -430
  138. modal_proto/api_pb2.py +2102 -1176
  139. modal_proto/api_pb2.pyi +8859 -0
  140. modal_proto/api_pb2_grpc.py +1329 -675
  141. modal_proto/api_pb2_grpc.pyi +1416 -0
  142. modal_proto/modal_api_grpc.py +149 -0
  143. modal_proto/modal_options_grpc.py +3 -0
  144. modal_proto/options_pb2.pyi +20 -0
  145. modal_proto/options_pb2_grpc.pyi +7 -0
  146. modal_proto/py.typed +0 -0
  147. modal_version/__init__.py +1 -1
  148. modal_version/_version_generated.py +2 -2
  149. modal/_asgi.py +0 -370
  150. modal/_container_exec.py +0 -128
  151. modal/_container_io_manager.py +0 -646
  152. modal/_container_io_manager.pyi +0 -412
  153. modal/_sandbox_shell.py +0 -49
  154. modal/app_utils.py +0 -20
  155. modal/app_utils.pyi +0 -17
  156. modal/execution_context.pyi +0 -37
  157. modal/shared_volume.py +0 -23
  158. modal/shared_volume.pyi +0 -24
  159. modal-0.62.115.dist-info/RECORD +0 -207
  160. modal_global_objects/images/conda.py +0 -15
  161. modal_global_objects/images/debian_slim.py +0 -15
  162. modal_global_objects/images/micromamba.py +0 -15
  163. test/__init__.py +0 -1
  164. test/aio_test.py +0 -12
  165. test/async_utils_test.py +0 -279
  166. test/blob_test.py +0 -67
  167. test/cli_imports_test.py +0 -149
  168. test/cli_test.py +0 -674
  169. test/client_test.py +0 -203
  170. test/cloud_bucket_mount_test.py +0 -22
  171. test/cls_test.py +0 -636
  172. test/config_test.py +0 -149
  173. test/conftest.py +0 -1485
  174. test/container_app_test.py +0 -50
  175. test/container_test.py +0 -1405
  176. test/cpu_test.py +0 -23
  177. test/decorator_test.py +0 -85
  178. test/deprecation_test.py +0 -34
  179. test/dict_test.py +0 -51
  180. test/e2e_test.py +0 -68
  181. test/error_test.py +0 -7
  182. test/function_serialization_test.py +0 -32
  183. test/function_test.py +0 -791
  184. test/function_utils_test.py +0 -101
  185. test/gpu_test.py +0 -159
  186. test/grpc_utils_test.py +0 -82
  187. test/helpers.py +0 -47
  188. test/image_test.py +0 -814
  189. test/live_reload_test.py +0 -80
  190. test/lookup_test.py +0 -70
  191. test/mdmd_test.py +0 -329
  192. test/mount_test.py +0 -162
  193. test/mounted_files_test.py +0 -327
  194. test/network_file_system_test.py +0 -188
  195. test/notebook_test.py +0 -66
  196. test/object_test.py +0 -41
  197. test/package_utils_test.py +0 -25
  198. test/queue_test.py +0 -115
  199. test/resolver_test.py +0 -59
  200. test/retries_test.py +0 -67
  201. test/runner_test.py +0 -85
  202. test/sandbox_test.py +0 -191
  203. test/schedule_test.py +0 -15
  204. test/scheduler_placement_test.py +0 -57
  205. test/secret_test.py +0 -89
  206. test/serialization_test.py +0 -50
  207. test/stub_composition_test.py +0 -10
  208. test/stub_test.py +0 -361
  209. test/test_asgi_wrapper.py +0 -234
  210. test/token_flow_test.py +0 -18
  211. test/traceback_test.py +0 -135
  212. test/tunnel_test.py +0 -29
  213. test/utils_test.py +0 -88
  214. test/version_test.py +0 -14
  215. test/volume_test.py +0 -397
  216. test/watcher_test.py +0 -58
  217. test/webhook_test.py +0 -145
  218. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/LICENSE +0 -0
  219. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/WHEEL +0 -0
  220. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/entry_points.txt +0 -0
modal/client.py CHANGED
@@ -1,32 +1,39 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import asyncio
3
+ import os
3
4
  import platform
4
5
  import warnings
5
- from typing import AsyncIterator, Awaitable, Callable, ClassVar, Dict, Optional, Tuple
6
+ from collections.abc import AsyncGenerator, AsyncIterator, Collection, Mapping
7
+ from typing import (
8
+ Any,
9
+ ClassVar,
10
+ Generic,
11
+ Optional,
12
+ TypeVar,
13
+ Union,
14
+ )
6
15
 
7
16
  import grpclib.client
8
- from aiohttp import ClientConnectorError, ClientResponseError
9
17
  from google.protobuf import empty_pb2
10
- from grpclib import GRPCError, Status
18
+ from google.protobuf.message import Message
11
19
  from synchronicity.async_wrap import asynccontextmanager
12
20
 
13
- from modal_proto import api_grpc, api_pb2
21
+ from modal._utils.async_utils import synchronizer
22
+ from modal_proto import api_grpc, api_pb2, modal_api_grpc
14
23
  from modal_version import __version__
15
24
 
25
+ from ._traceback import print_server_warnings
16
26
  from ._utils import async_utils
17
- from ._utils.async_utils import synchronize_api
18
- from ._utils.grpc_utils import create_channel, retry_transient_errors
19
- from ._utils.http_utils import http_client_with_tls
20
- from .config import _check_config, config, logger
21
- from .exception import AuthError, ConnectionError, DeprecationError, VersionError
27
+ from ._utils.async_utils import TaskContext, synchronize_api
28
+ from ._utils.grpc_utils import connect_channel, create_channel, retry_transient_errors
29
+ from .config import _check_config, _is_remote, config, logger
30
+ from .exception import AuthError, ClientClosed, ConnectionError
22
31
 
23
32
  HEARTBEAT_INTERVAL: float = config.get("heartbeat_interval")
24
33
  HEARTBEAT_TIMEOUT: float = HEARTBEAT_INTERVAL + 0.1
25
- CLIENT_CREATE_ATTEMPT_TIMEOUT: float = 4.0
26
- CLIENT_CREATE_TOTAL_TIMEOUT: float = 15.0
27
34
 
28
35
 
29
- def _get_metadata(client_type: int, credentials: Optional[Tuple[str, str]], version: str) -> Dict[str, str]:
36
+ def _get_metadata(client_type: int, credentials: Optional[tuple[str, str]], version: str) -> dict[str, str]:
30
37
  # This implements a simplified version of platform.platform() that's still machine-readable
31
38
  uname: platform.uname_result = platform.uname()
32
39
  if uname.system == "Darwin":
@@ -50,132 +57,89 @@ def _get_metadata(client_type: int, credentials: Optional[Tuple[str, str]], vers
50
57
  "x-modal-token-secret": token_secret,
51
58
  }
52
59
  )
53
- elif credentials and client_type == api_pb2.CLIENT_TYPE_CONTAINER:
54
- task_id, task_secret = credentials
55
- metadata.update(
56
- {
57
- "x-modal-task-id": task_id,
58
- "x-modal-task-secret": task_secret,
59
- }
60
- )
61
60
  return metadata
62
61
 
63
62
 
64
- async def _http_check(url: str, timeout: float) -> str:
65
- # Used for sanity checking connection issues
66
- try:
67
- async with http_client_with_tls(timeout=timeout) as session:
68
- async with session.get(url) as resp:
69
- return f"HTTP status: {resp.status}"
70
- except ClientResponseError as exc:
71
- return f"HTTP status: {exc.status}"
72
- except ClientConnectorError as exc:
73
- return f"HTTP exception: {exc.os_error.__class__.__name__}"
74
- except Exception as exc:
75
- return f"HTTP exception: {exc.__class__.__name__}"
76
-
77
-
78
- async def _grpc_exc_string(exc: GRPCError, method_name: str, server_url: str, timeout: float) -> str:
79
- http_status = await _http_check(server_url, timeout=timeout)
80
- return f"{method_name}: {exc.message} [gRPC status: {exc.status.name}, {http_status}]"
63
+ ReturnType = TypeVar("ReturnType")
64
+ _Value = Union[str, bytes]
65
+ _MetadataLike = Union[Mapping[str, _Value], Collection[tuple[str, _Value]]]
66
+ RequestType = TypeVar("RequestType", bound=Message)
67
+ ResponseType = TypeVar("ResponseType", bound=Message)
81
68
 
82
69
 
83
70
  class _Client:
84
71
  _client_from_env: ClassVar[Optional["_Client"]] = None
85
72
  _client_from_env_lock: ClassVar[Optional[asyncio.Lock]] = None
73
+ _cancellation_context: TaskContext
74
+ _cancellation_context_event_loop: asyncio.AbstractEventLoop = None
75
+ _stub: Optional[api_grpc.ModalClientStub]
86
76
 
87
77
  def __init__(
88
78
  self,
89
79
  server_url: str,
90
80
  client_type: int,
91
- credentials: Optional[Tuple[str, str]],
81
+ credentials: Optional[tuple[str, str]],
92
82
  version: str = __version__,
93
83
  ):
94
- """The Modal client object is not intended to be instantiated directly by users."""
84
+ """mdmd:hidden
85
+ The Modal client object is not intended to be instantiated directly by users.
86
+ """
95
87
  self.server_url = server_url
96
88
  self.client_type = client_type
97
- self.credentials = credentials
89
+ self._credentials = credentials
98
90
  self.version = version
99
- self._authenticated = False
100
- self.image_builder_version: Optional[str] = None
101
- self._pre_stop: Optional[Callable[[], Awaitable[None]]] = None
91
+ self._closed = False
102
92
  self._channel: Optional[grpclib.client.Channel] = None
103
- self._stub: Optional[api_grpc.ModalClientStub] = None
93
+ self._stub: Optional[modal_api_grpc.ModalClientModal] = None
94
+ self._snapshotted = False
95
+ self._owner_pid = None
104
96
 
105
- @property
106
- def stub(self) -> Optional[api_grpc.ModalClientStub]:
107
- """mdmd:hidden"""
108
- return self._stub
97
+ def is_closed(self) -> bool:
98
+ return self._closed
109
99
 
110
100
  @property
111
- def authenticated(self) -> bool:
101
+ def stub(self) -> modal_api_grpc.ModalClientModal:
112
102
  """mdmd:hidden"""
113
- return self._authenticated
103
+ assert self._stub
104
+ return self._stub
114
105
 
115
106
  async def _open(self):
107
+ self._closed = False
116
108
  assert self._stub is None
117
- metadata = _get_metadata(self.client_type, self.credentials, self.version)
109
+ metadata = _get_metadata(self.client_type, self._credentials, self.version)
118
110
  self._channel = create_channel(self.server_url, metadata=metadata)
119
- self._stub = api_grpc.ModalClientStub(self._channel) # type: ignore
120
-
121
- async def _close(self):
122
- if self._pre_stop is not None:
123
- logger.debug("Client: running pre-stop coroutine before shutting down")
124
- await self._pre_stop() # type: ignore
125
-
111
+ try:
112
+ await connect_channel(self._channel)
113
+ except OSError as exc:
114
+ raise ConnectionError(str(exc))
115
+ self._cancellation_context = TaskContext(grace=0.5) # allow running rpcs to finish in 0.5s when closing client
116
+ self._cancellation_context_event_loop = asyncio.get_running_loop()
117
+ await self._cancellation_context.__aenter__()
118
+ self._grpclib_stub = api_grpc.ModalClientStub(self._channel)
119
+ self._stub = modal_api_grpc.ModalClientModal(self._grpclib_stub, client=self)
120
+ self._owner_pid = os.getpid()
121
+
122
+ async def _close(self, prep_for_restore: bool = False):
123
+ logger.debug(f"Client ({id(self)}): closing")
124
+ self._closed = True
125
+ await self._cancellation_context.__aexit__(None, None, None) # wait for all rpcs to be finished/cancelled
126
126
  if self._channel is not None:
127
127
  self._channel.close()
128
128
 
129
+ if prep_for_restore:
130
+ self._snapshotted = True
131
+
129
132
  # Remove cached client.
130
133
  self.set_env_client(None)
131
134
 
132
- def set_pre_stop(self, pre_stop: Callable[[], Awaitable[None]]):
133
- """mdmd:hidden"""
134
- # hack: stub.serve() gets into a losing race with the `on_shutdown` client
135
- # teardown when an interrupt signal is received (eg. KeyboardInterrupt).
136
- # By registering a pre-stop fn stub.serve() can have its teardown
137
- # performed before the client is disconnected.
138
- #
139
- # ref: github.com/modal-labs/modal-client/pull/108
140
- self._pre_stop = pre_stop
141
-
142
- async def _init(self):
135
+ async def hello(self):
143
136
  """Connect to server and retrieve version information; raise appropriate error for various failures."""
144
- logger.debug("Client: Starting")
145
- _check_config()
146
- try:
147
- req = empty_pb2.Empty()
148
- resp = await retry_transient_errors(
149
- self.stub.ClientHello,
150
- req,
151
- attempt_timeout=CLIENT_CREATE_ATTEMPT_TIMEOUT,
152
- total_timeout=CLIENT_CREATE_TOTAL_TIMEOUT,
153
- )
154
- if resp.warning:
155
- ALARM_EMOJI = chr(0x1F6A8)
156
- warnings.warn(f"{ALARM_EMOJI} {resp.warning} {ALARM_EMOJI}", DeprecationError)
157
- self._authenticated = True
158
- self.image_builder_version = resp.image_builder_version
159
- except GRPCError as exc:
160
- if exc.status == Status.FAILED_PRECONDITION:
161
- raise VersionError(
162
- f"The client version ({self.version}) is too old. Please update (pip install --upgrade modal)."
163
- )
164
- elif exc.status == Status.UNAUTHENTICATED:
165
- raise AuthError(exc.message)
166
- else:
167
- exc_string = await _grpc_exc_string(exc, "ClientHello", self.server_url, CLIENT_CREATE_TOTAL_TIMEOUT)
168
- raise ConnectionError(exc_string)
169
- except (OSError, asyncio.TimeoutError) as exc:
170
- raise ConnectionError(str(exc))
137
+ logger.debug(f"Client ({id(self)}): Starting")
138
+ resp = await retry_transient_errors(self.stub.ClientHello, empty_pb2.Empty())
139
+ print_server_warnings(resp.server_warnings)
171
140
 
172
141
  async def __aenter__(self):
173
142
  await self._open()
174
- try:
175
- await self._init()
176
- except BaseException:
177
- await self._close()
178
- raise
179
143
  return self
180
144
 
181
145
  async def __aexit__(self, exc_type, exc, tb):
@@ -191,7 +155,6 @@ class _Client:
191
155
  client = cls(server_url, api_pb2.CLIENT_TYPE_CLIENT, credentials=None)
192
156
  try:
193
157
  await client._open()
194
- # Skip client._init
195
158
  yield client
196
159
  finally:
197
160
  await client._close()
@@ -201,28 +164,15 @@ class _Client:
201
164
  """mdmd:hidden
202
165
  Singleton that is instantiated from the Modal config and reused on subsequent calls.
203
166
  """
167
+ _check_config()
168
+
204
169
  if _override_config:
205
170
  # Only used for testing
206
171
  c = _override_config
207
172
  else:
208
173
  c = config
209
174
 
210
- server_url = c["server_url"]
211
-
212
- token_id = c["token_id"]
213
- token_secret = c["token_secret"]
214
- task_id = c["task_id"]
215
- task_secret = c["task_secret"]
216
-
217
- if task_id and task_secret:
218
- client_type = api_pb2.CLIENT_TYPE_CONTAINER
219
- credentials = (task_id, task_secret)
220
- elif token_id and token_secret:
221
- client_type = api_pb2.CLIENT_TYPE_CLIENT
222
- credentials = (token_id, token_secret)
223
- else:
224
- client_type = api_pb2.CLIENT_TYPE_CLIENT
225
- credentials = None
175
+ credentials: Optional[tuple[str, str]]
226
176
 
227
177
  if cls._client_from_env_lock is None:
228
178
  cls._client_from_env_lock = asyncio.Lock()
@@ -230,50 +180,63 @@ class _Client:
230
180
  async with cls._client_from_env_lock:
231
181
  if cls._client_from_env:
232
182
  return cls._client_from_env
183
+
184
+ token_id = c["token_id"]
185
+ token_secret = c["token_secret"]
186
+ if _is_remote():
187
+ if token_id or token_secret:
188
+ warnings.warn(
189
+ "Modal tokens provided by MODAL_TOKEN_ID and MODAL_TOKEN_SECRET"
190
+ " (or through the config file) are ignored inside containers."
191
+ )
192
+ client_type = api_pb2.CLIENT_TYPE_CONTAINER
193
+ credentials = None
194
+ elif token_id and token_secret:
195
+ client_type = api_pb2.CLIENT_TYPE_CLIENT
196
+ credentials = (token_id, token_secret)
233
197
  else:
234
- client = _Client(server_url, client_type, credentials)
235
- await client._open()
236
- async_utils.on_shutdown(client._close())
237
- try:
238
- await client._init()
239
- except AuthError:
240
- if not credentials:
241
- creds_missing_msg = (
242
- "Token missing. Could not authenticate client."
243
- " If you have token credentials, see modal.com/docs/reference/modal.config for setup help."
244
- " If you are a new user, register an account at modal.com, then run `modal token new`."
245
- )
246
- raise AuthError(creds_missing_msg)
247
- else:
248
- raise
249
- cls._client_from_env = client
250
- return client
198
+ raise AuthError(
199
+ "Token missing. Could not authenticate client."
200
+ " If you have token credentials, see modal.com/docs/reference/modal.config for setup help."
201
+ " If you are a new user, register an account at modal.com, then run `modal token new`."
202
+ )
203
+
204
+ server_url = c["server_url"]
205
+ client = _Client(server_url, client_type, credentials)
206
+ await client._open()
207
+ async_utils.on_shutdown(client._close())
208
+ cls._client_from_env = client
209
+ return client
251
210
 
252
211
  @classmethod
253
212
  async def from_credentials(cls, token_id: str, token_secret: str) -> "_Client":
254
- """mdmd:hidden
213
+ """
255
214
  Constructor based on token credentials; useful for managing Modal on behalf of third-party users.
215
+
216
+ **Usage:**
217
+
218
+ ```python notest
219
+ client = modal.Client.from_credentials("my_token_id", "my_token_secret")
220
+
221
+ modal.Sandbox.create("echo", "hi", client=client, app=app)
222
+ ```
256
223
  """
224
+ _check_config()
257
225
  server_url = config["server_url"]
258
226
  client_type = api_pb2.CLIENT_TYPE_CLIENT
259
227
  credentials = (token_id, token_secret)
260
228
  client = _Client(server_url, client_type, credentials)
261
229
  await client._open()
262
- try:
263
- await client._init()
264
- except BaseException:
265
- await client._close()
266
- raise
267
230
  async_utils.on_shutdown(client._close())
268
231
  return client
269
232
 
270
233
  @classmethod
271
- async def verify(cls, server_url: str, credentials: Tuple[str, str]) -> None:
234
+ async def verify(cls, server_url: str, credentials: tuple[str, str]) -> None:
272
235
  """mdmd:hidden
273
236
  Check whether can the client can connect to this server with these credentials; raise if not.
274
237
  """
275
- async with cls(server_url, api_pb2.CLIENT_TYPE_CLIENT, credentials):
276
- pass # Will call ClientHello RPC and possibly raise AuthError or ConnectionError
238
+ async with cls(server_url, api_pb2.CLIENT_TYPE_CLIENT, credentials) as client:
239
+ await client.hello() # Will call ClientHello RPC and possibly raise AuthError or ConnectionError
277
240
 
278
241
  @classmethod
279
242
  def set_env_client(cls, client: Optional["_Client"]):
@@ -281,5 +244,146 @@ class _Client:
281
244
  # Just used from tests.
282
245
  cls._client_from_env = client
283
246
 
247
+ async def _call_safely(self, coro, readable_method: str):
248
+ """Runs coroutine wrapped in a task that's part of the client's task context
249
+
250
+ * Raises ClientClosed in case the client is closed while the coroutine is executed
251
+ * Logs warning if call is made outside of the event loop that the client is running in,
252
+ and execute without the cancellation context in that case
253
+ """
254
+
255
+ if self.is_closed():
256
+ coro.close() # prevent "was never awaited"
257
+ raise ClientClosed(id(self))
258
+
259
+ current_event_loop = asyncio.get_running_loop()
260
+ if current_event_loop == self._cancellation_context_event_loop:
261
+ # make request cancellable if we are in the same event loop as the rpc context
262
+ # this should usually be the case!
263
+ try:
264
+ request_task = self._cancellation_context.create_task(coro)
265
+ request_task.set_name(readable_method)
266
+ return await request_task
267
+ except asyncio.CancelledError:
268
+ if self.is_closed():
269
+ raise ClientClosed(id(self)) from None
270
+ raise # if the task is cancelled as part of synchronizer shutdown or similar, don't raise ClientClosed
271
+ else:
272
+ # this should be rare - mostly used in tests where rpc requests sometimes are triggered
273
+ # outside of a client context/synchronicity loop
274
+ logger.warning(f"RPC request to {readable_method} made outside of task context")
275
+ return await coro
276
+
277
+ async def _reset_on_pid_change(self):
278
+ if self._owner_pid and self._owner_pid != os.getpid():
279
+ # not calling .close() since that would also interact with stale resources
280
+ # just reset the internal state
281
+ self._channel = None
282
+ self._stub = None
283
+ self._grpclib_stub = None
284
+ self._owner_pid = None
285
+
286
+ self.set_env_client(None)
287
+ # TODO(elias): reset _cancellation_context in case ?
288
+ await self._open()
289
+
290
+ async def _get_grpclib_method(self, method_name: str) -> Any:
291
+ # safely get grcplib method that is bound to a valid channel
292
+ # This prevents usage of stale methods across forks of processes
293
+ await self._reset_on_pid_change()
294
+ return getattr(self._grpclib_stub, method_name)
295
+
296
+ @synchronizer.nowrap
297
+ async def _call_unary(
298
+ self,
299
+ method_name: str,
300
+ request: Any,
301
+ *,
302
+ timeout: Optional[float] = None,
303
+ metadata: Optional[_MetadataLike] = None,
304
+ ) -> Any:
305
+ grpclib_method = await self._get_grpclib_method(method_name)
306
+ coro = grpclib_method(request, timeout=timeout, metadata=metadata)
307
+ return await self._call_safely(coro, grpclib_method.name)
308
+
309
+ @synchronizer.nowrap
310
+ async def _call_stream(
311
+ self,
312
+ method_name: str,
313
+ request: Any,
314
+ *,
315
+ metadata: Optional[_MetadataLike],
316
+ ) -> AsyncGenerator[Any, None]:
317
+ grpclib_method = await self._get_grpclib_method(method_name)
318
+ stream_context = grpclib_method.open(metadata=metadata)
319
+ stream = await self._call_safely(stream_context.__aenter__(), f"{grpclib_method.name}.open")
320
+ try:
321
+ await self._call_safely(stream.send_message(request, end=True), f"{grpclib_method.name}.send_message")
322
+ while 1:
323
+ try:
324
+ yield await self._call_safely(stream.__anext__(), f"{grpclib_method.name}.recv")
325
+ except StopAsyncIteration:
326
+ break
327
+ except BaseException as exc:
328
+ did_handle_exception = await stream_context.__aexit__(type(exc), exc, exc.__traceback__)
329
+ if not did_handle_exception:
330
+ raise
331
+ else:
332
+ await stream_context.__aexit__(None, None, None)
333
+
284
334
 
285
335
  Client = synchronize_api(_Client)
336
+
337
+
338
+ class UnaryUnaryWrapper(Generic[RequestType, ResponseType]):
339
+ # Calls a grpclib.UnaryUnaryMethod using a specific Client instance, respecting
340
+ # if that client is closed etc. and possibly introducing Modal-specific retry logic
341
+ wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]
342
+ client: _Client
343
+
344
+ def __init__(self, wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType], client: _Client):
345
+ # we pass in the wrapped_method here to get the correct static types
346
+ # but don't use the reference directly, see `def wrapped_method` below
347
+ self._wrapped_full_name = wrapped_method.name
348
+ self._wrapped_method_name = wrapped_method.name.rsplit("/", 1)[1]
349
+ self.client = client
350
+
351
+ @property
352
+ def name(self) -> str:
353
+ return self._wrapped_full_name
354
+
355
+ async def __call__(
356
+ self,
357
+ req: RequestType,
358
+ *,
359
+ timeout: Optional[float] = None,
360
+ metadata: Optional[_MetadataLike] = None,
361
+ ) -> ResponseType:
362
+ if self.client._snapshotted:
363
+ logger.debug(f"refreshing client after snapshot for {self._wrapped_method_name}")
364
+ self.client = await _Client.from_env()
365
+ return await self.client._call_unary(self._wrapped_method_name, req, timeout=timeout, metadata=metadata)
366
+
367
+
368
+ class UnaryStreamWrapper(Generic[RequestType, ResponseType]):
369
+ wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]
370
+
371
+ def __init__(self, wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType], client: _Client):
372
+ self._wrapped_full_name = wrapped_method.name
373
+ self._wrapped_method_name = wrapped_method.name.rsplit("/", 1)[1]
374
+ self.client = client
375
+
376
+ @property
377
+ def name(self) -> str:
378
+ return self._wrapped_full_name
379
+
380
+ async def unary_stream(
381
+ self,
382
+ request,
383
+ metadata: Optional[Any] = None,
384
+ ):
385
+ if self.client._snapshotted:
386
+ logger.debug(f"refreshing client after snapshot for {self._wrapped_method_name}")
387
+ self.client = await _Client.from_env()
388
+ async for response in self.client._call_stream(self._wrapped_method_name, request, metadata=metadata):
389
+ yield response