flyte 2.0.0b32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +108 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/debug.py +38 -0
- flyte/_bin/runtime.py +195 -0
- flyte/_bin/serve.py +178 -0
- flyte/_build.py +26 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +147 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/local_cache.py +216 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +121 -0
- flyte/_code_bundle/_packaging.py +218 -0
- flyte/_code_bundle/_utils.py +347 -0
- flyte/_code_bundle/bundle.py +266 -0
- flyte/_constants.py +1 -0
- flyte/_context.py +155 -0
- flyte/_custom_context.py +73 -0
- flyte/_debug/__init__.py +0 -0
- flyte/_debug/constants.py +38 -0
- flyte/_debug/utils.py +17 -0
- flyte/_debug/vscode.py +307 -0
- flyte/_deploy.py +408 -0
- flyte/_deployer.py +109 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +122 -0
- flyte/_excepthook.py +37 -0
- flyte/_group.py +32 -0
- flyte/_hash.py +8 -0
- flyte/_image.py +1055 -0
- flyte/_initialize.py +628 -0
- flyte/_interface.py +119 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +129 -0
- flyte/_internal/controllers/_local_controller.py +239 -0
- flyte/_internal/controllers/_trace.py +48 -0
- flyte/_internal/controllers/remote/__init__.py +58 -0
- flyte/_internal/controllers/remote/_action.py +211 -0
- flyte/_internal/controllers/remote/_client.py +47 -0
- flyte/_internal/controllers/remote/_controller.py +583 -0
- flyte/_internal/controllers/remote/_core.py +465 -0
- flyte/_internal/controllers/remote/_informer.py +381 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +3 -0
- flyte/_internal/imagebuild/docker_builder.py +706 -0
- flyte/_internal/imagebuild/image_builder.py +277 -0
- flyte/_internal/imagebuild/remote_builder.py +386 -0
- flyte/_internal/imagebuild/utils.py +78 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +21 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +486 -0
- flyte/_internal/runtime/entrypoints.py +204 -0
- flyte/_internal/runtime/io.py +188 -0
- flyte/_internal/runtime/resources_serde.py +152 -0
- flyte/_internal/runtime/reuse.py +125 -0
- flyte/_internal/runtime/rusty.py +193 -0
- flyte/_internal/runtime/task_serde.py +362 -0
- flyte/_internal/runtime/taskrunner.py +209 -0
- flyte/_internal/runtime/trigger_serde.py +160 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_keyring/__init__.py +0 -0
- flyte/_keyring/file.py +115 -0
- flyte/_logging.py +300 -0
- flyte/_map.py +312 -0
- flyte/_module.py +72 -0
- flyte/_pod.py +30 -0
- flyte/_resources.py +473 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +102 -0
- flyte/_run.py +724 -0
- flyte/_secret.py +96 -0
- flyte/_task.py +550 -0
- flyte/_task_environment.py +316 -0
- flyte/_task_plugins.py +47 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +119 -0
- flyte/_trigger.py +1000 -0
- flyte/_utils/__init__.py +30 -0
- flyte/_utils/asyn.py +121 -0
- flyte/_utils/async_cache.py +139 -0
- flyte/_utils/coro_management.py +27 -0
- flyte/_utils/docker_credentials.py +173 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +134 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/module_loader.py +104 -0
- flyte/_utils/org_discovery.py +57 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +34 -0
- flyte/app/__init__.py +22 -0
- flyte/app/_app_environment.py +157 -0
- flyte/app/_deploy.py +125 -0
- flyte/app/_input.py +160 -0
- flyte/app/_runtime/__init__.py +3 -0
- flyte/app/_runtime/app_serde.py +347 -0
- flyte/app/_types.py +101 -0
- flyte/app/extras/__init__.py +3 -0
- flyte/app/extras/_fastapi.py +151 -0
- flyte/cli/__init__.py +12 -0
- flyte/cli/_abort.py +28 -0
- flyte/cli/_build.py +114 -0
- flyte/cli/_common.py +468 -0
- flyte/cli/_create.py +371 -0
- flyte/cli/_delete.py +45 -0
- flyte/cli/_deploy.py +293 -0
- flyte/cli/_gen.py +176 -0
- flyte/cli/_get.py +370 -0
- flyte/cli/_option.py +33 -0
- flyte/cli/_params.py +554 -0
- flyte/cli/_plugins.py +209 -0
- flyte/cli/_run.py +597 -0
- flyte/cli/_serve.py +64 -0
- flyte/cli/_update.py +37 -0
- flyte/cli/_user.py +17 -0
- flyte/cli/main.py +221 -0
- flyte/config/__init__.py +3 -0
- flyte/config/_config.py +248 -0
- flyte/config/_internal.py +73 -0
- flyte/config/_reader.py +225 -0
- flyte/connectors/__init__.py +11 -0
- flyte/connectors/_connector.py +270 -0
- flyte/connectors/_server.py +197 -0
- flyte/connectors/utils.py +135 -0
- flyte/errors.py +243 -0
- flyte/extend.py +19 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +286 -0
- flyte/git/__init__.py +3 -0
- flyte/git/_config.py +21 -0
- flyte/io/__init__.py +29 -0
- flyte/io/_dataframe/__init__.py +131 -0
- flyte/io/_dataframe/basic_dfs.py +223 -0
- flyte/io/_dataframe/dataframe.py +1026 -0
- flyte/io/_dir.py +910 -0
- flyte/io/_file.py +914 -0
- flyte/io/_hashing_io.py +342 -0
- flyte/models.py +479 -0
- flyte/py.typed +0 -0
- flyte/remote/__init__.py +35 -0
- flyte/remote/_action.py +738 -0
- flyte/remote/_app.py +57 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +189 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_auth_utils.py +14 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +403 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +213 -0
- flyte/remote/_client/auth/_client_config.py +85 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +152 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +128 -0
- flyte/remote/_common.py +30 -0
- flyte/remote/_console.py +19 -0
- flyte/remote/_data.py +161 -0
- flyte/remote/_logs.py +185 -0
- flyte/remote/_project.py +88 -0
- flyte/remote/_run.py +386 -0
- flyte/remote/_secret.py +142 -0
- flyte/remote/_task.py +527 -0
- flyte/remote/_trigger.py +306 -0
- flyte/remote/_user.py +33 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +182 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +36 -0
- flyte/storage/_config.py +237 -0
- flyte/storage/_parallel_reader.py +274 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +456 -0
- flyte/storage/_utils.py +5 -0
- flyte/syncify/__init__.py +56 -0
- flyte/syncify/_api.py +375 -0
- flyte/types/__init__.py +52 -0
- flyte/types/_interface.py +40 -0
- flyte/types/_pickle.py +145 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +119 -0
- flyte/types/_type_engine.py +2254 -0
- flyte/types/_utils.py +80 -0
- flyte-2.0.0b32.data/scripts/debug.py +38 -0
- flyte-2.0.0b32.data/scripts/runtime.py +195 -0
- flyte-2.0.0b32.dist-info/METADATA +351 -0
- flyte-2.0.0b32.dist-info/RECORD +204 -0
- flyte-2.0.0b32.dist-info/WHEEL +5 -0
- flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
- flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
- flyte-2.0.0b32.dist-info/top_level.txt +1 -0
flyte/_interface.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import typing
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Dict, Generator, Literal, Tuple, Type, TypeVar, Union, cast, get_args, get_origin, get_type_hints
|
|
7
|
+
|
|
8
|
+
from flyte._logging import logger
|
|
9
|
+
|
|
10
|
+
LITERAL_ENUM = "LiteralEnum"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def default_output_name(index: int = 0) -> str:
|
|
14
|
+
return f"o{index}"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def output_name_generator(length: int) -> Generator[str, None, None]:
|
|
18
|
+
for x in range(length):
|
|
19
|
+
yield default_output_name(x)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Dict[str, Type]:
|
|
23
|
+
"""
|
|
24
|
+
The input to this function should be sig.return_annotation where sig = inspect.signature(some_func)
|
|
25
|
+
The purpose of this function is to sort out whether a function is returning one thing, or multiple things, and to
|
|
26
|
+
name the outputs accordingly, either by using our default name function, or from a typing.NamedTuple.
|
|
27
|
+
|
|
28
|
+
# Option 1
|
|
29
|
+
nt1 = typing.NamedTuple("NT1", x_str=str, y_int=int)
|
|
30
|
+
def t(a: int, b: str) -> nt1: ...
|
|
31
|
+
|
|
32
|
+
# Option 2
|
|
33
|
+
def t(a: int, b: str) -> typing.NamedTuple("NT1", x_str=str, y_int=int): ...
|
|
34
|
+
|
|
35
|
+
# Option 3
|
|
36
|
+
def t(a: int, b: str) -> typing.Tuple[int, str]: ...
|
|
37
|
+
|
|
38
|
+
# Option 4
|
|
39
|
+
def t(a: int, b: str) -> (int, str): ...
|
|
40
|
+
|
|
41
|
+
# Option 5
|
|
42
|
+
def t(a: int, b: str) -> str: ...
|
|
43
|
+
|
|
44
|
+
# Option 6
|
|
45
|
+
def t(a: int, b: str) -> None: ...
|
|
46
|
+
|
|
47
|
+
# Options 7/8
|
|
48
|
+
def t(a: int, b: str) -> List[int]: ...
|
|
49
|
+
def t(a: int, b: str) -> Dict[str, int]: ...
|
|
50
|
+
|
|
51
|
+
Note that Options 1 and 2 are identical, just syntactic sugar. In the NamedTuple case, we'll use the names in the
|
|
52
|
+
definition. In all other cases, we'll automatically generate output names, indexed starting at 0.
|
|
53
|
+
"""
|
|
54
|
+
if isinstance(return_annotation, str):
|
|
55
|
+
raise TypeError("String return annotations are not supported.")
|
|
56
|
+
|
|
57
|
+
# Handle Option 6
|
|
58
|
+
# We can think about whether we should add a default output name with type None in the future.
|
|
59
|
+
if return_annotation in (None, type(None), inspect.Signature.empty):
|
|
60
|
+
return {}
|
|
61
|
+
|
|
62
|
+
# This statement results in true for typing.Namedtuple, single and void return types, so this
|
|
63
|
+
# handles Options 1, 2. Even though NamedTuple for us is multi-valued, it's a single value for Python
|
|
64
|
+
if hasattr(return_annotation, "__bases__") and (
|
|
65
|
+
isinstance(return_annotation, type) or isinstance(return_annotation, TypeVar)
|
|
66
|
+
):
|
|
67
|
+
# isinstance / issubclass does not work for Namedtuple.
|
|
68
|
+
# Options 1 and 2
|
|
69
|
+
bases = return_annotation.__bases__ # type: ignore
|
|
70
|
+
if len(bases) == 1 and bases[0] is tuple and hasattr(return_annotation, "_fields"):
|
|
71
|
+
# Task returns named tuple
|
|
72
|
+
return dict(get_type_hints(cast(Type, return_annotation), include_extras=True))
|
|
73
|
+
|
|
74
|
+
if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore
|
|
75
|
+
# Handle option 3
|
|
76
|
+
# Task returns unnamed typing.Tuple
|
|
77
|
+
if len(return_annotation.__args__) == 1: # type: ignore
|
|
78
|
+
raise TypeError("Tuples should be used to indicate multiple return values, found only one return variable.")
|
|
79
|
+
ra = get_args(return_annotation)
|
|
80
|
+
annotations = {}
|
|
81
|
+
for i, r in enumerate(ra):
|
|
82
|
+
if r is Ellipsis:
|
|
83
|
+
raise TypeError("Variable length tuples are not supported as return types.")
|
|
84
|
+
if get_origin(r) is Literal:
|
|
85
|
+
annotations[default_output_name(i)] = literal_to_enum(cast(Type, r))
|
|
86
|
+
else:
|
|
87
|
+
annotations[default_output_name(i)] = r
|
|
88
|
+
return annotations
|
|
89
|
+
|
|
90
|
+
elif isinstance(return_annotation, tuple):
|
|
91
|
+
if len(return_annotation) == 1:
|
|
92
|
+
raise TypeError("Please don't use a tuple if you're just returning one thing.")
|
|
93
|
+
return dict(zip(list(output_name_generator(len(return_annotation))), return_annotation))
|
|
94
|
+
|
|
95
|
+
else:
|
|
96
|
+
# Handle all other single return types
|
|
97
|
+
# Task returns unnamed native tuple
|
|
98
|
+
if get_origin(return_annotation) is Literal:
|
|
99
|
+
return {default_output_name(): literal_to_enum(cast(Type, return_annotation))}
|
|
100
|
+
return {default_output_name(): cast(Type, return_annotation)}
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def literal_to_enum(literal_type: Type) -> Type[Enum | typing.Any]:
|
|
104
|
+
"""Convert a Literal[...] into Union[str, Enum]."""
|
|
105
|
+
|
|
106
|
+
if get_origin(literal_type) is not Literal:
|
|
107
|
+
raise TypeError(f"{literal_type} is not a Literal")
|
|
108
|
+
|
|
109
|
+
values = get_args(literal_type)
|
|
110
|
+
if not all(isinstance(v, str) for v in values):
|
|
111
|
+
logger.warning(f"Literal type {literal_type} contains non-string values, using Any instead of Enum")
|
|
112
|
+
return typing.Any
|
|
113
|
+
# Deduplicate & keep order
|
|
114
|
+
enum_dict = {str(v).upper(): v for v in values}
|
|
115
|
+
|
|
116
|
+
# Dynamically create an Enum
|
|
117
|
+
literal_enum = Enum(LITERAL_ENUM, enum_dict) # type: ignore
|
|
118
|
+
|
|
119
|
+
return literal_enum # type: ignore
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import concurrent.futures
|
|
2
|
+
import threading
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Protocol, Tuple, TypeVar
|
|
4
|
+
|
|
5
|
+
from flyte._task import TaskTemplate
|
|
6
|
+
from flyte.models import ActionID, NativeInterface
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from flyte.remote._task import TaskDetails
|
|
10
|
+
|
|
11
|
+
from ._trace import TraceInfo
|
|
12
|
+
|
|
13
|
+
__all__ = ["Controller", "ControllerType", "TraceInfo", "create_controller", "get_controller"]
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
import concurrent.futures
|
|
17
|
+
|
|
18
|
+
ControllerType = Literal["local", "remote"]
|
|
19
|
+
|
|
20
|
+
R = TypeVar("R")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Controller(Protocol):
|
|
24
|
+
"""
|
|
25
|
+
Controller interface, that is used to execute tasks. The implementation of this interface,
|
|
26
|
+
can execute tasks in different ways, such as locally, remotely etc.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
30
|
+
"""
|
|
31
|
+
Submit a node to the controller asynchronously and wait for the result. This is async and will block
|
|
32
|
+
the current coroutine until the result is available.
|
|
33
|
+
"""
|
|
34
|
+
...
|
|
35
|
+
|
|
36
|
+
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
|
|
37
|
+
"""
|
|
38
|
+
This should call the async submit method above, but return a concurrent Future object that can be
|
|
39
|
+
used in a blocking wait or wrapped in an async future. This is called when
|
|
40
|
+
a) a synchronous task is kicked off locally,
|
|
41
|
+
b) a running task (of either kind) kicks off a downstream synchronous task.
|
|
42
|
+
"""
|
|
43
|
+
...
|
|
44
|
+
|
|
45
|
+
async def submit_task_ref(self, _task: "TaskDetails", *args, **kwargs) -> Any:
|
|
46
|
+
"""
|
|
47
|
+
Submit a task reference to the controller asynchronously and wait for the result. This is async and will block
|
|
48
|
+
the current coroutine until the result is available.
|
|
49
|
+
"""
|
|
50
|
+
...
|
|
51
|
+
|
|
52
|
+
async def finalize_parent_action(self, action: ActionID):
|
|
53
|
+
"""
|
|
54
|
+
Finalize the parent action. This can be called to cleanup the action and should be called after the parent
|
|
55
|
+
task completes
|
|
56
|
+
:param action: Action ID
|
|
57
|
+
:return:
|
|
58
|
+
"""
|
|
59
|
+
...
|
|
60
|
+
|
|
61
|
+
async def watch_for_errors(self): ...
|
|
62
|
+
|
|
63
|
+
async def get_action_outputs(
|
|
64
|
+
self, _interface: NativeInterface, _func: Callable, *args, **kwargs
|
|
65
|
+
) -> Tuple[TraceInfo, bool]:
|
|
66
|
+
"""
|
|
67
|
+
This method returns the outputs of the action, if it is available.
|
|
68
|
+
:param _interface: NativeInterface
|
|
69
|
+
:param _func: Function name
|
|
70
|
+
:param args: Arguments
|
|
71
|
+
:param kwargs: Keyword arguments
|
|
72
|
+
:return: TraceInfo object and a boolean indicating if the action was found.
|
|
73
|
+
if boolean is False, it means the action is not found and the TraceInfo object will have only min info
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
async def record_trace(self, info: TraceInfo):
|
|
77
|
+
"""
|
|
78
|
+
Record a trace action. This is used to record the trace of the action and should be called when the action
|
|
79
|
+
is completed.
|
|
80
|
+
:param info: Trace information
|
|
81
|
+
:return:
|
|
82
|
+
"""
|
|
83
|
+
...
|
|
84
|
+
|
|
85
|
+
async def stop(self):
|
|
86
|
+
"""
|
|
87
|
+
Stops the engine and should be called when the engine is no longer needed.
|
|
88
|
+
"""
|
|
89
|
+
...
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# Internal state holder
|
|
93
|
+
class _ControllerState:
|
|
94
|
+
controller: Optional[Controller] = None
|
|
95
|
+
lock = threading.Lock()
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_controller() -> Controller:
|
|
99
|
+
"""
|
|
100
|
+
Get the controller instance. Raise an error if it has not been created.
|
|
101
|
+
"""
|
|
102
|
+
if _ControllerState.controller is not None:
|
|
103
|
+
return _ControllerState.controller
|
|
104
|
+
raise RuntimeError("Controller is not initialized. Please call create_controller() first.")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def create_controller(
|
|
108
|
+
ct: ControllerType,
|
|
109
|
+
**kwargs,
|
|
110
|
+
) -> Controller:
|
|
111
|
+
"""
|
|
112
|
+
Create a new instance of the controller, based on the kind and the given configuration.
|
|
113
|
+
"""
|
|
114
|
+
controller: Controller
|
|
115
|
+
match ct:
|
|
116
|
+
case "local":
|
|
117
|
+
from ._local_controller import LocalController
|
|
118
|
+
|
|
119
|
+
controller = LocalController()
|
|
120
|
+
case "remote" | "hybrid":
|
|
121
|
+
from flyte._internal.controllers.remote import create_remote_controller
|
|
122
|
+
|
|
123
|
+
controller = create_remote_controller(**kwargs)
|
|
124
|
+
case _:
|
|
125
|
+
raise ValueError(f"{ct} is not a valid controller type.")
|
|
126
|
+
|
|
127
|
+
with _ControllerState.lock:
|
|
128
|
+
_ControllerState.controller = controller
|
|
129
|
+
return controller
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import atexit
|
|
3
|
+
import concurrent.futures
|
|
4
|
+
import os
|
|
5
|
+
import pathlib
|
|
6
|
+
import threading
|
|
7
|
+
from typing import Any, Callable, Tuple, TypeVar
|
|
8
|
+
|
|
9
|
+
import flyte.errors
|
|
10
|
+
from flyte._cache.cache import VersionParameters, cache_from_request
|
|
11
|
+
from flyte._cache.local_cache import LocalTaskCache
|
|
12
|
+
from flyte._context import internal_ctx
|
|
13
|
+
from flyte._internal.controllers import TraceInfo
|
|
14
|
+
from flyte._internal.runtime import convert
|
|
15
|
+
from flyte._internal.runtime.entrypoints import direct_dispatch
|
|
16
|
+
from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
|
|
17
|
+
from flyte._logging import log, logger
|
|
18
|
+
from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
|
|
19
|
+
from flyte._utils.helpers import _selector_policy
|
|
20
|
+
from flyte.models import ActionID, NativeInterface
|
|
21
|
+
from flyte.remote._task import TaskDetails
|
|
22
|
+
|
|
23
|
+
R = TypeVar("R")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class _TaskRunner:
|
|
27
|
+
"""A task runner that runs an asyncio event loop on a background thread."""
|
|
28
|
+
|
|
29
|
+
def __init__(self) -> None:
|
|
30
|
+
self.__loop: asyncio.AbstractEventLoop | None = None
|
|
31
|
+
self.__runner_thread: threading.Thread | None = None
|
|
32
|
+
self.__lock = threading.Lock()
|
|
33
|
+
atexit.register(self._close)
|
|
34
|
+
|
|
35
|
+
def _close(self) -> None:
|
|
36
|
+
if self.__loop:
|
|
37
|
+
self.__loop.stop()
|
|
38
|
+
|
|
39
|
+
def _execute(self) -> None:
|
|
40
|
+
loop = self.__loop
|
|
41
|
+
assert loop is not None
|
|
42
|
+
try:
|
|
43
|
+
loop.run_forever()
|
|
44
|
+
finally:
|
|
45
|
+
loop.close()
|
|
46
|
+
|
|
47
|
+
def get_exc_handler(self):
|
|
48
|
+
def exc_handler(loop, context):
|
|
49
|
+
logger.error(
|
|
50
|
+
f"Taskrunner for {self.__runner_thread.name if self.__runner_thread else 'no thread'} caught"
|
|
51
|
+
f" exception in {loop}: {context}"
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
return exc_handler
|
|
55
|
+
|
|
56
|
+
def get_run_future(self, coro: Any) -> concurrent.futures.Future:
|
|
57
|
+
"""Synchronously run a coroutine on a background thread."""
|
|
58
|
+
name = f"{threading.current_thread().name} : loop-runner"
|
|
59
|
+
with self.__lock:
|
|
60
|
+
if self.__loop is None:
|
|
61
|
+
with _selector_policy():
|
|
62
|
+
self.__loop = asyncio.new_event_loop()
|
|
63
|
+
|
|
64
|
+
exc_handler = self.get_exc_handler()
|
|
65
|
+
self.__loop.set_exception_handler(exc_handler)
|
|
66
|
+
self.__runner_thread = threading.Thread(target=self._execute, daemon=True, name=name)
|
|
67
|
+
self.__runner_thread.start()
|
|
68
|
+
fut = asyncio.run_coroutine_threadsafe(coro, self.__loop)
|
|
69
|
+
return fut
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class LocalController:
|
|
73
|
+
def __init__(self):
|
|
74
|
+
logger.debug("LocalController init")
|
|
75
|
+
self._runner_map: dict[str, _TaskRunner] = {}
|
|
76
|
+
|
|
77
|
+
@log
|
|
78
|
+
async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
79
|
+
"""
|
|
80
|
+
Main entrypoint for submitting a task to the local controller.
|
|
81
|
+
"""
|
|
82
|
+
ctx = internal_ctx()
|
|
83
|
+
tctx = ctx.data.task_context
|
|
84
|
+
if not tctx:
|
|
85
|
+
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
86
|
+
|
|
87
|
+
inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
|
|
88
|
+
inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
|
|
89
|
+
task_interface = transform_native_to_typed_interface(_task.interface)
|
|
90
|
+
|
|
91
|
+
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
92
|
+
tctx, _task.name, inputs_hash, 0
|
|
93
|
+
)
|
|
94
|
+
sub_action_raw_data_path = tctx.raw_data_path
|
|
95
|
+
# Make sure the output path exists
|
|
96
|
+
pathlib.Path(sub_action_output_path).mkdir(parents=True, exist_ok=True)
|
|
97
|
+
pathlib.Path(sub_action_raw_data_path.path).mkdir(parents=True, exist_ok=True)
|
|
98
|
+
|
|
99
|
+
task_cache = cache_from_request(_task.cache)
|
|
100
|
+
cache_enabled = task_cache.is_enabled()
|
|
101
|
+
if isinstance(_task, AsyncFunctionTaskTemplate):
|
|
102
|
+
version_parameters = VersionParameters(func=_task.func, image=_task.image)
|
|
103
|
+
else:
|
|
104
|
+
version_parameters = VersionParameters(func=None, image=_task.image)
|
|
105
|
+
cache_version = task_cache.get_version(version_parameters)
|
|
106
|
+
cache_key = convert.generate_cache_key_hash(
|
|
107
|
+
_task.name,
|
|
108
|
+
inputs_hash,
|
|
109
|
+
task_interface,
|
|
110
|
+
cache_version,
|
|
111
|
+
list(task_cache.get_ignored_inputs()),
|
|
112
|
+
inputs.proto_inputs,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
out = None
|
|
116
|
+
# We only get output from cache if the cache behavior is set to auto
|
|
117
|
+
if task_cache.behavior == "auto":
|
|
118
|
+
out = await LocalTaskCache.get(cache_key)
|
|
119
|
+
if out is not None:
|
|
120
|
+
logger.info(
|
|
121
|
+
f"Cache hit for task '{_task.name}' (version: {cache_version}), getting result from cache..."
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if out is None:
|
|
125
|
+
out, err = await direct_dispatch(
|
|
126
|
+
_task,
|
|
127
|
+
controller=self,
|
|
128
|
+
action=sub_action_id,
|
|
129
|
+
raw_data_path=sub_action_raw_data_path,
|
|
130
|
+
inputs=inputs,
|
|
131
|
+
version=cache_version,
|
|
132
|
+
checkpoints=tctx.checkpoints,
|
|
133
|
+
code_bundle=tctx.code_bundle,
|
|
134
|
+
output_path=sub_action_output_path,
|
|
135
|
+
run_base_dir=tctx.run_base_dir,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
if err:
|
|
139
|
+
exc = convert.convert_error_to_native(err)
|
|
140
|
+
if exc:
|
|
141
|
+
raise exc
|
|
142
|
+
else:
|
|
143
|
+
raise flyte.errors.RuntimeSystemError("BadError", "Unknown error")
|
|
144
|
+
|
|
145
|
+
# store into cache
|
|
146
|
+
if cache_enabled and out is not None:
|
|
147
|
+
await LocalTaskCache.set(cache_key, out)
|
|
148
|
+
|
|
149
|
+
if _task.native_interface.outputs:
|
|
150
|
+
if out is None:
|
|
151
|
+
raise flyte.errors.RuntimeSystemError("BadOutput", "Task output not captured.")
|
|
152
|
+
result = await convert.convert_outputs_to_native(_task.native_interface, out)
|
|
153
|
+
return result
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
|
|
157
|
+
name = threading.current_thread().name + f"PID:{os.getpid()}"
|
|
158
|
+
coro = self.submit(_task, *args, **kwargs)
|
|
159
|
+
if name not in self._runner_map:
|
|
160
|
+
if len(self._runner_map) > 100:
|
|
161
|
+
logger.warning(
|
|
162
|
+
"More than 100 event loop runners created!!! This could be a case of runaway recursion..."
|
|
163
|
+
)
|
|
164
|
+
self._runner_map[name] = _TaskRunner()
|
|
165
|
+
|
|
166
|
+
return self._runner_map[name].get_run_future(coro)
|
|
167
|
+
|
|
168
|
+
async def finalize_parent_action(self, action: ActionID):
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
async def stop(self):
|
|
172
|
+
await LocalTaskCache.close()
|
|
173
|
+
|
|
174
|
+
async def watch_for_errors(self):
|
|
175
|
+
pass
|
|
176
|
+
|
|
177
|
+
async def get_action_outputs(
|
|
178
|
+
self, _interface: NativeInterface, _func: Callable, *args, **kwargs
|
|
179
|
+
) -> Tuple[TraceInfo, bool]:
|
|
180
|
+
"""
|
|
181
|
+
This method returns the outputs of the action, if it is available.
|
|
182
|
+
If not available it raises a flyte.errors.ActionNotFoundError.
|
|
183
|
+
:return:
|
|
184
|
+
"""
|
|
185
|
+
ctx = internal_ctx()
|
|
186
|
+
tctx = ctx.data.task_context
|
|
187
|
+
if not tctx:
|
|
188
|
+
raise flyte.errors.NotInTaskContextError("BadContext", "Task context not initialized")
|
|
189
|
+
|
|
190
|
+
converted_inputs = convert.Inputs.empty()
|
|
191
|
+
if _interface.inputs:
|
|
192
|
+
converted_inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
|
|
193
|
+
assert converted_inputs
|
|
194
|
+
|
|
195
|
+
inputs_hash = convert.generate_inputs_hash_from_proto(converted_inputs.proto_inputs)
|
|
196
|
+
action_id, action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
197
|
+
tctx,
|
|
198
|
+
_func.__name__,
|
|
199
|
+
inputs_hash,
|
|
200
|
+
0,
|
|
201
|
+
)
|
|
202
|
+
assert action_output_path
|
|
203
|
+
return (
|
|
204
|
+
TraceInfo(
|
|
205
|
+
name=_func.__name__,
|
|
206
|
+
action=action_id,
|
|
207
|
+
interface=_interface,
|
|
208
|
+
inputs_path=action_output_path,
|
|
209
|
+
),
|
|
210
|
+
True,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
async def record_trace(self, info: TraceInfo):
|
|
214
|
+
"""
|
|
215
|
+
This method records the trace of the action.
|
|
216
|
+
:param info: Trace information
|
|
217
|
+
:return:
|
|
218
|
+
"""
|
|
219
|
+
ctx = internal_ctx()
|
|
220
|
+
tctx = ctx.data.task_context
|
|
221
|
+
if not tctx:
|
|
222
|
+
raise flyte.errors.NotInTaskContextError("BadContext", "Task context not initialized")
|
|
223
|
+
|
|
224
|
+
if info.interface.outputs and info.output:
|
|
225
|
+
# If the result is not an AsyncGenerator, convert it directly
|
|
226
|
+
converted_outputs = await convert.convert_from_native_to_outputs(info.output, info.interface, info.name)
|
|
227
|
+
assert converted_outputs
|
|
228
|
+
elif info.error:
|
|
229
|
+
# If there is an error, convert it to a native error
|
|
230
|
+
converted_error = convert.convert_from_native_to_error(info.error)
|
|
231
|
+
assert converted_error
|
|
232
|
+
assert info.action
|
|
233
|
+
assert info.start_time
|
|
234
|
+
assert info.end_time
|
|
235
|
+
|
|
236
|
+
async def submit_task_ref(self, _task: TaskDetails, max_inline_io_bytes: int, *args, **kwargs) -> Any:
|
|
237
|
+
raise flyte.errors.ReferenceTaskError(
|
|
238
|
+
f"Reference tasks cannot be executed locally, only remotely. Found remote task {_task.name}"
|
|
239
|
+
)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
4
|
+
from flyteidl2.core import interface_pb2
|
|
5
|
+
|
|
6
|
+
from flyte.models import ActionID, NativeInterface
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class TraceInfo:
|
|
11
|
+
"""
|
|
12
|
+
Trace information for the action. This is used to record the trace of the action and should be called when
|
|
13
|
+
the action is completed.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
name: str
|
|
17
|
+
action: ActionID
|
|
18
|
+
interface: NativeInterface
|
|
19
|
+
inputs_path: str
|
|
20
|
+
start_time: float = field(init=False, default=0.0)
|
|
21
|
+
end_time: float = field(init=False, default=0.0)
|
|
22
|
+
output: Optional[Any] = None
|
|
23
|
+
error: Optional[Exception] = None
|
|
24
|
+
typed_interface: Optional[interface_pb2.TypedInterface] = None
|
|
25
|
+
|
|
26
|
+
def add_outputs(self, output: Any, start_time: float, end_time: float):
|
|
27
|
+
"""
|
|
28
|
+
Add outputs to the trace information.
|
|
29
|
+
:param output: Output of the action
|
|
30
|
+
:param start_time: Start time of the action
|
|
31
|
+
:param end_time: End time of the action
|
|
32
|
+
:return:
|
|
33
|
+
"""
|
|
34
|
+
self.output = output
|
|
35
|
+
self.start_time = start_time
|
|
36
|
+
self.end_time = end_time
|
|
37
|
+
|
|
38
|
+
def add_error(self, error: Exception, start_time: float, end_time: float):
|
|
39
|
+
"""
|
|
40
|
+
Add error to the trace information.
|
|
41
|
+
:param error: Error of the action
|
|
42
|
+
:param start_time: Start time of the action
|
|
43
|
+
:param end_time: End time of the action
|
|
44
|
+
:return:
|
|
45
|
+
"""
|
|
46
|
+
self.error = error
|
|
47
|
+
self.start_time = start_time
|
|
48
|
+
self.end_time = end_time
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from flyte.remote._client.auth import AuthType, ClientConfig
|
|
4
|
+
|
|
5
|
+
from ._controller import RemoteController
|
|
6
|
+
|
|
7
|
+
__all__ = ["RemoteController", "create_remote_controller"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def create_remote_controller(
|
|
11
|
+
*,
|
|
12
|
+
api_key: str | None = None,
|
|
13
|
+
endpoint: str | None = None,
|
|
14
|
+
insecure: bool = False,
|
|
15
|
+
insecure_skip_verify: bool = False,
|
|
16
|
+
ca_cert_file_path: str | None = None,
|
|
17
|
+
client_config: ClientConfig | None = None,
|
|
18
|
+
auth_type: AuthType = "Pkce",
|
|
19
|
+
headless: bool = False,
|
|
20
|
+
command: List[str] | None = None,
|
|
21
|
+
proxy_command: List[str] | None = None,
|
|
22
|
+
client_id: str | None = None,
|
|
23
|
+
client_credentials_secret: str | None = None,
|
|
24
|
+
rpc_retries: int = 3,
|
|
25
|
+
http_proxy_url: str | None = None,
|
|
26
|
+
) -> RemoteController:
|
|
27
|
+
"""
|
|
28
|
+
Create a new instance of the remote controller.
|
|
29
|
+
"""
|
|
30
|
+
assert endpoint or api_key, "Either endpoint or api_key must be provided when initializing remote controller"
|
|
31
|
+
from ._client import ControllerClient
|
|
32
|
+
from ._controller import RemoteController
|
|
33
|
+
|
|
34
|
+
if endpoint:
|
|
35
|
+
client_coro = ControllerClient.for_endpoint(
|
|
36
|
+
endpoint,
|
|
37
|
+
insecure=insecure,
|
|
38
|
+
insecure_skip_verify=insecure_skip_verify,
|
|
39
|
+
ca_cert_file_path=ca_cert_file_path,
|
|
40
|
+
client_id=client_id,
|
|
41
|
+
client_credentials_secret=client_credentials_secret,
|
|
42
|
+
auth_type=auth_type,
|
|
43
|
+
)
|
|
44
|
+
elif api_key:
|
|
45
|
+
client_coro = ControllerClient.for_api_key(
|
|
46
|
+
api_key,
|
|
47
|
+
insecure=insecure,
|
|
48
|
+
insecure_skip_verify=insecure_skip_verify,
|
|
49
|
+
ca_cert_file_path=ca_cert_file_path,
|
|
50
|
+
client_id=client_id,
|
|
51
|
+
client_credentials_secret=client_credentials_secret,
|
|
52
|
+
auth_type=auth_type,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
controller = RemoteController(
|
|
56
|
+
client_coro=client_coro,
|
|
57
|
+
)
|
|
58
|
+
return controller
|