modal 1.2.1.dev8__py3-none-any.whl → 1.2.2.dev19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. modal/_clustered_functions.py +1 -3
  2. modal/_container_entrypoint.py +4 -1
  3. modal/_functions.py +33 -49
  4. modal/_grpc_client.py +148 -0
  5. modal/_output.py +3 -4
  6. modal/_partial_function.py +22 -2
  7. modal/_runtime/container_io_manager.py +21 -22
  8. modal/_utils/async_utils.py +12 -3
  9. modal/_utils/auth_token_manager.py +1 -4
  10. modal/_utils/blob_utils.py +3 -4
  11. modal/_utils/function_utils.py +4 -0
  12. modal/_utils/grpc_utils.py +80 -51
  13. modal/_utils/mount_utils.py +26 -1
  14. modal/_utils/task_command_router_client.py +536 -0
  15. modal/app.py +7 -5
  16. modal/cli/cluster.py +4 -2
  17. modal/cli/config.py +3 -1
  18. modal/cli/container.py +5 -4
  19. modal/cli/entry_point.py +1 -0
  20. modal/cli/launch.py +1 -2
  21. modal/cli/network_file_system.py +1 -4
  22. modal/cli/queues.py +1 -2
  23. modal/cli/secret.py +1 -2
  24. modal/client.py +5 -115
  25. modal/client.pyi +2 -91
  26. modal/cls.py +1 -2
  27. modal/config.py +3 -1
  28. modal/container_process.py +287 -11
  29. modal/container_process.pyi +95 -32
  30. modal/dict.py +12 -12
  31. modal/environments.py +1 -2
  32. modal/exception.py +4 -0
  33. modal/experimental/__init__.py +2 -3
  34. modal/experimental/flash.py +27 -57
  35. modal/experimental/flash.pyi +6 -20
  36. modal/file_io.py +13 -27
  37. modal/functions.pyi +6 -6
  38. modal/image.py +24 -3
  39. modal/image.pyi +4 -0
  40. modal/io_streams.py +433 -127
  41. modal/io_streams.pyi +236 -171
  42. modal/mount.py +4 -4
  43. modal/network_file_system.py +5 -6
  44. modal/parallel_map.py +29 -31
  45. modal/parallel_map.pyi +3 -9
  46. modal/partial_function.pyi +4 -1
  47. modal/queue.py +17 -18
  48. modal/runner.py +12 -11
  49. modal/sandbox.py +148 -42
  50. modal/sandbox.pyi +139 -0
  51. modal/secret.py +4 -5
  52. modal/snapshot.py +1 -4
  53. modal/token_flow.py +1 -1
  54. modal/volume.py +22 -22
  55. {modal-1.2.1.dev8.dist-info → modal-1.2.2.dev19.dist-info}/METADATA +1 -1
  56. {modal-1.2.1.dev8.dist-info → modal-1.2.2.dev19.dist-info}/RECORD +70 -68
  57. modal_proto/api.proto +2 -24
  58. modal_proto/api_grpc.py +0 -32
  59. modal_proto/api_pb2.py +838 -878
  60. modal_proto/api_pb2.pyi +8 -70
  61. modal_proto/api_pb2_grpc.py +0 -67
  62. modal_proto/api_pb2_grpc.pyi +0 -22
  63. modal_proto/modal_api_grpc.py +175 -177
  64. modal_proto/sandbox_router.proto +0 -4
  65. modal_proto/sandbox_router_pb2.pyi +0 -4
  66. modal_version/__init__.py +1 -1
  67. {modal-1.2.1.dev8.dist-info → modal-1.2.2.dev19.dist-info}/WHEEL +0 -0
  68. {modal-1.2.1.dev8.dist-info → modal-1.2.2.dev19.dist-info}/entry_points.txt +0 -0
  69. {modal-1.2.1.dev8.dist-info → modal-1.2.2.dev19.dist-info}/licenses/LICENSE +0 -0
  70. {modal-1.2.1.dev8.dist-info → modal-1.2.2.dev19.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,536 @@
1
+ # Copyright Modal Labs 2025
2
+ import asyncio
3
+ import base64
4
+ import json
5
+ import ssl
6
+ import time
7
+ import urllib.parse
8
+ from typing import AsyncIterator, Optional
9
+
10
+ import grpclib.client
11
+ import grpclib.config
12
+ import grpclib.events
13
+ from grpclib import GRPCError, Status
14
+ from grpclib.exceptions import StreamTerminatedError
15
+
16
+ from modal.config import config, logger
17
+ from modal.exception import ExecTimeoutError
18
+ from modal_proto import api_pb2, task_command_router_pb2 as sr_pb2
19
+ from modal_proto.task_command_router_grpc import TaskCommandRouterStub
20
+
21
+ from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES, connect_channel
22
+
23
+
24
+ def _b64url_decode(data: str) -> bytes:
25
+ """Decode a base64url string with missing padding tolerated."""
26
+ padding = "=" * (-len(data) % 4)
27
+ return base64.urlsafe_b64decode(data + padding)
28
+
29
+
30
+ def _parse_jwt_expiration(jwt_token: str) -> Optional[float]:
31
+ """Parse exp from a JWT without verification. Returns UNIX time seconds or None.
32
+
33
+ This is best-effort; if parsing fails or claim missing, returns None.
34
+ """
35
+ try:
36
+ parts = jwt_token.split(".")
37
+ if len(parts) != 3:
38
+ return None
39
+ payload_b = _b64url_decode(parts[1])
40
+ payload = json.loads(payload_b)
41
+ exp = payload.get("exp")
42
+ if isinstance(exp, (int, float)):
43
+ return float(exp)
44
+ except Exception:
45
+ # Avoid raising on malformed tokens; fall back to server-driven refresh logic.
46
+ logger.warning("Failed to parse JWT expiration")
47
+ return None
48
+ return None
49
+
50
+
51
+ async def call_with_retries_on_transient_errors(
52
+ func,
53
+ *,
54
+ base_delay_secs: float = 0.01,
55
+ delay_factor: float = 2,
56
+ max_retries: Optional[int] = 10,
57
+ ):
58
+ """Call func() with transient error retries and exponential backoff.
59
+
60
+ Authentication retries are expected to be handled by the caller.
61
+ """
62
+ delay_secs = base_delay_secs
63
+ num_retries = 0
64
+
65
+ async def sleep_and_update_delay_and_num_retries_remaining(e: Exception):
66
+ nonlocal delay_secs, num_retries
67
+ logger.debug(f"Retrying RPC with delay {delay_secs}s due to error: {e}")
68
+ await asyncio.sleep(delay_secs)
69
+ delay_secs *= delay_factor
70
+ num_retries += 1
71
+
72
+ while True:
73
+ try:
74
+ return await func()
75
+ except GRPCError as e:
76
+ if (max_retries is None or num_retries < max_retries) and e.status in RETRYABLE_GRPC_STATUS_CODES:
77
+ await sleep_and_update_delay_and_num_retries_remaining(e)
78
+ else:
79
+ raise e
80
+ except AttributeError as e:
81
+ # StreamTerminatedError are not properly raised in grpclib<=0.4.7
82
+ # fixed in https://github.com/vmagamedov/grpclib/issues/185
83
+ # TODO: update to newer version (>=0.4.8) once stable
84
+ if (max_retries is None or num_retries < max_retries) and "_write_appdata" in str(e):
85
+ await sleep_and_update_delay_and_num_retries_remaining(e)
86
+ else:
87
+ raise e
88
+ except StreamTerminatedError as e:
89
+ if max_retries is None or num_retries < max_retries:
90
+ await sleep_and_update_delay_and_num_retries_remaining(e)
91
+ else:
92
+ raise e
93
+ except (OSError, asyncio.TimeoutError) as e:
94
+ if max_retries is None or num_retries < max_retries:
95
+ await sleep_and_update_delay_and_num_retries_remaining(e)
96
+ else:
97
+ raise ConnectionError(str(e))
98
+
99
+
100
+ async def fetch_command_router_access(server_client, task_id: str) -> api_pb2.TaskGetCommandRouterAccessResponse:
101
+ """Fetch direct command router access info from Modal server."""
102
+ return await server_client.stub.TaskGetCommandRouterAccess(
103
+ api_pb2.TaskGetCommandRouterAccessRequest(task_id=task_id),
104
+ )
105
+
106
+
107
+ class TaskCommandRouterClient:
108
+ """
109
+ Client used to talk directly to TaskCommandRouter service on worker hosts.
110
+
111
+ A new instance should be created per task.
112
+ """
113
+
114
+ @classmethod
115
+ async def try_init(
116
+ cls,
117
+ server_client,
118
+ task_id: str,
119
+ ) -> Optional["TaskCommandRouterClient"]:
120
+ """Attempt to initialize a TaskCommandRouterClient by fetching direct access.
121
+
122
+ Returns None if command router access is not enabled (FAILED_PRECONDITION).
123
+ """
124
+ try:
125
+ resp = await fetch_command_router_access(server_client, task_id)
126
+ except GRPCError as exc:
127
+ if exc.status == Status.FAILED_PRECONDITION:
128
+ logger.debug(f"Command router access is not enabled for task {task_id}")
129
+ return None
130
+ raise
131
+
132
+ logger.debug(f"Using command router access for task {task_id}")
133
+
134
+ # Build and connect a channel to the task command router now that we have access info.
135
+ o = urllib.parse.urlparse(resp.url)
136
+ if o.scheme != "https":
137
+ raise ValueError(f"Task router URL must be https, got: {resp.url}")
138
+
139
+ host, _, port_str = o.netloc.partition(":")
140
+ port = int(port_str) if port_str else 443
141
+ ssl_context = ssl.create_default_context()
142
+
143
+ # Allow insecure TLS when explicitly enabled via config.
144
+ if config["task_command_router_insecure"]:
145
+ logger.warning("Using insecure TLS for task command router due to MODAL_TASK_COMMAND_ROUTER_INSECURE")
146
+ ssl_context.check_hostname = False
147
+ ssl_context.verify_mode = ssl.CERT_NONE
148
+
149
+ channel = grpclib.client.Channel(
150
+ host,
151
+ port,
152
+ ssl=ssl_context,
153
+ config=grpclib.config.Configuration(
154
+ http2_connection_window_size=64 * 1024 * 1024, # 64 MiB
155
+ http2_stream_window_size=64 * 1024 * 1024, # 64 MiB
156
+ ),
157
+ )
158
+
159
+ await connect_channel(channel)
160
+
161
+ return cls(server_client, task_id, resp.url, resp.jwt, channel)
162
+
163
+ def __init__(
164
+ self,
165
+ server_client,
166
+ task_id: str,
167
+ server_url: str,
168
+ jwt: str,
169
+ channel: grpclib.client.Channel,
170
+ *,
171
+ stream_stdio_retry_delay_secs: float = 0.01,
172
+ stream_stdio_retry_delay_factor: float = 2,
173
+ stream_stdio_max_retries: int = 10,
174
+ ) -> None:
175
+ """Callers should not use this directly. Use TaskCommandRouterClient.try_init() instead."""
176
+ # Attach bearer token on all requests to the worker-side router service.
177
+ self._server_client = server_client
178
+ self._task_id = task_id
179
+ self._server_url = server_url
180
+ self._jwt = jwt
181
+ self._channel = channel
182
+ # Retry configuration for stdio streaming
183
+ self.stream_stdio_retry_delay_secs = stream_stdio_retry_delay_secs
184
+ self.stream_stdio_retry_delay_factor = stream_stdio_retry_delay_factor
185
+ self.stream_stdio_max_retries = stream_stdio_max_retries
186
+
187
+ # JWT refresh coordination
188
+ self._jwt_exp: Optional[float] = _parse_jwt_expiration(jwt)
189
+ self._jwt_refresh_lock = asyncio.Lock()
190
+ self._jwt_refresh_event = asyncio.Event()
191
+ self._closed = False
192
+
193
+ # Start background task to eagerly refresh JWT 30s before expiration.
194
+ self._jwt_refresh_task = asyncio.create_task(self._jwt_refresh_loop())
195
+
196
+ async def send_request(event: grpclib.events.SendRequest) -> None:
197
+ # This will get the most recent JWT for every request. No need to
198
+ # lock _jwt_refresh_lock: reads and writes happen on the
199
+ # single-threaded event loop and variable assignment is atomic.
200
+ event.metadata["authorization"] = f"Bearer {self._jwt}"
201
+
202
+ grpclib.events.listen(self._channel, grpclib.events.SendRequest, send_request)
203
+
204
+ self._stub = TaskCommandRouterStub(self._channel)
205
+
206
+ def __del__(self) -> None:
207
+ """Clean up the client when it's garbage collected."""
208
+ if self._closed:
209
+ return
210
+
211
+ self._jwt_refresh_task.cancel()
212
+
213
+ try:
214
+ self._channel.close()
215
+ except Exception:
216
+ pass
217
+
218
+ async def close(self) -> None:
219
+ """Close the client and stop the background JWT refresh task."""
220
+ if self._closed:
221
+ return
222
+
223
+ self._closed = True
224
+ self._jwt_refresh_task.cancel()
225
+ try:
226
+ logger.debug(f"Waiting for JWT refresh task to complete for exec with task ID {self._task_id}")
227
+ await self._jwt_refresh_task
228
+ except asyncio.CancelledError:
229
+ pass
230
+ self._channel.close()
231
+
232
+ async def exec_start(self, request: sr_pb2.TaskExecStartRequest) -> sr_pb2.TaskExecStartResponse:
233
+ """Start an exec'd command, properly retrying on transient errors."""
234
+ return await call_with_retries_on_transient_errors(
235
+ lambda: self._call_with_auth_retry(self._stub.TaskExecStart, request)
236
+ )
237
+
238
+ async def exec_stdio_read(
239
+ self,
240
+ task_id: str,
241
+ exec_id: str,
242
+ # Quotes around the type required for protobuf 3.19.
243
+ file_descriptor: "api_pb2.FileDescriptor.ValueType",
244
+ deadline: Optional[float] = None,
245
+ ) -> AsyncIterator[sr_pb2.TaskExecStdioReadResponse]:
246
+ """Stream stdout/stderr batches from the task, properly retrying on transient errors.
247
+
248
+ Args:
249
+ task_id: The task ID of the task running the exec'd command.
250
+ exec_id: The execution ID of the command to read from.
251
+ file_descriptor: The file descriptor to read from.
252
+ deadline: The deadline by which all output must be streamed. If
253
+ None, wait forever. If the deadline is exceeded, raises an
254
+ ExecTimeoutError.
255
+ Returns:
256
+ AsyncIterator[sr_pb2.TaskExecStdioReadResponse]: A stream of stdout/stderr batches.
257
+ Raises:
258
+ ExecTimeoutError: If the deadline is exceeded.
259
+ Other errors: If retries are exhausted on transient errors or if there's an error
260
+ from the RPC itself.
261
+ """
262
+ if file_descriptor == api_pb2.FILE_DESCRIPTOR_STDOUT:
263
+ sr_fd = sr_pb2.TASK_EXEC_STDIO_FILE_DESCRIPTOR_STDOUT
264
+ elif file_descriptor == api_pb2.FILE_DESCRIPTOR_STDERR:
265
+ sr_fd = sr_pb2.TASK_EXEC_STDIO_FILE_DESCRIPTOR_STDERR
266
+ elif file_descriptor == api_pb2.FILE_DESCRIPTOR_INFO or file_descriptor == api_pb2.FILE_DESCRIPTOR_UNSPECIFIED:
267
+ raise ValueError(f"Unsupported file descriptor: {file_descriptor}")
268
+ else:
269
+ raise ValueError(f"Invalid file descriptor: {file_descriptor}")
270
+
271
+ async for item in self._stream_stdio(task_id, exec_id, sr_fd, deadline):
272
+ yield item
273
+
274
+ async def exec_stdin_write(
275
+ self, task_id: str, exec_id: str, offset: int, data: bytes, eof: bool
276
+ ) -> sr_pb2.TaskExecStdinWriteResponse:
277
+ """Write to the stdin stream of an exec'd command, properly retrying on transient errors.
278
+
279
+ Args:
280
+ task_id: The task ID of the task running the exec'd command.
281
+ exec_id: The execution ID of the command to write to.
282
+ offset: The offset to start writing to.
283
+ data: The data to write to the stdin stream.
284
+ eof: Whether to close the stdin stream after writing the data.
285
+ Raises:
286
+ Other errors: If retries are exhausted on transient errors or if there's an error
287
+ from the RPC itself.
288
+ """
289
+ request = sr_pb2.TaskExecStdinWriteRequest(task_id=task_id, exec_id=exec_id, offset=offset, data=data, eof=eof)
290
+ return await call_with_retries_on_transient_errors(
291
+ lambda: self._call_with_auth_retry(self._stub.TaskExecStdinWrite, request)
292
+ )
293
+
294
+ async def exec_poll(
295
+ self, task_id: str, exec_id: str, deadline: Optional[float] = None
296
+ ) -> sr_pb2.TaskExecPollResponse:
297
+ """Poll for the exit status of an exec'd command, properly retrying on transient errors.
298
+
299
+ Args:
300
+ task_id: The task ID of the task running the exec'd command.
301
+ exec_id: The execution ID of the command to poll on.
302
+ Returns:
303
+ sr_pb2.TaskExecPollResponse: The exit status of the command if it has completed.
304
+
305
+ Raises:
306
+ ExecTimeoutError: If the deadline is exceeded.
307
+ Other errors: If retries are exhausted on transient errors or if there's an error
308
+ from the RPC itself.
309
+ """
310
+ request = sr_pb2.TaskExecPollRequest(task_id=task_id, exec_id=exec_id)
311
+ # The timeout here is really a backstop in the event of a hang contacting
312
+ # the command router. Poll should usually be instantaneous.
313
+ timeout = deadline - time.monotonic() if deadline is not None else None
314
+ if timeout is not None and timeout <= 0:
315
+ raise ExecTimeoutError(f"Deadline exceeded while polling for exec {exec_id}")
316
+ try:
317
+ return await asyncio.wait_for(
318
+ call_with_retries_on_transient_errors(
319
+ lambda: self._call_with_auth_retry(self._stub.TaskExecPoll, request)
320
+ ),
321
+ timeout=timeout,
322
+ )
323
+ except asyncio.TimeoutError:
324
+ raise ExecTimeoutError(f"Deadline exceeded while polling for exec {exec_id}")
325
+
326
+ async def exec_wait(
327
+ self,
328
+ task_id: str,
329
+ exec_id: str,
330
+ deadline: Optional[float] = None,
331
+ ) -> sr_pb2.TaskExecWaitResponse:
332
+ """Wait for an exec'd command to exit and return the exit code, properly retrying on transient errors.
333
+
334
+ Args:
335
+ task_id: The task ID of the task running the exec'd command.
336
+ exec_id: The execution ID of the command to wait on.
337
+ Returns:
338
+ Optional[sr_pb2.TaskExecWaitResponse]: The exit code of the command.
339
+ Raises:
340
+ ExecTimeoutError: If the deadline is exceeded.
341
+ Other errors: If there's an error from the RPC itself.
342
+ """
343
+ request = sr_pb2.TaskExecWaitRequest(task_id=task_id, exec_id=exec_id)
344
+ timeout = deadline - time.monotonic() if deadline is not None else None
345
+ if timeout is not None and timeout <= 0:
346
+ raise ExecTimeoutError(f"Deadline exceeded while waiting for exec {exec_id}")
347
+ try:
348
+ return await asyncio.wait_for(
349
+ call_with_retries_on_transient_errors(
350
+ # We set a 60s timeout here to avoid waiting forever if there's an unanticipated hang
351
+ # due to a networking issue. call_with_retries_on_transient_errors will retry if the
352
+ # timeout is exceeded, so we'll retry every 60s until the command exits.
353
+ #
354
+ # Safety:
355
+ # * If just the task shuts down, the task command router will return a NOT_FOUND error,
356
+ # and we'll stop retrying.
357
+ # * If the task shut down AND the worker shut down, this could
358
+ # infinitely retry. For callers without an exec deadline, this
359
+ # could hang indefinitely.
360
+ lambda: self._call_with_auth_retry(self._stub.TaskExecWait, request, timeout=60),
361
+ base_delay_secs=1, # Retry after 1s since total time is expected to be long.
362
+ delay_factor=1, # Fixed delay.
363
+ max_retries=None, # Retry forever.
364
+ ),
365
+ timeout=timeout,
366
+ )
367
+ except asyncio.TimeoutError:
368
+ raise ExecTimeoutError(f"Deadline exceeded while waiting for exec {exec_id}")
369
+
370
+ async def _refresh_jwt(self) -> None:
371
+ """Refresh JWT from the server and update internal state.
372
+
373
+ Concurrency-safe: only one refresh runs at a time.
374
+ """
375
+ async with self._jwt_refresh_lock:
376
+ if self._closed:
377
+ return
378
+
379
+ # If the current JWT expiration is already far enough in the future, don't refresh.
380
+ if self._jwt_exp is not None and self._jwt_exp - time.time() > 30:
381
+ # This can happen if multiple concurrent requests to the task command router
382
+ # get UNAUTHENTICATED errors and all refresh at the same time - one of them
383
+ # will win and the others will not refresh.
384
+ logger.debug(
385
+ f"Skipping JWT refresh for exec with task ID {self._task_id} "
386
+ "because its expiration is already far enough in the future"
387
+ )
388
+ return
389
+
390
+ resp = await fetch_command_router_access(self._server_client, self._task_id)
391
+ # Ensure the server URL remains stable for the lifetime of this client.
392
+ assert resp.url == self._server_url, "Task router URL changed during session"
393
+ self._jwt = resp.jwt
394
+ self._jwt_exp = _parse_jwt_expiration(resp.jwt)
395
+ # Wake up the background loop to recompute its next sleep.
396
+ self._jwt_refresh_event.set()
397
+
398
+ async def _call_with_auth_retry(self, func, *args, **kwargs):
399
+ try:
400
+ return await func(*args, **kwargs)
401
+ except GRPCError as exc:
402
+ if exc.status == Status.UNAUTHENTICATED:
403
+ await self._refresh_jwt()
404
+ # Retry with the original arguments preserved
405
+ return await func(*args, **kwargs)
406
+ raise
407
+
408
+ async def _jwt_refresh_loop(self) -> None:
409
+ """Background task that refreshes JWT 30 seconds before expiration.
410
+
411
+ Uses an event to wake early when a manual refresh happens or token changes.
412
+ """
413
+ while not self._closed:
414
+ try:
415
+ exp = self._jwt_exp
416
+ now = time.time()
417
+ if exp is None:
418
+ # Unknown expiration: re-check periodically or until event wakes us.
419
+ sleep_s = 60.0
420
+ else:
421
+ refresh_at = exp - 30.0
422
+ sleep_s = max(refresh_at - now, 0.0)
423
+
424
+ self._jwt_refresh_event.clear()
425
+ if sleep_s > 0:
426
+ try:
427
+ logger.debug(f"Waiting for JWT refresh for {sleep_s}s for exec with task ID {self._task_id}")
428
+ # Wait until it's time to refresh, unless woken early.
429
+ await asyncio.wait_for(self._jwt_refresh_event.wait(), timeout=sleep_s)
430
+ logger.debug(f"Stopped waiting for JWT refresh for exec with task ID {self._task_id}")
431
+ # Event fired (e.g., token changed) -> recompute timings.
432
+ continue
433
+ except asyncio.TimeoutError:
434
+ logger.debug(f"Done waiting for JWT refresh for exec with task ID {self._task_id}")
435
+ pass
436
+
437
+ # Time to refresh.
438
+ logger.debug(f"Refreshing JWT for exec with task ID {self._task_id}")
439
+ await self._refresh_jwt()
440
+ except asyncio.CancelledError:
441
+ logger.debug(f"Cancelled JWT refresh loop for exec with task ID {self._task_id}")
442
+ break
443
+ except Exception as e:
444
+ # Exceptions here can stem from non-transient errors against the server sending
445
+ # the TaskGetCommandRouterAccess RPC, for instance, if the task has finished.
446
+ logger.debug(f"Background JWT refresh failed for exec with task ID {self._task_id}: {e}")
447
+ break
448
+
449
+ async def _stream_stdio(
450
+ self,
451
+ task_id: str,
452
+ exec_id: str,
453
+ # Quotes around the type required for protobuf 3.19.
454
+ file_descriptor: "sr_pb2.TaskExecStdioFileDescriptor.ValueType",
455
+ deadline: Optional[float] = None,
456
+ ) -> AsyncIterator[sr_pb2.TaskExecStdioReadResponse]:
457
+ """Stream stdio from the task, properly updating the offset and retrying on transient errors.
458
+ Raises ExecTimeoutError if the deadline is exceeded.
459
+ """
460
+ offset = 0
461
+ delay_secs = self.stream_stdio_retry_delay_secs
462
+ delay_factor = self.stream_stdio_retry_delay_factor
463
+ num_retries_remaining = self.stream_stdio_max_retries
464
+ num_auth_retries = 0
465
+
466
+ async def sleep_and_update_delay_and_num_retries_remaining(e: Exception):
467
+ nonlocal delay_secs, num_retries_remaining
468
+ logger.debug(f"Retrying stdio read with delay {delay_secs}s due to error: {e}")
469
+ if deadline is not None and deadline - time.monotonic() <= delay_secs:
470
+ raise ExecTimeoutError(f"Deadline exceeded while streaming stdio for exec {exec_id}")
471
+
472
+ await asyncio.sleep(delay_secs)
473
+ delay_secs *= delay_factor
474
+ num_retries_remaining -= 1
475
+
476
+ while True:
477
+ timeout = max(0, deadline - time.monotonic()) if deadline is not None else None
478
+ try:
479
+ stream = self._stub.TaskExecStdioRead.open(timeout=timeout)
480
+ async with stream as s:
481
+ req = sr_pb2.TaskExecStdioReadRequest(
482
+ task_id=task_id,
483
+ exec_id=exec_id,
484
+ offset=offset,
485
+ file_descriptor=file_descriptor,
486
+ )
487
+
488
+ # Scope auth retry strictly to the initial send (where headers/auth are sent).
489
+ try:
490
+ await s.send_message(req, end=True)
491
+ except GRPCError as exc:
492
+ if exc.status == Status.UNAUTHENTICATED and num_auth_retries < 1:
493
+ await self._refresh_jwt()
494
+ num_auth_retries += 1
495
+ continue
496
+ raise
497
+
498
+ # We successfully authenticated, reset the auth retry count.
499
+ num_auth_retries = 0
500
+
501
+ async for item in s:
502
+ # Reset retry backoff after any successful chunk.
503
+ delay_secs = self.stream_stdio_retry_delay_secs
504
+ offset += len(item.data)
505
+ yield item
506
+
507
+ # We successfully streamed all output.
508
+ return
509
+ except GRPCError as e:
510
+ if num_retries_remaining > 0 and e.status in RETRYABLE_GRPC_STATUS_CODES:
511
+ await sleep_and_update_delay_and_num_retries_remaining(e)
512
+ else:
513
+ raise e
514
+ except AttributeError as e:
515
+ # StreamTerminatedError are not properly raised in grpclib<=0.4.7
516
+ # fixed in https://github.com/vmagamedov/grpclib/issues/185
517
+ # TODO: update to newer version (>=0.4.8) once stable
518
+ if num_retries_remaining > 0 and "_write_appdata" in str(e):
519
+ await sleep_and_update_delay_and_num_retries_remaining(e)
520
+ else:
521
+ raise e
522
+ except StreamTerminatedError as e:
523
+ if num_retries_remaining > 0:
524
+ await sleep_and_update_delay_and_num_retries_remaining(e)
525
+ else:
526
+ raise e
527
+ except asyncio.TimeoutError as e:
528
+ if num_retries_remaining > 0:
529
+ await sleep_and_update_delay_and_num_retries_remaining(e)
530
+ else:
531
+ raise ConnectionError(str(e))
532
+ except OSError as e:
533
+ if num_retries_remaining > 0:
534
+ await sleep_and_update_delay_and_num_retries_remaining(e)
535
+ else:
536
+ raise ConnectionError(str(e))
modal/app.py CHANGED
@@ -27,14 +27,14 @@ from ._partial_function import (
27
27
  _find_partial_methods_for_user_cls,
28
28
  _PartialFunction,
29
29
  _PartialFunctionFlags,
30
+ verify_concurrent_params,
30
31
  )
31
32
  from ._utils.async_utils import synchronize_api
32
33
  from ._utils.deprecation import (
33
34
  deprecation_warning,
34
35
  warn_on_renamed_autoscaler_settings,
35
36
  )
36
- from ._utils.function_utils import FunctionInfo, is_global_object, is_method_fn
37
- from ._utils.grpc_utils import retry_transient_errors
37
+ from ._utils.function_utils import FunctionInfo, is_flash_object, is_global_object, is_method_fn
38
38
  from ._utils.mount_utils import validate_volumes
39
39
  from ._utils.name_utils import check_object_name, check_tag_dict
40
40
  from .client import _Client
@@ -302,7 +302,7 @@ class _App:
302
302
  object_creation_type=(api_pb2.OBJECT_CREATION_TYPE_CREATE_IF_MISSING if create_if_missing else None),
303
303
  )
