flyte 2.0.0b20__py3-none-any.whl → 2.0.0b22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. flyte/_bin/runtime.py +3 -3
  2. flyte/_code_bundle/_ignore.py +11 -3
  3. flyte/_code_bundle/_packaging.py +9 -5
  4. flyte/_code_bundle/_utils.py +2 -2
  5. flyte/_deploy.py +16 -8
  6. flyte/_image.py +5 -1
  7. flyte/_initialize.py +21 -8
  8. flyte/_interface.py +37 -2
  9. flyte/_internal/imagebuild/docker_builder.py +61 -8
  10. flyte/_internal/imagebuild/image_builder.py +8 -7
  11. flyte/_internal/imagebuild/remote_builder.py +9 -26
  12. flyte/_internal/runtime/task_serde.py +16 -6
  13. flyte/_keyring/__init__.py +0 -0
  14. flyte/_keyring/file.py +85 -0
  15. flyte/_logging.py +19 -8
  16. flyte/_task_environment.py +1 -1
  17. flyte/_utils/coro_management.py +2 -1
  18. flyte/_version.py +3 -3
  19. flyte/cli/_common.py +2 -2
  20. flyte/cli/_deploy.py +11 -1
  21. flyte/cli/_run.py +13 -3
  22. flyte/config/_config.py +6 -4
  23. flyte/config/_reader.py +19 -4
  24. flyte/git/_config.py +2 -0
  25. flyte/io/_dataframe/dataframe.py +3 -2
  26. flyte/io/_dir.py +72 -72
  27. flyte/models.py +6 -2
  28. flyte/remote/_action.py +9 -8
  29. flyte/remote/_client/auth/_authenticators/device_code.py +3 -4
  30. flyte/remote/_data.py +2 -3
  31. flyte/remote/_run.py +17 -1
  32. flyte/storage/_config.py +5 -1
  33. flyte/types/_pickle.py +18 -4
  34. flyte/types/_type_engine.py +13 -0
  35. {flyte-2.0.0b20.data → flyte-2.0.0b22.data}/scripts/runtime.py +3 -3
  36. {flyte-2.0.0b20.dist-info → flyte-2.0.0b22.dist-info}/METADATA +1 -1
  37. {flyte-2.0.0b20.dist-info → flyte-2.0.0b22.dist-info}/RECORD +42 -40
  38. {flyte-2.0.0b20.dist-info → flyte-2.0.0b22.dist-info}/entry_points.txt +3 -0
  39. {flyte-2.0.0b20.data → flyte-2.0.0b22.data}/scripts/debug.py +0 -0
  40. {flyte-2.0.0b20.dist-info → flyte-2.0.0b22.dist-info}/WHEEL +0 -0
  41. {flyte-2.0.0b20.dist-info → flyte-2.0.0b22.dist-info}/licenses/LICENSE +0 -0
  42. {flyte-2.0.0b20.dist-info → flyte-2.0.0b22.dist-info}/top_level.txt +0 -0
