flyte 2.0.0b32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +108 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/debug.py +38 -0
- flyte/_bin/runtime.py +195 -0
- flyte/_bin/serve.py +178 -0
- flyte/_build.py +26 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +147 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/local_cache.py +216 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +121 -0
- flyte/_code_bundle/_packaging.py +218 -0
- flyte/_code_bundle/_utils.py +347 -0
- flyte/_code_bundle/bundle.py +266 -0
- flyte/_constants.py +1 -0
- flyte/_context.py +155 -0
- flyte/_custom_context.py +73 -0
- flyte/_debug/__init__.py +0 -0
- flyte/_debug/constants.py +38 -0
- flyte/_debug/utils.py +17 -0
- flyte/_debug/vscode.py +307 -0
- flyte/_deploy.py +408 -0
- flyte/_deployer.py +109 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +122 -0
- flyte/_excepthook.py +37 -0
- flyte/_group.py +32 -0
- flyte/_hash.py +8 -0
- flyte/_image.py +1055 -0
- flyte/_initialize.py +628 -0
- flyte/_interface.py +119 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +129 -0
- flyte/_internal/controllers/_local_controller.py +239 -0
- flyte/_internal/controllers/_trace.py +48 -0
- flyte/_internal/controllers/remote/__init__.py +58 -0
- flyte/_internal/controllers/remote/_action.py +211 -0
- flyte/_internal/controllers/remote/_client.py +47 -0
- flyte/_internal/controllers/remote/_controller.py +583 -0
- flyte/_internal/controllers/remote/_core.py +465 -0
- flyte/_internal/controllers/remote/_informer.py +381 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +3 -0
- flyte/_internal/imagebuild/docker_builder.py +706 -0
- flyte/_internal/imagebuild/image_builder.py +277 -0
- flyte/_internal/imagebuild/remote_builder.py +386 -0
- flyte/_internal/imagebuild/utils.py +78 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +21 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +486 -0
- flyte/_internal/runtime/entrypoints.py +204 -0
- flyte/_internal/runtime/io.py +188 -0
- flyte/_internal/runtime/resources_serde.py +152 -0
- flyte/_internal/runtime/reuse.py +125 -0
- flyte/_internal/runtime/rusty.py +193 -0
- flyte/_internal/runtime/task_serde.py +362 -0
- flyte/_internal/runtime/taskrunner.py +209 -0
- flyte/_internal/runtime/trigger_serde.py +160 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_keyring/__init__.py +0 -0
- flyte/_keyring/file.py +115 -0
- flyte/_logging.py +300 -0
- flyte/_map.py +312 -0
- flyte/_module.py +72 -0
- flyte/_pod.py +30 -0
- flyte/_resources.py +473 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +102 -0
- flyte/_run.py +724 -0
- flyte/_secret.py +96 -0
- flyte/_task.py +550 -0
- flyte/_task_environment.py +316 -0
- flyte/_task_plugins.py +47 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +119 -0
- flyte/_trigger.py +1000 -0
- flyte/_utils/__init__.py +30 -0
- flyte/_utils/asyn.py +121 -0
- flyte/_utils/async_cache.py +139 -0
- flyte/_utils/coro_management.py +27 -0
- flyte/_utils/docker_credentials.py +173 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +134 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/module_loader.py +104 -0
- flyte/_utils/org_discovery.py +57 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +34 -0
- flyte/app/__init__.py +22 -0
- flyte/app/_app_environment.py +157 -0
- flyte/app/_deploy.py +125 -0
- flyte/app/_input.py +160 -0
- flyte/app/_runtime/__init__.py +3 -0
- flyte/app/_runtime/app_serde.py +347 -0
- flyte/app/_types.py +101 -0
- flyte/app/extras/__init__.py +3 -0
- flyte/app/extras/_fastapi.py +151 -0
- flyte/cli/__init__.py +12 -0
- flyte/cli/_abort.py +28 -0
- flyte/cli/_build.py +114 -0
- flyte/cli/_common.py +468 -0
- flyte/cli/_create.py +371 -0
- flyte/cli/_delete.py +45 -0
- flyte/cli/_deploy.py +293 -0
- flyte/cli/_gen.py +176 -0
- flyte/cli/_get.py +370 -0
- flyte/cli/_option.py +33 -0
- flyte/cli/_params.py +554 -0
- flyte/cli/_plugins.py +209 -0
- flyte/cli/_run.py +597 -0
- flyte/cli/_serve.py +64 -0
- flyte/cli/_update.py +37 -0
- flyte/cli/_user.py +17 -0
- flyte/cli/main.py +221 -0
- flyte/config/__init__.py +3 -0
- flyte/config/_config.py +248 -0
- flyte/config/_internal.py +73 -0
- flyte/config/_reader.py +225 -0
- flyte/connectors/__init__.py +11 -0
- flyte/connectors/_connector.py +270 -0
- flyte/connectors/_server.py +197 -0
- flyte/connectors/utils.py +135 -0
- flyte/errors.py +243 -0
- flyte/extend.py +19 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +286 -0
- flyte/git/__init__.py +3 -0
- flyte/git/_config.py +21 -0
- flyte/io/__init__.py +29 -0
- flyte/io/_dataframe/__init__.py +131 -0
- flyte/io/_dataframe/basic_dfs.py +223 -0
- flyte/io/_dataframe/dataframe.py +1026 -0
- flyte/io/_dir.py +910 -0
- flyte/io/_file.py +914 -0
- flyte/io/_hashing_io.py +342 -0
- flyte/models.py +479 -0
- flyte/py.typed +0 -0
- flyte/remote/__init__.py +35 -0
- flyte/remote/_action.py +738 -0
- flyte/remote/_app.py +57 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +189 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_auth_utils.py +14 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +403 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +213 -0
- flyte/remote/_client/auth/_client_config.py +85 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +152 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +128 -0
- flyte/remote/_common.py +30 -0
- flyte/remote/_console.py +19 -0
- flyte/remote/_data.py +161 -0
- flyte/remote/_logs.py +185 -0
- flyte/remote/_project.py +88 -0
- flyte/remote/_run.py +386 -0
- flyte/remote/_secret.py +142 -0
- flyte/remote/_task.py +527 -0
- flyte/remote/_trigger.py +306 -0
- flyte/remote/_user.py +33 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +182 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +36 -0
- flyte/storage/_config.py +237 -0
- flyte/storage/_parallel_reader.py +274 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +456 -0
- flyte/storage/_utils.py +5 -0
- flyte/syncify/__init__.py +56 -0
- flyte/syncify/_api.py +375 -0
- flyte/types/__init__.py +52 -0
- flyte/types/_interface.py +40 -0
- flyte/types/_pickle.py +145 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +119 -0
- flyte/types/_type_engine.py +2254 -0
- flyte/types/_utils.py +80 -0
- flyte-2.0.0b32.data/scripts/debug.py +38 -0
- flyte-2.0.0b32.data/scripts/runtime.py +195 -0
- flyte-2.0.0b32.dist-info/METADATA +351 -0
- flyte-2.0.0b32.dist-info/RECORD +204 -0
- flyte-2.0.0b32.dist-info/WHEEL +5 -0
- flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
- flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
- flyte-2.0.0b32.dist-info/top_level.txt +1 -0
flyte/_logging.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Literal, Optional
|
|
8
|
+
|
|
9
|
+
import flyte
|
|
10
|
+
|
|
11
|
+
from ._tools import ipython_check
|
|
12
|
+
|
|
13
|
+
LogFormat = Literal["console", "json"]
|
|
14
|
+
|
|
15
|
+
DEFAULT_LOG_LEVEL = logging.WARNING
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def make_hyperlink(label: str, url: str):
|
|
19
|
+
"""
|
|
20
|
+
Create a hyperlink in the terminal output.
|
|
21
|
+
"""
|
|
22
|
+
BLUE = "\033[94m"
|
|
23
|
+
RESET = "\033[0m"
|
|
24
|
+
OSC8_BEGIN = f"\033]8;;{url}\033\\"
|
|
25
|
+
OSC8_END = "\033]8;;\033\\"
|
|
26
|
+
return f"{BLUE}{OSC8_BEGIN}{label}{RESET}{OSC8_END}"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def is_rich_logging_disabled() -> bool:
|
|
30
|
+
"""
|
|
31
|
+
Check if rich logging is enabled
|
|
32
|
+
"""
|
|
33
|
+
return os.environ.get("DISABLE_RICH_LOGGING") is not None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_env_log_level() -> int:
|
|
37
|
+
return int(os.environ.get("LOG_LEVEL", DEFAULT_LOG_LEVEL))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def log_format_from_env() -> LogFormat:
|
|
41
|
+
"""
|
|
42
|
+
Get the log format from the environment variable.
|
|
43
|
+
"""
|
|
44
|
+
format_str = os.environ.get("LOG_FORMAT", "console")
|
|
45
|
+
if format_str not in ("console", "json"):
|
|
46
|
+
return "console"
|
|
47
|
+
return format_str # type: ignore[return-value]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _get_console():
|
|
51
|
+
"""
|
|
52
|
+
Get the console.
|
|
53
|
+
"""
|
|
54
|
+
from rich.console import Console
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
width = os.get_terminal_size().columns
|
|
58
|
+
except Exception as e:
|
|
59
|
+
logger.debug(f"Failed to get terminal size: {e}")
|
|
60
|
+
width = 160
|
|
61
|
+
|
|
62
|
+
return Console(width=width)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_rich_handler(log_level: int) -> Optional[logging.Handler]:
|
|
66
|
+
"""
|
|
67
|
+
Upgrades the global loggers to use Rich logging.
|
|
68
|
+
"""
|
|
69
|
+
ctx = flyte.ctx()
|
|
70
|
+
if ctx and ctx.is_in_cluster():
|
|
71
|
+
return None
|
|
72
|
+
if not ipython_check() and is_rich_logging_disabled():
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
import click
|
|
76
|
+
from rich.highlighter import NullHighlighter
|
|
77
|
+
from rich.logging import RichHandler
|
|
78
|
+
|
|
79
|
+
handler = RichHandler(
|
|
80
|
+
tracebacks_suppress=[click],
|
|
81
|
+
rich_tracebacks=False,
|
|
82
|
+
omit_repeated_times=False,
|
|
83
|
+
show_path=False,
|
|
84
|
+
log_time_format="%H:%M:%S.%f",
|
|
85
|
+
console=_get_console(),
|
|
86
|
+
level=log_level,
|
|
87
|
+
highlighter=NullHighlighter(),
|
|
88
|
+
markup=True,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
formatter = logging.Formatter(fmt="%(filename)s:%(lineno)d - %(message)s")
|
|
92
|
+
handler.setFormatter(formatter)
|
|
93
|
+
return handler
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class JSONFormatter(logging.Formatter):
|
|
97
|
+
"""
|
|
98
|
+
Formatter that outputs JSON strings for each log record.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
102
|
+
log_data = {
|
|
103
|
+
"timestamp": datetime.fromtimestamp(record.created).isoformat(),
|
|
104
|
+
"level": record.levelname,
|
|
105
|
+
"logger": record.name,
|
|
106
|
+
"message": record.getMessage(),
|
|
107
|
+
"filename": record.filename,
|
|
108
|
+
"lineno": record.lineno,
|
|
109
|
+
"funcName": record.funcName,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
# Add context fields if present
|
|
113
|
+
if getattr(record, "run_name", None):
|
|
114
|
+
log_data["run_name"] = record.run_name # type: ignore[attr-defined]
|
|
115
|
+
if getattr(record, "action_name", None):
|
|
116
|
+
log_data["action_name"] = record.action_name # type: ignore[attr-defined]
|
|
117
|
+
if getattr(record, "is_flyte_internal", False):
|
|
118
|
+
log_data["is_flyte_internal"] = True
|
|
119
|
+
|
|
120
|
+
# Add exception info if present
|
|
121
|
+
if record.exc_info:
|
|
122
|
+
log_data["exc_info"] = self.formatException(record.exc_info)
|
|
123
|
+
|
|
124
|
+
return json.dumps(log_data)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def initialize_logger(log_level: int | None = None, log_format: LogFormat | None = None, enable_rich: bool = False):
|
|
128
|
+
"""
|
|
129
|
+
Initializes the global loggers to the default configuration.
|
|
130
|
+
When enable_rich=True, upgrades to Rich handler for local CLI usage.
|
|
131
|
+
"""
|
|
132
|
+
global logger # noqa: PLW0603
|
|
133
|
+
|
|
134
|
+
if log_level is None:
|
|
135
|
+
log_level = get_env_log_level()
|
|
136
|
+
if log_format is None:
|
|
137
|
+
log_format = log_format_from_env()
|
|
138
|
+
|
|
139
|
+
# Clear existing handlers to reconfigure
|
|
140
|
+
root = logging.getLogger()
|
|
141
|
+
root.handlers.clear()
|
|
142
|
+
|
|
143
|
+
flyte_logger = logging.getLogger("flyte")
|
|
144
|
+
flyte_logger.handlers.clear()
|
|
145
|
+
|
|
146
|
+
# Determine log format (JSON takes precedence over Rich)
|
|
147
|
+
use_json = log_format == "json"
|
|
148
|
+
use_rich = enable_rich and not use_json
|
|
149
|
+
|
|
150
|
+
# Set up root logger handler
|
|
151
|
+
root_handler: logging.Handler | None = None
|
|
152
|
+
if use_json:
|
|
153
|
+
root_handler = logging.StreamHandler()
|
|
154
|
+
root_handler.setFormatter(JSONFormatter())
|
|
155
|
+
elif use_rich:
|
|
156
|
+
root_handler = get_rich_handler(log_level)
|
|
157
|
+
|
|
158
|
+
if root_handler is None:
|
|
159
|
+
root_handler = logging.StreamHandler()
|
|
160
|
+
|
|
161
|
+
# Add context filter to root handler for all logging
|
|
162
|
+
root_handler.addFilter(ContextFilter())
|
|
163
|
+
root_handler.setLevel(logging.DEBUG)
|
|
164
|
+
root.addHandler(root_handler)
|
|
165
|
+
|
|
166
|
+
# Set up Flyte logger handler
|
|
167
|
+
flyte_handler: logging.Handler | None = None
|
|
168
|
+
if use_json:
|
|
169
|
+
flyte_handler = logging.StreamHandler()
|
|
170
|
+
flyte_handler.setLevel(log_level)
|
|
171
|
+
flyte_handler.setFormatter(JSONFormatter())
|
|
172
|
+
elif use_rich:
|
|
173
|
+
flyte_handler = get_rich_handler(log_level)
|
|
174
|
+
|
|
175
|
+
if flyte_handler is None:
|
|
176
|
+
flyte_handler = logging.StreamHandler()
|
|
177
|
+
flyte_handler.setLevel(log_level)
|
|
178
|
+
formatter = logging.Formatter(fmt="%(message)s")
|
|
179
|
+
flyte_handler.setFormatter(formatter)
|
|
180
|
+
|
|
181
|
+
# Add both filters to Flyte handler
|
|
182
|
+
flyte_handler.addFilter(FlyteInternalFilter())
|
|
183
|
+
flyte_handler.addFilter(ContextFilter())
|
|
184
|
+
|
|
185
|
+
flyte_logger.addHandler(flyte_handler)
|
|
186
|
+
flyte_logger.setLevel(log_level)
|
|
187
|
+
flyte_logger.propagate = False # Prevent double logging
|
|
188
|
+
|
|
189
|
+
logger = flyte_logger
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def log(fn=None, *, level=logging.DEBUG, entry=True, exit=True):
|
|
193
|
+
"""
|
|
194
|
+
Decorator to log function calls.
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
def decorator(func):
|
|
198
|
+
if logger.isEnabledFor(level):
|
|
199
|
+
|
|
200
|
+
def wrapper(*args, **kwargs):
|
|
201
|
+
if entry:
|
|
202
|
+
logger.log(level, f"[{func.__name__}] with args: {args} and kwargs: {kwargs}")
|
|
203
|
+
try:
|
|
204
|
+
return func(*args, **kwargs)
|
|
205
|
+
finally:
|
|
206
|
+
if exit:
|
|
207
|
+
logger.log(level, f"[{func.__name__}] completed")
|
|
208
|
+
|
|
209
|
+
return wrapper
|
|
210
|
+
return func
|
|
211
|
+
|
|
212
|
+
if fn is None:
|
|
213
|
+
return decorator
|
|
214
|
+
return decorator(fn)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class ContextFilter(logging.Filter):
|
|
218
|
+
"""
|
|
219
|
+
A logging filter that adds the current action's run name and name to all log records.
|
|
220
|
+
Applied globally to capture context for both user and Flyte internal logging.
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
def filter(self, record: logging.LogRecord) -> bool:
|
|
224
|
+
from flyte._context import ctx
|
|
225
|
+
|
|
226
|
+
c = ctx()
|
|
227
|
+
if c:
|
|
228
|
+
action = c.action
|
|
229
|
+
# Add as attributes for structured logging (JSON)
|
|
230
|
+
record.run_name = action.run_name
|
|
231
|
+
record.action_name = action.name
|
|
232
|
+
# Also modify message for console/Rich output
|
|
233
|
+
record.msg = f"[{action.run_name}][{action.name}] {record.msg}"
|
|
234
|
+
else:
|
|
235
|
+
record.run_name = None
|
|
236
|
+
record.action_name = None
|
|
237
|
+
return True
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class FlyteInternalFilter(logging.Filter):
|
|
241
|
+
"""
|
|
242
|
+
A logging filter that adds [flyte] prefix to internal Flyte logging only.
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
def filter(self, record: logging.LogRecord) -> bool:
|
|
246
|
+
is_internal = record.name.startswith("flyte")
|
|
247
|
+
# Add as attribute for structured logging (JSON)
|
|
248
|
+
record.is_flyte_internal = is_internal
|
|
249
|
+
# Also modify message for console/Rich output
|
|
250
|
+
if is_internal:
|
|
251
|
+
record.msg = f"[flyte] {record.msg}"
|
|
252
|
+
return True
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _setup_root_logger():
|
|
256
|
+
"""
|
|
257
|
+
Configure the root logger to capture all logging with context information.
|
|
258
|
+
This ensures both user code and Flyte internal logging get the context.
|
|
259
|
+
"""
|
|
260
|
+
root = logging.getLogger()
|
|
261
|
+
root.handlers.clear() # Remove any existing handlers to prevent double logging
|
|
262
|
+
|
|
263
|
+
# Create a basic handler for the root logger
|
|
264
|
+
handler = logging.StreamHandler()
|
|
265
|
+
# Add context filter to ALL logging
|
|
266
|
+
handler.addFilter(ContextFilter())
|
|
267
|
+
handler.setLevel(logging.DEBUG)
|
|
268
|
+
|
|
269
|
+
# Simple formatter since filters handle prefixes
|
|
270
|
+
root.addHandler(handler)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def _create_flyte_logger() -> logging.Logger:
|
|
274
|
+
"""
|
|
275
|
+
Create the internal Flyte logger with [flyte] prefix.
|
|
276
|
+
"""
|
|
277
|
+
flyte_logger = logging.getLogger("flyte")
|
|
278
|
+
flyte_logger.setLevel(get_env_log_level())
|
|
279
|
+
|
|
280
|
+
# Add a handler specifically for flyte logging with the prefix filter
|
|
281
|
+
handler = logging.StreamHandler()
|
|
282
|
+
handler.setLevel(get_env_log_level())
|
|
283
|
+
handler.addFilter(FlyteInternalFilter())
|
|
284
|
+
handler.addFilter(ContextFilter())
|
|
285
|
+
|
|
286
|
+
formatter = logging.Formatter(fmt="%(message)s")
|
|
287
|
+
handler.setFormatter(formatter)
|
|
288
|
+
|
|
289
|
+
# Prevent propagation to root to avoid double logging
|
|
290
|
+
flyte_logger.propagate = False
|
|
291
|
+
flyte_logger.addHandler(handler)
|
|
292
|
+
|
|
293
|
+
return flyte_logger
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# Initialize root logger for global context
|
|
297
|
+
_setup_root_logger()
|
|
298
|
+
|
|
299
|
+
# Create the Flyte internal logger
|
|
300
|
+
logger = _create_flyte_logger()
|
flyte/_map.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import functools
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast, overload
|
|
5
|
+
|
|
6
|
+
from flyte.syncify import syncify
|
|
7
|
+
|
|
8
|
+
from ._group import group
|
|
9
|
+
from ._logging import logger
|
|
10
|
+
from ._task import AsyncFunctionTaskTemplate, F, P, R
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MapAsyncIterator(Generic[P, R]):
|
|
14
|
+
"""AsyncIterator implementation for the map function results"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
19
|
+
args: tuple,
|
|
20
|
+
name: str,
|
|
21
|
+
concurrency: int,
|
|
22
|
+
return_exceptions: bool,
|
|
23
|
+
):
|
|
24
|
+
self.func = func
|
|
25
|
+
self.args = args
|
|
26
|
+
self.name = name
|
|
27
|
+
self.concurrency = concurrency
|
|
28
|
+
self.return_exceptions = return_exceptions
|
|
29
|
+
self._tasks: List[asyncio.Task] = []
|
|
30
|
+
self._current_index = 0
|
|
31
|
+
self._completed_count = 0
|
|
32
|
+
self._exception_count = 0
|
|
33
|
+
self._task_count = 0
|
|
34
|
+
self._initialized = False
|
|
35
|
+
|
|
36
|
+
def __aiter__(self) -> AsyncIterator[Union[R, Exception]]:
|
|
37
|
+
"""Return self as the async iterator"""
|
|
38
|
+
return self
|
|
39
|
+
|
|
40
|
+
async def __anext__(self) -> Union[R, Exception]:
|
|
41
|
+
"""Get the next result"""
|
|
42
|
+
# Initialize on first call
|
|
43
|
+
if not self._initialized:
|
|
44
|
+
await self._initialize()
|
|
45
|
+
|
|
46
|
+
# Check if we've exhausted all tasks
|
|
47
|
+
if self._current_index >= self._task_count:
|
|
48
|
+
raise StopAsyncIteration
|
|
49
|
+
|
|
50
|
+
# Get the next task result
|
|
51
|
+
task = self._tasks[self._current_index]
|
|
52
|
+
self._current_index += 1
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
result = await task
|
|
56
|
+
self._completed_count += 1
|
|
57
|
+
logger.debug(f"Task {self._current_index - 1} completed successfully")
|
|
58
|
+
return result
|
|
59
|
+
except Exception as e:
|
|
60
|
+
self._exception_count += 1
|
|
61
|
+
logger.debug(
|
|
62
|
+
f"Task {self._current_index - 1} failed with exception: {e}, return_exceptions={self.return_exceptions}"
|
|
63
|
+
)
|
|
64
|
+
if self.return_exceptions:
|
|
65
|
+
return e
|
|
66
|
+
else:
|
|
67
|
+
# Cancel remaining tasks
|
|
68
|
+
for remaining_task in self._tasks[self._current_index + 1 :]:
|
|
69
|
+
remaining_task.cancel()
|
|
70
|
+
logger.warning("Exception raising is `ON`, raising exception and cancelling remaining tasks")
|
|
71
|
+
raise e
|
|
72
|
+
|
|
73
|
+
async def _initialize(self):
|
|
74
|
+
"""Initialize the tasks - called lazily on first iteration"""
|
|
75
|
+
# Create all tasks at once
|
|
76
|
+
tasks = []
|
|
77
|
+
task_count = 0
|
|
78
|
+
|
|
79
|
+
if isinstance(self.func, functools.partial):
|
|
80
|
+
# Handle partial functions by merging bound args/kwargs with mapped args
|
|
81
|
+
base_func = cast(AsyncFunctionTaskTemplate, self.func.func)
|
|
82
|
+
bound_args = self.func.args
|
|
83
|
+
bound_kwargs = self.func.keywords or {}
|
|
84
|
+
|
|
85
|
+
for arg_tuple in zip(*self.args):
|
|
86
|
+
# Merge bound positional args with mapped args
|
|
87
|
+
merged_args = bound_args + arg_tuple
|
|
88
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
89
|
+
logger.debug(f"Running {base_func.name} with args: {merged_args} and kwargs: {bound_kwargs}")
|
|
90
|
+
task = asyncio.create_task(base_func.aio(*merged_args, **bound_kwargs))
|
|
91
|
+
tasks.append(task)
|
|
92
|
+
task_count += 1
|
|
93
|
+
else:
|
|
94
|
+
# Handle regular TaskTemplate functions
|
|
95
|
+
for arg_tuple in zip(*self.args):
|
|
96
|
+
task = asyncio.create_task(self.func.aio(*arg_tuple))
|
|
97
|
+
tasks.append(task)
|
|
98
|
+
task_count += 1
|
|
99
|
+
|
|
100
|
+
if task_count == 0:
|
|
101
|
+
logger.info(f"Group '{self.name}' has no tasks to process")
|
|
102
|
+
self._tasks = []
|
|
103
|
+
self._task_count = 0
|
|
104
|
+
else:
|
|
105
|
+
logger.info(f"Starting {task_count} tasks in group '{self.name}' with unlimited concurrency")
|
|
106
|
+
self._tasks = tasks
|
|
107
|
+
self._task_count = task_count
|
|
108
|
+
|
|
109
|
+
self._initialized = True
|
|
110
|
+
|
|
111
|
+
async def collect(self) -> List[Union[R, Exception]]:
|
|
112
|
+
"""Convenience method to collect all results into a list"""
|
|
113
|
+
results = []
|
|
114
|
+
async for result in self:
|
|
115
|
+
results.append(result)
|
|
116
|
+
return results
|
|
117
|
+
|
|
118
|
+
def __repr__(self):
|
|
119
|
+
return f"MapAsyncIterator(group_name='{self.name}', concurrency={self.concurrency})"
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class _Mapper(Generic[P, R]):
|
|
123
|
+
"""
|
|
124
|
+
Internal mapper class to handle the mapping logic
|
|
125
|
+
|
|
126
|
+
NOTE: The reason why we do not use the `@syncify` decorator here is because, in `syncify` we cannot use
|
|
127
|
+
context managers like `group` directly in the function body. This is because the `__exit__` method of the
|
|
128
|
+
context manager is called after the function returns. An for `_context` the `__exit__` method releases the
|
|
129
|
+
token (for contextvar), which was created in a separate thread. This leads to an exception like:
|
|
130
|
+
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def _get_name(cls, task_name: str, group_name: str | None) -> str:
|
|
135
|
+
"""Get the name of the group, defaulting to 'map' if not provided."""
|
|
136
|
+
return f"{task_name}_{group_name or 'map'}"
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def validate_partial(func: functools.partial[R]):
|
|
140
|
+
"""
|
|
141
|
+
This method validates that the provided partial function is valid for mapping, i.e. only the one argument
|
|
142
|
+
is left for mapping and the rest are provided as keywords or args.
|
|
143
|
+
|
|
144
|
+
:param func: partial function to validate
|
|
145
|
+
:raises TypeError: if the partial function is not valid for mapping
|
|
146
|
+
"""
|
|
147
|
+
f = cast(AsyncFunctionTaskTemplate, func.func)
|
|
148
|
+
inputs = f.native_interface.inputs
|
|
149
|
+
params = list(inputs.keys())
|
|
150
|
+
total_params = len(params)
|
|
151
|
+
provided_args = len(func.args)
|
|
152
|
+
provided_kwargs = len(func.keywords or {})
|
|
153
|
+
|
|
154
|
+
# Calculate how many parameters are left unspecified
|
|
155
|
+
unspecified_count = total_params - provided_args - provided_kwargs
|
|
156
|
+
|
|
157
|
+
# Exactly one parameter should be left for mapping
|
|
158
|
+
if unspecified_count != 1:
|
|
159
|
+
raise TypeError(
|
|
160
|
+
f"Partial function must leave exactly one parameter unspecified for mapping. "
|
|
161
|
+
f"Found {unspecified_count} unspecified parameters in {f.name}, "
|
|
162
|
+
f"params: {inputs.keys()}"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Validate that no parameter is both in args and keywords
|
|
166
|
+
if func.keywords:
|
|
167
|
+
param_names = list(inputs.keys())
|
|
168
|
+
for i, arg_name in enumerate(param_names[: provided_args + 1]):
|
|
169
|
+
if arg_name in func.keywords:
|
|
170
|
+
raise TypeError(
|
|
171
|
+
f"Parameter '{arg_name}' is provided both as positional argument and keyword argument "
|
|
172
|
+
f"in partial function {f.name}."
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
@overload
|
|
176
|
+
def __call__(
|
|
177
|
+
self,
|
|
178
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
179
|
+
*args: Iterable[Any],
|
|
180
|
+
group_name: str | None = None,
|
|
181
|
+
concurrency: int = 0,
|
|
182
|
+
) -> Iterator[R]: ...
|
|
183
|
+
|
|
184
|
+
@overload
|
|
185
|
+
def __call__(
|
|
186
|
+
self,
|
|
187
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
188
|
+
*args: Iterable[Any],
|
|
189
|
+
group_name: str | None = None,
|
|
190
|
+
concurrency: int = 0,
|
|
191
|
+
return_exceptions: bool = True,
|
|
192
|
+
) -> Iterator[Union[R, Exception]]: ...
|
|
193
|
+
|
|
194
|
+
def __call__(
|
|
195
|
+
self,
|
|
196
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
197
|
+
*args: Iterable[Any],
|
|
198
|
+
group_name: str | None = None,
|
|
199
|
+
concurrency: int = 0,
|
|
200
|
+
return_exceptions: bool = True,
|
|
201
|
+
) -> Iterator[Union[R, Exception]]:
|
|
202
|
+
"""
|
|
203
|
+
Map a function over the provided arguments with concurrent execution.
|
|
204
|
+
|
|
205
|
+
:param func: The async function to map.
|
|
206
|
+
:param args: Positional arguments to pass to the function (iterables that will be zipped).
|
|
207
|
+
:param group_name: The name of the group for the mapped tasks.
|
|
208
|
+
:param concurrency: The maximum number of concurrent tasks to run. If 0, run all tasks concurrently.
|
|
209
|
+
:param return_exceptions: If True, yield exceptions instead of raising them.
|
|
210
|
+
:return: AsyncIterator yielding results in order.
|
|
211
|
+
"""
|
|
212
|
+
if not args:
|
|
213
|
+
return
|
|
214
|
+
|
|
215
|
+
if isinstance(func, functools.partial):
|
|
216
|
+
f = cast(AsyncFunctionTaskTemplate, func.func)
|
|
217
|
+
self.validate_partial(func)
|
|
218
|
+
else:
|
|
219
|
+
f = cast(AsyncFunctionTaskTemplate, func)
|
|
220
|
+
|
|
221
|
+
name = self._get_name(f.name, group_name)
|
|
222
|
+
logger.debug(f"Blocking Map for {name}")
|
|
223
|
+
with group(name):
|
|
224
|
+
import flyte
|
|
225
|
+
|
|
226
|
+
tctx = flyte.ctx()
|
|
227
|
+
if tctx is None or tctx.mode == "local":
|
|
228
|
+
logger.warning("Running map in local mode, which will run every task sequentially.")
|
|
229
|
+
for v in zip(*args):
|
|
230
|
+
try:
|
|
231
|
+
yield func(*v) # type: ignore
|
|
232
|
+
except Exception as e:
|
|
233
|
+
if return_exceptions:
|
|
234
|
+
yield e
|
|
235
|
+
else:
|
|
236
|
+
raise e
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
i = 0
|
|
240
|
+
for x in cast(
|
|
241
|
+
Iterator[R],
|
|
242
|
+
_map(
|
|
243
|
+
func,
|
|
244
|
+
*args,
|
|
245
|
+
name=name,
|
|
246
|
+
concurrency=concurrency,
|
|
247
|
+
return_exceptions=return_exceptions,
|
|
248
|
+
),
|
|
249
|
+
):
|
|
250
|
+
logger.debug(f"Mapped {x}, task {i}")
|
|
251
|
+
i += 1
|
|
252
|
+
yield x
|
|
253
|
+
|
|
254
|
+
async def aio(
|
|
255
|
+
self,
|
|
256
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
257
|
+
*args: Iterable[Any],
|
|
258
|
+
group_name: str | None = None,
|
|
259
|
+
concurrency: int = 0,
|
|
260
|
+
return_exceptions: bool = True,
|
|
261
|
+
) -> AsyncGenerator[Union[R, Exception], None]:
|
|
262
|
+
if not args:
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
if isinstance(func, functools.partial):
|
|
266
|
+
f = cast(AsyncFunctionTaskTemplate, func.func)
|
|
267
|
+
self.validate_partial(func)
|
|
268
|
+
else:
|
|
269
|
+
f = cast(AsyncFunctionTaskTemplate, func)
|
|
270
|
+
|
|
271
|
+
name = self._get_name(f.name, group_name)
|
|
272
|
+
with group(name):
|
|
273
|
+
import flyte
|
|
274
|
+
|
|
275
|
+
tctx = flyte.ctx()
|
|
276
|
+
if tctx is None or tctx.mode == "local":
|
|
277
|
+
logger.warning("Running map in local mode, which will run every task sequentially.")
|
|
278
|
+
for v in zip(*args):
|
|
279
|
+
try:
|
|
280
|
+
yield func(*v) # type: ignore
|
|
281
|
+
except Exception as e:
|
|
282
|
+
if return_exceptions:
|
|
283
|
+
yield e
|
|
284
|
+
else:
|
|
285
|
+
raise e
|
|
286
|
+
return
|
|
287
|
+
async for x in _map.aio(
|
|
288
|
+
func,
|
|
289
|
+
*args,
|
|
290
|
+
name=name,
|
|
291
|
+
concurrency=concurrency,
|
|
292
|
+
return_exceptions=return_exceptions,
|
|
293
|
+
):
|
|
294
|
+
yield cast(Union[R, Exception], x)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
@syncify
|
|
298
|
+
async def _map(
|
|
299
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
300
|
+
*args: Iterable[Any],
|
|
301
|
+
name: str = "map",
|
|
302
|
+
concurrency: int = 0,
|
|
303
|
+
return_exceptions: bool = True,
|
|
304
|
+
) -> AsyncIterator[Union[R, Exception]]:
|
|
305
|
+
iter = MapAsyncIterator(
|
|
306
|
+
func=func, args=args, name=name, concurrency=concurrency, return_exceptions=return_exceptions
|
|
307
|
+
)
|
|
308
|
+
async for result in iter:
|
|
309
|
+
yield result
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
map: _Mapper = _Mapper()
|
flyte/_module.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import os
|
|
3
|
+
import pathlib
|
|
4
|
+
import sys
|
|
5
|
+
from types import ModuleType
|
|
6
|
+
from typing import Tuple
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def extract_obj_module(obj: object, /, source_dir: pathlib.Path | None = None) -> Tuple[str, ModuleType]:
|
|
10
|
+
"""
|
|
11
|
+
Extract the module from the given object. If source_dir is provided, the module will be relative to the source_dir.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
obj: The object to extract the module from.
|
|
15
|
+
source_dir: The source directory to use for relative paths.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
The module name as a string.
|
|
19
|
+
"""
|
|
20
|
+
if source_dir is None:
|
|
21
|
+
raise ValueError("extract_obj_module: source_dir cannot be None - specify root-dir")
|
|
22
|
+
# Get the module containing the object
|
|
23
|
+
entity_module = inspect.getmodule(obj)
|
|
24
|
+
if entity_module is None:
|
|
25
|
+
obj_name = getattr(obj, "__name__", str(obj))
|
|
26
|
+
raise ValueError(f"Object {obj_name} has no module.")
|
|
27
|
+
|
|
28
|
+
fp = entity_module.__file__
|
|
29
|
+
if fp is None:
|
|
30
|
+
obj_name = getattr(obj, "__name__", str(obj))
|
|
31
|
+
raise ValueError(f"Object {obj_name} has no module.")
|
|
32
|
+
|
|
33
|
+
file_path = pathlib.Path(fp)
|
|
34
|
+
try:
|
|
35
|
+
# Get the relative path to the current directory
|
|
36
|
+
# Will raise ValueError if the file is not in the source directory
|
|
37
|
+
relative_path = file_path.relative_to(str(pathlib.Path(source_dir).absolute()))
|
|
38
|
+
|
|
39
|
+
if relative_path == pathlib.Path("_internal/resolvers"):
|
|
40
|
+
entity_module_name = entity_module.__name__
|
|
41
|
+
else:
|
|
42
|
+
# Replace file separators with dots and remove the '.py' extension
|
|
43
|
+
dotted_path = os.path.splitext(str(relative_path))[0].replace(os.sep, ".")
|
|
44
|
+
entity_module_name = dotted_path
|
|
45
|
+
except ValueError:
|
|
46
|
+
# If source_dir is not provided or file is not in source_dir, fallback to module name
|
|
47
|
+
# File is not relative to source_dir - check if it's an installed package
|
|
48
|
+
file_path_str = str(file_path)
|
|
49
|
+
if "site-packages" in file_path_str or "dist-packages" in file_path_str:
|
|
50
|
+
# It's an installed package - use the module's __name__ directly
|
|
51
|
+
# This will be importable via importlib.import_module()
|
|
52
|
+
entity_module_name = entity_module.__name__
|
|
53
|
+
else:
|
|
54
|
+
# File is not in source_dir and not in site-packages - re-raise the error
|
|
55
|
+
obj_name = getattr(obj, "__name__", str(obj))
|
|
56
|
+
raise ValueError(
|
|
57
|
+
f"Object {obj_name} module file {file_path} is not relative to "
|
|
58
|
+
f"source directory {source_dir} and is not an installed package."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if entity_module_name == "__main__":
|
|
62
|
+
"""
|
|
63
|
+
This case is for the case in which the object is run from the main module.
|
|
64
|
+
"""
|
|
65
|
+
fp = sys.modules["__main__"].__file__
|
|
66
|
+
if fp is None:
|
|
67
|
+
obj_name = getattr(obj, "__name__", str(obj))
|
|
68
|
+
raise ValueError(f"Object {obj_name} has no module.")
|
|
69
|
+
main_path = pathlib.Path(fp)
|
|
70
|
+
entity_module_name = main_path.stem
|
|
71
|
+
|
|
72
|
+
return entity_module_name, entity_module
|