flyte 0.2.0b1__py3-none-any.whl → 0.2.0b3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +3 -4
- flyte/_bin/runtime.py +21 -7
- flyte/_cache/cache.py +1 -2
- flyte/_cli/_common.py +26 -4
- flyte/_cli/_create.py +48 -0
- flyte/_cli/_deploy.py +4 -2
- flyte/_cli/_get.py +18 -7
- flyte/_cli/_run.py +1 -0
- flyte/_cli/main.py +11 -5
- flyte/_code_bundle/bundle.py +42 -11
- flyte/_context.py +1 -1
- flyte/_deploy.py +3 -1
- flyte/_group.py +1 -1
- flyte/_initialize.py +28 -247
- flyte/_internal/controllers/__init__.py +6 -6
- flyte/_internal/controllers/_local_controller.py +14 -5
- flyte/_internal/controllers/_trace.py +1 -1
- flyte/_internal/controllers/remote/__init__.py +27 -7
- flyte/_internal/controllers/remote/_action.py +1 -1
- flyte/_internal/controllers/remote/_client.py +5 -1
- flyte/_internal/controllers/remote/_controller.py +68 -24
- flyte/_internal/controllers/remote/_core.py +1 -1
- flyte/_internal/runtime/convert.py +34 -8
- flyte/_internal/runtime/entrypoints.py +1 -1
- flyte/_internal/runtime/io.py +3 -3
- flyte/_internal/runtime/task_serde.py +31 -1
- flyte/_internal/runtime/taskrunner.py +1 -1
- flyte/_internal/runtime/types_serde.py +1 -1
- flyte/_run.py +47 -28
- flyte/_task.py +2 -2
- flyte/_task_environment.py +1 -1
- flyte/_trace.py +5 -6
- flyte/_utils/__init__.py +2 -0
- flyte/_utils/async_cache.py +139 -0
- flyte/_version.py +2 -2
- flyte/config/__init__.py +26 -4
- flyte/config/_config.py +13 -4
- flyte/extras/_container.py +3 -3
- flyte/{_datastructures.py → models.py} +3 -2
- flyte/remote/_client/auth/_auth_utils.py +14 -0
- flyte/remote/_client/auth/_channel.py +28 -3
- flyte/remote/_client/auth/_token_client.py +3 -3
- flyte/remote/_client/controlplane.py +13 -13
- flyte/remote/_logs.py +1 -1
- flyte/remote/_run.py +4 -8
- flyte/remote/_task.py +2 -2
- flyte/storage/__init__.py +5 -0
- flyte/storage/_config.py +233 -0
- flyte/storage/_storage.py +23 -3
- flyte/types/_interface.py +1 -1
- flyte/types/_type_engine.py +1 -1
- {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/METADATA +2 -2
- {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/RECORD +56 -54
- flyte/_internal/controllers/pbhash.py +0 -39
- {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/WHEEL +0 -0
- {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/entry_points.txt +0 -0
- {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import time
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
from typing import Awaitable, Callable, Dict, Generic, Optional, TypeVar
|
|
5
|
+
|
|
6
|
+
K = TypeVar("K")
|
|
7
|
+
V = TypeVar("V")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AsyncLRUCache(Generic[K, V]):
|
|
11
|
+
"""
|
|
12
|
+
A high-performance async-compatible LRU cache.
|
|
13
|
+
|
|
14
|
+
Examples:
|
|
15
|
+
```python
|
|
16
|
+
# Create a cache instance
|
|
17
|
+
cache = AsyncLRUCache[str, dict](maxsize=100)
|
|
18
|
+
|
|
19
|
+
async def fetch_data(user_id: str) -> dict:
|
|
20
|
+
# Define the expensive operation as a local function
|
|
21
|
+
async def get_user_data():
|
|
22
|
+
await asyncio.sleep(1) # Simulating network/DB delay
|
|
23
|
+
return {"id": user_id, "name": f"User {user_id}"}
|
|
24
|
+
|
|
25
|
+
# Use the cache
|
|
26
|
+
return await cache.get(f"user:{user_id}", get_user_data)
|
|
27
|
+
```
|
|
28
|
+
This cache can be used from async coroutines and handles concurrent access safely.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, maxsize: int = 128, ttl: Optional[float] = None):
|
|
32
|
+
"""
|
|
33
|
+
Initialize the async LRU cache.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
maxsize: Maximum number of items to keep in the cache
|
|
37
|
+
ttl: Time-to-live for cache entries in seconds, or None for no expiration
|
|
38
|
+
"""
|
|
39
|
+
self._cache: OrderedDict[K, tuple[V, float]] = OrderedDict()
|
|
40
|
+
self._maxsize = maxsize
|
|
41
|
+
self._ttl = ttl
|
|
42
|
+
self._locks: Dict[K, asyncio.Lock] = {}
|
|
43
|
+
self._access_lock = asyncio.Lock()
|
|
44
|
+
|
|
45
|
+
async def get(self, key: K, value_func: Callable[[], V | Awaitable[V]]) -> V:
|
|
46
|
+
"""
|
|
47
|
+
Get a value from the cache, computing it if necessary.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
key: The cache key
|
|
51
|
+
value_func: Function or coroutine to compute the value if not cached
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
The cached or computed value
|
|
55
|
+
"""
|
|
56
|
+
# Fast path: check if key exists and is not expired
|
|
57
|
+
if key in self._cache:
|
|
58
|
+
value, timestamp = self._cache[key]
|
|
59
|
+
if self._ttl is None or time.time() - timestamp < self._ttl:
|
|
60
|
+
# Move the accessed item to the end (most recently used)
|
|
61
|
+
async with self._access_lock:
|
|
62
|
+
self._cache.move_to_end(key)
|
|
63
|
+
return value
|
|
64
|
+
|
|
65
|
+
# Slow path: compute the value
|
|
66
|
+
# Get or create a lock for this key to prevent redundant computation
|
|
67
|
+
async with self._access_lock:
|
|
68
|
+
lock = self._locks.get(key)
|
|
69
|
+
if lock is None:
|
|
70
|
+
lock = asyncio.Lock()
|
|
71
|
+
self._locks[key] = lock
|
|
72
|
+
|
|
73
|
+
async with lock:
|
|
74
|
+
# Check again in case another coroutine computed the value while we waited
|
|
75
|
+
if key in self._cache:
|
|
76
|
+
value, timestamp = self._cache[key]
|
|
77
|
+
if self._ttl is None or time.time() - timestamp < self._ttl:
|
|
78
|
+
async with self._access_lock:
|
|
79
|
+
self._cache.move_to_end(key)
|
|
80
|
+
return value
|
|
81
|
+
|
|
82
|
+
# Compute the value
|
|
83
|
+
if asyncio.iscoroutinefunction(value_func):
|
|
84
|
+
value = await value_func()
|
|
85
|
+
else:
|
|
86
|
+
value = value_func() # type: ignore
|
|
87
|
+
|
|
88
|
+
# Store in cache
|
|
89
|
+
async with self._access_lock:
|
|
90
|
+
self._cache[key] = (value, time.time())
|
|
91
|
+
# Evict least recently used items if needed
|
|
92
|
+
while len(self._cache) > self._maxsize:
|
|
93
|
+
self._cache.popitem(last=False)
|
|
94
|
+
# Clean up the lock
|
|
95
|
+
self._locks.pop(key, None)
|
|
96
|
+
|
|
97
|
+
return value
|
|
98
|
+
|
|
99
|
+
async def set(self, key: K, value: V) -> None:
|
|
100
|
+
"""
|
|
101
|
+
Explicitly set a value in the cache.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
key: The cache key
|
|
105
|
+
value: The value to cache
|
|
106
|
+
"""
|
|
107
|
+
async with self._access_lock:
|
|
108
|
+
self._cache[key] = (value, time.time())
|
|
109
|
+
# Evict least recently used items if needed
|
|
110
|
+
while len(self._cache) > self._maxsize:
|
|
111
|
+
self._cache.popitem(last=False)
|
|
112
|
+
|
|
113
|
+
async def invalidate(self, key: K) -> None:
|
|
114
|
+
"""Remove a specific key from the cache."""
|
|
115
|
+
async with self._access_lock:
|
|
116
|
+
self._cache.pop(key, None)
|
|
117
|
+
|
|
118
|
+
async def clear(self) -> None:
|
|
119
|
+
"""Clear the entire cache."""
|
|
120
|
+
async with self._access_lock:
|
|
121
|
+
self._cache.clear()
|
|
122
|
+
self._locks.clear()
|
|
123
|
+
|
|
124
|
+
async def contains(self, key: K) -> bool:
|
|
125
|
+
"""Check if a key exists in the cache and is not expired."""
|
|
126
|
+
if key not in self._cache:
|
|
127
|
+
return False
|
|
128
|
+
|
|
129
|
+
if self._ttl is None:
|
|
130
|
+
return True
|
|
131
|
+
|
|
132
|
+
_, timestamp = self._cache[key]
|
|
133
|
+
return time.time() - timestamp < self._ttl
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# Example usage:
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
"""
|
flyte/_version.py
CHANGED
|
@@ -17,5 +17,5 @@ __version__: str
|
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
|
18
18
|
version_tuple: VERSION_TUPLE
|
|
19
19
|
|
|
20
|
-
__version__ = version = '0.2.
|
|
21
|
-
__version_tuple__ = version_tuple = (0, 2, 0, '
|
|
20
|
+
__version__ = version = '0.2.0b3'
|
|
21
|
+
__version_tuple__ = version_tuple = (0, 2, 0, 'b3')
|
flyte/config/__init__.py
CHANGED
|
@@ -1,8 +1,13 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import os
|
|
4
|
+
import pathlib
|
|
2
5
|
import typing
|
|
3
6
|
from dataclasses import dataclass, field
|
|
4
7
|
from typing import TYPE_CHECKING
|
|
5
8
|
|
|
9
|
+
import rich.repr
|
|
10
|
+
|
|
6
11
|
from flyte._logging import logger
|
|
7
12
|
from flyte.config import _internal
|
|
8
13
|
from flyte.config._config import ConfigFile, get_config_file, read_file_if_exists
|
|
@@ -13,6 +18,7 @@ if TYPE_CHECKING:
|
|
|
13
18
|
from flyte.remote._client.auth import AuthType
|
|
14
19
|
|
|
15
20
|
|
|
21
|
+
@rich.repr.auto
|
|
16
22
|
@dataclass(init=True, repr=True, eq=True, frozen=True)
|
|
17
23
|
class PlatformConfig(object):
|
|
18
24
|
"""
|
|
@@ -57,7 +63,6 @@ class PlatformConfig(object):
|
|
|
57
63
|
:param config_file:
|
|
58
64
|
:return:
|
|
59
65
|
"""
|
|
60
|
-
from .._initialize import set_if_exists
|
|
61
66
|
|
|
62
67
|
config_file = get_config_file(config_file)
|
|
63
68
|
kwargs: typing.Dict[str, typing.Any] = {}
|
|
@@ -106,6 +111,7 @@ class PlatformConfig(object):
|
|
|
106
111
|
return PlatformConfig(endpoint=endpoint, insecure=insecure)
|
|
107
112
|
|
|
108
113
|
|
|
114
|
+
@rich.repr.auto
|
|
109
115
|
@dataclass(init=True, repr=True, eq=True, frozen=True)
|
|
110
116
|
class TaskConfig(object):
|
|
111
117
|
org: str | None = None
|
|
@@ -119,8 +125,6 @@ class TaskConfig(object):
|
|
|
119
125
|
:param config_file:
|
|
120
126
|
:return:
|
|
121
127
|
"""
|
|
122
|
-
from flyte._initialize import set_if_exists
|
|
123
|
-
|
|
124
128
|
config_file = get_config_file(config_file)
|
|
125
129
|
kwargs: typing.Dict[str, typing.Any] = {}
|
|
126
130
|
kwargs = set_if_exists(kwargs, "org", _internal.Task.ORG.read(config_file))
|
|
@@ -129,6 +133,7 @@ class TaskConfig(object):
|
|
|
129
133
|
return TaskConfig(**kwargs)
|
|
130
134
|
|
|
131
135
|
|
|
136
|
+
@rich.repr.auto
|
|
132
137
|
@dataclass(init=True, repr=True, eq=True, frozen=True)
|
|
133
138
|
class Config(object):
|
|
134
139
|
"""
|
|
@@ -142,6 +147,7 @@ class Config(object):
|
|
|
142
147
|
|
|
143
148
|
platform: PlatformConfig = field(default=PlatformConfig())
|
|
144
149
|
task: TaskConfig = field(default=TaskConfig())
|
|
150
|
+
source: pathlib.Path | None = None
|
|
145
151
|
|
|
146
152
|
def with_params(
|
|
147
153
|
self,
|
|
@@ -165,4 +171,20 @@ class Config(object):
|
|
|
165
171
|
:return: Config
|
|
166
172
|
"""
|
|
167
173
|
config_file = get_config_file(config_file)
|
|
168
|
-
|
|
174
|
+
if config_file is None:
|
|
175
|
+
logger.debug("No config file found, using default values")
|
|
176
|
+
return Config()
|
|
177
|
+
return Config(
|
|
178
|
+
platform=PlatformConfig.auto(config_file), task=TaskConfig.auto(config_file), source=config_file.path
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def set_if_exists(d: dict, k: str, val: typing.Any) -> dict:
|
|
183
|
+
"""
|
|
184
|
+
Given a dict ``d`` sets the key ``k`` with value of config ``v``, if the config value ``v`` is set
|
|
185
|
+
and return the updated dictionary.
|
|
186
|
+
"""
|
|
187
|
+
exists = isinstance(val, bool) or bool(val is not None and val)
|
|
188
|
+
if exists:
|
|
189
|
+
d[k] = val
|
|
190
|
+
return d
|
flyte/config/_config.py
CHANGED
|
@@ -99,6 +99,14 @@ class ConfigFile(object):
|
|
|
99
99
|
self._location = location
|
|
100
100
|
self._yaml_config = self._read_yaml_config(location)
|
|
101
101
|
|
|
102
|
+
@property
|
|
103
|
+
def path(self) -> pathlib.Path:
|
|
104
|
+
"""
|
|
105
|
+
Returns the path to the config file.
|
|
106
|
+
:return: Path to the config file
|
|
107
|
+
"""
|
|
108
|
+
return pathlib.Path(self._location)
|
|
109
|
+
|
|
102
110
|
@staticmethod
|
|
103
111
|
def _read_yaml_config(location: str) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
|
104
112
|
with open(location, "r") as fh:
|
|
@@ -136,6 +144,11 @@ def resolve_config_path() -> pathlib.Path | None:
|
|
|
136
144
|
4. ~/.flyte/config.yaml if it exists
|
|
137
145
|
5. ./config.yaml if it exists
|
|
138
146
|
"""
|
|
147
|
+
current_location_config = Path("config.yaml")
|
|
148
|
+
if current_location_config.exists():
|
|
149
|
+
return current_location_config
|
|
150
|
+
logger.debug("No ./config.yaml found, returning None")
|
|
151
|
+
|
|
139
152
|
uctl_path_from_env = getenv(UCTL_CONFIG_ENV_VAR, None)
|
|
140
153
|
if uctl_path_from_env:
|
|
141
154
|
return pathlib.Path(uctl_path_from_env)
|
|
@@ -156,10 +169,6 @@ def resolve_config_path() -> pathlib.Path | None:
|
|
|
156
169
|
return home_dir_flytectl_config
|
|
157
170
|
logger.debug("No ~/.flyte/config.yaml found, checking current directory")
|
|
158
171
|
|
|
159
|
-
current_location_config = Path("config.yaml")
|
|
160
|
-
if current_location_config.exists():
|
|
161
|
-
return current_location_config
|
|
162
|
-
logger.debug("No ./config.yaml found, returning None")
|
|
163
172
|
return None
|
|
164
173
|
|
|
165
174
|
|
flyte/extras/_container.py
CHANGED
|
@@ -5,9 +5,9 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
|
|
5
5
|
from flyteidl.core import tasks_pb2
|
|
6
6
|
|
|
7
7
|
from flyte import Image, storage
|
|
8
|
-
from flyte._datastructures import NativeInterface, SerializationContext
|
|
9
8
|
from flyte._logging import logger
|
|
10
9
|
from flyte._task import TaskTemplate
|
|
10
|
+
from flyte.models import NativeInterface, SerializationContext
|
|
11
11
|
|
|
12
12
|
_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
|
|
13
13
|
|
|
@@ -258,8 +258,8 @@ class ContainerTask(TaskTemplate):
|
|
|
258
258
|
}
|
|
259
259
|
|
|
260
260
|
return tasks_pb2.DataLoadingConfig(
|
|
261
|
-
input_path=self._input_data_dir,
|
|
262
|
-
output_path=self._output_data_dir,
|
|
261
|
+
input_path=str(self._input_data_dir),
|
|
262
|
+
output_path=str(self._output_data_dir),
|
|
263
263
|
enabled=True,
|
|
264
264
|
format=literal_to_protobuf.get(self._metadata_format, "JSON"),
|
|
265
265
|
)
|
|
@@ -56,11 +56,12 @@ class ActionID:
|
|
|
56
56
|
name = generate_random_name()
|
|
57
57
|
return replace(self, name=name)
|
|
58
58
|
|
|
59
|
-
def new_sub_action_from(self,
|
|
59
|
+
def new_sub_action_from(self, task_call_seq: int, task_hash: str, input_hash: str, group: str | None) -> ActionID:
|
|
60
60
|
"""Make a deterministic name"""
|
|
61
61
|
import hashlib
|
|
62
62
|
|
|
63
|
-
components = f"{self.
|
|
63
|
+
components = f"{self.name}-{input_hash}-{task_hash}-{task_call_seq}" + (f"-{group}" if group else "")
|
|
64
|
+
logger.debug(f"----- Generating sub-action ID from components: {components}")
|
|
64
65
|
# has the components into something deterministic
|
|
65
66
|
bytes_digest = hashlib.md5(components.encode()).digest()
|
|
66
67
|
new_name = base36_encode(bytes_digest)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def decode_api_key(encoded_str: str) -> tuple[str, str, str, str | Literal["None"]]:
|
|
8
|
+
"""Decode encoded base64 string into app credentials. endpoint, client_id, client_secret, org"""
|
|
9
|
+
endpoint, client_id, client_secret, org = base64.b64decode(encoded_str.encode("utf-8")).decode("utf-8").split(":")
|
|
10
|
+
# For consistency, let's make sure org is always a non-empty string
|
|
11
|
+
if not org:
|
|
12
|
+
org = "None"
|
|
13
|
+
|
|
14
|
+
return endpoint, client_id, client_secret, org
|
|
@@ -44,7 +44,9 @@ def bootstrap_ssl_from_server(endpoint: str) -> grpc.ChannelCredentials:
|
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
async def create_channel(
|
|
47
|
-
endpoint: str,
|
|
47
|
+
endpoint: str | None,
|
|
48
|
+
api_key: str | None = None,
|
|
49
|
+
/,
|
|
48
50
|
insecure: typing.Optional[bool] = None,
|
|
49
51
|
insecure_skip_verify: typing.Optional[bool] = False,
|
|
50
52
|
ca_cert_file_path: typing.Optional[str] = None,
|
|
@@ -66,6 +68,7 @@ async def create_channel(
|
|
|
66
68
|
and create authentication interceptors that perform async operations.
|
|
67
69
|
|
|
68
70
|
:param endpoint: The endpoint URL for the gRPC channel
|
|
71
|
+
:param api_key: API key for authentication; if provided, it will be used to detect the endpoint and credentials.
|
|
69
72
|
:param insecure: Whether to use an insecure channel (no SSL)
|
|
70
73
|
:param insecure_skip_verify: Whether to skip SSL certificate verification
|
|
71
74
|
:param ca_cert_file_path: Path to CA certificate file for SSL verification
|
|
@@ -104,6 +107,18 @@ async def create_channel(
|
|
|
104
107
|
- refresh_access_token_params: Parameters to add when refreshing access token
|
|
105
108
|
:return: grpc.aio.Channel with authentication interceptors configured
|
|
106
109
|
"""
|
|
110
|
+
assert endpoint or api_key, "Either endpoint or api_key must be specified"
|
|
111
|
+
|
|
112
|
+
if api_key:
|
|
113
|
+
from flyte.remote._client.auth._auth_utils import decode_api_key
|
|
114
|
+
|
|
115
|
+
endpoint, client_id, client_secret, org = decode_api_key(api_key)
|
|
116
|
+
kwargs["auth_type"] = "ClientSecret"
|
|
117
|
+
kwargs["client_id"] = client_id
|
|
118
|
+
kwargs["client_secret"] = client_secret
|
|
119
|
+
kwargs["client_credentials_secret"] = client_secret
|
|
120
|
+
|
|
121
|
+
assert endpoint, "Endpoint must be specified by this point"
|
|
107
122
|
|
|
108
123
|
if not ssl_credentials:
|
|
109
124
|
if insecure_skip_verify:
|
|
@@ -119,7 +134,12 @@ async def create_channel(
|
|
|
119
134
|
|
|
120
135
|
# Create an unauthenticated channel first to use to get the server metadata
|
|
121
136
|
if insecure:
|
|
122
|
-
|
|
137
|
+
insecure_kwargs = {}
|
|
138
|
+
if kw_opts := kwargs.get("options"):
|
|
139
|
+
insecure_kwargs["options"] = kw_opts
|
|
140
|
+
if compression:
|
|
141
|
+
insecure_kwargs["compression"] = compression
|
|
142
|
+
unauthenticated_channel = grpc.aio.insecure_channel(endpoint, **insecure_kwargs)
|
|
123
143
|
else:
|
|
124
144
|
unauthenticated_channel = grpc.aio.secure_channel(
|
|
125
145
|
target=endpoint,
|
|
@@ -173,7 +193,12 @@ async def create_channel(
|
|
|
173
193
|
interceptors.extend(auth_interceptors)
|
|
174
194
|
|
|
175
195
|
if insecure:
|
|
176
|
-
|
|
196
|
+
insecure_kwargs = {}
|
|
197
|
+
if kw_opts := kwargs.get("options"):
|
|
198
|
+
insecure_kwargs["options"] = kw_opts
|
|
199
|
+
if compression:
|
|
200
|
+
insecure_kwargs["compression"] = compression
|
|
201
|
+
return grpc.aio.insecure_channel(endpoint, interceptors=interceptors, **insecure_kwargs)
|
|
177
202
|
|
|
178
203
|
return grpc.aio.secure_channel(
|
|
179
204
|
target=endpoint,
|
|
@@ -94,7 +94,7 @@ async def get_token(
|
|
|
94
94
|
http_proxy_url: typing.Optional[str] = None,
|
|
95
95
|
verify: typing.Optional[typing.Union[bool, str]] = None,
|
|
96
96
|
refresh_token: typing.Optional[str] = None,
|
|
97
|
-
) -> typing.Tuple[str, str, int]:
|
|
97
|
+
) -> typing.Tuple[str, str | None, int]:
|
|
98
98
|
"""
|
|
99
99
|
Retrieves an access token from the specified token endpoint.
|
|
100
100
|
|
|
@@ -165,7 +165,7 @@ async def get_token(
|
|
|
165
165
|
if "refresh_token" in j:
|
|
166
166
|
new_refresh_token = j["refresh_token"]
|
|
167
167
|
else:
|
|
168
|
-
|
|
168
|
+
logger.info("No refresh token received, this is expected for client credentials flow")
|
|
169
169
|
|
|
170
170
|
return j["access_token"], new_refresh_token, j["expires_in"]
|
|
171
171
|
|
|
@@ -213,7 +213,7 @@ async def poll_token_endpoint(
|
|
|
213
213
|
scopes: typing.Optional[typing.List[str]] = None,
|
|
214
214
|
http_proxy_url: typing.Optional[str] = None,
|
|
215
215
|
verify: typing.Optional[typing.Union[bool, str]] = None,
|
|
216
|
-
) -> typing.Tuple[str, str, int]:
|
|
216
|
+
) -> typing.Tuple[str, str | None, int]:
|
|
217
217
|
"""
|
|
218
218
|
Polls the token endpoint until authentication is complete or times out.
|
|
219
219
|
|
|
@@ -39,21 +39,21 @@ class ClientSet:
|
|
|
39
39
|
|
|
40
40
|
@classmethod
|
|
41
41
|
async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet:
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
del kwargs["headless"]
|
|
46
|
-
del kwargs["command"]
|
|
47
|
-
del kwargs["client_id"]
|
|
48
|
-
del kwargs["client_credentials_secret"]
|
|
49
|
-
del kwargs["client_config"]
|
|
50
|
-
del kwargs["rpc_retries"]
|
|
51
|
-
del kwargs["http_proxy_url"]
|
|
52
|
-
return cls(await create_channel(endpoint, insecure=insecure, **kwargs), endpoint, insecure=insecure, **kwargs)
|
|
42
|
+
return cls(
|
|
43
|
+
await create_channel(endpoint, None, insecure=insecure, **kwargs), endpoint, insecure=insecure, **kwargs
|
|
44
|
+
)
|
|
53
45
|
|
|
54
46
|
@classmethod
|
|
55
|
-
async def for_api_key(cls, api_key: str, **kwargs) -> ClientSet:
|
|
56
|
-
|
|
47
|
+
async def for_api_key(cls, api_key: str, *, insecure: bool = False, **kwargs) -> ClientSet:
|
|
48
|
+
from flyte.remote._client.auth._auth_utils import decode_api_key
|
|
49
|
+
|
|
50
|
+
# Parsing the API key is done in create_channel, but cleaner to redo it here rather than getting create_channel
|
|
51
|
+
# to return the endpoint
|
|
52
|
+
endpoint, _, _, _ = decode_api_key(api_key)
|
|
53
|
+
|
|
54
|
+
return cls(
|
|
55
|
+
await create_channel(None, api_key, insecure=insecure, **kwargs), endpoint, insecure=insecure, **kwargs
|
|
56
|
+
)
|
|
57
57
|
|
|
58
58
|
@classmethod
|
|
59
59
|
async def for_serverless(cls) -> ClientSet:
|
flyte/remote/_logs.py
CHANGED
|
@@ -111,6 +111,6 @@ class Logs:
|
|
|
111
111
|
log_source=cls.tail.aio(cls, action_id=action_id, attempt=attempt),
|
|
112
112
|
max_lines=max_lines,
|
|
113
113
|
show_ts=show_ts,
|
|
114
|
-
name=f"{action_id.run.name}:{action_id.name}",
|
|
114
|
+
name=f"{action_id.run.name}:{action_id.name} ({attempt})",
|
|
115
115
|
)
|
|
116
116
|
await viewer.run()
|
flyte/remote/_run.py
CHANGED
|
@@ -243,14 +243,10 @@ class Run:
|
|
|
243
243
|
return
|
|
244
244
|
yield ad
|
|
245
245
|
|
|
246
|
-
async def show_logs(
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
max_lines=max_lines,
|
|
251
|
-
show_ts=show_ts,
|
|
252
|
-
raw=raw,
|
|
253
|
-
)
|
|
246
|
+
async def show_logs(
|
|
247
|
+
self, attempt: int | None = None, max_lines: int = 100, show_ts: bool = False, raw: bool = False
|
|
248
|
+
):
|
|
249
|
+
await self.action.show_logs(attempt, max_lines, show_ts, raw)
|
|
254
250
|
|
|
255
251
|
async def details(self) -> RunDetails:
|
|
256
252
|
"""
|
flyte/remote/_task.py
CHANGED
|
@@ -11,9 +11,9 @@ import flyte
|
|
|
11
11
|
import flyte.errors
|
|
12
12
|
from flyte._api_commons import syncer
|
|
13
13
|
from flyte._context import internal_ctx
|
|
14
|
-
from flyte._datastructures import NativeInterface
|
|
15
14
|
from flyte._initialize import get_client, get_common_config
|
|
16
15
|
from flyte._protos.workflow import task_definition_pb2, task_service_pb2
|
|
16
|
+
from flyte.models import NativeInterface
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class LazyEntity:
|
|
@@ -187,7 +187,7 @@ class Task:
|
|
|
187
187
|
# We will also check if we are not initialized, It is not expected to be not initialized
|
|
188
188
|
from flyte._internal.controllers import get_controller
|
|
189
189
|
|
|
190
|
-
controller =
|
|
190
|
+
controller = get_controller()
|
|
191
191
|
if controller:
|
|
192
192
|
return await controller.submit_task_ref(self.pb2, *args, **kwargs)
|
|
193
193
|
raise flyte.errors
|
flyte/storage/__init__.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
|
1
1
|
__all__ = [
|
|
2
|
+
"ABFS",
|
|
3
|
+
"GCS",
|
|
4
|
+
"S3",
|
|
5
|
+
"Storage",
|
|
2
6
|
"get",
|
|
3
7
|
"get_random_local_directory",
|
|
4
8
|
"get_random_local_path",
|
|
@@ -11,6 +15,7 @@ __all__ = [
|
|
|
11
15
|
"put_stream",
|
|
12
16
|
]
|
|
13
17
|
|
|
18
|
+
from ._config import ABFS, GCS, S3, Storage
|
|
14
19
|
from ._storage import (
|
|
15
20
|
get,
|
|
16
21
|
get_random_local_directory,
|