flyte 2.0.0b32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (204) hide show
  1. flyte/__init__.py +108 -0
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +195 -0
  5. flyte/_bin/serve.py +178 -0
  6. flyte/_build.py +26 -0
  7. flyte/_cache/__init__.py +12 -0
  8. flyte/_cache/cache.py +147 -0
  9. flyte/_cache/defaults.py +9 -0
  10. flyte/_cache/local_cache.py +216 -0
  11. flyte/_cache/policy_function_body.py +42 -0
  12. flyte/_code_bundle/__init__.py +8 -0
  13. flyte/_code_bundle/_ignore.py +121 -0
  14. flyte/_code_bundle/_packaging.py +218 -0
  15. flyte/_code_bundle/_utils.py +347 -0
  16. flyte/_code_bundle/bundle.py +266 -0
  17. flyte/_constants.py +1 -0
  18. flyte/_context.py +155 -0
  19. flyte/_custom_context.py +73 -0
  20. flyte/_debug/__init__.py +0 -0
  21. flyte/_debug/constants.py +38 -0
  22. flyte/_debug/utils.py +17 -0
  23. flyte/_debug/vscode.py +307 -0
  24. flyte/_deploy.py +408 -0
  25. flyte/_deployer.py +109 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +122 -0
  29. flyte/_excepthook.py +37 -0
  30. flyte/_group.py +32 -0
  31. flyte/_hash.py +8 -0
  32. flyte/_image.py +1055 -0
  33. flyte/_initialize.py +628 -0
  34. flyte/_interface.py +119 -0
  35. flyte/_internal/__init__.py +3 -0
  36. flyte/_internal/controllers/__init__.py +129 -0
  37. flyte/_internal/controllers/_local_controller.py +239 -0
  38. flyte/_internal/controllers/_trace.py +48 -0
  39. flyte/_internal/controllers/remote/__init__.py +58 -0
  40. flyte/_internal/controllers/remote/_action.py +211 -0
  41. flyte/_internal/controllers/remote/_client.py +47 -0
  42. flyte/_internal/controllers/remote/_controller.py +583 -0
  43. flyte/_internal/controllers/remote/_core.py +465 -0
  44. flyte/_internal/controllers/remote/_informer.py +381 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +3 -0
  47. flyte/_internal/imagebuild/docker_builder.py +706 -0
  48. flyte/_internal/imagebuild/image_builder.py +277 -0
  49. flyte/_internal/imagebuild/remote_builder.py +386 -0
  50. flyte/_internal/imagebuild/utils.py +78 -0
  51. flyte/_internal/resolvers/__init__.py +0 -0
  52. flyte/_internal/resolvers/_task_module.py +21 -0
  53. flyte/_internal/resolvers/common.py +31 -0
  54. flyte/_internal/resolvers/default.py +28 -0
  55. flyte/_internal/runtime/__init__.py +0 -0
  56. flyte/_internal/runtime/convert.py +486 -0
  57. flyte/_internal/runtime/entrypoints.py +204 -0
  58. flyte/_internal/runtime/io.py +188 -0
  59. flyte/_internal/runtime/resources_serde.py +152 -0
  60. flyte/_internal/runtime/reuse.py +125 -0
  61. flyte/_internal/runtime/rusty.py +193 -0
  62. flyte/_internal/runtime/task_serde.py +362 -0
  63. flyte/_internal/runtime/taskrunner.py +209 -0
  64. flyte/_internal/runtime/trigger_serde.py +160 -0
  65. flyte/_internal/runtime/types_serde.py +54 -0
  66. flyte/_keyring/__init__.py +0 -0
  67. flyte/_keyring/file.py +115 -0
  68. flyte/_logging.py +300 -0
  69. flyte/_map.py +312 -0
  70. flyte/_module.py +72 -0
  71. flyte/_pod.py +30 -0
  72. flyte/_resources.py +473 -0
  73. flyte/_retry.py +32 -0
  74. flyte/_reusable_environment.py +102 -0
  75. flyte/_run.py +724 -0
  76. flyte/_secret.py +96 -0
  77. flyte/_task.py +550 -0
  78. flyte/_task_environment.py +316 -0
  79. flyte/_task_plugins.py +47 -0
  80. flyte/_timeout.py +47 -0
  81. flyte/_tools.py +27 -0
  82. flyte/_trace.py +119 -0
  83. flyte/_trigger.py +1000 -0
  84. flyte/_utils/__init__.py +30 -0
  85. flyte/_utils/asyn.py +121 -0
  86. flyte/_utils/async_cache.py +139 -0
  87. flyte/_utils/coro_management.py +27 -0
  88. flyte/_utils/docker_credentials.py +173 -0
  89. flyte/_utils/file_handling.py +72 -0
  90. flyte/_utils/helpers.py +134 -0
  91. flyte/_utils/lazy_module.py +54 -0
  92. flyte/_utils/module_loader.py +104 -0
  93. flyte/_utils/org_discovery.py +57 -0
  94. flyte/_utils/uv_script_parser.py +49 -0
  95. flyte/_version.py +34 -0
  96. flyte/app/__init__.py +22 -0
  97. flyte/app/_app_environment.py +157 -0
  98. flyte/app/_deploy.py +125 -0
  99. flyte/app/_input.py +160 -0
  100. flyte/app/_runtime/__init__.py +3 -0
  101. flyte/app/_runtime/app_serde.py +347 -0
  102. flyte/app/_types.py +101 -0
  103. flyte/app/extras/__init__.py +3 -0
  104. flyte/app/extras/_fastapi.py +151 -0
  105. flyte/cli/__init__.py +12 -0
  106. flyte/cli/_abort.py +28 -0
  107. flyte/cli/_build.py +114 -0
  108. flyte/cli/_common.py +468 -0
  109. flyte/cli/_create.py +371 -0
  110. flyte/cli/_delete.py +45 -0
  111. flyte/cli/_deploy.py +293 -0
  112. flyte/cli/_gen.py +176 -0
  113. flyte/cli/_get.py +370 -0
  114. flyte/cli/_option.py +33 -0
  115. flyte/cli/_params.py +554 -0
  116. flyte/cli/_plugins.py +209 -0
  117. flyte/cli/_run.py +597 -0
  118. flyte/cli/_serve.py +64 -0
  119. flyte/cli/_update.py +37 -0
  120. flyte/cli/_user.py +17 -0
  121. flyte/cli/main.py +221 -0
  122. flyte/config/__init__.py +3 -0
  123. flyte/config/_config.py +248 -0
  124. flyte/config/_internal.py +73 -0
  125. flyte/config/_reader.py +225 -0
  126. flyte/connectors/__init__.py +11 -0
  127. flyte/connectors/_connector.py +270 -0
  128. flyte/connectors/_server.py +197 -0
  129. flyte/connectors/utils.py +135 -0
  130. flyte/errors.py +243 -0
  131. flyte/extend.py +19 -0
  132. flyte/extras/__init__.py +5 -0
  133. flyte/extras/_container.py +286 -0
  134. flyte/git/__init__.py +3 -0
  135. flyte/git/_config.py +21 -0
  136. flyte/io/__init__.py +29 -0
  137. flyte/io/_dataframe/__init__.py +131 -0
  138. flyte/io/_dataframe/basic_dfs.py +223 -0
  139. flyte/io/_dataframe/dataframe.py +1026 -0
  140. flyte/io/_dir.py +910 -0
  141. flyte/io/_file.py +914 -0
  142. flyte/io/_hashing_io.py +342 -0
  143. flyte/models.py +479 -0
  144. flyte/py.typed +0 -0
  145. flyte/remote/__init__.py +35 -0
  146. flyte/remote/_action.py +738 -0
  147. flyte/remote/_app.py +57 -0
  148. flyte/remote/_client/__init__.py +0 -0
  149. flyte/remote/_client/_protocols.py +189 -0
  150. flyte/remote/_client/auth/__init__.py +12 -0
  151. flyte/remote/_client/auth/_auth_utils.py +14 -0
  152. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  153. flyte/remote/_client/auth/_authenticators/base.py +403 -0
  154. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  155. flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
  156. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  157. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  158. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  159. flyte/remote/_client/auth/_channel.py +213 -0
  160. flyte/remote/_client/auth/_client_config.py +85 -0
  161. flyte/remote/_client/auth/_default_html.py +32 -0
  162. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  163. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  164. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  165. flyte/remote/_client/auth/_keyring.py +152 -0
  166. flyte/remote/_client/auth/_token_client.py +260 -0
  167. flyte/remote/_client/auth/errors.py +16 -0
  168. flyte/remote/_client/controlplane.py +128 -0
  169. flyte/remote/_common.py +30 -0
  170. flyte/remote/_console.py +19 -0
  171. flyte/remote/_data.py +161 -0
  172. flyte/remote/_logs.py +185 -0
  173. flyte/remote/_project.py +88 -0
  174. flyte/remote/_run.py +386 -0
  175. flyte/remote/_secret.py +142 -0
  176. flyte/remote/_task.py +527 -0
  177. flyte/remote/_trigger.py +306 -0
  178. flyte/remote/_user.py +33 -0
  179. flyte/report/__init__.py +3 -0
  180. flyte/report/_report.py +182 -0
  181. flyte/report/_template.html +124 -0
  182. flyte/storage/__init__.py +36 -0
  183. flyte/storage/_config.py +237 -0
  184. flyte/storage/_parallel_reader.py +274 -0
  185. flyte/storage/_remote_fs.py +34 -0
  186. flyte/storage/_storage.py +456 -0
  187. flyte/storage/_utils.py +5 -0
  188. flyte/syncify/__init__.py +56 -0
  189. flyte/syncify/_api.py +375 -0
  190. flyte/types/__init__.py +52 -0
  191. flyte/types/_interface.py +40 -0
  192. flyte/types/_pickle.py +145 -0
  193. flyte/types/_renderer.py +162 -0
  194. flyte/types/_string_literals.py +119 -0
  195. flyte/types/_type_engine.py +2254 -0
  196. flyte/types/_utils.py +80 -0
  197. flyte-2.0.0b32.data/scripts/debug.py +38 -0
  198. flyte-2.0.0b32.data/scripts/runtime.py +195 -0
  199. flyte-2.0.0b32.dist-info/METADATA +351 -0
  200. flyte-2.0.0b32.dist-info/RECORD +204 -0
  201. flyte-2.0.0b32.dist-info/WHEEL +5 -0
  202. flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
  203. flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
  204. flyte-2.0.0b32.dist-info/top_level.txt +1 -0
