flyte 0.2.0b1__py3-none-any.whl → 2.0.0b46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyte/__init__.py +83 -30
- flyte/_bin/connect.py +61 -0
- flyte/_bin/debug.py +38 -0
- flyte/_bin/runtime.py +87 -19
- flyte/_bin/serve.py +351 -0
- flyte/_build.py +3 -2
- flyte/_cache/cache.py +6 -5
- flyte/_cache/local_cache.py +216 -0
- flyte/_code_bundle/_ignore.py +31 -5
- flyte/_code_bundle/_packaging.py +42 -11
- flyte/_code_bundle/_utils.py +57 -34
- flyte/_code_bundle/bundle.py +130 -27
- flyte/_constants.py +1 -0
- flyte/_context.py +21 -5
- flyte/_custom_context.py +73 -0
- flyte/_debug/constants.py +37 -0
- flyte/_debug/utils.py +17 -0
- flyte/_debug/vscode.py +315 -0
- flyte/_deploy.py +396 -75
- flyte/_deployer.py +109 -0
- flyte/_environment.py +94 -11
- flyte/_excepthook.py +37 -0
- flyte/_group.py +2 -1
- flyte/_hash.py +1 -16
- flyte/_image.py +544 -231
- flyte/_initialize.py +456 -316
- flyte/_interface.py +40 -5
- flyte/_internal/controllers/__init__.py +22 -8
- flyte/_internal/controllers/_local_controller.py +159 -35
- flyte/_internal/controllers/_trace.py +18 -10
- flyte/_internal/controllers/remote/__init__.py +38 -9
- flyte/_internal/controllers/remote/_action.py +82 -12
- flyte/_internal/controllers/remote/_client.py +6 -2
- flyte/_internal/controllers/remote/_controller.py +290 -64
- flyte/_internal/controllers/remote/_core.py +155 -95
- flyte/_internal/controllers/remote/_informer.py +40 -20
- flyte/_internal/controllers/remote/_service_protocol.py +2 -2
- flyte/_internal/imagebuild/__init__.py +2 -10
- flyte/_internal/imagebuild/docker_builder.py +391 -84
- flyte/_internal/imagebuild/image_builder.py +111 -55
- flyte/_internal/imagebuild/remote_builder.py +409 -0
- flyte/_internal/imagebuild/utils.py +79 -0
- flyte/_internal/resolvers/_app_env_module.py +92 -0
- flyte/_internal/resolvers/_task_module.py +5 -38
- flyte/_internal/resolvers/app_env.py +26 -0
- flyte/_internal/resolvers/common.py +8 -1
- flyte/_internal/resolvers/default.py +2 -2
- flyte/_internal/runtime/convert.py +319 -36
- flyte/_internal/runtime/entrypoints.py +106 -18
- flyte/_internal/runtime/io.py +71 -23
- flyte/_internal/runtime/resources_serde.py +21 -7
- flyte/_internal/runtime/reuse.py +125 -0
- flyte/_internal/runtime/rusty.py +196 -0
- flyte/_internal/runtime/task_serde.py +239 -66
- flyte/_internal/runtime/taskrunner.py +48 -8
- flyte/_internal/runtime/trigger_serde.py +162 -0
- flyte/_internal/runtime/types_serde.py +7 -16
- flyte/_keyring/file.py +115 -0
- flyte/_link.py +30 -0
- flyte/_logging.py +241 -42
- flyte/_map.py +312 -0
- flyte/_metrics.py +59 -0
- flyte/_module.py +74 -0
- flyte/_pod.py +30 -0
- flyte/_resources.py +296 -33
- flyte/_retry.py +1 -7
- flyte/_reusable_environment.py +72 -7
- flyte/_run.py +462 -132
- flyte/_secret.py +47 -11
- flyte/_serve.py +333 -0
- flyte/_task.py +245 -56
- flyte/_task_environment.py +219 -97
- flyte/_task_plugins.py +47 -0
- flyte/_tools.py +8 -8
- flyte/_trace.py +15 -24
- flyte/_trigger.py +1027 -0
- flyte/_utils/__init__.py +12 -1
- flyte/_utils/asyn.py +3 -1
- flyte/_utils/async_cache.py +139 -0
- flyte/_utils/coro_management.py +5 -4
- flyte/_utils/description_parser.py +19 -0
- flyte/_utils/docker_credentials.py +173 -0
- flyte/_utils/helpers.py +45 -19
- flyte/_utils/module_loader.py +123 -0
- flyte/_utils/org_discovery.py +57 -0
- flyte/_utils/uv_script_parser.py +8 -1
- flyte/_version.py +16 -3
- flyte/app/__init__.py +27 -0
- flyte/app/_app_environment.py +362 -0
- flyte/app/_connector_environment.py +40 -0
- flyte/app/_deploy.py +130 -0
- flyte/app/_parameter.py +343 -0
- flyte/app/_runtime/__init__.py +3 -0
- flyte/app/_runtime/app_serde.py +383 -0
- flyte/app/_types.py +113 -0
- flyte/app/extras/__init__.py +9 -0
- flyte/app/extras/_auth_middleware.py +217 -0
- flyte/app/extras/_fastapi.py +93 -0
- flyte/app/extras/_model_loader/__init__.py +3 -0
- flyte/app/extras/_model_loader/config.py +7 -0
- flyte/app/extras/_model_loader/loader.py +288 -0
- flyte/cli/__init__.py +12 -0
- flyte/cli/_abort.py +28 -0
- flyte/cli/_build.py +114 -0
- flyte/cli/_common.py +493 -0
- flyte/cli/_create.py +371 -0
- flyte/cli/_delete.py +45 -0
- flyte/cli/_deploy.py +401 -0
- flyte/cli/_gen.py +316 -0
- flyte/cli/_get.py +446 -0
- flyte/cli/_option.py +33 -0
- flyte/{_cli → cli}/_params.py +57 -17
- flyte/cli/_plugins.py +209 -0
- flyte/cli/_prefetch.py +292 -0
- flyte/cli/_run.py +690 -0
- flyte/cli/_serve.py +338 -0
- flyte/cli/_update.py +86 -0
- flyte/cli/_user.py +20 -0
- flyte/cli/main.py +246 -0
- flyte/config/__init__.py +2 -167
- flyte/config/_config.py +215 -163
- flyte/config/_internal.py +10 -1
- flyte/config/_reader.py +225 -0
- flyte/connectors/__init__.py +11 -0
- flyte/connectors/_connector.py +330 -0
- flyte/connectors/_server.py +194 -0
- flyte/connectors/utils.py +159 -0
- flyte/errors.py +134 -2
- flyte/extend.py +24 -0
- flyte/extras/_container.py +69 -56
- flyte/git/__init__.py +3 -0
- flyte/git/_config.py +279 -0
- flyte/io/__init__.py +8 -1
- flyte/io/{structured_dataset → _dataframe}/__init__.py +32 -30
- flyte/io/{structured_dataset → _dataframe}/basic_dfs.py +75 -68
- flyte/io/{structured_dataset/structured_dataset.py → _dataframe/dataframe.py} +207 -242
- flyte/io/_dir.py +575 -113
- flyte/io/_file.py +587 -141
- flyte/io/_hashing_io.py +342 -0
- flyte/io/extend.py +7 -0
- flyte/models.py +635 -0
- flyte/prefetch/__init__.py +22 -0
- flyte/prefetch/_hf_model.py +563 -0
- flyte/remote/__init__.py +14 -3
- flyte/remote/_action.py +879 -0
- flyte/remote/_app.py +346 -0
- flyte/remote/_auth_metadata.py +42 -0
- flyte/remote/_client/_protocols.py +62 -4
- flyte/remote/_client/auth/_auth_utils.py +19 -0
- flyte/remote/_client/auth/_authenticators/base.py +8 -2
- flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
- flyte/remote/_client/auth/_authenticators/factory.py +4 -0
- flyte/remote/_client/auth/_authenticators/passthrough.py +79 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +17 -18
- flyte/remote/_client/auth/_channel.py +47 -18
- flyte/remote/_client/auth/_client_config.py +5 -3
- flyte/remote/_client/auth/_keyring.py +15 -2
- flyte/remote/_client/auth/_token_client.py +3 -3
- flyte/remote/_client/controlplane.py +206 -18
- flyte/remote/_common.py +66 -0
- flyte/remote/_data.py +107 -22
- flyte/remote/_logs.py +116 -33
- flyte/remote/_project.py +21 -19
- flyte/remote/_run.py +164 -631
- flyte/remote/_secret.py +72 -29
- flyte/remote/_task.py +387 -46
- flyte/remote/_trigger.py +368 -0
- flyte/remote/_user.py +43 -0
- flyte/report/_report.py +10 -6
- flyte/storage/__init__.py +13 -1
- flyte/storage/_config.py +237 -0
- flyte/storage/_parallel_reader.py +289 -0
- flyte/storage/_storage.py +268 -59
- flyte/syncify/__init__.py +56 -0
- flyte/syncify/_api.py +414 -0
- flyte/types/__init__.py +39 -0
- flyte/types/_interface.py +22 -7
- flyte/{io/pickle/transformer.py → types/_pickle.py} +37 -9
- flyte/types/_string_literals.py +8 -9
- flyte/types/_type_engine.py +226 -126
- flyte/types/_utils.py +1 -1
- flyte-2.0.0b46.data/scripts/debug.py +38 -0
- flyte-2.0.0b46.data/scripts/runtime.py +194 -0
- flyte-2.0.0b46.dist-info/METADATA +352 -0
- flyte-2.0.0b46.dist-info/RECORD +221 -0
- flyte-2.0.0b46.dist-info/entry_points.txt +8 -0
- flyte-2.0.0b46.dist-info/licenses/LICENSE +201 -0
- flyte/_api_commons.py +0 -3
- flyte/_cli/_common.py +0 -299
- flyte/_cli/_create.py +0 -42
- flyte/_cli/_delete.py +0 -23
- flyte/_cli/_deploy.py +0 -140
- flyte/_cli/_get.py +0 -235
- flyte/_cli/_run.py +0 -174
- flyte/_cli/main.py +0 -98
- flyte/_datastructures.py +0 -342
- flyte/_internal/controllers/pbhash.py +0 -39
- flyte/_protos/common/authorization_pb2.py +0 -66
- flyte/_protos/common/authorization_pb2.pyi +0 -108
- flyte/_protos/common/authorization_pb2_grpc.py +0 -4
- flyte/_protos/common/identifier_pb2.py +0 -71
- flyte/_protos/common/identifier_pb2.pyi +0 -82
- flyte/_protos/common/identifier_pb2_grpc.py +0 -4
- flyte/_protos/common/identity_pb2.py +0 -48
- flyte/_protos/common/identity_pb2.pyi +0 -72
- flyte/_protos/common/identity_pb2_grpc.py +0 -4
- flyte/_protos/common/list_pb2.py +0 -36
- flyte/_protos/common/list_pb2.pyi +0 -69
- flyte/_protos/common/list_pb2_grpc.py +0 -4
- flyte/_protos/common/policy_pb2.py +0 -37
- flyte/_protos/common/policy_pb2.pyi +0 -27
- flyte/_protos/common/policy_pb2_grpc.py +0 -4
- flyte/_protos/common/role_pb2.py +0 -37
- flyte/_protos/common/role_pb2.pyi +0 -53
- flyte/_protos/common/role_pb2_grpc.py +0 -4
- flyte/_protos/common/runtime_version_pb2.py +0 -28
- flyte/_protos/common/runtime_version_pb2.pyi +0 -24
- flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
- flyte/_protos/logs/dataplane/payload_pb2.py +0 -96
- flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -168
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/definition_pb2.py +0 -49
- flyte/_protos/secret/definition_pb2.pyi +0 -93
- flyte/_protos/secret/definition_pb2_grpc.py +0 -4
- flyte/_protos/secret/payload_pb2.py +0 -62
- flyte/_protos/secret/payload_pb2.pyi +0 -94
- flyte/_protos/secret/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/secret_pb2.py +0 -38
- flyte/_protos/secret/secret_pb2.pyi +0 -6
- flyte/_protos/secret/secret_pb2_grpc.py +0 -198
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
- flyte/_protos/validate/validate/validate_pb2.py +0 -76
- flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
- flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
- flyte/_protos/workflow/queue_service_pb2.py +0 -106
- flyte/_protos/workflow/queue_service_pb2.pyi +0 -141
- flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
- flyte/_protos/workflow/run_definition_pb2.py +0 -128
- flyte/_protos/workflow/run_definition_pb2.pyi +0 -310
- flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
- flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
- flyte/_protos/workflow/run_service_pb2.py +0 -133
- flyte/_protos/workflow/run_service_pb2.pyi +0 -175
- flyte/_protos/workflow/run_service_pb2_grpc.py +0 -412
- flyte/_protos/workflow/state_service_pb2.py +0 -58
- flyte/_protos/workflow/state_service_pb2.pyi +0 -71
- flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
- flyte/_protos/workflow/task_definition_pb2.py +0 -72
- flyte/_protos/workflow/task_definition_pb2.pyi +0 -65
- flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/task_service_pb2.py +0 -44
- flyte/_protos/workflow/task_service_pb2.pyi +0 -31
- flyte/_protos/workflow/task_service_pb2_grpc.py +0 -104
- flyte/io/_dataframe.py +0 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/remote/_console.py +0 -18
- flyte-0.2.0b1.dist-info/METADATA +0 -179
- flyte-0.2.0b1.dist-info/RECORD +0 -204
- flyte-0.2.0b1.dist-info/entry_points.txt +0 -3
- /flyte/{_cli → _debug}/__init__.py +0 -0
- /flyte/{_protos → _keyring}/__init__.py +0 -0
- {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/WHEEL +0 -0
- {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/top_level.txt +0 -0
flyte/_logging.py
CHANGED
|
@@ -1,12 +1,36 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import json
|
|
3
4
|
import logging
|
|
4
5
|
import os
|
|
5
|
-
from
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Literal, Optional
|
|
6
8
|
|
|
7
|
-
|
|
9
|
+
import flyte
|
|
8
10
|
|
|
9
|
-
|
|
11
|
+
from ._tools import ipython_check
|
|
12
|
+
|
|
13
|
+
LogFormat = Literal["console", "json"]
|
|
14
|
+
_LOG_LEVEL_MAP = {
|
|
15
|
+
"critical": logging.CRITICAL, # 50
|
|
16
|
+
"error": logging.ERROR, # 40
|
|
17
|
+
"warning": logging.WARNING, # 30
|
|
18
|
+
"warn": logging.WARNING, # 30
|
|
19
|
+
"info": logging.INFO, # 20
|
|
20
|
+
"debug": logging.DEBUG, # 10
|
|
21
|
+
}
|
|
22
|
+
DEFAULT_LOG_LEVEL = logging.WARNING
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def make_hyperlink(label: str, url: str):
|
|
26
|
+
"""
|
|
27
|
+
Create a hyperlink in the terminal output.
|
|
28
|
+
"""
|
|
29
|
+
BLUE = "\033[94m"
|
|
30
|
+
RESET = "\033[0m"
|
|
31
|
+
OSC8_BEGIN = f"\033]8;;{url}\033\\"
|
|
32
|
+
OSC8_END = "\033]8;;\033\\"
|
|
33
|
+
return f"{BLUE}{OSC8_BEGIN}{label}{RESET}{OSC8_END}"
|
|
10
34
|
|
|
11
35
|
|
|
12
36
|
def is_rich_logging_disabled() -> bool:
|
|
@@ -17,43 +41,69 @@ def is_rich_logging_disabled() -> bool:
|
|
|
17
41
|
|
|
18
42
|
|
|
19
43
|
def get_env_log_level() -> int:
|
|
20
|
-
|
|
44
|
+
value = os.getenv("LOG_LEVEL")
|
|
45
|
+
if value is None:
|
|
46
|
+
return DEFAULT_LOG_LEVEL
|
|
47
|
+
# Case 1: numeric value ("10", "20", "5", etc.)
|
|
48
|
+
if value.isdigit():
|
|
49
|
+
return int(value)
|
|
21
50
|
|
|
51
|
+
# Case 2: named log level ("info", "debug", ...)
|
|
52
|
+
if value.lower() in _LOG_LEVEL_MAP:
|
|
53
|
+
return _LOG_LEVEL_MAP[value.lower()]
|
|
22
54
|
|
|
23
|
-
|
|
55
|
+
return DEFAULT_LOG_LEVEL
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def log_format_from_env() -> LogFormat:
|
|
24
59
|
"""
|
|
25
60
|
Get the log format from the environment variable.
|
|
26
61
|
"""
|
|
27
|
-
|
|
62
|
+
format_str = os.environ.get("LOG_FORMAT", "console")
|
|
63
|
+
if format_str not in ("console", "json"):
|
|
64
|
+
return "console"
|
|
65
|
+
return format_str # type: ignore[return-value]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _get_console():
|
|
69
|
+
"""
|
|
70
|
+
Get the console.
|
|
71
|
+
"""
|
|
72
|
+
from rich.console import Console
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
width = os.get_terminal_size().columns
|
|
76
|
+
except Exception as e:
|
|
77
|
+
logger.debug(f"Failed to get terminal size: {e}")
|
|
78
|
+
width = 160
|
|
79
|
+
|
|
80
|
+
return Console(width=width)
|
|
28
81
|
|
|
29
82
|
|
|
30
83
|
def get_rich_handler(log_level: int) -> Optional[logging.Handler]:
|
|
31
84
|
"""
|
|
32
85
|
Upgrades the global loggers to use Rich logging.
|
|
33
86
|
"""
|
|
34
|
-
|
|
87
|
+
ctx = flyte.ctx()
|
|
88
|
+
if ctx and ctx.is_in_cluster():
|
|
35
89
|
return None
|
|
36
90
|
if not ipython_check() and is_rich_logging_disabled():
|
|
37
91
|
return None
|
|
38
92
|
|
|
39
93
|
import click
|
|
40
|
-
from rich.
|
|
94
|
+
from rich.highlighter import NullHighlighter
|
|
41
95
|
from rich.logging import RichHandler
|
|
42
96
|
|
|
43
|
-
try:
|
|
44
|
-
width = os.get_terminal_size().columns
|
|
45
|
-
except Exception as e:
|
|
46
|
-
logger.debug(f"Failed to get terminal size: {e}")
|
|
47
|
-
width = 160
|
|
48
|
-
|
|
49
97
|
handler = RichHandler(
|
|
50
98
|
tracebacks_suppress=[click],
|
|
51
|
-
rich_tracebacks=
|
|
99
|
+
rich_tracebacks=False,
|
|
52
100
|
omit_repeated_times=False,
|
|
53
101
|
show_path=False,
|
|
54
102
|
log_time_format="%H:%M:%S.%f",
|
|
55
|
-
console=
|
|
103
|
+
console=_get_console(),
|
|
56
104
|
level=log_level,
|
|
105
|
+
highlighter=NullHighlighter(),
|
|
106
|
+
markup=True,
|
|
57
107
|
)
|
|
58
108
|
|
|
59
109
|
formatter = logging.Formatter(fmt="%(filename)s:%(lineno)d - %(message)s")
|
|
@@ -61,39 +111,99 @@ def get_rich_handler(log_level: int) -> Optional[logging.Handler]:
|
|
|
61
111
|
return handler
|
|
62
112
|
|
|
63
113
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
114
|
+
class JSONFormatter(logging.Formatter):
|
|
115
|
+
"""
|
|
116
|
+
Formatter that outputs JSON strings for each log record.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
120
|
+
log_data = {
|
|
121
|
+
"timestamp": datetime.fromtimestamp(record.created).isoformat(),
|
|
122
|
+
"level": record.levelname,
|
|
123
|
+
"logger": record.name,
|
|
124
|
+
"message": record.getMessage(),
|
|
125
|
+
"filename": record.filename,
|
|
126
|
+
"lineno": record.lineno,
|
|
127
|
+
"funcName": record.funcName,
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
# Add context fields if present
|
|
131
|
+
if getattr(record, "run_name", None):
|
|
132
|
+
log_data["run_name"] = record.run_name # type: ignore[attr-defined]
|
|
133
|
+
if getattr(record, "action_name", None):
|
|
134
|
+
log_data["action_name"] = record.action_name # type: ignore[attr-defined]
|
|
135
|
+
if getattr(record, "is_flyte_internal", False):
|
|
136
|
+
log_data["is_flyte_internal"] = True
|
|
73
137
|
|
|
138
|
+
# Add metric fields if present
|
|
139
|
+
if getattr(record, "metric_type", None):
|
|
140
|
+
log_data["metric_type"] = record.metric_type # type: ignore[attr-defined]
|
|
141
|
+
log_data["metric_name"] = record.metric_name # type: ignore[attr-defined]
|
|
142
|
+
log_data["duration_seconds"] = record.duration_seconds # type: ignore[attr-defined]
|
|
74
143
|
|
|
75
|
-
|
|
144
|
+
# Add exception info if present
|
|
145
|
+
if record.exc_info:
|
|
146
|
+
log_data["exc_info"] = self.formatException(record.exc_info)
|
|
147
|
+
|
|
148
|
+
return json.dumps(log_data)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def initialize_logger(
|
|
152
|
+
log_level: int | None = None,
|
|
153
|
+
log_format: LogFormat | None = None,
|
|
154
|
+
enable_rich: bool = False,
|
|
155
|
+
reset_root_logger: bool = False,
|
|
156
|
+
):
|
|
76
157
|
"""
|
|
77
158
|
Initializes the global loggers to the default configuration.
|
|
159
|
+
When enable_rich=True, upgrades to Rich handler for local CLI usage.
|
|
78
160
|
"""
|
|
79
161
|
global logger # noqa: PLW0603
|
|
80
|
-
logger = _create_logger("flyte", log_level, enable_rich)
|
|
81
162
|
|
|
163
|
+
if log_level is None:
|
|
164
|
+
log_level = get_env_log_level()
|
|
165
|
+
if log_format is None:
|
|
166
|
+
log_format = log_format_from_env()
|
|
82
167
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
if
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
168
|
+
flyte_logger = logging.getLogger("flyte")
|
|
169
|
+
flyte_logger.handlers.clear()
|
|
170
|
+
|
|
171
|
+
# Determine log format (JSON takes precedence over Rich)
|
|
172
|
+
use_json = log_format == "json"
|
|
173
|
+
use_rich = enable_rich and not use_json
|
|
174
|
+
|
|
175
|
+
reset_root_logger = reset_root_logger or os.environ.get("FLYTE_RESET_ROOT_LOGGER") == "1"
|
|
176
|
+
if reset_root_logger:
|
|
177
|
+
_setup_root_logger(use_json=use_json, use_rich=use_rich, log_level=log_level)
|
|
178
|
+
else:
|
|
179
|
+
root_logger = logging.getLogger()
|
|
180
|
+
for h in root_logger.handlers:
|
|
181
|
+
h.addFilter(ContextFilter())
|
|
182
|
+
|
|
183
|
+
# Set up Flyte logger handler
|
|
184
|
+
flyte_handler: logging.Handler | None = None
|
|
185
|
+
if use_json:
|
|
186
|
+
flyte_handler = logging.StreamHandler()
|
|
187
|
+
flyte_handler.setLevel(log_level)
|
|
188
|
+
flyte_handler.setFormatter(JSONFormatter())
|
|
189
|
+
elif use_rich:
|
|
190
|
+
flyte_handler = get_rich_handler(log_level)
|
|
191
|
+
|
|
192
|
+
if flyte_handler is None:
|
|
193
|
+
flyte_handler = logging.StreamHandler()
|
|
194
|
+
flyte_handler.setLevel(log_level)
|
|
195
|
+
formatter = logging.Formatter(fmt="%(message)s")
|
|
196
|
+
flyte_handler.setFormatter(formatter)
|
|
197
|
+
|
|
198
|
+
# Add both filters to Flyte handler
|
|
199
|
+
flyte_handler.addFilter(FlyteInternalFilter())
|
|
200
|
+
flyte_handler.addFilter(ContextFilter())
|
|
201
|
+
|
|
202
|
+
flyte_logger.addHandler(flyte_handler)
|
|
203
|
+
flyte_logger.setLevel(log_level)
|
|
204
|
+
flyte_logger.propagate = False # Prevent double logging
|
|
205
|
+
|
|
206
|
+
logger = flyte_logger
|
|
97
207
|
|
|
98
208
|
|
|
99
209
|
def log(fn=None, *, level=logging.DEBUG, entry=True, exit=True):
|
|
@@ -121,4 +231,93 @@ def log(fn=None, *, level=logging.DEBUG, entry=True, exit=True):
|
|
|
121
231
|
return decorator(fn)
|
|
122
232
|
|
|
123
233
|
|
|
124
|
-
|
|
234
|
+
class ContextFilter(logging.Filter):
|
|
235
|
+
"""
|
|
236
|
+
A logging filter that adds the current action's run name and name to all log records.
|
|
237
|
+
Applied globally to capture context for both user and Flyte internal logging.
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
def filter(self, record: logging.LogRecord) -> bool:
|
|
241
|
+
from flyte._context import ctx
|
|
242
|
+
|
|
243
|
+
c = ctx()
|
|
244
|
+
if c:
|
|
245
|
+
action = c.action
|
|
246
|
+
# Add as attributes for structured logging (JSON)
|
|
247
|
+
record.run_name = action.run_name
|
|
248
|
+
record.action_name = action.name
|
|
249
|
+
# Also modify message for console/Rich output
|
|
250
|
+
record.msg = f"[{action.run_name}][{action.name}] {record.msg}"
|
|
251
|
+
else:
|
|
252
|
+
record.run_name = None
|
|
253
|
+
record.action_name = None
|
|
254
|
+
return True
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class FlyteInternalFilter(logging.Filter):
|
|
258
|
+
"""
|
|
259
|
+
A logging filter that adds [flyte] prefix to internal Flyte logging only.
|
|
260
|
+
"""
|
|
261
|
+
|
|
262
|
+
def filter(self, record: logging.LogRecord) -> bool:
|
|
263
|
+
is_internal = record.name.startswith("flyte")
|
|
264
|
+
# Add as attribute for structured logging (JSON)
|
|
265
|
+
record.is_flyte_internal = is_internal
|
|
266
|
+
# Also modify message for console/Rich output
|
|
267
|
+
if is_internal:
|
|
268
|
+
record.msg = f"[flyte] {record.msg}"
|
|
269
|
+
return True
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def _setup_root_logger(use_json: bool, use_rich: bool, log_level: int):
|
|
273
|
+
"""
|
|
274
|
+
Wipe all handlers from the root logger and reconfigure. This ensures
|
|
275
|
+
both user/library logging and Flyte internal logging get context information and look the same.
|
|
276
|
+
"""
|
|
277
|
+
root = logging.getLogger()
|
|
278
|
+
root.handlers.clear() # Remove any existing handlers to prevent double logging
|
|
279
|
+
|
|
280
|
+
root_handler: logging.Handler | None = None
|
|
281
|
+
if use_json:
|
|
282
|
+
root_handler = logging.StreamHandler()
|
|
283
|
+
root_handler.setFormatter(JSONFormatter())
|
|
284
|
+
elif use_rich:
|
|
285
|
+
root_handler = get_rich_handler(log_level)
|
|
286
|
+
|
|
287
|
+
# get_rich_handler can return None in some environments
|
|
288
|
+
if not root_handler:
|
|
289
|
+
root_handler = logging.StreamHandler()
|
|
290
|
+
|
|
291
|
+
# Add context filter to ALL logging
|
|
292
|
+
root_handler.addFilter(ContextFilter())
|
|
293
|
+
root_handler.setLevel(log_level)
|
|
294
|
+
|
|
295
|
+
root.addHandler(root_handler)
|
|
296
|
+
root.setLevel(log_level)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _create_flyte_logger() -> logging.Logger:
|
|
300
|
+
"""
|
|
301
|
+
Create the internal Flyte logger with [flyte] prefix.
|
|
302
|
+
"""
|
|
303
|
+
flyte_logger = logging.getLogger("flyte")
|
|
304
|
+
flyte_logger.setLevel(get_env_log_level())
|
|
305
|
+
|
|
306
|
+
# Add a handler specifically for flyte logging with the prefix filter
|
|
307
|
+
handler = logging.StreamHandler()
|
|
308
|
+
handler.setLevel(get_env_log_level())
|
|
309
|
+
handler.addFilter(FlyteInternalFilter())
|
|
310
|
+
handler.addFilter(ContextFilter())
|
|
311
|
+
|
|
312
|
+
formatter = logging.Formatter(fmt="%(message)s")
|
|
313
|
+
handler.setFormatter(formatter)
|
|
314
|
+
|
|
315
|
+
# Prevent propagation to root to avoid double logging
|
|
316
|
+
flyte_logger.propagate = False
|
|
317
|
+
flyte_logger.addHandler(handler)
|
|
318
|
+
|
|
319
|
+
return flyte_logger
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
# Create the Flyte internal logger
|
|
323
|
+
logger = _create_flyte_logger()
|
flyte/_map.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import functools
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast, overload
|
|
5
|
+
|
|
6
|
+
from flyte.syncify import syncify
|
|
7
|
+
|
|
8
|
+
from ._group import group
|
|
9
|
+
from ._logging import logger
|
|
10
|
+
from ._task import AsyncFunctionTaskTemplate, F, P, R
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MapAsyncIterator(Generic[P, R]):
|
|
14
|
+
"""AsyncIterator implementation for the map function results"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
19
|
+
args: tuple,
|
|
20
|
+
name: str,
|
|
21
|
+
concurrency: int,
|
|
22
|
+
return_exceptions: bool,
|
|
23
|
+
):
|
|
24
|
+
self.func = func
|
|
25
|
+
self.args = args
|
|
26
|
+
self.name = name
|
|
27
|
+
self.concurrency = concurrency
|
|
28
|
+
self.return_exceptions = return_exceptions
|
|
29
|
+
self._tasks: List[asyncio.Task] = []
|
|
30
|
+
self._current_index = 0
|
|
31
|
+
self._completed_count = 0
|
|
32
|
+
self._exception_count = 0
|
|
33
|
+
self._task_count = 0
|
|
34
|
+
self._initialized = False
|
|
35
|
+
|
|
36
|
+
def __aiter__(self) -> AsyncIterator[Union[R, Exception]]:
|
|
37
|
+
"""Return self as the async iterator"""
|
|
38
|
+
return self
|
|
39
|
+
|
|
40
|
+
async def __anext__(self) -> Union[R, Exception]:
|
|
41
|
+
"""Get the next result"""
|
|
42
|
+
# Initialize on first call
|
|
43
|
+
if not self._initialized:
|
|
44
|
+
await self._initialize()
|
|
45
|
+
|
|
46
|
+
# Check if we've exhausted all tasks
|
|
47
|
+
if self._current_index >= self._task_count:
|
|
48
|
+
raise StopAsyncIteration
|
|
49
|
+
|
|
50
|
+
# Get the next task result
|
|
51
|
+
task = self._tasks[self._current_index]
|
|
52
|
+
self._current_index += 1
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
result = await task
|
|
56
|
+
self._completed_count += 1
|
|
57
|
+
logger.debug(f"Task {self._current_index - 1} completed successfully")
|
|
58
|
+
return result
|
|
59
|
+
except Exception as e:
|
|
60
|
+
self._exception_count += 1
|
|
61
|
+
logger.debug(
|
|
62
|
+
f"Task {self._current_index - 1} failed with exception: {e}, return_exceptions={self.return_exceptions}"
|
|
63
|
+
)
|
|
64
|
+
if self.return_exceptions:
|
|
65
|
+
return e
|
|
66
|
+
else:
|
|
67
|
+
# Cancel remaining tasks
|
|
68
|
+
for remaining_task in self._tasks[self._current_index + 1 :]:
|
|
69
|
+
remaining_task.cancel()
|
|
70
|
+
logger.warning("Exception raising is `ON`, raising exception and cancelling remaining tasks")
|
|
71
|
+
raise e
|
|
72
|
+
|
|
73
|
+
async def _initialize(self):
|
|
74
|
+
"""Initialize the tasks - called lazily on first iteration"""
|
|
75
|
+
# Create all tasks at once
|
|
76
|
+
tasks = []
|
|
77
|
+
task_count = 0
|
|
78
|
+
|
|
79
|
+
if isinstance(self.func, functools.partial):
|
|
80
|
+
# Handle partial functions by merging bound args/kwargs with mapped args
|
|
81
|
+
base_func = cast(AsyncFunctionTaskTemplate, self.func.func)
|
|
82
|
+
bound_args = self.func.args
|
|
83
|
+
bound_kwargs = self.func.keywords or {}
|
|
84
|
+
|
|
85
|
+
for arg_tuple in zip(*self.args):
|
|
86
|
+
# Merge bound positional args with mapped args
|
|
87
|
+
merged_args = bound_args + arg_tuple
|
|
88
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
89
|
+
logger.debug(f"Running {base_func.name} with args: {merged_args} and kwargs: {bound_kwargs}")
|
|
90
|
+
task = asyncio.create_task(base_func.aio(*merged_args, **bound_kwargs))
|
|
91
|
+
tasks.append(task)
|
|
92
|
+
task_count += 1
|
|
93
|
+
else:
|
|
94
|
+
# Handle regular TaskTemplate functions
|
|
95
|
+
for arg_tuple in zip(*self.args):
|
|
96
|
+
task = asyncio.create_task(self.func.aio(*arg_tuple))
|
|
97
|
+
tasks.append(task)
|
|
98
|
+
task_count += 1
|
|
99
|
+
|
|
100
|
+
if task_count == 0:
|
|
101
|
+
logger.info(f"Group '{self.name}' has no tasks to process")
|
|
102
|
+
self._tasks = []
|
|
103
|
+
self._task_count = 0
|
|
104
|
+
else:
|
|
105
|
+
logger.info(f"Starting {task_count} tasks in group '{self.name}' with unlimited concurrency")
|
|
106
|
+
self._tasks = tasks
|
|
107
|
+
self._task_count = task_count
|
|
108
|
+
|
|
109
|
+
self._initialized = True
|
|
110
|
+
|
|
111
|
+
async def collect(self) -> List[Union[R, Exception]]:
|
|
112
|
+
"""Convenience method to collect all results into a list"""
|
|
113
|
+
results = []
|
|
114
|
+
async for result in self:
|
|
115
|
+
results.append(result)
|
|
116
|
+
return results
|
|
117
|
+
|
|
118
|
+
def __repr__(self):
|
|
119
|
+
return f"MapAsyncIterator(group_name='{self.name}', concurrency={self.concurrency})"
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class _Mapper(Generic[P, R]):
|
|
123
|
+
"""
|
|
124
|
+
Internal mapper class to handle the mapping logic
|
|
125
|
+
|
|
126
|
+
NOTE: The reason why we do not use the `@syncify` decorator here is because, in `syncify` we cannot use
|
|
127
|
+
context managers like `group` directly in the function body. This is because the `__exit__` method of the
|
|
128
|
+
context manager is called after the function returns. An for `_context` the `__exit__` method releases the
|
|
129
|
+
token (for contextvar), which was created in a separate thread. This leads to an exception like:
|
|
130
|
+
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def _get_name(cls, task_name: str, group_name: str | None) -> str:
|
|
135
|
+
"""Get the name of the group, defaulting to 'map' if not provided."""
|
|
136
|
+
return f"{task_name}_{group_name or 'map'}"
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def validate_partial(func: functools.partial[R]):
|
|
140
|
+
"""
|
|
141
|
+
This method validates that the provided partial function is valid for mapping, i.e. only the one argument
|
|
142
|
+
is left for mapping and the rest are provided as keywords or args.
|
|
143
|
+
|
|
144
|
+
:param func: partial function to validate
|
|
145
|
+
:raises TypeError: if the partial function is not valid for mapping
|
|
146
|
+
"""
|
|
147
|
+
f = cast(AsyncFunctionTaskTemplate, func.func)
|
|
148
|
+
inputs = f.native_interface.inputs
|
|
149
|
+
params = list(inputs.keys())
|
|
150
|
+
total_params = len(params)
|
|
151
|
+
provided_args = len(func.args)
|
|
152
|
+
provided_kwargs = len(func.keywords or {})
|
|
153
|
+
|
|
154
|
+
# Calculate how many parameters are left unspecified
|
|
155
|
+
unspecified_count = total_params - provided_args - provided_kwargs
|
|
156
|
+
|
|
157
|
+
# Exactly one parameter should be left for mapping
|
|
158
|
+
if unspecified_count != 1:
|
|
159
|
+
raise TypeError(
|
|
160
|
+
f"Partial function must leave exactly one parameter unspecified for mapping. "
|
|
161
|
+
f"Found {unspecified_count} unspecified parameters in {f.name}, "
|
|
162
|
+
f"params: {inputs.keys()}"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Validate that no parameter is both in args and keywords
|
|
166
|
+
if func.keywords:
|
|
167
|
+
param_names = list(inputs.keys())
|
|
168
|
+
for i, arg_name in enumerate(param_names[: provided_args + 1]):
|
|
169
|
+
if arg_name in func.keywords:
|
|
170
|
+
raise TypeError(
|
|
171
|
+
f"Parameter '{arg_name}' is provided both as positional argument and keyword argument "
|
|
172
|
+
f"in partial function {f.name}."
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
@overload
|
|
176
|
+
def __call__(
|
|
177
|
+
self,
|
|
178
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
179
|
+
*args: Iterable[Any],
|
|
180
|
+
group_name: str | None = None,
|
|
181
|
+
concurrency: int = 0,
|
|
182
|
+
) -> Iterator[R]: ...
|
|
183
|
+
|
|
184
|
+
@overload
|
|
185
|
+
def __call__(
|
|
186
|
+
self,
|
|
187
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
188
|
+
*args: Iterable[Any],
|
|
189
|
+
group_name: str | None = None,
|
|
190
|
+
concurrency: int = 0,
|
|
191
|
+
return_exceptions: bool = True,
|
|
192
|
+
) -> Iterator[Union[R, Exception]]: ...
|
|
193
|
+
|
|
194
|
+
def __call__(
|
|
195
|
+
self,
|
|
196
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
197
|
+
*args: Iterable[Any],
|
|
198
|
+
group_name: str | None = None,
|
|
199
|
+
concurrency: int = 0,
|
|
200
|
+
return_exceptions: bool = True,
|
|
201
|
+
) -> Iterator[Union[R, Exception]]:
|
|
202
|
+
"""
|
|
203
|
+
Map a function over the provided arguments with concurrent execution.
|
|
204
|
+
|
|
205
|
+
:param func: The async function to map.
|
|
206
|
+
:param args: Positional arguments to pass to the function (iterables that will be zipped).
|
|
207
|
+
:param group_name: The name of the group for the mapped tasks.
|
|
208
|
+
:param concurrency: The maximum number of concurrent tasks to run. If 0, run all tasks concurrently.
|
|
209
|
+
:param return_exceptions: If True, yield exceptions instead of raising them.
|
|
210
|
+
:return: AsyncIterator yielding results in order.
|
|
211
|
+
"""
|
|
212
|
+
if not args:
|
|
213
|
+
return
|
|
214
|
+
|
|
215
|
+
if isinstance(func, functools.partial):
|
|
216
|
+
f = cast(AsyncFunctionTaskTemplate, func.func)
|
|
217
|
+
self.validate_partial(func)
|
|
218
|
+
else:
|
|
219
|
+
f = cast(AsyncFunctionTaskTemplate, func)
|
|
220
|
+
|
|
221
|
+
name = self._get_name(f.name, group_name)
|
|
222
|
+
logger.debug(f"Blocking Map for {name}")
|
|
223
|
+
with group(name):
|
|
224
|
+
import flyte
|
|
225
|
+
|
|
226
|
+
tctx = flyte.ctx()
|
|
227
|
+
if tctx is None or tctx.mode == "local":
|
|
228
|
+
logger.warning("Running map in local mode, which will run every task sequentially.")
|
|
229
|
+
for v in zip(*args):
|
|
230
|
+
try:
|
|
231
|
+
yield func(*v) # type: ignore
|
|
232
|
+
except Exception as e:
|
|
233
|
+
if return_exceptions:
|
|
234
|
+
yield e
|
|
235
|
+
else:
|
|
236
|
+
raise e
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
i = 0
|
|
240
|
+
for x in cast(
|
|
241
|
+
Iterator[R],
|
|
242
|
+
_map(
|
|
243
|
+
func,
|
|
244
|
+
*args,
|
|
245
|
+
name=name,
|
|
246
|
+
concurrency=concurrency,
|
|
247
|
+
return_exceptions=return_exceptions,
|
|
248
|
+
),
|
|
249
|
+
):
|
|
250
|
+
logger.debug(f"Mapped {x}, task {i}")
|
|
251
|
+
i += 1
|
|
252
|
+
yield x
|
|
253
|
+
|
|
254
|
+
async def aio(
|
|
255
|
+
self,
|
|
256
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
257
|
+
*args: Iterable[Any],
|
|
258
|
+
group_name: str | None = None,
|
|
259
|
+
concurrency: int = 0,
|
|
260
|
+
return_exceptions: bool = True,
|
|
261
|
+
) -> AsyncGenerator[Union[R, Exception], None]:
|
|
262
|
+
if not args:
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
if isinstance(func, functools.partial):
|
|
266
|
+
f = cast(AsyncFunctionTaskTemplate, func.func)
|
|
267
|
+
self.validate_partial(func)
|
|
268
|
+
else:
|
|
269
|
+
f = cast(AsyncFunctionTaskTemplate, func)
|
|
270
|
+
|
|
271
|
+
name = self._get_name(f.name, group_name)
|
|
272
|
+
with group(name):
|
|
273
|
+
import flyte
|
|
274
|
+
|
|
275
|
+
tctx = flyte.ctx()
|
|
276
|
+
if tctx is None or tctx.mode == "local":
|
|
277
|
+
logger.warning("Running map in local mode, which will run every task sequentially.")
|
|
278
|
+
for v in zip(*args):
|
|
279
|
+
try:
|
|
280
|
+
yield func(*v) # type: ignore
|
|
281
|
+
except Exception as e:
|
|
282
|
+
if return_exceptions:
|
|
283
|
+
yield e
|
|
284
|
+
else:
|
|
285
|
+
raise e
|
|
286
|
+
return
|
|
287
|
+
async for x in _map.aio(
|
|
288
|
+
func,
|
|
289
|
+
*args,
|
|
290
|
+
name=name,
|
|
291
|
+
concurrency=concurrency,
|
|
292
|
+
return_exceptions=return_exceptions,
|
|
293
|
+
):
|
|
294
|
+
yield cast(Union[R, Exception], x)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
@syncify
|
|
298
|
+
async def _map(
|
|
299
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
300
|
+
*args: Iterable[Any],
|
|
301
|
+
name: str = "map",
|
|
302
|
+
concurrency: int = 0,
|
|
303
|
+
return_exceptions: bool = True,
|
|
304
|
+
) -> AsyncIterator[Union[R, Exception]]:
|
|
305
|
+
iter = MapAsyncIterator(
|
|
306
|
+
func=func, args=args, name=name, concurrency=concurrency, return_exceptions=return_exceptions
|
|
307
|
+
)
|
|
308
|
+
async for result in iter:
|
|
309
|
+
yield result
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
map: _Mapper = _Mapper()
|