flyte 0.2.0b1__py3-none-any.whl → 0.2.0b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (57) hide show
  1. flyte/__init__.py +3 -4
  2. flyte/_bin/runtime.py +21 -7
  3. flyte/_cache/cache.py +1 -2
  4. flyte/_cli/_common.py +26 -4
  5. flyte/_cli/_create.py +48 -0
  6. flyte/_cli/_deploy.py +4 -2
  7. flyte/_cli/_get.py +18 -7
  8. flyte/_cli/_run.py +1 -0
  9. flyte/_cli/main.py +11 -5
  10. flyte/_code_bundle/bundle.py +42 -11
  11. flyte/_context.py +1 -1
  12. flyte/_deploy.py +3 -1
  13. flyte/_group.py +1 -1
  14. flyte/_initialize.py +28 -247
  15. flyte/_internal/controllers/__init__.py +6 -6
  16. flyte/_internal/controllers/_local_controller.py +14 -5
  17. flyte/_internal/controllers/_trace.py +1 -1
  18. flyte/_internal/controllers/remote/__init__.py +27 -7
  19. flyte/_internal/controllers/remote/_action.py +1 -1
  20. flyte/_internal/controllers/remote/_client.py +5 -1
  21. flyte/_internal/controllers/remote/_controller.py +68 -24
  22. flyte/_internal/controllers/remote/_core.py +1 -1
  23. flyte/_internal/runtime/convert.py +34 -8
  24. flyte/_internal/runtime/entrypoints.py +1 -1
  25. flyte/_internal/runtime/io.py +3 -3
  26. flyte/_internal/runtime/task_serde.py +31 -1
  27. flyte/_internal/runtime/taskrunner.py +1 -1
  28. flyte/_internal/runtime/types_serde.py +1 -1
  29. flyte/_run.py +47 -28
  30. flyte/_task.py +2 -2
  31. flyte/_task_environment.py +1 -1
  32. flyte/_trace.py +5 -6
  33. flyte/_utils/__init__.py +2 -0
  34. flyte/_utils/async_cache.py +139 -0
  35. flyte/_version.py +2 -2
  36. flyte/config/__init__.py +26 -4
  37. flyte/config/_config.py +13 -4
  38. flyte/extras/_container.py +3 -3
  39. flyte/{_datastructures.py → models.py} +3 -2
  40. flyte/remote/_client/auth/_auth_utils.py +14 -0
  41. flyte/remote/_client/auth/_channel.py +28 -3
  42. flyte/remote/_client/auth/_token_client.py +3 -3
  43. flyte/remote/_client/controlplane.py +13 -13
  44. flyte/remote/_logs.py +1 -1
  45. flyte/remote/_run.py +4 -8
  46. flyte/remote/_task.py +2 -2
  47. flyte/storage/__init__.py +5 -0
  48. flyte/storage/_config.py +233 -0
  49. flyte/storage/_storage.py +23 -3
  50. flyte/types/_interface.py +1 -1
  51. flyte/types/_type_engine.py +1 -1
  52. {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/METADATA +2 -2
  53. {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/RECORD +56 -54
  54. flyte/_internal/controllers/pbhash.py +0 -39
  55. {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/WHEEL +0 -0
  56. {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/entry_points.txt +0 -0
  57. {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,139 @@
1
+ import asyncio
2
+ import time
3
+ from collections import OrderedDict
4
+ from typing import Awaitable, Callable, Dict, Generic, Optional, TypeVar
5
+
6
+ K = TypeVar("K")
7
+ V = TypeVar("V")
8
+
9
+
10
+ class AsyncLRUCache(Generic[K, V]):
11
+ """
12
+ A high-performance async-compatible LRU cache.
13
+
14
+ Examples:
15
+ ```python
16
+ # Create a cache instance
17
+ cache = AsyncLRUCache[str, dict](maxsize=100)
18
+
19
+ async def fetch_data(user_id: str) -> dict:
20
+ # Define the expensive operation as a local function
21
+ async def get_user_data():
22
+ await asyncio.sleep(1) # Simulating network/DB delay
23
+ return {"id": user_id, "name": f"User {user_id}"}
24
+
25
+ # Use the cache
26
+ return await cache.get(f"user:{user_id}", get_user_data)
27
+ ```
28
+ This cache can be used from async coroutines and handles concurrent access safely.
29
+ """
30
+
31
+ def __init__(self, maxsize: int = 128, ttl: Optional[float] = None):
32
+ """
33
+ Initialize the async LRU cache.
34
+
35
+ Args:
36
+ maxsize: Maximum number of items to keep in the cache
37
+ ttl: Time-to-live for cache entries in seconds, or None for no expiration
38
+ """
39
+ self._cache: OrderedDict[K, tuple[V, float]] = OrderedDict()
40
+ self._maxsize = maxsize
41
+ self._ttl = ttl
42
+ self._locks: Dict[K, asyncio.Lock] = {}
43
+ self._access_lock = asyncio.Lock()
44
+
45
+ async def get(self, key: K, value_func: Callable[[], V | Awaitable[V]]) -> V:
46
+ """
47
+ Get a value from the cache, computing it if necessary.
48
+
49
+ Args:
50
+ key: The cache key
51
+ value_func: Function or coroutine to compute the value if not cached
52
+
53
+ Returns:
54
+ The cached or computed value
55
+ """
56
+ # Fast path: check if key exists and is not expired
57
+ if key in self._cache:
58
+ value, timestamp = self._cache[key]
59
+ if self._ttl is None or time.time() - timestamp < self._ttl:
60
+ # Move the accessed item to the end (most recently used)
61
+ async with self._access_lock:
62
+ self._cache.move_to_end(key)
63
+ return value
64
+
65
+ # Slow path: compute the value
66
+ # Get or create a lock for this key to prevent redundant computation
67
+ async with self._access_lock:
68
+ lock = self._locks.get(key)
69
+ if lock is None:
70
+ lock = asyncio.Lock()
71
+ self._locks[key] = lock
72
+
73
+ async with lock:
74
+ # Check again in case another coroutine computed the value while we waited
75
+ if key in self._cache:
76
+ value, timestamp = self._cache[key]
77
+ if self._ttl is None or time.time() - timestamp < self._ttl:
78
+ async with self._access_lock:
79
+ self._cache.move_to_end(key)
80
+ return value
81
+
82
+ # Compute the value
83
+ if asyncio.iscoroutinefunction(value_func):
84
+ value = await value_func()
85
+ else:
86
+ value = value_func() # type: ignore
87
+
88
+ # Store in cache
89
+ async with self._access_lock:
90
+ self._cache[key] = (value, time.time())
91
+ # Evict least recently used items if needed
92
+ while len(self._cache) > self._maxsize:
93
+ self._cache.popitem(last=False)
94
+ # Clean up the lock
95
+ self._locks.pop(key, None)
96
+
97
+ return value
98
+
99
+ async def set(self, key: K, value: V) -> None:
100
+ """
101
+ Explicitly set a value in the cache.
102
+
103
+ Args:
104
+ key: The cache key
105
+ value: The value to cache
106
+ """
107
+ async with self._access_lock:
108
+ self._cache[key] = (value, time.time())
109
+ # Evict least recently used items if needed
110
+ while len(self._cache) > self._maxsize:
111
+ self._cache.popitem(last=False)
112
+
113
+ async def invalidate(self, key: K) -> None:
114
+ """Remove a specific key from the cache."""
115
+ async with self._access_lock:
116
+ self._cache.pop(key, None)
117
+
118
+ async def clear(self) -> None:
119
+ """Clear the entire cache."""
120
+ async with self._access_lock:
121
+ self._cache.clear()
122
+ self._locks.clear()
123
+
124
+ async def contains(self, key: K) -> bool:
125
+ """Check if a key exists in the cache and is not expired."""
126
+ if key not in self._cache:
127
+ return False
128
+
129
+ if self._ttl is None:
130
+ return True
131
+
132
+ _, timestamp = self._cache[key]
133
+ return time.time() - timestamp < self._ttl
134
+
135
+
136
+ # Example usage:
137
+ """
138
+
139
+ """
flyte/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.0b1'
21
- __version_tuple__ = version_tuple = (0, 2, 0, 'b1')
20
+ __version__ = version = '0.2.0b3'
21
+ __version_tuple__ = version_tuple = (0, 2, 0, 'b3')
flyte/config/__init__.py CHANGED
@@ -1,8 +1,13 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
4
+ import pathlib
2
5
  import typing
3
6
  from dataclasses import dataclass, field
4
7
  from typing import TYPE_CHECKING
5
8
 
9
+ import rich.repr
10
+
6
11
  from flyte._logging import logger
7
12
  from flyte.config import _internal
8
13
  from flyte.config._config import ConfigFile, get_config_file, read_file_if_exists
@@ -13,6 +18,7 @@ if TYPE_CHECKING:
13
18
  from flyte.remote._client.auth import AuthType
14
19
 
15
20
 
21
+ @rich.repr.auto
16
22
  @dataclass(init=True, repr=True, eq=True, frozen=True)
17
23
  class PlatformConfig(object):
18
24
  """
@@ -57,7 +63,6 @@ class PlatformConfig(object):
57
63
  :param config_file:
58
64
  :return:
59
65
  """
60
- from .._initialize import set_if_exists
61
66
 
62
67
  config_file = get_config_file(config_file)
63
68
  kwargs: typing.Dict[str, typing.Any] = {}
@@ -106,6 +111,7 @@ class PlatformConfig(object):
106
111
  return PlatformConfig(endpoint=endpoint, insecure=insecure)
107
112
 
108
113
 
114
+ @rich.repr.auto
109
115
  @dataclass(init=True, repr=True, eq=True, frozen=True)
110
116
  class TaskConfig(object):
111
117
  org: str | None = None
@@ -119,8 +125,6 @@ class TaskConfig(object):
119
125
  :param config_file:
120
126
  :return:
121
127
  """
122
- from flyte._initialize import set_if_exists
123
-
124
128
  config_file = get_config_file(config_file)
125
129
  kwargs: typing.Dict[str, typing.Any] = {}
126
130
  kwargs = set_if_exists(kwargs, "org", _internal.Task.ORG.read(config_file))
@@ -129,6 +133,7 @@ class TaskConfig(object):
129
133
  return TaskConfig(**kwargs)
130
134
 
131
135
 
136
+ @rich.repr.auto
132
137
  @dataclass(init=True, repr=True, eq=True, frozen=True)
133
138
  class Config(object):
134
139
  """
@@ -142,6 +147,7 @@ class Config(object):
142
147
 
143
148
  platform: PlatformConfig = field(default=PlatformConfig())
144
149
  task: TaskConfig = field(default=TaskConfig())
150
+ source: pathlib.Path | None = None
145
151
 
146
152
  def with_params(
147
153
  self,
@@ -165,4 +171,20 @@ class Config(object):
165
171
  :return: Config
166
172
  """
167
173
  config_file = get_config_file(config_file)
168
- return Config(platform=PlatformConfig.auto(config_file), task=TaskConfig.auto(config_file))
174
+ if config_file is None:
175
+ logger.debug("No config file found, using default values")
176
+ return Config()
177
+ return Config(
178
+ platform=PlatformConfig.auto(config_file), task=TaskConfig.auto(config_file), source=config_file.path
179
+ )
180
+
181
+
182
+ def set_if_exists(d: dict, k: str, val: typing.Any) -> dict:
183
+ """
184
+ Given a dict ``d`` sets the key ``k`` with value of config ``v``, if the config value ``v`` is set
185
+ and return the updated dictionary.
186
+ """
187
+ exists = isinstance(val, bool) or bool(val is not None and val)
188
+ if exists:
189
+ d[k] = val
190
+ return d
flyte/config/_config.py CHANGED
@@ -99,6 +99,14 @@ class ConfigFile(object):
99
99
  self._location = location
100
100
  self._yaml_config = self._read_yaml_config(location)
101
101
 
102
+ @property
103
+ def path(self) -> pathlib.Path:
104
+ """
105
+ Returns the path to the config file.
106
+ :return: Path to the config file
107
+ """
108
+ return pathlib.Path(self._location)
109
+
102
110
  @staticmethod
103
111
  def _read_yaml_config(location: str) -> typing.Optional[typing.Dict[str, typing.Any]]:
104
112
  with open(location, "r") as fh:
@@ -136,6 +144,11 @@ def resolve_config_path() -> pathlib.Path | None:
136
144
  4. ~/.flyte/config.yaml if it exists
137
145
  5. ./config.yaml if it exists
138
146
  """
147
+ current_location_config = Path("config.yaml")
148
+ if current_location_config.exists():
149
+ return current_location_config
150
+ logger.debug("No ./config.yaml found, returning None")
151
+
139
152
  uctl_path_from_env = getenv(UCTL_CONFIG_ENV_VAR, None)
140
153
  if uctl_path_from_env:
141
154
  return pathlib.Path(uctl_path_from_env)
@@ -156,10 +169,6 @@ def resolve_config_path() -> pathlib.Path | None:
156
169
  return home_dir_flytectl_config
157
170
  logger.debug("No ~/.flyte/config.yaml found, checking current directory")
158
171
 
159
- current_location_config = Path("config.yaml")
160
- if current_location_config.exists():
161
- return current_location_config
162
- logger.debug("No ./config.yaml found, returning None")
163
172
  return None
164
173
 
165
174
 
@@ -5,9 +5,9 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
5
5
  from flyteidl.core import tasks_pb2
6
6
 
7
7
  from flyte import Image, storage
8
- from flyte._datastructures import NativeInterface, SerializationContext
9
8
  from flyte._logging import logger
10
9
  from flyte._task import TaskTemplate
10
+ from flyte.models import NativeInterface, SerializationContext
11
11
 
12
12
  _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
13
13
 
@@ -258,8 +258,8 @@ class ContainerTask(TaskTemplate):
258
258
  }
259
259
 
260
260
  return tasks_pb2.DataLoadingConfig(
261
- input_path=self._input_data_dir,
262
- output_path=self._output_data_dir,
261
+ input_path=str(self._input_data_dir),
262
+ output_path=str(self._output_data_dir),
263
263
  enabled=True,
264
264
  format=literal_to_protobuf.get(self._metadata_format, "JSON"),
265
265
  )
@@ -56,11 +56,12 @@ class ActionID:
56
56
  name = generate_random_name()
57
57
  return replace(self, name=name)
58
58
 
59
- def new_sub_action_from(self, task_name: str, input_hash: str, group: str | None) -> ActionID:
59
+ def new_sub_action_from(self, task_call_seq: int, task_hash: str, input_hash: str, group: str | None) -> ActionID:
60
60
  """Make a deterministic name"""
61
61
  import hashlib
62
62
 
63
- components = f"{self.run_name}-{self.name}-{input_hash}-{task_name}" + (f"-{group}" if group else "")
63
+ components = f"{self.name}-{input_hash}-{task_hash}-{task_call_seq}" + (f"-{group}" if group else "")
64
+ logger.debug(f"----- Generating sub-action ID from components: {components}")
64
65
  # has the components into something deterministic
65
66
  bytes_digest = hashlib.md5(components.encode()).digest()
66
67
  new_name = base36_encode(bytes_digest)
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ from typing import Literal
5
+
6
+
7
+ def decode_api_key(encoded_str: str) -> tuple[str, str, str, str | Literal["None"]]:
8
+ """Decode encoded base64 string into app credentials. endpoint, client_id, client_secret, org"""
9
+ endpoint, client_id, client_secret, org = base64.b64decode(encoded_str.encode("utf-8")).decode("utf-8").split(":")
10
+ # For consistency, let's make sure org is always a non-empty string
11
+ if not org:
12
+ org = "None"
13
+
14
+ return endpoint, client_id, client_secret, org
@@ -44,7 +44,9 @@ def bootstrap_ssl_from_server(endpoint: str) -> grpc.ChannelCredentials:
44
44
 
45
45
 
46
46
  async def create_channel(
47
- endpoint: str,
47
+ endpoint: str | None,
48
+ api_key: str | None = None,
49
+ /,
48
50
  insecure: typing.Optional[bool] = None,
49
51
  insecure_skip_verify: typing.Optional[bool] = False,
50
52
  ca_cert_file_path: typing.Optional[str] = None,
@@ -66,6 +68,7 @@ async def create_channel(
66
68
  and create authentication interceptors that perform async operations.
67
69
 
68
70
  :param endpoint: The endpoint URL for the gRPC channel
71
+ :param api_key: API key for authentication; if provided, it will be used to detect the endpoint and credentials.
69
72
  :param insecure: Whether to use an insecure channel (no SSL)
70
73
  :param insecure_skip_verify: Whether to skip SSL certificate verification
71
74
  :param ca_cert_file_path: Path to CA certificate file for SSL verification
@@ -104,6 +107,18 @@ async def create_channel(
104
107
  - refresh_access_token_params: Parameters to add when refreshing access token
105
108
  :return: grpc.aio.Channel with authentication interceptors configured
106
109
  """
110
+ assert endpoint or api_key, "Either endpoint or api_key must be specified"
111
+
112
+ if api_key:
113
+ from flyte.remote._client.auth._auth_utils import decode_api_key
114
+
115
+ endpoint, client_id, client_secret, org = decode_api_key(api_key)
116
+ kwargs["auth_type"] = "ClientSecret"
117
+ kwargs["client_id"] = client_id
118
+ kwargs["client_secret"] = client_secret
119
+ kwargs["client_credentials_secret"] = client_secret
120
+
121
+ assert endpoint, "Endpoint must be specified by this point"
107
122
 
108
123
  if not ssl_credentials:
109
124
  if insecure_skip_verify:
@@ -119,7 +134,12 @@ async def create_channel(
119
134
 
120
135
  # Create an unauthenticated channel first to use to get the server metadata
121
136
  if insecure:
122
- unauthenticated_channel = grpc.aio.insecure_channel(endpoint, **kwargs)
137
+ insecure_kwargs = {}
138
+ if kw_opts := kwargs.get("options"):
139
+ insecure_kwargs["options"] = kw_opts
140
+ if compression:
141
+ insecure_kwargs["compression"] = compression
142
+ unauthenticated_channel = grpc.aio.insecure_channel(endpoint, **insecure_kwargs)
123
143
  else:
124
144
  unauthenticated_channel = grpc.aio.secure_channel(
125
145
  target=endpoint,
@@ -173,7 +193,12 @@ async def create_channel(
173
193
  interceptors.extend(auth_interceptors)
174
194
 
175
195
  if insecure:
176
- return grpc.aio.insecure_channel(endpoint, interceptors=interceptors, **kwargs)
196
+ insecure_kwargs = {}
197
+ if kw_opts := kwargs.get("options"):
198
+ insecure_kwargs["options"] = kw_opts
199
+ if compression:
200
+ insecure_kwargs["compression"] = compression
201
+ return grpc.aio.insecure_channel(endpoint, interceptors=interceptors, **insecure_kwargs)
177
202
 
178
203
  return grpc.aio.secure_channel(
179
204
  target=endpoint,
@@ -94,7 +94,7 @@ async def get_token(
94
94
  http_proxy_url: typing.Optional[str] = None,
95
95
  verify: typing.Optional[typing.Union[bool, str]] = None,
96
96
  refresh_token: typing.Optional[str] = None,
97
- ) -> typing.Tuple[str, str, int]:
97
+ ) -> typing.Tuple[str, str | None, int]:
98
98
  """
99
99
  Retrieves an access token from the specified token endpoint.
100
100
 
@@ -165,7 +165,7 @@ async def get_token(
165
165
  if "refresh_token" in j:
166
166
  new_refresh_token = j["refresh_token"]
167
167
  else:
168
- raise AuthenticationError("Token not yet available, try again in some time")
168
+ logger.info("No refresh token received, this is expected for client credentials flow")
169
169
 
170
170
  return j["access_token"], new_refresh_token, j["expires_in"]
171
171
 
@@ -213,7 +213,7 @@ async def poll_token_endpoint(
213
213
  scopes: typing.Optional[typing.List[str]] = None,
214
214
  http_proxy_url: typing.Optional[str] = None,
215
215
  verify: typing.Optional[typing.Union[bool, str]] = None,
216
- ) -> typing.Tuple[str, str, int]:
216
+ ) -> typing.Tuple[str, str | None, int]:
217
217
  """
218
218
  Polls the token endpoint until authentication is complete or times out.
219
219
 
@@ -39,21 +39,21 @@ class ClientSet:
39
39
 
40
40
  @classmethod
41
41
  async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet:
42
- if insecure:
43
- del kwargs["api_key"]
44
- del kwargs["auth_type"]
45
- del kwargs["headless"]
46
- del kwargs["command"]
47
- del kwargs["client_id"]
48
- del kwargs["client_credentials_secret"]
49
- del kwargs["client_config"]
50
- del kwargs["rpc_retries"]
51
- del kwargs["http_proxy_url"]
52
- return cls(await create_channel(endpoint, insecure=insecure, **kwargs), endpoint, insecure=insecure, **kwargs)
42
+ return cls(
43
+ await create_channel(endpoint, None, insecure=insecure, **kwargs), endpoint, insecure=insecure, **kwargs
44
+ )
53
45
 
54
46
  @classmethod
55
- async def for_api_key(cls, api_key: str, **kwargs) -> ClientSet:
56
- raise NotImplementedError
47
+ async def for_api_key(cls, api_key: str, *, insecure: bool = False, **kwargs) -> ClientSet:
48
+ from flyte.remote._client.auth._auth_utils import decode_api_key
49
+
50
+ # Parsing the API key is done in create_channel, but cleaner to redo it here rather than getting create_channel
51
+ # to return the endpoint
52
+ endpoint, _, _, _ = decode_api_key(api_key)
53
+
54
+ return cls(
55
+ await create_channel(None, api_key, insecure=insecure, **kwargs), endpoint, insecure=insecure, **kwargs
56
+ )
57
57
 
58
58
  @classmethod
59
59
  async def for_serverless(cls) -> ClientSet:
flyte/remote/_logs.py CHANGED
@@ -111,6 +111,6 @@ class Logs:
111
111
  log_source=cls.tail.aio(cls, action_id=action_id, attempt=attempt),
112
112
  max_lines=max_lines,
113
113
  show_ts=show_ts,
114
- name=f"{action_id.run.name}:{action_id.name}",
114
+ name=f"{action_id.run.name}:{action_id.name} ({attempt})",
115
115
  )
116
116
  await viewer.run()
flyte/remote/_run.py CHANGED
@@ -243,14 +243,10 @@ class Run:
243
243
  return
244
244
  yield ad
245
245
 
246
- async def show_logs(self, attempt: int = 1, max_lines: int = 100, show_ts: bool = False, raw: bool = False):
247
- return await Logs.create_viewer(
248
- action_id=self.action.action_id,
249
- attempt=attempt,
250
- max_lines=max_lines,
251
- show_ts=show_ts,
252
- raw=raw,
253
- )
246
+ async def show_logs(
247
+ self, attempt: int | None = None, max_lines: int = 100, show_ts: bool = False, raw: bool = False
248
+ ):
249
+ await self.action.show_logs(attempt, max_lines, show_ts, raw)
254
250
 
255
251
  async def details(self) -> RunDetails:
256
252
  """
flyte/remote/_task.py CHANGED
@@ -11,9 +11,9 @@ import flyte
11
11
  import flyte.errors
12
12
  from flyte._api_commons import syncer
13
13
  from flyte._context import internal_ctx
14
- from flyte._datastructures import NativeInterface
15
14
  from flyte._initialize import get_client, get_common_config
16
15
  from flyte._protos.workflow import task_definition_pb2, task_service_pb2
16
+ from flyte.models import NativeInterface
17
17
 
18
18
 
19
19
  class LazyEntity:
@@ -187,7 +187,7 @@ class Task:
187
187
  # We will also check if we are not initialized, It is not expected to be not initialized
188
188
  from flyte._internal.controllers import get_controller
189
189
 
190
- controller = await get_controller()
190
+ controller = get_controller()
191
191
  if controller:
192
192
  return await controller.submit_task_ref(self.pb2, *args, **kwargs)
193
193
  raise flyte.errors
flyte/storage/__init__.py CHANGED
@@ -1,4 +1,8 @@
1
1
  __all__ = [
2
+ "ABFS",
3
+ "GCS",
4
+ "S3",
5
+ "Storage",
2
6
  "get",
3
7
  "get_random_local_directory",
4
8
  "get_random_local_path",
@@ -11,6 +15,7 @@ __all__ = [
11
15
  "put_stream",
12
16
  ]
13
17
 
18
+ from ._config import ABFS, GCS, S3, Storage
14
19
  from ._storage import (
15
20
  get,
16
21
  get_random_local_directory,