modal 0.62.16__py3-none-any.whl → 0.72.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +17 -13
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +420 -937
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -59
  11. modal/_resources.py +51 -0
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1036 -0
  15. modal/_runtime/execution_context.py +89 -0
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +134 -9
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +52 -16
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +479 -100
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +460 -171
  29. modal/_utils/grpc_testing.py +47 -31
  30. modal/_utils/grpc_utils.py +62 -109
  31. modal/_utils/hash_utils.py +61 -19
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +5 -7
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +14 -12
  43. modal/app.py +1003 -314
  44. modal/app.pyi +540 -264
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +63 -53
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +205 -45
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +62 -14
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +64 -58
  55. modal/cli/launch.py +32 -18
  56. modal/cli/network_file_system.py +64 -83
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +35 -10
  59. modal/cli/programs/vscode.py +60 -10
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +234 -131
  62. modal/cli/secret.py +8 -7
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +79 -10
  65. modal/cli/volume.py +110 -109
  66. modal/client.py +250 -144
  67. modal/client.pyi +157 -118
  68. modal/cloud_bucket_mount.py +108 -34
  69. modal/cloud_bucket_mount.pyi +32 -38
  70. modal/cls.py +535 -148
  71. modal/cls.pyi +190 -146
  72. modal/config.py +41 -19
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +111 -65
  76. modal/dict.pyi +136 -131
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +34 -43
  80. modal/experimental.py +61 -2
  81. modal/extensions/ipython.py +5 -5
  82. modal/file_io.py +537 -0
  83. modal/file_io.pyi +235 -0
  84. modal/file_pattern_matcher.py +197 -0
  85. modal/functions.py +906 -911
  86. modal/functions.pyi +466 -430
  87. modal/gpu.py +57 -44
  88. modal/image.py +1089 -479
  89. modal/image.pyi +584 -228
  90. modal/io_streams.py +434 -0
  91. modal/io_streams.pyi +122 -0
  92. modal/mount.py +314 -101
  93. modal/mount.pyi +241 -235
  94. modal/network_file_system.py +92 -92
  95. modal/network_file_system.pyi +152 -110
  96. modal/object.py +67 -36
  97. modal/object.pyi +166 -143
  98. modal/output.py +63 -0
  99. modal/parallel_map.py +434 -0
  100. modal/parallel_map.pyi +75 -0
  101. modal/partial_function.py +282 -117
  102. modal/partial_function.pyi +222 -129
  103. modal/proxy.py +15 -12
  104. modal/proxy.pyi +3 -8
  105. modal/queue.py +182 -65
  106. modal/queue.pyi +218 -118
  107. modal/requirements/2024.04.txt +29 -0
  108. modal/requirements/2024.10.txt +16 -0
  109. modal/requirements/README.md +21 -0
  110. modal/requirements/base-images.json +22 -0
  111. modal/retries.py +48 -7
  112. modal/runner.py +459 -156
  113. modal/runner.pyi +135 -71
  114. modal/running_app.py +38 -0
  115. modal/sandbox.py +514 -236
  116. modal/sandbox.pyi +397 -169
  117. modal/schedule.py +4 -4
  118. modal/scheduler_placement.py +20 -3
  119. modal/secret.py +56 -31
  120. modal/secret.pyi +62 -42
  121. modal/serving.py +51 -56
  122. modal/serving.pyi +44 -36
  123. modal/stream_type.py +15 -0
  124. modal/token_flow.py +5 -3
  125. modal/token_flow.pyi +37 -32
  126. modal/volume.py +285 -157
  127. modal/volume.pyi +249 -184
  128. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/METADATA +7 -7
  129. modal-0.72.11.dist-info/RECORD +174 -0
  130. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
  131. modal_docs/gen_reference_docs.py +3 -1
  132. modal_docs/mdmd/mdmd.py +0 -1
  133. modal_docs/mdmd/signatures.py +5 -2
  134. modal_global_objects/images/base_images.py +28 -0
  135. modal_global_objects/mounts/python_standalone.py +2 -2
  136. modal_proto/__init__.py +1 -1
  137. modal_proto/api.proto +1288 -533
  138. modal_proto/api_grpc.py +856 -456
  139. modal_proto/api_pb2.py +2165 -1157
  140. modal_proto/api_pb2.pyi +8859 -0
  141. modal_proto/api_pb2_grpc.py +1674 -855
  142. modal_proto/api_pb2_grpc.pyi +1416 -0
  143. modal_proto/modal_api_grpc.py +149 -0
  144. modal_proto/modal_options_grpc.py +3 -0
  145. modal_proto/options_pb2.pyi +20 -0
  146. modal_proto/options_pb2_grpc.pyi +7 -0
  147. modal_proto/py.typed +0 -0
  148. modal_version/__init__.py +1 -1
  149. modal_version/_version_generated.py +2 -2
  150. modal/_asgi.py +0 -370
  151. modal/_container_entrypoint.pyi +0 -378
  152. modal/_container_exec.py +0 -128
  153. modal/_sandbox_shell.py +0 -49
  154. modal/shared_volume.py +0 -23
  155. modal/shared_volume.pyi +0 -24
  156. modal/stub.py +0 -783
  157. modal/stub.pyi +0 -332
  158. modal-0.62.16.dist-info/RECORD +0 -198
  159. modal_global_objects/images/conda.py +0 -15
  160. modal_global_objects/images/debian_slim.py +0 -15
  161. modal_global_objects/images/micromamba.py +0 -15
  162. test/__init__.py +0 -1
  163. test/aio_test.py +0 -12
  164. test/async_utils_test.py +0 -262
  165. test/blob_test.py +0 -67
  166. test/cli_imports_test.py +0 -149
  167. test/cli_test.py +0 -659
  168. test/client_test.py +0 -194
  169. test/cls_test.py +0 -630
  170. test/config_test.py +0 -137
  171. test/conftest.py +0 -1420
  172. test/container_app_test.py +0 -32
  173. test/container_test.py +0 -1389
  174. test/cpu_test.py +0 -23
  175. test/decorator_test.py +0 -85
  176. test/deprecation_test.py +0 -34
  177. test/dict_test.py +0 -33
  178. test/e2e_test.py +0 -68
  179. test/error_test.py +0 -7
  180. test/function_serialization_test.py +0 -32
  181. test/function_test.py +0 -653
  182. test/function_utils_test.py +0 -101
  183. test/gpu_test.py +0 -159
  184. test/grpc_utils_test.py +0 -141
  185. test/helpers.py +0 -42
  186. test/image_test.py +0 -669
  187. test/live_reload_test.py +0 -80
  188. test/lookup_test.py +0 -70
  189. test/mdmd_test.py +0 -329
  190. test/mount_test.py +0 -162
  191. test/mounted_files_test.py +0 -329
  192. test/network_file_system_test.py +0 -181
  193. test/notebook_test.py +0 -66
  194. test/object_test.py +0 -41
  195. test/package_utils_test.py +0 -25
  196. test/queue_test.py +0 -97
  197. test/resolver_test.py +0 -58
  198. test/retries_test.py +0 -67
  199. test/runner_test.py +0 -85
  200. test/sandbox_test.py +0 -191
  201. test/schedule_test.py +0 -15
  202. test/scheduler_placement_test.py +0 -29
  203. test/secret_test.py +0 -78
  204. test/serialization_test.py +0 -42
  205. test/stub_composition_test.py +0 -10
  206. test/stub_test.py +0 -360
  207. test/test_asgi_wrapper.py +0 -234
  208. test/token_flow_test.py +0 -18
  209. test/traceback_test.py +0 -135
  210. test/tunnel_test.py +0 -29
  211. test/utils_test.py +0 -88
  212. test/version_test.py +0 -14
  213. test/volume_test.py +0 -341
  214. test/watcher_test.py +0 -30
  215. test/webhook_test.py +0 -146
  216. /modal/{requirements.312.txt → requirements/2023.12.312.txt} +0 -0
  217. /modal/{requirements.txt → requirements/2023.12.txt} +0 -0
  218. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
  219. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
  220. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
