flyte 0.2.0b1__py3-none-any.whl → 2.0.0b46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. flyte/__init__.py +83 -30
  2. flyte/_bin/connect.py +61 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +87 -19
  5. flyte/_bin/serve.py +351 -0
  6. flyte/_build.py +3 -2
  7. flyte/_cache/cache.py +6 -5
  8. flyte/_cache/local_cache.py +216 -0
  9. flyte/_code_bundle/_ignore.py +31 -5
  10. flyte/_code_bundle/_packaging.py +42 -11
  11. flyte/_code_bundle/_utils.py +57 -34
  12. flyte/_code_bundle/bundle.py +130 -27
  13. flyte/_constants.py +1 -0
  14. flyte/_context.py +21 -5
  15. flyte/_custom_context.py +73 -0
  16. flyte/_debug/constants.py +37 -0
  17. flyte/_debug/utils.py +17 -0
  18. flyte/_debug/vscode.py +315 -0
  19. flyte/_deploy.py +396 -75
  20. flyte/_deployer.py +109 -0
  21. flyte/_environment.py +94 -11
  22. flyte/_excepthook.py +37 -0
  23. flyte/_group.py +2 -1
  24. flyte/_hash.py +1 -16
  25. flyte/_image.py +544 -231
  26. flyte/_initialize.py +456 -316
  27. flyte/_interface.py +40 -5
  28. flyte/_internal/controllers/__init__.py +22 -8
  29. flyte/_internal/controllers/_local_controller.py +159 -35
  30. flyte/_internal/controllers/_trace.py +18 -10
  31. flyte/_internal/controllers/remote/__init__.py +38 -9
  32. flyte/_internal/controllers/remote/_action.py +82 -12
  33. flyte/_internal/controllers/remote/_client.py +6 -2
  34. flyte/_internal/controllers/remote/_controller.py +290 -64
  35. flyte/_internal/controllers/remote/_core.py +155 -95
  36. flyte/_internal/controllers/remote/_informer.py +40 -20
  37. flyte/_internal/controllers/remote/_service_protocol.py +2 -2
  38. flyte/_internal/imagebuild/__init__.py +2 -10
  39. flyte/_internal/imagebuild/docker_builder.py +391 -84
  40. flyte/_internal/imagebuild/image_builder.py +111 -55
  41. flyte/_internal/imagebuild/remote_builder.py +409 -0
  42. flyte/_internal/imagebuild/utils.py +79 -0
  43. flyte/_internal/resolvers/_app_env_module.py +92 -0
  44. flyte/_internal/resolvers/_task_module.py +5 -38
  45. flyte/_internal/resolvers/app_env.py +26 -0
  46. flyte/_internal/resolvers/common.py +8 -1
  47. flyte/_internal/resolvers/default.py +2 -2
  48. flyte/_internal/runtime/convert.py +319 -36
  49. flyte/_internal/runtime/entrypoints.py +106 -18
  50. flyte/_internal/runtime/io.py +71 -23
  51. flyte/_internal/runtime/resources_serde.py +21 -7
  52. flyte/_internal/runtime/reuse.py +125 -0
  53. flyte/_internal/runtime/rusty.py +196 -0
  54. flyte/_internal/runtime/task_serde.py +239 -66
  55. flyte/_internal/runtime/taskrunner.py +48 -8
  56. flyte/_internal/runtime/trigger_serde.py +162 -0
  57. flyte/_internal/runtime/types_serde.py +7 -16
  58. flyte/_keyring/file.py +115 -0
  59. flyte/_link.py +30 -0
  60. flyte/_logging.py +241 -42
  61. flyte/_map.py +312 -0
  62. flyte/_metrics.py +59 -0
  63. flyte/_module.py +74 -0
  64. flyte/_pod.py +30 -0
  65. flyte/_resources.py +296 -33
  66. flyte/_retry.py +1 -7
  67. flyte/_reusable_environment.py +72 -7
  68. flyte/_run.py +462 -132
  69. flyte/_secret.py +47 -11
  70. flyte/_serve.py +333 -0
  71. flyte/_task.py +245 -56
  72. flyte/_task_environment.py +219 -97
  73. flyte/_task_plugins.py +47 -0
  74. flyte/_tools.py +8 -8
  75. flyte/_trace.py +15 -24
  76. flyte/_trigger.py +1027 -0
  77. flyte/_utils/__init__.py +12 -1
  78. flyte/_utils/asyn.py +3 -1
  79. flyte/_utils/async_cache.py +139 -0
  80. flyte/_utils/coro_management.py +5 -4
  81. flyte/_utils/description_parser.py +19 -0
  82. flyte/_utils/docker_credentials.py +173 -0
  83. flyte/_utils/helpers.py +45 -19
  84. flyte/_utils/module_loader.py +123 -0
  85. flyte/_utils/org_discovery.py +57 -0
  86. flyte/_utils/uv_script_parser.py +8 -1
  87. flyte/_version.py +16 -3
  88. flyte/app/__init__.py +27 -0
  89. flyte/app/_app_environment.py +362 -0
  90. flyte/app/_connector_environment.py +40 -0
  91. flyte/app/_deploy.py +130 -0
  92. flyte/app/_parameter.py +343 -0
  93. flyte/app/_runtime/__init__.py +3 -0
  94. flyte/app/_runtime/app_serde.py +383 -0
  95. flyte/app/_types.py +113 -0
  96. flyte/app/extras/__init__.py +9 -0
  97. flyte/app/extras/_auth_middleware.py +217 -0
  98. flyte/app/extras/_fastapi.py +93 -0
  99. flyte/app/extras/_model_loader/__init__.py +3 -0
  100. flyte/app/extras/_model_loader/config.py +7 -0
  101. flyte/app/extras/_model_loader/loader.py +288 -0
  102. flyte/cli/__init__.py +12 -0
  103. flyte/cli/_abort.py +28 -0
  104. flyte/cli/_build.py +114 -0
  105. flyte/cli/_common.py +493 -0
  106. flyte/cli/_create.py +371 -0
  107. flyte/cli/_delete.py +45 -0
  108. flyte/cli/_deploy.py +401 -0
  109. flyte/cli/_gen.py +316 -0
  110. flyte/cli/_get.py +446 -0
  111. flyte/cli/_option.py +33 -0
  112. flyte/{_cli → cli}/_params.py +57 -17
  113. flyte/cli/_plugins.py +209 -0
  114. flyte/cli/_prefetch.py +292 -0
  115. flyte/cli/_run.py +690 -0
  116. flyte/cli/_serve.py +338 -0
  117. flyte/cli/_update.py +86 -0
  118. flyte/cli/_user.py +20 -0
  119. flyte/cli/main.py +246 -0
  120. flyte/config/__init__.py +2 -167
  121. flyte/config/_config.py +215 -163
  122. flyte/config/_internal.py +10 -1
  123. flyte/config/_reader.py +225 -0
  124. flyte/connectors/__init__.py +11 -0
  125. flyte/connectors/_connector.py +330 -0
  126. flyte/connectors/_server.py +194 -0
  127. flyte/connectors/utils.py +159 -0
  128. flyte/errors.py +134 -2
  129. flyte/extend.py +24 -0
  130. flyte/extras/_container.py +69 -56
  131. flyte/git/__init__.py +3 -0
  132. flyte/git/_config.py +279 -0
  133. flyte/io/__init__.py +8 -1
  134. flyte/io/{structured_dataset → _dataframe}/__init__.py +32 -30
  135. flyte/io/{structured_dataset → _dataframe}/basic_dfs.py +75 -68
  136. flyte/io/{structured_dataset/structured_dataset.py → _dataframe/dataframe.py} +207 -242
  137. flyte/io/_dir.py +575 -113
  138. flyte/io/_file.py +587 -141
  139. flyte/io/_hashing_io.py +342 -0
  140. flyte/io/extend.py +7 -0
  141. flyte/models.py +635 -0
  142. flyte/prefetch/__init__.py +22 -0
  143. flyte/prefetch/_hf_model.py +563 -0
  144. flyte/remote/__init__.py +14 -3
  145. flyte/remote/_action.py +879 -0
  146. flyte/remote/_app.py +346 -0
  147. flyte/remote/_auth_metadata.py +42 -0
  148. flyte/remote/_client/_protocols.py +62 -4
  149. flyte/remote/_client/auth/_auth_utils.py +19 -0
  150. flyte/remote/_client/auth/_authenticators/base.py +8 -2
  151. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  152. flyte/remote/_client/auth/_authenticators/factory.py +4 -0
  153. flyte/remote/_client/auth/_authenticators/passthrough.py +79 -0
  154. flyte/remote/_client/auth/_authenticators/pkce.py +17 -18
  155. flyte/remote/_client/auth/_channel.py +47 -18
  156. flyte/remote/_client/auth/_client_config.py +5 -3
  157. flyte/remote/_client/auth/_keyring.py +15 -2
  158. flyte/remote/_client/auth/_token_client.py +3 -3
  159. flyte/remote/_client/controlplane.py +206 -18
  160. flyte/remote/_common.py +66 -0
  161. flyte/remote/_data.py +107 -22
  162. flyte/remote/_logs.py +116 -33
  163. flyte/remote/_project.py +21 -19
  164. flyte/remote/_run.py +164 -631
  165. flyte/remote/_secret.py +72 -29
  166. flyte/remote/_task.py +387 -46
  167. flyte/remote/_trigger.py +368 -0
  168. flyte/remote/_user.py +43 -0
  169. flyte/report/_report.py +10 -6
  170. flyte/storage/__init__.py +13 -1
  171. flyte/storage/_config.py +237 -0
  172. flyte/storage/_parallel_reader.py +289 -0
  173. flyte/storage/_storage.py +268 -59
  174. flyte/syncify/__init__.py +56 -0
  175. flyte/syncify/_api.py +414 -0
  176. flyte/types/__init__.py +39 -0
  177. flyte/types/_interface.py +22 -7
  178. flyte/{io/pickle/transformer.py → types/_pickle.py} +37 -9
  179. flyte/types/_string_literals.py +8 -9
  180. flyte/types/_type_engine.py +226 -126
  181. flyte/types/_utils.py +1 -1
  182. flyte-2.0.0b46.data/scripts/debug.py +38 -0
  183. flyte-2.0.0b46.data/scripts/runtime.py +194 -0
  184. flyte-2.0.0b46.dist-info/METADATA +352 -0
  185. flyte-2.0.0b46.dist-info/RECORD +221 -0
  186. flyte-2.0.0b46.dist-info/entry_points.txt +8 -0
  187. flyte-2.0.0b46.dist-info/licenses/LICENSE +201 -0
  188. flyte/_api_commons.py +0 -3
  189. flyte/_cli/_common.py +0 -299
  190. flyte/_cli/_create.py +0 -42
  191. flyte/_cli/_delete.py +0 -23
  192. flyte/_cli/_deploy.py +0 -140
  193. flyte/_cli/_get.py +0 -235
  194. flyte/_cli/_run.py +0 -174
  195. flyte/_cli/main.py +0 -98
  196. flyte/_datastructures.py +0 -342
  197. flyte/_internal/controllers/pbhash.py +0 -39
  198. flyte/_protos/common/authorization_pb2.py +0 -66
  199. flyte/_protos/common/authorization_pb2.pyi +0 -108
  200. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  201. flyte/_protos/common/identifier_pb2.py +0 -71
  202. flyte/_protos/common/identifier_pb2.pyi +0 -82
  203. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  204. flyte/_protos/common/identity_pb2.py +0 -48
  205. flyte/_protos/common/identity_pb2.pyi +0 -72
  206. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  207. flyte/_protos/common/list_pb2.py +0 -36
  208. flyte/_protos/common/list_pb2.pyi +0 -69
  209. flyte/_protos/common/list_pb2_grpc.py +0 -4
  210. flyte/_protos/common/policy_pb2.py +0 -37
  211. flyte/_protos/common/policy_pb2.pyi +0 -27
  212. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  213. flyte/_protos/common/role_pb2.py +0 -37
  214. flyte/_protos/common/role_pb2.pyi +0 -53
  215. flyte/_protos/common/role_pb2_grpc.py +0 -4
  216. flyte/_protos/common/runtime_version_pb2.py +0 -28
  217. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  218. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  219. flyte/_protos/logs/dataplane/payload_pb2.py +0 -96
  220. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -168
  221. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  222. flyte/_protos/secret/definition_pb2.py +0 -49
  223. flyte/_protos/secret/definition_pb2.pyi +0 -93
  224. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  225. flyte/_protos/secret/payload_pb2.py +0 -62
  226. flyte/_protos/secret/payload_pb2.pyi +0 -94
  227. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  228. flyte/_protos/secret/secret_pb2.py +0 -38
  229. flyte/_protos/secret/secret_pb2.pyi +0 -6
  230. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  231. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  232. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  233. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  234. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  235. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  236. flyte/_protos/workflow/queue_service_pb2.py +0 -106
  237. flyte/_protos/workflow/queue_service_pb2.pyi +0 -141
  238. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  239. flyte/_protos/workflow/run_definition_pb2.py +0 -128
  240. flyte/_protos/workflow/run_definition_pb2.pyi +0 -310
  241. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  242. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  243. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  244. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  245. flyte/_protos/workflow/run_service_pb2.py +0 -133
  246. flyte/_protos/workflow/run_service_pb2.pyi +0 -175
  247. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -412
  248. flyte/_protos/workflow/state_service_pb2.py +0 -58
  249. flyte/_protos/workflow/state_service_pb2.pyi +0 -71
  250. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  251. flyte/_protos/workflow/task_definition_pb2.py +0 -72
  252. flyte/_protos/workflow/task_definition_pb2.pyi +0 -65
  253. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  254. flyte/_protos/workflow/task_service_pb2.py +0 -44
  255. flyte/_protos/workflow/task_service_pb2.pyi +0 -31
  256. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -104
  257. flyte/io/_dataframe.py +0 -0
  258. flyte/io/pickle/__init__.py +0 -0
  259. flyte/remote/_console.py +0 -18
  260. flyte-0.2.0b1.dist-info/METADATA +0 -179
  261. flyte-0.2.0b1.dist-info/RECORD +0 -204
  262. flyte-0.2.0b1.dist-info/entry_points.txt +0 -3
  263. /flyte/{_cli → _debug}/__init__.py +0 -0
  264. /flyte/{_protos → _keyring}/__init__.py +0 -0
  265. {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/WHEEL +0 -0
  266. {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/top_level.txt +0 -0
flyte/_logging.py CHANGED
@@ -1,12 +1,36 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import json
3
4
  import logging
4
5
  import os
5
- from typing import Optional
6
+ from datetime import datetime
7
+ from typing import Literal, Optional
6
8
 
7
- from ._tools import ipython_check, is_in_cluster
9
+ import flyte
8
10
 
9
- DEFAULT_LOG_LEVEL = logging.INFO
11
+ from ._tools import ipython_check
12
+
13
+ LogFormat = Literal["console", "json"]
14
+ _LOG_LEVEL_MAP = {
15
+ "critical": logging.CRITICAL, # 50
16
+ "error": logging.ERROR, # 40
17
+ "warning": logging.WARNING, # 30
18
+ "warn": logging.WARNING, # 30
19
+ "info": logging.INFO, # 20
20
+ "debug": logging.DEBUG, # 10
21
+ }
22
+ DEFAULT_LOG_LEVEL = logging.WARNING
23
+
24
+
25
+ def make_hyperlink(label: str, url: str):
26
+ """
27
+ Create a hyperlink in the terminal output.
28
+ """
29
+ BLUE = "\033[94m"
30
+ RESET = "\033[0m"
31
+ OSC8_BEGIN = f"\033]8;;{url}\033\\"
32
+ OSC8_END = "\033]8;;\033\\"
33
+ return f"{BLUE}{OSC8_BEGIN}{label}{RESET}{OSC8_END}"
10
34
 
11
35
 
12
36
  def is_rich_logging_disabled() -> bool:
@@ -17,43 +41,69 @@ def is_rich_logging_disabled() -> bool:
17
41
 
18
42
 
19
43
  def get_env_log_level() -> int:
20
- return int(os.environ.get("LOG_LEVEL", DEFAULT_LOG_LEVEL))
44
+ value = os.getenv("LOG_LEVEL")
45
+ if value is None:
46
+ return DEFAULT_LOG_LEVEL
47
+ # Case 1: numeric value ("10", "20", "5", etc.)
48
+ if value.isdigit():
49
+ return int(value)
21
50
 
51
+ # Case 2: named log level ("info", "debug", ...)
52
+ if value.lower() in _LOG_LEVEL_MAP:
53
+ return _LOG_LEVEL_MAP[value.lower()]
22
54
 
23
- def log_format_from_env() -> str:
55
+ return DEFAULT_LOG_LEVEL
56
+
57
+
58
+ def log_format_from_env() -> LogFormat:
24
59
  """
25
60
  Get the log format from the environment variable.
26
61
  """
27
- return os.environ.get("LOG_FORMAT", "json")
62
+ format_str = os.environ.get("LOG_FORMAT", "console")
63
+ if format_str not in ("console", "json"):
64
+ return "console"
65
+ return format_str # type: ignore[return-value]
66
+
67
+
68
+ def _get_console():
69
+ """
70
+ Get the console.
71
+ """
72
+ from rich.console import Console
73
+
74
+ try:
75
+ width = os.get_terminal_size().columns
76
+ except Exception as e:
77
+ logger.debug(f"Failed to get terminal size: {e}")
78
+ width = 160
79
+
80
+ return Console(width=width)
28
81
 
29
82
 
30
83
  def get_rich_handler(log_level: int) -> Optional[logging.Handler]:
31
84
  """
32
85
  Upgrades the global loggers to use Rich logging.
33
86
  """
34
- if is_in_cluster():
87
+ ctx = flyte.ctx()
88
+ if ctx and ctx.is_in_cluster():
35
89
  return None
36
90
  if not ipython_check() and is_rich_logging_disabled():
37
91
  return None
38
92
 
39
93
  import click
40
- from rich.console import Console
94
+ from rich.highlighter import NullHighlighter
41
95
  from rich.logging import RichHandler
42
96
 
43
- try:
44
- width = os.get_terminal_size().columns
45
- except Exception as e:
46
- logger.debug(f"Failed to get terminal size: {e}")
47
- width = 160
48
-
49
97
  handler = RichHandler(
50
98
  tracebacks_suppress=[click],
51
- rich_tracebacks=True,
99
+ rich_tracebacks=False,
52
100
  omit_repeated_times=False,
53
101
  show_path=False,
54
102
  log_time_format="%H:%M:%S.%f",
55
- console=Console(width=width),
103
+ console=_get_console(),
56
104
  level=log_level,
105
+ highlighter=NullHighlighter(),
106
+ markup=True,
57
107
  )
58
108
 
59
109
  formatter = logging.Formatter(fmt="%(filename)s:%(lineno)d - %(message)s")
@@ -61,39 +111,99 @@ def get_rich_handler(log_level: int) -> Optional[logging.Handler]:
61
111
  return handler
62
112
 
63
113
 
64
- def get_default_handler(log_level: int) -> logging.Handler:
65
- handler = logging.StreamHandler()
66
- handler.setLevel(log_level)
67
- formatter = logging.Formatter(fmt="[%(name)s] %(message)s")
68
- if log_format_from_env() == "json":
69
- pass
70
- # formatter = jsonlogger.JsonFormatter(fmt="%(asctime)s %(name)s %(levelname)s %(message)s")
71
- handler.setFormatter(formatter)
72
- return handler
114
+ class JSONFormatter(logging.Formatter):
115
+ """
116
+ Formatter that outputs JSON strings for each log record.
117
+ """
118
+
119
+ def format(self, record: logging.LogRecord) -> str:
120
+ log_data = {
121
+ "timestamp": datetime.fromtimestamp(record.created).isoformat(),
122
+ "level": record.levelname,
123
+ "logger": record.name,
124
+ "message": record.getMessage(),
125
+ "filename": record.filename,
126
+ "lineno": record.lineno,
127
+ "funcName": record.funcName,
128
+ }
129
+
130
+ # Add context fields if present
131
+ if getattr(record, "run_name", None):
132
+ log_data["run_name"] = record.run_name # type: ignore[attr-defined]
133
+ if getattr(record, "action_name", None):
134
+ log_data["action_name"] = record.action_name # type: ignore[attr-defined]
135
+ if getattr(record, "is_flyte_internal", False):
136
+ log_data["is_flyte_internal"] = True
73
137
 
138
+ # Add metric fields if present
139
+ if getattr(record, "metric_type", None):
140
+ log_data["metric_type"] = record.metric_type # type: ignore[attr-defined]
141
+ log_data["metric_name"] = record.metric_name # type: ignore[attr-defined]
142
+ log_data["duration_seconds"] = record.duration_seconds # type: ignore[attr-defined]
74
143
 
75
- def initialize_logger(log_level: int = DEFAULT_LOG_LEVEL, enable_rich: bool = False):
144
+ # Add exception info if present
145
+ if record.exc_info:
146
+ log_data["exc_info"] = self.formatException(record.exc_info)
147
+
148
+ return json.dumps(log_data)
149
+
150
+
151
+ def initialize_logger(
152
+ log_level: int | None = None,
153
+ log_format: LogFormat | None = None,
154
+ enable_rich: bool = False,
155
+ reset_root_logger: bool = False,
156
+ ):
76
157
  """
77
158
  Initializes the global loggers to the default configuration.
159
+ When enable_rich=True, upgrades to Rich handler for local CLI usage.
78
160
  """
79
161
  global logger # noqa: PLW0603
80
- logger = _create_logger("flyte", log_level, enable_rich)
81
162
 
163
+ if log_level is None:
164
+ log_level = get_env_log_level()
165
+ if log_format is None:
166
+ log_format = log_format_from_env()
82
167
 
83
- def _create_logger(name: str, log_level: int = DEFAULT_LOG_LEVEL, enable_rich: bool = False) -> logging.Logger:
84
- """
85
- Creates a logger with the given name and log level.
86
- """
87
- logger = logging.getLogger(name)
88
- logger.setLevel(log_level)
89
- handler = None
90
- logger.handlers = []
91
- if enable_rich:
92
- handler = get_rich_handler(log_level)
93
- if handler is None:
94
- handler = get_default_handler(log_level)
95
- logger.addHandler(handler)
96
- return logger
168
+ flyte_logger = logging.getLogger("flyte")
169
+ flyte_logger.handlers.clear()
170
+
171
+ # Determine log format (JSON takes precedence over Rich)
172
+ use_json = log_format == "json"
173
+ use_rich = enable_rich and not use_json
174
+
175
+ reset_root_logger = reset_root_logger or os.environ.get("FLYTE_RESET_ROOT_LOGGER") == "1"
176
+ if reset_root_logger:
177
+ _setup_root_logger(use_json=use_json, use_rich=use_rich, log_level=log_level)
178
+ else:
179
+ root_logger = logging.getLogger()
180
+ for h in root_logger.handlers:
181
+ h.addFilter(ContextFilter())
182
+
183
+ # Set up Flyte logger handler
184
+ flyte_handler: logging.Handler | None = None
185
+ if use_json:
186
+ flyte_handler = logging.StreamHandler()
187
+ flyte_handler.setLevel(log_level)
188
+ flyte_handler.setFormatter(JSONFormatter())
189
+ elif use_rich:
190
+ flyte_handler = get_rich_handler(log_level)
191
+
192
+ if flyte_handler is None:
193
+ flyte_handler = logging.StreamHandler()
194
+ flyte_handler.setLevel(log_level)
195
+ formatter = logging.Formatter(fmt="%(message)s")
196
+ flyte_handler.setFormatter(formatter)
197
+
198
+ # Add both filters to Flyte handler
199
+ flyte_handler.addFilter(FlyteInternalFilter())
200
+ flyte_handler.addFilter(ContextFilter())
201
+
202
+ flyte_logger.addHandler(flyte_handler)
203
+ flyte_logger.setLevel(log_level)
204
+ flyte_logger.propagate = False # Prevent double logging
205
+
206
+ logger = flyte_logger
97
207
 
98
208
 
99
209
  def log(fn=None, *, level=logging.DEBUG, entry=True, exit=True):
@@ -121,4 +231,93 @@ def log(fn=None, *, level=logging.DEBUG, entry=True, exit=True):
121
231
  return decorator(fn)
122
232
 
123
233
 
124
- logger = _create_logger("flyte", get_env_log_level())
234
+ class ContextFilter(logging.Filter):
235
+ """
236
+ A logging filter that adds the current action's run name and name to all log records.
237
+ Applied globally to capture context for both user and Flyte internal logging.
238
+ """
239
+
240
+ def filter(self, record: logging.LogRecord) -> bool:
241
+ from flyte._context import ctx
242
+
243
+ c = ctx()
244
+ if c:
245
+ action = c.action
246
+ # Add as attributes for structured logging (JSON)
247
+ record.run_name = action.run_name
248
+ record.action_name = action.name
249
+ # Also modify message for console/Rich output
250
+ record.msg = f"[{action.run_name}][{action.name}] {record.msg}"
251
+ else:
252
+ record.run_name = None
253
+ record.action_name = None
254
+ return True
255
+
256
+
257
+ class FlyteInternalFilter(logging.Filter):
258
+ """
259
+ A logging filter that adds [flyte] prefix to internal Flyte logging only.
260
+ """
261
+
262
+ def filter(self, record: logging.LogRecord) -> bool:
263
+ is_internal = record.name.startswith("flyte")
264
+ # Add as attribute for structured logging (JSON)
265
+ record.is_flyte_internal = is_internal
266
+ # Also modify message for console/Rich output
267
+ if is_internal:
268
+ record.msg = f"[flyte] {record.msg}"
269
+ return True
270
+
271
+
272
+ def _setup_root_logger(use_json: bool, use_rich: bool, log_level: int):
273
+ """
274
+ Wipe all handlers from the root logger and reconfigure. This ensures
275
+ both user/library logging and Flyte internal logging get context information and look the same.
276
+ """
277
+ root = logging.getLogger()
278
+ root.handlers.clear() # Remove any existing handlers to prevent double logging
279
+
280
+ root_handler: logging.Handler | None = None
281
+ if use_json:
282
+ root_handler = logging.StreamHandler()
283
+ root_handler.setFormatter(JSONFormatter())
284
+ elif use_rich:
285
+ root_handler = get_rich_handler(log_level)
286
+
287
+ # get_rich_handler can return None in some environments
288
+ if not root_handler:
289
+ root_handler = logging.StreamHandler()
290
+
291
+ # Add context filter to ALL logging
292
+ root_handler.addFilter(ContextFilter())
293
+ root_handler.setLevel(log_level)
294
+
295
+ root.addHandler(root_handler)
296
+ root.setLevel(log_level)
297
+
298
+
299
+ def _create_flyte_logger() -> logging.Logger:
300
+ """
301
+ Create the internal Flyte logger with [flyte] prefix.
302
+ """
303
+ flyte_logger = logging.getLogger("flyte")
304
+ flyte_logger.setLevel(get_env_log_level())
305
+
306
+ # Add a handler specifically for flyte logging with the prefix filter
307
+ handler = logging.StreamHandler()
308
+ handler.setLevel(get_env_log_level())
309
+ handler.addFilter(FlyteInternalFilter())
310
+ handler.addFilter(ContextFilter())
311
+
312
+ formatter = logging.Formatter(fmt="%(message)s")
313
+ handler.setFormatter(formatter)
314
+
315
+ # Prevent propagation to root to avoid double logging
316
+ flyte_logger.propagate = False
317
+ flyte_logger.addHandler(handler)
318
+
319
+ return flyte_logger
320
+
321
+
322
+ # Create the Flyte internal logger
323
+ logger = _create_flyte_logger()
flyte/_map.py ADDED
@@ -0,0 +1,312 @@
1
+ import asyncio
2
+ import functools
3
+ import logging
4
+ from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast, overload
5
+
6
+ from flyte.syncify import syncify
7
+
8
+ from ._group import group
9
+ from ._logging import logger
10
+ from ._task import AsyncFunctionTaskTemplate, F, P, R
11
+
12
+
13
+ class MapAsyncIterator(Generic[P, R]):
14
+ """AsyncIterator implementation for the map function results"""
15
+
16
+ def __init__(
17
+ self,
18
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
19
+ args: tuple,
20
+ name: str,
21
+ concurrency: int,
22
+ return_exceptions: bool,
23
+ ):
24
+ self.func = func
25
+ self.args = args
26
+ self.name = name
27
+ self.concurrency = concurrency
28
+ self.return_exceptions = return_exceptions
29
+ self._tasks: List[asyncio.Task] = []
30
+ self._current_index = 0
31
+ self._completed_count = 0
32
+ self._exception_count = 0
33
+ self._task_count = 0
34
+ self._initialized = False
35
+
36
+ def __aiter__(self) -> AsyncIterator[Union[R, Exception]]:
37
+ """Return self as the async iterator"""
38
+ return self
39
+
40
+ async def __anext__(self) -> Union[R, Exception]:
41
+ """Get the next result"""
42
+ # Initialize on first call
43
+ if not self._initialized:
44
+ await self._initialize()
45
+
46
+ # Check if we've exhausted all tasks
47
+ if self._current_index >= self._task_count:
48
+ raise StopAsyncIteration
49
+
50
+ # Get the next task result
51
+ task = self._tasks[self._current_index]
52
+ self._current_index += 1
53
+
54
+ try:
55
+ result = await task
56
+ self._completed_count += 1
57
+ logger.debug(f"Task {self._current_index - 1} completed successfully")
58
+ return result
59
+ except Exception as e:
60
+ self._exception_count += 1
61
+ logger.debug(
62
+ f"Task {self._current_index - 1} failed with exception: {e}, return_exceptions={self.return_exceptions}"
63
+ )
64
+ if self.return_exceptions:
65
+ return e
66
+ else:
67
+ # Cancel remaining tasks
68
+ for remaining_task in self._tasks[self._current_index + 1 :]:
69
+ remaining_task.cancel()
70
+ logger.warning("Exception raising is `ON`, raising exception and cancelling remaining tasks")
71
+ raise e
72
+
73
+ async def _initialize(self):
74
+ """Initialize the tasks - called lazily on first iteration"""
75
+ # Create all tasks at once
76
+ tasks = []
77
+ task_count = 0
78
+
79
+ if isinstance(self.func, functools.partial):
80
+ # Handle partial functions by merging bound args/kwargs with mapped args
81
+ base_func = cast(AsyncFunctionTaskTemplate, self.func.func)
82
+ bound_args = self.func.args
83
+ bound_kwargs = self.func.keywords or {}
84
+
85
+ for arg_tuple in zip(*self.args):
86
+ # Merge bound positional args with mapped args
87
+ merged_args = bound_args + arg_tuple
88
+ if logger.isEnabledFor(logging.DEBUG):
89
+ logger.debug(f"Running {base_func.name} with args: {merged_args} and kwargs: {bound_kwargs}")
90
+ task = asyncio.create_task(base_func.aio(*merged_args, **bound_kwargs))
91
+ tasks.append(task)
92
+ task_count += 1
93
+ else:
94
+ # Handle regular TaskTemplate functions
95
+ for arg_tuple in zip(*self.args):
96
+ task = asyncio.create_task(self.func.aio(*arg_tuple))
97
+ tasks.append(task)
98
+ task_count += 1
99
+
100
+ if task_count == 0:
101
+ logger.info(f"Group '{self.name}' has no tasks to process")
102
+ self._tasks = []
103
+ self._task_count = 0
104
+ else:
105
+ logger.info(f"Starting {task_count} tasks in group '{self.name}' with unlimited concurrency")
106
+ self._tasks = tasks
107
+ self._task_count = task_count
108
+
109
+ self._initialized = True
110
+
111
+ async def collect(self) -> List[Union[R, Exception]]:
112
+ """Convenience method to collect all results into a list"""
113
+ results = []
114
+ async for result in self:
115
+ results.append(result)
116
+ return results
117
+
118
+ def __repr__(self):
119
+ return f"MapAsyncIterator(group_name='{self.name}', concurrency={self.concurrency})"
120
+
121
+
122
+ class _Mapper(Generic[P, R]):
123
+ """
124
+ Internal mapper class to handle the mapping logic
125
+
126
+ NOTE: The reason why we do not use the `@syncify` decorator here is because, in `syncify` we cannot use
127
+ context managers like `group` directly in the function body. This is because the `__exit__` method of the
128
+ context manager is called after the function returns. An for `_context` the `__exit__` method releases the
129
+ token (for contextvar), which was created in a separate thread. This leads to an exception like:
130
+
131
+ """
132
+
133
+ @classmethod
134
+ def _get_name(cls, task_name: str, group_name: str | None) -> str:
135
+ """Get the name of the group, defaulting to 'map' if not provided."""
136
+ return f"{task_name}_{group_name or 'map'}"
137
+
138
+ @staticmethod
139
+ def validate_partial(func: functools.partial[R]):
140
+ """
141
+ This method validates that the provided partial function is valid for mapping, i.e. only the one argument
142
+ is left for mapping and the rest are provided as keywords or args.
143
+
144
+ :param func: partial function to validate
145
+ :raises TypeError: if the partial function is not valid for mapping
146
+ """
147
+ f = cast(AsyncFunctionTaskTemplate, func.func)
148
+ inputs = f.native_interface.inputs
149
+ params = list(inputs.keys())
150
+ total_params = len(params)
151
+ provided_args = len(func.args)
152
+ provided_kwargs = len(func.keywords or {})
153
+
154
+ # Calculate how many parameters are left unspecified
155
+ unspecified_count = total_params - provided_args - provided_kwargs
156
+
157
+ # Exactly one parameter should be left for mapping
158
+ if unspecified_count != 1:
159
+ raise TypeError(
160
+ f"Partial function must leave exactly one parameter unspecified for mapping. "
161
+ f"Found {unspecified_count} unspecified parameters in {f.name}, "
162
+ f"params: {inputs.keys()}"
163
+ )
164
+
165
+ # Validate that no parameter is both in args and keywords
166
+ if func.keywords:
167
+ param_names = list(inputs.keys())
168
+ for i, arg_name in enumerate(param_names[: provided_args + 1]):
169
+ if arg_name in func.keywords:
170
+ raise TypeError(
171
+ f"Parameter '{arg_name}' is provided both as positional argument and keyword argument "
172
+ f"in partial function {f.name}."
173
+ )
174
+
175
+ @overload
176
+ def __call__(
177
+ self,
178
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
179
+ *args: Iterable[Any],
180
+ group_name: str | None = None,
181
+ concurrency: int = 0,
182
+ ) -> Iterator[R]: ...
183
+
184
+ @overload
185
+ def __call__(
186
+ self,
187
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
188
+ *args: Iterable[Any],
189
+ group_name: str | None = None,
190
+ concurrency: int = 0,
191
+ return_exceptions: bool = True,
192
+ ) -> Iterator[Union[R, Exception]]: ...
193
+
194
+ def __call__(
195
+ self,
196
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
197
+ *args: Iterable[Any],
198
+ group_name: str | None = None,
199
+ concurrency: int = 0,
200
+ return_exceptions: bool = True,
201
+ ) -> Iterator[Union[R, Exception]]:
202
+ """
203
+ Map a function over the provided arguments with concurrent execution.
204
+
205
+ :param func: The async function to map.
206
+ :param args: Positional arguments to pass to the function (iterables that will be zipped).
207
+ :param group_name: The name of the group for the mapped tasks.
208
+ :param concurrency: The maximum number of concurrent tasks to run. If 0, run all tasks concurrently.
209
+ :param return_exceptions: If True, yield exceptions instead of raising them.
210
+ :return: AsyncIterator yielding results in order.
211
+ """
212
+ if not args:
213
+ return
214
+
215
+ if isinstance(func, functools.partial):
216
+ f = cast(AsyncFunctionTaskTemplate, func.func)
217
+ self.validate_partial(func)
218
+ else:
219
+ f = cast(AsyncFunctionTaskTemplate, func)
220
+
221
+ name = self._get_name(f.name, group_name)
222
+ logger.debug(f"Blocking Map for {name}")
223
+ with group(name):
224
+ import flyte
225
+
226
+ tctx = flyte.ctx()
227
+ if tctx is None or tctx.mode == "local":
228
+ logger.warning("Running map in local mode, which will run every task sequentially.")
229
+ for v in zip(*args):
230
+ try:
231
+ yield func(*v) # type: ignore
232
+ except Exception as e:
233
+ if return_exceptions:
234
+ yield e
235
+ else:
236
+ raise e
237
+ return
238
+
239
+ i = 0
240
+ for x in cast(
241
+ Iterator[R],
242
+ _map(
243
+ func,
244
+ *args,
245
+ name=name,
246
+ concurrency=concurrency,
247
+ return_exceptions=return_exceptions,
248
+ ),
249
+ ):
250
+ logger.debug(f"Mapped {x}, task {i}")
251
+ i += 1
252
+ yield x
253
+
254
+ async def aio(
255
+ self,
256
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
257
+ *args: Iterable[Any],
258
+ group_name: str | None = None,
259
+ concurrency: int = 0,
260
+ return_exceptions: bool = True,
261
+ ) -> AsyncGenerator[Union[R, Exception], None]:
262
+ if not args:
263
+ return
264
+
265
+ if isinstance(func, functools.partial):
266
+ f = cast(AsyncFunctionTaskTemplate, func.func)
267
+ self.validate_partial(func)
268
+ else:
269
+ f = cast(AsyncFunctionTaskTemplate, func)
270
+
271
+ name = self._get_name(f.name, group_name)
272
+ with group(name):
273
+ import flyte
274
+
275
+ tctx = flyte.ctx()
276
+ if tctx is None or tctx.mode == "local":
277
+ logger.warning("Running map in local mode, which will run every task sequentially.")
278
+ for v in zip(*args):
279
+ try:
280
+ yield func(*v) # type: ignore
281
+ except Exception as e:
282
+ if return_exceptions:
283
+ yield e
284
+ else:
285
+ raise e
286
+ return
287
+ async for x in _map.aio(
288
+ func,
289
+ *args,
290
+ name=name,
291
+ concurrency=concurrency,
292
+ return_exceptions=return_exceptions,
293
+ ):
294
+ yield cast(Union[R, Exception], x)
295
+
296
+
297
+ @syncify
298
+ async def _map(
299
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
300
+ *args: Iterable[Any],
301
+ name: str = "map",
302
+ concurrency: int = 0,
303
+ return_exceptions: bool = True,
304
+ ) -> AsyncIterator[Union[R, Exception]]:
305
+ iter = MapAsyncIterator(
306
+ func=func, args=args, name=name, concurrency=concurrency, return_exceptions=return_exceptions
307
+ )
308
+ async for result in iter:
309
+ yield result
310
+
311
+
312
+ map: _Mapper = _Mapper()