flyte 0.2.0b8__py3-none-any.whl → 0.2.0b10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +4 -2
- flyte/_context.py +7 -1
- flyte/_deploy.py +3 -0
- flyte/_group.py +1 -0
- flyte/_initialize.py +15 -5
- flyte/_internal/controllers/__init__.py +13 -2
- flyte/_internal/controllers/_local_controller.py +67 -5
- flyte/_internal/controllers/remote/_controller.py +47 -2
- flyte/_internal/runtime/taskrunner.py +2 -1
- flyte/_map.py +215 -0
- flyte/_run.py +109 -64
- flyte/_task.py +56 -7
- flyte/_utils/helpers.py +15 -0
- flyte/_version.py +2 -2
- flyte/cli/__init__.py +0 -7
- flyte/cli/_abort.py +1 -1
- flyte/cli/_common.py +7 -7
- flyte/cli/_create.py +44 -29
- flyte/cli/_delete.py +2 -2
- flyte/cli/_deploy.py +3 -3
- flyte/cli/_gen.py +12 -4
- flyte/cli/_get.py +35 -27
- flyte/cli/_params.py +1 -1
- flyte/cli/main.py +32 -29
- flyte/extras/_container.py +29 -32
- flyte/io/__init__.py +17 -1
- flyte/io/_file.py +2 -0
- flyte/io/{structured_dataset → _structured_dataset}/basic_dfs.py +1 -1
- flyte/io/{structured_dataset → _structured_dataset}/structured_dataset.py +1 -1
- flyte/models.py +11 -1
- flyte/syncify/_api.py +43 -15
- flyte/types/__init__.py +23 -0
- flyte/{io/pickle/transformer.py → types/_pickle.py} +2 -1
- flyte/types/_type_engine.py +4 -4
- {flyte-0.2.0b8.dist-info → flyte-0.2.0b10.dist-info}/METADATA +7 -6
- {flyte-0.2.0b8.dist-info → flyte-0.2.0b10.dist-info}/RECORD +40 -41
- flyte/io/_dataframe.py +0 -0
- flyte/io/pickle/__init__.py +0 -0
- /flyte/io/{structured_dataset → _structured_dataset}/__init__.py +0 -0
- {flyte-0.2.0b8.dist-info → flyte-0.2.0b10.dist-info}/WHEEL +0 -0
- {flyte-0.2.0b8.dist-info → flyte-0.2.0b10.dist-info}/entry_points.txt +0 -0
- {flyte-0.2.0b8.dist-info → flyte-0.2.0b10.dist-info}/top_level.txt +0 -0
flyte/__init__.py
CHANGED
|
@@ -38,7 +38,8 @@ __all__ = [
|
|
|
38
38
|
"deploy",
|
|
39
39
|
"group",
|
|
40
40
|
"init",
|
|
41
|
-
"
|
|
41
|
+
"init_from_config",
|
|
42
|
+
"map",
|
|
42
43
|
"run",
|
|
43
44
|
"trace",
|
|
44
45
|
"with_runcontext",
|
|
@@ -50,7 +51,8 @@ from ._deploy import deploy
|
|
|
50
51
|
from ._environment import Environment
|
|
51
52
|
from ._group import group
|
|
52
53
|
from ._image import Image
|
|
53
|
-
from ._initialize import init,
|
|
54
|
+
from ._initialize import init, init_from_config
|
|
55
|
+
from ._map import map
|
|
54
56
|
from ._resources import GPU, TPU, Device, Resources
|
|
55
57
|
from ._retry import RetryStrategy
|
|
56
58
|
from ._reusable_environment import ReusePolicy
|
flyte/_context.py
CHANGED
|
@@ -4,6 +4,7 @@ import contextvars
|
|
|
4
4
|
from dataclasses import dataclass, replace
|
|
5
5
|
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, ParamSpec, TypeVar
|
|
6
6
|
|
|
7
|
+
from flyte._logging import logger
|
|
7
8
|
from flyte.models import GroupData, RawDataPath, TaskContext
|
|
8
9
|
|
|
9
10
|
if TYPE_CHECKING:
|
|
@@ -49,6 +50,7 @@ class Context:
|
|
|
49
50
|
raise ValueError("Cannot create a new context without contextdata.")
|
|
50
51
|
self._data = data
|
|
51
52
|
self._id = id(self) # Immutable unique identifier
|
|
53
|
+
self._token = None # Context variable token to restore the previous context
|
|
52
54
|
|
|
53
55
|
@property
|
|
54
56
|
def data(self) -> ContextData:
|
|
@@ -106,7 +108,11 @@ class Context:
|
|
|
106
108
|
|
|
107
109
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
108
110
|
"""Exit the context, restoring the previous context."""
|
|
109
|
-
|
|
111
|
+
try:
|
|
112
|
+
root_context_var.reset(self._token)
|
|
113
|
+
except Exception as e:
|
|
114
|
+
logger.warn(f"Failed to reset context: {e}")
|
|
115
|
+
raise e
|
|
110
116
|
|
|
111
117
|
async def __aenter__(self):
|
|
112
118
|
"""Async version of context entry."""
|
flyte/_deploy.py
CHANGED
|
@@ -128,6 +128,9 @@ async def apply(deployment: DeploymentPlan, copy_style: CopyFiles, dryrun: bool
|
|
|
128
128
|
else:
|
|
129
129
|
code_bundle = await build_code_bundle(from_dir=cfg.root_dir, dryrun=dryrun, copy_style=copy_style)
|
|
130
130
|
deployment.version = code_bundle.computed_version
|
|
131
|
+
# TODO we should update the version to include the image cache digest and code bundle digest. This is
|
|
132
|
+
# to ensure that changes in image dependencies, cause an update to the deployment version.
|
|
133
|
+
# TODO Also hash the environment and tasks to ensure that changes in the environment or tasks
|
|
131
134
|
|
|
132
135
|
sc = SerializationContext(
|
|
133
136
|
project=cfg.project,
|
flyte/_group.py
CHANGED
flyte/_initialize.py
CHANGED
|
@@ -138,14 +138,13 @@ async def init(
|
|
|
138
138
|
:param project: Optional project name (not used in this implementation)
|
|
139
139
|
:param domain: Optional domain name (not used in this implementation)
|
|
140
140
|
:param root_dir: Optional root directory from which to determine how to load files, and find paths to files.
|
|
141
|
+
This is useful for determining the root directory for the current project, and for locating files like config etc.
|
|
142
|
+
also use to determine all the code that needs to be copied to the remote location.
|
|
141
143
|
defaults to the editable install directory if the cwd is in a Python editable install, else just the cwd.
|
|
142
144
|
:param log_level: Optional logging level for the logger, default is set using the default initialization policies
|
|
143
145
|
:param api_key: Optional API key for authentication
|
|
144
146
|
:param endpoint: Optional API endpoint URL
|
|
145
147
|
:param headless: Optional Whether to run in headless mode
|
|
146
|
-
:param mode: Optional execution model (local, remote). Default is local. When local is used,
|
|
147
|
-
the execution will be done locally. When remote is used, the execution will be sent to a remote server,
|
|
148
|
-
In the remote case, the endpoint or api_key must be set.
|
|
149
148
|
:param insecure_skip_verify: Whether to skip SSL certificate verification
|
|
150
149
|
:param auth_client_config: Optional client configuration for authentication
|
|
151
150
|
:param auth_type: The authentication type to use (Pkce, ClientSecret, ExternalCommand, DeviceFlow)
|
|
@@ -177,6 +176,9 @@ async def init(
|
|
|
177
176
|
|
|
178
177
|
global _init_config # noqa: PLW0603
|
|
179
178
|
|
|
179
|
+
if endpoint and "://" not in endpoint:
|
|
180
|
+
endpoint = f"dns:///{endpoint}"
|
|
181
|
+
|
|
180
182
|
with _init_lock:
|
|
181
183
|
client = None
|
|
182
184
|
if endpoint or api_key:
|
|
@@ -209,12 +211,16 @@ async def init(
|
|
|
209
211
|
|
|
210
212
|
|
|
211
213
|
@syncify
|
|
212
|
-
async def
|
|
214
|
+
async def init_from_config(path_or_config: str | Config | None = None, root_dir: Path | None = None) -> None:
|
|
213
215
|
"""
|
|
214
216
|
Initialize the Flyte system using a configuration file or Config object. This method should be called before any
|
|
215
217
|
other Flyte remote API methods are called. Thread-safe implementation.
|
|
216
218
|
|
|
217
219
|
:param path_or_config: Path to the configuration file or Config object
|
|
220
|
+
:param root_dir: Optional root directory from which to determine how to load files, and find paths to
|
|
221
|
+
files like config etc. For example if one uses the copy-style=="all", it is essential to determine the
|
|
222
|
+
root directory for the current project. If not provided, it defaults to the editable install directory or
|
|
223
|
+
if not available, the current working directory.
|
|
218
224
|
:return: None
|
|
219
225
|
"""
|
|
220
226
|
import flyte.config as config
|
|
@@ -222,7 +228,10 @@ async def init_auto_from_config(path_or_config: str | Config | None = None) -> N
|
|
|
222
228
|
cfg: config.Config
|
|
223
229
|
if path_or_config is None or isinstance(path_or_config, str):
|
|
224
230
|
# If a string is passed, treat it as a path to the config file
|
|
225
|
-
|
|
231
|
+
if root_dir and path_or_config:
|
|
232
|
+
cfg = config.auto(str(root_dir / path_or_config))
|
|
233
|
+
else:
|
|
234
|
+
cfg = config.auto(path_or_config)
|
|
226
235
|
else:
|
|
227
236
|
# If a Config object is passed, use it directly
|
|
228
237
|
cfg = path_or_config
|
|
@@ -241,6 +250,7 @@ async def init_auto_from_config(path_or_config: str | Config | None = None) -> N
|
|
|
241
250
|
proxy_command=cfg.platform.proxy_command,
|
|
242
251
|
client_id=cfg.platform.client_id,
|
|
243
252
|
client_credentials_secret=cfg.platform.client_credentials_secret,
|
|
253
|
+
root_dir=root_dir,
|
|
244
254
|
)
|
|
245
255
|
|
|
246
256
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
|
+
import concurrent.futures
|
|
1
2
|
import threading
|
|
2
|
-
from typing import Any, Callable, Literal, Optional, Protocol, Tuple, TypeVar
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Protocol, Tuple, TypeVar
|
|
3
4
|
|
|
4
5
|
from flyte._task import TaskTemplate
|
|
5
6
|
from flyte.models import ActionID, NativeInterface
|
|
@@ -10,6 +11,9 @@ __all__ = ["Controller", "ControllerType", "TraceInfo", "create_controller", "ge
|
|
|
10
11
|
|
|
11
12
|
from ..._protos.workflow import task_definition_pb2
|
|
12
13
|
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import concurrent.futures
|
|
16
|
+
|
|
13
17
|
ControllerType = Literal["local", "remote"]
|
|
14
18
|
|
|
15
19
|
R = TypeVar("R")
|
|
@@ -28,7 +32,14 @@ class Controller(Protocol):
|
|
|
28
32
|
"""
|
|
29
33
|
...
|
|
30
34
|
|
|
31
|
-
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) ->
|
|
35
|
+
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
|
|
36
|
+
"""
|
|
37
|
+
This should call the async submit method above, but return a concurrent Future object that can be
|
|
38
|
+
used in a blocking wait or wrapped in an async future. This is called when
|
|
39
|
+
a) a synchronous task is kicked off locally,
|
|
40
|
+
b) a running task (of either kind) kicks off a downstream synchronous task.
|
|
41
|
+
"""
|
|
42
|
+
...
|
|
32
43
|
|
|
33
44
|
async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
|
|
34
45
|
"""
|
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import atexit
|
|
3
|
+
import concurrent.futures
|
|
4
|
+
import os
|
|
5
|
+
import threading
|
|
1
6
|
from typing import Any, Callable, Tuple, TypeVar
|
|
2
7
|
|
|
3
8
|
import flyte.errors
|
|
@@ -8,20 +13,67 @@ from flyte._internal.runtime.entrypoints import direct_dispatch
|
|
|
8
13
|
from flyte._logging import log, logger
|
|
9
14
|
from flyte._protos.workflow import task_definition_pb2
|
|
10
15
|
from flyte._task import TaskTemplate
|
|
11
|
-
from flyte._utils.
|
|
12
|
-
from flyte.models import ActionID, NativeInterface
|
|
16
|
+
from flyte._utils.helpers import _selector_policy
|
|
17
|
+
from flyte.models import ActionID, NativeInterface
|
|
13
18
|
|
|
14
19
|
R = TypeVar("R")
|
|
15
20
|
|
|
16
21
|
|
|
22
|
+
class _TaskRunner:
|
|
23
|
+
"""A task runner that runs an asyncio event loop on a background thread."""
|
|
24
|
+
|
|
25
|
+
def __init__(self) -> None:
|
|
26
|
+
self.__loop: asyncio.AbstractEventLoop | None = None
|
|
27
|
+
self.__runner_thread: threading.Thread | None = None
|
|
28
|
+
self.__lock = threading.Lock()
|
|
29
|
+
atexit.register(self._close)
|
|
30
|
+
|
|
31
|
+
def _close(self) -> None:
|
|
32
|
+
if self.__loop:
|
|
33
|
+
self.__loop.stop()
|
|
34
|
+
|
|
35
|
+
def _execute(self) -> None:
|
|
36
|
+
loop = self.__loop
|
|
37
|
+
assert loop is not None
|
|
38
|
+
try:
|
|
39
|
+
loop.run_forever()
|
|
40
|
+
finally:
|
|
41
|
+
loop.close()
|
|
42
|
+
|
|
43
|
+
def get_exc_handler(self):
|
|
44
|
+
def exc_handler(loop, context):
|
|
45
|
+
logger.error(
|
|
46
|
+
f"Taskrunner for {self.__runner_thread.name if self.__runner_thread else 'no thread'} caught"
|
|
47
|
+
f" exception in {loop}: {context}"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return exc_handler
|
|
51
|
+
|
|
52
|
+
def get_run_future(self, coro: Any) -> concurrent.futures.Future:
|
|
53
|
+
"""Synchronously run a coroutine on a background thread."""
|
|
54
|
+
name = f"{threading.current_thread().name} : loop-runner"
|
|
55
|
+
with self.__lock:
|
|
56
|
+
if self.__loop is None:
|
|
57
|
+
with _selector_policy():
|
|
58
|
+
self.__loop = asyncio.new_event_loop()
|
|
59
|
+
|
|
60
|
+
exc_handler = self.get_exc_handler()
|
|
61
|
+
self.__loop.set_exception_handler(exc_handler)
|
|
62
|
+
self.__runner_thread = threading.Thread(target=self._execute, daemon=True, name=name)
|
|
63
|
+
self.__runner_thread.start()
|
|
64
|
+
fut = asyncio.run_coroutine_threadsafe(coro, self.__loop)
|
|
65
|
+
return fut
|
|
66
|
+
|
|
67
|
+
|
|
17
68
|
class LocalController:
|
|
18
69
|
def __init__(self):
|
|
19
70
|
logger.debug("LocalController init")
|
|
71
|
+
self._runner_map: dict[str, _TaskRunner] = {}
|
|
20
72
|
|
|
21
73
|
@log
|
|
22
74
|
async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
23
75
|
"""
|
|
24
|
-
|
|
76
|
+
Main entrypoint for submitting a task to the local controller.
|
|
25
77
|
"""
|
|
26
78
|
ctx = internal_ctx()
|
|
27
79
|
tctx = ctx.data.task_context
|
|
@@ -34,7 +86,7 @@ class LocalController:
|
|
|
34
86
|
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
35
87
|
tctx, _task.name, serialized_inputs, 0
|
|
36
88
|
)
|
|
37
|
-
sub_action_raw_data_path =
|
|
89
|
+
sub_action_raw_data_path = tctx.raw_data_path
|
|
38
90
|
|
|
39
91
|
out, err = await direct_dispatch(
|
|
40
92
|
_task,
|
|
@@ -59,7 +111,17 @@ class LocalController:
|
|
|
59
111
|
return result
|
|
60
112
|
return out
|
|
61
113
|
|
|
62
|
-
submit_sync
|
|
114
|
+
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
|
|
115
|
+
name = threading.current_thread().name + f"PID:{os.getpid()}"
|
|
116
|
+
coro = self.submit(_task, *args, **kwargs)
|
|
117
|
+
if name not in self._runner_map:
|
|
118
|
+
if len(self._runner_map) > 100:
|
|
119
|
+
logger.warning(
|
|
120
|
+
"More than 100 event loop runners created!!! This could be a case of runaway recursion..."
|
|
121
|
+
)
|
|
122
|
+
self._runner_map[name] = _TaskRunner()
|
|
123
|
+
|
|
124
|
+
return self._runner_map[name].get_run_future(coro)
|
|
63
125
|
|
|
64
126
|
async def finalize_parent_action(self, action: ActionID):
|
|
65
127
|
pass
|
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import concurrent.futures
|
|
5
|
+
import os
|
|
6
|
+
import threading
|
|
4
7
|
from collections import defaultdict
|
|
5
8
|
from collections.abc import Callable
|
|
6
9
|
from pathlib import Path
|
|
@@ -21,7 +24,7 @@ from flyte._internal.runtime.task_serde import translate_task_to_wire
|
|
|
21
24
|
from flyte._logging import logger
|
|
22
25
|
from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
|
|
23
26
|
from flyte._task import TaskTemplate
|
|
24
|
-
from flyte._utils.
|
|
27
|
+
from flyte._utils.helpers import _selector_policy
|
|
25
28
|
from flyte.models import ActionID, NativeInterface, SerializationContext
|
|
26
29
|
|
|
27
30
|
R = TypeVar("R")
|
|
@@ -120,6 +123,8 @@ class RemoteController(Controller):
|
|
|
120
123
|
self._parent_action_task_call_sequence: DefaultDict[str, DefaultDict[int, int]] = defaultdict(
|
|
121
124
|
lambda: defaultdict(int)
|
|
122
125
|
)
|
|
126
|
+
self._submit_loop: asyncio.AbstractEventLoop | None = None
|
|
127
|
+
self._submit_thread: threading.Thread | None = None
|
|
123
128
|
|
|
124
129
|
def generate_task_call_sequence(self, task_obj: object, action_id: ActionID) -> int:
|
|
125
130
|
"""
|
|
@@ -236,7 +241,47 @@ class RemoteController(Controller):
|
|
|
236
241
|
async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
|
|
237
242
|
return await self._submit(task_call_seq, _task, *args, **kwargs)
|
|
238
243
|
|
|
239
|
-
|
|
244
|
+
def _sync_thread_loop_runner(self) -> None:
|
|
245
|
+
"""This method runs the event loop and should be invoked in a separate thread."""
|
|
246
|
+
|
|
247
|
+
loop = self._submit_loop
|
|
248
|
+
assert loop is not None
|
|
249
|
+
try:
|
|
250
|
+
loop.run_forever()
|
|
251
|
+
finally:
|
|
252
|
+
loop.close()
|
|
253
|
+
|
|
254
|
+
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
|
|
255
|
+
"""
|
|
256
|
+
This function creates a cached thread and loop for the purpose of calling the submit method synchronously,
|
|
257
|
+
returning a concurrent Future that can be awaited. There's no need for a lock because this function itself is
|
|
258
|
+
single threaded and non-async. This pattern here is basically the trivial/degenerate case of the thread pool
|
|
259
|
+
in the LocalController.
|
|
260
|
+
Please see additional comments in protocol.
|
|
261
|
+
|
|
262
|
+
:param _task:
|
|
263
|
+
:param args:
|
|
264
|
+
:param kwargs:
|
|
265
|
+
:return:
|
|
266
|
+
"""
|
|
267
|
+
if self._submit_thread is None:
|
|
268
|
+
# Please see LocalController for the general implementation of this pattern.
|
|
269
|
+
def exc_handler(loop, context):
|
|
270
|
+
logger.error(f"Remote controller submit sync loop caught exception in {loop}: {context}")
|
|
271
|
+
|
|
272
|
+
with _selector_policy():
|
|
273
|
+
self._submit_loop = asyncio.new_event_loop()
|
|
274
|
+
self._submit_loop.set_exception_handler(exc_handler)
|
|
275
|
+
|
|
276
|
+
self._submit_thread = threading.Thread(
|
|
277
|
+
name=f"remote-controller-{os.getpid()}-submitter", daemon=True, target=self._sync_thread_loop_runner
|
|
278
|
+
)
|
|
279
|
+
self._submit_thread.start()
|
|
280
|
+
|
|
281
|
+
coro = self.submit(_task, *args, **kwargs)
|
|
282
|
+
assert self._submit_loop is not None, "Submit loop should always have been initialized by now"
|
|
283
|
+
fut = asyncio.run_coroutine_threadsafe(coro, self._submit_loop)
|
|
284
|
+
return fut
|
|
240
285
|
|
|
241
286
|
async def finalize_parent_action(self, action_id: ActionID):
|
|
242
287
|
"""
|
|
@@ -136,8 +136,9 @@ async def convert_and_run(
|
|
|
136
136
|
raw_data_path=raw_data_path,
|
|
137
137
|
compiled_image_cache=image_cache,
|
|
138
138
|
report=flyte.report.Report(name=action.name),
|
|
139
|
+
mode="remote" if not ctx.data.task_context else ctx.data.task_context.mode,
|
|
139
140
|
)
|
|
140
|
-
|
|
141
|
+
with ctx.replace_task_context(tctx):
|
|
141
142
|
inputs_kwargs = await convert_inputs_to_native(inputs, task.native_interface)
|
|
142
143
|
out, err = await run_task(tctx=tctx, controller=controller, task=task, inputs=inputs_kwargs)
|
|
143
144
|
if err is not None:
|
flyte/_map.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast
|
|
3
|
+
|
|
4
|
+
from flyte.syncify import syncify
|
|
5
|
+
|
|
6
|
+
from ._group import group
|
|
7
|
+
from ._logging import logger
|
|
8
|
+
from ._task import P, R, TaskTemplate
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MapAsyncIterator(Generic[P, R]):
|
|
12
|
+
"""AsyncIterator implementation for the map function results"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, func: TaskTemplate[P, R], args: tuple, name: str, concurrency: int, return_exceptions: bool):
|
|
15
|
+
self.func = func
|
|
16
|
+
self.args = args
|
|
17
|
+
self.name = name
|
|
18
|
+
self.concurrency = concurrency
|
|
19
|
+
self.return_exceptions = return_exceptions
|
|
20
|
+
self._tasks: List[asyncio.Task] = []
|
|
21
|
+
self._current_index = 0
|
|
22
|
+
self._completed_count = 0
|
|
23
|
+
self._exception_count = 0
|
|
24
|
+
self._task_count = 0
|
|
25
|
+
self._initialized = False
|
|
26
|
+
|
|
27
|
+
def __aiter__(self) -> AsyncIterator[Union[R, Exception]]:
|
|
28
|
+
"""Return self as the async iterator"""
|
|
29
|
+
return self
|
|
30
|
+
|
|
31
|
+
async def __anext__(self) -> Union[R, Exception]:
|
|
32
|
+
"""Get the next result"""
|
|
33
|
+
# Initialize on first call
|
|
34
|
+
if not self._initialized:
|
|
35
|
+
await self._initialize()
|
|
36
|
+
|
|
37
|
+
# Check if we've exhausted all tasks
|
|
38
|
+
if self._current_index >= self._task_count:
|
|
39
|
+
raise StopAsyncIteration
|
|
40
|
+
|
|
41
|
+
# Get the next task result
|
|
42
|
+
task = self._tasks[self._current_index]
|
|
43
|
+
self._current_index += 1
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
result = await task
|
|
47
|
+
self._completed_count += 1
|
|
48
|
+
logger.debug(f"Task {self._current_index - 1} completed successfully")
|
|
49
|
+
return result
|
|
50
|
+
except Exception as e:
|
|
51
|
+
self._exception_count += 1
|
|
52
|
+
logger.debug(f"Task {self._current_index - 1} failed with exception: {e}")
|
|
53
|
+
if self.return_exceptions:
|
|
54
|
+
return e
|
|
55
|
+
else:
|
|
56
|
+
# Cancel remaining tasks
|
|
57
|
+
for remaining_task in self._tasks[self._current_index + 1 :]:
|
|
58
|
+
remaining_task.cancel()
|
|
59
|
+
raise e
|
|
60
|
+
|
|
61
|
+
async def _initialize(self):
|
|
62
|
+
"""Initialize the tasks - called lazily on first iteration"""
|
|
63
|
+
# Create all tasks at once
|
|
64
|
+
tasks = []
|
|
65
|
+
task_count = 0
|
|
66
|
+
|
|
67
|
+
for arg_tuple in zip(*self.args):
|
|
68
|
+
task = asyncio.create_task(self.func.aio(*arg_tuple))
|
|
69
|
+
tasks.append(task)
|
|
70
|
+
task_count += 1
|
|
71
|
+
|
|
72
|
+
if task_count == 0:
|
|
73
|
+
logger.info(f"Group '{self.name}' has no tasks to process")
|
|
74
|
+
self._tasks = []
|
|
75
|
+
self._task_count = 0
|
|
76
|
+
else:
|
|
77
|
+
logger.info(f"Starting {task_count} tasks in group '{self.name}' with unlimited concurrency")
|
|
78
|
+
self._tasks = tasks
|
|
79
|
+
self._task_count = task_count
|
|
80
|
+
|
|
81
|
+
self._initialized = True
|
|
82
|
+
|
|
83
|
+
async def collect(self) -> List[Union[R, Exception]]:
|
|
84
|
+
"""Convenience method to collect all results into a list"""
|
|
85
|
+
results = []
|
|
86
|
+
async for result in self:
|
|
87
|
+
results.append(result)
|
|
88
|
+
return results
|
|
89
|
+
|
|
90
|
+
def __repr__(self):
|
|
91
|
+
return f"MapAsyncIterator(group_name='{self.name}', concurrency={self.concurrency})"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class _Mapper(Generic[P, R]):
|
|
95
|
+
"""
|
|
96
|
+
Internal mapper class to handle the mapping logic
|
|
97
|
+
|
|
98
|
+
NOTE: The reason why we do not use the `@syncify` decorator here is because, in `syncify` we cannot use
|
|
99
|
+
context managers like `group` directly in the function body. This is because the `__exit__` method of the
|
|
100
|
+
context manager is called after the function returns. An for `_context` the `__exit__` method releases the
|
|
101
|
+
token (for contextvar), which was created in a separate thread. This leads to an exception like:
|
|
102
|
+
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def _get_name(cls, task_name: str, group_name: str | None) -> str:
|
|
107
|
+
"""Get the name of the group, defaulting to 'map' if not provided."""
|
|
108
|
+
return f"{task_name}_{group_name or 'map'}"
|
|
109
|
+
|
|
110
|
+
def __call__(
|
|
111
|
+
self,
|
|
112
|
+
func: TaskTemplate[P, R],
|
|
113
|
+
*args: Iterable[Any],
|
|
114
|
+
group_name: str | None = None,
|
|
115
|
+
concurrency: int = 0,
|
|
116
|
+
return_exceptions: bool = True,
|
|
117
|
+
) -> Iterator[Union[R, Exception]]:
|
|
118
|
+
"""
|
|
119
|
+
Map a function over the provided arguments with concurrent execution.
|
|
120
|
+
|
|
121
|
+
:param func: The async function to map.
|
|
122
|
+
:param args: Positional arguments to pass to the function (iterables that will be zipped).
|
|
123
|
+
:param group_name: The name of the group for the mapped tasks.
|
|
124
|
+
:param concurrency: The maximum number of concurrent tasks to run. If 0, run all tasks concurrently.
|
|
125
|
+
:param return_exceptions: If True, yield exceptions instead of raising them.
|
|
126
|
+
:return: AsyncIterator yielding results in order.
|
|
127
|
+
"""
|
|
128
|
+
if not args:
|
|
129
|
+
return
|
|
130
|
+
|
|
131
|
+
name = self._get_name(func.name, group_name)
|
|
132
|
+
logger.debug(f"Blocking Map for {name}")
|
|
133
|
+
with group(name):
|
|
134
|
+
import flyte
|
|
135
|
+
|
|
136
|
+
tctx = flyte.ctx()
|
|
137
|
+
if tctx is None or tctx.mode == "local":
|
|
138
|
+
logger.warning("Running map in local mode, which will run every task sequentially.")
|
|
139
|
+
for v in zip(*args):
|
|
140
|
+
try:
|
|
141
|
+
yield func(*v) # type: ignore
|
|
142
|
+
except Exception as e:
|
|
143
|
+
if return_exceptions:
|
|
144
|
+
yield e
|
|
145
|
+
else:
|
|
146
|
+
raise e
|
|
147
|
+
return
|
|
148
|
+
|
|
149
|
+
i = 0
|
|
150
|
+
for x in cast(
|
|
151
|
+
Iterator[R],
|
|
152
|
+
_map(
|
|
153
|
+
func,
|
|
154
|
+
*args,
|
|
155
|
+
name=name,
|
|
156
|
+
concurrency=concurrency,
|
|
157
|
+
return_exceptions=True,
|
|
158
|
+
),
|
|
159
|
+
):
|
|
160
|
+
logger.debug(f"Mapped {x}, task {i}")
|
|
161
|
+
i += 1
|
|
162
|
+
yield x
|
|
163
|
+
|
|
164
|
+
async def aio(
|
|
165
|
+
self,
|
|
166
|
+
func: TaskTemplate[P, R],
|
|
167
|
+
*args: Iterable[Any],
|
|
168
|
+
group_name: str | None = None,
|
|
169
|
+
concurrency: int = 0,
|
|
170
|
+
return_exceptions: bool = True,
|
|
171
|
+
) -> AsyncGenerator[Union[R, Exception], None]:
|
|
172
|
+
if not args:
|
|
173
|
+
return
|
|
174
|
+
name = self._get_name(func.name, group_name)
|
|
175
|
+
with group(name):
|
|
176
|
+
import flyte
|
|
177
|
+
|
|
178
|
+
tctx = flyte.ctx()
|
|
179
|
+
if tctx is None or tctx.mode == "local":
|
|
180
|
+
logger.warning("Running map in local mode, which will run every task sequentially.")
|
|
181
|
+
for v in zip(*args):
|
|
182
|
+
try:
|
|
183
|
+
yield func(*v) # type: ignore
|
|
184
|
+
except Exception as e:
|
|
185
|
+
if return_exceptions:
|
|
186
|
+
yield e
|
|
187
|
+
else:
|
|
188
|
+
raise e
|
|
189
|
+
return
|
|
190
|
+
async for x in _map.aio(
|
|
191
|
+
func,
|
|
192
|
+
*args,
|
|
193
|
+
name=name,
|
|
194
|
+
concurrency=concurrency,
|
|
195
|
+
return_exceptions=return_exceptions,
|
|
196
|
+
):
|
|
197
|
+
yield cast(Union[R, Exception], x)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@syncify
|
|
201
|
+
async def _map(
|
|
202
|
+
func: TaskTemplate[P, R],
|
|
203
|
+
*args: Iterable[Any],
|
|
204
|
+
name: str = "map",
|
|
205
|
+
concurrency: int = 0,
|
|
206
|
+
return_exceptions: bool = True,
|
|
207
|
+
) -> AsyncIterator[Union[R, Exception]]:
|
|
208
|
+
iter = MapAsyncIterator(
|
|
209
|
+
func=func, args=args, name=name, concurrency=concurrency, return_exceptions=return_exceptions
|
|
210
|
+
)
|
|
211
|
+
async for result in iter:
|
|
212
|
+
yield result
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
map: _Mapper = _Mapper()
|