flyte 0.0.1b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +62 -0
- flyte/_api_commons.py +3 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +126 -0
- flyte/_build.py +25 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +146 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_common.py +287 -0
- flyte/_cli/_create.py +42 -0
- flyte/_cli/_delete.py +23 -0
- flyte/_cli/_deploy.py +140 -0
- flyte/_cli/_get.py +235 -0
- flyte/_cli/_run.py +152 -0
- flyte/_cli/main.py +72 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +339 -0
- flyte/_code_bundle/bundle.py +178 -0
- flyte/_context.py +146 -0
- flyte/_datastructures.py +342 -0
- flyte/_deploy.py +202 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +43 -0
- flyte/_group.py +31 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +760 -0
- flyte/_initialize.py +634 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +115 -0
- flyte/_internal/controllers/_local_controller.py +118 -0
- flyte/_internal/controllers/_trace.py +40 -0
- flyte/_internal/controllers/pbhash.py +39 -0
- flyte/_internal/controllers/remote/__init__.py +40 -0
- flyte/_internal/controllers/remote/_action.py +141 -0
- flyte/_internal/controllers/remote/_client.py +43 -0
- flyte/_internal/controllers/remote/_controller.py +361 -0
- flyte/_internal/controllers/remote/_core.py +402 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +416 -0
- flyte/_internal/imagebuild/image_builder.py +241 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +199 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +210 -0
- flyte/_internal/runtime/taskrunner.py +190 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +124 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +69 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +106 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +133 -0
- flyte/_protos/workflow/run_service_pb2.pyi +175 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +58 -0
- flyte/_protos/workflow/state_service_pb2.pyi +71 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +72 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +44 -0
- flyte/_protos/workflow/task_service_pb2.pyi +31 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +411 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +367 -0
- flyte/_task_environment.py +200 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +128 -0
- flyte/_utils/__init__.py +20 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/coro_management.py +25 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +108 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +143 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +273 -0
- flyte/io/__init__.py +11 -0
- flyte/io/_dataframe.py +0 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +468 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/io/pickle/transformer.py +117 -0
- flyte/io/structured_dataset/__init__.py +129 -0
- flyte/io/structured_dataset/basic_dfs.py +219 -0
- flyte/io/structured_dataset/structured_dataset.py +1061 -0
- flyte/py.typed +0 -0
- flyte/remote/__init__.py +25 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +131 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +184 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +155 -0
- flyte/remote/_logs.py +116 -0
- flyte/remote/_project.py +86 -0
- flyte/remote/_run.py +873 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +227 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +24 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +251 -0
- flyte/storage/_utils.py +5 -0
- flyte/types/__init__.py +13 -0
- flyte/types/_interface.py +25 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2210 -0
- flyte/types/_utils.py +80 -0
- flyte-0.0.1b0.dist-info/METADATA +179 -0
- flyte-0.0.1b0.dist-info/RECORD +390 -0
- flyte-0.0.1b0.dist-info/WHEEL +5 -0
- flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
- flyte-0.0.1b0.dist-info/top_level.txt +1 -0
- union/__init__.py +54 -0
- union/_api_commons.py +3 -0
- union/_bin/__init__.py +0 -0
- union/_bin/runtime.py +113 -0
- union/_build.py +25 -0
- union/_cache/__init__.py +12 -0
- union/_cache/cache.py +141 -0
- union/_cache/defaults.py +9 -0
- union/_cache/policy_function_body.py +42 -0
- union/_cli/__init__.py +0 -0
- union/_cli/_common.py +263 -0
- union/_cli/_create.py +40 -0
- union/_cli/_delete.py +23 -0
- union/_cli/_deploy.py +120 -0
- union/_cli/_get.py +162 -0
- union/_cli/_params.py +579 -0
- union/_cli/_run.py +150 -0
- union/_cli/main.py +72 -0
- union/_code_bundle/__init__.py +8 -0
- union/_code_bundle/_ignore.py +113 -0
- union/_code_bundle/_packaging.py +187 -0
- union/_code_bundle/_utils.py +342 -0
- union/_code_bundle/bundle.py +176 -0
- union/_context.py +146 -0
- union/_datastructures.py +295 -0
- union/_deploy.py +185 -0
- union/_doc.py +29 -0
- union/_docstring.py +26 -0
- union/_environment.py +43 -0
- union/_group.py +31 -0
- union/_hash.py +23 -0
- union/_image.py +760 -0
- union/_initialize.py +585 -0
- union/_interface.py +84 -0
- union/_internal/__init__.py +3 -0
- union/_internal/controllers/__init__.py +77 -0
- union/_internal/controllers/_local_controller.py +77 -0
- union/_internal/controllers/pbhash.py +39 -0
- union/_internal/controllers/remote/__init__.py +40 -0
- union/_internal/controllers/remote/_action.py +131 -0
- union/_internal/controllers/remote/_client.py +43 -0
- union/_internal/controllers/remote/_controller.py +169 -0
- union/_internal/controllers/remote/_core.py +341 -0
- union/_internal/controllers/remote/_informer.py +260 -0
- union/_internal/controllers/remote/_service_protocol.py +44 -0
- union/_internal/imagebuild/__init__.py +11 -0
- union/_internal/imagebuild/docker_builder.py +416 -0
- union/_internal/imagebuild/image_builder.py +243 -0
- union/_internal/imagebuild/remote_builder.py +0 -0
- union/_internal/resolvers/__init__.py +0 -0
- union/_internal/resolvers/_task_module.py +31 -0
- union/_internal/resolvers/common.py +24 -0
- union/_internal/resolvers/default.py +27 -0
- union/_internal/runtime/__init__.py +0 -0
- union/_internal/runtime/convert.py +163 -0
- union/_internal/runtime/entrypoints.py +121 -0
- union/_internal/runtime/io.py +136 -0
- union/_internal/runtime/resources_serde.py +134 -0
- union/_internal/runtime/task_serde.py +202 -0
- union/_internal/runtime/taskrunner.py +179 -0
- union/_internal/runtime/types_serde.py +53 -0
- union/_logging.py +124 -0
- union/_protos/__init__.py +0 -0
- union/_protos/common/authorization_pb2.py +66 -0
- union/_protos/common/authorization_pb2.pyi +106 -0
- union/_protos/common/authorization_pb2_grpc.py +4 -0
- union/_protos/common/identifier_pb2.py +71 -0
- union/_protos/common/identifier_pb2.pyi +82 -0
- union/_protos/common/identifier_pb2_grpc.py +4 -0
- union/_protos/common/identity_pb2.py +48 -0
- union/_protos/common/identity_pb2.pyi +72 -0
- union/_protos/common/identity_pb2_grpc.py +4 -0
- union/_protos/common/list_pb2.py +36 -0
- union/_protos/common/list_pb2.pyi +69 -0
- union/_protos/common/list_pb2_grpc.py +4 -0
- union/_protos/common/policy_pb2.py +37 -0
- union/_protos/common/policy_pb2.pyi +27 -0
- union/_protos/common/policy_pb2_grpc.py +4 -0
- union/_protos/common/role_pb2.py +37 -0
- union/_protos/common/role_pb2.pyi +51 -0
- union/_protos/common/role_pb2_grpc.py +4 -0
- union/_protos/common/runtime_version_pb2.py +28 -0
- union/_protos/common/runtime_version_pb2.pyi +24 -0
- union/_protos/common/runtime_version_pb2_grpc.py +4 -0
- union/_protos/logs/dataplane/payload_pb2.py +96 -0
- union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- union/_protos/secret/definition_pb2.py +49 -0
- union/_protos/secret/definition_pb2.pyi +93 -0
- union/_protos/secret/definition_pb2_grpc.py +4 -0
- union/_protos/secret/payload_pb2.py +62 -0
- union/_protos/secret/payload_pb2.pyi +94 -0
- union/_protos/secret/payload_pb2_grpc.py +4 -0
- union/_protos/secret/secret_pb2.py +38 -0
- union/_protos/secret/secret_pb2.pyi +6 -0
- union/_protos/secret/secret_pb2_grpc.py +198 -0
- union/_protos/validate/validate/validate_pb2.py +76 -0
- union/_protos/workflow/node_execution_service_pb2.py +26 -0
- union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- union/_protos/workflow/queue_service_pb2.py +75 -0
- union/_protos/workflow/queue_service_pb2.pyi +103 -0
- union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- union/_protos/workflow/run_definition_pb2.py +100 -0
- union/_protos/workflow/run_definition_pb2.pyi +256 -0
- union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/run_logs_service_pb2.py +41 -0
- union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- union/_protos/workflow/run_service_pb2.py +133 -0
- union/_protos/workflow/run_service_pb2.pyi +173 -0
- union/_protos/workflow/run_service_pb2_grpc.py +412 -0
- union/_protos/workflow/state_service_pb2.py +58 -0
- union/_protos/workflow/state_service_pb2.pyi +69 -0
- union/_protos/workflow/state_service_pb2_grpc.py +138 -0
- union/_protos/workflow/task_definition_pb2.py +72 -0
- union/_protos/workflow/task_definition_pb2.pyi +65 -0
- union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/task_service_pb2.py +44 -0
- union/_protos/workflow/task_service_pb2.pyi +31 -0
- union/_protos/workflow/task_service_pb2_grpc.py +104 -0
- union/_resources.py +226 -0
- union/_retry.py +32 -0
- union/_reusable_environment.py +25 -0
- union/_run.py +374 -0
- union/_secret.py +61 -0
- union/_task.py +354 -0
- union/_task_environment.py +186 -0
- union/_timeout.py +47 -0
- union/_tools.py +27 -0
- union/_utils/__init__.py +11 -0
- union/_utils/asyn.py +119 -0
- union/_utils/file_handling.py +71 -0
- union/_utils/helpers.py +46 -0
- union/_utils/lazy_module.py +54 -0
- union/_utils/uv_script_parser.py +49 -0
- union/_version.py +21 -0
- union/connectors/__init__.py +0 -0
- union/errors.py +128 -0
- union/extras/__init__.py +5 -0
- union/extras/_container.py +263 -0
- union/io/__init__.py +11 -0
- union/io/_dataframe.py +0 -0
- union/io/_dir.py +425 -0
- union/io/_file.py +418 -0
- union/io/pickle/__init__.py +0 -0
- union/io/pickle/transformer.py +117 -0
- union/io/structured_dataset/__init__.py +122 -0
- union/io/structured_dataset/basic_dfs.py +219 -0
- union/io/structured_dataset/structured_dataset.py +1057 -0
- union/py.typed +0 -0
- union/remote/__init__.py +23 -0
- union/remote/_client/__init__.py +0 -0
- union/remote/_client/_protocols.py +129 -0
- union/remote/_client/auth/__init__.py +12 -0
- union/remote/_client/auth/_authenticators/__init__.py +0 -0
- union/remote/_client/auth/_authenticators/base.py +391 -0
- union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- union/remote/_client/auth/_authenticators/device_code.py +120 -0
- union/remote/_client/auth/_authenticators/external_command.py +77 -0
- union/remote/_client/auth/_authenticators/factory.py +200 -0
- union/remote/_client/auth/_authenticators/pkce.py +515 -0
- union/remote/_client/auth/_channel.py +184 -0
- union/remote/_client/auth/_client_config.py +83 -0
- union/remote/_client/auth/_default_html.py +32 -0
- union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
- union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
- union/remote/_client/auth/_keyring.py +154 -0
- union/remote/_client/auth/_token_client.py +258 -0
- union/remote/_client/auth/errors.py +16 -0
- union/remote/_client/controlplane.py +86 -0
- union/remote/_data.py +149 -0
- union/remote/_logs.py +74 -0
- union/remote/_project.py +86 -0
- union/remote/_run.py +820 -0
- union/remote/_secret.py +132 -0
- union/remote/_task.py +193 -0
- union/report/__init__.py +3 -0
- union/report/_report.py +178 -0
- union/report/_template.html +124 -0
- union/storage/__init__.py +24 -0
- union/storage/_remote_fs.py +34 -0
- union/storage/_storage.py +247 -0
- union/storage/_utils.py +5 -0
- union/types/__init__.py +11 -0
- union/types/_renderer.py +162 -0
- union/types/_string_literals.py +120 -0
- union/types/_type_engine.py +2131 -0
- union/types/_utils.py +80 -0
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from typing import Any, Literal, Optional, Protocol
|
|
3
|
+
|
|
4
|
+
from union._datastructures import ActionID
|
|
5
|
+
from union._task import TaskTemplate
|
|
6
|
+
|
|
7
|
+
__all__ = ["Controller", "ControllerType", "create_controller", "get_controller"]
|
|
8
|
+
|
|
9
|
+
ControllerType = Literal["local", "remote"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Controller(Protocol):
|
|
13
|
+
"""
|
|
14
|
+
Controller interface, that is used to execute tasks. The implementation of this interface,
|
|
15
|
+
can execute tasks in different ways, such as locally, remotely etc.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
19
|
+
"""
|
|
20
|
+
Submit a node to the controller asynchronously and wait for the result. This is async and will block
|
|
21
|
+
the current coroutine until the result is available.
|
|
22
|
+
"""
|
|
23
|
+
...
|
|
24
|
+
|
|
25
|
+
async def finalize_parent_action(self, action: ActionID):
|
|
26
|
+
"""
|
|
27
|
+
Finalize the parent action. This can be called to cleanup the action and should be called after the parent
|
|
28
|
+
task completes
|
|
29
|
+
:param action: Action ID
|
|
30
|
+
:return:
|
|
31
|
+
"""
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
def stop(self):
|
|
35
|
+
"""
|
|
36
|
+
Stops the engine and should be called when the engine is no longer needed.
|
|
37
|
+
"""
|
|
38
|
+
...
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# Internal state holder
|
|
42
|
+
class _ControllerState:
|
|
43
|
+
controller: Optional[Controller] = None
|
|
44
|
+
lock = threading.Lock()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
async def get_controller() -> Controller:
|
|
48
|
+
"""
|
|
49
|
+
Get the controller instance. Raise an error if it has not been created.
|
|
50
|
+
"""
|
|
51
|
+
if _ControllerState.controller is not None:
|
|
52
|
+
return _ControllerState.controller
|
|
53
|
+
raise RuntimeError("Controller is not initialized. Please call get_or_create_controller() first.")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def create_controller(
|
|
57
|
+
ct: ControllerType,
|
|
58
|
+
**kwargs,
|
|
59
|
+
) -> Controller:
|
|
60
|
+
"""
|
|
61
|
+
Create a new instance of the controller, based on the kind and the given configuration.
|
|
62
|
+
"""
|
|
63
|
+
match ct:
|
|
64
|
+
case "local":
|
|
65
|
+
from ._local_controller import LocalController
|
|
66
|
+
|
|
67
|
+
controller = LocalController()
|
|
68
|
+
case ("remote" | "hybrid"):
|
|
69
|
+
from union._internal.controllers.remote import create_remote_controller
|
|
70
|
+
|
|
71
|
+
controller = create_remote_controller(**kwargs)
|
|
72
|
+
case _:
|
|
73
|
+
raise ValueError(f"{ct} is not a valid controller type.")
|
|
74
|
+
|
|
75
|
+
with _ControllerState.lock:
|
|
76
|
+
_ControllerState.controller = controller
|
|
77
|
+
return controller
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from typing import Any, Tuple
|
|
2
|
+
|
|
3
|
+
from union import storage
|
|
4
|
+
from union._context import internal_ctx
|
|
5
|
+
from union._datastructures import ActionID, RawDataPath
|
|
6
|
+
from union._internal.controllers import pbhash
|
|
7
|
+
from union._internal.runtime.convert import (
|
|
8
|
+
Inputs,
|
|
9
|
+
convert_error_to_native,
|
|
10
|
+
convert_from_native_to_inputs,
|
|
11
|
+
convert_outputs_to_native,
|
|
12
|
+
)
|
|
13
|
+
from union._internal.runtime.entrypoints import direct_dispatch
|
|
14
|
+
from union._logging import log, logger
|
|
15
|
+
from union._task import TaskTemplate
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LocalController:
|
|
19
|
+
def __init__(self):
|
|
20
|
+
logger.debug("LocalController init")
|
|
21
|
+
|
|
22
|
+
def _get_run_params(self, inputs: Inputs) -> Tuple[ActionID, RawDataPath]:
|
|
23
|
+
ctx = internal_ctx()
|
|
24
|
+
parent_run = ctx.data.task_context.action
|
|
25
|
+
# TODO assuming the raw_data_path is local, and for now not getting manipulated by the controller
|
|
26
|
+
# We will need to change this in case of remote execution, or create data sandboxes.
|
|
27
|
+
new_raw_data_path = ctx.data.raw_data_path
|
|
28
|
+
# TODO ideally we should generate the name deterministically using the inputs etc
|
|
29
|
+
sub_run = parent_run.new_sub_action()
|
|
30
|
+
return sub_run, new_raw_data_path
|
|
31
|
+
|
|
32
|
+
@log
|
|
33
|
+
async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
34
|
+
"""
|
|
35
|
+
Submit a node to the controller
|
|
36
|
+
"""
|
|
37
|
+
ctx = internal_ctx()
|
|
38
|
+
tctx = ctx.data.task_context
|
|
39
|
+
current_action_id = tctx.action
|
|
40
|
+
current_output_path = tctx.output_path
|
|
41
|
+
|
|
42
|
+
inputs = await convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
|
|
43
|
+
inputs_hash = self._input_hash(inputs)
|
|
44
|
+
sub_action_id = current_action_id.new_sub_action_from(
|
|
45
|
+
input_hash=inputs_hash,
|
|
46
|
+
group=tctx.group_data.name if tctx.group_data else None,
|
|
47
|
+
)
|
|
48
|
+
sub_action_output_path = storage.join(current_output_path, sub_action_id.name)
|
|
49
|
+
sub_action_raw_data_path = RawDataPath(path=sub_action_output_path)
|
|
50
|
+
out, err = await direct_dispatch(
|
|
51
|
+
_task,
|
|
52
|
+
controller=self,
|
|
53
|
+
action=sub_action_id,
|
|
54
|
+
raw_data_path=sub_action_raw_data_path,
|
|
55
|
+
inputs=inputs,
|
|
56
|
+
version=tctx.version,
|
|
57
|
+
checkpoints=tctx.checkpoints,
|
|
58
|
+
code_bundle=tctx.code_bundle,
|
|
59
|
+
output_path=sub_action_output_path,
|
|
60
|
+
)
|
|
61
|
+
if err:
|
|
62
|
+
raise convert_error_to_native(err)
|
|
63
|
+
if _task.native_interface.outputs:
|
|
64
|
+
out = await convert_outputs_to_native(_task.native_interface, out)
|
|
65
|
+
return out
|
|
66
|
+
|
|
67
|
+
async def finalize_parent_action(self, action: ActionID):
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
def stop(self):
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
def _input_hash(self, inputs: Inputs) -> str:
|
|
74
|
+
"""
|
|
75
|
+
Returns the hash of the inputs
|
|
76
|
+
"""
|
|
77
|
+
return pbhash.compute_hash_string(inputs.proto_inputs)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# This is a module that provides hashing utilities for Protobuf objects.
|
|
2
|
+
import base64
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
|
|
6
|
+
from google.protobuf import json_format
|
|
7
|
+
from google.protobuf.message import Message
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def compute_hash(pb: Message) -> bytes:
|
|
11
|
+
"""
|
|
12
|
+
Computes a deterministic hash in bytes for the Protobuf object.
|
|
13
|
+
"""
|
|
14
|
+
try:
|
|
15
|
+
pb_dict = json_format.MessageToDict(pb)
|
|
16
|
+
# json.dumps with sorted keys to ensure stability
|
|
17
|
+
stable_json_str = json.dumps(
|
|
18
|
+
pb_dict, sort_keys=True, separators=(",", ":")
|
|
19
|
+
) # separators to ensure no extra spaces
|
|
20
|
+
except Exception as e:
|
|
21
|
+
raise ValueError(f"Failed to marshal Protobuf object {pb} to JSON with error: {e}")
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
# Deterministically hash the JSON object to a byte array. Using SHA-256 for hashing here,
|
|
25
|
+
# assuming it provides a consistent hash output.
|
|
26
|
+
hash_obj = hashlib.sha256(stable_json_str.encode("utf-8"))
|
|
27
|
+
except Exception as e:
|
|
28
|
+
raise ValueError(f"Failed to hash JSON for Protobuf object {pb} with error: {e}")
|
|
29
|
+
|
|
30
|
+
# The digest is guaranteed to be 32 bytes long
|
|
31
|
+
return hash_obj.digest()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def compute_hash_string(pb: Message) -> str:
|
|
35
|
+
"""
|
|
36
|
+
Computes a deterministic hash in base64 encoded string for the Protobuf object
|
|
37
|
+
"""
|
|
38
|
+
hash_bytes = compute_hash(pb)
|
|
39
|
+
return base64.b64encode(hash_bytes).decode("utf-8")
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from union.remote._client.auth import AuthType, ClientConfig
|
|
4
|
+
|
|
5
|
+
from ._controller import RemoteController
|
|
6
|
+
|
|
7
|
+
__all__ = ["RemoteController", "create_remote_controller"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def create_remote_controller(
|
|
11
|
+
*,
|
|
12
|
+
api_key: str | None = None,
|
|
13
|
+
auth_type: AuthType = "Pkce",
|
|
14
|
+
endpoint: str | None = None,
|
|
15
|
+
client_config: ClientConfig | None = None,
|
|
16
|
+
headless: bool = False,
|
|
17
|
+
insecure: bool = False,
|
|
18
|
+
insecure_skip_verify: bool = False,
|
|
19
|
+
ca_cert_file_path: str | None = None,
|
|
20
|
+
command: List[str] | None = None,
|
|
21
|
+
proxy_command: List[str] | None = None,
|
|
22
|
+
client_id: str | None = None,
|
|
23
|
+
client_credentials_secret: str | None = None,
|
|
24
|
+
rpc_retries: int = 3,
|
|
25
|
+
http_proxy_url: str | None = None,
|
|
26
|
+
) -> RemoteController:
|
|
27
|
+
"""
|
|
28
|
+
Create a new instance of the remote controller.
|
|
29
|
+
"""
|
|
30
|
+
from ._client import ControllerClient
|
|
31
|
+
from ._controller import RemoteController
|
|
32
|
+
|
|
33
|
+
controller = RemoteController(
|
|
34
|
+
client_coro=ControllerClient.for_endpoint(
|
|
35
|
+
endpoint=endpoint, insecure=insecure, insecure_skip_verify=insecure_skip_verify
|
|
36
|
+
),
|
|
37
|
+
workers=10,
|
|
38
|
+
max_system_retries=5,
|
|
39
|
+
)
|
|
40
|
+
return controller
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from flyteidl.core import execution_pb2
|
|
6
|
+
|
|
7
|
+
from union._datastructures import GroupData
|
|
8
|
+
from union._protos.workflow import run_definition_pb2, state_service_pb2, task_definition_pb2
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class Action:
|
|
13
|
+
"""
|
|
14
|
+
Coroutine safe, as we never do await operations in any method.
|
|
15
|
+
Holds the inmemory state of a task. It is combined representation of local and remote states.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
action_id: run_definition_pb2.ActionIdentifier
|
|
19
|
+
parent_action_name: str
|
|
20
|
+
friendly_name: str | None = None
|
|
21
|
+
group: GroupData | None = None
|
|
22
|
+
task: task_definition_pb2.TaskSpec | None = None
|
|
23
|
+
inputs_uri: str | None = None
|
|
24
|
+
outputs_uri: str | None = None
|
|
25
|
+
err: execution_pb2.ExecutionError | None = None
|
|
26
|
+
phase: run_definition_pb2.Phase | None = None
|
|
27
|
+
started: bool = False
|
|
28
|
+
retries: int = 0
|
|
29
|
+
client_err: Exception | None = None # This error is set when something goes wrong in the controller.
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def name(self) -> str:
|
|
33
|
+
return self.action_id.name
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def run_name(self) -> str:
|
|
37
|
+
return self.action_id.run.name
|
|
38
|
+
|
|
39
|
+
def is_terminal(self) -> bool:
|
|
40
|
+
"""Check if resource has reached terminal state"""
|
|
41
|
+
if self.phase is None:
|
|
42
|
+
return False
|
|
43
|
+
return self.phase in [
|
|
44
|
+
run_definition_pb2.Phase.PHASE_FAILED,
|
|
45
|
+
run_definition_pb2.Phase.PHASE_SUCCEEDED,
|
|
46
|
+
run_definition_pb2.Phase.PHASE_ABORTED,
|
|
47
|
+
run_definition_pb2.Phase.PHASE_TIMED_OUT,
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
def increment_retries(self):
|
|
51
|
+
self.retries += 1
|
|
52
|
+
|
|
53
|
+
def is_started(self) -> bool:
|
|
54
|
+
"""Check if resource has been started."""
|
|
55
|
+
return self.started
|
|
56
|
+
|
|
57
|
+
def mark_started(self):
|
|
58
|
+
self.started = True
|
|
59
|
+
self.task = None
|
|
60
|
+
|
|
61
|
+
def merge_state(self, obj: state_service_pb2.ActionUpdate):
|
|
62
|
+
"""
|
|
63
|
+
This method is invoked when the watch API sends an update about the state of the action. We need to merge
|
|
64
|
+
the state of the action with the current state of the action. It is possible that we have no phase information
|
|
65
|
+
prior to this.
|
|
66
|
+
:param obj:
|
|
67
|
+
:return:
|
|
68
|
+
"""
|
|
69
|
+
if self.phase != obj.phase:
|
|
70
|
+
self.phase = obj.phase
|
|
71
|
+
self.err = obj.error if obj.HasField("error") else None
|
|
72
|
+
self.started = True
|
|
73
|
+
|
|
74
|
+
def merge_in_action_from_submit(self, action: Action):
|
|
75
|
+
"""
|
|
76
|
+
This method is invoked when parent_action submits an action that was observed previously observed from the
|
|
77
|
+
watch. We need to merge in the contents of the action, while preserving the observed phase.
|
|
78
|
+
|
|
79
|
+
:param action: The submitted action
|
|
80
|
+
"""
|
|
81
|
+
self.outputs_uri = action.outputs_uri
|
|
82
|
+
self.inputs_uri = action.inputs_uri
|
|
83
|
+
self.group = action.group
|
|
84
|
+
self.friendly_name = action.friendly_name
|
|
85
|
+
if not self.started:
|
|
86
|
+
self.task = action.task
|
|
87
|
+
|
|
88
|
+
def set_client_error(self, exc: Exception):
|
|
89
|
+
self.client_err = exc
|
|
90
|
+
|
|
91
|
+
def has_error(self) -> bool:
|
|
92
|
+
return self.client_err is not None or self.err is not None
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def from_task(
|
|
96
|
+
cls,
|
|
97
|
+
parent_action_name: str,
|
|
98
|
+
sub_action_id: run_definition_pb2.ActionIdentifier,
|
|
99
|
+
group_data: GroupData,
|
|
100
|
+
task_spec: task_definition_pb2.TaskSpec,
|
|
101
|
+
inputs_uri: str,
|
|
102
|
+
outputs_prefix_uri: str,
|
|
103
|
+
) -> Action:
|
|
104
|
+
return cls(
|
|
105
|
+
action_id=sub_action_id,
|
|
106
|
+
parent_action_name=parent_action_name,
|
|
107
|
+
friendly_name=task_spec.task_template.id.name,
|
|
108
|
+
group=group_data,
|
|
109
|
+
task=task_spec,
|
|
110
|
+
inputs_uri=inputs_uri,
|
|
111
|
+
outputs_uri=outputs_prefix_uri,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def from_state(cls, parent_action_name: str, obj: state_service_pb2.ActionUpdate) -> Action:
|
|
116
|
+
"""
|
|
117
|
+
This creates a new action, from the watch api. This is possible in the case of a recovery, where the
|
|
118
|
+
stateservice knows about future actions and sends this information to the informer. We may not have encountered
|
|
119
|
+
the "task" itself yet, but we know about the action id and the state of the action.
|
|
120
|
+
|
|
121
|
+
:param parent_action_name:
|
|
122
|
+
:param obj:
|
|
123
|
+
:return:
|
|
124
|
+
"""
|
|
125
|
+
return cls(
|
|
126
|
+
action_id=obj.action_id,
|
|
127
|
+
parent_action_name=parent_action_name,
|
|
128
|
+
phase=obj.phase,
|
|
129
|
+
started=True,
|
|
130
|
+
err=obj.error if obj.HasField("error") else None,
|
|
131
|
+
)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import grpc.aio
|
|
4
|
+
|
|
5
|
+
from union._protos.workflow import queue_service_pb2_grpc, state_service_pb2_grpc
|
|
6
|
+
from union.remote import create_channel
|
|
7
|
+
|
|
8
|
+
from ._service_protocol import QueueService, StateService
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ControllerClient:
|
|
12
|
+
"""
|
|
13
|
+
A client for the Controller API.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, channel: grpc.aio.Channel):
|
|
17
|
+
self._channel = channel
|
|
18
|
+
self._state_service = state_service_pb2_grpc.StateServiceStub(channel=channel)
|
|
19
|
+
self._queue_service = queue_service_pb2_grpc.QueueServiceStub(channel=channel)
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ControllerClient:
|
|
23
|
+
return cls(await create_channel(endpoint, insecure=insecure, **kwargs))
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def state_service(self) -> StateService:
|
|
27
|
+
"""
|
|
28
|
+
The state service.
|
|
29
|
+
"""
|
|
30
|
+
return self._state_service
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def queue_service(self) -> QueueService:
|
|
34
|
+
"""
|
|
35
|
+
The queue service.
|
|
36
|
+
"""
|
|
37
|
+
return self._queue_service
|
|
38
|
+
|
|
39
|
+
def close(self, grace: float | None = None):
|
|
40
|
+
"""
|
|
41
|
+
Close the channel.
|
|
42
|
+
"""
|
|
43
|
+
return self._channel.close(grace=grace)
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from typing import Any, Awaitable
|
|
6
|
+
|
|
7
|
+
import union.storage as storage
|
|
8
|
+
from union._code_bundle import build_pkl_bundle
|
|
9
|
+
from union._context import internal_ctx
|
|
10
|
+
from union._datastructures import ActionID, SerializationContext
|
|
11
|
+
from union._internal.controllers import pbhash
|
|
12
|
+
from union._internal.controllers.remote._action import Action
|
|
13
|
+
from union._internal.controllers.remote._core import Controller
|
|
14
|
+
from union._internal.controllers.remote._service_protocol import ClientSet
|
|
15
|
+
from union._internal.runtime import io
|
|
16
|
+
from union._internal.runtime.convert import (
|
|
17
|
+
Inputs,
|
|
18
|
+
convert_error_to_native,
|
|
19
|
+
convert_from_native_to_inputs,
|
|
20
|
+
convert_outputs_to_native,
|
|
21
|
+
)
|
|
22
|
+
from union._internal.runtime.task_serde import translate_task_to_wire
|
|
23
|
+
from union._logging import logger
|
|
24
|
+
from union._protos.workflow import run_definition_pb2
|
|
25
|
+
from union._task import TaskTemplate
|
|
26
|
+
from union.errors import RuntimeSystemError
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class RemoteController(Controller):
|
|
30
|
+
"""
|
|
31
|
+
This a specialized controller that wraps the core controller and performs IO, serialization and deserialization
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
client_coro: Awaitable[ClientSet],
|
|
37
|
+
workers: int,
|
|
38
|
+
max_system_retries: int,
|
|
39
|
+
default_parent_concurrency: int = 100,
|
|
40
|
+
):
|
|
41
|
+
""" """
|
|
42
|
+
super().__init__(
|
|
43
|
+
client_coro=client_coro,
|
|
44
|
+
workers=workers,
|
|
45
|
+
max_system_retries=max_system_retries,
|
|
46
|
+
)
|
|
47
|
+
self._default_parent_concurrency = default_parent_concurrency
|
|
48
|
+
self._parent_action_semaphore = defaultdict(lambda: asyncio.Semaphore(default_parent_concurrency))
|
|
49
|
+
|
|
50
|
+
async def _submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
51
|
+
ctx = internal_ctx()
|
|
52
|
+
tctx = ctx.data.task_context
|
|
53
|
+
current_action_id = tctx.action
|
|
54
|
+
current_output_path = tctx.output_path
|
|
55
|
+
|
|
56
|
+
inputs = await convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
|
|
57
|
+
inputs_hash = self._input_hash(inputs)
|
|
58
|
+
sub_action_id = current_action_id.new_sub_action_from(
|
|
59
|
+
input_hash=inputs_hash,
|
|
60
|
+
group=tctx.group_data.name if tctx.group_data else None,
|
|
61
|
+
)
|
|
62
|
+
sub_run_output_path = storage.join(current_output_path, sub_action_id.name)
|
|
63
|
+
# In the case of a regular code bundle, we will just pass it down as it is to the downstream tasks
|
|
64
|
+
# It is not allowed to change the code bundle (for regular code bundles) in the middle of a run.
|
|
65
|
+
code_bundle = tctx.code_bundle
|
|
66
|
+
|
|
67
|
+
if code_bundle:
|
|
68
|
+
# but if we are using a pkl bundle, we need to build a new one for the downstream tasks
|
|
69
|
+
if code_bundle.pkl:
|
|
70
|
+
logger.debug(f"Building new pkl bundle for task {sub_action_id.name}")
|
|
71
|
+
code_bundle = await build_pkl_bundle(
|
|
72
|
+
_task,
|
|
73
|
+
upload_to_controlplane=False,
|
|
74
|
+
upload_from_dataplane_path=io.pkl_path(sub_run_output_path),
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
inputs_uri = io.inputs_path(sub_run_output_path)
|
|
78
|
+
try:
|
|
79
|
+
# TODO Add retry decorator to this
|
|
80
|
+
await io.upload_inputs(inputs, inputs_uri)
|
|
81
|
+
except Exception as e:
|
|
82
|
+
logger.exception("Failed to upload inputs", e)
|
|
83
|
+
raise RuntimeSystemError(type(e).__name__, str(e)) from e
|
|
84
|
+
new_serialization_context = SerializationContext(
|
|
85
|
+
project=current_action_id.project,
|
|
86
|
+
domain=current_action_id.domain,
|
|
87
|
+
org=current_action_id.org,
|
|
88
|
+
code_bundle=code_bundle,
|
|
89
|
+
version=tctx.version,
|
|
90
|
+
# supplied version.
|
|
91
|
+
input_path=inputs_uri,
|
|
92
|
+
output_path=sub_run_output_path,
|
|
93
|
+
image_cache=ctx.data.task_context.compiled_image_cache,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
task_spec = translate_task_to_wire(_task, new_serialization_context)
|
|
97
|
+
|
|
98
|
+
action = Action.from_task(
|
|
99
|
+
sub_action_id=run_definition_pb2.ActionIdentifier(
|
|
100
|
+
name=sub_action_id.name,
|
|
101
|
+
run=run_definition_pb2.RunIdentifier(
|
|
102
|
+
name=current_action_id.run_name,
|
|
103
|
+
project=current_action_id.project,
|
|
104
|
+
domain=current_action_id.domain,
|
|
105
|
+
org=current_action_id.org,
|
|
106
|
+
),
|
|
107
|
+
),
|
|
108
|
+
parent_action_name=current_action_id.name,
|
|
109
|
+
group_data=tctx.group_data,
|
|
110
|
+
task_spec=task_spec,
|
|
111
|
+
inputs_uri=inputs_uri,
|
|
112
|
+
outputs_prefix_uri=sub_run_output_path,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
n = await self.submit_action(action)
|
|
116
|
+
|
|
117
|
+
if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED:
|
|
118
|
+
err = n.err or n.client_err
|
|
119
|
+
if not err and n.phase == run_definition_pb2.PHASE_FAILED:
|
|
120
|
+
logger.error(f"Server reported failure for action {n.action_id.name}, checking error file.")
|
|
121
|
+
error_path = io.error_path(n.outputs_uri)
|
|
122
|
+
# It is possible that the error file is not present in the case of a image pull failure or
|
|
123
|
+
# other reasons for failure. Ideally the err message should be sent by the server, but incase its
|
|
124
|
+
# missing, failed with unknown error
|
|
125
|
+
try:
|
|
126
|
+
err = await io.load_error(error_path)
|
|
127
|
+
except Exception as e:
|
|
128
|
+
logger.exception("Failed to load error file", e)
|
|
129
|
+
err = RuntimeSystemError(
|
|
130
|
+
type(e).__name__,
|
|
131
|
+
f"Failed to load error file: {e}",
|
|
132
|
+
)
|
|
133
|
+
else:
|
|
134
|
+
logger.error(f"Server reported failure for action {n.action_id.name}, error: {err}")
|
|
135
|
+
raise convert_error_to_native(err)
|
|
136
|
+
|
|
137
|
+
if _task.native_interface.outputs:
|
|
138
|
+
outputs_file_path = io.outputs_path(n.outputs_uri)
|
|
139
|
+
o = await io.load_outputs(outputs_file_path)
|
|
140
|
+
return await convert_outputs_to_native(_task.native_interface, o)
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
144
|
+
"""
|
|
145
|
+
Submit a task to the remote controller.This creates a new action on the queue service.
|
|
146
|
+
"""
|
|
147
|
+
from union._context import internal_ctx
|
|
148
|
+
|
|
149
|
+
ctx = internal_ctx()
|
|
150
|
+
current_action_id = ctx.data.task_context.action
|
|
151
|
+
async with self._parent_action_semaphore[current_action_id.name]:
|
|
152
|
+
return await self._submit(_task, *args, **kwargs)
|
|
153
|
+
|
|
154
|
+
async def finalize_parent_action(self, action_id: ActionID):
|
|
155
|
+
"""
|
|
156
|
+
This method is invoked when the parent action is finished. It will finalize the run and upload the outputs
|
|
157
|
+
to the control plane.
|
|
158
|
+
"""
|
|
159
|
+
run_id = run_definition_pb2.RunIdentifier(
|
|
160
|
+
name=action_id.run_name,
|
|
161
|
+
project=action_id.project,
|
|
162
|
+
domain=action_id.domain,
|
|
163
|
+
org=action_id.org,
|
|
164
|
+
)
|
|
165
|
+
await super()._finalize_parent_action(run_id=run_id, parent_action_name=action_id.name)
|
|
166
|
+
self._parent_action_semaphore.pop(action_id.name, None)
|
|
167
|
+
|
|
168
|
+
def _input_hash(self, inputs: Inputs) -> str:
|
|
169
|
+
return pbhash.compute_hash_string(inputs.proto_inputs)
|