flyte 2.0.0b17__py3-none-any.whl → 2.0.0b19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/_bin/runtime.py +3 -0
- flyte/_debug/vscode.py +4 -2
- flyte/_deploy.py +3 -1
- flyte/_environment.py +15 -6
- flyte/_hash.py +1 -16
- flyte/_image.py +6 -1
- flyte/_initialize.py +15 -16
- flyte/_internal/controllers/__init__.py +4 -5
- flyte/_internal/controllers/_local_controller.py +5 -5
- flyte/_internal/controllers/remote/_controller.py +21 -28
- flyte/_internal/controllers/remote/_core.py +1 -1
- flyte/_internal/imagebuild/docker_builder.py +31 -23
- flyte/_internal/imagebuild/remote_builder.py +37 -10
- flyte/_internal/imagebuild/utils.py +2 -1
- flyte/_internal/runtime/convert.py +69 -2
- flyte/_internal/runtime/taskrunner.py +4 -1
- flyte/_logging.py +110 -26
- flyte/_map.py +90 -12
- flyte/_pod.py +2 -1
- flyte/_run.py +6 -1
- flyte/_task.py +3 -0
- flyte/_task_environment.py +5 -1
- flyte/_trace.py +5 -0
- flyte/_version.py +3 -3
- flyte/cli/_create.py +4 -1
- flyte/cli/_deploy.py +4 -5
- flyte/cli/_params.py +18 -4
- flyte/cli/_run.py +2 -2
- flyte/config/_config.py +2 -2
- flyte/config/_reader.py +14 -8
- flyte/errors.py +3 -1
- flyte/git/__init__.py +3 -0
- flyte/git/_config.py +17 -0
- flyte/io/_dataframe/basic_dfs.py +16 -7
- flyte/io/_dataframe/dataframe.py +84 -123
- flyte/io/_dir.py +35 -4
- flyte/io/_file.py +61 -15
- flyte/io/_hashing_io.py +342 -0
- flyte/models.py +12 -4
- flyte/remote/_action.py +4 -2
- flyte/remote/_task.py +52 -22
- flyte/report/_report.py +1 -1
- flyte/storage/_storage.py +16 -1
- flyte/types/_type_engine.py +1 -51
- {flyte-2.0.0b17.data → flyte-2.0.0b19.data}/scripts/runtime.py +3 -0
- {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/METADATA +1 -1
- {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/RECORD +52 -49
- {flyte-2.0.0b17.data → flyte-2.0.0b19.data}/scripts/debug.py +0 -0
- {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/WHEEL +0 -0
- {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/entry_points.txt +0 -0
- {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/licenses/LICENSE +0 -0
- {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/top_level.txt +0 -0
|
@@ -308,15 +308,82 @@ def generate_inputs_hash(serialized_inputs: str | bytes) -> str:
|
|
|
308
308
|
return hash_data(serialized_inputs)
|
|
309
309
|
|
|
310
310
|
|
|
311
|
+
def generate_inputs_repr_for_literal(literal: literals_pb2.Literal) -> bytes:
|
|
312
|
+
"""
|
|
313
|
+
Generate a byte representation for a single literal that is meant to be hashed as part of the cache key
|
|
314
|
+
computation for an Action. This function should just serialize the literal deterministically, but will
|
|
315
|
+
use an existing hash value if present in the Literal. This is trivial, except we need to handle nested literals
|
|
316
|
+
(inside collections and maps), that may have the hash property set.
|
|
317
|
+
|
|
318
|
+
:param literal: The literal to get a hashable representation for.
|
|
319
|
+
:return: byte representation of the literal that can be fed into a hash function.
|
|
320
|
+
"""
|
|
321
|
+
# If the literal has a hash value, use that instead of serializing the full literal
|
|
322
|
+
if literal.hash:
|
|
323
|
+
return literal.hash.encode("utf-8")
|
|
324
|
+
|
|
325
|
+
if literal.HasField("collection"):
|
|
326
|
+
buf = bytearray()
|
|
327
|
+
for nested_literal in literal.collection.literals:
|
|
328
|
+
if nested_literal.hash:
|
|
329
|
+
buf += nested_literal.hash.encode("utf-8")
|
|
330
|
+
else:
|
|
331
|
+
buf += generate_inputs_repr_for_literal(nested_literal)
|
|
332
|
+
|
|
333
|
+
b = bytes(buf)
|
|
334
|
+
return b
|
|
335
|
+
|
|
336
|
+
elif literal.HasField("map"):
|
|
337
|
+
buf = bytearray()
|
|
338
|
+
# Sort keys to ensure deterministic ordering
|
|
339
|
+
for key in sorted(literal.map.literals.keys()):
|
|
340
|
+
nested_literal = literal.map.literals[key]
|
|
341
|
+
buf += key.encode("utf-8")
|
|
342
|
+
if nested_literal.hash:
|
|
343
|
+
buf += nested_literal.hash.encode("utf-8")
|
|
344
|
+
else:
|
|
345
|
+
buf += generate_inputs_repr_for_literal(nested_literal)
|
|
346
|
+
|
|
347
|
+
b = bytes(buf)
|
|
348
|
+
return b
|
|
349
|
+
|
|
350
|
+
# For all other cases (scalars, etc.), just serialize the literal normally
|
|
351
|
+
return literal.SerializeToString(deterministic=True)
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def generate_inputs_hash_for_named_literals(inputs: list[run_definition_pb2.NamedLiteral]) -> str:
|
|
355
|
+
"""
|
|
356
|
+
Generate a hash for the inputs using the new literal representation approach that respects
|
|
357
|
+
hash values already present in literals. This is used to uniquely identify the inputs for a task
|
|
358
|
+
when some literals may have precomputed hash values.
|
|
359
|
+
|
|
360
|
+
:param inputs: List of NamedLiteral inputs to hash.
|
|
361
|
+
:return: A base64-encoded string representation of the hash.
|
|
362
|
+
"""
|
|
363
|
+
if not inputs:
|
|
364
|
+
return ""
|
|
365
|
+
|
|
366
|
+
# Build the byte representation by concatenating each literal's representation
|
|
367
|
+
combined_bytes = b""
|
|
368
|
+
for named_literal in inputs:
|
|
369
|
+
# Add the name to ensure order matters
|
|
370
|
+
name_bytes = named_literal.name.encode("utf-8")
|
|
371
|
+
literal_bytes = generate_inputs_repr_for_literal(named_literal.value)
|
|
372
|
+
# Combine name and literal bytes with a separator to avoid collisions
|
|
373
|
+
combined_bytes += name_bytes + b":" + literal_bytes + b";"
|
|
374
|
+
|
|
375
|
+
return hash_data(combined_bytes)
|
|
376
|
+
|
|
377
|
+
|
|
311
378
|
def generate_inputs_hash_from_proto(inputs: run_definition_pb2.Inputs) -> str:
|
|
312
379
|
"""
|
|
313
380
|
Generate a hash for the inputs. This is used to uniquely identify the inputs for a task.
|
|
314
381
|
:param inputs: The inputs to hash.
|
|
315
382
|
:return: A hexadecimal string representation of the hash.
|
|
316
383
|
"""
|
|
317
|
-
if not inputs:
|
|
384
|
+
if not inputs or not inputs.literals:
|
|
318
385
|
return ""
|
|
319
|
-
return
|
|
386
|
+
return generate_inputs_hash_for_named_literals(list(inputs.literals))
|
|
320
387
|
|
|
321
388
|
|
|
322
389
|
def generate_interface_hash(task_interface: interface_pb2.TypedInterface) -> str:
|
|
@@ -4,6 +4,7 @@ invoked within a context tree.
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
import pathlib
|
|
7
|
+
import time
|
|
7
8
|
from typing import Any, Dict, List, Optional, Tuple
|
|
8
9
|
|
|
9
10
|
import flyte.report
|
|
@@ -172,6 +173,8 @@ async def extract_download_run_upload(
|
|
|
172
173
|
This method is invoked from the CLI (urun) and is used to run a task. This assumes that the context tree
|
|
173
174
|
has already been created, and the task has been loaded. It also handles the loading of the task.
|
|
174
175
|
"""
|
|
176
|
+
t = time.time()
|
|
177
|
+
logger.warning(f"Task {action.name} started at {t}")
|
|
175
178
|
outputs, err = await convert_and_run(
|
|
176
179
|
task=task,
|
|
177
180
|
input_path=input_path,
|
|
@@ -194,4 +197,4 @@ async def extract_download_run_upload(
|
|
|
194
197
|
logger.info(f"Task {task.name} completed successfully, no outputs")
|
|
195
198
|
return
|
|
196
199
|
await upload_outputs(outputs, output_path) if output_path else None
|
|
197
|
-
logger.
|
|
200
|
+
logger.warning(f"Task {task.name} completed successfully, uploaded outputs to {output_path} in {time.time() - t}s")
|
flyte/_logging.py
CHANGED
|
@@ -75,39 +75,52 @@ def get_rich_handler(log_level: int) -> Optional[logging.Handler]:
|
|
|
75
75
|
return handler
|
|
76
76
|
|
|
77
77
|
|
|
78
|
-
def get_default_handler(log_level: int) -> logging.Handler:
|
|
79
|
-
handler = logging.StreamHandler()
|
|
80
|
-
handler.setLevel(log_level)
|
|
81
|
-
formatter = logging.Formatter(fmt="[%(name)s] %(message)s")
|
|
82
|
-
if log_format_from_env() == "json":
|
|
83
|
-
pass
|
|
84
|
-
# formatter = jsonlogger.JsonFormatter(fmt="%(asctime)s %(name)s %(levelname)s %(message)s")
|
|
85
|
-
handler.setFormatter(formatter)
|
|
86
|
-
return handler
|
|
87
|
-
|
|
88
|
-
|
|
89
78
|
def initialize_logger(log_level: int = get_env_log_level(), enable_rich: bool = False):
|
|
90
79
|
"""
|
|
91
80
|
Initializes the global loggers to the default configuration.
|
|
81
|
+
When enable_rich=True, upgrades to Rich handler for local CLI usage.
|
|
92
82
|
"""
|
|
93
83
|
global logger # noqa: PLW0603
|
|
94
|
-
logger = _create_logger("flyte", log_level, enable_rich)
|
|
95
84
|
|
|
85
|
+
# Clear existing handlers to reconfigure
|
|
86
|
+
root = logging.getLogger()
|
|
87
|
+
root.handlers.clear()
|
|
96
88
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
89
|
+
flyte_logger = logging.getLogger("flyte")
|
|
90
|
+
flyte_logger.handlers.clear()
|
|
91
|
+
|
|
92
|
+
# Set up root logger handler
|
|
93
|
+
root_handler = None
|
|
94
|
+
if enable_rich:
|
|
95
|
+
root_handler = get_rich_handler(log_level)
|
|
96
|
+
|
|
97
|
+
if root_handler is None:
|
|
98
|
+
root_handler = logging.StreamHandler()
|
|
99
|
+
|
|
100
|
+
# Add context filter to root handler for all logging
|
|
101
|
+
root_handler.addFilter(ContextFilter())
|
|
102
|
+
root.addHandler(root_handler)
|
|
103
|
+
|
|
104
|
+
# Set up Flyte logger handler
|
|
105
|
+
flyte_handler = None
|
|
105
106
|
if enable_rich:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
107
|
+
flyte_handler = get_rich_handler(log_level)
|
|
108
|
+
|
|
109
|
+
if flyte_handler is None:
|
|
110
|
+
flyte_handler = logging.StreamHandler()
|
|
111
|
+
flyte_handler.setLevel(log_level)
|
|
112
|
+
formatter = logging.Formatter(fmt="%(message)s")
|
|
113
|
+
flyte_handler.setFormatter(formatter)
|
|
114
|
+
|
|
115
|
+
# Add both filters to Flyte handler
|
|
116
|
+
flyte_handler.addFilter(FlyteInternalFilter())
|
|
117
|
+
flyte_handler.addFilter(ContextFilter())
|
|
118
|
+
|
|
119
|
+
flyte_logger.addHandler(flyte_handler)
|
|
120
|
+
flyte_logger.setLevel(log_level)
|
|
121
|
+
flyte_logger.propagate = False # Prevent double logging
|
|
122
|
+
|
|
123
|
+
logger = flyte_logger
|
|
111
124
|
|
|
112
125
|
|
|
113
126
|
def log(fn=None, *, level=logging.DEBUG, entry=True, exit=True):
|
|
@@ -135,4 +148,75 @@ def log(fn=None, *, level=logging.DEBUG, entry=True, exit=True):
|
|
|
135
148
|
return decorator(fn)
|
|
136
149
|
|
|
137
150
|
|
|
138
|
-
|
|
151
|
+
class ContextFilter(logging.Filter):
|
|
152
|
+
"""
|
|
153
|
+
A logging filter that adds the current action's run name and name to all log records.
|
|
154
|
+
Applied globally to capture context for both user and Flyte internal logging.
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
def filter(self, record):
|
|
158
|
+
from flyte._context import ctx
|
|
159
|
+
|
|
160
|
+
c = ctx()
|
|
161
|
+
if c:
|
|
162
|
+
action = c.action
|
|
163
|
+
record.msg = f"[{action.run_name}][{action.name}] {record.msg}"
|
|
164
|
+
return True
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class FlyteInternalFilter(logging.Filter):
|
|
168
|
+
"""
|
|
169
|
+
A logging filter that adds [flyte] prefix to internal Flyte logging only.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def filter(self, record):
|
|
173
|
+
if record.name.startswith("flyte"):
|
|
174
|
+
record.msg = f"[flyte] {record.msg}"
|
|
175
|
+
return True
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _setup_root_logger():
|
|
179
|
+
"""
|
|
180
|
+
Configure the root logger to capture all logging with context information.
|
|
181
|
+
This ensures both user code and Flyte internal logging get the context.
|
|
182
|
+
"""
|
|
183
|
+
root = logging.getLogger()
|
|
184
|
+
root.handlers.clear() # Remove any existing handlers to prevent double logging
|
|
185
|
+
|
|
186
|
+
# Create a basic handler for the root logger
|
|
187
|
+
handler = logging.StreamHandler()
|
|
188
|
+
# Add context filter to ALL logging
|
|
189
|
+
handler.addFilter(ContextFilter())
|
|
190
|
+
|
|
191
|
+
# Simple formatter since filters handle prefixes
|
|
192
|
+
root.addHandler(handler)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _create_flyte_logger() -> logging.Logger:
|
|
196
|
+
"""
|
|
197
|
+
Create the internal Flyte logger with [flyte] prefix.
|
|
198
|
+
"""
|
|
199
|
+
flyte_logger = logging.getLogger("flyte")
|
|
200
|
+
flyte_logger.setLevel(get_env_log_level())
|
|
201
|
+
|
|
202
|
+
# Add a handler specifically for flyte logging with the prefix filter
|
|
203
|
+
handler = logging.StreamHandler()
|
|
204
|
+
handler.setLevel(get_env_log_level())
|
|
205
|
+
handler.addFilter(FlyteInternalFilter())
|
|
206
|
+
handler.addFilter(ContextFilter())
|
|
207
|
+
|
|
208
|
+
formatter = logging.Formatter(fmt="%(message)s")
|
|
209
|
+
handler.setFormatter(formatter)
|
|
210
|
+
|
|
211
|
+
# Prevent propagation to root to avoid double logging
|
|
212
|
+
flyte_logger.propagate = False
|
|
213
|
+
flyte_logger.addHandler(handler)
|
|
214
|
+
|
|
215
|
+
return flyte_logger
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
# Initialize root logger for global context
|
|
219
|
+
_setup_root_logger()
|
|
220
|
+
|
|
221
|
+
# Create the Flyte internal logger
|
|
222
|
+
logger = _create_flyte_logger()
|
flyte/_map.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import functools
|
|
3
|
+
import logging
|
|
2
4
|
from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast
|
|
3
5
|
|
|
4
6
|
from flyte.syncify import syncify
|
|
@@ -11,7 +13,14 @@ from ._task import P, R, TaskTemplate
|
|
|
11
13
|
class MapAsyncIterator(Generic[P, R]):
|
|
12
14
|
"""AsyncIterator implementation for the map function results"""
|
|
13
15
|
|
|
14
|
-
def __init__(
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
func: TaskTemplate[P, R] | functools.partial[R],
|
|
19
|
+
args: tuple,
|
|
20
|
+
name: str,
|
|
21
|
+
concurrency: int,
|
|
22
|
+
return_exceptions: bool,
|
|
23
|
+
):
|
|
15
24
|
self.func = func
|
|
16
25
|
self.args = args
|
|
17
26
|
self.name = name
|
|
@@ -49,13 +58,16 @@ class MapAsyncIterator(Generic[P, R]):
|
|
|
49
58
|
return result
|
|
50
59
|
except Exception as e:
|
|
51
60
|
self._exception_count += 1
|
|
52
|
-
logger.debug(
|
|
61
|
+
logger.debug(
|
|
62
|
+
f"Task {self._current_index - 1} failed with exception: {e}, return_exceptions={self.return_exceptions}"
|
|
63
|
+
)
|
|
53
64
|
if self.return_exceptions:
|
|
54
65
|
return e
|
|
55
66
|
else:
|
|
56
67
|
# Cancel remaining tasks
|
|
57
68
|
for remaining_task in self._tasks[self._current_index + 1 :]:
|
|
58
69
|
remaining_task.cancel()
|
|
70
|
+
logger.warning("Exception raising is `ON`, raising exception and cancelling remaining tasks")
|
|
59
71
|
raise e
|
|
60
72
|
|
|
61
73
|
async def _initialize(self):
|
|
@@ -64,10 +76,26 @@ class MapAsyncIterator(Generic[P, R]):
|
|
|
64
76
|
tasks = []
|
|
65
77
|
task_count = 0
|
|
66
78
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
79
|
+
if isinstance(self.func, functools.partial):
|
|
80
|
+
# Handle partial functions by merging bound args/kwargs with mapped args
|
|
81
|
+
base_func = cast(TaskTemplate, self.func.func)
|
|
82
|
+
bound_args = self.func.args
|
|
83
|
+
bound_kwargs = self.func.keywords or {}
|
|
84
|
+
|
|
85
|
+
for arg_tuple in zip(*self.args):
|
|
86
|
+
# Merge bound positional args with mapped args
|
|
87
|
+
merged_args = bound_args + arg_tuple
|
|
88
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
89
|
+
logger.debug(f"Running {base_func.name} with args: {merged_args} and kwargs: {bound_kwargs}")
|
|
90
|
+
task = asyncio.create_task(base_func.aio(*merged_args, **bound_kwargs))
|
|
91
|
+
tasks.append(task)
|
|
92
|
+
task_count += 1
|
|
93
|
+
else:
|
|
94
|
+
# Handle regular TaskTemplate functions
|
|
95
|
+
for arg_tuple in zip(*self.args):
|
|
96
|
+
task = asyncio.create_task(self.func.aio(*arg_tuple))
|
|
97
|
+
tasks.append(task)
|
|
98
|
+
task_count += 1
|
|
71
99
|
|
|
72
100
|
if task_count == 0:
|
|
73
101
|
logger.info(f"Group '{self.name}' has no tasks to process")
|
|
@@ -107,9 +135,46 @@ class _Mapper(Generic[P, R]):
|
|
|
107
135
|
"""Get the name of the group, defaulting to 'map' if not provided."""
|
|
108
136
|
return f"{task_name}_{group_name or 'map'}"
|
|
109
137
|
|
|
138
|
+
@staticmethod
|
|
139
|
+
def validate_partial(func: functools.partial[R]):
|
|
140
|
+
"""
|
|
141
|
+
This method validates that the provided partial function is valid for mapping, i.e. only the one argument
|
|
142
|
+
is left for mapping and the rest are provided as keywords or args.
|
|
143
|
+
|
|
144
|
+
:param func: partial function to validate
|
|
145
|
+
:raises TypeError: if the partial function is not valid for mapping
|
|
146
|
+
"""
|
|
147
|
+
f = cast(TaskTemplate, func.func)
|
|
148
|
+
inputs = f.native_interface.inputs
|
|
149
|
+
params = list(inputs.keys())
|
|
150
|
+
total_params = len(params)
|
|
151
|
+
provided_args = len(func.args)
|
|
152
|
+
provided_kwargs = len(func.keywords or {})
|
|
153
|
+
|
|
154
|
+
# Calculate how many parameters are left unspecified
|
|
155
|
+
unspecified_count = total_params - provided_args - provided_kwargs
|
|
156
|
+
|
|
157
|
+
# Exactly one parameter should be left for mapping
|
|
158
|
+
if unspecified_count != 1:
|
|
159
|
+
raise TypeError(
|
|
160
|
+
f"Partial function must leave exactly one parameter unspecified for mapping. "
|
|
161
|
+
f"Found {unspecified_count} unspecified parameters in {f.name}, "
|
|
162
|
+
f"params: {inputs.keys()}"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Validate that no parameter is both in args and keywords
|
|
166
|
+
if func.keywords:
|
|
167
|
+
param_names = list(inputs.keys())
|
|
168
|
+
for i, arg_name in enumerate(param_names[: provided_args + 1]):
|
|
169
|
+
if arg_name in func.keywords:
|
|
170
|
+
raise TypeError(
|
|
171
|
+
f"Parameter '{arg_name}' is provided both as positional argument and keyword argument "
|
|
172
|
+
f"in partial function {f.name}."
|
|
173
|
+
)
|
|
174
|
+
|
|
110
175
|
def __call__(
|
|
111
176
|
self,
|
|
112
|
-
func: TaskTemplate[P, R],
|
|
177
|
+
func: TaskTemplate[P, R] | functools.partial[R],
|
|
113
178
|
*args: Iterable[Any],
|
|
114
179
|
group_name: str | None = None,
|
|
115
180
|
concurrency: int = 0,
|
|
@@ -128,7 +193,13 @@ class _Mapper(Generic[P, R]):
|
|
|
128
193
|
if not args:
|
|
129
194
|
return
|
|
130
195
|
|
|
131
|
-
|
|
196
|
+
if isinstance(func, functools.partial):
|
|
197
|
+
f = cast(TaskTemplate, func.func)
|
|
198
|
+
self.validate_partial(func)
|
|
199
|
+
else:
|
|
200
|
+
f = cast(TaskTemplate, func)
|
|
201
|
+
|
|
202
|
+
name = self._get_name(f.name, group_name)
|
|
132
203
|
logger.debug(f"Blocking Map for {name}")
|
|
133
204
|
with group(name):
|
|
134
205
|
import flyte
|
|
@@ -154,7 +225,7 @@ class _Mapper(Generic[P, R]):
|
|
|
154
225
|
*args,
|
|
155
226
|
name=name,
|
|
156
227
|
concurrency=concurrency,
|
|
157
|
-
return_exceptions=
|
|
228
|
+
return_exceptions=return_exceptions,
|
|
158
229
|
),
|
|
159
230
|
):
|
|
160
231
|
logger.debug(f"Mapped {x}, task {i}")
|
|
@@ -163,7 +234,7 @@ class _Mapper(Generic[P, R]):
|
|
|
163
234
|
|
|
164
235
|
async def aio(
|
|
165
236
|
self,
|
|
166
|
-
func: TaskTemplate[P, R],
|
|
237
|
+
func: TaskTemplate[P, R] | functools.partial[R],
|
|
167
238
|
*args: Iterable[Any],
|
|
168
239
|
group_name: str | None = None,
|
|
169
240
|
concurrency: int = 0,
|
|
@@ -171,7 +242,14 @@ class _Mapper(Generic[P, R]):
|
|
|
171
242
|
) -> AsyncGenerator[Union[R, Exception], None]:
|
|
172
243
|
if not args:
|
|
173
244
|
return
|
|
174
|
-
|
|
245
|
+
|
|
246
|
+
if isinstance(func, functools.partial):
|
|
247
|
+
f = cast(TaskTemplate, func.func)
|
|
248
|
+
self.validate_partial(func)
|
|
249
|
+
else:
|
|
250
|
+
f = cast(TaskTemplate, func)
|
|
251
|
+
|
|
252
|
+
name = self._get_name(f.name, group_name)
|
|
175
253
|
with group(name):
|
|
176
254
|
import flyte
|
|
177
255
|
|
|
@@ -199,7 +277,7 @@ class _Mapper(Generic[P, R]):
|
|
|
199
277
|
|
|
200
278
|
@syncify
|
|
201
279
|
async def _map(
|
|
202
|
-
func: TaskTemplate[P, R],
|
|
280
|
+
func: TaskTemplate[P, R] | functools.partial[R],
|
|
203
281
|
*args: Iterable[Any],
|
|
204
282
|
name: str = "map",
|
|
205
283
|
concurrency: int = 0,
|
flyte/_pod.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict, Optional
|
|
|
3
3
|
|
|
4
4
|
if TYPE_CHECKING:
|
|
5
5
|
from flyteidl.core.tasks_pb2 import K8sPod
|
|
6
|
-
from kubernetes.client import
|
|
6
|
+
from kubernetes.client import V1PodSpec
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
|
|
@@ -21,6 +21,7 @@ class PodTemplate(object):
|
|
|
21
21
|
|
|
22
22
|
def to_k8s_pod(self) -> "K8sPod":
|
|
23
23
|
from flyteidl.core.tasks_pb2 import K8sObjectMetadata, K8sPod
|
|
24
|
+
from kubernetes.client import ApiClient
|
|
24
25
|
|
|
25
26
|
return K8sPod(
|
|
26
27
|
metadata=K8sObjectMetadata(labels=self.labels, annotations=self.annotations),
|
flyte/_run.py
CHANGED
|
@@ -161,7 +161,12 @@ class _Runner:
|
|
|
161
161
|
code_bundle = cached_value.code_bundle
|
|
162
162
|
image_cache = cached_value.image_cache
|
|
163
163
|
else:
|
|
164
|
-
|
|
164
|
+
if not self._dry_run:
|
|
165
|
+
image_cache = await build_images.aio(cast(Environment, obj.parent_env()))
|
|
166
|
+
else:
|
|
167
|
+
from ._internal.imagebuild.image_builder import ImageCache
|
|
168
|
+
|
|
169
|
+
image_cache = ImageCache(image_lookup={})
|
|
165
170
|
|
|
166
171
|
if self._interactive_mode:
|
|
167
172
|
code_bundle = await build_pkl_bundle(
|
flyte/_task.py
CHANGED
|
@@ -258,6 +258,9 @@ class TaskTemplate(Generic[P, R]):
|
|
|
258
258
|
else:
|
|
259
259
|
raise RuntimeSystemError("BadContext", "Controller is not initialized.")
|
|
260
260
|
else:
|
|
261
|
+
from flyte._logging import logger
|
|
262
|
+
|
|
263
|
+
logger.warning(f"Task {self.name} running aio outside of a task context.")
|
|
261
264
|
# Local execute, just stay out of the way, but because .aio is used, we want to return an awaitable,
|
|
262
265
|
# even for synchronous tasks. This is to support migration.
|
|
263
266
|
return self.forward(*args, **kwargs)
|
flyte/_task_environment.py
CHANGED
|
@@ -18,7 +18,7 @@ from typing import (
|
|
|
18
18
|
|
|
19
19
|
import rich.repr
|
|
20
20
|
|
|
21
|
-
from ._cache import CacheRequest
|
|
21
|
+
from ._cache import Cache, CacheRequest
|
|
22
22
|
from ._doc import Documentation
|
|
23
23
|
from ._environment import Environment
|
|
24
24
|
from ._image import Image
|
|
@@ -74,6 +74,10 @@ class TaskEnvironment(Environment):
|
|
|
74
74
|
super().__post_init__()
|
|
75
75
|
if self.reusable is not None and self.plugin_config is not None:
|
|
76
76
|
raise ValueError("Cannot set plugin_config when environment is reusable.")
|
|
77
|
+
if self.reusable and not isinstance(self.reusable, ReusePolicy):
|
|
78
|
+
raise TypeError(f"Expected reusable to be of type ReusePolicy, got {type(self.reusable)}")
|
|
79
|
+
if self.cache and not isinstance(self.cache, (str, Cache)):
|
|
80
|
+
raise TypeError(f"Expected cache to be of type str or Cache, got {type(self.cache)}")
|
|
77
81
|
|
|
78
82
|
def clone_with(
|
|
79
83
|
self,
|
flyte/_trace.py
CHANGED
|
@@ -3,6 +3,7 @@ import inspect
|
|
|
3
3
|
import time
|
|
4
4
|
from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Callable, TypeGuard, TypeVar, Union, cast
|
|
5
5
|
|
|
6
|
+
from flyte._logging import logger
|
|
6
7
|
from flyte.models import NativeInterface
|
|
7
8
|
|
|
8
9
|
T = TypeVar("T")
|
|
@@ -33,10 +34,13 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
|
|
|
33
34
|
iface = NativeInterface.from_callable(func)
|
|
34
35
|
info, ok = await controller.get_action_outputs(iface, func, *args, **kwargs)
|
|
35
36
|
if ok:
|
|
37
|
+
logger.info(f"Found existing trace info for {func}, {info}")
|
|
36
38
|
if info.output:
|
|
37
39
|
return info.output
|
|
38
40
|
elif info.error:
|
|
39
41
|
raise info.error
|
|
42
|
+
else:
|
|
43
|
+
logger.debug(f"No existing trace info found for {func}, proceeding to execute.")
|
|
40
44
|
start_time = time.time()
|
|
41
45
|
try:
|
|
42
46
|
# Cast to Awaitable to satisfy mypy
|
|
@@ -44,6 +48,7 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
|
|
|
44
48
|
results = await coroutine_result
|
|
45
49
|
info.add_outputs(results, start_time=start_time, end_time=time.time())
|
|
46
50
|
await controller.record_trace(info)
|
|
51
|
+
logger.debug(f"Finished trace for {func}, {info}")
|
|
47
52
|
return results
|
|
48
53
|
except Exception as e:
|
|
49
54
|
# If there is an error, we need to record it
|
flyte/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '2.0.
|
|
32
|
-
__version_tuple__ = version_tuple = (2, 0, 0, '
|
|
31
|
+
__version__ = version = '2.0.0b19'
|
|
32
|
+
__version_tuple__ = version_tuple = (2, 0, 0, 'b19')
|
|
33
33
|
|
|
34
|
-
__commit_id__ = commit_id = '
|
|
34
|
+
__commit_id__ = commit_id = 'g172a8c6b7'
|
flyte/cli/_create.py
CHANGED
|
@@ -98,7 +98,7 @@ def secret(
|
|
|
98
98
|
"-o",
|
|
99
99
|
"--output",
|
|
100
100
|
type=click.Path(exists=False, writable=True),
|
|
101
|
-
default=Path.cwd() / "config.yaml",
|
|
101
|
+
default=Path.cwd() / ".flyte" / "config.yaml",
|
|
102
102
|
help="Path to the output directory where the configuration will be saved. Defaults to current directory.",
|
|
103
103
|
show_default=True,
|
|
104
104
|
)
|
|
@@ -147,6 +147,9 @@ def config(
|
|
|
147
147
|
|
|
148
148
|
output_path = Path(output)
|
|
149
149
|
|
|
150
|
+
if not output_path.parent.exists():
|
|
151
|
+
output_path.parent.mkdir(parents=True)
|
|
152
|
+
|
|
150
153
|
if output_path.exists() and not force:
|
|
151
154
|
force = click.confirm(f"Overwrite [{output_path}]?", default=False)
|
|
152
155
|
if not force:
|
flyte/cli/_deploy.py
CHANGED
|
@@ -4,8 +4,7 @@ from pathlib import Path
|
|
|
4
4
|
from types import ModuleType
|
|
5
5
|
from typing import Any, Dict, List, cast, get_args
|
|
6
6
|
|
|
7
|
-
import click
|
|
8
|
-
from click import Context
|
|
7
|
+
import rich_click as click
|
|
9
8
|
|
|
10
9
|
import flyte
|
|
11
10
|
|
|
@@ -87,14 +86,14 @@ class DeployArguments:
|
|
|
87
86
|
return [common.get_option_from_metadata(f.metadata) for f in fields(cls) if f.metadata]
|
|
88
87
|
|
|
89
88
|
|
|
90
|
-
class DeployEnvCommand(click.
|
|
89
|
+
class DeployEnvCommand(click.RichCommand):
|
|
91
90
|
def __init__(self, env_name: str, env: Any, deploy_args: DeployArguments, *args, **kwargs):
|
|
92
91
|
self.env_name = env_name
|
|
93
92
|
self.env = env
|
|
94
93
|
self.deploy_args = deploy_args
|
|
95
94
|
super().__init__(*args, **kwargs)
|
|
96
95
|
|
|
97
|
-
def invoke(self, ctx: Context):
|
|
96
|
+
def invoke(self, ctx: click.Context):
|
|
98
97
|
from rich.console import Console
|
|
99
98
|
|
|
100
99
|
console = Console()
|
|
@@ -125,7 +124,7 @@ class DeployEnvRecursiveCommand(click.Command):
|
|
|
125
124
|
self.deploy_args = deploy_args
|
|
126
125
|
super().__init__(*args, **kwargs)
|
|
127
126
|
|
|
128
|
-
def invoke(self, ctx: Context):
|
|
127
|
+
def invoke(self, ctx: click.Context):
|
|
129
128
|
from rich.console import Console
|
|
130
129
|
|
|
131
130
|
from flyte._environment import list_loaded_environments
|
flyte/cli/_params.py
CHANGED
|
@@ -283,13 +283,17 @@ class UnionParamType(click.ParamType):
|
|
|
283
283
|
A composite type that allows for multiple types to be specified. This is used for union types.
|
|
284
284
|
"""
|
|
285
285
|
|
|
286
|
-
def __init__(self, types: typing.List[click.ParamType]):
|
|
286
|
+
def __init__(self, types: typing.List[click.ParamType | None]):
|
|
287
287
|
super().__init__()
|
|
288
288
|
self._types = self._sort_precedence(types)
|
|
289
|
-
self.name = "|".join([t.name for t in self._types])
|
|
289
|
+
self.name = "|".join([t.name for t in self._types if t is not None])
|
|
290
|
+
self.optional = False
|
|
291
|
+
if None in types:
|
|
292
|
+
self.name = f"Optional[{self.name}]"
|
|
293
|
+
self.optional = True
|
|
290
294
|
|
|
291
295
|
@staticmethod
|
|
292
|
-
def _sort_precedence(tp: typing.List[click.ParamType]) -> typing.List[click.ParamType]:
|
|
296
|
+
def _sort_precedence(tp: typing.List[click.ParamType | None]) -> typing.List[click.ParamType]:
|
|
293
297
|
unprocessed = []
|
|
294
298
|
str_types = []
|
|
295
299
|
others = []
|
|
@@ -311,6 +315,8 @@ class UnionParamType(click.ParamType):
|
|
|
311
315
|
"""
|
|
312
316
|
for p in self._types:
|
|
313
317
|
try:
|
|
318
|
+
if p is None and value is None:
|
|
319
|
+
return None
|
|
314
320
|
return p.convert(value, param, ctx)
|
|
315
321
|
except Exception as e:
|
|
316
322
|
logger.debug(f"Ignoring conversion error for type {p} trying other variants in Union. Error: {e}")
|
|
@@ -433,7 +439,10 @@ def literal_type_to_click_type(lt: LiteralType, python_type: typing.Type) -> cli
|
|
|
433
439
|
for i in range(len(lt.union_type.variants)):
|
|
434
440
|
variant = lt.union_type.variants[i]
|
|
435
441
|
variant_python_type = typing.get_args(python_type)[i]
|
|
436
|
-
|
|
442
|
+
if variant_python_type is type(None):
|
|
443
|
+
cts.append(None)
|
|
444
|
+
else:
|
|
445
|
+
cts.append(literal_type_to_click_type(variant, variant_python_type))
|
|
437
446
|
return UnionParamType(cts)
|
|
438
447
|
|
|
439
448
|
if lt.HasField("enum_type"):
|
|
@@ -461,6 +470,9 @@ class FlyteLiteralConverter(object):
|
|
|
461
470
|
def is_bool(self) -> bool:
|
|
462
471
|
return self.click_type == click.BOOL
|
|
463
472
|
|
|
473
|
+
def is_optional(self) -> bool:
|
|
474
|
+
return isinstance(self.click_type, UnionParamType) and self.click_type.optional
|
|
475
|
+
|
|
464
476
|
def convert(
|
|
465
477
|
self, ctx: click.Context, param: typing.Optional[click.Parameter], value: typing.Any
|
|
466
478
|
) -> typing.Union[Literal, typing.Any]:
|
|
@@ -525,6 +537,8 @@ def to_click_option(
|
|
|
525
537
|
if literal_converter.is_bool():
|
|
526
538
|
required = False
|
|
527
539
|
is_flag = True
|
|
540
|
+
if literal_converter.is_optional():
|
|
541
|
+
required = False
|
|
528
542
|
|
|
529
543
|
return click.Option(
|
|
530
544
|
param_decls=[f"--{input_name}"],
|
flyte/cli/_run.py
CHANGED
|
@@ -112,7 +112,7 @@ class RunArguments:
|
|
|
112
112
|
return [common.get_option_from_metadata(f.metadata) for f in fields(cls) if f.metadata]
|
|
113
113
|
|
|
114
114
|
|
|
115
|
-
class RunTaskCommand(click.
|
|
115
|
+
class RunTaskCommand(click.RichCommand):
|
|
116
116
|
def __init__(self, obj_name: str, obj: Any, run_args: RunArguments, *args, **kwargs):
|
|
117
117
|
self.obj_name = obj_name
|
|
118
118
|
self.obj = cast(TaskTemplate, obj)
|
|
@@ -196,7 +196,7 @@ class TaskPerFileGroup(common.ObjectsPerFileGroup):
|
|
|
196
196
|
)
|
|
197
197
|
|
|
198
198
|
|
|
199
|
-
class RunReferenceTaskCommand(click.
|
|
199
|
+
class RunReferenceTaskCommand(click.RichCommand):
|
|
200
200
|
def __init__(self, task_name: str, run_args: RunArguments, version: str | None, *args, **kwargs):
|
|
201
201
|
self.task_name = task_name
|
|
202
202
|
self.run_args = run_args
|