flyte/_keyring/file.py ADDED
@@ -0,0 +1,85 @@
1
+ from base64 import decodebytes, encodebytes
2
+ from configparser import ConfigParser, NoOptionError, NoSectionError
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ from keyring.backend import KeyringBackend
7
+ from keyring.errors import PasswordDeleteError
8
+
9
+ _FLYTE_KEYRING_PATH: Path = Path.home() / ".flyte" / "keyring.cfg"
10
+
11
+
12
+ class SimplePlainTextKeyring(KeyringBackend):
13
+ """Simple plain text keyring"""
14
+
15
+ priority = 0.5
16
+
17
+ def get_password(self, service: str, username: str) -> Optional[str]:
18
+ """Get password."""
19
+ if not self.file_path.exists():
20
+ return None
21
+
22
+ config = ConfigParser(interpolation=None)
23
+ config.read(self.file_path, encoding="utf-8")
24
+
25
+ try:
26
+ password_base64 = config.get(service, username).encode("utf-8")
27
+ return decodebytes(password_base64).decode("utf-8")
28
+ except (NoOptionError, NoSectionError):
29
+ return None
30
+
31
+ def delete_password(self, service: str, username: str) -> None:
32
+ """Delete password."""
33
+ if not self.file_path.exists():
34
+ raise PasswordDeleteError("Config file does not exist")
35
+
36
+ config = ConfigParser(interpolation=None)
37
+ config.read(self.file_path, encoding="utf-8")
38
+
39
+ try:
40
+ if not config.remove_option(service, username):
41
+ raise PasswordDeleteError("Password not found")
42
+ except NoSectionError:
43
+ raise PasswordDeleteError("Password not found")
44
+
45
+ with self.file_path.open("w", encoding="utf-8") as config_file:
46
+ config.write(config_file)
47
+
48
+ def set_password(self, service: str, username: str, password: str) -> None:
49
+ """Set password."""
50
+ if not username:
51
+ raise ValueError("Username must be provided")
52
+
53
+ file_path = self._ensure_file_path()
54
+ value = encodebytes(password.encode("utf-8")).decode("utf-8")
55
+
56
+ config = ConfigParser(interpolation=None)
57
+ config.read(file_path, encoding="utf-8")
58
+
59
+ if not config.has_section(service):
60
+ config.add_section(service)
61
+
62
+ config.set(service, username, value)
63
+
64
+ with file_path.open("w", encoding="utf-8") as config_file:
65
+ config.write(config_file)
66
+
67
+ def _ensure_file_path(self):
68
+ self.file_path.parent.mkdir(exist_ok=True, parents=True)
69
+ if not self.file_path.is_file():
70
+ self.file_path.touch(0o600)
71
+ return self.file_path
72
+
73
+ @property
74
+ def file_path(self) -> Path:
75
+ from flyte._initialize import get_common_config
76
+
77
+ config_path = get_common_config().source_config_path
78
+ if config_path and str(config_path.parent) == ".flyte":
79
+ # if the config is in a .flyte directory, use that as the path
80
+ return config_path.parent / "keyring.cfg"
81
+ # otherwise use the default path
82
+ return _FLYTE_KEYRING_PATH
83
+
84
+ def __repr__(self):
85
+ return f"<{self.__class__.__name__}> at {self.file_path}>"
flyte/_logging.py CHANGED
@@ -40,6 +40,21 @@ def log_format_from_env() -> str:
40
40
  return os.environ.get("LOG_FORMAT", "json")
41
41
 
42
42
 
43
+ def _get_console():
44
+ """
45
+ Get the console.
46
+ """
47
+ from rich.console import Console
48
+
49
+ try:
50
+ width = os.get_terminal_size().columns
51
+ except Exception as e:
52
+ logger.debug(f"Failed to get terminal size: {e}")
53
+ width = 160
54
+
55
+ return Console(width=width)
56
+
57
+
43
58
  def get_rich_handler(log_level: int) -> Optional[logging.Handler]:
44
59
  """
45
60
  Upgrades the global loggers to use Rich logging.
@@ -51,23 +66,19 @@ def get_rich_handler(log_level: int) -> Optional[logging.Handler]:
51
66
  return None
52
67
 
53
68
  import click
54
- from rich.console import Console
69
+ from rich.highlighter import NullHighlighter
55
70
  from rich.logging import RichHandler
56
71
 
57
- try:
58
- width = os.get_terminal_size().columns
59
- except Exception as e:
60
- logger.debug(f"Failed to get terminal size: {e}")
61
- width = 160
62
-
63
72
  handler = RichHandler(
64
73
  tracebacks_suppress=[click],
65
74
  rich_tracebacks=True,
66
75
  omit_repeated_times=False,
67
76
  show_path=False,
68
77
  log_time_format="%H:%M:%S.%f",
69
- console=Console(width=width),
78
+ console=_get_console(),
70
79
  level=log_level,
80
+ highlighter=NullHighlighter(),
81
+ markup=True,
71
82
  )
72
83
 
73
84
  formatter = logging.Formatter(fmt="%(filename)s:%(lineno)d - %(message)s")
@@ -135,7 +135,7 @@ class TaskEnvironment(Environment):
135
135
 
136
136
  def task(
137
137
  self,
138
- _func=None,
138
+ _func: Callable[P, R] | None = None,
139
139
  *,
140
140
  short_name: Optional[str] = None,
141
141
  cache: CacheRequest | None = None,
@@ -11,7 +11,8 @@ async def run_coros(*coros: typing.Coroutine, return_when: str = asyncio.FIRST_C
11
11
  :param return_when:
12
12
  :return:
13
13
  """
14
- tasks: typing.List[asyncio.Task[typing.Never]] = [asyncio.create_task(c) for c in coros]
14
+ # tasks: typing.List[asyncio.Task[typing.Never]] = [asyncio.create_task(c) for c in coros] # Python 3.11+
15
+ tasks: typing.List[asyncio.Task] = [asyncio.create_task(c) for c in coros]
15
16
  done, pending = await asyncio.wait(tasks, return_when=return_when)
16
17
  # TODO we might want to handle asyncio.CancelledError here, for cases when the `action` is cancelled
17
18
  # and we want to propagate it to all tasks. Though the backend will handle it anyway,
flyte/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.0.0b20'
32
- __version_tuple__ = version_tuple = (2, 0, 0, 'b20')
31
+ __version__ = version = '2.0.0b22'
32
+ __version_tuple__ = version_tuple = (2, 0, 0, 'b22')
33
33
 
34
- __commit_id__ = commit_id = 'g5109b02e4'
34
+ __commit_id__ = commit_id = 'gce9a6bede'
flyte/cli/_common.py CHANGED
@@ -111,7 +111,7 @@ class CLIConfig:
111
111
  """
112
112
  return replace(self, **kwargs)
113
113
 
114
- def init(self, project: str | None = None, domain: str | None = None):
114
+ def init(self, project: str | None = None, domain: str | None = None, root_dir: str | None = None):
115
115
  from flyte.config._config import TaskConfig
116
116
 
117
117
  task_cfg = TaskConfig(
@@ -131,7 +131,7 @@ class CLIConfig:
131
131
 
132
132
  updated_config = self.config.with_params(platform_cfg, task_cfg)
133
133
 
134
- flyte.init_from_config(updated_config, log_level=self.log_level)
134
+ flyte.init_from_config(updated_config, log_level=self.log_level, root_dir=root_dir)
135
135
 
136
136
 
137
137
  class InvokeBaseMixin:
flyte/cli/_deploy.py CHANGED
@@ -43,6 +43,16 @@ class DeployArguments:
43
43
  )
44
44
  },
45
45
  )
46
+ root_dir: str | None = field(
47
+ default=None,
48
+ metadata={
49
+ "click.option": click.Option(
50
+ ["--root-dir"],
51
+ type=str,
52
+ help="Override the root source directory, helpful when working with monorepos.",
53
+ )
54
+ },
55
+ )
46
56
  recursive: bool = field(
47
57
  default=False,
48
58
  metadata={
@@ -99,7 +109,7 @@ class DeployEnvCommand(click.RichCommand):
99
109
  console = Console()
100
110
  console.print(f"Deploying root - environment: {self.env_name}")
101
111
  obj: CLIConfig = ctx.obj
102
- obj.init(self.deploy_args.project, self.deploy_args.domain)
112
+ obj.init(self.deploy_args.project, self.deploy_args.domain, root_dir=self.deploy_args.root_dir)
103
113
  with console.status("Deploying...", spinner="dots"):
104
114
  deployment = flyte.deploy(
105
115
  self.env,
flyte/cli/_run.py CHANGED
@@ -23,14 +23,14 @@ RUN_REMOTE_CMD = "deployed-task"
23
23
 
24
24
 
25
25
  @lru_cache()
26
- def _initialize_config(ctx: click.Context, project: str, domain: str):
26
+ def _initialize_config(ctx: click.Context, project: str, domain: str, root_dir: str | None = None):
27
27
  obj: CLIConfig | None = ctx.obj
28
28
  if obj is None:
29
29
  import flyte.config
30
30
 
31
31
  obj = CLIConfig(flyte.config.auto(), ctx)
32
32
 
33
- obj.init(project, domain)
33
+ obj.init(project, domain, root_dir)
34
34
  return obj
35
35
 
36
36
 
@@ -77,6 +77,16 @@ class RunArguments:
77
77
  )
78
78
  },
79
79
  )
80
+ root_dir: str | None = field(
81
+ default=None,
82
+ metadata={
83
+ "click.option": click.Option(
84
+ ["--root-dir"],
85
+ type=str,
86
+ help="Override the root source directory, helpful when working with monorepos.",
87
+ )
88
+ },
89
+ )
80
90
  name: str | None = field(
81
91
  default=None,
82
92
  metadata={
@@ -121,7 +131,7 @@ class RunTaskCommand(click.RichCommand):
121
131
  super().__init__(obj_name, *args, **kwargs)
122
132
 
123
133
  def invoke(self, ctx: click.Context):
124
- obj: CLIConfig = _initialize_config(ctx, self.run_args.project, self.run_args.domain)
134
+ obj: CLIConfig = _initialize_config(ctx, self.run_args.project, self.run_args.domain, self.run_args.root_dir)
125
135
 
126
136
  async def _run():
127
137
  import flyte
flyte/config/_config.py CHANGED
@@ -231,10 +231,12 @@ def auto(config_file: typing.Union[str, pathlib.Path, ConfigFile, None] = None)
231
231
  1. If specified, read the config from the provided file path.
232
232
  2. If not specified, the config file is searched in the default locations.
233
233
  a. ./config.yaml if it exists (current working directory)
234
- b. `UCTL_CONFIG` environment variable
235
- c. `FLYTECTL_CONFIG` environment variable
236
- d. ~/.union/config.yaml if it exists
237
- e. ~/.flyte/config.yaml if it exists
234
+ b. ./.flyte/config.yaml if it exists (current working directory)
235
+ c. <git_root>/.flyte/config.yaml if it exists
236
+ d. `UCTL_CONFIG` environment variable
237
+ e. `FLYTECTL_CONFIG` environment variable
238
+ f. ~/.union/config.yaml if it exists
239
+ g. ~/.flyte/config.yaml if it exists
238
240
  3. If any value is not found in the config file, the default value is used.
239
241
  4. For any value there are environment variables that match the config variable names, those will override
240
242
 
flyte/config/_reader.py CHANGED
@@ -135,15 +135,25 @@ class ConfigFile(object):
135
135
  return self._yaml_config
136
136
 
137
137
 
138
+ def _config_path_from_git_root() -> pathlib.Path | None:
139
+ from flyte.git import config_from_root
140
+
141
+ try:
142
+ return config_from_root().source
143
+ except RuntimeError:
144
+ return None
145
+
146
+
138
147
  def resolve_config_path() -> pathlib.Path | None:
139
148
  """
140
149
  Config is read from the following locations in order of precedence:
141
150
  1. ./config.yaml if it exists
142
151
  2. ./.flyte/config.yaml if it exists
