flyte 0.2.0b1__py3-none-any.whl → 0.2.0b3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +3 -4
- flyte/_bin/runtime.py +21 -7
- flyte/_cache/cache.py +1 -2
- flyte/_cli/_common.py +26 -4
- flyte/_cli/_create.py +48 -0
- flyte/_cli/_deploy.py +4 -2
- flyte/_cli/_get.py +18 -7
- flyte/_cli/_run.py +1 -0
- flyte/_cli/main.py +11 -5
- flyte/_code_bundle/bundle.py +42 -11
- flyte/_context.py +1 -1
- flyte/_deploy.py +3 -1
- flyte/_group.py +1 -1
- flyte/_initialize.py +28 -247
- flyte/_internal/controllers/__init__.py +6 -6
- flyte/_internal/controllers/_local_controller.py +14 -5
- flyte/_internal/controllers/_trace.py +1 -1
- flyte/_internal/controllers/remote/__init__.py +27 -7
- flyte/_internal/controllers/remote/_action.py +1 -1
- flyte/_internal/controllers/remote/_client.py +5 -1
- flyte/_internal/controllers/remote/_controller.py +68 -24
- flyte/_internal/controllers/remote/_core.py +1 -1
- flyte/_internal/runtime/convert.py +34 -8
- flyte/_internal/runtime/entrypoints.py +1 -1
- flyte/_internal/runtime/io.py +3 -3
- flyte/_internal/runtime/task_serde.py +31 -1
- flyte/_internal/runtime/taskrunner.py +1 -1
- flyte/_internal/runtime/types_serde.py +1 -1
- flyte/_run.py +47 -28
- flyte/_task.py +2 -2
- flyte/_task_environment.py +1 -1
- flyte/_trace.py +5 -6
- flyte/_utils/__init__.py +2 -0
- flyte/_utils/async_cache.py +139 -0
- flyte/_version.py +2 -2
- flyte/config/__init__.py +26 -4
- flyte/config/_config.py +13 -4
- flyte/extras/_container.py +3 -3
- flyte/{_datastructures.py → models.py} +3 -2
- flyte/remote/_client/auth/_auth_utils.py +14 -0
- flyte/remote/_client/auth/_channel.py +28 -3
- flyte/remote/_client/auth/_token_client.py +3 -3
- flyte/remote/_client/controlplane.py +13 -13
- flyte/remote/_logs.py +1 -1
- flyte/remote/_run.py +4 -8
- flyte/remote/_task.py +2 -2
- flyte/storage/__init__.py +5 -0
- flyte/storage/_config.py +233 -0
- flyte/storage/_storage.py +23 -3
- flyte/types/_interface.py +1 -1
- flyte/types/_type_engine.py +1 -1
- {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/METADATA +2 -2
- {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/RECORD +56 -54
- flyte/_internal/controllers/pbhash.py +0 -39
- {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/WHEEL +0 -0
- {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/entry_points.txt +0 -0
- {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/top_level.txt +0 -0
flyte/_initialize.py
CHANGED
|
@@ -1,14 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import datetime
|
|
4
3
|
import functools
|
|
5
|
-
import os
|
|
6
4
|
import threading
|
|
7
5
|
import typing
|
|
8
6
|
from dataclasses import dataclass, replace
|
|
9
|
-
from datetime import timedelta
|
|
10
7
|
from pathlib import Path
|
|
11
|
-
from typing import TYPE_CHECKING,
|
|
8
|
+
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, TypeVar
|
|
12
9
|
|
|
13
10
|
from flyte.errors import InitializationError
|
|
14
11
|
|
|
@@ -20,242 +17,11 @@ if TYPE_CHECKING:
|
|
|
20
17
|
from flyte.config import Config
|
|
21
18
|
from flyte.remote._client.auth import AuthType, ClientConfig
|
|
22
19
|
from flyte.remote._client.controlplane import ClientSet
|
|
20
|
+
from flyte.storage import Storage
|
|
23
21
|
|
|
24
22
|
Mode = Literal["local", "remote"]
|
|
25
23
|
|
|
26
24
|
|
|
27
|
-
def set_if_exists(d: dict, k: str, val: typing.Any) -> dict:
|
|
28
|
-
"""
|
|
29
|
-
Given a dict ``d`` sets the key ``k`` with value of config ``v``, if the config value ``v`` is set
|
|
30
|
-
and return the updated dictionary.
|
|
31
|
-
"""
|
|
32
|
-
exists = isinstance(val, bool) or bool(val is not None and val)
|
|
33
|
-
if exists:
|
|
34
|
-
d[k] = val
|
|
35
|
-
return d
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
@dataclass(init=True, repr=True, eq=True, frozen=True)
|
|
39
|
-
class Storage(object):
|
|
40
|
-
"""
|
|
41
|
-
Data storage configuration that applies across any provider.
|
|
42
|
-
"""
|
|
43
|
-
|
|
44
|
-
retries: int = 3
|
|
45
|
-
backoff: datetime.timedelta = datetime.timedelta(seconds=5)
|
|
46
|
-
enable_debug: bool = False
|
|
47
|
-
attach_execution_metadata: bool = True
|
|
48
|
-
|
|
49
|
-
_KEY_ENV_VAR_MAPPING: ClassVar[typing.Dict[str, str]] = {
|
|
50
|
-
"enable_debug": "UNION_STORAGE_DEBUG",
|
|
51
|
-
"retries": "UNION_STORAGE_RETRIES",
|
|
52
|
-
"backoff": "UNION_STORAGE_BACKOFF_SECONDS",
|
|
53
|
-
}
|
|
54
|
-
|
|
55
|
-
def get_fsspec_kwargs(self, anonymous: bool = False, /, **kwargs) -> Dict[str, Any]:
|
|
56
|
-
"""
|
|
57
|
-
Returns the configuration as kwargs for constructing an fsspec filesystem.
|
|
58
|
-
"""
|
|
59
|
-
return {}
|
|
60
|
-
|
|
61
|
-
@classmethod
|
|
62
|
-
def _auto_as_kwargs(cls) -> Dict[str, Any]:
|
|
63
|
-
retries = os.getenv(cls._KEY_ENV_VAR_MAPPING["retries"])
|
|
64
|
-
backoff = os.getenv(cls._KEY_ENV_VAR_MAPPING["backoff"])
|
|
65
|
-
enable_debug = os.getenv(cls._KEY_ENV_VAR_MAPPING["enable_debug"])
|
|
66
|
-
|
|
67
|
-
kwargs: Dict[str, Any] = {}
|
|
68
|
-
kwargs = set_if_exists(kwargs, "enable_debug", enable_debug)
|
|
69
|
-
kwargs = set_if_exists(kwargs, "retries", retries)
|
|
70
|
-
kwargs = set_if_exists(kwargs, "backoff", backoff)
|
|
71
|
-
return kwargs
|
|
72
|
-
|
|
73
|
-
@classmethod
|
|
74
|
-
def auto(cls) -> Storage:
|
|
75
|
-
"""
|
|
76
|
-
Construct the config object automatically from environment variables.
|
|
77
|
-
"""
|
|
78
|
-
return cls(**cls._auto_as_kwargs())
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
@dataclass(init=True, repr=True, eq=True, frozen=True)
|
|
82
|
-
class S3(Storage):
|
|
83
|
-
"""
|
|
84
|
-
S3 specific configuration
|
|
85
|
-
"""
|
|
86
|
-
|
|
87
|
-
endpoint: typing.Optional[str] = None
|
|
88
|
-
access_key_id: typing.Optional[str] = None
|
|
89
|
-
secret_access_key: typing.Optional[str] = None
|
|
90
|
-
|
|
91
|
-
_KEY_ENV_VAR_MAPPING: ClassVar[typing.Dict[str, str]] = {
|
|
92
|
-
"endpoint": "FLYTE_AWS_ENDPOINT",
|
|
93
|
-
"access_key_id": "FLYTE_AWS_ACCESS_KEY_ID",
|
|
94
|
-
"secret_access_key": "FLYTE_AWS_SECRET_ACCESS_KEY",
|
|
95
|
-
} | Storage._KEY_ENV_VAR_MAPPING
|
|
96
|
-
|
|
97
|
-
# Refer to https://github.com/developmentseed/obstore/blob/33654fc37f19a657689eb93327b621e9f9e01494/obstore/python/obstore/store/_aws.pyi#L11
|
|
98
|
-
# for key and secret
|
|
99
|
-
_CONFIG_KEY_FSSPEC_S3_KEY_ID: ClassVar = "access_key_id"
|
|
100
|
-
_CONFIG_KEY_FSSPEC_S3_SECRET: ClassVar = "secret_access_key"
|
|
101
|
-
_CONFIG_KEY_ENDPOINT: ClassVar = "endpoint_url"
|
|
102
|
-
_KEY_SKIP_SIGNATURE: ClassVar = "skip_signature"
|
|
103
|
-
|
|
104
|
-
@classmethod
|
|
105
|
-
def auto(cls) -> S3:
|
|
106
|
-
"""
|
|
107
|
-
:return: Config
|
|
108
|
-
"""
|
|
109
|
-
endpoint = os.getenv(cls._KEY_ENV_VAR_MAPPING["endpoint"], None)
|
|
110
|
-
access_key_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["access_key_id"], None)
|
|
111
|
-
secret_access_key = os.getenv(cls._KEY_ENV_VAR_MAPPING["secret_access_key"], None)
|
|
112
|
-
|
|
113
|
-
kwargs = super()._auto_as_kwargs()
|
|
114
|
-
kwargs = set_if_exists(kwargs, "endpoint", endpoint)
|
|
115
|
-
kwargs = set_if_exists(kwargs, "access_key_id", access_key_id)
|
|
116
|
-
kwargs = set_if_exists(kwargs, "secret_access_key", secret_access_key)
|
|
117
|
-
|
|
118
|
-
return S3(**kwargs)
|
|
119
|
-
|
|
120
|
-
@classmethod
|
|
121
|
-
def for_sandbox(cls) -> S3:
|
|
122
|
-
"""
|
|
123
|
-
:return:
|
|
124
|
-
"""
|
|
125
|
-
kwargs = super()._auto_as_kwargs()
|
|
126
|
-
final_kwargs = kwargs | {
|
|
127
|
-
"endpoint": "http://localhost:4566",
|
|
128
|
-
"access_key_id": "minio",
|
|
129
|
-
"secret_access_key": "miniostorage",
|
|
130
|
-
}
|
|
131
|
-
return S3(**final_kwargs)
|
|
132
|
-
|
|
133
|
-
def get_fsspec_kwargs(self, anonymous: bool = False, /, **kwargs) -> Dict[str, Any]:
|
|
134
|
-
# Construct the config object
|
|
135
|
-
config: Dict[str, Any] = {}
|
|
136
|
-
if self._CONFIG_KEY_FSSPEC_S3_KEY_ID in kwargs or self.access_key_id:
|
|
137
|
-
config[self._CONFIG_KEY_FSSPEC_S3_KEY_ID] = kwargs.pop(
|
|
138
|
-
self._CONFIG_KEY_FSSPEC_S3_KEY_ID, self.access_key_id
|
|
139
|
-
)
|
|
140
|
-
if self._CONFIG_KEY_FSSPEC_S3_SECRET in kwargs or self.secret_access_key:
|
|
141
|
-
config[self._CONFIG_KEY_FSSPEC_S3_SECRET] = kwargs.pop(
|
|
142
|
-
self._CONFIG_KEY_FSSPEC_S3_SECRET, self.secret_access_key
|
|
143
|
-
)
|
|
144
|
-
if self._CONFIG_KEY_ENDPOINT in kwargs or self.endpoint:
|
|
145
|
-
config["endpoint_url"] = kwargs.pop(self._CONFIG_KEY_ENDPOINT, self.endpoint)
|
|
146
|
-
|
|
147
|
-
retries = kwargs.pop("retries", self.retries)
|
|
148
|
-
backoff = kwargs.pop("backoff", self.backoff)
|
|
149
|
-
|
|
150
|
-
if anonymous:
|
|
151
|
-
config[self._KEY_SKIP_SIGNATURE] = True
|
|
152
|
-
|
|
153
|
-
retry_config = {
|
|
154
|
-
"max_retries": retries,
|
|
155
|
-
"backoff": {
|
|
156
|
-
"base": 2,
|
|
157
|
-
"init_backoff": backoff,
|
|
158
|
-
"max_backoff": timedelta(seconds=16),
|
|
159
|
-
},
|
|
160
|
-
"retry_timeout": timedelta(minutes=3),
|
|
161
|
-
}
|
|
162
|
-
|
|
163
|
-
client_options = {"timeout": "99999s", "allow_http": True}
|
|
164
|
-
|
|
165
|
-
if config:
|
|
166
|
-
kwargs["config"] = config
|
|
167
|
-
kwargs["client_options"] = client_options or None
|
|
168
|
-
kwargs["retry_config"] = retry_config or None
|
|
169
|
-
|
|
170
|
-
return kwargs
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
@dataclass(init=True, repr=True, eq=True, frozen=True)
|
|
174
|
-
class GCS(Storage):
|
|
175
|
-
"""
|
|
176
|
-
Any GCS specific configuration.
|
|
177
|
-
"""
|
|
178
|
-
|
|
179
|
-
gsutil_parallelism: bool = False
|
|
180
|
-
|
|
181
|
-
_KEY_ENV_VAR_MAPPING: ClassVar[dict[str, str]] = {
|
|
182
|
-
"gsutil_parallelism": "GCP_GSUTIL_PARALLELISM",
|
|
183
|
-
}
|
|
184
|
-
|
|
185
|
-
@classmethod
|
|
186
|
-
def auto(cls) -> GCS:
|
|
187
|
-
gsutil_parallelism = os.getenv(cls._KEY_ENV_VAR_MAPPING["gsutil_parallelism"], None)
|
|
188
|
-
|
|
189
|
-
kwargs: Dict[str, Any] = {}
|
|
190
|
-
kwargs = set_if_exists(kwargs, "gsutil_parallelism", gsutil_parallelism)
|
|
191
|
-
return GCS(**kwargs)
|
|
192
|
-
|
|
193
|
-
def get_fsspec_kwargs(self, anonymous: bool = False, /, **kwargs) -> Dict[str, Any]:
|
|
194
|
-
return kwargs
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
@dataclass(init=True, repr=True, eq=True, frozen=True)
|
|
198
|
-
class ABFS(Storage):
|
|
199
|
-
"""
|
|
200
|
-
Any Azure Blob Storage specific configuration.
|
|
201
|
-
"""
|
|
202
|
-
|
|
203
|
-
account_name: typing.Optional[str] = None
|
|
204
|
-
account_key: typing.Optional[str] = None
|
|
205
|
-
tenant_id: typing.Optional[str] = None
|
|
206
|
-
client_id: typing.Optional[str] = None
|
|
207
|
-
client_secret: typing.Optional[str] = None
|
|
208
|
-
|
|
209
|
-
_KEY_ENV_VAR_MAPPING: ClassVar[dict[str, str]] = {
|
|
210
|
-
"account_name": "AZURE_STORAGE_ACCOUNT_NAME",
|
|
211
|
-
"account_key": "AZURE_STORAGE_ACCOUNT_KEY",
|
|
212
|
-
"tenant_id": "AZURE_TENANT_ID",
|
|
213
|
-
"client_id": "AZURE_CLIENT_ID",
|
|
214
|
-
"client_secret": "AZURE_CLIENT_SECRET",
|
|
215
|
-
}
|
|
216
|
-
_KEY_SKIP_SIGNATURE: ClassVar = "skip_signature"
|
|
217
|
-
|
|
218
|
-
@classmethod
|
|
219
|
-
def auto(cls) -> ABFS:
|
|
220
|
-
account_name = os.getenv(cls._KEY_ENV_VAR_MAPPING["account_name"], None)
|
|
221
|
-
account_key = os.getenv(cls._KEY_ENV_VAR_MAPPING["account_key"], None)
|
|
222
|
-
tenant_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["tenant_id"], None)
|
|
223
|
-
client_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["client_id"], None)
|
|
224
|
-
client_secret = os.getenv(cls._KEY_ENV_VAR_MAPPING["client_secret"], None)
|
|
225
|
-
|
|
226
|
-
kwargs: Dict[str, Any] = {}
|
|
227
|
-
kwargs = set_if_exists(kwargs, "account_name", account_name)
|
|
228
|
-
kwargs = set_if_exists(kwargs, "account_key", account_key)
|
|
229
|
-
kwargs = set_if_exists(kwargs, "tenant_id", tenant_id)
|
|
230
|
-
kwargs = set_if_exists(kwargs, "client_id", client_id)
|
|
231
|
-
kwargs = set_if_exists(kwargs, "client_secret", client_secret)
|
|
232
|
-
return ABFS(**kwargs)
|
|
233
|
-
|
|
234
|
-
def get_fsspec_kwargs(self, anonymous: bool = False, /, **kwargs) -> Dict[str, Any]:
|
|
235
|
-
config: Dict[str, Any] = {}
|
|
236
|
-
if "account_name" in kwargs or self.account_name:
|
|
237
|
-
config["account_name"] = kwargs.get("account_name", self.account_name)
|
|
238
|
-
if "account_key" in kwargs or self.account_key:
|
|
239
|
-
config["account_key"] = kwargs.get("account_key", self.account_key)
|
|
240
|
-
if "client_id" in kwargs or self.client_id:
|
|
241
|
-
config["client_id"] = kwargs.get("client_id", self.client_id)
|
|
242
|
-
if "client_secret" in kwargs or self.client_secret:
|
|
243
|
-
config["client_secret"] = kwargs.get("client_secret", self.client_secret)
|
|
244
|
-
if "tenant_id" in kwargs or self.tenant_id:
|
|
245
|
-
config["tenant_id"] = kwargs.get("tenant_id", self.tenant_id)
|
|
246
|
-
|
|
247
|
-
if anonymous:
|
|
248
|
-
config[self._KEY_SKIP_SIGNATURE] = True
|
|
249
|
-
|
|
250
|
-
client_options = {"timeout": "99999s", "allow_http": "true"}
|
|
251
|
-
|
|
252
|
-
if config:
|
|
253
|
-
kwargs["config"] = config
|
|
254
|
-
kwargs["client_options"] = client_options
|
|
255
|
-
|
|
256
|
-
return kwargs
|
|
257
|
-
|
|
258
|
-
|
|
259
25
|
@dataclass(init=True, repr=True, eq=True, frozen=True, kw_only=True)
|
|
260
26
|
class CommonInit:
|
|
261
27
|
"""
|
|
@@ -304,11 +70,10 @@ async def _initialize_client(
|
|
|
304
70
|
"""
|
|
305
71
|
from flyte.remote._client.controlplane import ClientSet
|
|
306
72
|
|
|
307
|
-
if endpoint
|
|
73
|
+
if endpoint:
|
|
308
74
|
return await ClientSet.for_endpoint(
|
|
309
75
|
endpoint,
|
|
310
76
|
insecure=insecure,
|
|
311
|
-
api_key=api_key,
|
|
312
77
|
insecure_skip_verify=insecure_skip_verify,
|
|
313
78
|
auth_type=auth_type,
|
|
314
79
|
headless=headless,
|
|
@@ -321,7 +86,26 @@ async def _initialize_client(
|
|
|
321
86
|
rpc_retries=rpc_retries,
|
|
322
87
|
http_proxy_url=http_proxy_url,
|
|
323
88
|
)
|
|
324
|
-
|
|
89
|
+
elif api_key:
|
|
90
|
+
return await ClientSet.for_api_key(
|
|
91
|
+
api_key,
|
|
92
|
+
insecure=insecure,
|
|
93
|
+
insecure_skip_verify=insecure_skip_verify,
|
|
94
|
+
auth_type=auth_type,
|
|
95
|
+
headless=headless,
|
|
96
|
+
ca_cert_file_path=ca_cert_file_path,
|
|
97
|
+
command=command,
|
|
98
|
+
proxy_command=proxy_command,
|
|
99
|
+
client_id=client_id,
|
|
100
|
+
client_credentials_secret=client_credentials_secret,
|
|
101
|
+
client_config=client_config,
|
|
102
|
+
rpc_retries=rpc_retries,
|
|
103
|
+
http_proxy_url=http_proxy_url,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
raise InitializationError(
|
|
107
|
+
"MissingEndpointOrApiKeyError", "user", "Either endpoint or api_key must be provided to initialize the client."
|
|
108
|
+
)
|
|
325
109
|
|
|
326
110
|
|
|
327
111
|
@syncer.wrap
|
|
@@ -396,9 +180,9 @@ async def init(
|
|
|
396
180
|
|
|
397
181
|
with _init_lock:
|
|
398
182
|
if config is None:
|
|
399
|
-
|
|
183
|
+
import flyte.config as _f_cfg
|
|
400
184
|
|
|
401
|
-
config = Config
|
|
185
|
+
config = _f_cfg.Config()
|
|
402
186
|
platform_cfg = config.platform
|
|
403
187
|
task_cfg = config.task
|
|
404
188
|
client = None
|
|
@@ -458,7 +242,7 @@ def get_common_config() -> CommonInit:
|
|
|
458
242
|
return cfg
|
|
459
243
|
|
|
460
244
|
|
|
461
|
-
def get_storage() -> Storage:
|
|
245
|
+
def get_storage() -> Storage | None:
|
|
462
246
|
"""
|
|
463
247
|
Get the current storage configuration. Thread-safe implementation.
|
|
464
248
|
|
|
@@ -472,9 +256,6 @@ def get_storage() -> Storage:
|
|
|
472
256
|
"Configuration has not been initialized. Call flyte.init() with a valid endpoint or",
|
|
473
257
|
" api-key before using this function.",
|
|
474
258
|
)
|
|
475
|
-
if cfg.storage is None:
|
|
476
|
-
# return default local storage
|
|
477
|
-
return typing.cast(Storage, cfg.replace(storage=Storage()).storage)
|
|
478
259
|
return cfg.storage
|
|
479
260
|
|
|
480
261
|
|
|
@@ -504,13 +285,13 @@ def is_initialized() -> bool:
|
|
|
504
285
|
return _get_init_config() is not None
|
|
505
286
|
|
|
506
287
|
|
|
507
|
-
def initialize_in_cluster(
|
|
288
|
+
def initialize_in_cluster() -> None:
|
|
508
289
|
"""
|
|
509
290
|
Initialize the system for in-cluster execution. This is a placeholder function and does not perform any actions.
|
|
510
291
|
|
|
511
292
|
:return: None
|
|
512
293
|
"""
|
|
513
|
-
init(
|
|
294
|
+
init()
|
|
514
295
|
|
|
515
296
|
|
|
516
297
|
# Define a generic type variable for the decorated function
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import threading
|
|
2
|
-
from typing import Any, Literal, Optional, Protocol, Tuple, TypeVar
|
|
2
|
+
from typing import Any, Callable, Literal, Optional, Protocol, Tuple, TypeVar
|
|
3
3
|
|
|
4
|
-
from flyte._datastructures import ActionID, NativeInterface
|
|
5
4
|
from flyte._task import TaskTemplate
|
|
5
|
+
from flyte.models import ActionID, NativeInterface
|
|
6
6
|
|
|
7
7
|
from ._trace import TraceInfo
|
|
8
8
|
|
|
@@ -47,12 +47,12 @@ class Controller(Protocol):
|
|
|
47
47
|
async def watch_for_errors(self): ...
|
|
48
48
|
|
|
49
49
|
async def get_action_outputs(
|
|
50
|
-
self, _interface: NativeInterface,
|
|
50
|
+
self, _interface: NativeInterface, _func: Callable, *args, **kwargs
|
|
51
51
|
) -> Tuple[TraceInfo, bool]:
|
|
52
52
|
"""
|
|
53
53
|
This method returns the outputs of the action, if it is available.
|
|
54
54
|
:param _interface: NativeInterface
|
|
55
|
-
:param
|
|
55
|
+
:param _func: Function name
|
|
56
56
|
:param args: Arguments
|
|
57
57
|
:param kwargs: Keyword arguments
|
|
58
58
|
:return: TraceInfo object and a boolean indicating if the action was found.
|
|
@@ -81,13 +81,13 @@ class _ControllerState:
|
|
|
81
81
|
lock = threading.Lock()
|
|
82
82
|
|
|
83
83
|
|
|
84
|
-
|
|
84
|
+
def get_controller() -> Controller:
|
|
85
85
|
"""
|
|
86
86
|
Get the controller instance. Raise an error if it has not been created.
|
|
87
87
|
"""
|
|
88
88
|
if _ControllerState.controller is not None:
|
|
89
89
|
return _ControllerState.controller
|
|
90
|
-
raise RuntimeError("Controller is not initialized. Please call
|
|
90
|
+
raise RuntimeError("Controller is not initialized. Please call create_controller() first.")
|
|
91
91
|
|
|
92
92
|
|
|
93
93
|
def create_controller(
|
|
@@ -1,14 +1,14 @@
|
|
|
1
|
-
from typing import Any, Tuple, TypeVar
|
|
1
|
+
from typing import Any, Callable, Tuple, TypeVar
|
|
2
2
|
|
|
3
3
|
import flyte.errors
|
|
4
4
|
from flyte._context import internal_ctx
|
|
5
|
-
from flyte._datastructures import ActionID, NativeInterface, RawDataPath
|
|
6
5
|
from flyte._internal.controllers import TraceInfo
|
|
7
6
|
from flyte._internal.runtime import convert
|
|
8
7
|
from flyte._internal.runtime.entrypoints import direct_dispatch
|
|
9
8
|
from flyte._logging import log, logger
|
|
10
9
|
from flyte._protos.workflow import task_definition_pb2
|
|
11
10
|
from flyte._task import TaskTemplate
|
|
11
|
+
from flyte.models import ActionID, NativeInterface, RawDataPath
|
|
12
12
|
|
|
13
13
|
R = TypeVar("R")
|
|
14
14
|
|
|
@@ -28,7 +28,11 @@ class LocalController:
|
|
|
28
28
|
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
29
29
|
|
|
30
30
|
inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
|
|
31
|
-
|
|
31
|
+
serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
|
|
32
|
+
|
|
33
|
+
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
34
|
+
tctx, _task.name, serialized_inputs, 0
|
|
35
|
+
)
|
|
32
36
|
sub_action_raw_data_path = RawDataPath(path=sub_action_output_path)
|
|
33
37
|
|
|
34
38
|
out, err = await direct_dispatch(
|
|
@@ -64,7 +68,7 @@ class LocalController:
|
|
|
64
68
|
pass
|
|
65
69
|
|
|
66
70
|
async def get_action_outputs(
|
|
67
|
-
self, _interface: NativeInterface,
|
|
71
|
+
self, _interface: NativeInterface, _func: Callable, *args, **kwargs
|
|
68
72
|
) -> Tuple[TraceInfo, bool]:
|
|
69
73
|
"""
|
|
70
74
|
This method returns the outputs of the action, if it is available.
|
|
@@ -79,8 +83,13 @@ class LocalController:
|
|
|
79
83
|
if _interface.inputs:
|
|
80
84
|
converted_inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
|
|
81
85
|
assert converted_inputs
|
|
86
|
+
|
|
87
|
+
serialized_inputs = converted_inputs.proto_inputs.SerializeToString(deterministic=True)
|
|
82
88
|
action_id, action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
83
|
-
tctx,
|
|
89
|
+
tctx,
|
|
90
|
+
_func.__name__,
|
|
91
|
+
serialized_inputs,
|
|
92
|
+
0,
|
|
84
93
|
)
|
|
85
94
|
assert action_output_path
|
|
86
95
|
return (
|
|
@@ -10,13 +10,13 @@ __all__ = ["RemoteController", "create_remote_controller"]
|
|
|
10
10
|
def create_remote_controller(
|
|
11
11
|
*,
|
|
12
12
|
api_key: str | None = None,
|
|
13
|
-
|
|
14
|
-
endpoint: str,
|
|
15
|
-
client_config: ClientConfig | None = None,
|
|
16
|
-
headless: bool = False,
|
|
13
|
+
endpoint: str | None = None,
|
|
17
14
|
insecure: bool = False,
|
|
18
15
|
insecure_skip_verify: bool = False,
|
|
19
16
|
ca_cert_file_path: str | None = None,
|
|
17
|
+
client_config: ClientConfig | None = None,
|
|
18
|
+
auth_type: AuthType = "Pkce",
|
|
19
|
+
headless: bool = False,
|
|
20
20
|
command: List[str] | None = None,
|
|
21
21
|
proxy_command: List[str] | None = None,
|
|
22
22
|
client_id: str | None = None,
|
|
@@ -27,13 +27,33 @@ def create_remote_controller(
|
|
|
27
27
|
"""
|
|
28
28
|
Create a new instance of the remote controller.
|
|
29
29
|
"""
|
|
30
|
+
assert endpoint or api_key, "Either endpoint or api_key must be provided when initializing remote controller"
|
|
30
31
|
from ._client import ControllerClient
|
|
31
32
|
from ._controller import RemoteController
|
|
32
33
|
|
|
34
|
+
if endpoint:
|
|
35
|
+
client_coro = ControllerClient.for_endpoint(
|
|
36
|
+
endpoint,
|
|
37
|
+
insecure=insecure,
|
|
38
|
+
insecure_skip_verify=insecure_skip_verify,
|
|
39
|
+
ca_cert_file_path=ca_cert_file_path,
|
|
40
|
+
client_id=client_id,
|
|
41
|
+
client_credentials_secret=client_credentials_secret,
|
|
42
|
+
auth_type=auth_type,
|
|
43
|
+
)
|
|
44
|
+
elif api_key:
|
|
45
|
+
client_coro = ControllerClient.for_api_key(
|
|
46
|
+
api_key,
|
|
47
|
+
insecure=insecure,
|
|
48
|
+
insecure_skip_verify=insecure_skip_verify,
|
|
49
|
+
ca_cert_file_path=ca_cert_file_path,
|
|
50
|
+
client_id=client_id,
|
|
51
|
+
client_credentials_secret=client_credentials_secret,
|
|
52
|
+
auth_type=auth_type,
|
|
53
|
+
)
|
|
54
|
+
|
|
33
55
|
controller = RemoteController(
|
|
34
|
-
client_coro=
|
|
35
|
-
endpoint=endpoint, insecure=insecure, insecure_skip_verify=insecure_skip_verify
|
|
36
|
-
),
|
|
56
|
+
client_coro=client_coro,
|
|
37
57
|
workers=10,
|
|
38
58
|
max_system_retries=5,
|
|
39
59
|
)
|
|
@@ -4,8 +4,8 @@ from dataclasses import dataclass
|
|
|
4
4
|
|
|
5
5
|
from flyteidl.core import execution_pb2
|
|
6
6
|
|
|
7
|
-
from flyte._datastructures import GroupData
|
|
8
7
|
from flyte._protos.workflow import run_definition_pb2, state_service_pb2, task_definition_pb2
|
|
8
|
+
from flyte.models import GroupData
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
@dataclass
|
|
@@ -20,7 +20,11 @@ class ControllerClient:
|
|
|
20
20
|
|
|
21
21
|
@classmethod
|
|
22
22
|
async def for_endpoint(cls, endpoint: str, insecure: bool = False, **kwargs) -> ControllerClient:
|
|
23
|
-
return cls(await create_channel(endpoint, insecure=insecure, **kwargs))
|
|
23
|
+
return cls(await create_channel(endpoint, None, insecure=insecure, **kwargs))
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
async def for_api_key(cls, api_key: str, insecure: bool = False, **kwargs) -> ControllerClient:
|
|
27
|
+
return cls(await create_channel(None, api_key, insecure=insecure, **kwargs))
|
|
24
28
|
|
|
25
29
|
@property
|
|
26
30
|
def state_service(self) -> StateService:
|