modal 1.1.5.dev83__py3-none-any.whl → 1.3.1.dev8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (139) hide show
  1. modal/__init__.py +4 -4
  2. modal/__main__.py +4 -29
  3. modal/_billing.py +84 -0
  4. modal/_clustered_functions.py +1 -3
  5. modal/_container_entrypoint.py +33 -208
  6. modal/_functions.py +146 -121
  7. modal/_grpc_client.py +191 -0
  8. modal/_ipython.py +16 -6
  9. modal/_load_context.py +106 -0
  10. modal/_object.py +72 -21
  11. modal/_output.py +12 -14
  12. modal/_partial_function.py +31 -4
  13. modal/_resolver.py +44 -57
  14. modal/_runtime/container_io_manager.py +26 -28
  15. modal/_runtime/container_io_manager.pyi +42 -44
  16. modal/_runtime/gpu_memory_snapshot.py +9 -7
  17. modal/_runtime/user_code_event_loop.py +80 -0
  18. modal/_runtime/user_code_imports.py +236 -10
  19. modal/_serialization.py +2 -1
  20. modal/_traceback.py +4 -13
  21. modal/_tunnel.py +16 -11
  22. modal/_tunnel.pyi +25 -3
  23. modal/_utils/async_utils.py +337 -10
  24. modal/_utils/auth_token_manager.py +1 -4
  25. modal/_utils/blob_utils.py +29 -22
  26. modal/_utils/function_utils.py +20 -21
  27. modal/_utils/grpc_testing.py +6 -3
  28. modal/_utils/grpc_utils.py +223 -64
  29. modal/_utils/mount_utils.py +26 -1
  30. modal/_utils/package_utils.py +0 -1
  31. modal/_utils/rand_pb_testing.py +8 -1
  32. modal/_utils/task_command_router_client.py +524 -0
  33. modal/_vendor/cloudpickle.py +144 -48
  34. modal/app.py +215 -96
  35. modal/app.pyi +78 -37
  36. modal/billing.py +5 -0
  37. modal/builder/2025.06.txt +6 -3
  38. modal/builder/PREVIEW.txt +2 -1
  39. modal/builder/base-images.json +4 -2
  40. modal/cli/_download.py +19 -3
  41. modal/cli/cluster.py +4 -2
  42. modal/cli/config.py +3 -1
  43. modal/cli/container.py +5 -4
  44. modal/cli/dict.py +5 -2
  45. modal/cli/entry_point.py +26 -2
  46. modal/cli/environment.py +2 -16
  47. modal/cli/launch.py +1 -76
  48. modal/cli/network_file_system.py +5 -20
  49. modal/cli/queues.py +5 -4
  50. modal/cli/run.py +24 -204
  51. modal/cli/secret.py +1 -2
  52. modal/cli/shell.py +375 -0
  53. modal/cli/utils.py +1 -13
  54. modal/cli/volume.py +11 -17
  55. modal/client.py +16 -125
  56. modal/client.pyi +94 -144
  57. modal/cloud_bucket_mount.py +3 -1
  58. modal/cloud_bucket_mount.pyi +4 -0
  59. modal/cls.py +101 -64
  60. modal/cls.pyi +9 -8
  61. modal/config.py +21 -1
  62. modal/container_process.py +288 -12
  63. modal/container_process.pyi +99 -38
  64. modal/dict.py +72 -33
  65. modal/dict.pyi +88 -57
  66. modal/environments.py +16 -8
  67. modal/environments.pyi +6 -2
  68. modal/exception.py +154 -16
  69. modal/experimental/__init__.py +23 -5
  70. modal/experimental/flash.py +161 -74
  71. modal/experimental/flash.pyi +97 -49
  72. modal/file_io.py +50 -92
  73. modal/file_io.pyi +117 -89
  74. modal/functions.pyi +70 -87
  75. modal/image.py +73 -47
  76. modal/image.pyi +33 -30
  77. modal/io_streams.py +500 -149
  78. modal/io_streams.pyi +279 -189
  79. modal/mount.py +60 -45
  80. modal/mount.pyi +41 -17
  81. modal/network_file_system.py +19 -11
  82. modal/network_file_system.pyi +72 -39
  83. modal/object.pyi +114 -22
  84. modal/parallel_map.py +42 -44
  85. modal/parallel_map.pyi +9 -17
  86. modal/partial_function.pyi +4 -2
  87. modal/proxy.py +14 -6
  88. modal/proxy.pyi +10 -2
  89. modal/queue.py +45 -38
  90. modal/queue.pyi +88 -52
  91. modal/runner.py +96 -96
  92. modal/runner.pyi +44 -27
  93. modal/sandbox.py +225 -108
  94. modal/sandbox.pyi +226 -63
  95. modal/secret.py +58 -56
  96. modal/secret.pyi +28 -13
  97. modal/serving.py +7 -11
  98. modal/serving.pyi +7 -8
  99. modal/snapshot.py +29 -15
  100. modal/snapshot.pyi +18 -10
  101. modal/token_flow.py +1 -1
  102. modal/token_flow.pyi +4 -6
  103. modal/volume.py +102 -55
  104. modal/volume.pyi +125 -66
  105. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/METADATA +10 -9
  106. modal-1.3.1.dev8.dist-info/RECORD +189 -0
  107. modal_proto/api.proto +86 -30
  108. modal_proto/api_grpc.py +10 -25
  109. modal_proto/api_pb2.py +1080 -1047
  110. modal_proto/api_pb2.pyi +253 -79
  111. modal_proto/api_pb2_grpc.py +14 -48
  112. modal_proto/api_pb2_grpc.pyi +6 -18
  113. modal_proto/modal_api_grpc.py +175 -176
  114. modal_proto/{sandbox_router.proto → task_command_router.proto} +62 -45
  115. modal_proto/task_command_router_grpc.py +138 -0
  116. modal_proto/task_command_router_pb2.py +180 -0
  117. modal_proto/{sandbox_router_pb2.pyi → task_command_router_pb2.pyi} +110 -63
  118. modal_proto/task_command_router_pb2_grpc.py +272 -0
  119. modal_proto/task_command_router_pb2_grpc.pyi +100 -0
  120. modal_version/__init__.py +1 -1
  121. modal_version/__main__.py +1 -1
  122. modal/cli/programs/launch_instance_ssh.py +0 -94
  123. modal/cli/programs/run_marimo.py +0 -95
  124. modal-1.1.5.dev83.dist-info/RECORD +0 -191
  125. modal_proto/modal_options_grpc.py +0 -3
  126. modal_proto/options.proto +0 -19
  127. modal_proto/options_grpc.py +0 -3
  128. modal_proto/options_pb2.py +0 -35
  129. modal_proto/options_pb2.pyi +0 -20
  130. modal_proto/options_pb2_grpc.py +0 -4
  131. modal_proto/options_pb2_grpc.pyi +0 -7
  132. modal_proto/sandbox_router_grpc.py +0 -105
  133. modal_proto/sandbox_router_pb2.py +0 -148
  134. modal_proto/sandbox_router_pb2_grpc.py +0 -203
  135. modal_proto/sandbox_router_pb2_grpc.pyi +0 -75
  136. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/WHEEL +0 -0
  137. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/entry_points.txt +0 -0
  138. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/licenses/LICENSE +0 -0
  139. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,524 @@