modal/client.py CHANGED
@@ -1,32 +1,39 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import asyncio
3
+ import os
3
4
  import platform
4
5
  import warnings
5
- from typing import AsyncIterator, Awaitable, Callable, Dict, Optional, Tuple
6
+ from collections.abc import AsyncGenerator, AsyncIterator, Collection, Mapping
7
+ from typing import (
8
+ Any,
9
+ ClassVar,
10
+ Generic,
11
+ Optional,
12
+ TypeVar,
13
+ Union,
14
+ )
6
15
 
7
16
  import grpclib.client
8
- from aiohttp import ClientConnectorError, ClientResponseError
9
17
  from google.protobuf import empty_pb2
10
- from grpclib import GRPCError, Status
18
+ from google.protobuf.message import Message
11
19
  from synchronicity.async_wrap import asynccontextmanager
12
20
 
13
- from modal_proto import api_grpc, api_pb2
21
+ from modal._utils.async_utils import synchronizer
22
+ from modal_proto import api_grpc, api_pb2, modal_api_grpc
14
23
  from modal_version import __version__
15
24
 
25
+ from ._traceback import print_server_warnings
16
26
  from ._utils import async_utils
17
- from ._utils.async_utils import synchronize_api
18
- from ._utils.grpc_utils import create_channel, retry_transient_errors
19
- from ._utils.http_utils import http_client_with_tls
20
- from .config import _check_config, config, logger
21
- from .exception import AuthError, ConnectionError, DeprecationError, VersionError
27
+ from ._utils.async_utils import TaskContext, synchronize_api
28
+ from ._utils.grpc_utils import connect_channel, create_channel, retry_transient_errors
29
+ from .config import _check_config, _is_remote, config, logger
30
+ from .exception import AuthError, ClientClosed, ConnectionError
22
31
 
