flyte 0.2.0b7__py3-none-any.whl → 0.2.0b9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +2 -0
- flyte/_context.py +7 -1
- flyte/_environment.py +42 -1
- flyte/_group.py +1 -0
- flyte/_image.py +1 -2
- flyte/_internal/controllers/__init__.py +14 -1
- flyte/_internal/controllers/_local_controller.py +66 -1
- flyte/_internal/controllers/remote/_controller.py +48 -0
- flyte/_internal/controllers/remote/_informer.py +3 -3
- flyte/_internal/runtime/taskrunner.py +2 -1
- flyte/_map.py +215 -0
- flyte/_run.py +26 -26
- flyte/_task.py +101 -13
- flyte/_task_environment.py +48 -66
- flyte/_utils/coro_management.py +0 -2
- flyte/_utils/helpers.py +15 -0
- flyte/_version.py +2 -2
- flyte/cli/__init__.py +0 -7
- flyte/cli/_abort.py +1 -1
- flyte/cli/_common.py +4 -3
- flyte/cli/_create.py +69 -23
- flyte/cli/_delete.py +2 -2
- flyte/cli/_deploy.py +7 -4
- flyte/cli/_gen.py +163 -0
- flyte/cli/_get.py +62 -8
- flyte/cli/_run.py +22 -2
- flyte/cli/main.py +63 -13
- flyte/extras/_container.py +1 -1
- flyte/models.py +10 -1
- flyte/remote/_run.py +1 -0
- flyte/syncify/__init__.py +51 -0
- flyte/syncify/_api.py +48 -21
- {flyte-0.2.0b7.dist-info → flyte-0.2.0b9.dist-info}/METADATA +30 -4
- {flyte-0.2.0b7.dist-info → flyte-0.2.0b9.dist-info}/RECORD +37 -35
- {flyte-0.2.0b7.dist-info → flyte-0.2.0b9.dist-info}/WHEEL +0 -0
- {flyte-0.2.0b7.dist-info → flyte-0.2.0b9.dist-info}/entry_points.txt +0 -0
- {flyte-0.2.0b7.dist-info → flyte-0.2.0b9.dist-info}/top_level.txt +0 -0
flyte/__init__.py
CHANGED
|
@@ -39,6 +39,7 @@ __all__ = [
|
|
|
39
39
|
"group",
|
|
40
40
|
"init",
|
|
41
41
|
"init_auto_from_config",
|
|
42
|
+
"map",
|
|
42
43
|
"run",
|
|
43
44
|
"trace",
|
|
44
45
|
"with_runcontext",
|
|
@@ -51,6 +52,7 @@ from ._environment import Environment
|
|
|
51
52
|
from ._group import group
|
|
52
53
|
from ._image import Image
|
|
53
54
|
from ._initialize import init, init_auto_from_config
|
|
55
|
+
from ._map import map
|
|
54
56
|
from ._resources import GPU, TPU, Device, Resources
|
|
55
57
|
from ._retry import RetryStrategy
|
|
56
58
|
from ._reusable_environment import ReusePolicy
|
flyte/_context.py
CHANGED
|
@@ -4,6 +4,7 @@ import contextvars
|
|
|
4
4
|
from dataclasses import dataclass, replace
|
|
5
5
|
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, ParamSpec, TypeVar
|
|
6
6
|
|
|
7
|
+
from flyte._logging import logger
|
|
7
8
|
from flyte.models import GroupData, RawDataPath, TaskContext
|
|
8
9
|
|
|
9
10
|
if TYPE_CHECKING:
|
|
@@ -49,6 +50,7 @@ class Context:
|
|
|
49
50
|
raise ValueError("Cannot create a new context without contextdata.")
|
|
50
51
|
self._data = data
|
|
51
52
|
self._id = id(self) # Immutable unique identifier
|
|
53
|
+
self._token = None # Context variable token to restore the previous context
|
|
52
54
|
|
|
53
55
|
@property
|
|
54
56
|
def data(self) -> ContextData:
|
|
@@ -106,7 +108,11 @@ class Context:
|
|
|
106
108
|
|
|
107
109
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
108
110
|
"""Exit the context, restoring the previous context."""
|
|
109
|
-
|
|
111
|
+
try:
|
|
112
|
+
root_context_var.reset(self._token)
|
|
113
|
+
except Exception as e:
|
|
114
|
+
logger.warn(f"Failed to reset context: {e}")
|
|
115
|
+
raise e
|
|
110
116
|
|
|
111
117
|
async def __aenter__(self):
|
|
112
118
|
"""Async version of context entry."""
|
flyte/_environment.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import re
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
|
-
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
|
5
6
|
|
|
6
7
|
import rich.repr
|
|
7
8
|
|
|
@@ -14,6 +15,10 @@ if TYPE_CHECKING:
|
|
|
14
15
|
from kubernetes.client import V1PodTemplate
|
|
15
16
|
|
|
16
17
|
|
|
18
|
+
def is_snake_or_kebab_with_numbers(s: str) -> bool:
|
|
19
|
+
return re.fullmatch(r"^[a-z0-9]+([_-][a-z0-9]+)*$", s) is not None
|
|
20
|
+
|
|
21
|
+
|
|
17
22
|
@rich.repr.auto
|
|
18
23
|
@dataclass(init=True, repr=True)
|
|
19
24
|
class Environment:
|
|
@@ -36,8 +41,44 @@ class Environment:
|
|
|
36
41
|
resources: Optional[Resources] = None
|
|
37
42
|
image: Union[str, Image, Literal["auto"]] = "auto"
|
|
38
43
|
|
|
44
|
+
def __post_init__(self):
|
|
45
|
+
if not is_snake_or_kebab_with_numbers(self.name):
|
|
46
|
+
raise ValueError(f"Environment name '{self.name}' must be in snake_case or kebab-case format.")
|
|
47
|
+
|
|
39
48
|
def add_dependency(self, *env: Environment):
|
|
40
49
|
"""
|
|
41
50
|
Add a dependency to the environment.
|
|
42
51
|
"""
|
|
43
52
|
self.env_dep_hints.extend(env)
|
|
53
|
+
|
|
54
|
+
def clone_with(
|
|
55
|
+
self,
|
|
56
|
+
name: str,
|
|
57
|
+
image: Optional[Union[str, Image, Literal["auto"]]] = None,
|
|
58
|
+
resources: Optional[Resources] = None,
|
|
59
|
+
env: Optional[Dict[str, str]] = None,
|
|
60
|
+
secrets: Optional[SecretRequest] = None,
|
|
61
|
+
env_dep_hints: Optional[List[Environment]] = None,
|
|
62
|
+
**kwargs: Any,
|
|
63
|
+
) -> Environment:
|
|
64
|
+
raise NotImplementedError
|
|
65
|
+
|
|
66
|
+
def _get_kwargs(self) -> Dict[str, Any]:
|
|
67
|
+
"""
|
|
68
|
+
Get the keyword arguments for the environment.
|
|
69
|
+
"""
|
|
70
|
+
kwargs: Dict[str, Any] = {
|
|
71
|
+
"env_dep_hints": self.env_dep_hints,
|
|
72
|
+
"image": self.image,
|
|
73
|
+
}
|
|
74
|
+
if self.resources is not None:
|
|
75
|
+
kwargs["resources"] = self.resources
|
|
76
|
+
if self.secrets is not None:
|
|
77
|
+
kwargs["secrets"] = self.secrets
|
|
78
|
+
if self.env is not None:
|
|
79
|
+
kwargs["env"] = self.env
|
|
80
|
+
if self.pod_template is not None:
|
|
81
|
+
kwargs["pod_template"] = self.pod_template
|
|
82
|
+
if self.description is not None:
|
|
83
|
+
kwargs["description"] = self.description
|
|
84
|
+
return kwargs
|
flyte/_group.py
CHANGED
flyte/_image.py
CHANGED
|
@@ -444,8 +444,7 @@ class Image:
|
|
|
444
444
|
```
|
|
445
445
|
|
|
446
446
|
For more information on the uv script format, see the documentation:
|
|
447
|
-
|
|
448
|
-
UV: Declaring script dependencies</href>
|
|
447
|
+
[UV: Declaring script dependencies](https://docs.astral.sh/uv/guides/scripts/#declaring-script-dependencies)
|
|
449
448
|
|
|
450
449
|
:param name: name of the image
|
|
451
450
|
:param registry: registry to use for the image
|
|
@@ -1,5 +1,6 @@
|
|
|
1
|
+
import concurrent.futures
|
|
1
2
|
import threading
|
|
2
|
-
from typing import Any, Callable, Literal, Optional, Protocol, Tuple, TypeVar
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Protocol, Tuple, TypeVar
|
|
3
4
|
|
|
4
5
|
from flyte._task import TaskTemplate
|
|
5
6
|
from flyte.models import ActionID, NativeInterface
|
|
@@ -10,6 +11,9 @@ __all__ = ["Controller", "ControllerType", "TraceInfo", "create_controller", "ge
|
|
|
10
11
|
|
|
11
12
|
from ..._protos.workflow import task_definition_pb2
|
|
12
13
|
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import concurrent.futures
|
|
16
|
+
|
|
13
17
|
ControllerType = Literal["local", "remote"]
|
|
14
18
|
|
|
15
19
|
R = TypeVar("R")
|
|
@@ -28,6 +32,15 @@ class Controller(Protocol):
|
|
|
28
32
|
"""
|
|
29
33
|
...
|
|
30
34
|
|
|
35
|
+
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
|
|
36
|
+
"""
|
|
37
|
+
This should call the async submit method above, but return a concurrent Future object that can be
|
|
38
|
+
used in a blocking wait or wrapped in an async future. This is called when
|
|
39
|
+
a) a synchronous task is kicked off locally,
|
|
40
|
+
b) a running task (of either kind) kicks off a downstream synchronous task.
|
|
41
|
+
"""
|
|
42
|
+
...
|
|
43
|
+
|
|
31
44
|
async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
|
|
32
45
|
"""
|
|
33
46
|
Submit a task reference to the controller asynchronously and wait for the result. This is async and will block
|
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import atexit
|
|
3
|
+
import concurrent.futures
|
|
4
|
+
import os
|
|
5
|
+
import threading
|
|
1
6
|
from typing import Any, Callable, Tuple, TypeVar
|
|
2
7
|
|
|
3
8
|
import flyte.errors
|
|
@@ -8,19 +13,67 @@ from flyte._internal.runtime.entrypoints import direct_dispatch
|
|
|
8
13
|
from flyte._logging import log, logger
|
|
9
14
|
from flyte._protos.workflow import task_definition_pb2
|
|
10
15
|
from flyte._task import TaskTemplate
|
|
16
|
+
from flyte._utils.helpers import _selector_policy
|
|
11
17
|
from flyte.models import ActionID, NativeInterface, RawDataPath
|
|
12
18
|
|
|
13
19
|
R = TypeVar("R")
|
|
14
20
|
|
|
15
21
|
|
|
22
|
+
class _TaskRunner:
|
|
23
|
+
"""A task runner that runs an asyncio event loop on a background thread."""
|
|
24
|
+
|
|
25
|
+
def __init__(self) -> None:
|
|
26
|
+
self.__loop: asyncio.AbstractEventLoop | None = None
|
|
27
|
+
self.__runner_thread: threading.Thread | None = None
|
|
28
|
+
self.__lock = threading.Lock()
|
|
29
|
+
atexit.register(self._close)
|
|
30
|
+
|
|
31
|
+
def _close(self) -> None:
|
|
32
|
+
if self.__loop:
|
|
33
|
+
self.__loop.stop()
|
|
34
|
+
|
|
35
|
+
def _execute(self) -> None:
|
|
36
|
+
loop = self.__loop
|
|
37
|
+
assert loop is not None
|
|
38
|
+
try:
|
|
39
|
+
loop.run_forever()
|
|
40
|
+
finally:
|
|
41
|
+
loop.close()
|
|
42
|
+
|
|
43
|
+
def get_exc_handler(self):
|
|
44
|
+
def exc_handler(loop, context):
|
|
45
|
+
logger.error(
|
|
46
|
+
f"Taskrunner for {self.__runner_thread.name if self.__runner_thread else 'no thread'} caught"
|
|
47
|
+
f" exception in {loop}: {context}"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return exc_handler
|
|
51
|
+
|
|
52
|
+
def get_run_future(self, coro: Any) -> concurrent.futures.Future:
|
|
53
|
+
"""Synchronously run a coroutine on a background thread."""
|
|
54
|
+
name = f"{threading.current_thread().name} : loop-runner"
|
|
55
|
+
with self.__lock:
|
|
56
|
+
if self.__loop is None:
|
|
57
|
+
with _selector_policy():
|
|
58
|
+
self.__loop = asyncio.new_event_loop()
|
|
59
|
+
|
|
60
|
+
exc_handler = self.get_exc_handler()
|
|
61
|
+
self.__loop.set_exception_handler(exc_handler)
|
|
62
|
+
self.__runner_thread = threading.Thread(target=self._execute, daemon=True, name=name)
|
|
63
|
+
self.__runner_thread.start()
|
|
64
|
+
fut = asyncio.run_coroutine_threadsafe(coro, self.__loop)
|
|
65
|
+
return fut
|
|
66
|
+
|
|
67
|
+
|
|
16
68
|
class LocalController:
|
|
17
69
|
def __init__(self):
|
|
18
70
|
logger.debug("LocalController init")
|
|
71
|
+
self._runner_map: dict[str, _TaskRunner] = {}
|
|
19
72
|
|
|
20
73
|
@log
|
|
21
74
|
async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
22
75
|
"""
|
|
23
|
-
|
|
76
|
+
Main entrypoint for submitting a task to the local controller.
|
|
24
77
|
"""
|
|
25
78
|
ctx = internal_ctx()
|
|
26
79
|
tctx = ctx.data.task_context
|
|
@@ -58,6 +111,18 @@ class LocalController:
|
|
|
58
111
|
return result
|
|
59
112
|
return out
|
|
60
113
|
|
|
114
|
+
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
|
|
115
|
+
name = threading.current_thread().name + f"PID:{os.getpid()}"
|
|
116
|
+
coro = self.submit(_task, *args, **kwargs)
|
|
117
|
+
if name not in self._runner_map:
|
|
118
|
+
if len(self._runner_map) > 100:
|
|
119
|
+
logger.warning(
|
|
120
|
+
"More than 100 event loop runners created!!! This could be a case of runaway recursion..."
|
|
121
|
+
)
|
|
122
|
+
self._runner_map[name] = _TaskRunner()
|
|
123
|
+
|
|
124
|
+
return self._runner_map[name].get_run_future(coro)
|
|
125
|
+
|
|
61
126
|
async def finalize_parent_action(self, action: ActionID):
|
|
62
127
|
pass
|
|
63
128
|
|
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import concurrent.futures
|
|
5
|
+
import os
|
|
6
|
+
import threading
|
|
4
7
|
from collections import defaultdict
|
|
5
8
|
from collections.abc import Callable
|
|
6
9
|
from pathlib import Path
|
|
@@ -21,6 +24,7 @@ from flyte._internal.runtime.task_serde import translate_task_to_wire
|
|
|
21
24
|
from flyte._logging import logger
|
|
22
25
|
from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
|
|
23
26
|
from flyte._task import TaskTemplate
|
|
27
|
+
from flyte._utils.helpers import _selector_policy
|
|
24
28
|
from flyte.models import ActionID, NativeInterface, SerializationContext
|
|
25
29
|
|
|
26
30
|
R = TypeVar("R")
|
|
@@ -119,6 +123,8 @@ class RemoteController(Controller):
|
|
|
119
123
|
self._parent_action_task_call_sequence: DefaultDict[str, DefaultDict[int, int]] = defaultdict(
|
|
120
124
|
lambda: defaultdict(int)
|
|
121
125
|
)
|
|
126
|
+
self._submit_loop: asyncio.AbstractEventLoop | None = None
|
|
127
|
+
self._submit_thread: threading.Thread | None = None
|
|
122
128
|
|
|
123
129
|
def generate_task_call_sequence(self, task_obj: object, action_id: ActionID) -> int:
|
|
124
130
|
"""
|
|
@@ -235,6 +241,48 @@ class RemoteController(Controller):
|
|
|
235
241
|
async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
|
|
236
242
|
return await self._submit(task_call_seq, _task, *args, **kwargs)
|
|
237
243
|
|
|
244
|
+
def _sync_thread_loop_runner(self) -> None:
|
|
245
|
+
"""This method runs the event loop and should be invoked in a separate thread."""
|
|
246
|
+
|
|
247
|
+
loop = self._submit_loop
|
|
248
|
+
assert loop is not None
|
|
249
|
+
try:
|
|
250
|
+
loop.run_forever()
|
|
251
|
+
finally:
|
|
252
|
+
loop.close()
|
|
253
|
+
|
|
254
|
+
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
|
|
255
|
+
"""
|
|
256
|
+
This function creates a cached thread and loop for the purpose of calling the submit method synchronously,
|
|
257
|
+
returning a concurrent Future that can be awaited. There's no need for a lock because this function itself is
|
|
258
|
+
single threaded and non-async. This pattern here is basically the trivial/degenerate case of the thread pool
|
|
259
|
+
in the LocalController.
|
|
260
|
+
Please see additional comments in protocol.
|
|
261
|
+
|
|
262
|
+
:param _task:
|
|
263
|
+
:param args:
|
|
264
|
+
:param kwargs:
|
|
265
|
+
:return:
|
|
266
|
+
"""
|
|
267
|
+
if self._submit_thread is None:
|
|
268
|
+
# Please see LocalController for the general implementation of this pattern.
|
|
269
|
+
def exc_handler(loop, context):
|
|
270
|
+
logger.error(f"Remote controller submit sync loop caught exception in {loop}: {context}")
|
|
271
|
+
|
|
272
|
+
with _selector_policy():
|
|
273
|
+
self._submit_loop = asyncio.new_event_loop()
|
|
274
|
+
self._submit_loop.set_exception_handler(exc_handler)
|
|
275
|
+
|
|
276
|
+
self._submit_thread = threading.Thread(
|
|
277
|
+
name=f"remote-controller-{os.getpid()}-submitter", daemon=True, target=self._sync_thread_loop_runner
|
|
278
|
+
)
|
|
279
|
+
self._submit_thread.start()
|
|
280
|
+
|
|
281
|
+
coro = self.submit(_task, *args, **kwargs)
|
|
282
|
+
assert self._submit_loop is not None, "Submit loop should always have been initialized by now"
|
|
283
|
+
fut = asyncio.run_coroutine_threadsafe(coro, self._submit_loop)
|
|
284
|
+
return fut
|
|
285
|
+
|
|
238
286
|
async def finalize_parent_action(self, action_id: ActionID):
|
|
239
287
|
"""
|
|
240
288
|
This method is invoked when the parent action is finished. It will finalize the run and upload the outputs
|
|
@@ -38,11 +38,11 @@ class ActionCache:
|
|
|
38
38
|
"""
|
|
39
39
|
Add an action to the cache if it doesn't exist. This is invoked by the watch.
|
|
40
40
|
"""
|
|
41
|
-
logger.
|
|
41
|
+
logger.debug(f"Observing phase {run_definition_pb2.Phase.Name(state.phase)} for {state.action_id.name}")
|
|
42
42
|
if state.output_uri:
|
|
43
|
-
logger.
|
|
43
|
+
logger.debug(f"Output URI: {state.output_uri}")
|
|
44
44
|
else:
|
|
45
|
-
logger.
|
|
45
|
+
logger.warning(f"{state.action_id.name} has no output URI")
|
|
46
46
|
if state.phase == run_definition_pb2.Phase.PHASE_FAILED:
|
|
47
47
|
logger.error(
|
|
48
48
|
f"Action {state.action_id.name} failed with error (msg):"
|
|
@@ -136,8 +136,9 @@ async def convert_and_run(
|
|
|
136
136
|
raw_data_path=raw_data_path,
|
|
137
137
|
compiled_image_cache=image_cache,
|
|
138
138
|
report=flyte.report.Report(name=action.name),
|
|
139
|
+
mode="remote" if not ctx.data.task_context else ctx.data.task_context.mode,
|
|
139
140
|
)
|
|
140
|
-
|
|
141
|
+
with ctx.replace_task_context(tctx):
|
|
141
142
|
inputs_kwargs = await convert_inputs_to_native(inputs, task.native_interface)
|
|
142
143
|
out, err = await run_task(tctx=tctx, controller=controller, task=task, inputs=inputs_kwargs)
|
|
143
144
|
if err is not None:
|
flyte/_map.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast
|
|
3
|
+
|
|
4
|
+
from flyte.syncify import syncify
|
|
5
|
+
|
|
6
|
+
from ._group import group
|
|
7
|
+
from ._logging import logger
|
|
8
|
+
from ._task import P, R, TaskTemplate
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MapAsyncIterator(Generic[P, R]):
|
|
12
|
+
"""AsyncIterator implementation for the map function results"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, func: TaskTemplate[P, R], args: tuple, name: str, concurrency: int, return_exceptions: bool):
|
|
15
|
+
self.func = func
|
|
16
|
+
self.args = args
|
|
17
|
+
self.name = name
|
|
18
|
+
self.concurrency = concurrency
|
|
19
|
+
self.return_exceptions = return_exceptions
|
|
20
|
+
self._tasks: List[asyncio.Task] = []
|
|
21
|
+
self._current_index = 0
|
|
22
|
+
self._completed_count = 0
|
|
23
|
+
self._exception_count = 0
|
|
24
|
+
self._task_count = 0
|
|
25
|
+
self._initialized = False
|
|
26
|
+
|
|
27
|
+
def __aiter__(self) -> AsyncIterator[Union[R, Exception]]:
|
|
28
|
+
"""Return self as the async iterator"""
|
|
29
|
+
return self
|
|
30
|
+
|
|
31
|
+
async def __anext__(self) -> Union[R, Exception]:
|
|
32
|
+
"""Get the next result"""
|
|
33
|
+
# Initialize on first call
|
|
34
|
+
if not self._initialized:
|
|
35
|
+
await self._initialize()
|
|
36
|
+
|
|
37
|
+
# Check if we've exhausted all tasks
|
|
38
|
+
if self._current_index >= self._task_count:
|
|
39
|
+
raise StopAsyncIteration
|
|
40
|
+
|
|
41
|
+
# Get the next task result
|
|
42
|
+
task = self._tasks[self._current_index]
|
|
43
|
+
self._current_index += 1
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
result = await task
|
|
47
|
+
self._completed_count += 1
|
|
48
|
+
logger.debug(f"Task {self._current_index - 1} completed successfully")
|
|
49
|
+
return result
|
|
50
|
+
except Exception as e:
|
|
51
|
+
self._exception_count += 1
|
|
52
|
+
logger.debug(f"Task {self._current_index - 1} failed with exception: {e}")
|
|
53
|
+
if self.return_exceptions:
|
|
54
|
+
return e
|
|
55
|
+
else:
|
|
56
|
+
# Cancel remaining tasks
|
|
57
|
+
for remaining_task in self._tasks[self._current_index + 1 :]:
|
|
58
|
+
remaining_task.cancel()
|
|
59
|
+
raise e
|
|
60
|
+
|
|
61
|
+
async def _initialize(self):
|
|
62
|
+
"""Initialize the tasks - called lazily on first iteration"""
|
|
63
|
+
# Create all tasks at once
|
|
64
|
+
tasks = []
|
|
65
|
+
task_count = 0
|
|
66
|
+
|
|
67
|
+
for arg_tuple in zip(*self.args):
|
|
68
|
+
task = asyncio.create_task(self.func.aio(*arg_tuple))
|
|
69
|
+
tasks.append(task)
|
|
70
|
+
task_count += 1
|
|
71
|
+
|
|
72
|
+
if task_count == 0:
|
|
73
|
+
logger.info(f"Group '{self.name}' has no tasks to process")
|
|
74
|
+
self._tasks = []
|
|
75
|
+
self._task_count = 0
|
|
76
|
+
else:
|
|
77
|
+
logger.info(f"Starting {task_count} tasks in group '{self.name}' with unlimited concurrency")
|
|
78
|
+
self._tasks = tasks
|
|
79
|
+
self._task_count = task_count
|
|
80
|
+
|
|
81
|
+
self._initialized = True
|
|
82
|
+
|
|
83
|
+
async def collect(self) -> List[Union[R, Exception]]:
|
|
84
|
+
"""Convenience method to collect all results into a list"""
|
|
85
|
+
results = []
|
|
86
|
+
async for result in self:
|
|
87
|
+
results.append(result)
|
|
88
|
+
return results
|
|
89
|
+
|
|
90
|
+
def __repr__(self):
|
|
91
|
+
return f"MapAsyncIterator(group_name='{self.name}', concurrency={self.concurrency})"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class _Mapper(Generic[P, R]):
|
|
95
|
+
"""
|
|
96
|
+
Internal mapper class to handle the mapping logic
|
|
97
|
+
|
|
98
|
+
NOTE: The reason why we do not use the `@syncify` decorator here is because, in `syncify` we cannot use
|
|
99
|
+
context managers like `group` directly in the function body. This is because the `__exit__` method of the
|
|
100
|
+
context manager is called after the function returns. An for `_context` the `__exit__` method releases the
|
|
101
|
+
token (for contextvar), which was created in a separate thread. This leads to an exception like:
|
|
102
|
+
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def _get_name(cls, task_name: str, group_name: str | None) -> str:
|
|
107
|
+
"""Get the name of the group, defaulting to 'map' if not provided."""
|
|
108
|
+
return f"{task_name}_{group_name or 'map'}"
|
|
109
|
+
|
|
110
|
+
def __call__(
|
|
111
|
+
self,
|
|
112
|
+
func: TaskTemplate[P, R],
|
|
113
|
+
*args: Iterable[Any],
|
|
114
|
+
group_name: str | None = None,
|
|
115
|
+
concurrency: int = 0,
|
|
116
|
+
return_exceptions: bool = True,
|
|
117
|
+
) -> Iterator[Union[R, Exception]]:
|
|
118
|
+
"""
|
|
119
|
+
Map a function over the provided arguments with concurrent execution.
|
|
120
|
+
|
|
121
|
+
:param func: The async function to map.
|
|
122
|
+
:param args: Positional arguments to pass to the function (iterables that will be zipped).
|
|
123
|
+
:param group_name: The name of the group for the mapped tasks.
|
|
124
|
+
:param concurrency: The maximum number of concurrent tasks to run. If 0, run all tasks concurrently.
|
|
125
|
+
:param return_exceptions: If True, yield exceptions instead of raising them.
|
|
126
|
+
:return: AsyncIterator yielding results in order.
|
|
127
|
+
"""
|
|
128
|
+
if not args:
|
|
129
|
+
return
|
|
130
|
+
|
|
131
|
+
name = self._get_name(func.name, group_name)
|
|
132
|
+
logger.debug(f"Blocking Map for {name}")
|
|
133
|
+
with group(name):
|
|
134
|
+
import flyte
|
|
135
|
+
|
|
136
|
+
tctx = flyte.ctx()
|
|
137
|
+
if tctx is None or tctx.mode == "local":
|
|
138
|
+
logger.warning("Running map in local mode, which will run every task sequentially.")
|
|
139
|
+
for v in zip(*args):
|
|
140
|
+
try:
|
|
141
|
+
yield func(*v) # type: ignore
|
|
142
|
+
except Exception as e:
|
|
143
|
+
if return_exceptions:
|
|
144
|
+
yield e
|
|
145
|
+
else:
|
|
146
|
+
raise e
|
|
147
|
+
return
|
|
148
|
+
|
|
149
|
+
i = 0
|
|
150
|
+
for x in cast(
|
|
151
|
+
Iterator[R],
|
|
152
|
+
_map(
|
|
153
|
+
func,
|
|
154
|
+
*args,
|
|
155
|
+
name=name,
|
|
156
|
+
concurrency=concurrency,
|
|
157
|
+
return_exceptions=True,
|
|
158
|
+
),
|
|
159
|
+
):
|
|
160
|
+
logger.debug(f"Mapped {x}, task {i}")
|
|
161
|
+
i += 1
|
|
162
|
+
yield x
|
|
163
|
+
|
|
164
|
+
async def aio(
|
|
165
|
+
self,
|
|
166
|
+
func: TaskTemplate[P, R],
|
|
167
|
+
*args: Iterable[Any],
|
|
168
|
+
group_name: str | None = None,
|
|
169
|
+
concurrency: int = 0,
|
|
170
|
+
return_exceptions: bool = True,
|
|
171
|
+
) -> AsyncGenerator[Union[R, Exception], None]:
|
|
172
|
+
if not args:
|
|
173
|
+
return
|
|
174
|
+
name = self._get_name(func.name, group_name)
|
|
175
|
+
with group(name):
|
|
176
|
+
import flyte
|
|
177
|
+
|
|
178
|
+
tctx = flyte.ctx()
|
|
179
|
+
if tctx is None or tctx.mode == "local":
|
|
180
|
+
logger.warning("Running map in local mode, which will run every task sequentially.")
|
|
181
|
+
for v in zip(*args):
|
|
182
|
+
try:
|
|
183
|
+
yield func(*v) # type: ignore
|
|
184
|
+
except Exception as e:
|
|
185
|
+
if return_exceptions:
|
|
186
|
+
yield e
|
|
187
|
+
else:
|
|
188
|
+
raise e
|
|
189
|
+
return
|
|
190
|
+
async for x in _map.aio(
|
|
191
|
+
func,
|
|
192
|
+
*args,
|
|
193
|
+
name=name,
|
|
194
|
+
concurrency=concurrency,
|
|
195
|
+
return_exceptions=return_exceptions,
|
|
196
|
+
):
|
|
197
|
+
yield cast(Union[R, Exception], x)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@syncify
|
|
201
|
+
async def _map(
|
|
202
|
+
func: TaskTemplate[P, R],
|
|
203
|
+
*args: Iterable[Any],
|
|
204
|
+
name: str = "map",
|
|
205
|
+
concurrency: int = 0,
|
|
206
|
+
return_exceptions: bool = True,
|
|
207
|
+
) -> AsyncIterator[Union[R, Exception]]:
|
|
208
|
+
iter = MapAsyncIterator(
|
|
209
|
+
func=func, args=args, name=name, concurrency=concurrency, return_exceptions=return_exceptions
|
|
210
|
+
)
|
|
211
|
+
async for result in iter:
|
|
212
|
+
yield result
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
map: _Mapper = _Mapper()
|
flyte/_run.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import pathlib
|
|
4
5
|
import uuid
|
|
5
6
|
from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union, cast
|
|
@@ -287,40 +288,39 @@ class _Runner:
|
|
|
287
288
|
|
|
288
289
|
async def _run_local(self, obj: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
|
289
290
|
from flyte._internal.controllers import create_controller
|
|
290
|
-
from flyte._internal.
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
)
|
|
295
|
-
from flyte._internal.runtime.entrypoints import direct_dispatch
|
|
291
|
+
from flyte._internal.controllers._local_controller import LocalController
|
|
292
|
+
from flyte.report import Report
|
|
293
|
+
|
|
294
|
+
controller = cast(LocalController, create_controller("local"))
|
|
296
295
|
|
|
297
|
-
controller = create_controller("local")
|
|
298
|
-
inputs = await convert_from_native_to_inputs(obj.native_interface, *args, **kwargs)
|
|
299
296
|
if self._name is None:
|
|
300
297
|
action = ActionID.create_random()
|
|
301
298
|
else:
|
|
302
299
|
action = ActionID(name=self._name)
|
|
303
|
-
|
|
304
|
-
|
|
300
|
+
|
|
301
|
+
ctx = internal_ctx()
|
|
302
|
+
tctx = TaskContext(
|
|
305
303
|
action=action,
|
|
306
|
-
raw_data_path=internal_ctx().raw_data,
|
|
307
|
-
version="na",
|
|
308
|
-
controller=controller,
|
|
309
|
-
inputs=inputs,
|
|
310
|
-
output_path=self._metadata_path,
|
|
311
|
-
run_base_dir=self._metadata_path,
|
|
312
304
|
checkpoints=Checkpoints(
|
|
313
305
|
prev_checkpoint_path=internal_ctx().raw_data.path, checkpoint_path=internal_ctx().raw_data.path
|
|
314
306
|
),
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
307
|
+
code_bundle=None,
|
|
308
|
+
output_path=self._metadata_path,
|
|
309
|
+
run_base_dir=self._metadata_path,
|
|
310
|
+
version="na",
|
|
311
|
+
raw_data_path=internal_ctx().raw_data,
|
|
312
|
+
compiled_image_cache=None,
|
|
313
|
+
report=Report(name=action.name),
|
|
314
|
+
mode="local",
|
|
315
|
+
)
|
|
316
|
+
async with ctx.replace_task_context(tctx):
|
|
317
|
+
# make the local version always runs on a different thread, returns a wrapped future.
|
|
318
|
+
if obj._call_as_synchronous:
|
|
319
|
+
fut = controller.submit_sync(obj, *args, **kwargs)
|
|
320
|
+
awaitable = asyncio.wrap_future(fut)
|
|
321
|
+
return await awaitable
|
|
322
|
+
else:
|
|
323
|
+
return await controller.submit(obj, *args, **kwargs)
|
|
324
324
|
|
|
325
325
|
@syncify
|
|
326
326
|
async def run(self, task: TaskTemplate[P, Union[R, Run]], *args: P.args, **kwargs: P.kwargs) -> Union[R, Run]:
|
|
@@ -351,7 +351,7 @@ class _Runner:
|
|
|
351
351
|
return await self._run_hybrid(task, *args, **kwargs)
|
|
352
352
|
|
|
353
353
|
# TODO We could use this for remote as well and users could simply pass flyte:// or s3:// or file://
|
|
354
|
-
|
|
354
|
+
with internal_ctx().new_raw_data_path(
|
|
355
355
|
raw_data_path=RawDataPath.from_local_folder(local_folder=self._raw_data_path)
|
|
356
356
|
):
|
|
357
357
|
return await self._run_local(task, *args, **kwargs)
|