flyte/_interface.py ADDED
@@ -0,0 +1,119 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import typing
5
+ from enum import Enum
6
+ from typing import Dict, Generator, Literal, Tuple, Type, TypeVar, Union, cast, get_args, get_origin, get_type_hints
7
+
8
+ from flyte._logging import logger
9
+
10
+ LITERAL_ENUM = "LiteralEnum"
11
+
12
+
13
+ def default_output_name(index: int = 0) -> str:
14
+ return f"o{index}"
15
+
16
+
17
+ def output_name_generator(length: int) -> Generator[str, None, None]:
18
+ for x in range(length):
19
+ yield default_output_name(x)
20
+
21
+
22
+ def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Dict[str, Type]:
23
+ """
24
+ The input to this function should be sig.return_annotation where sig = inspect.signature(some_func)
25
+ The purpose of this function is to sort out whether a function is returning one thing, or multiple things, and to
26
+ name the outputs accordingly, either by using our default name function, or from a typing.NamedTuple.
27
+
28
+ # Option 1
29
+ nt1 = typing.NamedTuple("NT1", x_str=str, y_int=int)
30
+ def t(a: int, b: str) -> nt1: ...
31
+
32
+ # Option 2
33
+ def t(a: int, b: str) -> typing.NamedTuple("NT1", x_str=str, y_int=int): ...
34
+
35
+ # Option 3
36
+ def t(a: int, b: str) -> typing.Tuple[int, str]: ...
37
+
38
+ # Option 4
39
+ def t(a: int, b: str) -> (int, str): ...
40
+
41
+ # Option 5
42
+ def t(a: int, b: str) -> str: ...
43
+
44
+ # Option 6
45
+ def t(a: int, b: str) -> None: ...
46
+
47
+ # Options 7/8
48
+ def t(a: int, b: str) -> List[int]: ...
49
+ def t(a: int, b: str) -> Dict[str, int]: ...
50
+
51
+ Note that Options 1 and 2 are identical, just syntactic sugar. In the NamedTuple case, we'll use the names in the
52
+ definition. In all other cases, we'll automatically generate output names, indexed starting at 0.
53
+ """
54
+ if isinstance(return_annotation, str):
55
+ raise TypeError("String return annotations are not supported.")
56
+
57
+ # Handle Option 6
58
+ # We can think about whether we should add a default output name with type None in the future.
59
+ if return_annotation in (None, type(None), inspect.Signature.empty):
60
+ return {}
61
+
62
+ # This statement results in true for typing.Namedtuple, single and void return types, so this
63
+ # handles Options 1, 2. Even though NamedTuple for us is multi-valued, it's a single value for Python
64
+ if hasattr(return_annotation, "__bases__") and (
65
+ isinstance(return_annotation, type) or isinstance(return_annotation, TypeVar)
66
+ ):
67
+ # isinstance / issubclass does not work for Namedtuple.
68
+ # Options 1 and 2
69
+ bases = return_annotation.__bases__ # type: ignore
70
+ if len(bases) == 1 and bases[0] is tuple and hasattr(return_annotation, "_fields"):
71
+ # Task returns named tuple
72
+ return dict(get_type_hints(cast(Type, return_annotation), include_extras=True))
73
+
74
+ if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore
75
+ # Handle option 3
76
+ # Task returns unnamed typing.Tuple
77
+ if len(return_annotation.__args__) == 1: # type: ignore
78
+ raise TypeError("Tuples should be used to indicate multiple return values, found only one return variable.")
79
+ ra = get_args(return_annotation)
80
+ annotations = {}
81
+ for i, r in enumerate(ra):
82
+ if r is Ellipsis:
83
+ raise TypeError("Variable length tuples are not supported as return types.")
84
+ if get_origin(r) is Literal:
85
+ annotations[default_output_name(i)] = literal_to_enum(cast(Type, r))
86
+ else:
87
+ annotations[default_output_name(i)] = r
88
+ return annotations
89
+
90
+ elif isinstance(return_annotation, tuple):
91
+ if len(return_annotation) == 1:
92
+ raise TypeError("Please don't use a tuple if you're just returning one thing.")
93
+ return dict(zip(list(output_name_generator(len(return_annotation))), return_annotation))
94
+
95
+ else:
96
+ # Handle all other single return types
97
+ # Task returns unnamed native tuple
98
+ if get_origin(return_annotation) is Literal:
99
+ return {default_output_name(): literal_to_enum(cast(Type, return_annotation))}
100
+ return {default_output_name(): cast(Type, return_annotation)}
101
+
102
+
103
+ def literal_to_enum(literal_type: Type) -> Type[Enum | typing.Any]:
104
+ """Convert a Literal[...] into Union[str, Enum]."""
105
+
106
+ if get_origin(literal_type) is not Literal:
107
+ raise TypeError(f"{literal_type} is not a Literal")
108
+
109
+ values = get_args(literal_type)
110
+ if not all(isinstance(v, str) for v in values):
111
+ logger.warning(f"Literal type {literal_type} contains non-string values, using Any instead of Enum")
112
+ return typing.Any
113
+ # Deduplicate & keep order
114
+ enum_dict = {str(v).upper(): v for v in values}
115
+
116
+ # Dynamically create an Enum
117
+ literal_enum = Enum(LITERAL_ENUM, enum_dict) # type: ignore
118
+
119
+ return literal_enum # type: ignore
@@ -0,0 +1,3 @@
1
+ from .controllers import Controller, ControllerType, create_controller
2
+
3
+ __all__ = ["Controller", "ControllerType", "create_controller"]
@@ -0,0 +1,129 @@
1
+ import concurrent.futures
2
+ import threading
3
+ from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Protocol, Tuple, TypeVar
4
+
5
+ from flyte._task import TaskTemplate
6
+ from flyte.models import ActionID, NativeInterface
7
+
8
+ if TYPE_CHECKING:
9
+ from flyte.remote._task import TaskDetails
10
+
11
+ from ._trace import TraceInfo
12
+
13
+ __all__ = ["Controller", "ControllerType", "TraceInfo", "create_controller", "get_controller"]
14
+
15
+ if TYPE_CHECKING:
16
+ import concurrent.futures
17
+
18
+ ControllerType = Literal["local", "remote"]
19
+
20
+ R = TypeVar("R")
21
+
22
+
23
+ class Controller(Protocol):
24
+ """
25
+ Controller interface, that is used to execute tasks. The implementation of this interface,
26
+ can execute tasks in different ways, such as locally, remotely etc.
27
+ """
28
+
29
+ async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
30
+ """
31
+ Submit a node to the controller asynchronously and wait for the result. This is async and will block
32
+ the current coroutine until the result is available.
33
+ """
34
+ ...
35
+
36
+ def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
37
+ """
38
+ This should call the async submit method above, but return a concurrent Future object that can be
39
+ used in a blocking wait or wrapped in an async future. This is called when
40
+ a) a synchronous task is kicked off locally,
41
+ b) a running task (of either kind) kicks off a downstream synchronous task.
42
+ """
43
+ ...
44
+
45
+ async def submit_task_ref(self, _task: "TaskDetails", *args, **kwargs) -> Any:
46
+ """
47
+ Submit a task reference to the controller asynchronously and wait for the result. This is async and will block
48
+ the current coroutine until the result is available.
49
+ """
50
+ ...
51
+
52
+ async def finalize_parent_action(self, action: ActionID):
53
+ """
54
+ Finalize the parent action. This can be called to cleanup the action and should be called after the parent
55
+ task completes
56
+ :param action: Action ID
57
+ :return:
58
+ """
59
+ ...
60
+
61
+ async def watch_for_errors(self): ...
62
+
63
+ async def get_action_outputs(
64
+ self, _interface: NativeInterface, _func: Callable, *args, **kwargs
65
+ ) -> Tuple[TraceInfo, bool]:
66
+ """
67
+ This method returns the outputs of the action, if it is available.
68
+ :param _interface: NativeInterface
69
+ :param _func: Function name
70
+ :param args: Arguments
71
+ :param kwargs: Keyword arguments
72
+ :return: TraceInfo object and a boolean indicating if the action was found.
73
+ if boolean is False, it means the action is not found and the TraceInfo object will have only min info
74
+ """
75
+
76
+ async def record_trace(self, info: TraceInfo):
77
+ """
78
+ Record a trace action. This is used to record the trace of the action and should be called when the action
79
+ is completed.
80
+ :param info: Trace information
81
+ :return:
82
+ """
83
+ ...
84
+
85
+ async def stop(self):
86
+ """
87
+ Stops the engine and should be called when the engine is no longer needed.
88
+ """
89
+ ...
90
+
91
+
92
+ # Internal state holder
93
+ class _ControllerState:
94
+ controller: Optional[Controller] = None
95
+ lock = threading.Lock()
96
+
97
+
98
+ def get_controller() -> Controller:
99
+ """
100
+ Get the controller instance. Raise an error if it has not been created.
101
+ """
102
+ if _ControllerState.controller is not None:
103
+ return _ControllerState.controller
104
+ raise RuntimeError("Controller is not initialized. Please call create_controller() first.")
105
+
106
+
107
+ def create_controller(
108
+ ct: ControllerType,
109
+ **kwargs,
110
+ ) -> Controller:
111
+ """
112
+ Create a new instance of the controller, based on the kind and the given configuration.
113
+ """
114
+ controller: Controller
115
+ match ct:
116
+ case "local":
117
+ from ._local_controller import LocalController
118
+
119
+ controller = LocalController()
120
+ case "remote" | "hybrid":
121
+ from flyte._internal.controllers.remote import create_remote_controller
122
+
123
+ controller = create_remote_controller(**kwargs)
124
+ case _:
125
+ raise ValueError(f"{ct} is not a valid controller type.")
126
+
127
+ with _ControllerState.lock:
128
+ _ControllerState.controller = controller
129
+ return controller
@@ -0,0 +1,239 @@
1
+ import asyncio
2
+ import atexit
3
+ import concurrent.futures
4
+ import os
5
+ import pathlib
6
+ import threading
7
+ from typing import Any, Callable, Tuple, TypeVar
8
+
9
+ import flyte.errors
10
+ from flyte._cache.cache import VersionParameters, cache_from_request
11
+ from flyte._cache.local_cache import LocalTaskCache
12
+ from flyte._context import internal_ctx
13
+ from flyte._internal.controllers import TraceInfo
14
+ from flyte._internal.runtime import convert
15
+ from flyte._internal.runtime.entrypoints import direct_dispatch
16
+ from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
17
+ from flyte._logging import log, logger
18
+ from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
19
+ from flyte._utils.helpers import _selector_policy
20
+ from flyte.models import ActionID, NativeInterface
21
+ from flyte.remote._task import TaskDetails
22
+
23
+ R = TypeVar("R")
24
+
25
+
26
+ class _TaskRunner:
27
+ """A task runner that runs an asyncio event loop on a background thread."""
28
+
29
+ def __init__(self) -> None:
30
+ self.__loop: asyncio.AbstractEventLoop | None = None
31
+ self.__runner_thread: threading.Thread | None = None
32
+ self.__lock = threading.Lock()
33
+ atexit.register(self._close)
34
+
35
+ def _close(self) -> None:
36
+ if self.__loop:
37
+ self.__loop.stop()
38
+
39
+ def _execute(self) -> None:
40
+ loop = self.__loop
41
+ assert loop is not None
42
+ try:
43
+ loop.run_forever()
44
+ finally:
45
+ loop.close()
46
+
47
+ def get_exc_handler(self):
48
+ def exc_handler(loop, context):
49
+ logger.error(
50
+ f"Taskrunner for {self.__runner_thread.name if self.__runner_thread else 'no thread'} caught"
51
+ f" exception in {loop}: {context}"
52
+ )
53
+
54
+ return exc_handler
55
+
56
+ def get_run_future(self, coro: Any) -> concurrent.futures.Future:
57
+ """Synchronously run a coroutine on a background thread."""
58
+ name = f"{threading.current_thread().name} : loop-runner"
59
+ with self.__lock:
60
+ if self.__loop is None:
61
+ with _selector_policy():
62
+ self.__loop = asyncio.new_event_loop()
63
+
64
+ exc_handler = self.get_exc_handler()
65
+ self.__loop.set_exception_handler(exc_handler)
66
+ self.__runner_thread = threading.Thread(target=self._execute, daemon=True, name=name)
67
+ self.__runner_thread.start()
68
+ fut = asyncio.run_coroutine_threadsafe(coro, self.__loop)
69
+ return fut
70
+
71
+
72
+ class LocalController:
73
+ def __init__(self):
74
+ logger.debug("LocalController init")
75
+ self._runner_map: dict[str, _TaskRunner] = {}
76
+
77
+ @log
78
+ async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
79
+ """
80
+ Main entrypoint for submitting a task to the local controller.
81
+ """
82
+ ctx = internal_ctx()
83
+ tctx = ctx.data.task_context
84
+ if not tctx:
85
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
86
+
87
+ inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
88
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
89
+ task_interface = transform_native_to_typed_interface(_task.interface)
90
+
91
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
92
+ tctx, _task.name, inputs_hash, 0
93
+ )
94
+ sub_action_raw_data_path = tctx.raw_data_path
95
+ # Make sure the output path exists
96
+ pathlib.Path(sub_action_output_path).mkdir(parents=True, exist_ok=True)
97
+ pathlib.Path(sub_action_raw_data_path.path).mkdir(parents=True, exist_ok=True)
98
+
99
+ task_cache = cache_from_request(_task.cache)
100
+ cache_enabled = task_cache.is_enabled()
101
+ if isinstance(_task, AsyncFunctionTaskTemplate):
102
+ version_parameters = VersionParameters(func=_task.func, image=_task.image)
103
+ else:
104
+ version_parameters = VersionParameters(func=None, image=_task.image)
105
+ cache_version = task_cache.get_version(version_parameters)
106
+ cache_key = convert.generate_cache_key_hash(
107
+ _task.name,
108
+ inputs_hash,
109
+ task_interface,
110
+ cache_version,
111
+ list(task_cache.get_ignored_inputs()),
112
+ inputs.proto_inputs,
113
+ )
114
+
115
+ out = None
116
+ # We only get output from cache if the cache behavior is set to auto
117
+ if task_cache.behavior == "auto":
118
+ out = await LocalTaskCache.get(cache_key)
119
+ if out is not None:
120
+ logger.info(
121
+ f"Cache hit for task '{_task.name}' (version: {cache_version}), getting result from cache..."
122
+ )
123
+
124
+ if out is None:
125
+ out, err = await direct_dispatch(
126
+ _task,
127
+ controller=self,
128
+ action=sub_action_id,
129
+ raw_data_path=sub_action_raw_data_path,
130
+ inputs=inputs,
131
+ version=cache_version,
132
+ checkpoints=tctx.checkpoints,
133
+ code_bundle=tctx.code_bundle,
134
+ output_path=sub_action_output_path,
135
+ run_base_dir=tctx.run_base_dir,
136
+ )
137
+
138
+ if err:
139
+ exc = convert.convert_error_to_native(err)
140
+ if exc:
141
+ raise exc
142
+ else:
143
+ raise flyte.errors.RuntimeSystemError("BadError", "Unknown error")
144
+
145
+ # store into cache
146
+ if cache_enabled and out is not None:
147
+ await LocalTaskCache.set(cache_key, out)
148
+
149
+ if _task.native_interface.outputs:
150
+ if out is None:
151
+ raise flyte.errors.RuntimeSystemError("BadOutput", "Task output not captured.")
152
+ result = await convert.convert_outputs_to_native(_task.native_interface, out)
153
+ return result
154
+ return None
155
+
156
+ def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
157
+ name = threading.current_thread().name + f"PID:{os.getpid()}"
158
+ coro = self.submit(_task, *args, **kwargs)
159
+ if name not in self._runner_map:
160
+ if len(self._runner_map) > 100:
161
+ logger.warning(
162
+ "More than 100 event loop runners created!!! This could be a case of runaway recursion..."
163
+ )
164
+ self._runner_map[name] = _TaskRunner()
165
+
166
+ return self._runner_map[name].get_run_future(coro)
167
+
168
+ async def finalize_parent_action(self, action: ActionID):
169
+ pass
170
+
171
+ async def stop(self):
172
+ await LocalTaskCache.close()
173
+
174
+ async def watch_for_errors(self):
175
+ pass
176
+
177
+ async def get_action_outputs(
178
+ self, _interface: NativeInterface, _func: Callable, *args, **kwargs
179
+ ) -> Tuple[TraceInfo, bool]:
180
+ """
181
+ This method returns the outputs of the action, if it is available.
182
+ If not available it raises a flyte.errors.ActionNotFoundError.
183
+ :return:
184
+ """
185
+ ctx = internal_ctx()
186
+ tctx = ctx.data.task_context
187
+ if not tctx:
188
+ raise flyte.errors.NotInTaskContextError("BadContext", "Task context not initialized")
189
+
190
+ converted_inputs = convert.Inputs.empty()
191
+ if _interface.inputs:
192
+ converted_inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
193
+ assert converted_inputs
194
+
195
+ inputs_hash = convert.generate_inputs_hash_from_proto(converted_inputs.proto_inputs)
196
+ action_id, action_output_path = convert.generate_sub_action_id_and_output_path(
197
+ tctx,
198
+ _func.__name__,
199
+ inputs_hash,
200
+ 0,
201
+ )
202
+ assert action_output_path
203
+ return (
204
+ TraceInfo(
205
+ name=_func.__name__,
206
+ action=action_id,
207
+ interface=_interface,
208
+ inputs_path=action_output_path,
209
+ ),
210
+ True,
211
+ )
212
+
213
+ async def record_trace(self, info: TraceInfo):
214
+ """
215
+ This method records the trace of the action.
216
+ :param info: Trace information
217
+ :return:
218
+ """
219
+ ctx = internal_ctx()
220
+ tctx = ctx.data.task_context
221
+ if not tctx:
222
+ raise flyte.errors.NotInTaskContextError("BadContext", "Task context not initialized")
223
+
224
+ if info.interface.outputs and info.output:
225
+ # If the result is not an AsyncGenerator, convert it directly
226
+ converted_outputs = await convert.convert_from_native_to_outputs(info.output, info.interface, info.name)
227
+ assert converted_outputs
228
+ elif info.error:
229
+ # If there is an error, convert it to a native error
230
+ converted_error = convert.convert_from_native_to_error(info.error)
231
+ assert converted_error
232
+ assert info.action
233
+ assert info.start_time
234
+ assert info.end_time
235
+
236
+ async def submit_task_ref(self, _task: TaskDetails, max_inline_io_bytes: int, *args, **kwargs) -> Any:
237
+ raise flyte.errors.ReferenceTaskError(
238
+ f"Reference tasks cannot be executed locally, only remotely. Found remote task {_task.name}"
239
+ )
@@ -0,0 +1,48 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, Optional
3
+
4
+ from flyteidl2.core import interface_pb2
5
+
6
+ from flyte.models import ActionID, NativeInterface
7
+
8
+
9
+ @dataclass
10
+ class TraceInfo:
11
+ """
12
+ Trace information for the action. This is used to record the trace of the action and should be called when
13
+ the action is completed.
14
+ """
15
+
16
+ name: str
17
+ action: ActionID
18
+ interface: NativeInterface
19
+ inputs_path: str
20
+ start_time: float = field(init=False, default=0.0)
21
+ end_time: float = field(init=False, default=0.0)
22
+ output: Optional[Any] = None
23
+ error: Optional[Exception] = None
24
+ typed_interface: Optional[interface_pb2.TypedInterface] = None
25
+
26
+ def add_outputs(self, output: Any, start_time: float, end_time: float):
27
+ """
28
+ Add outputs to the trace information.
29
+ :param output: Output of the action
30
+ :param start_time: Start time of the action
31
+ :param end_time: End time of the action
32
+ :return:
33
+ """
34
+ self.output = output
35
+ self.start_time = start_time
36
+ self.end_time = end_time
37
+
38
+ def add_error(self, error: Exception, start_time: float, end_time: float):
39
+ """
40
+ Add error to the trace information.
41
+ :param error: Error of the action
42
+ :param start_time: Start time of the action
43
+ :param end_time: End time of the action
44
+ :return:
45
+ """
46
+ self.error = error
47
+ self.start_time = start_time
48
+ self.end_time = end_time
@@ -0,0 +1,58 @@
1
+ from typing import List
2
+
3
+ from flyte.remote._client.auth import AuthType, ClientConfig
4
+
5
+ from ._controller import RemoteController
6
+
7
+ __all__ = ["RemoteController", "create_remote_controller"]
8
+
9
+
10
+ def create_remote_controller(
11
+ *,
12
+ api_key: str | None = None,
13
+ endpoint: str | None = None,
14
+ insecure: bool = False,
15
+ insecure_skip_verify: bool = False,
16
+ ca_cert_file_path: str | None = None,
17
+ client_config: ClientConfig | None = None,
18
+ auth_type: AuthType = "Pkce",
19
+ headless: bool = False,
20
+ command: List[str] | None = None,
21
+ proxy_command: List[str] | None = None,
22
+ client_id: str | None = None,
23
+ client_credentials_secret: str | None = None,
24
+ rpc_retries: int = 3,
25
+ http_proxy_url: str | None = None,
26
+ ) -> RemoteController:
27
+ """
28
+ Create a new instance of the remote controller.
29
+ """
30
+ assert endpoint or api_key, "Either endpoint or api_key must be provided when initializing remote controller"
31
+ from ._client import ControllerClient
32
+ from ._controller import RemoteController
33
+
34
+ if endpoint:
35
+ client_coro = ControllerClient.for_endpoint(
36
+ endpoint,
37
+ insecure=insecure,
38
+ insecure_skip_verify=insecure_skip_verify,
39
+ ca_cert_file_path=ca_cert_file_path,
40
+ client_id=client_id,
41
+ client_credentials_secret=client_credentials_secret,
42
+ auth_type=auth_type,
43
+ )
44
+ elif api_key:
45
+ client_coro = ControllerClient.for_api_key(
46
+ api_key,
47
+ insecure=insecure,
48
+ insecure_skip_verify=insecure_skip_verify,
49
+ ca_cert_file_path=ca_cert_file_path,
50
+ client_id=client_id,
51
+ client_credentials_secret=client_credentials_secret,
52
+ auth_type=auth_type,
53
+ )
54
+
55
+ controller = RemoteController(
56
+ client_coro=client_coro,
57
+ )
58
+ return controller