23
32
  HEARTBEAT_INTERVAL: float = config.get("heartbeat_interval")
24
33
  HEARTBEAT_TIMEOUT: float = HEARTBEAT_INTERVAL + 0.1
25
- CLIENT_CREATE_ATTEMPT_TIMEOUT: float = 4.0
26
- CLIENT_CREATE_TOTAL_TIMEOUT: float = 15.0
27
34
 
28
35
 
29
- def _get_metadata(client_type: int, credentials: Optional[Tuple[str, str]], version: str) -> Dict[str, str]:
36
+ def _get_metadata(client_type: int, credentials: Optional[tuple[str, str]], version: str) -> dict[str, str]:
30
37
  # This implements a simplified version of platform.platform() that's still machine-readable
31
38
  uname: platform.uname_result = platform.uname()
32
39
  if uname.system == "Darwin":
@@ -50,130 +57,89 @@ def _get_metadata(client_type: int, credentials: Optional[Tuple[str, str]], vers
50
57
  "x-modal-token-secret": token_secret,
51
58
  }
52
59
  )
53
- elif credentials and client_type == api_pb2.CLIENT_TYPE_CONTAINER:
54
- task_id, task_secret = credentials
55
- metadata.update(
56
- {
57
- "x-modal-task-id": task_id,
58
- "x-modal-task-secret": task_secret,
59
- }
60
- )
61
60
  return metadata
62
61
 
63
62
 
64
- async def _http_check(url: str, timeout: float) -> str:
65
- # Used for sanity checking connection issues
66
- try:
67
- async with http_client_with_tls(timeout=timeout) as session:
68
- async with session.get(url) as resp:
69
- return f"HTTP status: {resp.status}"
70
- except ClientResponseError as exc:
71
- return f"HTTP status: {exc.status}"
72
- except ClientConnectorError as exc:
73
- return f"HTTP exception: {exc.os_error.__class__.__name__}"
74
- except Exception as exc:
75
- return f"HTTP exception: {exc.__class__.__name__}"
76
-
77
-
78
- async def _grpc_exc_string(exc: GRPCError, method_name: str, server_url: str, timeout: float) -> str:
79
- http_status = await _http_check(server_url, timeout=timeout)
80
- return f"{method_name}: {exc.message} [gRPC status: {exc.status.name}, {http_status}]"
63
+ ReturnType = TypeVar("ReturnType")
64
+ _Value = Union[str, bytes]
65
+ _MetadataLike = Union[Mapping[str, _Value], Collection[tuple[str, _Value]]]
66
+ RequestType = TypeVar("RequestType", bound=Message)
67
+ ResponseType = TypeVar("ResponseType", bound=Message)
81
68
 
82
69
 
83
70
  class _Client:
84
- _client_from_env = None
85
- _client_from_env_lock = None
71
+ _client_from_env: ClassVar[Optional["_Client"]] = None
72
+ _client_from_env_lock: ClassVar[Optional[asyncio.Lock]] = None
73
+ _cancellation_context: TaskContext
74
+ _cancellation_context_event_loop: asyncio.AbstractEventLoop = None
75
+ _stub: Optional[api_grpc.ModalClientStub]
86
76
 
