flyte 0.2.0b1__py3-none-any.whl → 0.2.0b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (57) hide show
  1. flyte/__init__.py +3 -4
  2. flyte/_bin/runtime.py +21 -7
  3. flyte/_cache/cache.py +1 -2
  4. flyte/_cli/_common.py +26 -4
  5. flyte/_cli/_create.py +48 -0
  6. flyte/_cli/_deploy.py +4 -2
  7. flyte/_cli/_get.py +18 -7
  8. flyte/_cli/_run.py +1 -0
  9. flyte/_cli/main.py +11 -5
  10. flyte/_code_bundle/bundle.py +42 -11
  11. flyte/_context.py +1 -1
  12. flyte/_deploy.py +3 -1
  13. flyte/_group.py +1 -1
  14. flyte/_initialize.py +28 -247
  15. flyte/_internal/controllers/__init__.py +6 -6
  16. flyte/_internal/controllers/_local_controller.py +14 -5
  17. flyte/_internal/controllers/_trace.py +1 -1
  18. flyte/_internal/controllers/remote/__init__.py +27 -7
  19. flyte/_internal/controllers/remote/_action.py +1 -1
  20. flyte/_internal/controllers/remote/_client.py +5 -1
  21. flyte/_internal/controllers/remote/_controller.py +68 -24
  22. flyte/_internal/controllers/remote/_core.py +1 -1
  23. flyte/_internal/runtime/convert.py +34 -8
  24. flyte/_internal/runtime/entrypoints.py +1 -1
  25. flyte/_internal/runtime/io.py +3 -3
  26. flyte/_internal/runtime/task_serde.py +31 -1
  27. flyte/_internal/runtime/taskrunner.py +1 -1
  28. flyte/_internal/runtime/types_serde.py +1 -1
  29. flyte/_run.py +47 -28
  30. flyte/_task.py +2 -2
  31. flyte/_task_environment.py +1 -1
  32. flyte/_trace.py +5 -6
  33. flyte/_utils/__init__.py +2 -0
  34. flyte/_utils/async_cache.py +139 -0
  35. flyte/_version.py +2 -2
  36. flyte/config/__init__.py +26 -4
  37. flyte/config/_config.py +13 -4
  38. flyte/extras/_container.py +3 -3
  39. flyte/{_datastructures.py → models.py} +3 -2
  40. flyte/remote/_client/auth/_auth_utils.py +14 -0
  41. flyte/remote/_client/auth/_channel.py +28 -3
  42. flyte/remote/_client/auth/_token_client.py +3 -3
  43. flyte/remote/_client/controlplane.py +13 -13
  44. flyte/remote/_logs.py +1 -1
  45. flyte/remote/_run.py +4 -8
  46. flyte/remote/_task.py +2 -2
  47. flyte/storage/__init__.py +5 -0
  48. flyte/storage/_config.py +233 -0
  49. flyte/storage/_storage.py +23 -3
  50. flyte/types/_interface.py +1 -1
  51. flyte/types/_type_engine.py +1 -1
  52. {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/METADATA +2 -2
  53. {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/RECORD +56 -54
  54. flyte/_internal/controllers/pbhash.py +0 -39
  55. {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/WHEEL +0 -0
  56. {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/entry_points.txt +0 -0
  57. {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/top_level.txt +0 -0
flyte/_initialize.py CHANGED
@@ -1,14 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
- import datetime
4
3
  import functools
5
- import os
6
4
  import threading
7
5
  import typing
8
6
  from dataclasses import dataclass, replace
9
- from datetime import timedelta
10
7
  from pathlib import Path
11
- from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Literal, Optional, TypeVar
8
+ from typing import TYPE_CHECKING, Callable, List, Literal, Optional, TypeVar
12
9
 
13
10
  from flyte.errors import InitializationError
14
11
 
@@ -20,242 +17,11 @@ if TYPE_CHECKING:
20
17
  from flyte.config import Config
21
18
  from flyte.remote._client.auth import AuthType, ClientConfig
22
19
  from flyte.remote._client.controlplane import ClientSet
20
+ from flyte.storage import Storage
23
21
 
24
22
  Mode = Literal["local", "remote"]
25
23
 
26
24
 
27
- def set_if_exists(d: dict, k: str, val: typing.Any) -> dict:
28
- """
29
- Given a dict ``d`` sets the key ``k`` with value of config ``v``, if the config value ``v`` is set
30
- and return the updated dictionary.
31
- """
32
- exists = isinstance(val, bool) or bool(val is not None and val)
33
- if exists:
34
- d[k] = val
35
- return d
36
-
37
-
38
- @dataclass(init=True, repr=True, eq=True, frozen=True)
39
- class Storage(object):
40
- """
41
- Data storage configuration that applies across any provider.
42
- """
43
-
44
- retries: int = 3
45
- backoff: datetime.timedelta = datetime.timedelta(seconds=5)
46
- enable_debug: bool = False
47
- attach_execution_metadata: bool = True
48
-
49
- _KEY_ENV_VAR_MAPPING: ClassVar[typing.Dict[str, str]] = {
50
- "enable_debug": "UNION_STORAGE_DEBUG",
51
- "retries": "UNION_STORAGE_RETRIES",
52
- "backoff": "UNION_STORAGE_BACKOFF_SECONDS",
53
- }
54
-
55
- def get_fsspec_kwargs(self, anonymous: bool = False, /, **kwargs) -> Dict[str, Any]:
56
- """
57
- Returns the configuration as kwargs for constructing an fsspec filesystem.
58
- """
59
- return {}
60
-
61
- @classmethod
62
- def _auto_as_kwargs(cls) -> Dict[str, Any]:
63
- retries = os.getenv(cls._KEY_ENV_VAR_MAPPING["retries"])
64
- backoff = os.getenv(cls._KEY_ENV_VAR_MAPPING["backoff"])
65
- enable_debug = os.getenv(cls._KEY_ENV_VAR_MAPPING["enable_debug"])
66
-
67
- kwargs: Dict[str, Any] = {}
68
- kwargs = set_if_exists(kwargs, "enable_debug", enable_debug)
69
- kwargs = set_if_exists(kwargs, "retries", retries)
70
- kwargs = set_if_exists(kwargs, "backoff", backoff)
71
- return kwargs
72
-
73
- @classmethod
74
- def auto(cls) -> Storage:
75
- """
76
- Construct the config object automatically from environment variables.
77
- """
78
- return cls(**cls._auto_as_kwargs())
79
-
80
-
81
- @dataclass(init=True, repr=True, eq=True, frozen=True)
82
- class S3(Storage):
83
- """
84
- S3 specific configuration
85
- """
86
-
87
- endpoint: typing.Optional[str] = None
88
- access_key_id: typing.Optional[str] = None
89
- secret_access_key: typing.Optional[str] = None
90
-
91
- _KEY_ENV_VAR_MAPPING: ClassVar[typing.Dict[str, str]] = {
92
- "endpoint": "FLYTE_AWS_ENDPOINT",
93
- "access_key_id": "FLYTE_AWS_ACCESS_KEY_ID",
94
- "secret_access_key": "FLYTE_AWS_SECRET_ACCESS_KEY",
95
- } | Storage._KEY_ENV_VAR_MAPPING
96
-
97
- # Refer to https://github.com/developmentseed/obstore/blob/33654fc37f19a657689eb93327b621e9f9e01494/obstore/python/obstore/store/_aws.pyi#L11
98
- # for key and secret
99
- _CONFIG_KEY_FSSPEC_S3_KEY_ID: ClassVar = "access_key_id"
100
- _CONFIG_KEY_FSSPEC_S3_SECRET: ClassVar = "secret_access_key"
101
- _CONFIG_KEY_ENDPOINT: ClassVar = "endpoint_url"
102
- _KEY_SKIP_SIGNATURE: ClassVar = "skip_signature"
103
-
104
- @classmethod
105
- def auto(cls) -> S3:
106
- """
107
- :return: Config
108
- """
109
- endpoint = os.getenv(cls._KEY_ENV_VAR_MAPPING["endpoint"], None)
110
- access_key_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["access_key_id"], None)
111
- secret_access_key = os.getenv(cls._KEY_ENV_VAR_MAPPING["secret_access_key"], None)
112
-
113
- kwargs = super()._auto_as_kwargs()
114
- kwargs = set_if_exists(kwargs, "endpoint", endpoint)
115
- kwargs = set_if_exists(kwargs, "access_key_id", access_key_id)
116
- kwargs = set_if_exists(kwargs, "secret_access_key", secret_access_key)
117
-
118
- return S3(**kwargs)
119
-
120
- @classmethod
121
- def for_sandbox(cls) -> S3:
122
- """
123
- :return:
124
- """
125
- kwargs = super()._auto_as_kwargs()
126
- final_kwargs = kwargs | {
127
- "endpoint": "http://localhost:4566",
128
- "access_key_id": "minio",
129
- "secret_access_key": "miniostorage",
130
- }
131
- return S3(**final_kwargs)
132
-
133
- def get_fsspec_kwargs(self, anonymous: bool = False, /, **kwargs) -> Dict[str, Any]:
134
- # Construct the config object
135
- config: Dict[str, Any] = {}
136
- if self._CONFIG_KEY_FSSPEC_S3_KEY_ID in kwargs or self.access_key_id:
137
- config[self._CONFIG_KEY_FSSPEC_S3_KEY_ID] = kwargs.pop(
138
- self._CONFIG_KEY_FSSPEC_S3_KEY_ID, self.access_key_id
139
- )
140
- if self._CONFIG_KEY_FSSPEC_S3_SECRET in kwargs or self.secret_access_key:
141
- config[self._CONFIG_KEY_FSSPEC_S3_SECRET] = kwargs.pop(
142
- self._CONFIG_KEY_FSSPEC_S3_SECRET, self.secret_access_key
143
- )
144
- if self._CONFIG_KEY_ENDPOINT in kwargs or self.endpoint:
145
- config["endpoint_url"] = kwargs.pop(self._CONFIG_KEY_ENDPOINT, self.endpoint)
146
-
147
- retries = kwargs.pop("retries", self.retries)
148
- backoff = kwargs.pop("backoff", self.backoff)
149
-
150
- if anonymous:
151
- config[self._KEY_SKIP_SIGNATURE] = True
152
-
153
- retry_config = {
154
- "max_retries": retries,
155
- "backoff": {
156
- "base": 2,
157
- "init_backoff": backoff,
158
- "max_backoff": timedelta(seconds=16),
159
- },
160
- "retry_timeout": timedelta(minutes=3),
161
- }
162
-
163
- client_options = {"timeout": "99999s", "allow_http": True}
164
-
165
- if config:
166
- kwargs["config"] = config
167
- kwargs["client_options"] = client_options or None
168
- kwargs["retry_config"] = retry_config or None
169
-
170
- return kwargs
171
-
172
-
173
- @dataclass(init=True, repr=True, eq=True, frozen=True)
174
- class GCS(Storage):
175
- """
176
- Any GCS specific configuration.
177
- """
178
-
179
- gsutil_parallelism: bool = False
180
-
181
- _KEY_ENV_VAR_MAPPING: ClassVar[dict[str, str]] = {
182
- "gsutil_parallelism": "GCP_GSUTIL_PARALLELISM",
183
- }
184
-
185
- @classmethod
186
- def auto(cls) -> GCS:
187
- gsutil_parallelism = os.getenv(cls._KEY_ENV_VAR_MAPPING["gsutil_parallelism"], None)
188
-
189
- kwargs: Dict[str, Any] = {}
190
- kwargs = set_if_exists(kwargs, "gsutil_parallelism", gsutil_parallelism)
191
- return GCS(**kwargs)
192
-
193
- def get_fsspec_kwargs(self, anonymous: bool = False, /, **kwargs) -> Dict[str, Any]:
194
- return kwargs
195
-
196
-
197
- @dataclass(init=True, repr=True, eq=True, frozen=True)
198
- class ABFS(Storage):
199
- """
200
- Any Azure Blob Storage specific configuration.
201
- """
202
-
203
- account_name: typing.Optional[str] = None
204
- account_key: typing.Optional[str] = None
205
- tenant_id: typing.Optional[str] = None
206
- client_id: typing.Optional[str] = None
207
- client_secret: typing.Optional[str] = None
208
-
209
- _KEY_ENV_VAR_MAPPING: ClassVar[dict[str, str]] = {
210
- "account_name": "AZURE_STORAGE_ACCOUNT_NAME",
211
- "account_key": "AZURE_STORAGE_ACCOUNT_KEY",
212
- "tenant_id": "AZURE_TENANT_ID",
213
- "client_id": "AZURE_CLIENT_ID",
214
- "client_secret": "AZURE_CLIENT_SECRET",
215
- }
216
- _KEY_SKIP_SIGNATURE: ClassVar = "skip_signature"
217
-
218
- @classmethod
219
- def auto(cls) -> ABFS:
220
- account_name = os.getenv(cls._KEY_ENV_VAR_MAPPING["account_name"], None)
221
- account_key = os.getenv(cls._KEY_ENV_VAR_MAPPING["account_key"], None)
222
- tenant_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["tenant_id"], None)
223
- client_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["client_id"], None)
224
- client_secret = os.getenv(cls._KEY_ENV_VAR_MAPPING["client_secret"], None)
225
-
226
- kwargs: Dict[str, Any] = {}
227
- kwargs = set_if_exists(kwargs, "account_name", account_name)
228
- kwargs = set_if_exists(kwargs, "account_key", account_key)
229
- kwargs = set_if_exists(kwargs, "tenant_id", tenant_id)
230
- kwargs = set_if_exists(kwargs, "client_id", client_id)
231
- kwargs = set_if_exists(kwargs, "client_secret", client_secret)
232
- return ABFS(**kwargs)
233
-
234
- def get_fsspec_kwargs(self, anonymous: bool = False, /, **kwargs) -> Dict[str, Any]:
235
- config: Dict[str, Any] = {}
236
- if "account_name" in kwargs or self.account_name:
237
- config["account_name"] = kwargs.get("account_name", self.account_name)
238
- if "account_key" in kwargs or self.account_key:
239
- config["account_key"] = kwargs.get("account_key", self.account_key)
240
- if "client_id" in kwargs or self.client_id:
241
- config["client_id"] = kwargs.get("client_id", self.client_id)
242
- if "client_secret" in kwargs or self.client_secret:
243
- config["client_secret"] = kwargs.get("client_secret", self.client_secret)
244
- if "tenant_id" in kwargs or self.tenant_id:
245
- config["tenant_id"] = kwargs.get("tenant_id", self.tenant_id)
246
-
247
- if anonymous:
248
- config[self._KEY_SKIP_SIGNATURE] = True
249
-
250
- client_options = {"timeout": "99999s", "allow_http": "true"}
251
-
252
- if config:
253
- kwargs["config"] = config
254
- kwargs["client_options"] = client_options
255
-
256
- return kwargs
257
-
258
-
259
25
  @dataclass(init=True, repr=True, eq=True, frozen=True, kw_only=True)
260
26
  class CommonInit:
261
27
  """
@@ -304,11 +70,10 @@ async def _initialize_client(
304
70
  """
305
71
  from flyte.remote._client.controlplane import ClientSet
306
72
 
307
- if endpoint is not None:
73
+ if endpoint:
308
74
  return await ClientSet.for_endpoint(
309
75
  endpoint,
310
76
  insecure=insecure,
311
- api_key=api_key,
312
77
  insecure_skip_verify=insecure_skip_verify,
313
78
  auth_type=auth_type,
314
79
  headless=headless,
@@ -321,7 +86,26 @@ async def _initialize_client(
321
86
  rpc_retries=rpc_retries,
322
87
  http_proxy_url=http_proxy_url,
323
88
  )
324
- raise NotImplementedError("Currently only endpoints are supported.")
89
+ elif api_key:
90
+ return await ClientSet.for_api_key(
91
+ api_key,
92
+ insecure=insecure,
93
+ insecure_skip_verify=insecure_skip_verify,
94
+ auth_type=auth_type,
95
+ headless=headless,
96
+ ca_cert_file_path=ca_cert_file_path,
97
+ command=command,
98
+ proxy_command=proxy_command,
99
+ client_id=client_id,
100
+ client_credentials_secret=client_credentials_secret,
101
+ client_config=client_config,
102
+ rpc_retries=rpc_retries,
103
+ http_proxy_url=http_proxy_url,
104
+ )
105
+
106
+ raise InitializationError(
107
+ "MissingEndpointOrApiKeyError", "user", "Either endpoint or api_key must be provided to initialize the client."
108
+ )
325
109
 
326
110
 
327
111
  @syncer.wrap
@@ -396,9 +180,9 @@ async def init(
396
180
 
397
181
  with _init_lock:
398
182
  if config is None:
399
- from flyte.config import Config
183
+ import flyte.config as _f_cfg
400
184
 
401
- config = Config.auto()
185
+ config = _f_cfg.Config()
402
186
  platform_cfg = config.platform
403
187
  task_cfg = config.task
404
188
  client = None
@@ -458,7 +242,7 @@ def get_common_config() -> CommonInit:
458
242
  return cfg
459
243
 
460
244
 
461
- def get_storage() -> Storage:
245
+ def get_storage() -> Storage | None:
462
246
  """
463
247
  Get the current storage configuration. Thread-safe implementation.
464
248
 
@@ -472,9 +256,6 @@ def get_storage() -> Storage:
472
256
  "Configuration has not been initialized. Call flyte.init() with a valid endpoint or",
473
257
  " api-key before using this function.",
474
258
  )
475
- if cfg.storage is None:
476
- # return default local storage
477
- return typing.cast(Storage, cfg.replace(storage=Storage()).storage)
478
259
  return cfg.storage
479
260
 
480
261
 
@@ -504,13 +285,13 @@ def is_initialized() -> bool:
504
285
  return _get_init_config() is not None
505
286
 
506
287
 
507
- def initialize_in_cluster(storage: Storage | None = None) -> None:
288
+ def initialize_in_cluster() -> None:
508
289
  """
509
290
  Initialize the system for in-cluster execution. This is a placeholder function and does not perform any actions.
510
291
 
511
292
  :return: None
512
293
  """
513
- init(storage=storage)
294
+ init()
514
295
 
515
296
 
516
297
  # Define a generic type variable for the decorated function
@@ -1,8 +1,8 @@
1
1
  import threading
2
- from typing import Any, Literal, Optional, Protocol, Tuple, TypeVar
2
+ from typing import Any, Callable, Literal, Optional, Protocol, Tuple, TypeVar
3
3
 
4
- from flyte._datastructures import ActionID, NativeInterface
5
4
  from flyte._task import TaskTemplate
5
+ from flyte.models import ActionID, NativeInterface
6
6
 
7
7
  from ._trace import TraceInfo
8
8
 
@@ -47,12 +47,12 @@ class Controller(Protocol):
47
47
  async def watch_for_errors(self): ...
48
48
 
49
49
  async def get_action_outputs(
50
- self, _interface: NativeInterface, _func_name: str, *args, **kwargs
50
+ self, _interface: NativeInterface, _func: Callable, *args, **kwargs
51
51
  ) -> Tuple[TraceInfo, bool]:
52
52
  """
53
53
  This method returns the outputs of the action, if it is available.
54
54
  :param _interface: NativeInterface
55
- :param _func_name: Function name
55
+ :param _func: Function name
56
56
  :param args: Arguments
57
57
  :param kwargs: Keyword arguments
58
58
  :return: TraceInfo object and a boolean indicating if the action was found.
@@ -81,13 +81,13 @@ class _ControllerState:
81
81
  lock = threading.Lock()
82
82
 
83
83
 
84
- async def get_controller() -> Controller:
84
+ def get_controller() -> Controller:
85
85
  """
86
86
  Get the controller instance. Raise an error if it has not been created.
87
87
  """
88
88
  if _ControllerState.controller is not None:
89
89
  return _ControllerState.controller
90
- raise RuntimeError("Controller is not initialized. Please call get_or_create_controller() first.")
90
+ raise RuntimeError("Controller is not initialized. Please call create_controller() first.")
91
91
 
92
92
 
93
93
  def create_controller(
@@ -1,14 +1,14 @@
1
- from typing import Any, Tuple, TypeVar
1
+ from typing import Any, Callable, Tuple, TypeVar
2
2
 
3
3
  import flyte.errors
4
4
  from flyte._context import internal_ctx
5
- from flyte._datastructures import ActionID, NativeInterface, RawDataPath
6
5
  from flyte._internal.controllers import TraceInfo
7
6
  from flyte._internal.runtime import convert
8
7
  from flyte._internal.runtime.entrypoints import direct_dispatch
9
8
  from flyte._logging import log, logger
10
9
  from flyte._protos.workflow import task_definition_pb2
11
10
  from flyte._task import TaskTemplate
11
+ from flyte.models import ActionID, NativeInterface, RawDataPath
12
12
 
13
13
  R = TypeVar("R")
14
14
 
@@ -28,7 +28,11 @@ class LocalController:
28
28
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
29
29
 
30
30
  inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
31
- sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, _task.name, inputs)
31
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
32
+
33
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
34
+ tctx, _task.name, serialized_inputs, 0
35
+ )
32
36
  sub_action_raw_data_path = RawDataPath(path=sub_action_output_path)
33
37
 
34
38
  out, err = await direct_dispatch(
@@ -64,7 +68,7 @@ class LocalController:
64
68
  pass
65
69
 
66
70
  async def get_action_outputs(
67
- self, _interface: NativeInterface, _func_name: str, *args, **kwargs
71
+ self, _interface: NativeInterface, _func: Callable, *args, **kwargs
68
72
  ) -> Tuple[TraceInfo, bool]:
69
73
  """
70
74
  This method returns the outputs of the action, if it is available.
@@ -79,8 +83,13 @@ class LocalController:
79
83
  if _interface.inputs:
80
84
  converted_inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
81
85
  assert converted_inputs
86
+
87
+ serialized_inputs = converted_inputs.proto_inputs.SerializeToString(deterministic=True)
82
88
  action_id, action_output_path = convert.generate_sub_action_id_and_output_path(
83
- tctx, _func_name, converted_inputs
89
+ tctx,
90
+ _func.__name__,
91
+ serialized_inputs,
92
+ 0,
84
93
  )
85
94
  assert action_output_path
86
95
  return (
@@ -2,7 +2,7 @@ from dataclasses import dataclass
2
2
  from datetime import timedelta
3
3
  from typing import Any, Optional
4
4
 
5
- from flyte._datastructures import ActionID, NativeInterface
5
+ from flyte.models import ActionID, NativeInterface
6
6
 
7
7
 
8
8
  @dataclass
@@ -10,13 +10,13 @@ __all__ = ["RemoteController", "create_remote_controller"]
10
10
  def create_remote_controller(
11
11
  *,
12
12
  api_key: str | None = None,
13
- auth_type: AuthType = "Pkce",
14
- endpoint: str,
15
- client_config: ClientConfig | None = None,
16
- headless: bool = False,
13
+ endpoint: str | None = None,
17
14
  insecure: bool = False,
18
15
  insecure_skip_verify: bool = False,
19
16
  ca_cert_file_path: str | None = None,
17
+ client_config: ClientConfig | None = None,
18
+ auth_type: AuthType = "Pkce",
19
+ headless: bool = False,
20
20
  command: List[str] | None = None,
21
21
  proxy_command: List[str] | None = None,
22
22
  client_id: str | None = None,
@@ -27,13 +27,33 @@ def create_remote_controller(
27
27
  """
28
28
  Create a new instance of the remote controller.
29
29
  """
30
+ assert endpoint or api_key, "Either endpoint or api_key must be provided when initializing remote controller"
30
31
  from ._client import ControllerClient
31
32
  from ._controller import RemoteController
32
33
 
34
+ if endpoint:
35
+ client_coro = ControllerClient.for_endpoint(
36
+ endpoint,
37
+ insecure=insecure,
38
+ insecure_skip_verify=insecure_skip_verify,
39
+ ca_cert_file_path=ca_cert_file_path,
40
+ client_id=client_id,
41
+ client_credentials_secret=client_credentials_secret,
42
+ auth_type=auth_type,
43
+ )
44
+ elif api_key:
45
+ client_coro = ControllerClient.for_api_key(
46
+ api_key,
47
+ insecure=insecure,
48
+ insecure_skip_verify=insecure_skip_verify,
49
+ ca_cert_file_path=ca_cert_file_path,
50
+ client_id=client_id,
51
+ client_credentials_secret=client_credentials_secret,
52
+ auth_type=auth_type,
53
+ )
54
+
33
55
  controller = RemoteController(
34
- client_coro=ControllerClient.for_endpoint(
35
- endpoint=endpoint, insecure=insecure, insecure_skip_verify=insecure_skip_verify
36
- ),
56
+ client_coro=client_coro,
37
57
  workers=10,
38
58
  max_system_retries=5,
39
59
  )
@@ -4,8 +4,8 @@ from dataclasses import dataclass
4
4
 
5
5
  from flyteidl.core import execution_pb2
6
6
 
7
- from flyte._datastructures import GroupData
8
7
  from flyte._protos.workflow import run_definition_pb2, state_service_pb2, task_definition_pb2
8
+ from flyte.models import GroupData
9
9
 
10
10
 
11
11
  @dataclass
@@ -20,7 +20,11 @@ class ControllerClient:
20
20
 
21
21
  @classmethod
22
22
  async def for_endpoint(cls, endpoint: str, insecure: bool = False, **kwargs) -> ControllerClient:
23
- return cls(await create_channel(endpoint, insecure=insecure, **kwargs))
23
+ return cls(await create_channel(endpoint, None, insecure=insecure, **kwargs))
24
+
25
+ @classmethod
26
+ async def for_api_key(cls, api_key: str, insecure: bool = False, **kwargs) -> ControllerClient:
27
+ return cls(await create_channel(None, api_key, insecure=insecure, **kwargs))
24
28
 
25
29
  @property
26
30
  def state_service(self) -> StateService: