flyte 0.2.0b8__py3-none-any.whl → 0.2.0b9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +2 -0
- flyte/_context.py +7 -1
- flyte/_group.py +1 -0
- flyte/_internal/controllers/__init__.py +13 -2
- flyte/_internal/controllers/_local_controller.py +65 -3
- flyte/_internal/controllers/remote/_controller.py +47 -2
- flyte/_internal/runtime/taskrunner.py +2 -1
- flyte/_map.py +215 -0
- flyte/_run.py +26 -26
- flyte/_task.py +56 -7
- flyte/_utils/helpers.py +15 -0
- flyte/_version.py +2 -2
- flyte/cli/__init__.py +0 -7
- flyte/cli/_abort.py +1 -1
- flyte/cli/_common.py +3 -3
- flyte/cli/_create.py +44 -29
- flyte/cli/_delete.py +2 -2
- flyte/cli/_deploy.py +3 -3
- flyte/cli/_gen.py +12 -4
- flyte/cli/_get.py +35 -27
- flyte/cli/main.py +32 -29
- flyte/models.py +10 -1
- flyte/syncify/_api.py +43 -15
- {flyte-0.2.0b8.dist-info → flyte-0.2.0b9.dist-info}/METADATA +3 -1
- {flyte-0.2.0b8.dist-info → flyte-0.2.0b9.dist-info}/RECORD +28 -27
- {flyte-0.2.0b8.dist-info → flyte-0.2.0b9.dist-info}/WHEEL +0 -0
- {flyte-0.2.0b8.dist-info → flyte-0.2.0b9.dist-info}/entry_points.txt +0 -0
- {flyte-0.2.0b8.dist-info → flyte-0.2.0b9.dist-info}/top_level.txt +0 -0
flyte/__init__.py
CHANGED
|
@@ -39,6 +39,7 @@ __all__ = [
|
|
|
39
39
|
"group",
|
|
40
40
|
"init",
|
|
41
41
|
"init_auto_from_config",
|
|
42
|
+
"map",
|
|
42
43
|
"run",
|
|
43
44
|
"trace",
|
|
44
45
|
"with_runcontext",
|
|
@@ -51,6 +52,7 @@ from ._environment import Environment
|
|
|
51
52
|
from ._group import group
|
|
52
53
|
from ._image import Image
|
|
53
54
|
from ._initialize import init, init_auto_from_config
|
|
55
|
+
from ._map import map
|
|
54
56
|
from ._resources import GPU, TPU, Device, Resources
|
|
55
57
|
from ._retry import RetryStrategy
|
|
56
58
|
from ._reusable_environment import ReusePolicy
|
flyte/_context.py
CHANGED
|
@@ -4,6 +4,7 @@ import contextvars
|
|
|
4
4
|
from dataclasses import dataclass, replace
|
|
5
5
|
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, ParamSpec, TypeVar
|
|
6
6
|
|
|
7
|
+
from flyte._logging import logger
|
|
7
8
|
from flyte.models import GroupData, RawDataPath, TaskContext
|
|
8
9
|
|
|
9
10
|
if TYPE_CHECKING:
|
|
@@ -49,6 +50,7 @@ class Context:
|
|
|
49
50
|
raise ValueError("Cannot create a new context without contextdata.")
|
|
50
51
|
self._data = data
|
|
51
52
|
self._id = id(self) # Immutable unique identifier
|
|
53
|
+
self._token = None # Context variable token to restore the previous context
|
|
52
54
|
|
|
53
55
|
@property
|
|
54
56
|
def data(self) -> ContextData:
|
|
@@ -106,7 +108,11 @@ class Context:
|
|
|
106
108
|
|
|
107
109
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
108
110
|
"""Exit the context, restoring the previous context."""
|
|
109
|
-
|
|
111
|
+
try:
|
|
112
|
+
root_context_var.reset(self._token)
|
|
113
|
+
except Exception as e:
|
|
114
|
+
logger.warn(f"Failed to reset context: {e}")
|
|
115
|
+
raise e
|
|
110
116
|
|
|
111
117
|
async def __aenter__(self):
|
|
112
118
|
"""Async version of context entry."""
|
flyte/_group.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
|
+
import concurrent.futures
|
|
1
2
|
import threading
|
|
2
|
-
from typing import Any, Callable, Literal, Optional, Protocol, Tuple, TypeVar
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Protocol, Tuple, TypeVar
|
|
3
4
|
|
|
4
5
|
from flyte._task import TaskTemplate
|
|
5
6
|
from flyte.models import ActionID, NativeInterface
|
|
@@ -10,6 +11,9 @@ __all__ = ["Controller", "ControllerType", "TraceInfo", "create_controller", "ge
|
|
|
10
11
|
|
|
11
12
|
from ..._protos.workflow import task_definition_pb2
|
|
12
13
|
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import concurrent.futures
|
|
16
|
+
|
|
13
17
|
ControllerType = Literal["local", "remote"]
|
|
14
18
|
|
|
15
19
|
R = TypeVar("R")
|
|
@@ -28,7 +32,14 @@ class Controller(Protocol):
|
|
|
28
32
|
"""
|
|
29
33
|
...
|
|
30
34
|
|
|
31
|
-
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) ->
|
|
35
|
+
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
|
|
36
|
+
"""
|
|
37
|
+
This should call the async submit method above, but return a concurrent Future object that can be
|
|
38
|
+
used in a blocking wait or wrapped in an async future. This is called when
|
|
39
|
+
a) a synchronous task is kicked off locally,
|
|
40
|
+
b) a running task (of either kind) kicks off a downstream synchronous task.
|
|
41
|
+
"""
|
|
42
|
+
...
|
|
32
43
|
|
|
33
44
|
async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
|
|
34
45
|
"""
|
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import atexit
|
|
3
|
+
import concurrent.futures
|
|
4
|
+
import os
|
|
5
|
+
import threading
|
|
1
6
|
from typing import Any, Callable, Tuple, TypeVar
|
|
2
7
|
|
|
3
8
|
import flyte.errors
|
|
@@ -8,20 +13,67 @@ from flyte._internal.runtime.entrypoints import direct_dispatch
|
|
|
8
13
|
from flyte._logging import log, logger
|
|
9
14
|
from flyte._protos.workflow import task_definition_pb2
|
|
10
15
|
from flyte._task import TaskTemplate
|
|
11
|
-
from flyte._utils.
|
|
16
|
+
from flyte._utils.helpers import _selector_policy
|
|
12
17
|
from flyte.models import ActionID, NativeInterface, RawDataPath
|
|
13
18
|
|
|
14
19
|
R = TypeVar("R")
|
|
15
20
|
|
|
16
21
|
|
|
22
|
+
class _TaskRunner:
|
|
23
|
+
"""A task runner that runs an asyncio event loop on a background thread."""
|
|
24
|
+
|
|
25
|
+
def __init__(self) -> None:
|
|
26
|
+
self.__loop: asyncio.AbstractEventLoop | None = None
|
|
27
|
+
self.__runner_thread: threading.Thread | None = None
|
|
28
|
+
self.__lock = threading.Lock()
|
|
29
|
+
atexit.register(self._close)
|
|
30
|
+
|
|
31
|
+
def _close(self) -> None:
|
|
32
|
+
if self.__loop:
|
|
33
|
+
self.__loop.stop()
|
|
34
|
+
|
|
35
|
+
def _execute(self) -> None:
|
|
36
|
+
loop = self.__loop
|
|
37
|
+
assert loop is not None
|
|
38
|
+
try:
|
|
39
|
+
loop.run_forever()
|
|
40
|
+
finally:
|
|
41
|
+
loop.close()
|
|
42
|
+
|
|
43
|
+
def get_exc_handler(self):
|
|
44
|
+
def exc_handler(loop, context):
|
|
45
|
+
logger.error(
|
|
46
|
+
f"Taskrunner for {self.__runner_thread.name if self.__runner_thread else 'no thread'} caught"
|
|
47
|
+
f" exception in {loop}: {context}"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return exc_handler
|
|
51
|
+
|
|
52
|
+
def get_run_future(self, coro: Any) -> concurrent.futures.Future:
|
|
53
|
+
"""Synchronously run a coroutine on a background thread."""
|
|
54
|
+
name = f"{threading.current_thread().name} : loop-runner"
|
|
55
|
+
with self.__lock:
|
|
56
|
+
if self.__loop is None:
|
|
57
|
+
with _selector_policy():
|
|
58
|
+
self.__loop = asyncio.new_event_loop()
|
|
59
|
+
|
|
60
|
+
exc_handler = self.get_exc_handler()
|
|
61
|
+
self.__loop.set_exception_handler(exc_handler)
|
|
62
|
+
self.__runner_thread = threading.Thread(target=self._execute, daemon=True, name=name)
|
|
63
|
+
self.__runner_thread.start()
|
|
64
|
+
fut = asyncio.run_coroutine_threadsafe(coro, self.__loop)
|
|
65
|
+
return fut
|
|
66
|
+
|
|
67
|
+
|
|
17
68
|
class LocalController:
|
|
18
69
|
def __init__(self):
|
|
19
70
|
logger.debug("LocalController init")
|
|
71
|
+
self._runner_map: dict[str, _TaskRunner] = {}
|
|
20
72
|
|
|
21
73
|
@log
|
|
22
74
|
async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
23
75
|
"""
|
|
24
|
-
|
|
76
|
+
Main entrypoint for submitting a task to the local controller.
|
|
25
77
|
"""
|
|
26
78
|
ctx = internal_ctx()
|
|
27
79
|
tctx = ctx.data.task_context
|
|
@@ -59,7 +111,17 @@ class LocalController:
|
|
|
59
111
|
return result
|
|
60
112
|
return out
|
|
61
113
|
|
|
62
|
-
submit_sync
|
|
114
|
+
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
|
|
115
|
+
name = threading.current_thread().name + f"PID:{os.getpid()}"
|
|
116
|
+
coro = self.submit(_task, *args, **kwargs)
|
|
117
|
+
if name not in self._runner_map:
|
|
118
|
+
if len(self._runner_map) > 100:
|
|
119
|
+
logger.warning(
|
|
120
|
+
"More than 100 event loop runners created!!! This could be a case of runaway recursion..."
|
|
121
|
+
)
|
|
122
|
+
self._runner_map[name] = _TaskRunner()
|
|
123
|
+
|
|
124
|
+
return self._runner_map[name].get_run_future(coro)
|
|
63
125
|
|
|
64
126
|
async def finalize_parent_action(self, action: ActionID):
|
|
65
127
|
pass
|
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import concurrent.futures
|
|
5
|
+
import os
|
|
6
|
+
import threading
|
|
4
7
|
from collections import defaultdict
|
|
5
8
|
from collections.abc import Callable
|
|
6
9
|
from pathlib import Path
|
|
@@ -21,7 +24,7 @@ from flyte._internal.runtime.task_serde import translate_task_to_wire
|
|
|
21
24
|
from flyte._logging import logger
|
|
22
25
|
from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
|
|
23
26
|
from flyte._task import TaskTemplate
|
|
24
|
-
from flyte._utils.
|
|
27
|
+
from flyte._utils.helpers import _selector_policy
|
|
25
28
|
from flyte.models import ActionID, NativeInterface, SerializationContext
|
|
26
29
|
|
|
27
30
|
R = TypeVar("R")
|
|
@@ -120,6 +123,8 @@ class RemoteController(Controller):
|
|
|
120
123
|
self._parent_action_task_call_sequence: DefaultDict[str, DefaultDict[int, int]] = defaultdict(
|
|
121
124
|
lambda: defaultdict(int)
|
|
122
125
|
)
|
|
126
|
+
self._submit_loop: asyncio.AbstractEventLoop | None = None
|
|
127
|
+
self._submit_thread: threading.Thread | None = None
|
|
123
128
|
|
|
124
129
|
def generate_task_call_sequence(self, task_obj: object, action_id: ActionID) -> int:
|
|
125
130
|
"""
|
|
@@ -236,7 +241,47 @@ class RemoteController(Controller):
|
|
|
236
241
|
async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
|
|
237
242
|
return await self._submit(task_call_seq, _task, *args, **kwargs)
|
|
238
243
|
|
|
239
|
-
|
|
244
|
+
def _sync_thread_loop_runner(self) -> None:
|
|
245
|
+
"""This method runs the event loop and should be invoked in a separate thread."""
|
|
246
|
+
|
|
247
|
+
loop = self._submit_loop
|
|
248
|
+
assert loop is not None
|
|
249
|
+
try:
|
|
250
|
+
loop.run_forever()
|
|
251
|
+
finally:
|
|
252
|
+
loop.close()
|
|
253
|
+
|
|
254
|
+
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
|
|
255
|
+
"""
|
|
256
|
+
This function creates a cached thread and loop for the purpose of calling the submit method synchronously,
|
|
257
|
+
returning a concurrent Future that can be awaited. There's no need for a lock because this function itself is
|
|
258
|
+
single threaded and non-async. This pattern here is basically the trivial/degenerate case of the thread pool
|
|
259
|
+
in the LocalController.
|
|
260
|
+
Please see additional comments in protocol.
|
|
261
|
+
|
|
262
|
+
:param _task:
|
|
263
|
+
:param args:
|
|
264
|
+
:param kwargs:
|
|
265
|
+
:return:
|
|
266
|
+
"""
|
|
267
|
+
if self._submit_thread is None:
|
|
268
|
+
# Please see LocalController for the general implementation of this pattern.
|
|
269
|
+
def exc_handler(loop, context):
|
|
270
|
+
logger.error(f"Remote controller submit sync loop caught exception in {loop}: {context}")
|
|
271
|
+
|
|
272
|
+
with _selector_policy():
|
|
273
|
+
self._submit_loop = asyncio.new_event_loop()
|
|
274
|
+
self._submit_loop.set_exception_handler(exc_handler)
|
|
275
|
+
|
|
276
|
+
self._submit_thread = threading.Thread(
|
|
277
|
+
name=f"remote-controller-{os.getpid()}-submitter", daemon=True, target=self._sync_thread_loop_runner
|
|
278
|
+
)
|
|
279
|
+
self._submit_thread.start()
|
|
280
|
+
|
|
281
|
+
coro = self.submit(_task, *args, **kwargs)
|
|
282
|
+
assert self._submit_loop is not None, "Submit loop should always have been initialized by now"
|
|
283
|
+
fut = asyncio.run_coroutine_threadsafe(coro, self._submit_loop)
|
|
284
|
+
return fut
|
|
240
285
|
|
|
241
286
|
async def finalize_parent_action(self, action_id: ActionID):
|
|
242
287
|
"""
|
|
@@ -136,8 +136,9 @@ async def convert_and_run(
|
|
|
136
136
|
raw_data_path=raw_data_path,
|
|
137
137
|
compiled_image_cache=image_cache,
|
|
138
138
|
report=flyte.report.Report(name=action.name),
|
|
139
|
+
mode="remote" if not ctx.data.task_context else ctx.data.task_context.mode,
|
|
139
140
|
)
|
|
140
|
-
|
|
141
|
+
with ctx.replace_task_context(tctx):
|
|
141
142
|
inputs_kwargs = await convert_inputs_to_native(inputs, task.native_interface)
|
|
142
143
|
out, err = await run_task(tctx=tctx, controller=controller, task=task, inputs=inputs_kwargs)
|
|
143
144
|
if err is not None:
|
flyte/_map.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast
|
|
3
|
+
|
|
4
|
+
from flyte.syncify import syncify
|
|
5
|
+
|
|
6
|
+
from ._group import group
|
|
7
|
+
from ._logging import logger
|
|
8
|
+
from ._task import P, R, TaskTemplate
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MapAsyncIterator(Generic[P, R]):
|
|
12
|
+
"""AsyncIterator implementation for the map function results"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, func: TaskTemplate[P, R], args: tuple, name: str, concurrency: int, return_exceptions: bool):
|
|
15
|
+
self.func = func
|
|
16
|
+
self.args = args
|
|
17
|
+
self.name = name
|
|
18
|
+
self.concurrency = concurrency
|
|
19
|
+
self.return_exceptions = return_exceptions
|
|
20
|
+
self._tasks: List[asyncio.Task] = []
|
|
21
|
+
self._current_index = 0
|
|
22
|
+
self._completed_count = 0
|
|
23
|
+
self._exception_count = 0
|
|
24
|
+
self._task_count = 0
|
|
25
|
+
self._initialized = False
|
|
26
|
+
|
|
27
|
+
def __aiter__(self) -> AsyncIterator[Union[R, Exception]]:
|
|
28
|
+
"""Return self as the async iterator"""
|
|
29
|
+
return self
|
|
30
|
+
|
|
31
|
+
async def __anext__(self) -> Union[R, Exception]:
|
|
32
|
+
"""Get the next result"""
|
|
33
|
+
# Initialize on first call
|
|
34
|
+
if not self._initialized:
|
|
35
|
+
await self._initialize()
|
|
36
|
+
|
|
37
|
+
# Check if we've exhausted all tasks
|
|
38
|
+
if self._current_index >= self._task_count:
|
|
39
|
+
raise StopAsyncIteration
|
|
40
|
+
|
|
41
|
+
# Get the next task result
|
|
42
|
+
task = self._tasks[self._current_index]
|
|
43
|
+
self._current_index += 1
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
result = await task
|
|
47
|
+
self._completed_count += 1
|
|
48
|
+
logger.debug(f"Task {self._current_index - 1} completed successfully")
|
|
49
|
+
return result
|
|
50
|
+
except Exception as e:
|
|
51
|
+
self._exception_count += 1
|
|
52
|
+
logger.debug(f"Task {self._current_index - 1} failed with exception: {e}")
|
|
53
|
+
if self.return_exceptions:
|
|
54
|
+
return e
|
|
55
|
+
else:
|
|
56
|
+
# Cancel remaining tasks
|
|
57
|
+
for remaining_task in self._tasks[self._current_index + 1 :]:
|
|
58
|
+
remaining_task.cancel()
|
|
59
|
+
raise e
|
|
60
|
+
|
|
61
|
+
async def _initialize(self):
|
|
62
|
+
"""Initialize the tasks - called lazily on first iteration"""
|
|
63
|
+
# Create all tasks at once
|
|
64
|
+
tasks = []
|
|
65
|
+
task_count = 0
|
|
66
|
+
|
|
67
|
+
for arg_tuple in zip(*self.args):
|
|
68
|
+
task = asyncio.create_task(self.func.aio(*arg_tuple))
|
|
69
|
+
tasks.append(task)
|
|
70
|
+
task_count += 1
|
|
71
|
+
|
|
72
|
+
if task_count == 0:
|
|
73
|
+
logger.info(f"Group '{self.name}' has no tasks to process")
|
|
74
|
+
self._tasks = []
|
|
75
|
+
self._task_count = 0
|
|
76
|
+
else:
|
|
77
|
+
logger.info(f"Starting {task_count} tasks in group '{self.name}' with unlimited concurrency")
|
|
78
|
+
self._tasks = tasks
|
|
79
|
+
self._task_count = task_count
|
|
80
|
+
|
|
81
|
+
self._initialized = True
|
|
82
|
+
|
|
83
|
+
async def collect(self) -> List[Union[R, Exception]]:
|
|
84
|
+
"""Convenience method to collect all results into a list"""
|
|
85
|
+
results = []
|
|
86
|
+
async for result in self:
|
|
87
|
+
results.append(result)
|
|
88
|
+
return results
|
|
89
|
+
|
|
90
|
+
def __repr__(self):
|
|
91
|
+
return f"MapAsyncIterator(group_name='{self.name}', concurrency={self.concurrency})"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class _Mapper(Generic[P, R]):
|
|
95
|
+
"""
|
|
96
|
+
Internal mapper class to handle the mapping logic
|
|
97
|
+
|
|
98
|
+
NOTE: The reason why we do not use the `@syncify` decorator here is because, in `syncify` we cannot use
|
|
99
|
+
context managers like `group` directly in the function body. This is because the `__exit__` method of the
|
|
100
|
+
context manager is called after the function returns. An for `_context` the `__exit__` method releases the
|
|
101
|
+
token (for contextvar), which was created in a separate thread. This leads to an exception like:
|
|
102
|
+
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def _get_name(cls, task_name: str, group_name: str | None) -> str:
|
|
107
|
+
"""Get the name of the group, defaulting to 'map' if not provided."""
|
|
108
|
+
return f"{task_name}_{group_name or 'map'}"
|
|
109
|
+
|
|
110
|
+
def __call__(
|
|
111
|
+
self,
|
|
112
|
+
func: TaskTemplate[P, R],
|
|
113
|
+
*args: Iterable[Any],
|
|
114
|
+
group_name: str | None = None,
|
|
115
|
+
concurrency: int = 0,
|
|
116
|
+
return_exceptions: bool = True,
|
|
117
|
+
) -> Iterator[Union[R, Exception]]:
|
|
118
|
+
"""
|
|
119
|
+
Map a function over the provided arguments with concurrent execution.
|
|
120
|
+
|
|
121
|
+
:param func: The async function to map.
|
|
122
|
+
:param args: Positional arguments to pass to the function (iterables that will be zipped).
|
|
123
|
+
:param group_name: The name of the group for the mapped tasks.
|
|
124
|
+
:param concurrency: The maximum number of concurrent tasks to run. If 0, run all tasks concurrently.
|
|
125
|
+
:param return_exceptions: If True, yield exceptions instead of raising them.
|
|
126
|
+
:return: AsyncIterator yielding results in order.
|
|
127
|
+
"""
|
|
128
|
+
if not args:
|
|
129
|
+
return
|
|
130
|
+
|
|
131
|
+
name = self._get_name(func.name, group_name)
|
|
132
|
+
logger.debug(f"Blocking Map for {name}")
|
|
133
|
+
with group(name):
|
|
134
|
+
import flyte
|
|
135
|
+
|
|
136
|
+
tctx = flyte.ctx()
|
|
137
|
+
if tctx is None or tctx.mode == "local":
|
|
138
|
+
logger.warning("Running map in local mode, which will run every task sequentially.")
|
|
139
|
+
for v in zip(*args):
|
|
140
|
+
try:
|
|
141
|
+
yield func(*v) # type: ignore
|
|
142
|
+
except Exception as e:
|
|
143
|
+
if return_exceptions:
|
|
144
|
+
yield e
|
|
145
|
+
else:
|
|
146
|
+
raise e
|
|
147
|
+
return
|
|
148
|
+
|
|
149
|
+
i = 0
|
|
150
|
+
for x in cast(
|
|
151
|
+
Iterator[R],
|
|
152
|
+
_map(
|
|
153
|
+
func,
|
|
154
|
+
*args,
|
|
155
|
+
name=name,
|
|
156
|
+
concurrency=concurrency,
|
|
157
|
+
return_exceptions=True,
|
|
158
|
+
),
|
|
159
|
+
):
|
|
160
|
+
logger.debug(f"Mapped {x}, task {i}")
|
|
161
|
+
i += 1
|
|
162
|
+
yield x
|
|
163
|
+
|
|
164
|
+
async def aio(
|
|
165
|
+
self,
|
|
166
|
+
func: TaskTemplate[P, R],
|
|
167
|
+
*args: Iterable[Any],
|
|
168
|
+
group_name: str | None = None,
|
|
169
|
+
concurrency: int = 0,
|
|
170
|
+
return_exceptions: bool = True,
|
|
171
|
+
) -> AsyncGenerator[Union[R, Exception], None]:
|
|
172
|
+
if not args:
|
|
173
|
+
return
|
|
174
|
+
name = self._get_name(func.name, group_name)
|
|
175
|
+
with group(name):
|
|
176
|
+
import flyte
|
|
177
|
+
|
|
178
|
+
tctx = flyte.ctx()
|
|
179
|
+
if tctx is None or tctx.mode == "local":
|
|
180
|
+
logger.warning("Running map in local mode, which will run every task sequentially.")
|
|
181
|
+
for v in zip(*args):
|
|
182
|
+
try:
|
|
183
|
+
yield func(*v) # type: ignore
|
|
184
|
+
except Exception as e:
|
|
185
|
+
if return_exceptions:
|
|
186
|
+
yield e
|
|
187
|
+
else:
|
|
188
|
+
raise e
|
|
189
|
+
return
|
|
190
|
+
async for x in _map.aio(
|
|
191
|
+
func,
|
|
192
|
+
*args,
|
|
193
|
+
name=name,
|
|
194
|
+
concurrency=concurrency,
|
|
195
|
+
return_exceptions=return_exceptions,
|
|
196
|
+
):
|
|
197
|
+
yield cast(Union[R, Exception], x)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@syncify
|
|
201
|
+
async def _map(
|
|
202
|
+
func: TaskTemplate[P, R],
|
|
203
|
+
*args: Iterable[Any],
|
|
204
|
+
name: str = "map",
|
|
205
|
+
concurrency: int = 0,
|
|
206
|
+
return_exceptions: bool = True,
|
|
207
|
+
) -> AsyncIterator[Union[R, Exception]]:
|
|
208
|
+
iter = MapAsyncIterator(
|
|
209
|
+
func=func, args=args, name=name, concurrency=concurrency, return_exceptions=return_exceptions
|
|
210
|
+
)
|
|
211
|
+
async for result in iter:
|
|
212
|
+
yield result
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
map: _Mapper = _Mapper()
|
flyte/_run.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import pathlib
|
|
4
5
|
import uuid
|
|
5
6
|
from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union, cast
|
|
@@ -287,40 +288,39 @@ class _Runner:
|
|
|
287
288
|
|
|
288
289
|
async def _run_local(self, obj: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
|
289
290
|
from flyte._internal.controllers import create_controller
|
|
290
|
-
from flyte._internal.
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
)
|
|
295
|
-
from flyte._internal.runtime.entrypoints import direct_dispatch
|
|
291
|
+
from flyte._internal.controllers._local_controller import LocalController
|
|
292
|
+
from flyte.report import Report
|
|
293
|
+
|
|
294
|
+
controller = cast(LocalController, create_controller("local"))
|
|
296
295
|
|
|
297
|
-
controller = create_controller("local")
|
|
298
|
-
inputs = await convert_from_native_to_inputs(obj.native_interface, *args, **kwargs)
|
|
299
296
|
if self._name is None:
|
|
300
297
|
action = ActionID.create_random()
|
|
301
298
|
else:
|
|
302
299
|
action = ActionID(name=self._name)
|
|
303
|
-
|
|
304
|
-
|
|
300
|
+
|
|
301
|
+
ctx = internal_ctx()
|
|
302
|
+
tctx = TaskContext(
|
|
305
303
|
action=action,
|
|
306
|
-
raw_data_path=internal_ctx().raw_data,
|
|
307
|
-
version="na",
|
|
308
|
-
controller=controller,
|
|
309
|
-
inputs=inputs,
|
|
310
|
-
output_path=self._metadata_path,
|
|
311
|
-
run_base_dir=self._metadata_path,
|
|
312
304
|
checkpoints=Checkpoints(
|
|
313
305
|
prev_checkpoint_path=internal_ctx().raw_data.path, checkpoint_path=internal_ctx().raw_data.path
|
|
314
306
|
),
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
307
|
+
code_bundle=None,
|
|
308
|
+
output_path=self._metadata_path,
|
|
309
|
+
run_base_dir=self._metadata_path,
|
|
310
|
+
version="na",
|
|
311
|
+
raw_data_path=internal_ctx().raw_data,
|
|
312
|
+
compiled_image_cache=None,
|
|
313
|
+
report=Report(name=action.name),
|
|
314
|
+
mode="local",
|
|
315
|
+
)
|
|
316
|
+
async with ctx.replace_task_context(tctx):
|
|
317
|
+
# make the local version always runs on a different thread, returns a wrapped future.
|
|
318
|
+
if obj._call_as_synchronous:
|
|
319
|
+
fut = controller.submit_sync(obj, *args, **kwargs)
|
|
320
|
+
awaitable = asyncio.wrap_future(fut)
|
|
321
|
+
return await awaitable
|
|
322
|
+
else:
|
|
323
|
+
return await controller.submit(obj, *args, **kwargs)
|
|
324
324
|
|
|
325
325
|
@syncify
|
|
326
326
|
async def run(self, task: TaskTemplate[P, Union[R, Run]], *args: P.args, **kwargs: P.kwargs) -> Union[R, Run]:
|
|
@@ -351,7 +351,7 @@ class _Runner:
|
|
|
351
351
|
return await self._run_hybrid(task, *args, **kwargs)
|
|
352
352
|
|
|
353
353
|
# TODO We could use this for remote as well and users could simply pass flyte:// or s3:// or file://
|
|
354
|
-
|
|
354
|
+
with internal_ctx().new_raw_data_path(
|
|
355
355
|
raw_data_path=RawDataPath.from_local_folder(local_folder=self._raw_data_path)
|
|
356
356
|
):
|
|
357
357
|
return await self._run_local(task, *args, **kwargs)
|
flyte/_task.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import weakref
|
|
4
5
|
from dataclasses import dataclass, field, replace
|
|
5
6
|
from functools import cached_property
|
|
@@ -216,6 +217,51 @@ class TaskTemplate(Generic[P, R]):
|
|
|
216
217
|
def native_interface(self) -> NativeInterface:
|
|
217
218
|
return self.interface
|
|
218
219
|
|
|
220
|
+
async def aio(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R] | R:
|
|
221
|
+
"""
|
|
222
|
+
The aio function allows executing "sync" tasks, in an async context. This helps with migrating v1 defined sync
|
|
223
|
+
tasks to be used within an asyncio parent task.
|
|
224
|
+
This function will also re-raise exceptions from the underlying task.
|
|
225
|
+
|
|
226
|
+
Example:
|
|
227
|
+
```python
|
|
228
|
+
@env.task
|
|
229
|
+
def my_legacy_task(x: int) -> int:
|
|
230
|
+
return x
|
|
231
|
+
|
|
232
|
+
@env.task
|
|
233
|
+
async def my_new_parent_task(n: int) -> List[int]:
|
|
234
|
+
collect = []
|
|
235
|
+
for x in range(n):
|
|
236
|
+
collect.append(my_legacy_task.aio(x))
|
|
237
|
+
return asyncio.gather(*collect)
|
|
238
|
+
```
|
|
239
|
+
:param args:
|
|
240
|
+
:param kwargs:
|
|
241
|
+
:return:
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
ctx = internal_ctx()
|
|
245
|
+
if ctx.is_task_context():
|
|
246
|
+
from ._internal.controllers import get_controller
|
|
247
|
+
|
|
248
|
+
# If we are in a task context, that implies we are executing a Run.
|
|
249
|
+
# In this scenario, we should submit the task to the controller.
|
|
250
|
+
controller = get_controller()
|
|
251
|
+
if controller:
|
|
252
|
+
if self._call_as_synchronous:
|
|
253
|
+
fut = controller.submit_sync(self, *args, **kwargs)
|
|
254
|
+
asyncio_future = asyncio.wrap_future(fut) # Wrap the future to make it awaitable
|
|
255
|
+
return await asyncio_future
|
|
256
|
+
else:
|
|
257
|
+
return await controller.submit(self, *args, **kwargs)
|
|
258
|
+
else:
|
|
259
|
+
raise RuntimeSystemError("BadContext", "Controller is not initialized.")
|
|
260
|
+
else:
|
|
261
|
+
# Local execute, just stay out of the way, but because .aio is used, we want to return an awaitable,
|
|
262
|
+
# even for synchronous tasks. This is to support migration.
|
|
263
|
+
return self.forward(*args, **kwargs)
|
|
264
|
+
|
|
219
265
|
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R] | R:
|
|
220
266
|
"""
|
|
221
267
|
This is the entrypoint for an async function task at runtime. It will be called during an execution.
|
|
@@ -235,15 +281,18 @@ class TaskTemplate(Generic[P, R]):
|
|
|
235
281
|
from ._internal.controllers import get_controller
|
|
236
282
|
|
|
237
283
|
controller = get_controller()
|
|
238
|
-
if controller:
|
|
239
|
-
if self._call_as_synchronous:
|
|
240
|
-
return controller.submit_sync(self, *args, **kwargs)
|
|
241
|
-
else:
|
|
242
|
-
return controller.submit(self, *args, **kwargs)
|
|
243
|
-
else:
|
|
284
|
+
if not controller:
|
|
244
285
|
raise RuntimeSystemError("BadContext", "Controller is not initialized.")
|
|
245
286
|
|
|
246
|
-
|
|
287
|
+
if self._call_as_synchronous:
|
|
288
|
+
fut = controller.submit_sync(self, *args, **kwargs)
|
|
289
|
+
x = fut.result(None)
|
|
290
|
+
return x
|
|
291
|
+
else:
|
|
292
|
+
return controller.submit(self, *args, **kwargs)
|
|
293
|
+
else:
|
|
294
|
+
# If not in task context, purely function run, stay out of the way
|
|
295
|
+
return self.forward(*args, **kwargs)
|
|
247
296
|
except RuntimeSystemError:
|
|
248
297
|
raise
|
|
249
298
|
except RuntimeUserError:
|
flyte/_utils/helpers.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import string
|
|
3
3
|
import typing
|
|
4
|
+
from contextlib import contextmanager
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
|
|
6
7
|
|
|
@@ -106,3 +107,17 @@ def get_cwd_editable_install() -> typing.Optional[Path]:
|
|
|
106
107
|
return install # note we want the install folder, not the parent
|
|
107
108
|
|
|
108
109
|
return None
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@contextmanager
|
|
113
|
+
def _selector_policy():
|
|
114
|
+
import asyncio
|
|
115
|
+
|
|
116
|
+
original_policy = asyncio.get_event_loop_policy()
|
|
117
|
+
try:
|
|
118
|
+
if os.name == "nt" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
|
|
119
|
+
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
120
|
+
|
|
121
|
+
yield
|
|
122
|
+
finally:
|
|
123
|
+
asyncio.set_event_loop_policy(original_policy)
|
flyte/_version.py
CHANGED
|
@@ -17,5 +17,5 @@ __version__: str
|
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
|
18
18
|
version_tuple: VERSION_TUPLE
|
|
19
19
|
|
|
20
|
-
__version__ = version = '0.2.
|
|
21
|
-
__version_tuple__ = version_tuple = (0, 2, 0, '
|
|
20
|
+
__version__ = version = '0.2.0b9'
|
|
21
|
+
__version_tuple__ = version_tuple = (0, 2, 0, 'b9')
|