87
77
  def __init__(
88
78
  self,
89
79
  server_url: str,
90
80
  client_type: int,
91
- credentials: Optional[Tuple[str, str]],
81
+ credentials: Optional[tuple[str, str]],
92
82
  version: str = __version__,
93
83
  ):
94
- """The Modal client object is not intended to be instantiated directly by users."""
84
+ """mdmd:hidden
85
+ The Modal client object is not intended to be instantiated directly by users.
86
+ """
95
87
  self.server_url = server_url
96
88
  self.client_type = client_type
97
- self.credentials = credentials
89
+ self._credentials = credentials
98
90
  self.version = version
99
- self._authenticated = False
100
- self._pre_stop: Optional[Callable[[], Awaitable[None]]] = None
91
+ self._closed = False
101
92
  self._channel: Optional[grpclib.client.Channel] = None
102
- self._stub: Optional[api_grpc.ModalClientStub] = None
93
+ self._stub: Optional[modal_api_grpc.ModalClientModal] = None
94
+ self._snapshotted = False
95
+ self._owner_pid = None
103
96
 
104
- @property
105
- def stub(self) -> Optional[api_grpc.ModalClientStub]:
106
- """mdmd:hidden"""
107
- return self._stub
97
+ def is_closed(self) -> bool:
98
+ return self._closed
108
99
 
109
100
  @property
110
- def authenticated(self) -> bool:
101
+ def stub(self) -> modal_api_grpc.ModalClientModal:
111
102
  """mdmd:hidden"""
112
- return self._authenticated
103
+ assert self._stub
104
+ return self._stub
113
105
 
114
106
  async def _open(self):
107
+ self._closed = False
115
108
  assert self._stub is None
116
- metadata = _get_metadata(self.client_type, self.credentials, self.version)
109
+ metadata = _get_metadata(self.client_type, self._credentials, self.version)
117
110
  self._channel = create_channel(self.server_url, metadata=metadata)
118
- self._stub = api_grpc.ModalClientStub(self._channel) # type: ignore
119
-
120
- async def _close(self):
121
- if self._pre_stop is not None:
122
- logger.debug("Client: running pre-stop coroutine before shutting down")
123
- await self._pre_stop() # type: ignore
124
-
111
+ try:
112
+ await connect_channel(self._channel)
113
+ except OSError as exc:
114
+ raise ConnectionError(str(exc))
115
+ self._cancellation_context = TaskContext(grace=0.5) # allow running rpcs to finish in 0.5s when closing client
116
+ self._cancellation_context_event_loop = asyncio.get_running_loop()
117
+ await self._cancellation_context.__aenter__()
118
+ self._grpclib_stub = api_grpc.ModalClientStub(self._channel)
119
+ self._stub = modal_api_grpc.ModalClientModal(self._grpclib_stub, client=self)
120
+ self._owner_pid = os.getpid()
121
+
122
+ async def _close(self, prep_for_restore: bool = False):
123
+ logger.debug(f"Client ({id(self)}): closing")
124
+ self._closed = True
125
+ await self._cancellation_context.__aexit__(None, None, None) # wait for all rpcs to be finished/cancelled
125
126
  if self._channel is not None:
126
127
  self._channel.close()
127
128
 
129
+ if prep_for_restore:
130
+ self._snapshotted = True
131
+
128
132
  # Remove cached client.
129
133
  self.set_env_client(None)
130
134
 
131
- def set_pre_stop(self, pre_stop: Callable[[], Awaitable[None]]):
132
- """mdmd:hidden"""
133
- # hack: stub.serve() gets into a losing race with the `on_shutdown` client
134
- # teardown when an interrupt signal is received (eg. KeyboardInterrupt).
135
- # By registering a pre-stop fn stub.serve() can have its teardown
136
- # performed before the client is disconnected.
137
- #
138
- # ref: github.com/modal-labs/modal-client/pull/108
139
- self._pre_stop = pre_stop
140
-
141
- async def _init(self):
135
+ async def hello(self):
142
136
  """Connect to server and retrieve version information; raise appropriate error for various failures."""
