flyte 2.0.0b18__py3-none-any.whl → 2.0.0b20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

flyte/_map.py CHANGED
@@ -1,4 +1,6 @@
1
1
  import asyncio
2
+ import functools
3
+ import logging
2
4
  from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast
3
5
 
4
6
  from flyte.syncify import syncify
@@ -11,7 +13,14 @@ from ._task import P, R, TaskTemplate
11
13
  class MapAsyncIterator(Generic[P, R]):
12
14
  """AsyncIterator implementation for the map function results"""
13
15
 
14
- def __init__(self, func: TaskTemplate[P, R], args: tuple, name: str, concurrency: int, return_exceptions: bool):
16
+ def __init__(
17
+ self,
18
+ func: TaskTemplate[P, R] | functools.partial[R],
19
+ args: tuple,
20
+ name: str,
21
+ concurrency: int,
22
+ return_exceptions: bool,
23
+ ):
15
24
  self.func = func
16
25
  self.args = args
17
26
  self.name = name
@@ -49,13 +58,16 @@ class MapAsyncIterator(Generic[P, R]):
49
58
  return result
50
59
  except Exception as e:
51
60
  self._exception_count += 1
52
- logger.debug(f"Task {self._current_index - 1} failed with exception: {e}")
61
+ logger.debug(
62
+ f"Task {self._current_index - 1} failed with exception: {e}, return_exceptions={self.return_exceptions}"
63
+ )
53
64
  if self.return_exceptions:
54
65
  return e
55
66
  else:
56
67
  # Cancel remaining tasks
57
68
  for remaining_task in self._tasks[self._current_index + 1 :]:
58
69
  remaining_task.cancel()
70
+ logger.warning("Exception raising is `ON`, raising exception and cancelling remaining tasks")
59
71
  raise e
60
72
 
61
73
  async def _initialize(self):
@@ -64,10 +76,26 @@ class MapAsyncIterator(Generic[P, R]):
64
76
  tasks = []
65
77
  task_count = 0
66
78
 
67
- for arg_tuple in zip(*self.args):
68
- task = asyncio.create_task(self.func.aio(*arg_tuple))
69
- tasks.append(task)
70
- task_count += 1
79
+ if isinstance(self.func, functools.partial):
80
+ # Handle partial functions by merging bound args/kwargs with mapped args
81
+ base_func = cast(TaskTemplate, self.func.func)
82
+ bound_args = self.func.args
83
+ bound_kwargs = self.func.keywords or {}
84
+
85
+ for arg_tuple in zip(*self.args):
86
+ # Merge bound positional args with mapped args
87
+ merged_args = bound_args + arg_tuple
88
+ if logger.isEnabledFor(logging.DEBUG):
89
+ logger.debug(f"Running {base_func.name} with args: {merged_args} and kwargs: {bound_kwargs}")
90
+ task = asyncio.create_task(base_func.aio(*merged_args, **bound_kwargs))
91
+ tasks.append(task)
92
+ task_count += 1
93
+ else:
94
+ # Handle regular TaskTemplate functions
95
+ for arg_tuple in zip(*self.args):
96
+ task = asyncio.create_task(self.func.aio(*arg_tuple))
97
+ tasks.append(task)
98
+ task_count += 1
71
99
 
72
100
  if task_count == 0:
73
101
  logger.info(f"Group '{self.name}' has no tasks to process")
@@ -107,9 +135,46 @@ class _Mapper(Generic[P, R]):
107
135
  """Get the name of the group, defaulting to 'map' if not provided."""
108
136
  return f"{task_name}_{group_name or 'map'}"
109
137
 
138
+ @staticmethod
139
+ def validate_partial(func: functools.partial[R]):
140
+ """
141
+ This method validates that the provided partial function is valid for mapping, i.e. only the one argument
142
+ is left for mapping and the rest are provided as keywords or args.
143
+
144
+ :param func: partial function to validate
145
+ :raises TypeError: if the partial function is not valid for mapping
146
+ """
147
+ f = cast(TaskTemplate, func.func)
148
+ inputs = f.native_interface.inputs
149
+ params = list(inputs.keys())
150
+ total_params = len(params)
151
+ provided_args = len(func.args)
152
+ provided_kwargs = len(func.keywords or {})
153
+
154
+ # Calculate how many parameters are left unspecified
155
+ unspecified_count = total_params - provided_args - provided_kwargs
156
+
157
+ # Exactly one parameter should be left for mapping
158
+ if unspecified_count != 1:
159
+ raise TypeError(
160
+ f"Partial function must leave exactly one parameter unspecified for mapping. "
161
+ f"Found {unspecified_count} unspecified parameters in {f.name}, "
162
+ f"params: {inputs.keys()}"
163
+ )
164
+
165
+ # Validate that no parameter is both in args and keywords
166
+ if func.keywords:
167
+ param_names = list(inputs.keys())
168
+ for i, arg_name in enumerate(param_names[: provided_args + 1]):
169
+ if arg_name in func.keywords:
170
+ raise TypeError(
171
+ f"Parameter '{arg_name}' is provided both as positional argument and keyword argument "
172
+ f"in partial function {f.name}."
173
+ )
174
+
110
175
  def __call__(
111
176
  self,
112
- func: TaskTemplate[P, R],
177
+ func: TaskTemplate[P, R] | functools.partial[R],
113
178
  *args: Iterable[Any],
114
179
  group_name: str | None = None,
115
180
  concurrency: int = 0,
@@ -128,7 +193,13 @@ class _Mapper(Generic[P, R]):
128
193
  if not args:
129
194
  return
130
195
 
131
- name = self._get_name(func.name, group_name)
196
+ if isinstance(func, functools.partial):
197
+ f = cast(TaskTemplate, func.func)
198
+ self.validate_partial(func)
199
+ else:
200
+ f = cast(TaskTemplate, func)
201
+
202
+ name = self._get_name(f.name, group_name)
132
203
  logger.debug(f"Blocking Map for {name}")
133
204
  with group(name):
134
205
  import flyte
@@ -154,7 +225,7 @@ class _Mapper(Generic[P, R]):
154
225
  *args,
155
226
  name=name,
156
227
  concurrency=concurrency,
157
- return_exceptions=True,
228
+ return_exceptions=return_exceptions,
158
229
  ),
159
230
  ):
160
231
  logger.debug(f"Mapped {x}, task {i}")
@@ -163,7 +234,7 @@ class _Mapper(Generic[P, R]):
163
234
 
164
235
  async def aio(
165
236
  self,
166
- func: TaskTemplate[P, R],
237
+ func: TaskTemplate[P, R] | functools.partial[R],
167
238
  *args: Iterable[Any],
168
239
  group_name: str | None = None,
169
240
  concurrency: int = 0,
@@ -171,7 +242,14 @@ class _Mapper(Generic[P, R]):
171
242
  ) -> AsyncGenerator[Union[R, Exception], None]:
172
243
  if not args:
173
244
  return
174
- name = self._get_name(func.name, group_name)
245
+
246
+ if isinstance(func, functools.partial):
247
+ f = cast(TaskTemplate, func.func)
248
+ self.validate_partial(func)
249
+ else:
250
+ f = cast(TaskTemplate, func)
251
+
252
+ name = self._get_name(f.name, group_name)
175
253
  with group(name):
176
254
  import flyte
177
255
 
@@ -199,7 +277,7 @@ class _Mapper(Generic[P, R]):
199
277
 
200
278
  @syncify
201
279
  async def _map(
202
- func: TaskTemplate[P, R],
280
+ func: TaskTemplate[P, R] | functools.partial[R],
203
281
  *args: Iterable[Any],
204
282
  name: str = "map",
205
283
  concurrency: int = 0,
flyte/_task.py CHANGED
@@ -258,6 +258,9 @@ class TaskTemplate(Generic[P, R]):
258
258
  else:
259
259
  raise RuntimeSystemError("BadContext", "Controller is not initialized.")
260
260
  else:
261
+ from flyte._logging import logger
262
+
263
+ logger.warning(f"Task {self.name} running aio outside of a task context.")
261
264
  # Local execute, just stay out of the way, but because .aio is used, we want to return an awaitable,
262
265
  # even for synchronous tasks. This is to support migration.
263
266
  return self.forward(*args, **kwargs)
flyte/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.0.0b18'
32
- __version_tuple__ = version_tuple = (2, 0, 0, 'b18')
31
+ __version__ = version = '2.0.0b20'
32
+ __version_tuple__ = version_tuple = (2, 0, 0, 'b20')
33
33
 
34
- __commit_id__ = commit_id = 'g930faeaea'
34
+ __commit_id__ = commit_id = 'g5109b02e4'
flyte/cli/_create.py CHANGED
@@ -98,7 +98,7 @@ def secret(
98
98
  "-o",
99
99
  "--output",
100
100
  type=click.Path(exists=False, writable=True),
101
- default=Path.cwd() / "config.yaml",
101
+ default=Path.cwd() / ".flyte" / "config.yaml",
102
102
  help="Path to the output directory where the configuration will be saved. Defaults to current directory.",
103
103
  show_default=True,
104
104
  )
@@ -147,6 +147,9 @@ def config(
147
147
 
148
148
  output_path = Path(output)
149
149
 
150
+ if not output_path.parent.exists():
151
+ output_path.parent.mkdir(parents=True)
152
+
150
153
  if output_path.exists() and not force:
151
154
  force = click.confirm(f"Overwrite [{output_path}]?", default=False)
152
155
  if not force:
flyte/cli/_deploy.py CHANGED
@@ -4,8 +4,7 @@ from pathlib import Path
4
4
  from types import ModuleType
5
5
  from typing import Any, Dict, List, cast, get_args
6
6
 
7
- import click
8
- from click import Context
7
+ import rich_click as click
9
8
 
10
9
  import flyte
11
10
 
@@ -87,14 +86,14 @@ class DeployArguments:
87
86
  return [common.get_option_from_metadata(f.metadata) for f in fields(cls) if f.metadata]
88
87
 
89
88
 
90
- class DeployEnvCommand(click.Command):
89
+ class DeployEnvCommand(click.RichCommand):
91
90
  def __init__(self, env_name: str, env: Any, deploy_args: DeployArguments, *args, **kwargs):
92
91
  self.env_name = env_name
93
92
  self.env = env
94
93
  self.deploy_args = deploy_args
95
94
  super().__init__(*args, **kwargs)
96
95
 
97
- def invoke(self, ctx: Context):
96
+ def invoke(self, ctx: click.Context):
98
97
  from rich.console import Console
99
98
 
100
99
  console = Console()
@@ -125,7 +124,7 @@ class DeployEnvRecursiveCommand(click.Command):
125
124
  self.deploy_args = deploy_args
126
125
  super().__init__(*args, **kwargs)
127
126
 
128
- def invoke(self, ctx: Context):
127
+ def invoke(self, ctx: click.Context):
129
128
  from rich.console import Console
130
129
 
131
130
  from flyte._environment import list_loaded_environments
flyte/cli/_params.py CHANGED
@@ -283,13 +283,17 @@ class UnionParamType(click.ParamType):
283
283
  A composite type that allows for multiple types to be specified. This is used for union types.
284
284
  """
285
285
 
286
- def __init__(self, types: typing.List[click.ParamType]):
286
+ def __init__(self, types: typing.List[click.ParamType | None]):
287
287
  super().__init__()
288
288
  self._types = self._sort_precedence(types)
289
- self.name = "|".join([t.name for t in self._types])
289
+ self.name = "|".join([t.name for t in self._types if t is not None])
290
+ self.optional = False
291
+ if None in types:
292
+ self.name = f"Optional[{self.name}]"
293
+ self.optional = True
290
294
 
291
295
  @staticmethod
292
- def _sort_precedence(tp: typing.List[click.ParamType]) -> typing.List[click.ParamType]:
296
+ def _sort_precedence(tp: typing.List[click.ParamType | None]) -> typing.List[click.ParamType]:
293
297
  unprocessed = []
294
298
  str_types = []
295
299
  others = []
@@ -311,6 +315,8 @@ class UnionParamType(click.ParamType):
311
315
  """
312
316
  for p in self._types:
313
317
  try:
318
+ if p is None and value is None:
319
+ return None
314
320
  return p.convert(value, param, ctx)
315
321
  except Exception as e:
316
322
  logger.debug(f"Ignoring conversion error for type {p} trying other variants in Union. Error: {e}")
@@ -433,7 +439,10 @@ def literal_type_to_click_type(lt: LiteralType, python_type: typing.Type) -> cli
433
439
  for i in range(len(lt.union_type.variants)):
434
440
  variant = lt.union_type.variants[i]
435
441
  variant_python_type = typing.get_args(python_type)[i]
436
- cts.append(literal_type_to_click_type(variant, variant_python_type))
442
+ if variant_python_type is type(None):
443
+ cts.append(None)
444
+ else:
445
+ cts.append(literal_type_to_click_type(variant, variant_python_type))
437
446
  return UnionParamType(cts)
438
447
 
439
448
  if lt.HasField("enum_type"):
@@ -461,6 +470,9 @@ class FlyteLiteralConverter(object):
461
470
  def is_bool(self) -> bool:
462
471
  return self.click_type == click.BOOL
463
472
 
473
+ def is_optional(self) -> bool:
474
+ return isinstance(self.click_type, UnionParamType) and self.click_type.optional
475
+
464
476
  def convert(
465
477
  self, ctx: click.Context, param: typing.Optional[click.Parameter], value: typing.Any
466
478
  ) -> typing.Union[Literal, typing.Any]:
@@ -525,6 +537,8 @@ def to_click_option(
525
537
  if literal_converter.is_bool():
526
538
  required = False
527
539
  is_flag = True
540
+ if literal_converter.is_optional():
541
+ required = False
528
542
 
529
543
  return click.Option(
530
544
  param_decls=[f"--{input_name}"],
flyte/cli/_run.py CHANGED
@@ -112,7 +112,7 @@ class RunArguments:
112
112
  return [common.get_option_from_metadata(f.metadata) for f in fields(cls) if f.metadata]
113
113
 
114
114
 
115
- class RunTaskCommand(click.Command):
115
+ class RunTaskCommand(click.RichCommand):
116
116
  def __init__(self, obj_name: str, obj: Any, run_args: RunArguments, *args, **kwargs):
117
117
  self.obj_name = obj_name
118
118
  self.obj = cast(TaskTemplate, obj)
@@ -196,7 +196,7 @@ class TaskPerFileGroup(common.ObjectsPerFileGroup):
196
196
  )
197
197
 
198
198
 
199
- class RunReferenceTaskCommand(click.Command):
199
+ class RunReferenceTaskCommand(click.RichCommand):
200
200
  def __init__(self, task_name: str, run_args: RunArguments, version: str | None, *args, **kwargs):
201
201
  self.task_name = task_name
202
202
  self.run_args = run_args
flyte/config/_config.py CHANGED
@@ -192,7 +192,7 @@ class Config(object):
192
192
  )
193
193
 
194
194
  @classmethod
195
- def auto(cls, config_file: typing.Union[str, ConfigFile, None] = None) -> "Config":
195
+ def auto(cls, config_file: typing.Union[str, pathlib.Path, ConfigFile, None] = None) -> "Config":
196
196
  """
197
197
  Automatically constructs the Config Object. The order of precedence is as follows
198
198
  1. first try to find any env vars that match the config vars specified in the FLYTE_CONFIG format.
@@ -225,7 +225,7 @@ def set_if_exists(d: dict, k: str, val: typing.Any) -> dict:
225
225
  return d
226
226
 
227
227
 
228
- def auto(config_file: typing.Union[str, ConfigFile, None] = None) -> Config:
228
+ def auto(config_file: typing.Union[str, pathlib.Path, ConfigFile, None] = None) -> Config:
229
229
  """
230
230
  Automatically constructs the Config Object. The order of precedence is as follows
231
231
  1. If specified, read the config from the provided file path.
flyte/config/_reader.py CHANGED
@@ -108,7 +108,7 @@ class ConfigFile(object):
108
108
  return pathlib.Path(self._location)
109
109
 
110
110
  @staticmethod
111
- def _read_yaml_config(location: str) -> typing.Optional[typing.Dict[str, typing.Any]]:
111
+ def _read_yaml_config(location: str | pathlib.Path) -> typing.Optional[typing.Dict[str, typing.Any]]:
112
112
  with open(location, "r") as fh:
113
113
  try:
114
114
  yaml_contents = yaml.safe_load(fh)
@@ -139,16 +139,22 @@ def resolve_config_path() -> pathlib.Path | None:
139
139
  """
140
140
  Config is read from the following locations in order of precedence:
141
141
  1. ./config.yaml if it exists
142
- 2. `UCTL_CONFIG` environment variable
143
- 3. `FLYTECTL_CONFIG` environment variable
144
- 4. ~/.union/config.yaml if it exists
145
- 5. ~/.flyte/config.yaml if it exists
142
+ 2. ./.flyte/config.yaml if it exists
143
+ 3. `UCTL_CONFIG` environment variable
144
+ 4. `FLYTECTL_CONFIG` environment variable
145
+ 5. ~/.union/config.yaml if it exists
146
+ 6. ~/.flyte/config.yaml if it exists
146
147
  """
147
148
  current_location_config = Path("config.yaml")
148
149
  if current_location_config.exists():
149
150
  return current_location_config
150
151
  logger.debug("No ./config.yaml found")
151
152
 
153
+ dot_flyte_config = Path(".flyte", "config.yaml")
154
+ if dot_flyte_config.exists():
155
+ return dot_flyte_config
156
+ logger.debug("No ./.flyte/config.yaml found")
157
+
152
158
  uctl_path_from_env = getenv(UCTL_CONFIG_ENV_VAR, None)
153
159
  if uctl_path_from_env:
154
160
  return pathlib.Path(uctl_path_from_env)
@@ -173,13 +179,13 @@ def resolve_config_path() -> pathlib.Path | None:
173
179
 
174
180
 
175
181
  @lru_cache
176
- def get_config_file(c: typing.Union[str, ConfigFile, None]) -> ConfigFile | None:
182
+ def get_config_file(c: typing.Union[str, pathlib.Path, ConfigFile, None]) -> ConfigFile | None:
177
183
  """
178
184
  Checks if the given argument is a file or a configFile and returns a loaded configFile else returns None
179
185
  """
180
- if isinstance(c, str):
186
+ if isinstance(c, (str, pathlib.Path)):
181
187
  logger.debug(f"Using specified config file at {c}")
182
- return ConfigFile(c)
188
+ return ConfigFile(str(c))
183
189
  elif isinstance(c, ConfigFile):
184
190
  return c
185
191
  config_path = resolve_config_path()
flyte/errors.py CHANGED
@@ -132,7 +132,9 @@ class CustomError(RuntimeUserError):
132
132
  Create a CustomError from an exception. The exception's class name is used as the error code and the exception
133
133
  message is used as the error message.
134
134
  """
135
- return cls(e.__class__.__name__, str(e))
135
+ new_exc = cls(e.__class__.__name__, str(e))
136
+ new_exc.__cause__ = e
137
+ return new_exc
136
138
 
137
139
 
138
140
  class NotInTaskContextError(RuntimeUserError):
@@ -221,3 +223,12 @@ class RunAbortedError(RuntimeUserError):
221
223
 
222
224
  def __init__(self, message: str):
223
225
  super().__init__("RunAbortedError", message, "user")
226
+
227
+
228
+ class SlowDownError(RuntimeUserError):
229
+ """
230
+ This error is raised when the user tries to access a resource that does not exist or is invalid.
231
+ """
232
+
233
+ def __init__(self, message: str):
234
+ super().__init__("SlowDownError", message, "user")
flyte/git/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from ._config import config_from_root
2
+
3
+ __all__ = ["config_from_root"]
flyte/git/_config.py ADDED
@@ -0,0 +1,17 @@
1
+ import pathlib
2
+ import subprocess
3
+
4
+ import flyte.config
5
+
6
+
7
+ def config_from_root(path: pathlib.Path | str = ".flyte/config.yaml") -> flyte.config.Config:
8
+ """Get the config file from the git root directory.
9
+
10
+ By default, the config file is expected to be in `.flyte/config.yaml` in the git root directory.
11
+ """
12
+
13
+ result = subprocess.run(["git", "rev-parse", "--show-toplevel"], check=False, capture_output=True, text=True)
14
+ if result.returncode != 0:
15
+ raise RuntimeError(f"Failed to get git root directory: {result.stderr}")
16
+ root = pathlib.Path(result.stdout.strip())
17
+ return flyte.config.auto(root / path)
@@ -58,16 +58,16 @@ class PandasToCSVEncodingHandler(DataFrameEncoder):
58
58
 
59
59
  if not storage.is_remote(uri):
60
60
  Path(uri).mkdir(parents=True, exist_ok=True)
61
- path = os.path.join(uri, ".csv")
61
+ csv_file = storage.join(uri, "data.csv")
62
62
  df = typing.cast(pd.DataFrame, dataframe.val)
63
63
  df.to_csv(
64
- path,
64
+ csv_file,
65
65
  index=False,
66
- storage_options=get_pandas_storage_options(uri=path),
66
+ storage_options=get_pandas_storage_options(uri=csv_file),
67
67
  )
68
68
  structured_dataset_type.format = CSV
69
69
  return literals_pb2.StructuredDataset(
70
- uri=uri, metadata=literals_pb2.StructuredDatasetMetadata(structured_dataset_type)
70
+ uri=uri, metadata=literals_pb2.StructuredDatasetMetadata(structured_dataset_type=structured_dataset_type)
71
71
  )
72
72
 
73
73
 
@@ -83,16 +83,25 @@ class CSVToPandasDecodingHandler(DataFrameDecoder):
83
83
  uri = proto_value.uri
84
84
  columns = None
85
85
  kwargs = get_pandas_storage_options(uri=uri)
86
- path = os.path.join(uri, ".csv")
86
+ csv_file = storage.join(uri, "data.csv")
87
87
  if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
88
88
  columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
89
89
  try:
90
- return pd.read_csv(path, usecols=columns, storage_options=kwargs)
90
+ import io
91
+
92
+ # The pattern used here is a bit wonky because of obstore issues with csv, getting early eof error.
93
+ buf = io.BytesIO()
94
+ async for chunk in storage.get_stream(csv_file):
95
+ buf.write(chunk)
96
+ buf.seek(0)
97
+ df = pd.read_csv(buf)
98
+ return df
99
+
91
100
  except Exception as exc:
92
101
  if exc.__class__.__name__ == "NoCredentialsError":
93
102
  logger.debug("S3 source detected, attempting anonymous S3 access")
94
103
  kwargs = get_pandas_storage_options(uri=uri, anonymous=True)
95
- return pd.read_csv(path, usecols=columns, storage_options=kwargs)
104
+ return pd.read_csv(csv_file, usecols=columns, storage_options=kwargs)
96
105
  else:
97
106
  raise
98
107