1
+ # Copyright Modal Labs 2025
2
+ import asyncio
3
+ import base64
4
+ import json
5
+ import ssl
6
+ import time
7
+ import urllib.parse
8
+ import weakref
9
+ from contextlib import suppress
10
+ from typing import AsyncGenerator, Optional
11
+
12
+ import grpclib.client
13
+ import grpclib.config
14
+ import grpclib.events
15
+ from grpclib import GRPCError, Status
16
+ from grpclib.exceptions import StreamTerminatedError
17
+
18
+ from modal.config import config, logger
19
+ from modal.exception import ConflictError, ExecTimeoutError
20
+ from modal_proto import api_pb2, task_command_router_pb2 as sr_pb2
21
+ from modal_proto.task_command_router_grpc import TaskCommandRouterStub
22
+
23
+ from .._grpc_client import grpc_error_converter
24
+ from .async_utils import aclosing
25
+ from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES, connect_channel
26
+
27
+
28
+ def _b64url_decode(data: str) -> bytes:
29
+ """Decode a base64url string with missing padding tolerated."""
30
+ padding = "=" * (-len(data) % 4)
31
+ return base64.urlsafe_b64decode(data + padding)
32
+
33
+
34
+ def _parse_jwt_expiration(jwt_token: str) -> Optional[float]:
35
+ """Parse exp from a JWT without verification. Returns UNIX time seconds or None.
36
+
37
+ This is best-effort; if parsing fails or claim missing, returns None.
38
+ """
39
+ try:
40
+ parts = jwt_token.split(".")
41
+ if len(parts) != 3:
42
+ return None
43
+ payload_b = _b64url_decode(parts[1])
44
+ payload = json.loads(payload_b)
45
+ exp = payload.get("exp")
46
+ if isinstance(exp, (int, float)):
47
+ return float(exp)
48
+ except Exception:
49
+ # Avoid raising on malformed tokens; fall back to server-driven refresh logic.
50
+ logger.warning("Failed to parse JWT expiration")
51
+ return None
52
+ return None
53
+
54
+
55
+ async def call_with_retries_on_transient_errors(
56
+ func,
57
+ *,
58
+ base_delay_secs: float = 0.01,
59
+ delay_factor: float = 2,
60
+ max_retries: Optional[int] = 10,
61
+ ):
62
+ """Call func() with transient error retries and exponential backoff.
63
+
64
+ Authentication retries are expected to be handled by the caller.
65
+ """
66
+ delay_secs = base_delay_secs
67
+ num_retries = 0
68
+
69
+ async def sleep_and_update_delay_and_num_retries_remaining(e: Exception):
70
+ nonlocal delay_secs, num_retries
71
+ logger.debug(f"Retrying RPC with delay {delay_secs}s due to error: {e}")
72
+ await asyncio.sleep(delay_secs)
73
+ delay_secs *= delay_factor
74
+ num_retries += 1
75
+
76
+ while True:
77
+ try:
78
+ return await func()
79
+ except GRPCError as e:
80
+ if (max_retries is None or num_retries < max_retries) and e.status in RETRYABLE_GRPC_STATUS_CODES:
81
+ await sleep_and_update_delay_and_num_retries_remaining(e)
82
+ else:
83
+ raise e
84
+ except AttributeError as e:
85
+ # StreamTerminatedError are not properly raised in grpclib<=0.4.7
86
+ # fixed in https://github.com/vmagamedov/grpclib/issues/185
87
+ # TODO: update to newer version (>=0.4.8) once stable
88
+ if (max_retries is None or num_retries < max_retries) and "_write_appdata" in str(e):
89
+ await sleep_and_update_delay_and_num_retries_remaining(e)
90
+ else:
91
+ raise e
92
+ except StreamTerminatedError as e:
93
+ if max_retries is None or num_retries < max_retries:
94
+ await sleep_and_update_delay_and_num_retries_remaining(e)
95
+ else:
96
+ raise e
97
+ except (OSError, asyncio.TimeoutError) as e:
98
+ if max_retries is None or num_retries < max_retries:
99
+ await sleep_and_update_delay_and_num_retries_remaining(e)
100
+ else:
101
+ raise ConnectionError(str(e))
102
+
103
+
104
+ async def fetch_command_router_access(server_client, task_id: str) -> api_pb2.TaskGetCommandRouterAccessResponse:
105
+ """Fetch direct command router access info from Modal server."""
106
+ return await server_client.stub.TaskGetCommandRouterAccess(
107
+ api_pb2.TaskGetCommandRouterAccessRequest(task_id=task_id),
108
+ )
109
+
110
+
111
+ def _finalize_channel(loop, channel):
112
+ if not loop.is_closed():
113
+ # only run if loop has not shut down
114
+ # call_soon_threadsafe could throw if the loop is torn down after calling
115
+ # is_closed. This is safe to ignore.
116
+ with suppress(Exception):
117
+ loop.call_soon_threadsafe(channel.close)
118
+
119
+
120
+ class TaskCommandRouterClient:
121
+ """
122
+ Client used to talk directly to TaskCommandRouter service on worker hosts.
123
+
124
+ A new instance should be created per task.
125
+ """
126
+
127
+ @classmethod
128
+ async def try_init(
129
+ cls,
130
+ server_client,
131
+ task_id: str,
132
+ ) -> Optional["TaskCommandRouterClient"]:
133
+ """Attempt to initialize a TaskCommandRouterClient by fetching direct access.
134
+
135
+ Returns None if command router access is not enabled (FAILED_PRECONDITION).
136
+ """
137
+ try:
138
+ resp = await fetch_command_router_access(server_client, task_id)
139
+ except ConflictError:
140
+ logger.debug(f"Command router access is not enabled for task {task_id}")
141
+ return None
142
+
143
+ logger.debug(f"Using command router access for task {task_id}")
144
+
145
+ # Build and connect a channel to the task command router now that we have access info.
146
+ o = urllib.parse.urlparse(resp.url)
147
+ if o.scheme != "https":
148
+ raise ValueError(f"Task router URL must be https, got: {resp.url}")
149
+
150
+ host, _, port_str = o.netloc.partition(":")
151
+ port = int(port_str) if port_str else 443
152
+ ssl_context = ssl.create_default_context()
153
+
154
+ # Allow insecure TLS when explicitly enabled via config.
155
+ if config["task_command_router_insecure"]:
156
+ logger.warning("Using insecure TLS for task command router due to MODAL_TASK_COMMAND_ROUTER_INSECURE")
157
+ ssl_context.check_hostname = False
158
+ ssl_context.verify_mode = ssl.CERT_NONE
159
+
160
+ channel = grpclib.client.Channel(
161
+ host,
162
+ port,
163
+ ssl=ssl_context,
164
+ config=grpclib.config.Configuration(
165
+ http2_connection_window_size=64 * 1024 * 1024, # 64 MiB
166
+ http2_stream_window_size=64 * 1024 * 1024, # 64 MiB
167
+ ),
168
+ )
169
+
170
+ await connect_channel(channel)
171
+ loop = asyncio.get_running_loop()
172
+ jwt_refresh_lock = asyncio.Lock()
173
+
174
+ return cls(server_client, task_id, resp.url, resp.jwt, channel, loop, jwt_refresh_lock)
175
+
176
+ def __init__(
177
+ self,
178
+ server_client,
179
+ task_id: str,
180
+ server_url: str,
181
+ jwt: str,
182
+ channel: grpclib.client.Channel,
183
+ loop: asyncio.AbstractEventLoop,
184
+ jwt_refresh_lock: asyncio.Lock,
185
+ *,
186
+ stream_stdio_retry_delay_secs: float = 0.01,
187
+ stream_stdio_retry_delay_factor: float = 2,
188
+ stream_stdio_max_retries: int = 10,
189
+ ) -> None:
190
+ """Callers should not use this directly. Use TaskCommandRouterClient.try_init() instead."""
191
+ # Record the loop this instance is bound to so __del__ can safely schedule cleanup
192
+ # even if finalization happens from a different thread (e.g. via synchronicity).
193
+ self._loop = loop
194
+
195
+ # Attach bearer token on all requests to the worker-side router service.
196
+ self._server_client = server_client
197
+ self._task_id = task_id
198
+ self._server_url = server_url
199
+ self._jwt = jwt
200
+ self._channel = channel
201
+ # Retry configuration for stdio streaming
202
+ self.stream_stdio_retry_delay_secs = stream_stdio_retry_delay_secs
203
+ self.stream_stdio_retry_delay_factor = stream_stdio_retry_delay_factor
204
+ self.stream_stdio_max_retries = stream_stdio_max_retries
205
+
206
+ # JWT refresh coordination
207
+ self._jwt_exp: Optional[float] = _parse_jwt_expiration(jwt)
208
+ # This is passed in as an argument to ensure it's created from within the correct event loop.
209
+ self._jwt_refresh_lock = jwt_refresh_lock
210
+
211
+ self._closed = False
212
+
213
+ self._channel_finalizer = weakref.finalize(
214
+ self,
215
+ _finalize_channel,
216
+ loop,
217
+ channel,
218
+ )
219
+
220
+ self._stub = TaskCommandRouterStub(self._channel)
221
+
222
+ def _get_metadata(self):
223
+ return {"authorization": f"Bearer {self._jwt}"}
224
+
225
+ async def close(self) -> None:
226
+ """Close the client."""
227
+ if self._closed:
228
+ return
229
+
230
+ self._closed = True
231
+ self._channel.close()
232
+ if self._channel_finalizer.alive:
233
+ # skip the finalizer if we've closed the channel anyway
234
+ self._channel_finalizer.detach()
235
+
236
+ async def exec_start(self, request: sr_pb2.TaskExecStartRequest) -> sr_pb2.TaskExecStartResponse:
237
+ """Start an exec'd command, properly retrying on transient errors."""
238
+ with grpc_error_converter():
239
+ return await call_with_retries_on_transient_errors(
240
+ lambda: self._call_with_auth_retry(self._stub.TaskExecStart, request)
241
+ )
242
+
243
+ async def exec_stdio_read(
244
+ self,
245
+ task_id: str,
246
+ exec_id: str,
247
+ # Quotes around the type required for protobuf 3.19.
248
+ file_descriptor: "api_pb2.FileDescriptor.ValueType",
249
+ deadline: Optional[float] = None,
250
+ ) -> AsyncGenerator[sr_pb2.TaskExecStdioReadResponse, None]:
251
+ """Stream stdout/stderr batches from the task, properly retrying on transient errors.
252
+
253
+ Args:
254
+ task_id: The task ID of the task running the exec'd command.
255
+ exec_id: The execution ID of the command to read from.
256
+ file_descriptor: The file descriptor to read from.
257
+ deadline: The deadline by which all output must be streamed. If
258
+ None, wait forever. If the deadline is exceeded, raises an
259
+ ExecTimeoutError.
260
+ Returns:
261
+ AsyncGenerator[sr_pb2.TaskExecStdioReadResponse, None]: A stream of stdout/stderr batches.
262
+ Raises:
263
+ ExecTimeoutError: If the deadline is exceeded.
264
+ Other errors: If retries are exhausted on transient errors or if there's an error
265
+ from the RPC itself.
266
+ """
267
+ if file_descriptor == api_pb2.FILE_DESCRIPTOR_STDOUT:
268
+ sr_fd = sr_pb2.TASK_EXEC_STDIO_FILE_DESCRIPTOR_STDOUT
269
+ elif file_descriptor == api_pb2.FILE_DESCRIPTOR_STDERR:
270
+ sr_fd = sr_pb2.TASK_EXEC_STDIO_FILE_DESCRIPTOR_STDERR
271
+ elif file_descriptor == api_pb2.FILE_DESCRIPTOR_INFO or file_descriptor == api_pb2.FILE_DESCRIPTOR_UNSPECIFIED:
272
+ raise ValueError(f"Unsupported file descriptor: {file_descriptor}")
273
+ else:
274
+ raise ValueError(f"Invalid file descriptor: {file_descriptor}")
275
+
276
+ with grpc_error_converter():
277
+ async with aclosing(self._stream_stdio(task_id, exec_id, sr_fd, deadline)) as stream:
278
+ async for item in stream:
279
+ yield item
280
+
281
+ async def exec_stdin_write(
282
+ self, task_id: str, exec_id: str, offset: int, data: bytes, eof: bool
283
+ ) -> sr_pb2.TaskExecStdinWriteResponse:
284
+ """Write to the stdin stream of an exec'd command, properly retrying on transient errors.
285
+
286
+ Args:
287
+ task_id: The task ID of the task running the exec'd command.
288
+ exec_id: The execution ID of the command to write to.
289
+ offset: The offset to start writing to.
290
+ data: The data to write to the stdin stream.
291
+ eof: Whether to close the stdin stream after writing the data.
292
+ Raises:
293
+ Other errors: If retries are exhausted on transient errors or if there's an error
294
+ from the RPC itself.
295
+ """
296
+ request = sr_pb2.TaskExecStdinWriteRequest(task_id=task_id, exec_id=exec_id, offset=offset, data=data, eof=eof)
297
+ with grpc_error_converter():
298
+ return await call_with_retries_on_transient_errors(
299
+ lambda: self._call_with_auth_retry(self._stub.TaskExecStdinWrite, request)
300
+ )
301
+
302
+ async def exec_poll(
303
+ self, task_id: str, exec_id: str, deadline: Optional[float] = None
304
+ ) -> sr_pb2.TaskExecPollResponse:
305
+ """Poll for the exit status of an exec'd command, properly retrying on transient errors.
306
+
307
+ Args:
308
+ task_id: The task ID of the task running the exec'd command.
309
+ exec_id: The execution ID of the command to poll on.
310
+ Returns:
311
+ sr_pb2.TaskExecPollResponse: The exit status of the command if it has completed.
312
+
313
+ Raises:
314
+ ExecTimeoutError: If the deadline is exceeded.
315
+ Other errors: If retries are exhausted on transient errors or if there's an error
316
+ from the RPC itself.
317
+ """
318
+ request = sr_pb2.TaskExecPollRequest(task_id=task_id, exec_id=exec_id)
319
+ # The timeout here is really a backstop in the event of a hang contacting
320
+ # the command router. Poll should usually be instantaneous.
321
+ timeout = deadline - time.monotonic() if deadline is not None else None
322
+ if timeout is not None and timeout <= 0:
323
+ raise ExecTimeoutError(f"Deadline exceeded while polling for exec {exec_id}")
324
+
325
+ with grpc_error_converter():
326
+ try:
327
+ return await asyncio.wait_for(
328
+ call_with_retries_on_transient_errors(
329
+ lambda: self._call_with_auth_retry(self._stub.TaskExecPoll, request)
330
+ ),
331
+ timeout=timeout,
332
+ )
333
+ except asyncio.TimeoutError:
334
+ raise ExecTimeoutError(f"Deadline exceeded while polling for exec {exec_id}")
335
+
336
+ async def exec_wait(
337
+ self,
338
+ task_id: str,
339
+ exec_id: str,
340
+ deadline: Optional[float] = None,
341
+ ) -> sr_pb2.TaskExecWaitResponse:
342
+ """Wait for an exec'd command to exit and return the exit code, properly retrying on transient errors.
343
+
344
+ Args:
345
+ task_id: The task ID of the task running the exec'd command.
346
+ exec_id: The execution ID of the command to wait on.
347
+ Returns:
348
+ Optional[sr_pb2.TaskExecWaitResponse]: The exit code of the command.
349
+ Raises:
350
+ ExecTimeoutError: If the deadline is exceeded.
351
+ Other errors: If there's an error from the RPC itself.
352
+ """
353
+ request = sr_pb2.TaskExecWaitRequest(task_id=task_id, exec_id=exec_id)
354
+ timeout = deadline - time.monotonic() if deadline is not None else None
355
+ if timeout is not None and timeout <= 0:
356
+ raise ExecTimeoutError(f"Deadline exceeded while waiting for exec {exec_id}")
357
+
358
+ with grpc_error_converter():
359
+ try:
360
+ return await asyncio.wait_for(
361
+ call_with_retries_on_transient_errors(
362
+ # We set a 60s timeout here to avoid waiting forever if there's an unanticipated hang
363
+ # due to a networking issue. call_with_retries_on_transient_errors will retry if the
364
+ # timeout is exceeded, so we'll retry every 60s until the command exits.
365
+ #
366
+ # Safety:
367
+ # * If just the task shuts down, the task command router will return a NOT_FOUND error,
368
+ # and we'll stop retrying.
369
+ # * If the task shut down AND the worker shut down, this could
370
+ # infinitely retry. For callers without an exec deadline, this
371
+ # could hang indefinitely.
372
+ lambda: self._call_with_auth_retry(self._stub.TaskExecWait, request, timeout=60),
373
+ base_delay_secs=1, # Retry after 1s since total time is expected to be long.
374
+ delay_factor=1, # Fixed delay.
375
+ max_retries=None, # Retry forever.
376
+ ),
377
+ timeout=timeout,
378
+ )
379
+ except asyncio.TimeoutError:
380
+ raise ExecTimeoutError(f"Deadline exceeded while waiting for exec {exec_id}")
381
+
382
+ async def _refresh_jwt(self) -> None:
383
+ """Refresh JWT from the server and update internal state."""
384
+ async with self._jwt_refresh_lock:
385
+ if self._closed:
386
+ return
387
+
388
+ # If the current JWT expiration is already far enough in the future, don't refresh.
389
+ if self._jwt_exp is not None and self._jwt_exp - time.time() > 30:
390
+ # This can happen if multiple concurrent requests to the task command router
391
+ # get UNAUTHENTICATED errors and all refresh at the same time - one of them
392
+ # will win and the others will not refresh.
393
+ logger.debug(
394
+ f"Skipping JWT refresh for exec with task ID {self._task_id} "
395
+ "because its expiration is already far enough in the future"
396
+ )
397
+ return
398
+
399
+ logger.debug(f"Refreshing JWT for exec with task ID {self._task_id}")
400
+ resp = await fetch_command_router_access(self._server_client, self._task_id)
401
+ logger.debug(f"Finished refreshing JWT for exec with task ID {self._task_id}")
402
+
403
+ # Ensure the server URL remains stable for the lifetime of this client.
404
+ assert resp.url == self._server_url, "Task router URL changed during session"
405
+ self._jwt = resp.jwt
406
+ self._jwt_exp = _parse_jwt_expiration(resp.jwt)
407
+
408
+ async def _call_with_auth_retry(self, func, *args, **kwargs):
409
+ try:
410
+ return await func(*args, **kwargs, metadata=self._get_metadata())
411
+ except GRPCError as exc:
412
+ if exc.status == Status.UNAUTHENTICATED:
413
+ await self._refresh_jwt()
414
+ # Retry with the original arguments preserved
415
+ return await func(*args, **kwargs, metadata=self._get_metadata())
416
+ raise
417
+
418
+ async def _stream_stdio(
419
+ self,
420
+ task_id: str,
421
+ exec_id: str,
422
+ # Quotes around the type required for protobuf 3.19.
423
+ file_descriptor: "sr_pb2.TaskExecStdioFileDescriptor.ValueType",
424
+ deadline: Optional[float] = None,
425
+ ) -> AsyncGenerator[sr_pb2.TaskExecStdioReadResponse, None]:
426
+ """Stream stdio from the task, properly updating the offset and retrying on transient errors.
427
+ Raises ExecTimeoutError if the deadline is exceeded.
428
+ """
429
+ offset = 0
430
+ delay_secs = self.stream_stdio_retry_delay_secs
431
+ delay_factor = self.stream_stdio_retry_delay_factor
432
+ num_retries_remaining = self.stream_stdio_max_retries
433
+ # Flag to prevent infinite auth retries in the event that the JWT
434
+ # refresh yields an invalid JWT somehow or that the JWT is otherwise invalid.
435
+ did_auth_retry = False
436
+
437
+ async def sleep_and_update_delay_and_num_retries_remaining(e: Exception):
438
+ nonlocal delay_secs, num_retries_remaining
439
+ logger.debug(f"Retrying stdio read with delay {delay_secs}s due to error: {e}")
440
+ if deadline is not None and deadline - time.monotonic() <= delay_secs:
441
+ raise ExecTimeoutError(f"Deadline exceeded while streaming stdio for exec {exec_id}")
442
+
443
+ await asyncio.sleep(delay_secs)
444
+ delay_secs *= delay_factor
445
+ num_retries_remaining -= 1
446
+
447
+ while True:
448
+ timeout = max(0, deadline - time.monotonic()) if deadline is not None else None
449
+ try:
450
+ stream = self._stub.TaskExecStdioRead.open(timeout=timeout, metadata=self._get_metadata())
451
+ async with stream as s:
452
+ req = sr_pb2.TaskExecStdioReadRequest(
453
+ task_id=task_id,
454
+ exec_id=exec_id,
455
+ offset=offset,
456
+ file_descriptor=file_descriptor,
457
+ )
458
+
459
+ # Auth retry is scoped to a single refresh per streaming attempt. While auth metadata is
460
+ # sent on request start, UNAUTHENTICATED may sometimes surface during iteration,
461
+ # so we handle it at both send and receive boundaries.
462
+ try:
463
+ await s.send_message(req, end=True)
464
+ async for item in s:
465
+ # We successfully authenticated after a JWT refresh, reset the auth retry flag.
466
+ if did_auth_retry:
467
+ did_auth_retry = False
468
+ # Reset retry backoff after any successful chunk.
469
+ delay_secs = self.stream_stdio_retry_delay_secs
470
+ offset += len(item.data)
471
+ yield item
472
+ except GRPCError as exc:
473
+ if exc.status == Status.UNAUTHENTICATED and not did_auth_retry:
474
+ await self._refresh_jwt()
475
+ # Mark that we've retried authentication for this streaming attempt, to
476
+ # prevent subsequent retries.
477
+ did_auth_retry = True
478
+ continue
479
+ raise
480
+
481
+ # We successfully streamed all output.
482
+ return
483
+ except GRPCError as e:
484
+ if num_retries_remaining > 0 and e.status in RETRYABLE_GRPC_STATUS_CODES:
485
+ await sleep_and_update_delay_and_num_retries_remaining(e)
486
+ else:
487
+ raise e
488
+ except AttributeError as e:
489
+ # StreamTerminatedError are not properly raised in grpclib<=0.4.7
490
+ # fixed in https://github.com/vmagamedov/grpclib/issues/185
491
+ # TODO: update to newer version (>=0.4.8) once stable
492
+ if num_retries_remaining > 0 and "_write_appdata" in str(e):
493
+ await sleep_and_update_delay_and_num_retries_remaining(e)
494
+ else:
495
+ raise e
496
+ except StreamTerminatedError as e:
497
+ if num_retries_remaining > 0:
498
+ await sleep_and_update_delay_and_num_retries_remaining(e)
499
+ else:
500
+ raise e
501
+ except asyncio.TimeoutError as e:
502
+ if num_retries_remaining > 0:
503
+ await sleep_and_update_delay_and_num_retries_remaining(e)
504
+ else:
505
+ raise ConnectionError(str(e))
506
+ except OSError as e:
507
+ if num_retries_remaining > 0:
508
+ await sleep_and_update_delay_and_num_retries_remaining(e)
509
+ else:
510
+ raise ConnectionError(str(e))
511
+
512
+ async def mount_image(self, request: sr_pb2.TaskMountDirectoryRequest):
513
+ with grpc_error_converter():
514
+ return await call_with_retries_on_transient_errors(
515
+ lambda: self._call_with_auth_retry(self._stub.TaskMountDirectory, request)
516
+ )
517
+
518
+ async def snapshot_directory(
519
+ self, request: sr_pb2.TaskSnapshotDirectoryRequest
520
+ ) -> sr_pb2.TaskSnapshotDirectoryResponse:
521
+ with grpc_error_converter():
522
+ return await call_with_retries_on_transient_errors(
523
+ lambda: self._call_with_auth_retry(self._stub.TaskSnapshotDirectory, request)
524
+ )