143
- logger.debug("Client: Starting")
144
- _check_config()
145
- try:
146
- req = empty_pb2.Empty()
147
- resp = await retry_transient_errors(
148
- self.stub.ClientHello,
149
- req,
150
- attempt_timeout=CLIENT_CREATE_ATTEMPT_TIMEOUT,
151
- total_timeout=CLIENT_CREATE_TOTAL_TIMEOUT,
152
- )
153
- if resp.warning:
154
- ALARM_EMOJI = chr(0x1F6A8)
155
- warnings.warn(f"{ALARM_EMOJI} {resp.warning} {ALARM_EMOJI}", DeprecationError)
156
- self._authenticated = True
157
- except GRPCError as exc:
158
- if exc.status == Status.FAILED_PRECONDITION:
159
- raise VersionError(
160
- f"The client version ({self.version}) is too old. Please update (pip install --update modal)."
161
- )
162
- elif exc.status == Status.UNAUTHENTICATED:
163
- raise AuthError(exc.message)
164
- else:
165
- exc_string = await _grpc_exc_string(exc, "ClientHello", self.server_url, CLIENT_CREATE_TOTAL_TIMEOUT)
166
- raise ConnectionError(exc_string)
167
- except (OSError, asyncio.TimeoutError) as exc:
168
- raise ConnectionError(str(exc))
137
+ logger.debug(f"Client ({id(self)}): Starting")
138
+ resp = await retry_transient_errors(self.stub.ClientHello, empty_pb2.Empty())
139
+ print_server_warnings(resp.server_warnings)
169
140
 
170
141
  async def __aenter__(self):
171
142
  await self._open()
172
- try:
173
- await self._init()
174
- except BaseException:
175
- await self._close()
176
- raise
177
143
  return self
178
144
 
179
145
  async def __aexit__(self, exc_type, exc, tb):
@@ -189,7 +155,6 @@ class _Client:
189
155
  client = cls(server_url, api_pb2.CLIENT_TYPE_CLIENT, credentials=None)
190
156
  try:
191
157
  await client._open()
192
- # Skip client._init
193
158
  yield client
194
159
  finally:
195
160
  await client._close()
@@ -199,28 +164,15 @@ class _Client:
199
164
  """mdmd:hidden
200
165
  Singleton that is instantiated from the Modal config and reused on subsequent calls.
201
166
  """
167
+ _check_config()
168
+
202
169
  if _override_config:
203
170
  # Only used for testing
204
171
  c = _override_config
205
172
  else:
206
173
  c = config
207
174
 
208
- server_url = c["server_url"]
209
-
210
- token_id = c["token_id"]
211
- token_secret = c["token_secret"]
212
- task_id = c["task_id"]
213
- task_secret = c["task_secret"]
214
-
215
- if task_id and task_secret:
216
- client_type = api_pb2.CLIENT_TYPE_CONTAINER
217
- credentials = (task_id, task_secret)
218
- elif token_id and token_secret:
219
- client_type = api_pb2.CLIENT_TYPE_CLIENT
220
- credentials = (token_id, token_secret)
221
- else:
222
- client_type = api_pb2.CLIENT_TYPE_CLIENT
223
- credentials = None
175
+ credentials: Optional[tuple[str, str]]
224
176
 
225
177
  if cls._client_from_env_lock is None:
226
178
  cls._client_from_env_lock = asyncio.Lock()
@@ -228,50 +180,63 @@ class _Client:
228
180
  async with cls._client_from_env_lock:
229
181
  if cls._client_from_env:
230
182
  return cls._client_from_env
183
+
184
+ token_id = c["token_id"]
185
+ token_secret = c["token_secret"]
186
+ if _is_remote():
187
+ if token_id or token_secret:
188
+ warnings.warn(
189
+ "Modal tokens provided by MODAL_TOKEN_ID and MODAL_TOKEN_SECRET"
190
+ " (or through the config file) are ignored inside containers."
191
+ )
192
+ client_type = api_pb2.CLIENT_TYPE_CONTAINER
193
+ credentials = None
194
+ elif token_id and token_secret:
195
+ client_type = api_pb2.CLIENT_TYPE_CLIENT
196
+ credentials = (token_id, token_secret)
231
197
  else:
232
- client = _Client(server_url, client_type, credentials)
233
- await client._open()
234
- async_utils.on_shutdown(client._close())
235
- try:
236
- await client._init()
237
- except AuthError:
238
- if not credentials:
239
- creds_missing_msg = (
240
- "Token missing. Could not authenticate client."
241
- " If you have token credentials, see modal.com/docs/reference/modal.config for setup help."
242
- " If you are a new user, register an account at modal.com, then run `modal token new`."
243
- )
244
- raise AuthError(creds_missing_msg)
245
- else:
246
- raise
247
- cls._client_from_env = client
248
- return client
198
+ raise AuthError(
199
+ "Token missing. Could not authenticate client."
200
+ " If you have token credentials, see modal.com/docs/reference/modal.config for setup help."
201
+ " If you are a new user, register an account at modal.com, then run `modal token new`."
202
+ )
203
+
204
+ server_url = c["server_url"]
205
+ client = _Client(server_url, client_type, credentials)
206
+ await client._open()
207
+ async_utils.on_shutdown(client._close())
208
+ cls._client_from_env = client
209
+ return client
249
210
 
250
211
  @classmethod
251
212
  async def from_credentials(cls, token_id: str, token_secret: str) -> "_Client":
252
- """mdmd:hidden
213
+ """
253
214
  Constructor based on token credentials; useful for managing Modal on behalf of third-party users.
215
+
216
+ **Usage:**
217
+
218
+ ```python notest
219
+ client = modal.Client.from_credentials("my_token_id", "my_token_secret")
220
+
221
+ modal.Sandbox.create("echo", "hi", client=client, app=app)
222
+ ```
254
223
  """
224
+ _check_config()
255
225
  server_url = config["server_url"]
256
226
  client_type = api_pb2.CLIENT_TYPE_CLIENT
257
227
  credentials = (token_id, token_secret)
258
228
  client = _Client(server_url, client_type, credentials)
259
229
  await client._open()
260
- try:
261
- await client._init()
262
- except BaseException:
263
- await client._close()
264
- raise
265
230
  async_utils.on_shutdown(client._close())
266
231
  return client
267
232
 
268
233
  @classmethod
269
- async def verify(cls, server_url: str, credentials: Tuple[str, str]) -> None:
234
+ async def verify(cls, server_url: str, credentials: tuple[str, str]) -> None:
270
235
  """mdmd:hidden
271
236
  Check whether can the client can connect to this server with these credentials; raise if not.
272
237
  """
273
- async with cls(server_url, api_pb2.CLIENT_TYPE_CLIENT, credentials):
274
- pass # Will call ClientHello RPC and possibly raise AuthError or ConnectionError
238
+ async with cls(server_url, api_pb2.CLIENT_TYPE_CLIENT, credentials) as client:
239
+ await client.hello() # Will call ClientHello RPC and possibly raise AuthError or ConnectionError
275
240
 
276
241
  @classmethod
277
242
  def set_env_client(cls, client: Optional["_Client"]):
@@ -279,5 +244,146 @@ class _Client:
279
244
  # Just used from tests.
280
245
  cls._client_from_env = client
281
246
 
247
+ async def _call_safely(self, coro, readable_method: str):
248
+ """Runs coroutine wrapped in a task that's part of the client's task context
249
+
250
+ * Raises ClientClosed in case the client is closed while the coroutine is executed
251
+ * Logs warning if call is made outside of the event loop that the client is running in,
252
+ and execute without the cancellation context in that case
253
+ """
254
+
255
+ if self.is_closed():
256
+ coro.close() # prevent "was never awaited"
257
+ raise ClientClosed(id(self))
258
+
259
+ current_event_loop = asyncio.get_running_loop()
260
+ if current_event_loop == self._cancellation_context_event_loop:
261
+ # make request cancellable if we are in the same event loop as the rpc context
262
+ # this should usually be the case!
263
+ try:
264
+ request_task = self._cancellation_context.create_task(coro)
265
+ request_task.set_name(readable_method)
266
+ return await request_task
267
+ except asyncio.CancelledError:
268
+ if self.is_closed():
269
+ raise ClientClosed(id(self)) from None
270
+ raise # if the task is cancelled as part of synchronizer shutdown or similar, don't raise ClientClosed
271
+ else:
272
+ # this should be rare - mostly used in tests where rpc requests sometimes are triggered
273
+ # outside of a client context/synchronicity loop
274
+ logger.warning(f"RPC request to {readable_method} made outside of task context")
275
+ return await coro
276
+
277
+ async def _reset_on_pid_change(self):
278
+ if self._owner_pid and self._owner_pid != os.getpid():
279
+ # not calling .close() since that would also interact with stale resources
280
+ # just reset the internal state
281
+ self._channel = None
282
+ self._stub = None
283
+ self._grpclib_stub = None
284
+ self._owner_pid = None
285
+
286
+ self.set_env_client(None)
287
+ # TODO(elias): reset _cancellation_context in case ?
288
+ await self._open()
289
+
290
+ async def _get_grpclib_method(self, method_name: str) -> Any:
291
+ # safely get grcplib method that is bound to a valid channel
292
+ # This prevents usage of stale methods across forks of processes
293
+ await self._reset_on_pid_change()
294
+ return getattr(self._grpclib_stub, method_name)
295
+
296
+ @synchronizer.nowrap
297
+ async def _call_unary(
298
+ self,
299
+ method_name: str,
300
+ request: Any,
301
+ *,
302
+ timeout: Optional[float] = None,
303
+ metadata: Optional[_MetadataLike] = None,
304
+ ) -> Any:
305
+ grpclib_method = await self._get_grpclib_method(method_name)
306
+ coro = grpclib_method(request, timeout=timeout, metadata=metadata)
307
+ return await self._call_safely(coro, grpclib_method.name)
308
+
309
+ @synchronizer.nowrap
310
+ async def _call_stream(
311
+ self,
312
+ method_name: str,
313
+ request: Any,
314
+ *,
315
+ metadata: Optional[_MetadataLike],
316
+ ) -> AsyncGenerator[Any, None]:
317
+ grpclib_method = await self._get_grpclib_method(method_name)
318
+ stream_context = grpclib_method.open(metadata=metadata)
319
+ stream = await self._call_safely(stream_context.__aenter__(), f"{grpclib_method.name}.open")
320
+ try:
321
+ await self._call_safely(stream.send_message(request, end=True), f"{grpclib_method.name}.send_message")
322
+ while 1:
323
+ try:
324
+ yield await self._call_safely(stream.__anext__(), f"{grpclib_method.name}.recv")
325
+ except StopAsyncIteration:
326
+ break
327
+ except BaseException as exc:
328
+ did_handle_exception = await stream_context.__aexit__(type(exc), exc, exc.__traceback__)
329
+ if not did_handle_exception:
330
+ raise
331
+ else:
332
+ await stream_context.__aexit__(None, None, None)
333
+
282
334
 
283
335
  Client = synchronize_api(_Client)
336
+
337
+
338
+ class UnaryUnaryWrapper(Generic[RequestType, ResponseType]):
339
+ # Calls a grpclib.UnaryUnaryMethod using a specific Client instance, respecting
340
+ # if that client is closed etc. and possibly introducing Modal-specific retry logic
341
+ wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]
342
+ client: _Client
343
+
344
+ def __init__(self, wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType], client: _Client):
345
+ # we pass in the wrapped_method here to get the correct static types
346
+ # but don't use the reference directly, see `def wrapped_method` below
347
+ self._wrapped_full_name = wrapped_method.name
348
+ self._wrapped_method_name = wrapped_method.name.rsplit("/", 1)[1]
349
+ self.client = client
350
+
351
+ @property
352
+ def name(self) -> str:
353
+ return self._wrapped_full_name
354
+
355
+ async def __call__(
356
+ self,
357
+ req: RequestType,
358
+ *,
359
+ timeout: Optional[float] = None,
360
+ metadata: Optional[_MetadataLike] = None,
361
+ ) -> ResponseType:
362
+ if self.client._snapshotted:
363
+ logger.debug(f"refreshing client after snapshot for {self._wrapped_method_name}")
364
+ self.client = await _Client.from_env()
365
+ return await self.client._call_unary(self._wrapped_method_name, req, timeout=timeout, metadata=metadata)
366
+
367
+
368
+ class UnaryStreamWrapper(Generic[RequestType, ResponseType]):
369
+ wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]
370
+
371
+ def __init__(self, wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType], client: _Client):
372
+ self._wrapped_full_name = wrapped_method.name
373
+ self._wrapped_method_name = wrapped_method.name.rsplit("/", 1)[1]
374
+ self.client = client
375
+
376
+ @property
377
+ def name(self) -> str:
378
+ return self._wrapped_full_name
379
+
380
+ async def unary_stream(
381
+ self,
382
+ request,
383
+ metadata: Optional[Any] = None,
384
+ ):
385
+ if self.client._snapshotted:
386
+ logger.debug(f"refreshing client after snapshot for {self._wrapped_method_name}")
387
+ self.client = await _Client.from_env()
388
+ async for response in self.client._call_stream(self._wrapped_method_name, request, metadata=metadata):
389
+ yield response