flyte 0.0.1b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +62 -0
- flyte/_api_commons.py +3 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +126 -0
- flyte/_build.py +25 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +146 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_common.py +287 -0
- flyte/_cli/_create.py +42 -0
- flyte/_cli/_delete.py +23 -0
- flyte/_cli/_deploy.py +140 -0
- flyte/_cli/_get.py +235 -0
- flyte/_cli/_run.py +152 -0
- flyte/_cli/main.py +72 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +339 -0
- flyte/_code_bundle/bundle.py +178 -0
- flyte/_context.py +146 -0
- flyte/_datastructures.py +342 -0
- flyte/_deploy.py +202 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +43 -0
- flyte/_group.py +31 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +760 -0
- flyte/_initialize.py +634 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +115 -0
- flyte/_internal/controllers/_local_controller.py +118 -0
- flyte/_internal/controllers/_trace.py +40 -0
- flyte/_internal/controllers/pbhash.py +39 -0
- flyte/_internal/controllers/remote/__init__.py +40 -0
- flyte/_internal/controllers/remote/_action.py +141 -0
- flyte/_internal/controllers/remote/_client.py +43 -0
- flyte/_internal/controllers/remote/_controller.py +361 -0
- flyte/_internal/controllers/remote/_core.py +402 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +416 -0
- flyte/_internal/imagebuild/image_builder.py +241 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +199 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +210 -0
- flyte/_internal/runtime/taskrunner.py +190 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +124 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +69 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +106 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +133 -0
- flyte/_protos/workflow/run_service_pb2.pyi +175 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +58 -0
- flyte/_protos/workflow/state_service_pb2.pyi +71 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +72 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +44 -0
- flyte/_protos/workflow/task_service_pb2.pyi +31 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +411 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +367 -0
- flyte/_task_environment.py +200 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +128 -0
- flyte/_utils/__init__.py +20 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/coro_management.py +25 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +108 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +143 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +273 -0
- flyte/io/__init__.py +11 -0
- flyte/io/_dataframe.py +0 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +468 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/io/pickle/transformer.py +117 -0
- flyte/io/structured_dataset/__init__.py +129 -0
- flyte/io/structured_dataset/basic_dfs.py +219 -0
- flyte/io/structured_dataset/structured_dataset.py +1061 -0
- flyte/py.typed +0 -0
- flyte/remote/__init__.py +25 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +131 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +184 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +155 -0
- flyte/remote/_logs.py +116 -0
- flyte/remote/_project.py +86 -0
- flyte/remote/_run.py +873 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +227 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +24 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +251 -0
- flyte/storage/_utils.py +5 -0
- flyte/types/__init__.py +13 -0
- flyte/types/_interface.py +25 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2210 -0
- flyte/types/_utils.py +80 -0
- flyte-0.0.1b0.dist-info/METADATA +179 -0
- flyte-0.0.1b0.dist-info/RECORD +390 -0
- flyte-0.0.1b0.dist-info/WHEEL +5 -0
- flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
- flyte-0.0.1b0.dist-info/top_level.txt +1 -0
- union/__init__.py +54 -0
- union/_api_commons.py +3 -0
- union/_bin/__init__.py +0 -0
- union/_bin/runtime.py +113 -0
- union/_build.py +25 -0
- union/_cache/__init__.py +12 -0
- union/_cache/cache.py +141 -0
- union/_cache/defaults.py +9 -0
- union/_cache/policy_function_body.py +42 -0
- union/_cli/__init__.py +0 -0
- union/_cli/_common.py +263 -0
- union/_cli/_create.py +40 -0
- union/_cli/_delete.py +23 -0
- union/_cli/_deploy.py +120 -0
- union/_cli/_get.py +162 -0
- union/_cli/_params.py +579 -0
- union/_cli/_run.py +150 -0
- union/_cli/main.py +72 -0
- union/_code_bundle/__init__.py +8 -0
- union/_code_bundle/_ignore.py +113 -0
- union/_code_bundle/_packaging.py +187 -0
- union/_code_bundle/_utils.py +342 -0
- union/_code_bundle/bundle.py +176 -0
- union/_context.py +146 -0
- union/_datastructures.py +295 -0
- union/_deploy.py +185 -0
- union/_doc.py +29 -0
- union/_docstring.py +26 -0
- union/_environment.py +43 -0
- union/_group.py +31 -0
- union/_hash.py +23 -0
- union/_image.py +760 -0
- union/_initialize.py +585 -0
- union/_interface.py +84 -0
- union/_internal/__init__.py +3 -0
- union/_internal/controllers/__init__.py +77 -0
- union/_internal/controllers/_local_controller.py +77 -0
- union/_internal/controllers/pbhash.py +39 -0
- union/_internal/controllers/remote/__init__.py +40 -0
- union/_internal/controllers/remote/_action.py +131 -0
- union/_internal/controllers/remote/_client.py +43 -0
- union/_internal/controllers/remote/_controller.py +169 -0
- union/_internal/controllers/remote/_core.py +341 -0
- union/_internal/controllers/remote/_informer.py +260 -0
- union/_internal/controllers/remote/_service_protocol.py +44 -0
- union/_internal/imagebuild/__init__.py +11 -0
- union/_internal/imagebuild/docker_builder.py +416 -0
- union/_internal/imagebuild/image_builder.py +243 -0
- union/_internal/imagebuild/remote_builder.py +0 -0
- union/_internal/resolvers/__init__.py +0 -0
- union/_internal/resolvers/_task_module.py +31 -0
- union/_internal/resolvers/common.py +24 -0
- union/_internal/resolvers/default.py +27 -0
- union/_internal/runtime/__init__.py +0 -0
- union/_internal/runtime/convert.py +163 -0
- union/_internal/runtime/entrypoints.py +121 -0
- union/_internal/runtime/io.py +136 -0
- union/_internal/runtime/resources_serde.py +134 -0
- union/_internal/runtime/task_serde.py +202 -0
- union/_internal/runtime/taskrunner.py +179 -0
- union/_internal/runtime/types_serde.py +53 -0
- union/_logging.py +124 -0
- union/_protos/__init__.py +0 -0
- union/_protos/common/authorization_pb2.py +66 -0
- union/_protos/common/authorization_pb2.pyi +106 -0
- union/_protos/common/authorization_pb2_grpc.py +4 -0
- union/_protos/common/identifier_pb2.py +71 -0
- union/_protos/common/identifier_pb2.pyi +82 -0
- union/_protos/common/identifier_pb2_grpc.py +4 -0
- union/_protos/common/identity_pb2.py +48 -0
- union/_protos/common/identity_pb2.pyi +72 -0
- union/_protos/common/identity_pb2_grpc.py +4 -0
- union/_protos/common/list_pb2.py +36 -0
- union/_protos/common/list_pb2.pyi +69 -0
- union/_protos/common/list_pb2_grpc.py +4 -0
- union/_protos/common/policy_pb2.py +37 -0
- union/_protos/common/policy_pb2.pyi +27 -0
- union/_protos/common/policy_pb2_grpc.py +4 -0
- union/_protos/common/role_pb2.py +37 -0
- union/_protos/common/role_pb2.pyi +51 -0
- union/_protos/common/role_pb2_grpc.py +4 -0
- union/_protos/common/runtime_version_pb2.py +28 -0
- union/_protos/common/runtime_version_pb2.pyi +24 -0
- union/_protos/common/runtime_version_pb2_grpc.py +4 -0
- union/_protos/logs/dataplane/payload_pb2.py +96 -0
- union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- union/_protos/secret/definition_pb2.py +49 -0
- union/_protos/secret/definition_pb2.pyi +93 -0
- union/_protos/secret/definition_pb2_grpc.py +4 -0
- union/_protos/secret/payload_pb2.py +62 -0
- union/_protos/secret/payload_pb2.pyi +94 -0
- union/_protos/secret/payload_pb2_grpc.py +4 -0
- union/_protos/secret/secret_pb2.py +38 -0
- union/_protos/secret/secret_pb2.pyi +6 -0
- union/_protos/secret/secret_pb2_grpc.py +198 -0
- union/_protos/validate/validate/validate_pb2.py +76 -0
- union/_protos/workflow/node_execution_service_pb2.py +26 -0
- union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- union/_protos/workflow/queue_service_pb2.py +75 -0
- union/_protos/workflow/queue_service_pb2.pyi +103 -0
- union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- union/_protos/workflow/run_definition_pb2.py +100 -0
- union/_protos/workflow/run_definition_pb2.pyi +256 -0
- union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/run_logs_service_pb2.py +41 -0
- union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- union/_protos/workflow/run_service_pb2.py +133 -0
- union/_protos/workflow/run_service_pb2.pyi +173 -0
- union/_protos/workflow/run_service_pb2_grpc.py +412 -0
- union/_protos/workflow/state_service_pb2.py +58 -0
- union/_protos/workflow/state_service_pb2.pyi +69 -0
- union/_protos/workflow/state_service_pb2_grpc.py +138 -0
- union/_protos/workflow/task_definition_pb2.py +72 -0
- union/_protos/workflow/task_definition_pb2.pyi +65 -0
- union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/task_service_pb2.py +44 -0
- union/_protos/workflow/task_service_pb2.pyi +31 -0
- union/_protos/workflow/task_service_pb2_grpc.py +104 -0
- union/_resources.py +226 -0
- union/_retry.py +32 -0
- union/_reusable_environment.py +25 -0
- union/_run.py +374 -0
- union/_secret.py +61 -0
- union/_task.py +354 -0
- union/_task_environment.py +186 -0
- union/_timeout.py +47 -0
- union/_tools.py +27 -0
- union/_utils/__init__.py +11 -0
- union/_utils/asyn.py +119 -0
- union/_utils/file_handling.py +71 -0
- union/_utils/helpers.py +46 -0
- union/_utils/lazy_module.py +54 -0
- union/_utils/uv_script_parser.py +49 -0
- union/_version.py +21 -0
- union/connectors/__init__.py +0 -0
- union/errors.py +128 -0
- union/extras/__init__.py +5 -0
- union/extras/_container.py +263 -0
- union/io/__init__.py +11 -0
- union/io/_dataframe.py +0 -0
- union/io/_dir.py +425 -0
- union/io/_file.py +418 -0
- union/io/pickle/__init__.py +0 -0
- union/io/pickle/transformer.py +117 -0
- union/io/structured_dataset/__init__.py +122 -0
- union/io/structured_dataset/basic_dfs.py +219 -0
- union/io/structured_dataset/structured_dataset.py +1057 -0
- union/py.typed +0 -0
- union/remote/__init__.py +23 -0
- union/remote/_client/__init__.py +0 -0
- union/remote/_client/_protocols.py +129 -0
- union/remote/_client/auth/__init__.py +12 -0
- union/remote/_client/auth/_authenticators/__init__.py +0 -0
- union/remote/_client/auth/_authenticators/base.py +391 -0
- union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- union/remote/_client/auth/_authenticators/device_code.py +120 -0
- union/remote/_client/auth/_authenticators/external_command.py +77 -0
- union/remote/_client/auth/_authenticators/factory.py +200 -0
- union/remote/_client/auth/_authenticators/pkce.py +515 -0
- union/remote/_client/auth/_channel.py +184 -0
- union/remote/_client/auth/_client_config.py +83 -0
- union/remote/_client/auth/_default_html.py +32 -0
- union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
- union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
- union/remote/_client/auth/_keyring.py +154 -0
- union/remote/_client/auth/_token_client.py +258 -0
- union/remote/_client/auth/errors.py +16 -0
- union/remote/_client/controlplane.py +86 -0
- union/remote/_data.py +149 -0
- union/remote/_logs.py +74 -0
- union/remote/_project.py +86 -0
- union/remote/_run.py +820 -0
- union/remote/_secret.py +132 -0
- union/remote/_task.py +193 -0
- union/report/__init__.py +3 -0
- union/report/_report.py +178 -0
- union/report/_template.html +124 -0
- union/storage/__init__.py +24 -0
- union/storage/_remote_fs.py +34 -0
- union/storage/_storage.py +247 -0
- union/storage/_utils.py +5 -0
- union/types/__init__.py +11 -0
- union/types/_renderer.py +162 -0
- union/types/_string_literals.py +120 -0
- union/types/_type_engine.py +2131 -0
- union/types/_utils.py +80 -0
flyte/_interface.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from typing import Dict, Generator, Tuple, Type, TypeVar, Union, cast, get_args, get_type_hints
|
|
5
|
+
|
|
6
|
+
from flyte._logging import logger
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def default_output_name(index: int = 0) -> str:
|
|
10
|
+
return f"o{index}"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def output_name_generator(length: int) -> Generator[str, None, None]:
|
|
14
|
+
for x in range(length):
|
|
15
|
+
yield default_output_name(x)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Dict[str, Type]:
|
|
19
|
+
"""
|
|
20
|
+
The input to this function should be sig.return_annotation where sig = inspect.signature(some_func)
|
|
21
|
+
The purpose of this function is to sort out whether a function is returning one thing, or multiple things, and to
|
|
22
|
+
name the outputs accordingly, either by using our default name function, or from a typing.NamedTuple.
|
|
23
|
+
|
|
24
|
+
# Option 1
|
|
25
|
+
nt1 = typing.NamedTuple("NT1", x_str=str, y_int=int)
|
|
26
|
+
def t(a: int, b: str) -> nt1: ...
|
|
27
|
+
|
|
28
|
+
# Option 2
|
|
29
|
+
def t(a: int, b: str) -> typing.NamedTuple("NT1", x_str=str, y_int=int): ...
|
|
30
|
+
|
|
31
|
+
# Option 3
|
|
32
|
+
def t(a: int, b: str) -> typing.Tuple[int, str]: ...
|
|
33
|
+
|
|
34
|
+
# Option 4
|
|
35
|
+
def t(a: int, b: str) -> (int, str): ...
|
|
36
|
+
|
|
37
|
+
# Option 5
|
|
38
|
+
def t(a: int, b: str) -> str: ...
|
|
39
|
+
|
|
40
|
+
# Option 6
|
|
41
|
+
def t(a: int, b: str) -> None: ...
|
|
42
|
+
|
|
43
|
+
# Options 7/8
|
|
44
|
+
def t(a: int, b: str) -> List[int]: ...
|
|
45
|
+
def t(a: int, b: str) -> Dict[str, int]: ...
|
|
46
|
+
|
|
47
|
+
Note that Options 1 and 2 are identical, just syntactic sugar. In the NamedTuple case, we'll use the names in the
|
|
48
|
+
definition. In all other cases, we'll automatically generate output names, indexed starting at 0.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
# Handle Option 6
|
|
52
|
+
# We can think about whether we should add a default output name with type None in the future.
|
|
53
|
+
if return_annotation in (None, type(None), inspect.Signature.empty):
|
|
54
|
+
return {}
|
|
55
|
+
|
|
56
|
+
# This statement results in true for typing.Namedtuple, single and void return types, so this
|
|
57
|
+
# handles Options 1, 2. Even though NamedTuple for us is multi-valued, it's a single value for Python
|
|
58
|
+
if hasattr(return_annotation, "__bases__") and (
|
|
59
|
+
isinstance(return_annotation, type) or isinstance(return_annotation, TypeVar)
|
|
60
|
+
):
|
|
61
|
+
# isinstance / issubclass does not work for Namedtuple.
|
|
62
|
+
# Options 1 and 2
|
|
63
|
+
bases = return_annotation.__bases__ # type: ignore
|
|
64
|
+
if len(bases) == 1 and bases[0] is tuple and hasattr(return_annotation, "_fields"):
|
|
65
|
+
logger.debug(f"Task returns named tuple {return_annotation}")
|
|
66
|
+
return dict(get_type_hints(cast(Type, return_annotation), include_extras=True))
|
|
67
|
+
|
|
68
|
+
if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore
|
|
69
|
+
# Handle option 3
|
|
70
|
+
logger.debug(f"Task returns unnamed typing.Tuple {return_annotation}")
|
|
71
|
+
if len(return_annotation.__args__) == 1: # type: ignore
|
|
72
|
+
raise TypeError("Tuples should be used to indicate multiple return values, found only one return variable.")
|
|
73
|
+
ra = get_args(return_annotation)
|
|
74
|
+
return dict(zip(list(output_name_generator(len(ra))), ra))
|
|
75
|
+
|
|
76
|
+
elif isinstance(return_annotation, tuple):
|
|
77
|
+
if len(return_annotation) == 1:
|
|
78
|
+
raise TypeError("Please don't use a tuple if you're just returning one thing.")
|
|
79
|
+
return dict(zip(list(output_name_generator(len(return_annotation))), return_annotation))
|
|
80
|
+
|
|
81
|
+
else:
|
|
82
|
+
# Handle all other single return types
|
|
83
|
+
logger.debug(f"Task returns unnamed native tuple {return_annotation}")
|
|
84
|
+
return {default_output_name(): cast(Type, return_annotation)}
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from typing import Any, Literal, Optional, Protocol, Tuple, TypeVar
|
|
3
|
+
|
|
4
|
+
from flyte._datastructures import ActionID, NativeInterface
|
|
5
|
+
from flyte._task import TaskTemplate
|
|
6
|
+
|
|
7
|
+
from ._trace import TraceInfo
|
|
8
|
+
|
|
9
|
+
__all__ = ["Controller", "ControllerType", "TraceInfo", "create_controller", "get_controller"]
|
|
10
|
+
|
|
11
|
+
from ..._protos.workflow import task_definition_pb2
|
|
12
|
+
|
|
13
|
+
ControllerType = Literal["local", "remote"]
|
|
14
|
+
|
|
15
|
+
R = TypeVar("R")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Controller(Protocol):
|
|
19
|
+
"""
|
|
20
|
+
Controller interface, that is used to execute tasks. The implementation of this interface,
|
|
21
|
+
can execute tasks in different ways, such as locally, remotely etc.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
25
|
+
"""
|
|
26
|
+
Submit a node to the controller asynchronously and wait for the result. This is async and will block
|
|
27
|
+
the current coroutine until the result is available.
|
|
28
|
+
"""
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
|
|
32
|
+
"""
|
|
33
|
+
Submit a task reference to the controller asynchronously and wait for the result. This is async and will block
|
|
34
|
+
the current coroutine until the result is available.
|
|
35
|
+
"""
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
async def finalize_parent_action(self, action: ActionID):
|
|
39
|
+
"""
|
|
40
|
+
Finalize the parent action. This can be called to cleanup the action and should be called after the parent
|
|
41
|
+
task completes
|
|
42
|
+
:param action: Action ID
|
|
43
|
+
:return:
|
|
44
|
+
"""
|
|
45
|
+
...
|
|
46
|
+
|
|
47
|
+
async def watch_for_errors(self): ...
|
|
48
|
+
|
|
49
|
+
async def get_action_outputs(
|
|
50
|
+
self, _interface: NativeInterface, _func_name: str, *args, **kwargs
|
|
51
|
+
) -> Tuple[TraceInfo, bool]:
|
|
52
|
+
"""
|
|
53
|
+
This method returns the outputs of the action, if it is available.
|
|
54
|
+
:param _interface: NativeInterface
|
|
55
|
+
:param _func_name: Function name
|
|
56
|
+
:param args: Arguments
|
|
57
|
+
:param kwargs: Keyword arguments
|
|
58
|
+
:return: TraceInfo object and a boolean indicating if the action was found.
|
|
59
|
+
if boolean is False, it means the action is not found and the TraceInfo object will have only min info
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
async def record_trace(self, info: TraceInfo):
|
|
63
|
+
"""
|
|
64
|
+
Record a trace action. This is used to record the trace of the action and should be called when the action
|
|
65
|
+
is completed.
|
|
66
|
+
:param info: Trace information
|
|
67
|
+
:return:
|
|
68
|
+
"""
|
|
69
|
+
...
|
|
70
|
+
|
|
71
|
+
async def stop(self):
|
|
72
|
+
"""
|
|
73
|
+
Stops the engine and should be called when the engine is no longer needed.
|
|
74
|
+
"""
|
|
75
|
+
...
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# Internal state holder
|
|
79
|
+
class _ControllerState:
|
|
80
|
+
controller: Optional[Controller] = None
|
|
81
|
+
lock = threading.Lock()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
async def get_controller() -> Controller:
|
|
85
|
+
"""
|
|
86
|
+
Get the controller instance. Raise an error if it has not been created.
|
|
87
|
+
"""
|
|
88
|
+
if _ControllerState.controller is not None:
|
|
89
|
+
return _ControllerState.controller
|
|
90
|
+
raise RuntimeError("Controller is not initialized. Please call get_or_create_controller() first.")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def create_controller(
|
|
94
|
+
ct: ControllerType,
|
|
95
|
+
**kwargs,
|
|
96
|
+
) -> Controller:
|
|
97
|
+
"""
|
|
98
|
+
Create a new instance of the controller, based on the kind and the given configuration.
|
|
99
|
+
"""
|
|
100
|
+
controller: Controller
|
|
101
|
+
match ct:
|
|
102
|
+
case "local":
|
|
103
|
+
from ._local_controller import LocalController
|
|
104
|
+
|
|
105
|
+
controller = LocalController()
|
|
106
|
+
case "remote" | "hybrid":
|
|
107
|
+
from flyte._internal.controllers.remote import create_remote_controller
|
|
108
|
+
|
|
109
|
+
controller = create_remote_controller(**kwargs)
|
|
110
|
+
case _:
|
|
111
|
+
raise ValueError(f"{ct} is not a valid controller type.")
|
|
112
|
+
|
|
113
|
+
with _ControllerState.lock:
|
|
114
|
+
_ControllerState.controller = controller
|
|
115
|
+
return controller
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from typing import Any, Tuple, TypeVar
|
|
2
|
+
|
|
3
|
+
import flyte.errors
|
|
4
|
+
from flyte._context import internal_ctx
|
|
5
|
+
from flyte._datastructures import ActionID, NativeInterface, RawDataPath
|
|
6
|
+
from flyte._internal.controllers import TraceInfo
|
|
7
|
+
from flyte._internal.runtime import convert
|
|
8
|
+
from flyte._internal.runtime.entrypoints import direct_dispatch
|
|
9
|
+
from flyte._logging import log, logger
|
|
10
|
+
from flyte._protos.workflow import task_definition_pb2
|
|
11
|
+
from flyte._task import TaskTemplate
|
|
12
|
+
|
|
13
|
+
R = TypeVar("R")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LocalController:
|
|
17
|
+
def __init__(self):
|
|
18
|
+
logger.debug("LocalController init")
|
|
19
|
+
|
|
20
|
+
@log
|
|
21
|
+
async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
22
|
+
"""
|
|
23
|
+
Submit a node to the controller
|
|
24
|
+
"""
|
|
25
|
+
ctx = internal_ctx()
|
|
26
|
+
tctx = ctx.data.task_context
|
|
27
|
+
if not tctx:
|
|
28
|
+
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
29
|
+
|
|
30
|
+
inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
|
|
31
|
+
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, _task.name, inputs)
|
|
32
|
+
sub_action_raw_data_path = RawDataPath(path=sub_action_output_path)
|
|
33
|
+
|
|
34
|
+
out, err = await direct_dispatch(
|
|
35
|
+
_task,
|
|
36
|
+
controller=self,
|
|
37
|
+
action=sub_action_id,
|
|
38
|
+
raw_data_path=sub_action_raw_data_path,
|
|
39
|
+
inputs=inputs,
|
|
40
|
+
version=tctx.version,
|
|
41
|
+
checkpoints=tctx.checkpoints,
|
|
42
|
+
code_bundle=tctx.code_bundle,
|
|
43
|
+
output_path=sub_action_output_path,
|
|
44
|
+
run_base_dir=tctx.run_base_dir,
|
|
45
|
+
)
|
|
46
|
+
if err:
|
|
47
|
+
exc = convert.convert_error_to_native(err)
|
|
48
|
+
if exc:
|
|
49
|
+
raise exc
|
|
50
|
+
else:
|
|
51
|
+
raise flyte.errors.RuntimeSystemError("BadError", "Unknown error")
|
|
52
|
+
if _task.native_interface.outputs and out is not None:
|
|
53
|
+
result = await convert.convert_outputs_to_native(_task.native_interface, out)
|
|
54
|
+
return result
|
|
55
|
+
return out
|
|
56
|
+
|
|
57
|
+
async def finalize_parent_action(self, action: ActionID):
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
async def stop(self):
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
async def watch_for_errors(self):
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
async def get_action_outputs(
|
|
67
|
+
self, _interface: NativeInterface, _func_name: str, *args, **kwargs
|
|
68
|
+
) -> Tuple[TraceInfo, bool]:
|
|
69
|
+
"""
|
|
70
|
+
This method returns the outputs of the action, if it is available.
|
|
71
|
+
If not available it raises a flyte.errors.ActionNotFoundError.
|
|
72
|
+
:return:
|
|
73
|
+
"""
|
|
74
|
+
ctx = internal_ctx()
|
|
75
|
+
tctx = ctx.data.task_context
|
|
76
|
+
if not tctx:
|
|
77
|
+
raise flyte.errors.NotInTaskContextError("BadContext", "Task context not initialized")
|
|
78
|
+
converted_inputs = convert.Inputs.empty()
|
|
79
|
+
if _interface.inputs:
|
|
80
|
+
converted_inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
|
|
81
|
+
assert converted_inputs
|
|
82
|
+
action_id, action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
83
|
+
tctx, _func_name, converted_inputs
|
|
84
|
+
)
|
|
85
|
+
assert action_output_path
|
|
86
|
+
return (
|
|
87
|
+
TraceInfo(
|
|
88
|
+
action=action_id,
|
|
89
|
+
interface=_interface,
|
|
90
|
+
inputs_path=action_output_path,
|
|
91
|
+
),
|
|
92
|
+
True,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
async def record_trace(self, info: TraceInfo):
|
|
96
|
+
"""
|
|
97
|
+
This method records the trace of the action.
|
|
98
|
+
:param info: Trace information
|
|
99
|
+
:return:
|
|
100
|
+
"""
|
|
101
|
+
ctx = internal_ctx()
|
|
102
|
+
tctx = ctx.data.task_context
|
|
103
|
+
if not tctx:
|
|
104
|
+
raise flyte.errors.NotInTaskContextError("BadContext", "Task context not initialized")
|
|
105
|
+
|
|
106
|
+
if info.interface.outputs and info.output:
|
|
107
|
+
# If the result is not an AsyncGenerator, convert it directly
|
|
108
|
+
converted_outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
|
|
109
|
+
assert converted_outputs
|
|
110
|
+
elif info.error:
|
|
111
|
+
# If there is an error, convert it to a native error
|
|
112
|
+
converted_error = convert.convert_from_native_to_error(info.error)
|
|
113
|
+
assert converted_error
|
|
114
|
+
assert info.action
|
|
115
|
+
assert info.duration
|
|
116
|
+
|
|
117
|
+
async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
|
|
118
|
+
raise flyte.errors.ReferenceTaskError("Reference tasks cannot be executed locally, only remotely.")
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from datetime import timedelta
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
from flyte._datastructures import ActionID, NativeInterface
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class TraceInfo:
|
|
10
|
+
"""
|
|
11
|
+
Trace information for the action. This is used to record the trace of the action and should be called when
|
|
12
|
+
the action is completed.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
action: ActionID
|
|
16
|
+
interface: NativeInterface
|
|
17
|
+
inputs_path: str
|
|
18
|
+
duration: Optional[timedelta] = None
|
|
19
|
+
output: Optional[Any] = None
|
|
20
|
+
error: Optional[Exception] = None
|
|
21
|
+
|
|
22
|
+
def add_outputs(self, output: Any, duration: timedelta):
|
|
23
|
+
"""
|
|
24
|
+
Add outputs to the trace information.
|
|
25
|
+
:param output: Output of the action
|
|
26
|
+
:param duration: Duration of the action
|
|
27
|
+
:return:
|
|
28
|
+
"""
|
|
29
|
+
self.output = output
|
|
30
|
+
self.duration = duration
|
|
31
|
+
|
|
32
|
+
def add_error(self, error: Exception, duration: timedelta):
|
|
33
|
+
"""
|
|
34
|
+
Add error to the trace information.
|
|
35
|
+
:param error: Error of the action
|
|
36
|
+
:param duration: Duration of the action
|
|
37
|
+
:return:
|
|
38
|
+
"""
|
|
39
|
+
self.error = error
|
|
40
|
+
self.duration = duration
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# This is a module that provides hashing utilities for Protobuf objects.
|
|
2
|
+
import base64
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
|
|
6
|
+
from google.protobuf import json_format
|
|
7
|
+
from google.protobuf.message import Message
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def compute_hash(pb: Message) -> bytes:
|
|
11
|
+
"""
|
|
12
|
+
Computes a deterministic hash in bytes for the Protobuf object.
|
|
13
|
+
"""
|
|
14
|
+
try:
|
|
15
|
+
pb_dict = json_format.MessageToDict(pb)
|
|
16
|
+
# json.dumps with sorted keys to ensure stability
|
|
17
|
+
stable_json_str = json.dumps(
|
|
18
|
+
pb_dict, sort_keys=True, separators=(",", ":")
|
|
19
|
+
) # separators to ensure no extra spaces
|
|
20
|
+
except Exception as e:
|
|
21
|
+
raise ValueError(f"Failed to marshal Protobuf object {pb} to JSON with error: {e}")
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
# Deterministically hash the JSON object to a byte array. Using SHA-256 for hashing here,
|
|
25
|
+
# assuming it provides a consistent hash output.
|
|
26
|
+
hash_obj = hashlib.sha256(stable_json_str.encode("utf-8"))
|
|
27
|
+
except Exception as e:
|
|
28
|
+
raise ValueError(f"Failed to hash JSON for Protobuf object {pb} with error: {e}")
|
|
29
|
+
|
|
30
|
+
# The digest is guaranteed to be 32 bytes long
|
|
31
|
+
return hash_obj.digest()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def compute_hash_string(pb: Message) -> str:
|
|
35
|
+
"""
|
|
36
|
+
Computes a deterministic hash in base64 encoded string for the Protobuf object
|
|
37
|
+
"""
|
|
38
|
+
hash_bytes = compute_hash(pb)
|
|
39
|
+
return base64.b64encode(hash_bytes).decode("utf-8")
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from flyte.remote._client.auth import AuthType, ClientConfig
|
|
4
|
+
|
|
5
|
+
from ._controller import RemoteController
|
|
6
|
+
|
|
7
|
+
__all__ = ["RemoteController", "create_remote_controller"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def create_remote_controller(
|
|
11
|
+
*,
|
|
12
|
+
api_key: str | None = None,
|
|
13
|
+
auth_type: AuthType = "Pkce",
|
|
14
|
+
endpoint: str,
|
|
15
|
+
client_config: ClientConfig | None = None,
|
|
16
|
+
headless: bool = False,
|
|
17
|
+
insecure: bool = False,
|
|
18
|
+
insecure_skip_verify: bool = False,
|
|
19
|
+
ca_cert_file_path: str | None = None,
|
|
20
|
+
command: List[str] | None = None,
|
|
21
|
+
proxy_command: List[str] | None = None,
|
|
22
|
+
client_id: str | None = None,
|
|
23
|
+
client_credentials_secret: str | None = None,
|
|
24
|
+
rpc_retries: int = 3,
|
|
25
|
+
http_proxy_url: str | None = None,
|
|
26
|
+
) -> RemoteController:
|
|
27
|
+
"""
|
|
28
|
+
Create a new instance of the remote controller.
|
|
29
|
+
"""
|
|
30
|
+
from ._client import ControllerClient
|
|
31
|
+
from ._controller import RemoteController
|
|
32
|
+
|
|
33
|
+
controller = RemoteController(
|
|
34
|
+
client_coro=ControllerClient.for_endpoint(
|
|
35
|
+
endpoint=endpoint, insecure=insecure, insecure_skip_verify=insecure_skip_verify
|
|
36
|
+
),
|
|
37
|
+
workers=10,
|
|
38
|
+
max_system_retries=5,
|
|
39
|
+
)
|
|
40
|
+
return controller
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from flyteidl.core import execution_pb2
|
|
6
|
+
|
|
7
|
+
from flyte._datastructures import GroupData
|
|
8
|
+
from flyte._protos.workflow import run_definition_pb2, state_service_pb2, task_definition_pb2
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class Action:
|
|
13
|
+
"""
|
|
14
|
+
Coroutine safe, as we never do await operations in any method.
|
|
15
|
+
Holds the inmemory state of a task. It is combined representation of local and remote states.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
action_id: run_definition_pb2.ActionIdentifier
|
|
19
|
+
parent_action_name: str
|
|
20
|
+
friendly_name: str | None = None
|
|
21
|
+
group: GroupData | None = None
|
|
22
|
+
task: task_definition_pb2.TaskSpec | None = None
|
|
23
|
+
inputs_uri: str | None = None
|
|
24
|
+
run_output_base: str | None = None
|
|
25
|
+
realized_outputs_uri: str | None = None
|
|
26
|
+
err: execution_pb2.ExecutionError | None = None
|
|
27
|
+
phase: run_definition_pb2.Phase | None = None
|
|
28
|
+
started: bool = False
|
|
29
|
+
retries: int = 0
|
|
30
|
+
client_err: Exception | None = None # This error is set when something goes wrong in the controller.
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def name(self) -> str:
|
|
34
|
+
return self.action_id.name
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def run_name(self) -> str:
|
|
38
|
+
return self.action_id.run.name
|
|
39
|
+
|
|
40
|
+
def is_terminal(self) -> bool:
|
|
41
|
+
"""Check if resource has reached terminal state"""
|
|
42
|
+
if self.phase is None:
|
|
43
|
+
return False
|
|
44
|
+
return self.phase in [
|
|
45
|
+
run_definition_pb2.Phase.PHASE_FAILED,
|
|
46
|
+
run_definition_pb2.Phase.PHASE_SUCCEEDED,
|
|
47
|
+
run_definition_pb2.Phase.PHASE_ABORTED,
|
|
48
|
+
run_definition_pb2.Phase.PHASE_TIMED_OUT,
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
def increment_retries(self):
|
|
52
|
+
self.retries += 1
|
|
53
|
+
|
|
54
|
+
def is_started(self) -> bool:
|
|
55
|
+
"""Check if resource has been started."""
|
|
56
|
+
return self.started
|
|
57
|
+
|
|
58
|
+
def mark_started(self):
|
|
59
|
+
self.started = True
|
|
60
|
+
self.task = None
|
|
61
|
+
|
|
62
|
+
def mark_cancelled(self):
|
|
63
|
+
self.mark_started()
|
|
64
|
+
self.phase = run_definition_pb2.Phase.PHASE_ABORTED
|
|
65
|
+
|
|
66
|
+
def merge_state(self, obj: state_service_pb2.ActionUpdate):
|
|
67
|
+
"""
|
|
68
|
+
This method is invoked when the watch API sends an update about the state of the action. We need to merge
|
|
69
|
+
the state of the action with the current state of the action. It is possible that we have no phase information
|
|
70
|
+
prior to this.
|
|
71
|
+
:param obj:
|
|
72
|
+
:return:
|
|
73
|
+
"""
|
|
74
|
+
if self.phase != obj.phase:
|
|
75
|
+
self.phase = obj.phase
|
|
76
|
+
self.err = obj.error if obj.HasField("error") else None
|
|
77
|
+
self.realized_outputs_uri = obj.output_uri
|
|
78
|
+
self.started = True
|
|
79
|
+
|
|
80
|
+
def merge_in_action_from_submit(self, action: Action):
|
|
81
|
+
"""
|
|
82
|
+
This method is invoked when parent_action submits an action that was observed previously observed from the
|
|
83
|
+
watch. We need to merge in the contents of the action, while preserving the observed phase.
|
|
84
|
+
|
|
85
|
+
:param action: The submitted action
|
|
86
|
+
"""
|
|
87
|
+
self.run_output_base = action.run_output_base
|
|
88
|
+
self.inputs_uri = action.inputs_uri
|
|
89
|
+
self.group = action.group
|
|
90
|
+
self.friendly_name = action.friendly_name
|
|
91
|
+
if not self.started:
|
|
92
|
+
self.task = action.task
|
|
93
|
+
|
|
94
|
+
def set_client_error(self, exc: Exception):
|
|
95
|
+
self.client_err = exc
|
|
96
|
+
|
|
97
|
+
def has_error(self) -> bool:
|
|
98
|
+
return self.client_err is not None or self.err is not None
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def from_task(
|
|
102
|
+
cls,
|
|
103
|
+
parent_action_name: str,
|
|
104
|
+
sub_action_id: run_definition_pb2.ActionIdentifier,
|
|
105
|
+
group_data: GroupData | None,
|
|
106
|
+
task_spec: task_definition_pb2.TaskSpec,
|
|
107
|
+
inputs_uri: str,
|
|
108
|
+
run_output_base: str,
|
|
109
|
+
) -> Action:
|
|
110
|
+
return cls(
|
|
111
|
+
action_id=sub_action_id,
|
|
112
|
+
parent_action_name=parent_action_name,
|
|
113
|
+
friendly_name=task_spec.task_template.id.name,
|
|
114
|
+
group=group_data,
|
|
115
|
+
task=task_spec,
|
|
116
|
+
inputs_uri=inputs_uri,
|
|
117
|
+
run_output_base=run_output_base,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
@classmethod
|
|
121
|
+
def from_state(cls, parent_action_name: str, obj: state_service_pb2.ActionUpdate) -> Action:
|
|
122
|
+
"""
|
|
123
|
+
This creates a new action, from the watch api. This is possible in the case of a recovery, where the
|
|
124
|
+
state service knows about future actions and sends this information to the informer. We may not have
|
|
125
|
+
encountered the "task" itself yet, but we know about the action id and the state of the action.
|
|
126
|
+
|
|
127
|
+
:param parent_action_name:
|
|
128
|
+
:param obj:
|
|
129
|
+
:return:
|
|
130
|
+
"""
|
|
131
|
+
from flyte._logging import logger
|
|
132
|
+
|
|
133
|
+
logger.info(f"In Action from_state {obj.action_id} {obj.phase} {obj.output_uri}")
|
|
134
|
+
return cls(
|
|
135
|
+
action_id=obj.action_id,
|
|
136
|
+
parent_action_name=parent_action_name,
|
|
137
|
+
phase=obj.phase,
|
|
138
|
+
started=True,
|
|
139
|
+
err=obj.error if obj.HasField("error") else None,
|
|
140
|
+
realized_outputs_uri=obj.output_uri,
|
|
141
|
+
)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import grpc.aio
|
|
4
|
+
|
|
5
|
+
from flyte._protos.workflow import queue_service_pb2_grpc, state_service_pb2_grpc
|
|
6
|
+
from flyte.remote import create_channel
|
|
7
|
+
|
|
8
|
+
from ._service_protocol import QueueService, StateService
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ControllerClient:
|
|
12
|
+
"""
|
|
13
|
+
A client for the Controller API.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, channel: grpc.aio.Channel):
|
|
17
|
+
self._channel = channel
|
|
18
|
+
self._state_service = state_service_pb2_grpc.StateServiceStub(channel=channel)
|
|
19
|
+
self._queue_service = queue_service_pb2_grpc.QueueServiceStub(channel=channel)
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
async def for_endpoint(cls, endpoint: str, insecure: bool = False, **kwargs) -> ControllerClient:
|
|
23
|
+
return cls(await create_channel(endpoint, insecure=insecure, **kwargs))
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def state_service(self) -> StateService:
|
|
27
|
+
"""
|
|
28
|
+
The state service.
|
|
29
|
+
"""
|
|
30
|
+
return self._state_service
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def queue_service(self) -> QueueService:
|
|
34
|
+
"""
|
|
35
|
+
The queue service.
|
|
36
|
+
"""
|
|
37
|
+
return self._queue_service
|
|
38
|
+
|
|
39
|
+
def close(self, grace: float | None = None):
|
|
40
|
+
"""
|
|
41
|
+
Close the channel.
|
|
42
|
+
"""
|
|
43
|
+
return self._channel.close(grace=grace)
|