modal 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modal might be problematic. Click here for more details.
- modal/_container_entrypoint.py +4 -1
- modal/_partial_function.py +28 -3
- modal/_utils/function_utils.py +4 -0
- modal/_utils/task_command_router_client.py +537 -0
- modal/app.py +93 -54
- modal/app.pyi +48 -18
- modal/cli/_download.py +19 -3
- modal/cli/cluster.py +4 -2
- modal/cli/container.py +4 -2
- modal/cli/entry_point.py +1 -0
- modal/cli/launch.py +1 -2
- modal/cli/run.py +6 -0
- modal/cli/volume.py +7 -1
- modal/client.pyi +2 -2
- modal/cls.py +5 -12
- modal/config.py +14 -0
- modal/container_process.py +283 -3
- modal/container_process.pyi +95 -32
- modal/exception.py +4 -0
- modal/experimental/flash.py +21 -47
- modal/experimental/flash.pyi +6 -20
- modal/functions.pyi +6 -6
- modal/io_streams.py +455 -122
- modal/io_streams.pyi +220 -95
- modal/partial_function.pyi +4 -1
- modal/runner.py +39 -36
- modal/runner.pyi +40 -24
- modal/sandbox.py +130 -11
- modal/sandbox.pyi +145 -9
- modal/volume.py +23 -3
- modal/volume.pyi +30 -0
- {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/METADATA +5 -5
- {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/RECORD +49 -48
- modal_proto/api.proto +2 -26
- modal_proto/api_grpc.py +0 -32
- modal_proto/api_pb2.py +327 -367
- modal_proto/api_pb2.pyi +6 -69
- modal_proto/api_pb2_grpc.py +0 -67
- modal_proto/api_pb2_grpc.pyi +0 -22
- modal_proto/modal_api_grpc.py +0 -2
- modal_proto/sandbox_router.proto +0 -4
- modal_proto/sandbox_router_pb2.pyi +0 -4
- modal_proto/task_command_router.proto +1 -1
- modal_proto/task_command_router_pb2.py +2 -2
- modal_version/__init__.py +1 -1
- {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/WHEEL +0 -0
- {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/entry_points.txt +0 -0
- {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/licenses/LICENSE +0 -0
- {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/top_level.txt +0 -0
modal/_container_entrypoint.py
CHANGED
|
@@ -450,7 +450,10 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
|
|
|
450
450
|
f"Function has {len(service.service_deps)} dependencies"
|
|
451
451
|
f" but container got {len(dep_object_ids)} object ids.\n"
|
|
452
452
|
f"Code deps: {service.service_deps}\n"
|
|
453
|
-
f"Object ids: {dep_object_ids}"
|
|
453
|
+
f"Object ids: {dep_object_ids}\n"
|
|
454
|
+
"\n"
|
|
455
|
+
"This can happen if you are defining Modal objects under a conditional statement "
|
|
456
|
+
"that evaluates differently in the local and remote environments."
|
|
454
457
|
)
|
|
455
458
|
for object_id, obj in zip(dep_object_ids, service.service_deps):
|
|
456
459
|
metadata: Message = container_app.object_handle_metadata[object_id]
|
modal/_partial_function.py
CHANGED
|
@@ -19,7 +19,7 @@ from ._functions import _Function
|
|
|
19
19
|
from ._utils.async_utils import synchronizer
|
|
20
20
|
from ._utils.deprecation import deprecation_warning
|
|
21
21
|
from ._utils.function_utils import callable_has_non_self_params
|
|
22
|
-
from .config import logger
|
|
22
|
+
from .config import config, logger
|
|
23
23
|
from .exception import InvalidError
|
|
24
24
|
|
|
25
25
|
MAX_MAX_BATCH_SIZE = 1000
|
|
@@ -93,6 +93,26 @@ NullaryFuncOrMethod = Union[Callable[[], Any], Callable[[Any], Any]]
|
|
|
93
93
|
NullaryMethod = Callable[[Any], Any]
|
|
94
94
|
|
|
95
95
|
|
|
96
|
+
def verify_concurrent_params(params: _PartialFunctionParams, is_flash: bool = False) -> None:
|
|
97
|
+
def _verify_concurrent_params_with_flash_settings(params: _PartialFunctionParams) -> None:
|
|
98
|
+
if params.max_concurrent_inputs is not None:
|
|
99
|
+
raise TypeError(
|
|
100
|
+
"@modal.concurrent(max_inputs=...) is not yet supported for Flash functions. "
|
|
101
|
+
"Use `@modal.concurrent(target_inputs=...)` instead."
|
|
102
|
+
)
|
|
103
|
+
if params.target_concurrent_inputs is None:
|
|
104
|
+
raise TypeError("`@modal.concurrent()` missing required argument: `target_inputs`.")
|
|
105
|
+
|
|
106
|
+
def _verify_concurrent_params(params: _PartialFunctionParams) -> None:
|
|
107
|
+
if params.max_concurrent_inputs is None:
|
|
108
|
+
raise TypeError("`@modal.concurrent()` missing required argument: `max_inputs`.")
|
|
109
|
+
|
|
110
|
+
if is_flash:
|
|
111
|
+
_verify_concurrent_params_with_flash_settings(params)
|
|
112
|
+
else:
|
|
113
|
+
_verify_concurrent_params(params)
|
|
114
|
+
|
|
115
|
+
|
|
96
116
|
class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
|
|
97
117
|
"""Object produced by a decorator in the `modal` namespace
|
|
98
118
|
|
|
@@ -378,6 +398,7 @@ def _fastapi_endpoint(
|
|
|
378
398
|
method=method,
|
|
379
399
|
web_endpoint_docs=docs,
|
|
380
400
|
requested_suffix=label or "",
|
|
401
|
+
ephemeral_suffix=config.get("dev_suffix"),
|
|
381
402
|
async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
|
|
382
403
|
custom_domains=_parse_custom_domains(custom_domains),
|
|
383
404
|
requires_proxy_auth=requires_proxy_auth,
|
|
@@ -446,6 +467,7 @@ def _web_endpoint(
|
|
|
446
467
|
method=method,
|
|
447
468
|
web_endpoint_docs=docs,
|
|
448
469
|
requested_suffix=label or "",
|
|
470
|
+
ephemeral_suffix=config.get("dev_suffix"),
|
|
449
471
|
async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
|
|
450
472
|
custom_domains=_parse_custom_domains(custom_domains),
|
|
451
473
|
requires_proxy_auth=requires_proxy_auth,
|
|
@@ -505,6 +527,7 @@ def _asgi_app(
|
|
|
505
527
|
webhook_config = api_pb2.WebhookConfig(
|
|
506
528
|
type=api_pb2.WEBHOOK_TYPE_ASGI_APP,
|
|
507
529
|
requested_suffix=label or "",
|
|
530
|
+
ephemeral_suffix=config.get("dev_suffix"),
|
|
508
531
|
async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
|
|
509
532
|
custom_domains=_parse_custom_domains(custom_domains),
|
|
510
533
|
requires_proxy_auth=requires_proxy_auth,
|
|
@@ -562,6 +585,7 @@ def _wsgi_app(
|
|
|
562
585
|
webhook_config = api_pb2.WebhookConfig(
|
|
563
586
|
type=api_pb2.WEBHOOK_TYPE_WSGI_APP,
|
|
564
587
|
requested_suffix=label or "",
|
|
588
|
+
ephemeral_suffix=config.get("dev_suffix"),
|
|
565
589
|
async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
|
|
566
590
|
custom_domains=_parse_custom_domains(custom_domains),
|
|
567
591
|
requires_proxy_auth=requires_proxy_auth,
|
|
@@ -623,6 +647,7 @@ def _web_server(
|
|
|
623
647
|
webhook_config = api_pb2.WebhookConfig(
|
|
624
648
|
type=api_pb2.WEBHOOK_TYPE_WEB_SERVER,
|
|
625
649
|
requested_suffix=label or "",
|
|
650
|
+
ephemeral_suffix=config.get("dev_suffix"),
|
|
626
651
|
async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
|
|
627
652
|
custom_domains=_parse_custom_domains(custom_domains),
|
|
628
653
|
web_server_port=port,
|
|
@@ -760,7 +785,7 @@ def _batched(
|
|
|
760
785
|
def _concurrent(
|
|
761
786
|
_warn_parentheses_missing=None, # mdmd:line-hidden
|
|
762
787
|
*,
|
|
763
|
-
max_inputs: int, # Hard limit on each container's input concurrency
|
|
788
|
+
max_inputs: Optional[int] = None, # Hard limit on each container's input concurrency
|
|
764
789
|
target_inputs: Optional[int] = None, # Input concurrency that Modal's autoscaler should target
|
|
765
790
|
) -> Callable[
|
|
766
791
|
[Union[Callable[P, ReturnType], _PartialFunction[P, ReturnType, ReturnType]]],
|
|
@@ -812,7 +837,7 @@ def _concurrent(
|
|
|
812
837
|
"Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@modal.concurrent()`."
|
|
813
838
|
)
|
|
814
839
|
|
|
815
|
-
if target_inputs and target_inputs > max_inputs:
|
|
840
|
+
if max_inputs is not None and target_inputs is not None and target_inputs > max_inputs:
|
|
816
841
|
raise InvalidError("`target_inputs` parameter cannot be greater than `max_inputs`.")
|
|
817
842
|
|
|
818
843
|
flags = _PartialFunctionFlags.CONCURRENT
|
modal/_utils/function_utils.py
CHANGED
|
@@ -75,6 +75,10 @@ def is_global_object(object_qual_name: str):
|
|
|
75
75
|
return "<locals>" not in object_qual_name.split(".")
|
|
76
76
|
|
|
77
77
|
|
|
78
|
+
def is_flash_object(experimental_options: Optional[dict[str, Any]]) -> bool:
|
|
79
|
+
return experimental_options.get("flash", False) if experimental_options else False
|
|
80
|
+
|
|
81
|
+
|
|
78
82
|
def is_method_fn(object_qual_name: str):
|
|
79
83
|
# methods have names like Cls.foo.
|
|
80
84
|
if "<locals>" in object_qual_name:
|
|
@@ -0,0 +1,537 @@
|
|
|
1
|
+
# Copyright Modal Labs 2025
|
|
2
|
+
import asyncio
|
|
3
|
+
import base64
|
|
4
|
+
import json
|
|
5
|
+
import ssl
|
|
6
|
+
import time
|
|
7
|
+
import urllib.parse
|
|
8
|
+
from typing import AsyncIterator, Optional
|
|
9
|
+
|
|
10
|
+
import grpclib.client
|
|
11
|
+
import grpclib.config
|
|
12
|
+
import grpclib.events
|
|
13
|
+
from grpclib import GRPCError, Status
|
|
14
|
+
from grpclib.exceptions import StreamTerminatedError
|
|
15
|
+
|
|
16
|
+
from modal.config import config, logger
|
|
17
|
+
from modal.exception import ExecTimeoutError
|
|
18
|
+
from modal_proto import api_pb2, task_command_router_pb2 as sr_pb2
|
|
19
|
+
from modal_proto.task_command_router_grpc import TaskCommandRouterStub
|
|
20
|
+
|
|
21
|
+
from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES, connect_channel, retry_transient_errors
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _b64url_decode(data: str) -> bytes:
|
|
25
|
+
"""Decode a base64url string with missing padding tolerated."""
|
|
26
|
+
padding = "=" * (-len(data) % 4)
|
|
27
|
+
return base64.urlsafe_b64decode(data + padding)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _parse_jwt_expiration(jwt_token: str) -> Optional[float]:
|
|
31
|
+
"""Parse exp from a JWT without verification. Returns UNIX time seconds or None.
|
|
32
|
+
|
|
33
|
+
This is best-effort; if parsing fails or claim missing, returns None.
|
|
34
|
+
"""
|
|
35
|
+
try:
|
|
36
|
+
parts = jwt_token.split(".")
|
|
37
|
+
if len(parts) != 3:
|
|
38
|
+
return None
|
|
39
|
+
payload_b = _b64url_decode(parts[1])
|
|
40
|
+
payload = json.loads(payload_b)
|
|
41
|
+
exp = payload.get("exp")
|
|
42
|
+
if isinstance(exp, (int, float)):
|
|
43
|
+
return float(exp)
|
|
44
|
+
except Exception:
|
|
45
|
+
# Avoid raising on malformed tokens; fall back to server-driven refresh logic.
|
|
46
|
+
logger.warning("Failed to parse JWT expiration")
|
|
47
|
+
return None
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
async def call_with_retries_on_transient_errors(
|
|
52
|
+
func,
|
|
53
|
+
*,
|
|
54
|
+
base_delay_secs: float = 0.01,
|
|
55
|
+
delay_factor: float = 2,
|
|
56
|
+
max_retries: Optional[int] = 10,
|
|
57
|
+
):
|
|
58
|
+
"""Call func() with transient error retries and exponential backoff.
|
|
59
|
+
|
|
60
|
+
Authentication retries are expected to be handled by the caller.
|
|
61
|
+
"""
|
|
62
|
+
delay_secs = base_delay_secs
|
|
63
|
+
num_retries = 0
|
|
64
|
+
|
|
65
|
+
async def sleep_and_update_delay_and_num_retries_remaining(e: Exception):
|
|
66
|
+
nonlocal delay_secs, num_retries
|
|
67
|
+
logger.debug(f"Retrying RPC with delay {delay_secs}s due to error: {e}")
|
|
68
|
+
await asyncio.sleep(delay_secs)
|
|
69
|
+
delay_secs *= delay_factor
|
|
70
|
+
num_retries += 1
|
|
71
|
+
|
|
72
|
+
while True:
|
|
73
|
+
try:
|
|
74
|
+
return await func()
|
|
75
|
+
except GRPCError as e:
|
|
76
|
+
if (max_retries is None or num_retries < max_retries) and e.status in RETRYABLE_GRPC_STATUS_CODES:
|
|
77
|
+
await sleep_and_update_delay_and_num_retries_remaining(e)
|
|
78
|
+
else:
|
|
79
|
+
raise e
|
|
80
|
+
except AttributeError as e:
|
|
81
|
+
# StreamTerminatedError are not properly raised in grpclib<=0.4.7
|
|
82
|
+
# fixed in https://github.com/vmagamedov/grpclib/issues/185
|
|
83
|
+
# TODO: update to newer version (>=0.4.8) once stable
|
|
84
|
+
if (max_retries is None or num_retries < max_retries) and "_write_appdata" in str(e):
|
|
85
|
+
await sleep_and_update_delay_and_num_retries_remaining(e)
|
|
86
|
+
else:
|
|
87
|
+
raise e
|
|
88
|
+
except StreamTerminatedError as e:
|
|
89
|
+
if max_retries is None or num_retries < max_retries:
|
|
90
|
+
await sleep_and_update_delay_and_num_retries_remaining(e)
|
|
91
|
+
else:
|
|
92
|
+
raise e
|
|
93
|
+
except (OSError, asyncio.TimeoutError) as e:
|
|
94
|
+
if max_retries is None or num_retries < max_retries:
|
|
95
|
+
await sleep_and_update_delay_and_num_retries_remaining(e)
|
|
96
|
+
else:
|
|
97
|
+
raise ConnectionError(str(e))
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
async def fetch_command_router_access(server_client, task_id: str) -> api_pb2.TaskGetCommandRouterAccessResponse:
|
|
101
|
+
"""Fetch direct command router access info from Modal server."""
|
|
102
|
+
return await retry_transient_errors(
|
|
103
|
+
server_client.stub.TaskGetCommandRouterAccess,
|
|
104
|
+
api_pb2.TaskGetCommandRouterAccessRequest(task_id=task_id),
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class TaskCommandRouterClient:
|
|
109
|
+
"""
|
|
110
|
+
Client used to talk directly to TaskCommandRouter service on worker hosts.
|
|
111
|
+
|
|
112
|
+
A new instance should be created per task.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
async def try_init(
|
|
117
|
+
cls,
|
|
118
|
+
server_client,
|
|
119
|
+
task_id: str,
|
|
120
|
+
) -> Optional["TaskCommandRouterClient"]:
|
|
121
|
+
"""Attempt to initialize a TaskCommandRouterClient by fetching direct access.
|
|
122
|
+
|
|
123
|
+
Returns None if command router access is not enabled (FAILED_PRECONDITION).
|
|
124
|
+
"""
|
|
125
|
+
try:
|
|
126
|
+
resp = await fetch_command_router_access(server_client, task_id)
|
|
127
|
+
except GRPCError as exc:
|
|
128
|
+
if exc.status == Status.FAILED_PRECONDITION:
|
|
129
|
+
logger.debug(f"Command router access is not enabled for task {task_id}")
|
|
130
|
+
return None
|
|
131
|
+
raise
|
|
132
|
+
|
|
133
|
+
logger.debug(f"Using command router access for task {task_id}")
|
|
134
|
+
|
|
135
|
+
# Build and connect a channel to the task command router now that we have access info.
|
|
136
|
+
o = urllib.parse.urlparse(resp.url)
|
|
137
|
+
if o.scheme != "https":
|
|
138
|
+
raise ValueError(f"Task router URL must be https, got: {resp.url}")
|
|
139
|
+
|
|
140
|
+
host, _, port_str = o.netloc.partition(":")
|
|
141
|
+
port = int(port_str) if port_str else 443
|
|
142
|
+
ssl_context = ssl.create_default_context()
|
|
143
|
+
|
|
144
|
+
# Allow insecure TLS when explicitly enabled via config.
|
|
145
|
+
if config["task_command_router_insecure"]:
|
|
146
|
+
logger.warning("Using insecure TLS for task command router due to MODAL_TASK_COMMAND_ROUTER_INSECURE")
|
|
147
|
+
ssl_context.check_hostname = False
|
|
148
|
+
ssl_context.verify_mode = ssl.CERT_NONE
|
|
149
|
+
|
|
150
|
+
channel = grpclib.client.Channel(
|
|
151
|
+
host,
|
|
152
|
+
port,
|
|
153
|
+
ssl=ssl_context,
|
|
154
|
+
config=grpclib.config.Configuration(
|
|
155
|
+
http2_connection_window_size=64 * 1024 * 1024, # 64 MiB
|
|
156
|
+
http2_stream_window_size=64 * 1024 * 1024, # 64 MiB
|
|
157
|
+
),
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
await connect_channel(channel)
|
|
161
|
+
|
|
162
|
+
return cls(server_client, task_id, resp.url, resp.jwt, channel)
|
|
163
|
+
|
|
164
|
+
def __init__(
|
|
165
|
+
self,
|
|
166
|
+
server_client,
|
|
167
|
+
task_id: str,
|
|
168
|
+
server_url: str,
|
|
169
|
+
jwt: str,
|
|
170
|
+
channel: grpclib.client.Channel,
|
|
171
|
+
*,
|
|
172
|
+
stream_stdio_retry_delay_secs: float = 0.01,
|
|
173
|
+
stream_stdio_retry_delay_factor: float = 2,
|
|
174
|
+
stream_stdio_max_retries: int = 10,
|
|
175
|
+
) -> None:
|
|
176
|
+
"""Callers should not use this directly. Use TaskCommandRouterClient.try_init() instead."""
|
|
177
|
+
# Attach bearer token on all requests to the worker-side router service.
|
|
178
|
+
self._server_client = server_client
|
|
179
|
+
self._task_id = task_id
|
|
180
|
+
self._server_url = server_url
|
|
181
|
+
self._jwt = jwt
|
|
182
|
+
self._channel = channel
|
|
183
|
+
# Retry configuration for stdio streaming
|
|
184
|
+
self.stream_stdio_retry_delay_secs = stream_stdio_retry_delay_secs
|
|
185
|
+
self.stream_stdio_retry_delay_factor = stream_stdio_retry_delay_factor
|
|
186
|
+
self.stream_stdio_max_retries = stream_stdio_max_retries
|
|
187
|
+
|
|
188
|
+
# JWT refresh coordination
|
|
189
|
+
self._jwt_exp: Optional[float] = _parse_jwt_expiration(jwt)
|
|
190
|
+
self._jwt_refresh_lock = asyncio.Lock()
|
|
191
|
+
self._jwt_refresh_event = asyncio.Event()
|
|
192
|
+
self._closed = False
|
|
193
|
+
|
|
194
|
+
# Start background task to eagerly refresh JWT 30s before expiration.
|
|
195
|
+
self._jwt_refresh_task = asyncio.create_task(self._jwt_refresh_loop())
|
|
196
|
+
|
|
197
|
+
async def send_request(event: grpclib.events.SendRequest) -> None:
|
|
198
|
+
# This will get the most recent JWT for every request. No need to
|
|
199
|
+
# lock _jwt_refresh_lock: reads and writes happen on the
|
|
200
|
+
# single-threaded event loop and variable assignment is atomic.
|
|
201
|
+
event.metadata["authorization"] = f"Bearer {self._jwt}"
|
|
202
|
+
|
|
203
|
+
grpclib.events.listen(self._channel, grpclib.events.SendRequest, send_request)
|
|
204
|
+
|
|
205
|
+
self._stub = TaskCommandRouterStub(self._channel)
|
|
206
|
+
|
|
207
|
+
def __del__(self) -> None:
|
|
208
|
+
"""Clean up the client when it's garbage collected."""
|
|
209
|
+
if self._closed:
|
|
210
|
+
return
|
|
211
|
+
|
|
212
|
+
self._jwt_refresh_task.cancel()
|
|
213
|
+
|
|
214
|
+
try:
|
|
215
|
+
self._channel.close()
|
|
216
|
+
except Exception:
|
|
217
|
+
pass
|
|
218
|
+
|
|
219
|
+
async def close(self) -> None:
|
|
220
|
+
"""Close the client and stop the background JWT refresh task."""
|
|
221
|
+
if self._closed:
|
|
222
|
+
return
|
|
223
|
+
|
|
224
|
+
self._closed = True
|
|
225
|
+
self._jwt_refresh_task.cancel()
|
|
226
|
+
try:
|
|
227
|
+
logger.debug(f"Waiting for JWT refresh task to complete for exec with task ID {self._task_id}")
|
|
228
|
+
await self._jwt_refresh_task
|
|
229
|
+
except asyncio.CancelledError:
|
|
230
|
+
pass
|
|
231
|
+
self._channel.close()
|
|
232
|
+
|
|
233
|
+
async def exec_start(self, request: sr_pb2.TaskExecStartRequest) -> sr_pb2.TaskExecStartResponse:
|
|
234
|
+
"""Start an exec'd command, properly retrying on transient errors."""
|
|
235
|
+
return await call_with_retries_on_transient_errors(
|
|
236
|
+
lambda: self._call_with_auth_retry(self._stub.TaskExecStart, request)
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
async def exec_stdio_read(
|
|
240
|
+
self,
|
|
241
|
+
task_id: str,
|
|
242
|
+
exec_id: str,
|
|
243
|
+
# Quotes around the type required for protobuf 3.19.
|
|
244
|
+
file_descriptor: "api_pb2.FileDescriptor.ValueType",
|
|
245
|
+
deadline: Optional[float] = None,
|
|
246
|
+
) -> AsyncIterator[sr_pb2.TaskExecStdioReadResponse]:
|
|
247
|
+
"""Stream stdout/stderr batches from the task, properly retrying on transient errors.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
task_id: The task ID of the task running the exec'd command.
|
|
251
|
+
exec_id: The execution ID of the command to read from.
|
|
252
|
+
file_descriptor: The file descriptor to read from.
|
|
253
|
+
deadline: The deadline by which all output must be streamed. If
|
|
254
|
+
None, wait forever. If the deadline is exceeded, raises an
|
|
255
|
+
ExecTimeoutError.
|
|
256
|
+
Returns:
|
|
257
|
+
AsyncIterator[sr_pb2.TaskExecStdioReadResponse]: A stream of stdout/stderr batches.
|
|
258
|
+
Raises:
|
|
259
|
+
ExecTimeoutError: If the deadline is exceeded.
|
|
260
|
+
Other errors: If retries are exhausted on transient errors or if there's an error
|
|
261
|
+
from the RPC itself.
|
|
262
|
+
"""
|
|
263
|
+
if file_descriptor == api_pb2.FILE_DESCRIPTOR_STDOUT:
|
|
264
|
+
sr_fd = sr_pb2.TASK_EXEC_STDIO_FILE_DESCRIPTOR_STDOUT
|
|
265
|
+
elif file_descriptor == api_pb2.FILE_DESCRIPTOR_STDERR:
|
|
266
|
+
sr_fd = sr_pb2.TASK_EXEC_STDIO_FILE_DESCRIPTOR_STDERR
|
|
267
|
+
elif file_descriptor == api_pb2.FILE_DESCRIPTOR_INFO or file_descriptor == api_pb2.FILE_DESCRIPTOR_UNSPECIFIED:
|
|
268
|
+
raise ValueError(f"Unsupported file descriptor: {file_descriptor}")
|
|
269
|
+
else:
|
|
270
|
+
raise ValueError(f"Invalid file descriptor: {file_descriptor}")
|
|
271
|
+
|
|
272
|
+
async for item in self._stream_stdio(task_id, exec_id, sr_fd, deadline):
|
|
273
|
+
yield item
|
|
274
|
+
|
|
275
|
+
async def exec_stdin_write(
|
|
276
|
+
self, task_id: str, exec_id: str, offset: int, data: bytes, eof: bool
|
|
277
|
+
) -> sr_pb2.TaskExecStdinWriteResponse:
|
|
278
|
+
"""Write to the stdin stream of an exec'd command, properly retrying on transient errors.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
task_id: The task ID of the task running the exec'd command.
|
|
282
|
+
exec_id: The execution ID of the command to write to.
|
|
283
|
+
offset: The offset to start writing to.
|
|
284
|
+
data: The data to write to the stdin stream.
|
|
285
|
+
eof: Whether to close the stdin stream after writing the data.
|
|
286
|
+
Raises:
|
|
287
|
+
Other errors: If retries are exhausted on transient errors or if there's an error
|
|
288
|
+
from the RPC itself.
|
|
289
|
+
"""
|
|
290
|
+
request = sr_pb2.TaskExecStdinWriteRequest(task_id=task_id, exec_id=exec_id, offset=offset, data=data, eof=eof)
|
|
291
|
+
return await call_with_retries_on_transient_errors(
|
|
292
|
+
lambda: self._call_with_auth_retry(self._stub.TaskExecStdinWrite, request)
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
async def exec_poll(
|
|
296
|
+
self, task_id: str, exec_id: str, deadline: Optional[float] = None
|
|
297
|
+
) -> sr_pb2.TaskExecPollResponse:
|
|
298
|
+
"""Poll for the exit status of an exec'd command, properly retrying on transient errors.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
task_id: The task ID of the task running the exec'd command.
|
|
302
|
+
exec_id: The execution ID of the command to poll on.
|
|
303
|
+
Returns:
|
|
304
|
+
sr_pb2.TaskExecPollResponse: The exit status of the command if it has completed.
|
|
305
|
+
|
|
306
|
+
Raises:
|
|
307
|
+
ExecTimeoutError: If the deadline is exceeded.
|
|
308
|
+
Other errors: If retries are exhausted on transient errors or if there's an error
|
|
309
|
+
from the RPC itself.
|
|
310
|
+
"""
|
|
311
|
+
request = sr_pb2.TaskExecPollRequest(task_id=task_id, exec_id=exec_id)
|
|
312
|
+
# The timeout here is really a backstop in the event of a hang contacting
|
|
313
|
+
# the command router. Poll should usually be instantaneous.
|
|
314
|
+
timeout = deadline - time.monotonic() if deadline is not None else None
|
|
315
|
+
if timeout is not None and timeout <= 0:
|
|
316
|
+
raise ExecTimeoutError(f"Deadline exceeded while polling for exec {exec_id}")
|
|
317
|
+
try:
|
|
318
|
+
return await asyncio.wait_for(
|
|
319
|
+
call_with_retries_on_transient_errors(
|
|
320
|
+
lambda: self._call_with_auth_retry(self._stub.TaskExecPoll, request)
|
|
321
|
+
),
|
|
322
|
+
timeout=timeout,
|
|
323
|
+
)
|
|
324
|
+
except asyncio.TimeoutError:
|
|
325
|
+
raise ExecTimeoutError(f"Deadline exceeded while polling for exec {exec_id}")
|
|
326
|
+
|
|
327
|
+
async def exec_wait(
|
|
328
|
+
self,
|
|
329
|
+
task_id: str,
|
|
330
|
+
exec_id: str,
|
|
331
|
+
deadline: Optional[float] = None,
|
|
332
|
+
) -> sr_pb2.TaskExecWaitResponse:
|
|
333
|
+
"""Wait for an exec'd command to exit and return the exit code, properly retrying on transient errors.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
task_id: The task ID of the task running the exec'd command.
|
|
337
|
+
exec_id: The execution ID of the command to wait on.
|
|
338
|
+
Returns:
|
|
339
|
+
Optional[sr_pb2.TaskExecWaitResponse]: The exit code of the command.
|
|
340
|
+
Raises:
|
|
341
|
+
ExecTimeoutError: If the deadline is exceeded.
|
|
342
|
+
Other errors: If there's an error from the RPC itself.
|
|
343
|
+
"""
|
|
344
|
+
request = sr_pb2.TaskExecWaitRequest(task_id=task_id, exec_id=exec_id)
|
|
345
|
+
timeout = deadline - time.monotonic() if deadline is not None else None
|
|
346
|
+
if timeout is not None and timeout <= 0:
|
|
347
|
+
raise ExecTimeoutError(f"Deadline exceeded while waiting for exec {exec_id}")
|
|
348
|
+
try:
|
|
349
|
+
return await asyncio.wait_for(
|
|
350
|
+
call_with_retries_on_transient_errors(
|
|
351
|
+
# We set a 60s timeout here to avoid waiting forever if there's an unanticipated hang
|
|
352
|
+
# due to a networking issue. call_with_retries_on_transient_errors will retry if the
|
|
353
|
+
# timeout is exceeded, so we'll retry every 60s until the command exits.
|
|
354
|
+
#
|
|
355
|
+
# Safety:
|
|
356
|
+
# * If just the task shuts down, the task command router will return a NOT_FOUND error,
|
|
357
|
+
# and we'll stop retrying.
|
|
358
|
+
# * If the task shut down AND the worker shut down, this could
|
|
359
|
+
# infinitely retry. For callers without an exec deadline, this
|
|
360
|
+
# could hang indefinitely.
|
|
361
|
+
lambda: self._call_with_auth_retry(self._stub.TaskExecWait, request, timeout=60),
|
|
362
|
+
base_delay_secs=1, # Retry after 1s since total time is expected to be long.
|
|
363
|
+
delay_factor=1, # Fixed delay.
|
|
364
|
+
max_retries=None, # Retry forever.
|
|
365
|
+
),
|
|
366
|
+
timeout=timeout,
|
|
367
|
+
)
|
|
368
|
+
except asyncio.TimeoutError:
|
|
369
|
+
raise ExecTimeoutError(f"Deadline exceeded while waiting for exec {exec_id}")
|
|
370
|
+
|
|
371
|
+
async def _refresh_jwt(self) -> None:
|
|
372
|
+
"""Refresh JWT from the server and update internal state.
|
|
373
|
+
|
|
374
|
+
Concurrency-safe: only one refresh runs at a time.
|
|
375
|
+
"""
|
|
376
|
+
async with self._jwt_refresh_lock:
|
|
377
|
+
if self._closed:
|
|
378
|
+
return
|
|
379
|
+
|
|
380
|
+
# If the current JWT expiration is already far enough in the future, don't refresh.
|
|
381
|
+
if self._jwt_exp is not None and self._jwt_exp - time.time() > 30:
|
|
382
|
+
# This can happen if multiple concurrent requests to the task command router
|
|
383
|
+
# get UNAUTHENTICATED errors and all refresh at the same time - one of them
|
|
384
|
+
# will win and the others will not refresh.
|
|
385
|
+
logger.debug(
|
|
386
|
+
f"Skipping JWT refresh for exec with task ID {self._task_id} "
|
|
387
|
+
"because its expiration is already far enough in the future"
|
|
388
|
+
)
|
|
389
|
+
return
|
|
390
|
+
|
|
391
|
+
resp = await fetch_command_router_access(self._server_client, self._task_id)
|
|
392
|
+
# Ensure the server URL remains stable for the lifetime of this client.
|
|
393
|
+
assert resp.url == self._server_url, "Task router URL changed during session"
|
|
394
|
+
self._jwt = resp.jwt
|
|
395
|
+
self._jwt_exp = _parse_jwt_expiration(resp.jwt)
|
|
396
|
+
# Wake up the background loop to recompute its next sleep.
|
|
397
|
+
self._jwt_refresh_event.set()
|
|
398
|
+
|
|
399
|
+
async def _call_with_auth_retry(self, func, *args, **kwargs):
|
|
400
|
+
try:
|
|
401
|
+
return await func(*args, **kwargs)
|
|
402
|
+
except GRPCError as exc:
|
|
403
|
+
if exc.status == Status.UNAUTHENTICATED:
|
|
404
|
+
await self._refresh_jwt()
|
|
405
|
+
# Retry with the original arguments preserved
|
|
406
|
+
return await func(*args, **kwargs)
|
|
407
|
+
raise
|
|
408
|
+
|
|
409
|
+
async def _jwt_refresh_loop(self) -> None:
|
|
410
|
+
"""Background task that refreshes JWT 30 seconds before expiration.
|
|
411
|
+
|
|
412
|
+
Uses an event to wake early when a manual refresh happens or token changes.
|
|
413
|
+
"""
|
|
414
|
+
while not self._closed:
|
|
415
|
+
try:
|
|
416
|
+
exp = self._jwt_exp
|
|
417
|
+
now = time.time()
|
|
418
|
+
if exp is None:
|
|
419
|
+
# Unknown expiration: re-check periodically or until event wakes us.
|
|
420
|
+
sleep_s = 60.0
|
|
421
|
+
else:
|
|
422
|
+
refresh_at = exp - 30.0
|
|
423
|
+
sleep_s = max(refresh_at - now, 0.0)
|
|
424
|
+
|
|
425
|
+
self._jwt_refresh_event.clear()
|
|
426
|
+
if sleep_s > 0:
|
|
427
|
+
try:
|
|
428
|
+
logger.debug(f"Waiting for JWT refresh for {sleep_s}s for exec with task ID {self._task_id}")
|
|
429
|
+
# Wait until it's time to refresh, unless woken early.
|
|
430
|
+
await asyncio.wait_for(self._jwt_refresh_event.wait(), timeout=sleep_s)
|
|
431
|
+
logger.debug(f"Stopped waiting for JWT refresh for exec with task ID {self._task_id}")
|
|
432
|
+
# Event fired (e.g., token changed) -> recompute timings.
|
|
433
|
+
continue
|
|
434
|
+
except asyncio.TimeoutError:
|
|
435
|
+
logger.debug(f"Done waiting for JWT refresh for exec with task ID {self._task_id}")
|
|
436
|
+
pass
|
|
437
|
+
|
|
438
|
+
# Time to refresh.
|
|
439
|
+
logger.debug(f"Refreshing JWT for exec with task ID {self._task_id}")
|
|
440
|
+
await self._refresh_jwt()
|
|
441
|
+
except asyncio.CancelledError:
|
|
442
|
+
logger.debug(f"Cancelled JWT refresh loop for exec with task ID {self._task_id}")
|
|
443
|
+
break
|
|
444
|
+
except Exception as e:
|
|
445
|
+
# Exceptions here can stem from non-transient errors against the server sending
|
|
446
|
+
# the TaskGetCommandRouterAccess RPC, for instance, if the task has finished.
|
|
447
|
+
logger.warning(f"Background JWT refresh failed for exec with task ID {self._task_id}: {e}")
|
|
448
|
+
break
|
|
449
|
+
|
|
450
|
+
async def _stream_stdio(
|
|
451
|
+
self,
|
|
452
|
+
task_id: str,
|
|
453
|
+
exec_id: str,
|
|
454
|
+
# Quotes around the type required for protobuf 3.19.
|
|
455
|
+
file_descriptor: "sr_pb2.TaskExecStdioFileDescriptor.ValueType",
|
|
456
|
+
deadline: Optional[float] = None,
|
|
457
|
+
) -> AsyncIterator[sr_pb2.TaskExecStdioReadResponse]:
|
|
458
|
+
"""Stream stdio from the task, properly updating the offset and retrying on transient errors.
|
|
459
|
+
Raises ExecTimeoutError if the deadline is exceeded.
|
|
460
|
+
"""
|
|
461
|
+
offset = 0
|
|
462
|
+
delay_secs = self.stream_stdio_retry_delay_secs
|
|
463
|
+
delay_factor = self.stream_stdio_retry_delay_factor
|
|
464
|
+
num_retries_remaining = self.stream_stdio_max_retries
|
|
465
|
+
num_auth_retries = 0
|
|
466
|
+
|
|
467
|
+
async def sleep_and_update_delay_and_num_retries_remaining(e: Exception):
|
|
468
|
+
nonlocal delay_secs, num_retries_remaining
|
|
469
|
+
logger.debug(f"Retrying stdio read with delay {delay_secs}s due to error: {e}")
|
|
470
|
+
if deadline is not None and deadline - time.monotonic() <= delay_secs:
|
|
471
|
+
raise ExecTimeoutError(f"Deadline exceeded while streaming stdio for exec {exec_id}")
|
|
472
|
+
|
|
473
|
+
await asyncio.sleep(delay_secs)
|
|
474
|
+
delay_secs *= delay_factor
|
|
475
|
+
num_retries_remaining -= 1
|
|
476
|
+
|
|
477
|
+
while True:
|
|
478
|
+
timeout = max(0, deadline - time.monotonic()) if deadline is not None else None
|
|
479
|
+
try:
|
|
480
|
+
stream = self._stub.TaskExecStdioRead.open(timeout=timeout)
|
|
481
|
+
async with stream as s:
|
|
482
|
+
req = sr_pb2.TaskExecStdioReadRequest(
|
|
483
|
+
task_id=task_id,
|
|
484
|
+
exec_id=exec_id,
|
|
485
|
+
offset=offset,
|
|
486
|
+
file_descriptor=file_descriptor,
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
# Scope auth retry strictly to the initial send (where headers/auth are sent).
|
|
490
|
+
try:
|
|
491
|
+
await s.send_message(req, end=True)
|
|
492
|
+
except GRPCError as exc:
|
|
493
|
+
if exc.status == Status.UNAUTHENTICATED and num_auth_retries < 1:
|
|
494
|
+
await self._refresh_jwt()
|
|
495
|
+
num_auth_retries += 1
|
|
496
|
+
continue
|
|
497
|
+
raise
|
|
498
|
+
|
|
499
|
+
# We successfully authenticated, reset the auth retry count.
|
|
500
|
+
num_auth_retries = 0
|
|
501
|
+
|
|
502
|
+
async for item in s:
|
|
503
|
+
# Reset retry backoff after any successful chunk.
|
|
504
|
+
delay_secs = self.stream_stdio_retry_delay_secs
|
|
505
|
+
offset += len(item.data)
|
|
506
|
+
yield item
|
|
507
|
+
|
|
508
|
+
# We successfully streamed all output.
|
|
509
|
+
return
|
|
510
|
+
except GRPCError as e:
|
|
511
|
+
if num_retries_remaining > 0 and e.status in RETRYABLE_GRPC_STATUS_CODES:
|
|
512
|
+
await sleep_and_update_delay_and_num_retries_remaining(e)
|
|
513
|
+
else:
|
|
514
|
+
raise e
|
|
515
|
+
except AttributeError as e:
|
|
516
|
+
# StreamTerminatedError are not properly raised in grpclib<=0.4.7
|
|
517
|
+
# fixed in https://github.com/vmagamedov/grpclib/issues/185
|
|
518
|
+
# TODO: update to newer version (>=0.4.8) once stable
|
|
519
|
+
if num_retries_remaining > 0 and "_write_appdata" in str(e):
|
|
520
|
+
await sleep_and_update_delay_and_num_retries_remaining(e)
|
|
521
|
+
else:
|
|
522
|
+
raise e
|
|
523
|
+
except StreamTerminatedError as e:
|
|
524
|
+
if num_retries_remaining > 0:
|
|
525
|
+
await sleep_and_update_delay_and_num_retries_remaining(e)
|
|
526
|
+
else:
|
|
527
|
+
raise e
|
|
528
|
+
except asyncio.TimeoutError as e:
|
|
529
|
+
if num_retries_remaining > 0:
|
|
530
|
+
await sleep_and_update_delay_and_num_retries_remaining(e)
|
|
531
|
+
else:
|
|
532
|
+
raise ConnectionError(str(e))
|
|
533
|
+
except OSError as e:
|
|
534
|
+
if num_retries_remaining > 0:
|
|
535
|
+
await sleep_and_update_delay_and_num_retries_remaining(e)
|
|
536
|
+
else:
|
|
537
|
+
raise ConnectionError(str(e))
|