143
- 3. `UCTL_CONFIG` environment variable
144
- 4. `FLYTECTL_CONFIG` environment variable
145
- 5. ~/.union/config.yaml if it exists
146
- 6. ~/.flyte/config.yaml if it exists
152
+ 3. <git_root>/.flyte/config.yaml if it exists
153
+ 4. `UCTL_CONFIG` environment variable
154
+ 5. `FLYTECTL_CONFIG` environment variable
155
+ 6. ~/.union/config.yaml if it exists
156
+ 7. ~/.flyte/config.yaml if it exists
147
157
  """
148
158
  current_location_config = Path("config.yaml")
149
159
  if current_location_config.exists():
@@ -155,6 +165,11 @@ def resolve_config_path() -> pathlib.Path | None:
155
165
  return dot_flyte_config
156
166
  logger.debug("No ./.flyte/config.yaml found")
157
167
 
168
+ git_root_config = _config_path_from_git_root()
169
+ if git_root_config:
170
+ return git_root_config
171
+ logger.debug("No .flyte/config.yaml found in git repo root")
172
+
158
173
  uctl_path_from_env = getenv(UCTL_CONFIG_ENV_VAR, None)
159
174
  if uctl_path_from_env:
160
175
  return pathlib.Path(uctl_path_from_env)
flyte/git/_config.py CHANGED
@@ -14,4 +14,6 @@ def config_from_root(path: pathlib.Path | str = ".flyte/config.yaml") -> flyte.c
14
14
  if result.returncode != 0:
15
15
  raise RuntimeError(f"Failed to get git root directory: {result.stderr}")
16
16
  root = pathlib.Path(result.stdout.strip())
17
+ if not (root / path).exists():
18
+ raise RuntimeError(f"Config file {root / path} does not exist")
17
19
  return flyte.config.auto(root / path)
@@ -904,7 +904,8 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
904
904
  # t1(input_a: DataFrame) # or
905
905
  # t1(input_a: Annotated[DataFrame, my_cols])
906
906
  if issubclass(expected_python_type, DataFrame):
907
- fdf = DataFrame(format=metad.structured_dataset_type.format)
907
+ fdf = DataFrame(format=metad.structured_dataset_type.format, uri=lv.scalar.structured_dataset.uri)
908
+ fdf._already_uploaded = True
908
909
  fdf._literal_sd = lv.scalar.structured_dataset
909
910
  fdf._metadata = metad
910
911
  return fdf
@@ -1012,7 +1013,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
1012
1013
  def guess_python_type(self, literal_type: types_pb2.LiteralType) -> Type[DataFrame]:
1013
1014
  # todo: technically we should return the dataframe type specified in the constructor, but to do that,
1014
1015
  # we'd have to store that, which we don't do today. See possibly #1363
1015
- if literal_type.HasField("dataframe_type"):
1016
+ if literal_type.HasField("structured_dataset_type"):
1016
1017
  return DataFrame
1017
1018
  raise ValueError(f"DataFrameTransformerEngine cannot reverse {literal_type}")
1018
1019
 
flyte/io/_dir.py CHANGED
@@ -27,21 +27,21 @@ class Dir(BaseModel, Generic[T], SerializableType):
27
27
  The generic type T represents the format of the files in the directory.
28
28
 
29
29
  Example:
30
- ```python
31
- # Async usage
32
- from pandas import DataFrame
33
- data_dir = Dir[DataFrame](path="s3://my-bucket/data/")
34
-
35
- # Walk through files
36
- async for file in data_dir.walk():
37
- async with file.open() as f:
38
- content = await f.read()
39
-
40
- # Sync alternative
41
- for file in data_dir.walk_sync():
42
- with file.open_sync() as f:
43
- content = f.read()
44
- ```
30
+ ```python
31
+ # Async usage
32
+ from pandas import DataFrame
33
+ data_dir = Dir[DataFrame](path="s3://my-bucket/data/")
34
+
35
+ # Walk through files
36
+ async for file in data_dir.walk():
37
+ async with file.open() as f:
38
+ content = await f.read()
39
+
40
+ # Sync alternative
41
+ for file in data_dir.walk_sync():
42
+ with file.open_sync() as f:
43
+ content = f.read()
44
+ ```
45
45
  """
46
46
 
47
47
  # Represents either a local or remote path.
@@ -94,11 +94,11 @@ class Dir(BaseModel, Generic[T], SerializableType):
94
94
  File objects for each file found in the directory
95
95
 
96
96
  Example:
97
- ```python
98
- async for file in directory.walk():
99
- local_path = await file.download()
100
- # Process the file
101
- ```
97
+ ```python
98
+ async for file in directory.walk():
99
+ local_path = await file.download()
100
+ # Process the file
101
+ ```
102
102
  """
103
103
  fs = storage.get_underlying_filesystem(path=self.path)
104
104
  if recursive is False:
@@ -134,11 +134,11 @@ class Dir(BaseModel, Generic[T], SerializableType):
134
134
  File objects for each file found in the directory
135
135
 
136
136
  Example:
137
- ```python
138
- for file in directory.walk_sync():
139
- local_path = file.download_sync()
140
- # Process the file
141
- ```
137
+ ```python
138
+ for file in directory.walk_sync():
139
+ local_path = file.download_sync()
140
+ # Process the file
141
+ ```
142
142
  """
143
143
  fs = storage.get_underlying_filesystem(path=self.path)
144
144
  for parent, _, files in fs.walk(self.path, maxdepth=max_depth):
@@ -157,11 +157,11 @@ class Dir(BaseModel, Generic[T], SerializableType):
157
157
  A list of File objects
158
158
 
159
159
  Example:
160
- ```python
161
- files = await directory.list_files()
162
- for file in files:
163
- # Process the file
164
- ```
160
+ ```python
161
+ files = await directory.list_files()
162
+ for file in files:
163
+ # Process the file
164
+ ```
165
165
  """
166
166
  # todo: this should probably also just defer to fsspec.find()
167
167
  files = []
@@ -177,11 +177,11 @@ class Dir(BaseModel, Generic[T], SerializableType):
177
177
  A list of File objects
178
178
 
179
179
  Example:
180
- ```python
181
- files = directory.list_files_sync()
182
- for file in files:
183
- # Process the file
184
- ```
180
+ ```python
181
+ files = directory.list_files_sync()
182
+ for file in files:
183
+ # Process the file
184
+ ```
185
185
  """
186
186
  return list(self.walk_sync(recursive=False))
187
187
 
@@ -197,9 +197,9 @@ class Dir(BaseModel, Generic[T], SerializableType):
197
197
  The path to the downloaded directory
198
198
 
199
199
  Example:
200
- ```python
201
- local_dir = await directory.download('/tmp/my_data/')
202
- ```
200
+ ```python
201
+ local_dir = await directory.download('/tmp/my_data/')
202
+ ```
203
203
  """
204
204
  local_dest = str(local_path) if local_path else str(storage.get_random_local_path())
205
205
  if not storage.is_remote(self.path):
@@ -230,9 +230,9 @@ class Dir(BaseModel, Generic[T], SerializableType):
230
230
  The path to the downloaded directory
231
231
 
232
232
  Example:
233
- ```python
234
- local_dir = directory.download_sync('/tmp/my_data/')
235
- ```
233
+ ```python
234
+ local_dir = directory.download_sync('/tmp/my_data/')
235
+ ```
236
236
  """
237
237
  local_dest = str(local_path) if local_path else str(storage.get_random_local_path())
238
238
  if not storage.is_remote(self.path):
@@ -268,11 +268,11 @@ class Dir(BaseModel, Generic[T], SerializableType):
268
268
  A new Dir instance pointing to the uploaded directory
269
269
 
270
270
  Example:
271
- ```python
272
- remote_dir = await Dir[DataFrame].from_local('/tmp/data_dir/', 's3://bucket/data/')
273
- # With a known hash value you want to use for cache key calculation
274
- remote_dir = await Dir[DataFrame].from_local('/tmp/data_dir/', 's3://bucket/data/', dir_cache_key='abc123')
275
- ```
271
+ ```python
272
+ remote_dir = await Dir[DataFrame].from_local('/tmp/data_dir/', 's3://bucket/data/')
273
+ # With a known hash value you want to use for cache key calculation
274
+ remote_dir = await Dir[DataFrame].from_local('/tmp/data_dir/', 's3://bucket/data/', dir_cache_key='abc123')
275
+ ```
276
276
  """
277
277
  local_path_str = str(local_path)
278
278
  dirname = os.path.basename(os.path.normpath(local_path_str))
@@ -291,11 +291,11 @@ class Dir(BaseModel, Generic[T], SerializableType):
291
291
  the cache key will be computed based on this object's attributes.
292
292
 
293
293
  Example:
294
- ```python
295
- remote_dir = Dir.from_existing_remote("s3://bucket/data/")
296
- # With a known hash
297
- remote_dir = Dir.from_existing_remote("s3://bucket/data/", dir_cache_key="abc123")
298
- ```
294
+ ```python
295
+ remote_dir = Dir.from_existing_remote("s3://bucket/data/")
296
+ # With a known hash
297
+ remote_dir = Dir.from_existing_remote("s3://bucket/data/", dir_cache_key="abc123")
298
+ ```
299
299
  """
300
300
  return cls(path=remote_path, hash=dir_cache_key)
301
301
 
@@ -312,9 +312,9 @@ class Dir(BaseModel, Generic[T], SerializableType):
312
312
  A new Dir instance pointing to the uploaded directory
313
313
 
314
314
  Example:
315
- ```python
316
- remote_dir = Dir[DataFrame].from_local_sync('/tmp/data_dir/', 's3://bucket/data/')
317
- ```
315
+ ```python
316
+ remote_dir = Dir[DataFrame].from_local_sync('/tmp/data_dir/', 's3://bucket/data/')
317
+ ```
318
318
  """
319
319
  # Implement this after we figure out the final sync story
320
320
  raise NotImplementedError("Sync upload is not implemented for remote paths")
@@ -327,10 +327,10 @@ class Dir(BaseModel, Generic[T], SerializableType):
327
327
  True if the directory exists, False otherwise
328
328
 
329
329
  Example:
330
- ```python
331
- if await directory.exists():
332
- # Process the directory
333
- ```
330
+ ```python
331
+ if await directory.exists():
332
+ # Process the directory
333
+ ```
334
334
  """
335
335
  fs = storage.get_underlying_filesystem(path=self.path)
336
336
  if isinstance(fs, AsyncFileSystem):
@@ -346,10 +346,10 @@ class Dir(BaseModel, Generic[T], SerializableType):
346
346
  True if the directory exists, False otherwise
347
347
 
348
348
  Example:
349
- ```python
350
- if directory.exists_sync():
351
- # Process the directory
352
- ```
349
+ ```python
350
+ if directory.exists_sync():
351
+ # Process the directory
352
+ ```
353
353
  """
354
354
  fs = storage.get_underlying_filesystem(path=self.path)
355
355
  return fs.exists(self.path)
@@ -365,11 +365,11 @@ class Dir(BaseModel, Generic[T], SerializableType):
365
365
  A File instance if the file exists, None otherwise
366
366
 
367
367
  Example:
368
- ```python
369
- file = await directory.get_file("data.csv")
370
- if file:
371
- # Process the file
372
- ```
368
+ ```python
369
+ file = await directory.get_file("data.csv")
370
+ if file:
371
+ # Process the file
372
+ ```
373
373
  """
374
374
  fs = storage.get_underlying_filesystem(path=self.path)
375
375
  file_path = fs.sep.join([self.path, file_name])
@@ -390,11 +390,11 @@ class Dir(BaseModel, Generic[T], SerializableType):
390
390
  A File instance if the file exists, None otherwise
391
391
 
392
392
  Example:
393
- ```python
394
- file = directory.get_file_sync("data.csv")
395
- if file:
396
- # Process the file
397
- ```
393
+ ```python
394
+ file = directory.get_file_sync("data.csv")
395
+ if file:
396
+ # Process the file
397
+ ```
398
398
  """
399
399
  file_path = os.path.join(self.path, file_name)
400
400
  file = File[T](path=file_path)
flyte/models.py CHANGED
@@ -3,13 +3,14 @@ from __future__ import annotations
3
3
  import inspect
4
4
  import os
5
5
  import pathlib
6
+ import typing
6
7
  from dataclasses import dataclass, field, replace
7
8
  from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Literal, Optional, Tuple, Type
8
9
 
9
10
  import rich.repr
10
11
 
11
12
  from flyte._docstring import Docstring
12
- from flyte._interface import extract_return_annotation
13
+ from flyte._interface import extract_return_annotation, literal_to_enum
13
14
  from flyte._logging import logger
14
15
 
15
16
  if TYPE_CHECKING:
@@ -329,7 +330,10 @@ class NativeInterface:
329
330
  logger.warning(
330
331
  f"Function {func.__name__} has parameter {name} without type annotation. Data will be pickled."
331
332
  )
332
- param_info[name] = (param.annotation, param.default)
333
+ if typing.get_origin(param.annotation) is Literal:
334
+ param_info[name] = (literal_to_enum(param.annotation), param.default)
335
+ else:
336
+ param_info[name] = (param.annotation, param.default)
333
337
 
334
338
  # Get return type
335
339
  outputs = extract_return_annotation(sig.return_annotation)
flyte/remote/_action.py CHANGED
@@ -369,14 +369,15 @@ class Action(ToJSONMixin):
369
369
  # If the action is done, handle the final state
370
370
  if ad.done():
371
371
  progress.stop_task(task_id)
372
- if ad.pb2.status.phase == run_definition_pb2.PHASE_SUCCEEDED:
373
- console.print(f"[bold green]Run '{self.run_name}' completed successfully.[/bold green]")
374
- else:
375
- error_message = ad.error_info.message if ad.error_info else ""
376
- console.print(
377
- f"[bold red]Run '{self.run_name}' exited unsuccessfully in state {ad.phase}"
378
- f" with error: {error_message}[/bold red]"
379
- )
372
+ if not quiet:
373
+ if ad.pb2.status.phase == run_definition_pb2.PHASE_SUCCEEDED:
374
+ console.print(f"[bold green]Run '{self.run_name}' completed successfully.[/bold green]")
375
+ else:
376
+ error_message = ad.error_info.message if ad.error_info else ""
377
+ console.print(
378
+ f"[bold red]Run '{self.run_name}' exited unsuccessfully in state {ad.phase}"
379
+ f" with error: {error_message}[/bold red]"
380
+ )
380
381
  break
381
382
  except asyncio.CancelledError:
382
383
  # Handle cancellation gracefully
@@ -1,4 +1,4 @@
1
- import click
1
+ from rich import print as rich_print
2
2
 
3
3
  from flyte._logging import logger
4
4
  from flyte.remote._client.auth import _token_client as token_client
@@ -94,10 +94,9 @@ class DeviceCodeAuthenticator(Authenticator):
94
94
 
95
95
  full_uri = f"{resp.verification_uri}?user_code={resp.user_code}"
96
96
  text = (
97
- f"To Authenticate, navigate in a browser to the following URL: "
98
- f"{click.style(full_uri, fg='blue', underline=True)}"
97
+ f"To Authenticate, navigate in a browser to the following URL: [blue link={full_uri}]{full_uri}[/blue link]"
99
98
  )
100
- click.secho(text)
99
+ rich_print(text)
101
100
  try:
102
101
  token, refresh_token, expires_in = await token_client.poll_token_endpoint(
103
102
  resp,