flyte 2.0.0b17__py3-none-any.whl → 2.0.0b19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (52) hide show
  1. flyte/_bin/runtime.py +3 -0
  2. flyte/_debug/vscode.py +4 -2
  3. flyte/_deploy.py +3 -1
  4. flyte/_environment.py +15 -6
  5. flyte/_hash.py +1 -16
  6. flyte/_image.py +6 -1
  7. flyte/_initialize.py +15 -16
  8. flyte/_internal/controllers/__init__.py +4 -5
  9. flyte/_internal/controllers/_local_controller.py +5 -5
  10. flyte/_internal/controllers/remote/_controller.py +21 -28
  11. flyte/_internal/controllers/remote/_core.py +1 -1
  12. flyte/_internal/imagebuild/docker_builder.py +31 -23
  13. flyte/_internal/imagebuild/remote_builder.py +37 -10
  14. flyte/_internal/imagebuild/utils.py +2 -1
  15. flyte/_internal/runtime/convert.py +69 -2
  16. flyte/_internal/runtime/taskrunner.py +4 -1
  17. flyte/_logging.py +110 -26
  18. flyte/_map.py +90 -12
  19. flyte/_pod.py +2 -1
  20. flyte/_run.py +6 -1
  21. flyte/_task.py +3 -0
  22. flyte/_task_environment.py +5 -1
  23. flyte/_trace.py +5 -0
  24. flyte/_version.py +3 -3
  25. flyte/cli/_create.py +4 -1
  26. flyte/cli/_deploy.py +4 -5
  27. flyte/cli/_params.py +18 -4
  28. flyte/cli/_run.py +2 -2
  29. flyte/config/_config.py +2 -2
  30. flyte/config/_reader.py +14 -8
  31. flyte/errors.py +3 -1
  32. flyte/git/__init__.py +3 -0
  33. flyte/git/_config.py +17 -0
  34. flyte/io/_dataframe/basic_dfs.py +16 -7
  35. flyte/io/_dataframe/dataframe.py +84 -123
  36. flyte/io/_dir.py +35 -4
  37. flyte/io/_file.py +61 -15
  38. flyte/io/_hashing_io.py +342 -0
  39. flyte/models.py +12 -4
  40. flyte/remote/_action.py +4 -2
  41. flyte/remote/_task.py +52 -22
  42. flyte/report/_report.py +1 -1
  43. flyte/storage/_storage.py +16 -1
  44. flyte/types/_type_engine.py +1 -51
  45. {flyte-2.0.0b17.data → flyte-2.0.0b19.data}/scripts/runtime.py +3 -0
  46. {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/METADATA +1 -1
  47. {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/RECORD +52 -49
  48. {flyte-2.0.0b17.data → flyte-2.0.0b19.data}/scripts/debug.py +0 -0
  49. {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/WHEEL +0 -0
  50. {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/entry_points.txt +0 -0
  51. {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/licenses/LICENSE +0 -0
  52. {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/top_level.txt +0 -0
@@ -308,15 +308,82 @@ def generate_inputs_hash(serialized_inputs: str | bytes) -> str:
308
308
  return hash_data(serialized_inputs)
309
309
 
310
310
 
311
+ def generate_inputs_repr_for_literal(literal: literals_pb2.Literal) -> bytes:
312
+ """
313
+ Generate a byte representation for a single literal that is meant to be hashed as part of the cache key
314
+ computation for an Action. This function should just serialize the literal deterministically, but will
315
+ use an existing hash value if present in the Literal. This is trivial, except we need to handle nested literals
316
+ (inside collections and maps), that may have the hash property set.
317
+
318
+ :param literal: The literal to get a hashable representation for.
319
+ :return: byte representation of the literal that can be fed into a hash function.
320
+ """
321
+ # If the literal has a hash value, use that instead of serializing the full literal
322
+ if literal.hash:
323
+ return literal.hash.encode("utf-8")
324
+
325
+ if literal.HasField("collection"):
326
+ buf = bytearray()
327
+ for nested_literal in literal.collection.literals:
328
+ if nested_literal.hash:
329
+ buf += nested_literal.hash.encode("utf-8")
330
+ else:
331
+ buf += generate_inputs_repr_for_literal(nested_literal)
332
+
333
+ b = bytes(buf)
334
+ return b
335
+
336
+ elif literal.HasField("map"):
337
+ buf = bytearray()
338
+ # Sort keys to ensure deterministic ordering
339
+ for key in sorted(literal.map.literals.keys()):
340
+ nested_literal = literal.map.literals[key]
341
+ buf += key.encode("utf-8")
342
+ if nested_literal.hash:
343
+ buf += nested_literal.hash.encode("utf-8")
344
+ else:
345
+ buf += generate_inputs_repr_for_literal(nested_literal)
346
+
347
+ b = bytes(buf)
348
+ return b
349
+
350
+ # For all other cases (scalars, etc.), just serialize the literal normally
351
+ return literal.SerializeToString(deterministic=True)
352
+
353
+
354
+ def generate_inputs_hash_for_named_literals(inputs: list[run_definition_pb2.NamedLiteral]) -> str:
355
+ """
356
+ Generate a hash for the inputs using the new literal representation approach that respects
357
+ hash values already present in literals. This is used to uniquely identify the inputs for a task
358
+ when some literals may have precomputed hash values.
359
+
360
+ :param inputs: List of NamedLiteral inputs to hash.
361
+ :return: A base64-encoded string representation of the hash.
362
+ """
363
+ if not inputs:
364
+ return ""
365
+
366
+ # Build the byte representation by concatenating each literal's representation
367
+ combined_bytes = b""
368
+ for named_literal in inputs:
369
+ # Add the name to ensure order matters
370
+ name_bytes = named_literal.name.encode("utf-8")
371
+ literal_bytes = generate_inputs_repr_for_literal(named_literal.value)
372
+ # Combine name and literal bytes with a separator to avoid collisions
373
+ combined_bytes += name_bytes + b":" + literal_bytes + b";"
374
+
375
+ return hash_data(combined_bytes)
376
+
377
+
311
378
  def generate_inputs_hash_from_proto(inputs: run_definition_pb2.Inputs) -> str:
312
379
  """
313
380
  Generate a hash for the inputs. This is used to uniquely identify the inputs for a task.
314
381
  :param inputs: The inputs to hash.
315
382
  :return: A hexadecimal string representation of the hash.
316
383
  """
317
- if not inputs:
384
+ if not inputs or not inputs.literals:
318
385
  return ""
319
- return generate_inputs_hash(inputs.SerializeToString(deterministic=True))
386
+ return generate_inputs_hash_for_named_literals(list(inputs.literals))
320
387
 
321
388
 
322
389
  def generate_interface_hash(task_interface: interface_pb2.TypedInterface) -> str:
@@ -4,6 +4,7 @@ invoked within a context tree.
4
4
  """
5
5
 
6
6
  import pathlib
7
+ import time
7
8
  from typing import Any, Dict, List, Optional, Tuple
8
9
 
9
10
  import flyte.report
@@ -172,6 +173,8 @@ async def extract_download_run_upload(
172
173
  This method is invoked from the CLI (urun) and is used to run a task. This assumes that the context tree
173
174
  has already been created, and the task has been loaded. It also handles the loading of the task.
174
175
  """
176
+ t = time.time()
177
+ logger.warning(f"Task {action.name} started at {t}")
175
178
  outputs, err = await convert_and_run(
176
179
  task=task,
177
180
  input_path=input_path,
@@ -194,4 +197,4 @@ async def extract_download_run_upload(
194
197
  logger.info(f"Task {task.name} completed successfully, no outputs")
195
198
  return
196
199
  await upload_outputs(outputs, output_path) if output_path else None
197
- logger.info(f"Task {task.name} completed successfully, uploaded outputs to {output_path}")
200
+ logger.warning(f"Task {task.name} completed successfully, uploaded outputs to {output_path} in {time.time() - t}s")
flyte/_logging.py CHANGED
@@ -75,39 +75,52 @@ def get_rich_handler(log_level: int) -> Optional[logging.Handler]:
75
75
  return handler
76
76
 
77
77
 
78
- def get_default_handler(log_level: int) -> logging.Handler:
79
- handler = logging.StreamHandler()
80
- handler.setLevel(log_level)
81
- formatter = logging.Formatter(fmt="[%(name)s] %(message)s")
82
- if log_format_from_env() == "json":
83
- pass
84
- # formatter = jsonlogger.JsonFormatter(fmt="%(asctime)s %(name)s %(levelname)s %(message)s")
85
- handler.setFormatter(formatter)
86
- return handler
87
-
88
-
89
78
  def initialize_logger(log_level: int = get_env_log_level(), enable_rich: bool = False):
90
79
  """
91
80
  Initializes the global loggers to the default configuration.
81
+ When enable_rich=True, upgrades to Rich handler for local CLI usage.
92
82
  """
93
83
  global logger # noqa: PLW0603
94
- logger = _create_logger("flyte", log_level, enable_rich)
95
84
 
85
+ # Clear existing handlers to reconfigure
86
+ root = logging.getLogger()
87
+ root.handlers.clear()
96
88
 
97
- def _create_logger(name: str, log_level: int = DEFAULT_LOG_LEVEL, enable_rich: bool = False) -> logging.Logger:
98
- """
99
- Creates a logger with the given name and log level.
100
- """
101
- logger = logging.getLogger(name)
102
- logger.setLevel(log_level)
103
- handler = None
104
- logger.handlers = []
89
+ flyte_logger = logging.getLogger("flyte")
90
+ flyte_logger.handlers.clear()
91
+
92
+ # Set up root logger handler
93
+ root_handler = None
94
+ if enable_rich:
95
+ root_handler = get_rich_handler(log_level)
96
+
97
+ if root_handler is None:
98
+ root_handler = logging.StreamHandler()
99
+
100
+ # Add context filter to root handler for all logging
101
+ root_handler.addFilter(ContextFilter())
102
+ root.addHandler(root_handler)
103
+
104
+ # Set up Flyte logger handler
105
+ flyte_handler = None
105
106
  if enable_rich:
106
- handler = get_rich_handler(log_level)
107
- if handler is None:
108
- handler = get_default_handler(log_level)
109
- logger.addHandler(handler)
110
- return logger
107
+ flyte_handler = get_rich_handler(log_level)
108
+
109
+ if flyte_handler is None:
110
+ flyte_handler = logging.StreamHandler()
111
+ flyte_handler.setLevel(log_level)
112
+ formatter = logging.Formatter(fmt="%(message)s")
113
+ flyte_handler.setFormatter(formatter)
114
+
115
+ # Add both filters to Flyte handler
116
+ flyte_handler.addFilter(FlyteInternalFilter())
117
+ flyte_handler.addFilter(ContextFilter())
118
+
119
+ flyte_logger.addHandler(flyte_handler)
120
+ flyte_logger.setLevel(log_level)
121
+ flyte_logger.propagate = False # Prevent double logging
122
+
123
+ logger = flyte_logger
111
124
 
112
125
 
113
126
  def log(fn=None, *, level=logging.DEBUG, entry=True, exit=True):
@@ -135,4 +148,75 @@ def log(fn=None, *, level=logging.DEBUG, entry=True, exit=True):
135
148
  return decorator(fn)
136
149
 
137
150
 
138
- logger = _create_logger("flyte", get_env_log_level())
151
+ class ContextFilter(logging.Filter):
152
+ """
153
+ A logging filter that adds the current action's run name and name to all log records.
154
+ Applied globally to capture context for both user and Flyte internal logging.
155
+ """
156
+
157
+ def filter(self, record):
158
+ from flyte._context import ctx
159
+
160
+ c = ctx()
161
+ if c:
162
+ action = c.action
163
+ record.msg = f"[{action.run_name}][{action.name}] {record.msg}"
164
+ return True
165
+
166
+
167
+ class FlyteInternalFilter(logging.Filter):
168
+ """
169
+ A logging filter that adds [flyte] prefix to internal Flyte logging only.
170
+ """
171
+
172
+ def filter(self, record):
173
+ if record.name.startswith("flyte"):
174
+ record.msg = f"[flyte] {record.msg}"
175
+ return True
176
+
177
+
178
+ def _setup_root_logger():
179
+ """
180
+ Configure the root logger to capture all logging with context information.
181
+ This ensures both user code and Flyte internal logging get the context.
182
+ """
183
+ root = logging.getLogger()
184
+ root.handlers.clear() # Remove any existing handlers to prevent double logging
185
+
186
+ # Create a basic handler for the root logger
187
+ handler = logging.StreamHandler()
188
+ # Add context filter to ALL logging
189
+ handler.addFilter(ContextFilter())
190
+
191
+ # Simple formatter since filters handle prefixes
192
+ root.addHandler(handler)
193
+
194
+
195
+ def _create_flyte_logger() -> logging.Logger:
196
+ """
197
+ Create the internal Flyte logger with [flyte] prefix.
198
+ """
199
+ flyte_logger = logging.getLogger("flyte")
200
+ flyte_logger.setLevel(get_env_log_level())
201
+
202
+ # Add a handler specifically for flyte logging with the prefix filter
203
+ handler = logging.StreamHandler()
204
+ handler.setLevel(get_env_log_level())
205
+ handler.addFilter(FlyteInternalFilter())
206
+ handler.addFilter(ContextFilter())
207
+
208
+ formatter = logging.Formatter(fmt="%(message)s")
209
+ handler.setFormatter(formatter)
210
+
211
+ # Prevent propagation to root to avoid double logging
212
+ flyte_logger.propagate = False
213
+ flyte_logger.addHandler(handler)
214
+
215
+ return flyte_logger
216
+
217
+
218
+ # Initialize root logger for global context
219
+ _setup_root_logger()
220
+
221
+ # Create the Flyte internal logger
222
+ logger = _create_flyte_logger()
flyte/_map.py CHANGED
@@ -1,4 +1,6 @@
1
1
  import asyncio
2
+ import functools
3
+ import logging
2
4
  from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast
3
5
 
4
6
  from flyte.syncify import syncify
@@ -11,7 +13,14 @@ from ._task import P, R, TaskTemplate
11
13
  class MapAsyncIterator(Generic[P, R]):
12
14
  """AsyncIterator implementation for the map function results"""
13
15
 
14
- def __init__(self, func: TaskTemplate[P, R], args: tuple, name: str, concurrency: int, return_exceptions: bool):
16
+ def __init__(
17
+ self,
18
+ func: TaskTemplate[P, R] | functools.partial[R],
19
+ args: tuple,
20
+ name: str,
21
+ concurrency: int,
22
+ return_exceptions: bool,
23
+ ):
15
24
  self.func = func
16
25
  self.args = args
17
26
  self.name = name
@@ -49,13 +58,16 @@ class MapAsyncIterator(Generic[P, R]):
49
58
  return result
50
59
  except Exception as e:
51
60
  self._exception_count += 1
52
- logger.debug(f"Task {self._current_index - 1} failed with exception: {e}")
61
+ logger.debug(
62
+ f"Task {self._current_index - 1} failed with exception: {e}, return_exceptions={self.return_exceptions}"
63
+ )
53
64
  if self.return_exceptions:
54
65
  return e
55
66
  else:
56
67
  # Cancel remaining tasks
57
68
  for remaining_task in self._tasks[self._current_index + 1 :]:
58
69
  remaining_task.cancel()
70
+ logger.warning("Exception raising is `ON`, raising exception and cancelling remaining tasks")
59
71
  raise e
60
72
 
61
73
  async def _initialize(self):
@@ -64,10 +76,26 @@ class MapAsyncIterator(Generic[P, R]):
64
76
  tasks = []
65
77
  task_count = 0
66
78
 
67
- for arg_tuple in zip(*self.args):
68
- task = asyncio.create_task(self.func.aio(*arg_tuple))
69
- tasks.append(task)
70
- task_count += 1
79
+ if isinstance(self.func, functools.partial):
80
+ # Handle partial functions by merging bound args/kwargs with mapped args
81
+ base_func = cast(TaskTemplate, self.func.func)
82
+ bound_args = self.func.args
83
+ bound_kwargs = self.func.keywords or {}
84
+
85
+ for arg_tuple in zip(*self.args):
86
+ # Merge bound positional args with mapped args
87
+ merged_args = bound_args + arg_tuple
88
+ if logger.isEnabledFor(logging.DEBUG):
89
+ logger.debug(f"Running {base_func.name} with args: {merged_args} and kwargs: {bound_kwargs}")
90
+ task = asyncio.create_task(base_func.aio(*merged_args, **bound_kwargs))
91
+ tasks.append(task)
92
+ task_count += 1
93
+ else:
94
+ # Handle regular TaskTemplate functions
95
+ for arg_tuple in zip(*self.args):
96
+ task = asyncio.create_task(self.func.aio(*arg_tuple))
97
+ tasks.append(task)
98
+ task_count += 1
71
99
 
72
100
  if task_count == 0:
73
101
  logger.info(f"Group '{self.name}' has no tasks to process")
@@ -107,9 +135,46 @@ class _Mapper(Generic[P, R]):
107
135
  """Get the name of the group, defaulting to 'map' if not provided."""
108
136
  return f"{task_name}_{group_name or 'map'}"
109
137
 
138
+ @staticmethod
139
+ def validate_partial(func: functools.partial[R]):
140
+ """
141
+ This method validates that the provided partial function is valid for mapping, i.e. only the one argument
142
+ is left for mapping and the rest are provided as keywords or args.
143
+
144
+ :param func: partial function to validate
145
+ :raises TypeError: if the partial function is not valid for mapping
146
+ """
147
+ f = cast(TaskTemplate, func.func)
148
+ inputs = f.native_interface.inputs
149
+ params = list(inputs.keys())
150
+ total_params = len(params)
151
+ provided_args = len(func.args)
152
+ provided_kwargs = len(func.keywords or {})
153
+
154
+ # Calculate how many parameters are left unspecified
155
+ unspecified_count = total_params - provided_args - provided_kwargs
156
+
157
+ # Exactly one parameter should be left for mapping
158
+ if unspecified_count != 1:
159
+ raise TypeError(
160
+ f"Partial function must leave exactly one parameter unspecified for mapping. "
161
+ f"Found {unspecified_count} unspecified parameters in {f.name}, "
162
+ f"params: {inputs.keys()}"
163
+ )
164
+
165
+ # Validate that no parameter is both in args and keywords
166
+ if func.keywords:
167
+ param_names = list(inputs.keys())
168
+ for i, arg_name in enumerate(param_names[: provided_args + 1]):
169
+ if arg_name in func.keywords:
170
+ raise TypeError(
171
+ f"Parameter '{arg_name}' is provided both as positional argument and keyword argument "
172
+ f"in partial function {f.name}."
173
+ )
174
+
110
175
  def __call__(
111
176
  self,
112
- func: TaskTemplate[P, R],
177
+ func: TaskTemplate[P, R] | functools.partial[R],
113
178
  *args: Iterable[Any],
114
179
  group_name: str | None = None,
115
180
  concurrency: int = 0,
@@ -128,7 +193,13 @@ class _Mapper(Generic[P, R]):
128
193
  if not args:
129
194
  return
130
195
 
131
- name = self._get_name(func.name, group_name)
196
+ if isinstance(func, functools.partial):
197
+ f = cast(TaskTemplate, func.func)
198
+ self.validate_partial(func)
199
+ else:
200
+ f = cast(TaskTemplate, func)
201
+
202
+ name = self._get_name(f.name, group_name)
132
203
  logger.debug(f"Blocking Map for {name}")
133
204
  with group(name):
134
205
  import flyte
@@ -154,7 +225,7 @@ class _Mapper(Generic[P, R]):
154
225
  *args,
155
226
  name=name,
156
227
  concurrency=concurrency,
157
- return_exceptions=True,
228
+ return_exceptions=return_exceptions,
158
229
  ),
159
230
  ):
160
231
  logger.debug(f"Mapped {x}, task {i}")
@@ -163,7 +234,7 @@ class _Mapper(Generic[P, R]):
163
234
 
164
235
  async def aio(
165
236
  self,
166
- func: TaskTemplate[P, R],
237
+ func: TaskTemplate[P, R] | functools.partial[R],
167
238
  *args: Iterable[Any],
168
239
  group_name: str | None = None,
169
240
  concurrency: int = 0,
@@ -171,7 +242,14 @@ class _Mapper(Generic[P, R]):
171
242
  ) -> AsyncGenerator[Union[R, Exception], None]:
172
243
  if not args:
173
244
  return
174
- name = self._get_name(func.name, group_name)
245
+
246
+ if isinstance(func, functools.partial):
247
+ f = cast(TaskTemplate, func.func)
248
+ self.validate_partial(func)
249
+ else:
250
+ f = cast(TaskTemplate, func)
251
+
252
+ name = self._get_name(f.name, group_name)
175
253
  with group(name):
176
254
  import flyte
177
255
 
@@ -199,7 +277,7 @@ class _Mapper(Generic[P, R]):
199
277
 
200
278
  @syncify
201
279
  async def _map(
202
- func: TaskTemplate[P, R],
280
+ func: TaskTemplate[P, R] | functools.partial[R],
203
281
  *args: Iterable[Any],
204
282
  name: str = "map",
205
283
  concurrency: int = 0,
flyte/_pod.py CHANGED
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict, Optional
3
3
 
4
4
  if TYPE_CHECKING:
5
5
  from flyteidl.core.tasks_pb2 import K8sPod
6
- from kubernetes.client import ApiClient, V1PodSpec
6
+ from kubernetes.client import V1PodSpec
7
7
 
8
8
 
9
9
  _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
@@ -21,6 +21,7 @@ class PodTemplate(object):
21
21
 
22
22
  def to_k8s_pod(self) -> "K8sPod":
23
23
  from flyteidl.core.tasks_pb2 import K8sObjectMetadata, K8sPod
24
+ from kubernetes.client import ApiClient
24
25
 
25
26
  return K8sPod(
26
27
  metadata=K8sObjectMetadata(labels=self.labels, annotations=self.annotations),
flyte/_run.py CHANGED
@@ -161,7 +161,12 @@ class _Runner:
161
161
  code_bundle = cached_value.code_bundle
162
162
  image_cache = cached_value.image_cache
163
163
  else:
164
- image_cache = await build_images.aio(cast(Environment, obj.parent_env()))
164
+ if not self._dry_run:
165
+ image_cache = await build_images.aio(cast(Environment, obj.parent_env()))
166
+ else:
167
+ from ._internal.imagebuild.image_builder import ImageCache
168
+
169
+ image_cache = ImageCache(image_lookup={})
165
170
 
166
171
  if self._interactive_mode:
167
172
  code_bundle = await build_pkl_bundle(
flyte/_task.py CHANGED
@@ -258,6 +258,9 @@ class TaskTemplate(Generic[P, R]):
258
258
  else:
259
259
  raise RuntimeSystemError("BadContext", "Controller is not initialized.")
260
260
  else:
261
+ from flyte._logging import logger
262
+
263
+ logger.warning(f"Task {self.name} running aio outside of a task context.")
261
264
  # Local execute, just stay out of the way, but because .aio is used, we want to return an awaitable,
262
265
  # even for synchronous tasks. This is to support migration.
263
266
  return self.forward(*args, **kwargs)
@@ -18,7 +18,7 @@ from typing import (
18
18
 
19
19
  import rich.repr
20
20
 
21
- from ._cache import CacheRequest
21
+ from ._cache import Cache, CacheRequest
22
22
  from ._doc import Documentation
23
23
  from ._environment import Environment
24
24
  from ._image import Image
@@ -74,6 +74,10 @@ class TaskEnvironment(Environment):
74
74
  super().__post_init__()
75
75
  if self.reusable is not None and self.plugin_config is not None:
76
76
  raise ValueError("Cannot set plugin_config when environment is reusable.")
77
+ if self.reusable and not isinstance(self.reusable, ReusePolicy):
78
+ raise TypeError(f"Expected reusable to be of type ReusePolicy, got {type(self.reusable)}")
79
+ if self.cache and not isinstance(self.cache, (str, Cache)):
80
+ raise TypeError(f"Expected cache to be of type str or Cache, got {type(self.cache)}")
77
81
 
78
82
  def clone_with(
79
83
  self,
flyte/_trace.py CHANGED
@@ -3,6 +3,7 @@ import inspect
3
3
  import time
4
4
  from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Callable, TypeGuard, TypeVar, Union, cast
5
5
 
6
+ from flyte._logging import logger
6
7
  from flyte.models import NativeInterface
7
8
 
8
9
  T = TypeVar("T")
@@ -33,10 +34,13 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
33
34
  iface = NativeInterface.from_callable(func)
34
35
  info, ok = await controller.get_action_outputs(iface, func, *args, **kwargs)
35
36
  if ok:
37
+ logger.info(f"Found existing trace info for {func}, {info}")
36
38
  if info.output:
37
39
  return info.output
38
40
  elif info.error:
39
41
  raise info.error
42
+ else:
43
+ logger.debug(f"No existing trace info found for {func}, proceeding to execute.")
40
44
  start_time = time.time()
41
45
  try:
42
46
  # Cast to Awaitable to satisfy mypy
@@ -44,6 +48,7 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
44
48
  results = await coroutine_result
45
49
  info.add_outputs(results, start_time=start_time, end_time=time.time())
46
50
  await controller.record_trace(info)
51
+ logger.debug(f"Finished trace for {func}, {info}")
47
52
  return results
48
53
  except Exception as e:
49
54
  # If there is an error, we need to record it
flyte/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.0.0b17'
32
- __version_tuple__ = version_tuple = (2, 0, 0, 'b17')
31
+ __version__ = version = '2.0.0b19'
32
+ __version_tuple__ = version_tuple = (2, 0, 0, 'b19')
33
33
 
34
- __commit_id__ = commit_id = 'gfe0ca1266'
34
+ __commit_id__ = commit_id = 'g172a8c6b7'
flyte/cli/_create.py CHANGED
@@ -98,7 +98,7 @@ def secret(
98
98
  "-o",
99
99
  "--output",
100
100
  type=click.Path(exists=False, writable=True),
101
- default=Path.cwd() / "config.yaml",
101
+ default=Path.cwd() / ".flyte" / "config.yaml",
102
102
  help="Path to the output directory where the configuration will be saved. Defaults to current directory.",
103
103
  show_default=True,
104
104
  )
@@ -147,6 +147,9 @@ def config(
147
147
 
148
148
  output_path = Path(output)
149
149
 
150
+ if not output_path.parent.exists():
151
+ output_path.parent.mkdir(parents=True)
152
+
150
153
  if output_path.exists() and not force:
151
154
  force = click.confirm(f"Overwrite [{output_path}]?", default=False)
152
155
  if not force:
flyte/cli/_deploy.py CHANGED
@@ -4,8 +4,7 @@ from pathlib import Path
4
4
  from types import ModuleType
5
5
  from typing import Any, Dict, List, cast, get_args
6
6
 
7
- import click
8
- from click import Context
7
+ import rich_click as click
9
8
 
10
9
  import flyte
11
10
 
@@ -87,14 +86,14 @@ class DeployArguments:
87
86
  return [common.get_option_from_metadata(f.metadata) for f in fields(cls) if f.metadata]
88
87
 
89
88
 
90
- class DeployEnvCommand(click.Command):
89
+ class DeployEnvCommand(click.RichCommand):
91
90
  def __init__(self, env_name: str, env: Any, deploy_args: DeployArguments, *args, **kwargs):
92
91
  self.env_name = env_name
93
92
  self.env = env
94
93
  self.deploy_args = deploy_args
95
94
  super().__init__(*args, **kwargs)
96
95
 
97
- def invoke(self, ctx: Context):
96
+ def invoke(self, ctx: click.Context):
98
97
  from rich.console import Console
99
98
 
100
99
  console = Console()
@@ -125,7 +124,7 @@ class DeployEnvRecursiveCommand(click.Command):
125
124
  self.deploy_args = deploy_args
126
125
  super().__init__(*args, **kwargs)
127
126
 
128
- def invoke(self, ctx: Context):
127
+ def invoke(self, ctx: click.Context):
129
128
  from rich.console import Console
130
129
 
131
130
  from flyte._environment import list_loaded_environments
flyte/cli/_params.py CHANGED
@@ -283,13 +283,17 @@ class UnionParamType(click.ParamType):
283
283
  A composite type that allows for multiple types to be specified. This is used for union types.
284
284
  """
285
285
 
286
- def __init__(self, types: typing.List[click.ParamType]):
286
+ def __init__(self, types: typing.List[click.ParamType | None]):
287
287
  super().__init__()
288
288
  self._types = self._sort_precedence(types)
289
- self.name = "|".join([t.name for t in self._types])
289
+ self.name = "|".join([t.name for t in self._types if t is not None])
290
+ self.optional = False
291
+ if None in types:
292
+ self.name = f"Optional[{self.name}]"
293
+ self.optional = True
290
294
 
291
295
  @staticmethod
292
- def _sort_precedence(tp: typing.List[click.ParamType]) -> typing.List[click.ParamType]:
296
+ def _sort_precedence(tp: typing.List[click.ParamType | None]) -> typing.List[click.ParamType]:
293
297
  unprocessed = []
294
298
  str_types = []
295
299
  others = []
@@ -311,6 +315,8 @@ class UnionParamType(click.ParamType):
311
315
  """
312
316
  for p in self._types:
313
317
  try:
318
+ if p is None and value is None:
319
+ return None
314
320
  return p.convert(value, param, ctx)
315
321
  except Exception as e:
316
322
  logger.debug(f"Ignoring conversion error for type {p} trying other variants in Union. Error: {e}")
@@ -433,7 +439,10 @@ def literal_type_to_click_type(lt: LiteralType, python_type: typing.Type) -> cli
433
439
  for i in range(len(lt.union_type.variants)):
434
440
  variant = lt.union_type.variants[i]
435
441
  variant_python_type = typing.get_args(python_type)[i]
436
- cts.append(literal_type_to_click_type(variant, variant_python_type))
442
+ if variant_python_type is type(None):
443
+ cts.append(None)
444
+ else:
445
+ cts.append(literal_type_to_click_type(variant, variant_python_type))
437
446
  return UnionParamType(cts)
438
447
 
439
448
  if lt.HasField("enum_type"):
@@ -461,6 +470,9 @@ class FlyteLiteralConverter(object):
461
470
  def is_bool(self) -> bool:
462
471
  return self.click_type == click.BOOL
463
472
 
473
+ def is_optional(self) -> bool:
474
+ return isinstance(self.click_type, UnionParamType) and self.click_type.optional
475
+
464
476
  def convert(
465
477
  self, ctx: click.Context, param: typing.Optional[click.Parameter], value: typing.Any
466
478
  ) -> typing.Union[Literal, typing.Any]:
@@ -525,6 +537,8 @@ def to_click_option(
525
537
  if literal_converter.is_bool():
526
538
  required = False
527
539
  is_flag = True
540
+ if literal_converter.is_optional():
541
+ required = False
528
542
 
529
543
  return click.Option(
530
544
  param_decls=[f"--{input_name}"],
flyte/cli/_run.py CHANGED
@@ -112,7 +112,7 @@ class RunArguments:
112
112
  return [common.get_option_from_metadata(f.metadata) for f in fields(cls) if f.metadata]
113
113
 
114
114
 
115
- class RunTaskCommand(click.Command):
115
+ class RunTaskCommand(click.RichCommand):
116
116
  def __init__(self, obj_name: str, obj: Any, run_args: RunArguments, *args, **kwargs):
117
117
  self.obj_name = obj_name
118
118
  self.obj = cast(TaskTemplate, obj)
@@ -196,7 +196,7 @@ class TaskPerFileGroup(common.ObjectsPerFileGroup):
196
196
  )
197
197
 
198
198
 
199
- class RunReferenceTaskCommand(click.Command):
199
+ class RunReferenceTaskCommand(click.RichCommand):
200
200
  def __init__(self, task_name: str, run_args: RunArguments, version: str | None, *args, **kwargs):
201
201
  self.task_name = task_name
202
202
  self.run_args = run_args