304
304
 
305
- response = await retry_transient_errors(client.stub.AppGetOrCreate, request)
305
+ response = await client.stub.AppGetOrCreate(request)
306
306
 
307
307
  app = _App(name) # TODO: this should probably be a distinct constructor, possibly even a distinct type
308
308
  app._local_state_attr = None # this is not a locally defined App, so no local state
@@ -802,6 +802,7 @@ class _App:
802
802
  batch_max_size = f.params.batch_max_size
803
803
  batch_wait_ms = f.params.batch_wait_ms
804
804
  if f.flags & _PartialFunctionFlags.CONCURRENT:
805
+ verify_concurrent_params(params=f.params, is_flash=is_flash_object(experimental_options))
805
806
  max_concurrent_inputs = f.params.max_concurrent_inputs
806
807
  target_concurrent_inputs = f.params.target_concurrent_inputs
807
808
  else:
@@ -996,6 +997,7 @@ class _App:
996
997
  wrapped_cls.registered = True
997
998
  user_cls = wrapped_cls.user_cls
998
999
  if wrapped_cls.flags & _PartialFunctionFlags.CONCURRENT:
1000
+ verify_concurrent_params(params=wrapped_cls.params, is_flash=is_flash_object(experimental_options))
999
1001
  max_concurrent_inputs = wrapped_cls.params.max_concurrent_inputs
1000
1002
  target_concurrent_inputs = wrapped_cls.params.target_concurrent_inputs
1001
1003
  else:
@@ -1180,7 +1182,7 @@ class _App:
1180
1182
  req = api_pb2.AppSetTagsRequest(app_id=self._app_id, tags=tags)
1181
1183
 
1182
1184
  client = client or self._client or await _Client.from_env()
1183
- await retry_transient_errors(client.stub.AppSetTags, req)
1185
+ await client.stub.AppSetTags(req)
1184
1186
 
1185
1187
  async def get_tags(self, *, client: Optional[_Client] = None) -> dict[str, str]:
1186
1188
  """Get the tags that are currently attached to the App."""
@@ -1188,7 +1190,7 @@ class _App:
1188
1190
  raise InvalidError("`App.get_tags` cannot be called before the App is running.")
1189
1191
  req = api_pb2.AppGetTagsRequest(app_id=self._app_id)
1190
1192
  client = client or self._client or await _Client.from_env()
1191
- resp = await retry_transient_errors(client.stub.AppGetTags, req)
1193
+ resp = await client.stub.AppGetTags(req)
1192
1194
  return dict(resp.tags)
1193
1195
 
1194
1196
  async def _logs(self, client: Optional[_Client] = None) -> AsyncGenerator[str, None]:
modal/cli/cluster.py CHANGED
@@ -83,7 +83,9 @@ async def shell(
83
83
  )
84
84
  exec_res: api_pb2.ContainerExecResponse = await client.stub.ContainerExec(req)
85
85
  if pty:
86
- await _ContainerProcess(exec_res.exec_id, client).attach()
86
+ await _ContainerProcess(exec_res.exec_id, task_id, client).attach()
87
87
  else:
88
88
  # TODO: redirect stderr to its own stream?
89
- await _ContainerProcess(exec_res.exec_id, client, stdout=StreamType.STDOUT, stderr=StreamType.STDOUT).wait()
89
+ await _ContainerProcess(
90
+ exec_res.exec_id, task_id, client, stdout=StreamType.STDOUT, stderr=StreamType.STDOUT
91
+ ).wait()
modal/cli/config.py CHANGED
@@ -1,4 +1,6 @@
1
1
  # Copyright Modal Labs 2022
2
+ import json
3
+
2
4
  import typer
3
5
 
4
6
  from modal._output import make_console
@@ -25,7 +27,7 @@ def show(redact: bool = typer.Option(True, help="Redact the `token_secret` value
25
27
  config_dict["token_secret"] = "***"
26
28
 
27
29
  console = make_console()
28
- console.print(config_dict)
30
+ console.print_json(json.dumps(config_dict))
29
31
 
30
32
 
31
33
  SET_DEFAULT_ENV_HELP = """Set the default Modal environment for the active profile
modal/cli/container.py CHANGED
@@ -7,7 +7,6 @@ from rich.text import Text
7
7
  from modal._object import _get_environment_name
8
8
  from modal._pty import get_pty_info
9
9
  from modal._utils.async_utils import synchronizer
10
- from modal._utils.grpc_utils import retry_transient_errors
11
10
  from modal._utils.time_utils import timestamp_to_localized_str
12
11
  from modal.cli.utils import ENV_OPTION, display_table, is_tty, stream_app_logs
13
12
  from modal.client import _Client
@@ -80,10 +79,12 @@ async def exec(
80
79
  res: api_pb2.ContainerExecResponse = await client.stub.ContainerExec(req)
81
80
 
82
81
  if pty:
83
- await _ContainerProcess(res.exec_id, client).attach()
82
+ await _ContainerProcess(res.exec_id, container_id, client).attach()
84
83
  else:
85
84
  # TODO: redirect stderr to its own stream?
86
- await _ContainerProcess(res.exec_id, client, stdout=StreamType.STDOUT, stderr=StreamType.STDOUT).wait()
85
+ await _ContainerProcess(
86
+ res.exec_id, container_id, client, stdout=StreamType.STDOUT, stderr=StreamType.STDOUT
87
+ ).wait()
87
88
 
88
89
 
89
90
  @container_cli.command("stop")
@@ -95,4 +96,4 @@ async def stop(container_id: str = typer.Argument(help="Container ID")):
95
96
  """
96
97
  client = await _Client.from_env()
97
98
  request = api_pb2.ContainerStopRequest(task_id=container_id)
98
- await retry_transient_errors(client.stub.ContainerStop, request)
99
+ await client.stub.ContainerStop(request)
modal/cli/entry_point.py CHANGED
@@ -36,6 +36,7 @@ entrypoint_cli_typer = typer.Typer(
36
36
  no_args_is_help=False,
37
37
  add_completion=False,
38
38
  rich_markup_mode="markdown",
39
+ context_settings={"help_option_names": ["-h", "--help"]},
39
40
  help="""
40
41
  Modal is the fastest way to run code in the cloud.
41
42
 
modal/cli/launch.py CHANGED
@@ -23,8 +23,7 @@ launch_cli = Typer(
23
23
  no_args_is_help=True,
24
24
  rich_markup_mode="markdown",
25
25
  help="""
26
- Open a serverless app instance on Modal.
27
- >⚠️ `modal launch` is **experimental** and may change in the future.
26
+ [Experimental] Open a serverless app instance on Modal.
28
27
  """,
29
28
  )
30
29