flyte 2.0.0b32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (204) hide show
  1. flyte/__init__.py +108 -0
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +195 -0
  5. flyte/_bin/serve.py +178 -0
  6. flyte/_build.py +26 -0
  7. flyte/_cache/__init__.py +12 -0
  8. flyte/_cache/cache.py +147 -0
  9. flyte/_cache/defaults.py +9 -0
  10. flyte/_cache/local_cache.py +216 -0
  11. flyte/_cache/policy_function_body.py +42 -0
  12. flyte/_code_bundle/__init__.py +8 -0
  13. flyte/_code_bundle/_ignore.py +121 -0
  14. flyte/_code_bundle/_packaging.py +218 -0
  15. flyte/_code_bundle/_utils.py +347 -0
  16. flyte/_code_bundle/bundle.py +266 -0
  17. flyte/_constants.py +1 -0
  18. flyte/_context.py +155 -0
  19. flyte/_custom_context.py +73 -0
  20. flyte/_debug/__init__.py +0 -0
  21. flyte/_debug/constants.py +38 -0
  22. flyte/_debug/utils.py +17 -0
  23. flyte/_debug/vscode.py +307 -0
  24. flyte/_deploy.py +408 -0
  25. flyte/_deployer.py +109 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +122 -0
  29. flyte/_excepthook.py +37 -0
  30. flyte/_group.py +32 -0
  31. flyte/_hash.py +8 -0
  32. flyte/_image.py +1055 -0
  33. flyte/_initialize.py +628 -0
  34. flyte/_interface.py +119 -0
  35. flyte/_internal/__init__.py +3 -0
  36. flyte/_internal/controllers/__init__.py +129 -0
  37. flyte/_internal/controllers/_local_controller.py +239 -0
  38. flyte/_internal/controllers/_trace.py +48 -0
  39. flyte/_internal/controllers/remote/__init__.py +58 -0
  40. flyte/_internal/controllers/remote/_action.py +211 -0
  41. flyte/_internal/controllers/remote/_client.py +47 -0
  42. flyte/_internal/controllers/remote/_controller.py +583 -0
  43. flyte/_internal/controllers/remote/_core.py +465 -0
  44. flyte/_internal/controllers/remote/_informer.py +381 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +3 -0
  47. flyte/_internal/imagebuild/docker_builder.py +706 -0
  48. flyte/_internal/imagebuild/image_builder.py +277 -0
  49. flyte/_internal/imagebuild/remote_builder.py +386 -0
  50. flyte/_internal/imagebuild/utils.py +78 -0
  51. flyte/_internal/resolvers/__init__.py +0 -0
  52. flyte/_internal/resolvers/_task_module.py +21 -0
  53. flyte/_internal/resolvers/common.py +31 -0
  54. flyte/_internal/resolvers/default.py +28 -0
  55. flyte/_internal/runtime/__init__.py +0 -0
  56. flyte/_internal/runtime/convert.py +486 -0
  57. flyte/_internal/runtime/entrypoints.py +204 -0
  58. flyte/_internal/runtime/io.py +188 -0
  59. flyte/_internal/runtime/resources_serde.py +152 -0
  60. flyte/_internal/runtime/reuse.py +125 -0
  61. flyte/_internal/runtime/rusty.py +193 -0
  62. flyte/_internal/runtime/task_serde.py +362 -0
  63. flyte/_internal/runtime/taskrunner.py +209 -0
  64. flyte/_internal/runtime/trigger_serde.py +160 -0
  65. flyte/_internal/runtime/types_serde.py +54 -0
  66. flyte/_keyring/__init__.py +0 -0
  67. flyte/_keyring/file.py +115 -0
  68. flyte/_logging.py +300 -0
  69. flyte/_map.py +312 -0
  70. flyte/_module.py +72 -0
  71. flyte/_pod.py +30 -0
  72. flyte/_resources.py +473 -0
  73. flyte/_retry.py +32 -0
  74. flyte/_reusable_environment.py +102 -0
  75. flyte/_run.py +724 -0
  76. flyte/_secret.py +96 -0
  77. flyte/_task.py +550 -0
  78. flyte/_task_environment.py +316 -0
  79. flyte/_task_plugins.py +47 -0
  80. flyte/_timeout.py +47 -0
  81. flyte/_tools.py +27 -0
  82. flyte/_trace.py +119 -0
  83. flyte/_trigger.py +1000 -0
  84. flyte/_utils/__init__.py +30 -0
  85. flyte/_utils/asyn.py +121 -0
  86. flyte/_utils/async_cache.py +139 -0
  87. flyte/_utils/coro_management.py +27 -0
  88. flyte/_utils/docker_credentials.py +173 -0
  89. flyte/_utils/file_handling.py +72 -0
  90. flyte/_utils/helpers.py +134 -0
  91. flyte/_utils/lazy_module.py +54 -0
  92. flyte/_utils/module_loader.py +104 -0
  93. flyte/_utils/org_discovery.py +57 -0
  94. flyte/_utils/uv_script_parser.py +49 -0
  95. flyte/_version.py +34 -0
  96. flyte/app/__init__.py +22 -0
  97. flyte/app/_app_environment.py +157 -0
  98. flyte/app/_deploy.py +125 -0
  99. flyte/app/_input.py +160 -0
  100. flyte/app/_runtime/__init__.py +3 -0
  101. flyte/app/_runtime/app_serde.py +347 -0
  102. flyte/app/_types.py +101 -0
  103. flyte/app/extras/__init__.py +3 -0
  104. flyte/app/extras/_fastapi.py +151 -0
  105. flyte/cli/__init__.py +12 -0
  106. flyte/cli/_abort.py +28 -0
  107. flyte/cli/_build.py +114 -0
  108. flyte/cli/_common.py +468 -0
  109. flyte/cli/_create.py +371 -0
  110. flyte/cli/_delete.py +45 -0
  111. flyte/cli/_deploy.py +293 -0
  112. flyte/cli/_gen.py +176 -0
  113. flyte/cli/_get.py +370 -0
  114. flyte/cli/_option.py +33 -0
  115. flyte/cli/_params.py +554 -0
  116. flyte/cli/_plugins.py +209 -0
  117. flyte/cli/_run.py +597 -0
  118. flyte/cli/_serve.py +64 -0
  119. flyte/cli/_update.py +37 -0
  120. flyte/cli/_user.py +17 -0
  121. flyte/cli/main.py +221 -0
  122. flyte/config/__init__.py +3 -0
  123. flyte/config/_config.py +248 -0
  124. flyte/config/_internal.py +73 -0
  125. flyte/config/_reader.py +225 -0
  126. flyte/connectors/__init__.py +11 -0
  127. flyte/connectors/_connector.py +270 -0
  128. flyte/connectors/_server.py +197 -0
  129. flyte/connectors/utils.py +135 -0
  130. flyte/errors.py +243 -0
  131. flyte/extend.py +19 -0
  132. flyte/extras/__init__.py +5 -0
  133. flyte/extras/_container.py +286 -0
  134. flyte/git/__init__.py +3 -0
  135. flyte/git/_config.py +21 -0
  136. flyte/io/__init__.py +29 -0
  137. flyte/io/_dataframe/__init__.py +131 -0
  138. flyte/io/_dataframe/basic_dfs.py +223 -0
  139. flyte/io/_dataframe/dataframe.py +1026 -0
  140. flyte/io/_dir.py +910 -0
  141. flyte/io/_file.py +914 -0
  142. flyte/io/_hashing_io.py +342 -0
  143. flyte/models.py +479 -0
  144. flyte/py.typed +0 -0
  145. flyte/remote/__init__.py +35 -0
  146. flyte/remote/_action.py +738 -0
  147. flyte/remote/_app.py +57 -0
  148. flyte/remote/_client/__init__.py +0 -0
  149. flyte/remote/_client/_protocols.py +189 -0
  150. flyte/remote/_client/auth/__init__.py +12 -0
  151. flyte/remote/_client/auth/_auth_utils.py +14 -0
  152. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  153. flyte/remote/_client/auth/_authenticators/base.py +403 -0
  154. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  155. flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
  156. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  157. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  158. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  159. flyte/remote/_client/auth/_channel.py +213 -0
  160. flyte/remote/_client/auth/_client_config.py +85 -0
  161. flyte/remote/_client/auth/_default_html.py +32 -0
  162. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  163. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  164. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  165. flyte/remote/_client/auth/_keyring.py +152 -0
  166. flyte/remote/_client/auth/_token_client.py +260 -0
  167. flyte/remote/_client/auth/errors.py +16 -0
  168. flyte/remote/_client/controlplane.py +128 -0
  169. flyte/remote/_common.py +30 -0
  170. flyte/remote/_console.py +19 -0
  171. flyte/remote/_data.py +161 -0
  172. flyte/remote/_logs.py +185 -0
  173. flyte/remote/_project.py +88 -0
  174. flyte/remote/_run.py +386 -0
  175. flyte/remote/_secret.py +142 -0
  176. flyte/remote/_task.py +527 -0
  177. flyte/remote/_trigger.py +306 -0
  178. flyte/remote/_user.py +33 -0
  179. flyte/report/__init__.py +3 -0
  180. flyte/report/_report.py +182 -0
  181. flyte/report/_template.html +124 -0
  182. flyte/storage/__init__.py +36 -0
  183. flyte/storage/_config.py +237 -0
  184. flyte/storage/_parallel_reader.py +274 -0
  185. flyte/storage/_remote_fs.py +34 -0
  186. flyte/storage/_storage.py +456 -0
  187. flyte/storage/_utils.py +5 -0
  188. flyte/syncify/__init__.py +56 -0
  189. flyte/syncify/_api.py +375 -0
  190. flyte/types/__init__.py +52 -0
  191. flyte/types/_interface.py +40 -0
  192. flyte/types/_pickle.py +145 -0
  193. flyte/types/_renderer.py +162 -0
  194. flyte/types/_string_literals.py +119 -0
  195. flyte/types/_type_engine.py +2254 -0
  196. flyte/types/_utils.py +80 -0
  197. flyte-2.0.0b32.data/scripts/debug.py +38 -0
  198. flyte-2.0.0b32.data/scripts/runtime.py +195 -0
  199. flyte-2.0.0b32.dist-info/METADATA +351 -0
  200. flyte-2.0.0b32.dist-info/RECORD +204 -0
  201. flyte-2.0.0b32.dist-info/WHEEL +5 -0
  202. flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
  203. flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
  204. flyte-2.0.0b32.dist-info/top_level.txt +1 -0
flyte/_logging.py ADDED
@@ -0,0 +1,300 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ from datetime import datetime
7
+ from typing import Literal, Optional
8
+
9
+ import flyte
10
+
11
+ from ._tools import ipython_check
12
+
13
+ LogFormat = Literal["console", "json"]
14
+
15
+ DEFAULT_LOG_LEVEL = logging.WARNING
16
+
17
+
18
+ def make_hyperlink(label: str, url: str):
19
+ """
20
+ Create a hyperlink in the terminal output.
21
+ """
22
+ BLUE = "\033[94m"
23
+ RESET = "\033[0m"
24
+ OSC8_BEGIN = f"\033]8;;{url}\033\\"
25
+ OSC8_END = "\033]8;;\033\\"
26
+ return f"{BLUE}{OSC8_BEGIN}{label}{RESET}{OSC8_END}"
27
+
28
+
29
+ def is_rich_logging_disabled() -> bool:
30
+ """
31
+ Check if rich logging is enabled
32
+ """
33
+ return os.environ.get("DISABLE_RICH_LOGGING") is not None
34
+
35
+
36
+ def get_env_log_level() -> int:
37
+ return int(os.environ.get("LOG_LEVEL", DEFAULT_LOG_LEVEL))
38
+
39
+
40
+ def log_format_from_env() -> LogFormat:
41
+ """
42
+ Get the log format from the environment variable.
43
+ """
44
+ format_str = os.environ.get("LOG_FORMAT", "console")
45
+ if format_str not in ("console", "json"):
46
+ return "console"
47
+ return format_str # type: ignore[return-value]
48
+
49
+
50
+ def _get_console():
51
+ """
52
+ Get the console.
53
+ """
54
+ from rich.console import Console
55
+
56
+ try:
57
+ width = os.get_terminal_size().columns
58
+ except Exception as e:
59
+ logger.debug(f"Failed to get terminal size: {e}")
60
+ width = 160
61
+
62
+ return Console(width=width)
63
+
64
+
65
+ def get_rich_handler(log_level: int) -> Optional[logging.Handler]:
66
+ """
67
+ Upgrades the global loggers to use Rich logging.
68
+ """
69
+ ctx = flyte.ctx()
70
+ if ctx and ctx.is_in_cluster():
71
+ return None
72
+ if not ipython_check() and is_rich_logging_disabled():
73
+ return None
74
+
75
+ import click
76
+ from rich.highlighter import NullHighlighter
77
+ from rich.logging import RichHandler
78
+
79
+ handler = RichHandler(
80
+ tracebacks_suppress=[click],
81
+ rich_tracebacks=False,
82
+ omit_repeated_times=False,
83
+ show_path=False,
84
+ log_time_format="%H:%M:%S.%f",
85
+ console=_get_console(),
86
+ level=log_level,
87
+ highlighter=NullHighlighter(),
88
+ markup=True,
89
+ )
90
+
91
+ formatter = logging.Formatter(fmt="%(filename)s:%(lineno)d - %(message)s")
92
+ handler.setFormatter(formatter)
93
+ return handler
94
+
95
+
96
+ class JSONFormatter(logging.Formatter):
97
+ """
98
+ Formatter that outputs JSON strings for each log record.
99
+ """
100
+
101
+ def format(self, record: logging.LogRecord) -> str:
102
+ log_data = {
103
+ "timestamp": datetime.fromtimestamp(record.created).isoformat(),
104
+ "level": record.levelname,
105
+ "logger": record.name,
106
+ "message": record.getMessage(),
107
+ "filename": record.filename,
108
+ "lineno": record.lineno,
109
+ "funcName": record.funcName,
110
+ }
111
+
112
+ # Add context fields if present
113
+ if getattr(record, "run_name", None):
114
+ log_data["run_name"] = record.run_name # type: ignore[attr-defined]
115
+ if getattr(record, "action_name", None):
116
+ log_data["action_name"] = record.action_name # type: ignore[attr-defined]
117
+ if getattr(record, "is_flyte_internal", False):
118
+ log_data["is_flyte_internal"] = True
119
+
120
+ # Add exception info if present
121
+ if record.exc_info:
122
+ log_data["exc_info"] = self.formatException(record.exc_info)
123
+
124
+ return json.dumps(log_data)
125
+
126
+
127
+ def initialize_logger(log_level: int | None = None, log_format: LogFormat | None = None, enable_rich: bool = False):
128
+ """
129
+ Initializes the global loggers to the default configuration.
130
+ When enable_rich=True, upgrades to Rich handler for local CLI usage.
131
+ """
132
+ global logger # noqa: PLW0603
133
+
134
+ if log_level is None:
135
+ log_level = get_env_log_level()
136
+ if log_format is None:
137
+ log_format = log_format_from_env()
138
+
139
+ # Clear existing handlers to reconfigure
140
+ root = logging.getLogger()
141
+ root.handlers.clear()
142
+
143
+ flyte_logger = logging.getLogger("flyte")
144
+ flyte_logger.handlers.clear()
145
+
146
+ # Determine log format (JSON takes precedence over Rich)
147
+ use_json = log_format == "json"
148
+ use_rich = enable_rich and not use_json
149
+
150
+ # Set up root logger handler
151
+ root_handler: logging.Handler | None = None
152
+ if use_json:
153
+ root_handler = logging.StreamHandler()
154
+ root_handler.setFormatter(JSONFormatter())
155
+ elif use_rich:
156
+ root_handler = get_rich_handler(log_level)
157
+
158
+ if root_handler is None:
159
+ root_handler = logging.StreamHandler()
160
+
161
+ # Add context filter to root handler for all logging
162
+ root_handler.addFilter(ContextFilter())
163
+ root_handler.setLevel(logging.DEBUG)
164
+ root.addHandler(root_handler)
165
+
166
+ # Set up Flyte logger handler
167
+ flyte_handler: logging.Handler | None = None
168
+ if use_json:
169
+ flyte_handler = logging.StreamHandler()
170
+ flyte_handler.setLevel(log_level)
171
+ flyte_handler.setFormatter(JSONFormatter())
172
+ elif use_rich:
173
+ flyte_handler = get_rich_handler(log_level)
174
+
175
+ if flyte_handler is None:
176
+ flyte_handler = logging.StreamHandler()
177
+ flyte_handler.setLevel(log_level)
178
+ formatter = logging.Formatter(fmt="%(message)s")
179
+ flyte_handler.setFormatter(formatter)
180
+
181
+ # Add both filters to Flyte handler
182
+ flyte_handler.addFilter(FlyteInternalFilter())
183
+ flyte_handler.addFilter(ContextFilter())
184
+
185
+ flyte_logger.addHandler(flyte_handler)
186
+ flyte_logger.setLevel(log_level)
187
+ flyte_logger.propagate = False # Prevent double logging
188
+
189
+ logger = flyte_logger
190
+
191
+
192
+ def log(fn=None, *, level=logging.DEBUG, entry=True, exit=True):
193
+ """
194
+ Decorator to log function calls.
195
+ """
196
+
197
+ def decorator(func):
198
+ if logger.isEnabledFor(level):
199
+
200
+ def wrapper(*args, **kwargs):
201
+ if entry:
202
+ logger.log(level, f"[{func.__name__}] with args: {args} and kwargs: {kwargs}")
203
+ try:
204
+ return func(*args, **kwargs)
205
+ finally:
206
+ if exit:
207
+ logger.log(level, f"[{func.__name__}] completed")
208
+
209
+ return wrapper
210
+ return func
211
+
212
+ if fn is None:
213
+ return decorator
214
+ return decorator(fn)
215
+
216
+
217
+ class ContextFilter(logging.Filter):
218
+ """
219
+ A logging filter that adds the current action's run name and name to all log records.
220
+ Applied globally to capture context for both user and Flyte internal logging.
221
+ """
222
+
223
+ def filter(self, record: logging.LogRecord) -> bool:
224
+ from flyte._context import ctx
225
+
226
+ c = ctx()
227
+ if c:
228
+ action = c.action
229
+ # Add as attributes for structured logging (JSON)
230
+ record.run_name = action.run_name
231
+ record.action_name = action.name
232
+ # Also modify message for console/Rich output
233
+ record.msg = f"[{action.run_name}][{action.name}] {record.msg}"
234
+ else:
235
+ record.run_name = None
236
+ record.action_name = None
237
+ return True
238
+
239
+
240
+ class FlyteInternalFilter(logging.Filter):
241
+ """
242
+ A logging filter that adds [flyte] prefix to internal Flyte logging only.
243
+ """
244
+
245
+ def filter(self, record: logging.LogRecord) -> bool:
246
+ is_internal = record.name.startswith("flyte")
247
+ # Add as attribute for structured logging (JSON)
248
+ record.is_flyte_internal = is_internal
249
+ # Also modify message for console/Rich output
250
+ if is_internal:
251
+ record.msg = f"[flyte] {record.msg}"
252
+ return True
253
+
254
+
255
+ def _setup_root_logger():
256
+ """
257
+ Configure the root logger to capture all logging with context information.
258
+ This ensures both user code and Flyte internal logging get the context.
259
+ """
260
+ root = logging.getLogger()
261
+ root.handlers.clear() # Remove any existing handlers to prevent double logging
262
+
263
+ # Create a basic handler for the root logger
264
+ handler = logging.StreamHandler()
265
+ # Add context filter to ALL logging
266
+ handler.addFilter(ContextFilter())
267
+ handler.setLevel(logging.DEBUG)
268
+
269
+ # Simple formatter since filters handle prefixes
270
+ root.addHandler(handler)
271
+
272
+
273
+ def _create_flyte_logger() -> logging.Logger:
274
+ """
275
+ Create the internal Flyte logger with [flyte] prefix.
276
+ """
277
+ flyte_logger = logging.getLogger("flyte")
278
+ flyte_logger.setLevel(get_env_log_level())
279
+
280
+ # Add a handler specifically for flyte logging with the prefix filter
281
+ handler = logging.StreamHandler()
282
+ handler.setLevel(get_env_log_level())
283
+ handler.addFilter(FlyteInternalFilter())
284
+ handler.addFilter(ContextFilter())
285
+
286
+ formatter = logging.Formatter(fmt="%(message)s")
287
+ handler.setFormatter(formatter)
288
+
289
+ # Prevent propagation to root to avoid double logging
290
+ flyte_logger.propagate = False
291
+ flyte_logger.addHandler(handler)
292
+
293
+ return flyte_logger
294
+
295
+
296
+ # Initialize root logger for global context
297
+ _setup_root_logger()
298
+
299
+ # Create the Flyte internal logger
300
+ logger = _create_flyte_logger()
flyte/_map.py ADDED
@@ -0,0 +1,312 @@
1
+ import asyncio
2
+ import functools
3
+ import logging
4
+ from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast, overload
5
+
6
+ from flyte.syncify import syncify
7
+
8
+ from ._group import group
9
+ from ._logging import logger
10
+ from ._task import AsyncFunctionTaskTemplate, F, P, R
11
+
12
+
13
+ class MapAsyncIterator(Generic[P, R]):
14
+ """AsyncIterator implementation for the map function results"""
15
+
16
+ def __init__(
17
+ self,
18
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
19
+ args: tuple,
20
+ name: str,
21
+ concurrency: int,
22
+ return_exceptions: bool,
23
+ ):
24
+ self.func = func
25
+ self.args = args
26
+ self.name = name
27
+ self.concurrency = concurrency
28
+ self.return_exceptions = return_exceptions
29
+ self._tasks: List[asyncio.Task] = []
30
+ self._current_index = 0
31
+ self._completed_count = 0
32
+ self._exception_count = 0
33
+ self._task_count = 0
34
+ self._initialized = False
35
+
36
+ def __aiter__(self) -> AsyncIterator[Union[R, Exception]]:
37
+ """Return self as the async iterator"""
38
+ return self
39
+
40
+ async def __anext__(self) -> Union[R, Exception]:
41
+ """Get the next result"""
42
+ # Initialize on first call
43
+ if not self._initialized:
44
+ await self._initialize()
45
+
46
+ # Check if we've exhausted all tasks
47
+ if self._current_index >= self._task_count:
48
+ raise StopAsyncIteration
49
+
50
+ # Get the next task result
51
+ task = self._tasks[self._current_index]
52
+ self._current_index += 1
53
+
54
+ try:
55
+ result = await task
56
+ self._completed_count += 1
57
+ logger.debug(f"Task {self._current_index - 1} completed successfully")
58
+ return result
59
+ except Exception as e:
60
+ self._exception_count += 1
61
+ logger.debug(
62
+ f"Task {self._current_index - 1} failed with exception: {e}, return_exceptions={self.return_exceptions}"
63
+ )
64
+ if self.return_exceptions:
65
+ return e
66
+ else:
67
+ # Cancel remaining tasks
68
+ for remaining_task in self._tasks[self._current_index + 1 :]:
69
+ remaining_task.cancel()
70
+ logger.warning("Exception raising is `ON`, raising exception and cancelling remaining tasks")
71
+ raise e
72
+
73
+ async def _initialize(self):
74
+ """Initialize the tasks - called lazily on first iteration"""
75
+ # Create all tasks at once
76
+ tasks = []
77
+ task_count = 0
78
+
79
+ if isinstance(self.func, functools.partial):
80
+ # Handle partial functions by merging bound args/kwargs with mapped args
81
+ base_func = cast(AsyncFunctionTaskTemplate, self.func.func)
82
+ bound_args = self.func.args
83
+ bound_kwargs = self.func.keywords or {}
84
+
85
+ for arg_tuple in zip(*self.args):
86
+ # Merge bound positional args with mapped args
87
+ merged_args = bound_args + arg_tuple
88
+ if logger.isEnabledFor(logging.DEBUG):
89
+ logger.debug(f"Running {base_func.name} with args: {merged_args} and kwargs: {bound_kwargs}")
90
+ task = asyncio.create_task(base_func.aio(*merged_args, **bound_kwargs))
91
+ tasks.append(task)
92
+ task_count += 1
93
+ else:
94
+ # Handle regular TaskTemplate functions
95
+ for arg_tuple in zip(*self.args):
96
+ task = asyncio.create_task(self.func.aio(*arg_tuple))
97
+ tasks.append(task)
98
+ task_count += 1
99
+
100
+ if task_count == 0:
101
+ logger.info(f"Group '{self.name}' has no tasks to process")
102
+ self._tasks = []
103
+ self._task_count = 0
104
+ else:
105
+ logger.info(f"Starting {task_count} tasks in group '{self.name}' with unlimited concurrency")
106
+ self._tasks = tasks
107
+ self._task_count = task_count
108
+
109
+ self._initialized = True
110
+
111
+ async def collect(self) -> List[Union[R, Exception]]:
112
+ """Convenience method to collect all results into a list"""
113
+ results = []
114
+ async for result in self:
115
+ results.append(result)
116
+ return results
117
+
118
+ def __repr__(self):
119
+ return f"MapAsyncIterator(group_name='{self.name}', concurrency={self.concurrency})"
120
+
121
+
122
+ class _Mapper(Generic[P, R]):
123
+ """
124
+ Internal mapper class to handle the mapping logic
125
+
126
+ NOTE: The reason why we do not use the `@syncify` decorator here is because, in `syncify` we cannot use
127
+ context managers like `group` directly in the function body. This is because the `__exit__` method of the
128
+ context manager is called after the function returns. An for `_context` the `__exit__` method releases the
129
+ token (for contextvar), which was created in a separate thread. This leads to an exception like:
130
+
131
+ """
132
+
133
+ @classmethod
134
+ def _get_name(cls, task_name: str, group_name: str | None) -> str:
135
+ """Get the name of the group, defaulting to 'map' if not provided."""
136
+ return f"{task_name}_{group_name or 'map'}"
137
+
138
+ @staticmethod
139
+ def validate_partial(func: functools.partial[R]):
140
+ """
141
+ This method validates that the provided partial function is valid for mapping, i.e. only the one argument
142
+ is left for mapping and the rest are provided as keywords or args.
143
+
144
+ :param func: partial function to validate
145
+ :raises TypeError: if the partial function is not valid for mapping
146
+ """
147
+ f = cast(AsyncFunctionTaskTemplate, func.func)
148
+ inputs = f.native_interface.inputs
149
+ params = list(inputs.keys())
150
+ total_params = len(params)
151
+ provided_args = len(func.args)
152
+ provided_kwargs = len(func.keywords or {})
153
+
154
+ # Calculate how many parameters are left unspecified
155
+ unspecified_count = total_params - provided_args - provided_kwargs
156
+
157
+ # Exactly one parameter should be left for mapping
158
+ if unspecified_count != 1:
159
+ raise TypeError(
160
+ f"Partial function must leave exactly one parameter unspecified for mapping. "
161
+ f"Found {unspecified_count} unspecified parameters in {f.name}, "
162
+ f"params: {inputs.keys()}"
163
+ )
164
+
165
+ # Validate that no parameter is both in args and keywords
166
+ if func.keywords:
167
+ param_names = list(inputs.keys())
168
+ for i, arg_name in enumerate(param_names[: provided_args + 1]):
169
+ if arg_name in func.keywords:
170
+ raise TypeError(
171
+ f"Parameter '{arg_name}' is provided both as positional argument and keyword argument "
172
+ f"in partial function {f.name}."
173
+ )
174
+
175
+ @overload
176
+ def __call__(
177
+ self,
178
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
179
+ *args: Iterable[Any],
180
+ group_name: str | None = None,
181
+ concurrency: int = 0,
182
+ ) -> Iterator[R]: ...
183
+
184
+ @overload
185
+ def __call__(
186
+ self,
187
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
188
+ *args: Iterable[Any],
189
+ group_name: str | None = None,
190
+ concurrency: int = 0,
191
+ return_exceptions: bool = True,
192
+ ) -> Iterator[Union[R, Exception]]: ...
193
+
194
+ def __call__(
195
+ self,
196
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
197
+ *args: Iterable[Any],
198
+ group_name: str | None = None,
199
+ concurrency: int = 0,
200
+ return_exceptions: bool = True,
201
+ ) -> Iterator[Union[R, Exception]]:
202
+ """
203
+ Map a function over the provided arguments with concurrent execution.
204
+
205
+ :param func: The async function to map.
206
+ :param args: Positional arguments to pass to the function (iterables that will be zipped).
207
+ :param group_name: The name of the group for the mapped tasks.
208
+ :param concurrency: The maximum number of concurrent tasks to run. If 0, run all tasks concurrently.
209
+ :param return_exceptions: If True, yield exceptions instead of raising them.
210
+ :return: AsyncIterator yielding results in order.
211
+ """
212
+ if not args:
213
+ return
214
+
215
+ if isinstance(func, functools.partial):
216
+ f = cast(AsyncFunctionTaskTemplate, func.func)
217
+ self.validate_partial(func)
218
+ else:
219
+ f = cast(AsyncFunctionTaskTemplate, func)
220
+
221
+ name = self._get_name(f.name, group_name)
222
+ logger.debug(f"Blocking Map for {name}")
223
+ with group(name):
224
+ import flyte
225
+
226
+ tctx = flyte.ctx()
227
+ if tctx is None or tctx.mode == "local":
228
+ logger.warning("Running map in local mode, which will run every task sequentially.")
229
+ for v in zip(*args):
230
+ try:
231
+ yield func(*v) # type: ignore
232
+ except Exception as e:
233
+ if return_exceptions:
234
+ yield e
235
+ else:
236
+ raise e
237
+ return
238
+
239
+ i = 0
240
+ for x in cast(
241
+ Iterator[R],
242
+ _map(
243
+ func,
244
+ *args,
245
+ name=name,
246
+ concurrency=concurrency,
247
+ return_exceptions=return_exceptions,
248
+ ),
249
+ ):
250
+ logger.debug(f"Mapped {x}, task {i}")
251
+ i += 1
252
+ yield x
253
+
254
+ async def aio(
255
+ self,
256
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
257
+ *args: Iterable[Any],
258
+ group_name: str | None = None,
259
+ concurrency: int = 0,
260
+ return_exceptions: bool = True,
261
+ ) -> AsyncGenerator[Union[R, Exception], None]:
262
+ if not args:
263
+ return
264
+
265
+ if isinstance(func, functools.partial):
266
+ f = cast(AsyncFunctionTaskTemplate, func.func)
267
+ self.validate_partial(func)
268
+ else:
269
+ f = cast(AsyncFunctionTaskTemplate, func)
270
+
271
+ name = self._get_name(f.name, group_name)
272
+ with group(name):
273
+ import flyte
274
+
275
+ tctx = flyte.ctx()
276
+ if tctx is None or tctx.mode == "local":
277
+ logger.warning("Running map in local mode, which will run every task sequentially.")
278
+ for v in zip(*args):
279
+ try:
280
+ yield func(*v) # type: ignore
281
+ except Exception as e:
282
+ if return_exceptions:
283
+ yield e
284
+ else:
285
+ raise e
286
+ return
287
+ async for x in _map.aio(
288
+ func,
289
+ *args,
290
+ name=name,
291
+ concurrency=concurrency,
292
+ return_exceptions=return_exceptions,
293
+ ):
294
+ yield cast(Union[R, Exception], x)
295
+
296
+
297
+ @syncify
298
+ async def _map(
299
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
300
+ *args: Iterable[Any],
301
+ name: str = "map",
302
+ concurrency: int = 0,
303
+ return_exceptions: bool = True,
304
+ ) -> AsyncIterator[Union[R, Exception]]:
305
+ iter = MapAsyncIterator(
306
+ func=func, args=args, name=name, concurrency=concurrency, return_exceptions=return_exceptions
307
+ )
308
+ async for result in iter:
309
+ yield result
310
+
311
+
312
+ map: _Mapper = _Mapper()
flyte/_module.py ADDED
@@ -0,0 +1,72 @@
1
+ import inspect
2
+ import os
3
+ import pathlib
4
+ import sys
5
+ from types import ModuleType
6
+ from typing import Tuple
7
+
8
+
9
+ def extract_obj_module(obj: object, /, source_dir: pathlib.Path | None = None) -> Tuple[str, ModuleType]:
10
+ """
11
+ Extract the module from the given object. If source_dir is provided, the module will be relative to the source_dir.
12
+
13
+ Args:
14
+ obj: The object to extract the module from.
15
+ source_dir: The source directory to use for relative paths.
16
+
17
+ Returns:
18
+ The module name as a string.
19
+ """
20
+ if source_dir is None:
21
+ raise ValueError("extract_obj_module: source_dir cannot be None - specify root-dir")
22
+ # Get the module containing the object
23
+ entity_module = inspect.getmodule(obj)
24
+ if entity_module is None:
25
+ obj_name = getattr(obj, "__name__", str(obj))
26
+ raise ValueError(f"Object {obj_name} has no module.")
27
+
28
+ fp = entity_module.__file__
29
+ if fp is None:
30
+ obj_name = getattr(obj, "__name__", str(obj))
31
+ raise ValueError(f"Object {obj_name} has no module.")
32
+
33
+ file_path = pathlib.Path(fp)
34
+ try:
35
+ # Get the relative path to the current directory
36
+ # Will raise ValueError if the file is not in the source directory
37
+ relative_path = file_path.relative_to(str(pathlib.Path(source_dir).absolute()))
38
+
39
+ if relative_path == pathlib.Path("_internal/resolvers"):
40
+ entity_module_name = entity_module.__name__
41
+ else:
42
+ # Replace file separators with dots and remove the '.py' extension
43
+ dotted_path = os.path.splitext(str(relative_path))[0].replace(os.sep, ".")
44
+ entity_module_name = dotted_path
45
+ except ValueError:
46
+ # If source_dir is not provided or file is not in source_dir, fallback to module name
47
+ # File is not relative to source_dir - check if it's an installed package
48
+ file_path_str = str(file_path)
49
+ if "site-packages" in file_path_str or "dist-packages" in file_path_str:
50
+ # It's an installed package - use the module's __name__ directly
51
+ # This will be importable via importlib.import_module()
52
+ entity_module_name = entity_module.__name__
53
+ else:
54
+ # File is not in source_dir and not in site-packages - re-raise the error
55
+ obj_name = getattr(obj, "__name__", str(obj))
56
+ raise ValueError(
57
+ f"Object {obj_name} module file {file_path} is not relative to "
58
+ f"source directory {source_dir} and is not an installed package."
59
+ )
60
+
61
+ if entity_module_name == "__main__":
62
+ """
63
+ This case is for the case in which the object is run from the main module.
64
+ """
65
+ fp = sys.modules["__main__"].__file__
66
+ if fp is None:
67
+ obj_name = getattr(obj, "__name__", str(obj))
68
+ raise ValueError(f"Object {obj_name} has no module.")
69
+ main_path = pathlib.Path(fp)
70
+ entity_module_name = main_path.stem
71
+
72
+ return entity_module_name, entity_module