modal 1.0.3.dev10__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modal might be problematic. Click here for more details.
- modal/__init__.py +0 -2
- modal/__main__.py +3 -4
- modal/_billing.py +80 -0
- modal/_clustered_functions.py +7 -3
- modal/_clustered_functions.pyi +15 -3
- modal/_container_entrypoint.py +51 -69
- modal/_functions.py +508 -240
- modal/_grpc_client.py +171 -0
- modal/_load_context.py +105 -0
- modal/_object.py +81 -21
- modal/_output.py +58 -45
- modal/_partial_function.py +48 -73
- modal/_pty.py +7 -3
- modal/_resolver.py +26 -46
- modal/_runtime/asgi.py +4 -3
- modal/_runtime/container_io_manager.py +358 -220
- modal/_runtime/container_io_manager.pyi +296 -101
- modal/_runtime/execution_context.py +18 -2
- modal/_runtime/execution_context.pyi +64 -7
- modal/_runtime/gpu_memory_snapshot.py +262 -57
- modal/_runtime/user_code_imports.py +28 -58
- modal/_serialization.py +90 -6
- modal/_traceback.py +42 -1
- modal/_tunnel.pyi +380 -12
- modal/_utils/async_utils.py +84 -29
- modal/_utils/auth_token_manager.py +111 -0
- modal/_utils/blob_utils.py +181 -58
- modal/_utils/deprecation.py +19 -0
- modal/_utils/function_utils.py +91 -47
- modal/_utils/grpc_utils.py +89 -66
- modal/_utils/mount_utils.py +26 -1
- modal/_utils/name_utils.py +17 -3
- modal/_utils/task_command_router_client.py +536 -0
- modal/_utils/time_utils.py +34 -6
- modal/app.py +256 -88
- modal/app.pyi +909 -92
- modal/billing.py +5 -0
- modal/builder/2025.06.txt +18 -0
- modal/builder/PREVIEW.txt +18 -0
- modal/builder/base-images.json +58 -0
- modal/cli/_download.py +19 -3
- modal/cli/_traceback.py +3 -2
- modal/cli/app.py +4 -4
- modal/cli/cluster.py +15 -7
- modal/cli/config.py +5 -3
- modal/cli/container.py +7 -6
- modal/cli/dict.py +22 -16
- modal/cli/entry_point.py +12 -5
- modal/cli/environment.py +5 -4
- modal/cli/import_refs.py +3 -3
- modal/cli/launch.py +102 -5
- modal/cli/network_file_system.py +11 -12
- modal/cli/profile.py +3 -2
- modal/cli/programs/launch_instance_ssh.py +94 -0
- modal/cli/programs/run_jupyter.py +1 -1
- modal/cli/programs/run_marimo.py +95 -0
- modal/cli/programs/vscode.py +1 -1
- modal/cli/queues.py +57 -26
- modal/cli/run.py +91 -23
- modal/cli/secret.py +48 -22
- modal/cli/token.py +7 -8
- modal/cli/utils.py +4 -7
- modal/cli/volume.py +31 -25
- modal/client.py +15 -85
- modal/client.pyi +183 -62
- modal/cloud_bucket_mount.py +5 -3
- modal/cloud_bucket_mount.pyi +197 -5
- modal/cls.py +200 -126
- modal/cls.pyi +446 -68
- modal/config.py +29 -11
- modal/container_process.py +319 -19
- modal/container_process.pyi +190 -20
- modal/dict.py +290 -71
- modal/dict.pyi +835 -83
- modal/environments.py +15 -27
- modal/environments.pyi +46 -24
- modal/exception.py +14 -2
- modal/experimental/__init__.py +194 -40
- modal/experimental/flash.py +618 -0
- modal/experimental/flash.pyi +380 -0
- modal/experimental/ipython.py +11 -7
- modal/file_io.py +29 -36
- modal/file_io.pyi +251 -53
- modal/file_pattern_matcher.py +56 -16
- modal/functions.pyi +673 -92
- modal/gpu.py +1 -1
- modal/image.py +528 -176
- modal/image.pyi +1572 -145
- modal/io_streams.py +458 -128
- modal/io_streams.pyi +433 -52
- modal/mount.py +216 -151
- modal/mount.pyi +225 -78
- modal/network_file_system.py +45 -62
- modal/network_file_system.pyi +277 -56
- modal/object.pyi +93 -17
- modal/parallel_map.py +942 -129
- modal/parallel_map.pyi +294 -15
- modal/partial_function.py +0 -2
- modal/partial_function.pyi +234 -19
- modal/proxy.py +17 -8
- modal/proxy.pyi +36 -3
- modal/queue.py +270 -65
- modal/queue.pyi +817 -57
- modal/runner.py +115 -101
- modal/runner.pyi +205 -49
- modal/sandbox.py +512 -136
- modal/sandbox.pyi +845 -111
- modal/schedule.py +1 -1
- modal/secret.py +300 -70
- modal/secret.pyi +589 -34
- modal/serving.py +7 -11
- modal/serving.pyi +7 -8
- modal/snapshot.py +11 -8
- modal/snapshot.pyi +25 -4
- modal/token_flow.py +4 -4
- modal/token_flow.pyi +28 -8
- modal/volume.py +416 -158
- modal/volume.pyi +1117 -121
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +10 -9
- modal-1.2.3.dev7.dist-info/RECORD +195 -0
- modal_docs/mdmd/mdmd.py +17 -4
- modal_proto/api.proto +534 -79
- modal_proto/api_grpc.py +337 -1
- modal_proto/api_pb2.py +1522 -968
- modal_proto/api_pb2.pyi +1619 -134
- modal_proto/api_pb2_grpc.py +699 -4
- modal_proto/api_pb2_grpc.pyi +226 -14
- modal_proto/modal_api_grpc.py +175 -154
- modal_proto/sandbox_router.proto +145 -0
- modal_proto/sandbox_router_grpc.py +105 -0
- modal_proto/sandbox_router_pb2.py +149 -0
- modal_proto/sandbox_router_pb2.pyi +333 -0
- modal_proto/sandbox_router_pb2_grpc.py +203 -0
- modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
- modal_proto/task_command_router.proto +144 -0
- modal_proto/task_command_router_grpc.py +105 -0
- modal_proto/task_command_router_pb2.py +149 -0
- modal_proto/task_command_router_pb2.pyi +333 -0
- modal_proto/task_command_router_pb2_grpc.py +203 -0
- modal_proto/task_command_router_pb2_grpc.pyi +75 -0
- modal_version/__init__.py +1 -1
- modal/requirements/PREVIEW.txt +0 -16
- modal/requirements/base-images.json +0 -26
- modal-1.0.3.dev10.dist-info/RECORD +0 -179
- modal_proto/modal_options_grpc.py +0 -3
- modal_proto/options.proto +0 -19
- modal_proto/options_grpc.py +0 -3
- modal_proto/options_pb2.py +0 -35
- modal_proto/options_pb2.pyi +0 -20
- modal_proto/options_pb2_grpc.py +0 -4
- modal_proto/options_pb2_grpc.pyi +0 -7
- /modal/{requirements → builder}/2023.12.312.txt +0 -0
- /modal/{requirements → builder}/2023.12.txt +0 -0
- /modal/{requirements → builder}/2024.04.txt +0 -0
- /modal/{requirements → builder}/2024.10.txt +0 -0
- /modal/{requirements → builder}/README.md +0 -0
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
modal/_utils/async_utils.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
# Copyright Modal Labs 2022
|
|
2
2
|
import asyncio
|
|
3
3
|
import concurrent.futures
|
|
4
|
+
import contextlib
|
|
4
5
|
import functools
|
|
5
6
|
import inspect
|
|
6
7
|
import itertools
|
|
8
|
+
import sys
|
|
7
9
|
import time
|
|
8
10
|
import typing
|
|
11
|
+
import warnings
|
|
9
12
|
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Iterable, Iterator
|
|
10
13
|
from contextlib import asynccontextmanager
|
|
11
14
|
from dataclasses import dataclass
|
|
@@ -31,6 +34,10 @@ T = TypeVar("T")
|
|
|
31
34
|
P = ParamSpec("P")
|
|
32
35
|
V = TypeVar("V")
|
|
33
36
|
|
|
37
|
+
if sys.platform == "win32":
|
|
38
|
+
# quick workaround for deadlocks on shutdown - need to investigate further
|
|
39
|
+
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
40
|
+
|
|
34
41
|
synchronizer = synchronicity.Synchronizer()
|
|
35
42
|
|
|
36
43
|
|
|
@@ -46,6 +53,10 @@ def synchronize_api(obj, target_module=None):
|
|
|
46
53
|
return synchronizer.create_blocking(obj, blocking_name, target_module=target_module)
|
|
47
54
|
|
|
48
55
|
|
|
56
|
+
# Used for testing to configure the `n_attempts` that `retry` will use.
|
|
57
|
+
RETRY_N_ATTEMPTS_OVERRIDE: Optional[int] = None
|
|
58
|
+
|
|
59
|
+
|
|
49
60
|
def retry(direct_fn=None, *, n_attempts=3, base_delay=0, delay_factor=2, timeout=90):
|
|
50
61
|
"""Decorator that calls an async function multiple times, with a given timeout.
|
|
51
62
|
|
|
@@ -70,8 +81,13 @@ def retry(direct_fn=None, *, n_attempts=3, base_delay=0, delay_factor=2, timeout
|
|
|
70
81
|
def decorator(fn):
|
|
71
82
|
@functools.wraps(fn)
|
|
72
83
|
async def f_wrapped(*args, **kwargs):
|
|
84
|
+
if RETRY_N_ATTEMPTS_OVERRIDE is not None:
|
|
85
|
+
local_n_attempts = RETRY_N_ATTEMPTS_OVERRIDE
|
|
86
|
+
else:
|
|
87
|
+
local_n_attempts = n_attempts
|
|
88
|
+
|
|
73
89
|
delay = base_delay
|
|
74
|
-
for i in range(
|
|
90
|
+
for i in range(local_n_attempts):
|
|
75
91
|
t0 = time.time()
|
|
76
92
|
try:
|
|
77
93
|
return await asyncio.wait_for(fn(*args, **kwargs), timeout=timeout)
|
|
@@ -79,12 +95,12 @@ def retry(direct_fn=None, *, n_attempts=3, base_delay=0, delay_factor=2, timeout
|
|
|
79
95
|
logger.debug(f"Function {fn} was cancelled")
|
|
80
96
|
raise
|
|
81
97
|
except Exception as e:
|
|
82
|
-
if i >=
|
|
98
|
+
if i >= local_n_attempts - 1:
|
|
83
99
|
raise
|
|
84
100
|
logger.debug(
|
|
85
101
|
f"Failed invoking function {fn}: {e}"
|
|
86
102
|
f" (took {time.time() - t0}s, sleeping {delay}s"
|
|
87
|
-
f" and trying {
|
|
103
|
+
f" and trying {local_n_attempts - i - 1} more times)"
|
|
88
104
|
)
|
|
89
105
|
await asyncio.sleep(delay)
|
|
90
106
|
delay *= delay_factor
|
|
@@ -120,7 +136,8 @@ class TaskContext:
|
|
|
120
136
|
_loops: set[asyncio.Task]
|
|
121
137
|
|
|
122
138
|
def __init__(self, grace: Optional[float] = None):
|
|
123
|
-
self._grace = grace
|
|
139
|
+
self._grace = grace # grace is the time we want for tasks to finish before cancelling them
|
|
140
|
+
self._cancellation_grace: float = 1.0 # extra graceperiod for the cancellation itself to "bubble up"
|
|
124
141
|
self._loops = set()
|
|
125
142
|
|
|
126
143
|
async def start(self):
|
|
@@ -152,22 +169,29 @@ class TaskContext:
|
|
|
152
169
|
# still needs to be handled
|
|
153
170
|
# (https://stackoverflow.com/a/63356323/2475114)
|
|
154
171
|
if gather_future:
|
|
155
|
-
|
|
172
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
156
173
|
await gather_future
|
|
157
|
-
except asyncio.CancelledError:
|
|
158
|
-
pass
|
|
159
174
|
|
|
175
|
+
cancelled_tasks: list[asyncio.Task] = []
|
|
160
176
|
for task in self._tasks:
|
|
161
177
|
if task.done() and not task.cancelled():
|
|
162
178
|
# Raise any exceptions if they happened.
|
|
163
179
|
# Only tasks without a done_callback will still be present in self._tasks
|
|
164
180
|
task.result()
|
|
165
181
|
|
|
166
|
-
if task.done()
|
|
182
|
+
if task.done():
|
|
167
183
|
continue
|
|
168
184
|
|
|
169
185
|
# Cancel any remaining unfinished tasks.
|
|
170
186
|
task.cancel()
|
|
187
|
+
cancelled_tasks.append(task)
|
|
188
|
+
|
|
189
|
+
cancellation_gather = asyncio.gather(*cancelled_tasks, return_exceptions=True)
|
|
190
|
+
try:
|
|
191
|
+
await asyncio.wait_for(cancellation_gather, timeout=self._cancellation_grace)
|
|
192
|
+
except asyncio.TimeoutError:
|
|
193
|
+
warnings.warn(f"Internal warning: Tasks did not cancel in a timely manner: {cancelled_tasks}")
|
|
194
|
+
|
|
171
195
|
await asyncio.sleep(0) # wake up coroutines waiting for cancellations
|
|
172
196
|
|
|
173
197
|
async def __aexit__(self, exc_type, value, tb):
|
|
@@ -274,7 +298,9 @@ class TimestampPriorityQueue(Generic[T]):
|
|
|
274
298
|
|
|
275
299
|
def __init__(self, maxsize: int = 0):
|
|
276
300
|
self.condition = asyncio.Condition()
|
|
277
|
-
self._queue: asyncio.PriorityQueue[tuple[float, Union[T, None]]] = asyncio.PriorityQueue(maxsize=maxsize)
|
|
301
|
+
self._queue: asyncio.PriorityQueue[tuple[float, int, Union[T, None]]] = asyncio.PriorityQueue(maxsize=maxsize)
|
|
302
|
+
# Used to tiebreak items with the same timestamp that are not comparable. (eg. protos)
|
|
303
|
+
self._counter = itertools.count()
|
|
278
304
|
|
|
279
305
|
async def close(self):
|
|
280
306
|
await self.put(self._MAX_PRIORITY, None)
|
|
@@ -283,7 +309,7 @@ class TimestampPriorityQueue(Generic[T]):
|
|
|
283
309
|
"""
|
|
284
310
|
Add an item to the queue to be processed at a specific timestamp.
|
|
285
311
|
"""
|
|
286
|
-
await self._queue.put((timestamp, item))
|
|
312
|
+
await self._queue.put((timestamp, next(self._counter), item))
|
|
287
313
|
async with self.condition:
|
|
288
314
|
self.condition.notify_all() # notify any waiting coroutines
|
|
289
315
|
|
|
@@ -296,7 +322,7 @@ class TimestampPriorityQueue(Generic[T]):
|
|
|
296
322
|
while self.empty():
|
|
297
323
|
await self.condition.wait()
|
|
298
324
|
# peek at the next item
|
|
299
|
-
timestamp, item = await self._queue.get()
|
|
325
|
+
timestamp, counter, item = await self._queue.get()
|
|
300
326
|
now = time.time()
|
|
301
327
|
if timestamp < now:
|
|
302
328
|
return item
|
|
@@ -304,7 +330,7 @@ class TimestampPriorityQueue(Generic[T]):
|
|
|
304
330
|
return None
|
|
305
331
|
# not ready yet, calculate sleep time
|
|
306
332
|
sleep_time = timestamp - now
|
|
307
|
-
self._queue.put_nowait((timestamp, item)) # put it back
|
|
333
|
+
self._queue.put_nowait((timestamp, counter, item)) # put it back
|
|
308
334
|
# wait until either the timeout or a new item is added
|
|
309
335
|
try:
|
|
310
336
|
await asyncio.wait_for(self.condition.wait(), timeout=sleep_time)
|
|
@@ -396,7 +422,7 @@ class _WarnIfGeneratorIsNotConsumed:
|
|
|
396
422
|
return await self.gen.aclose()
|
|
397
423
|
|
|
398
424
|
|
|
399
|
-
synchronize_api(_WarnIfGeneratorIsNotConsumed)
|
|
425
|
+
_BlockingWarnIfGeneratorIsNotConsumed = synchronize_api(_WarnIfGeneratorIsNotConsumed)
|
|
400
426
|
|
|
401
427
|
|
|
402
428
|
class _WarnIfNonWrappedGeneratorIsNotConsumed(_WarnIfGeneratorIsNotConsumed):
|
|
@@ -647,7 +673,9 @@ class StopSentinelType: ...
|
|
|
647
673
|
STOP_SENTINEL = StopSentinelType()
|
|
648
674
|
|
|
649
675
|
|
|
650
|
-
async def async_merge(
|
|
676
|
+
async def async_merge(
|
|
677
|
+
*generators: AsyncGenerator[T, None], cancellation_timeout: float = 10.0
|
|
678
|
+
) -> AsyncGenerator[T, None]:
|
|
651
679
|
"""
|
|
652
680
|
Asynchronously merges multiple async generators into a single async generator.
|
|
653
681
|
|
|
@@ -692,8 +720,9 @@ async def async_merge(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T,
|
|
|
692
720
|
|
|
693
721
|
async def producer(generator: AsyncGenerator[T, None]):
|
|
694
722
|
try:
|
|
695
|
-
async
|
|
696
|
-
|
|
723
|
+
async with aclosing(generator) as stream:
|
|
724
|
+
async for item in stream:
|
|
725
|
+
await queue.put(ValueWrapper(item))
|
|
697
726
|
except Exception as e:
|
|
698
727
|
await queue.put(ExceptionWrapper(e))
|
|
699
728
|
|
|
@@ -735,15 +764,20 @@ async def async_merge(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T,
|
|
|
735
764
|
new_output_task = asyncio.create_task(queue.get())
|
|
736
765
|
|
|
737
766
|
finally:
|
|
738
|
-
if not
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
767
|
+
unfinished_tasks = [t for t in tasks | {new_output_task} if not t.done()]
|
|
768
|
+
for t in unfinished_tasks:
|
|
769
|
+
t.cancel()
|
|
770
|
+
try:
|
|
771
|
+
await asyncio.wait_for(
|
|
772
|
+
asyncio.shield(
|
|
773
|
+
# we need to `shield` the `gather` to ensure cooperation with the timeout
|
|
774
|
+
# all underlying tasks have been marked as cancelled at this point anyway
|
|
775
|
+
asyncio.gather(*unfinished_tasks, return_exceptions=True)
|
|
776
|
+
),
|
|
777
|
+
timeout=cancellation_timeout,
|
|
778
|
+
)
|
|
779
|
+
except asyncio.TimeoutError:
|
|
780
|
+
logger.debug("Timed out while cleaning up async_merge")
|
|
747
781
|
|
|
748
782
|
|
|
749
783
|
async def callable_to_agen(awaitable: Callable[[], Awaitable[T]]) -> AsyncGenerator[T, None]:
|
|
@@ -761,16 +795,34 @@ async def gather_cancel_on_exc(*coros_or_futures):
|
|
|
761
795
|
raise
|
|
762
796
|
|
|
763
797
|
|
|
798
|
+
async def prevent_cancellation_abortion(coro):
|
|
799
|
+
# if this is cancelled, it will wait for coro cancellation handling
|
|
800
|
+
# and then unconditionally re-raises a CancelledError, even if the underlying coro
|
|
801
|
+
# doesn't re-raise the cancellation itself
|
|
802
|
+
t = asyncio.create_task(coro)
|
|
803
|
+
try:
|
|
804
|
+
return await asyncio.shield(t)
|
|
805
|
+
except asyncio.CancelledError:
|
|
806
|
+
if t.cancelled():
|
|
807
|
+
# coro cancelled itself - reraise
|
|
808
|
+
raise
|
|
809
|
+
t.cancel() # cancel task
|
|
810
|
+
await t # this *normally* reraises
|
|
811
|
+
raise # if the above somehow resolved, by swallowing cancellation - we still raise
|
|
812
|
+
|
|
813
|
+
|
|
764
814
|
async def async_map(
|
|
765
815
|
input_generator: AsyncGenerator[T, None],
|
|
766
816
|
async_mapper_func: Callable[[T], Awaitable[V]],
|
|
767
817
|
concurrency: int,
|
|
818
|
+
cancellation_timeout: float = 10.0,
|
|
768
819
|
) -> AsyncGenerator[V, None]:
|
|
769
820
|
queue: asyncio.Queue[Union[ValueWrapper[T], StopSentinelType]] = asyncio.Queue(maxsize=concurrency * 2)
|
|
770
821
|
|
|
771
822
|
async def producer() -> AsyncGenerator[V, None]:
|
|
772
|
-
async
|
|
773
|
-
|
|
823
|
+
async with aclosing(input_generator) as stream:
|
|
824
|
+
async for item in stream:
|
|
825
|
+
await queue.put(ValueWrapper(item))
|
|
774
826
|
|
|
775
827
|
for _ in range(concurrency):
|
|
776
828
|
await queue.put(STOP_SENTINEL)
|
|
@@ -784,14 +836,17 @@ async def async_map(
|
|
|
784
836
|
while True:
|
|
785
837
|
item = await queue.get()
|
|
786
838
|
if isinstance(item, ValueWrapper):
|
|
787
|
-
|
|
839
|
+
res = await prevent_cancellation_abortion(async_mapper_func(item.value))
|
|
840
|
+
yield res
|
|
788
841
|
elif isinstance(item, ExceptionWrapper):
|
|
789
842
|
raise item.value
|
|
790
843
|
else:
|
|
791
844
|
assert_type(item, StopSentinelType)
|
|
792
845
|
break
|
|
793
846
|
|
|
794
|
-
async with aclosing(
|
|
847
|
+
async with aclosing(
|
|
848
|
+
async_merge(*[worker() for i in range(concurrency)], producer(), cancellation_timeout=cancellation_timeout)
|
|
849
|
+
) as stream:
|
|
795
850
|
async for item in stream:
|
|
796
851
|
yield item
|
|
797
852
|
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# Copyright Modal Labs 2025
|
|
2
|
+
import asyncio
|
|
3
|
+
import base64
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
import typing
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from modal.exception import ExecutionError
|
|
10
|
+
from modal_proto import api_pb2, modal_api_grpc
|
|
11
|
+
|
|
12
|
+
from .logger import logger
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class _AuthTokenManager:
|
|
16
|
+
"""Handles fetching and refreshing of the input plane auth token."""
|
|
17
|
+
|
|
18
|
+
# Start refreshing this many seconds before the token expires
|
|
19
|
+
REFRESH_WINDOW = 5 * 60
|
|
20
|
+
# If the token doesn't have an expiry field, default to current time plus this value (not expected).
|
|
21
|
+
DEFAULT_EXPIRY_OFFSET = 20 * 60
|
|
22
|
+
|
|
23
|
+
def __init__(self, stub: "modal_api_grpc.ModalClientModal"):
|
|
24
|
+
self._stub = stub
|
|
25
|
+
self._token = ""
|
|
26
|
+
self._expiry = 0.0
|
|
27
|
+
self._lock: typing.Union[asyncio.Lock, None] = None
|
|
28
|
+
|
|
29
|
+
async def get_token(self) -> str:
|
|
30
|
+
"""
|
|
31
|
+
When called, the AuthTokenManager can be in one of three states:
|
|
32
|
+
1. Has a valid cached token. It is returned to the caller.
|
|
33
|
+
2. Has no cached token, or the token is expired. We fetch a new one and cache it. If `get_token` is called
|
|
34
|
+
concurrently by multiple coroutines, all requests will block until the token has been fetched. But only one
|
|
35
|
+
coroutine will actually make a request to the control plane to fetch the new token. This ensures we do not hit
|
|
36
|
+
the control plane with more requests than needed.
|
|
37
|
+
3. Has a valid cached token, but it is going to expire in the next 5 minutes. In this case we fetch a new token
|
|
38
|
+
and cache it. If `get_token` is called concurrently, only one request will fetch the new token, and the others
|
|
39
|
+
will be given the old (but still valid) token - i.e. they will not block.
|
|
40
|
+
"""
|
|
41
|
+
if not self._token or self._is_expired():
|
|
42
|
+
# We either have no token or it is expired - block everyone until we get a new token
|
|
43
|
+
await self._refresh_token()
|
|
44
|
+
elif self._needs_refresh():
|
|
45
|
+
# The token hasn't expired yet, but will soon, so it needs a refresh.
|
|
46
|
+
lock = await self._get_lock()
|
|
47
|
+
if lock.locked():
|
|
48
|
+
# The lock is taken, so someone else is refreshing. Continue to use the old token.
|
|
49
|
+
return self._token
|
|
50
|
+
else:
|
|
51
|
+
# The lock is not taken, so we need to fetch a new token.
|
|
52
|
+
await self._refresh_token()
|
|
53
|
+
|
|
54
|
+
return self._token
|
|
55
|
+
|
|
56
|
+
async def _refresh_token(self):
|
|
57
|
+
"""
|
|
58
|
+
Fetch a new token from the control plane. If called concurrently, only one coroutine will make a request for a
|
|
59
|
+
new token. The others will block on a lock, until the first coroutine has fetched the new token.
|
|
60
|
+
"""
|
|
61
|
+
lock = await self._get_lock()
|
|
62
|
+
async with lock:
|
|
63
|
+
# Double check inside lock - maybe another coroutine refreshed already. This happens the first time we fetch
|
|
64
|
+
# the token. The first coroutine will fetch the token, while the others block on the lock, waiting for the
|
|
65
|
+
# new token. Once we have a new token, the other coroutines will unblock and return from here.
|
|
66
|
+
if self._token and not self._needs_refresh():
|
|
67
|
+
return
|
|
68
|
+
resp: api_pb2.AuthTokenGetResponse = await self._stub.AuthTokenGet(api_pb2.AuthTokenGetRequest())
|
|
69
|
+
if not resp.token:
|
|
70
|
+
# Not expected
|
|
71
|
+
raise ExecutionError(
|
|
72
|
+
"Internal error: Did not receive auth token from server. Please contact Modal support."
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
self._token = resp.token
|
|
76
|
+
if exp := self._decode_jwt(resp.token).get("exp"):
|
|
77
|
+
self._expiry = float(exp)
|
|
78
|
+
else:
|
|
79
|
+
# This should never happen.
|
|
80
|
+
logger.warning("x-modal-auth-token does not contain exp field")
|
|
81
|
+
# We'll use the token, and set the expiry to 20 min from now.
|
|
82
|
+
self._expiry = time.time() + self.DEFAULT_EXPIRY_OFFSET
|
|
83
|
+
|
|
84
|
+
async def _get_lock(self) -> asyncio.Lock:
|
|
85
|
+
# Note: this function runs no async code but is marked as async to ensure it's
|
|
86
|
+
# being run inside the synchronicity event loop and binds the lock to the
|
|
87
|
+
# correct event loop on Python 3.9 which eagerly assigns event loops on
|
|
88
|
+
# constructions of locks
|
|
89
|
+
if self._lock is None:
|
|
90
|
+
self._lock = asyncio.Lock()
|
|
91
|
+
return self._lock
|
|
92
|
+
|
|
93
|
+
@staticmethod
|
|
94
|
+
def _decode_jwt(token: str) -> dict[str, Any]:
|
|
95
|
+
"""
|
|
96
|
+
Decodes a JWT into a dict without verifying signature. We do this manually instead of using a library to avoid
|
|
97
|
+
adding another dependency to the client.
|
|
98
|
+
"""
|
|
99
|
+
try:
|
|
100
|
+
payload = token.split(".")[1]
|
|
101
|
+
padding = "=" * (-len(payload) % 4)
|
|
102
|
+
decoded_bytes = base64.urlsafe_b64decode(payload + padding)
|
|
103
|
+
return json.loads(decoded_bytes)
|
|
104
|
+
except Exception as e:
|
|
105
|
+
raise ValueError("Internal error: Cannot parse auth token. Please contact Modal support.") from e
|
|
106
|
+
|
|
107
|
+
def _needs_refresh(self):
|
|
108
|
+
return time.time() >= (self._expiry - self.REFRESH_WINDOW)
|
|
109
|
+
|
|
110
|
+
def _is_expired(self):
|
|
111
|
+
return time.time() >= self._expiry
|