flyte 0.0.1b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +62 -0
- flyte/_api_commons.py +3 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +126 -0
- flyte/_build.py +25 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +146 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_common.py +287 -0
- flyte/_cli/_create.py +42 -0
- flyte/_cli/_delete.py +23 -0
- flyte/_cli/_deploy.py +140 -0
- flyte/_cli/_get.py +235 -0
- flyte/_cli/_run.py +152 -0
- flyte/_cli/main.py +72 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +339 -0
- flyte/_code_bundle/bundle.py +178 -0
- flyte/_context.py +146 -0
- flyte/_datastructures.py +342 -0
- flyte/_deploy.py +202 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +43 -0
- flyte/_group.py +31 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +760 -0
- flyte/_initialize.py +634 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +115 -0
- flyte/_internal/controllers/_local_controller.py +118 -0
- flyte/_internal/controllers/_trace.py +40 -0
- flyte/_internal/controllers/pbhash.py +39 -0
- flyte/_internal/controllers/remote/__init__.py +40 -0
- flyte/_internal/controllers/remote/_action.py +141 -0
- flyte/_internal/controllers/remote/_client.py +43 -0
- flyte/_internal/controllers/remote/_controller.py +361 -0
- flyte/_internal/controllers/remote/_core.py +402 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +416 -0
- flyte/_internal/imagebuild/image_builder.py +241 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +199 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +210 -0
- flyte/_internal/runtime/taskrunner.py +190 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +124 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +69 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +106 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +133 -0
- flyte/_protos/workflow/run_service_pb2.pyi +175 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +58 -0
- flyte/_protos/workflow/state_service_pb2.pyi +71 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +72 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +44 -0
- flyte/_protos/workflow/task_service_pb2.pyi +31 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +411 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +367 -0
- flyte/_task_environment.py +200 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +128 -0
- flyte/_utils/__init__.py +20 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/coro_management.py +25 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +108 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +143 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +273 -0
- flyte/io/__init__.py +11 -0
- flyte/io/_dataframe.py +0 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +468 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/io/pickle/transformer.py +117 -0
- flyte/io/structured_dataset/__init__.py +129 -0
- flyte/io/structured_dataset/basic_dfs.py +219 -0
- flyte/io/structured_dataset/structured_dataset.py +1061 -0
- flyte/py.typed +0 -0
- flyte/remote/__init__.py +25 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +131 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +184 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +155 -0
- flyte/remote/_logs.py +116 -0
- flyte/remote/_project.py +86 -0
- flyte/remote/_run.py +873 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +227 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +24 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +251 -0
- flyte/storage/_utils.py +5 -0
- flyte/types/__init__.py +13 -0
- flyte/types/_interface.py +25 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2210 -0
- flyte/types/_utils.py +80 -0
- flyte-0.0.1b0.dist-info/METADATA +179 -0
- flyte-0.0.1b0.dist-info/RECORD +390 -0
- flyte-0.0.1b0.dist-info/WHEEL +5 -0
- flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
- flyte-0.0.1b0.dist-info/top_level.txt +1 -0
- union/__init__.py +54 -0
- union/_api_commons.py +3 -0
- union/_bin/__init__.py +0 -0
- union/_bin/runtime.py +113 -0
- union/_build.py +25 -0
- union/_cache/__init__.py +12 -0
- union/_cache/cache.py +141 -0
- union/_cache/defaults.py +9 -0
- union/_cache/policy_function_body.py +42 -0
- union/_cli/__init__.py +0 -0
- union/_cli/_common.py +263 -0
- union/_cli/_create.py +40 -0
- union/_cli/_delete.py +23 -0
- union/_cli/_deploy.py +120 -0
- union/_cli/_get.py +162 -0
- union/_cli/_params.py +579 -0
- union/_cli/_run.py +150 -0
- union/_cli/main.py +72 -0
- union/_code_bundle/__init__.py +8 -0
- union/_code_bundle/_ignore.py +113 -0
- union/_code_bundle/_packaging.py +187 -0
- union/_code_bundle/_utils.py +342 -0
- union/_code_bundle/bundle.py +176 -0
- union/_context.py +146 -0
- union/_datastructures.py +295 -0
- union/_deploy.py +185 -0
- union/_doc.py +29 -0
- union/_docstring.py +26 -0
- union/_environment.py +43 -0
- union/_group.py +31 -0
- union/_hash.py +23 -0
- union/_image.py +760 -0
- union/_initialize.py +585 -0
- union/_interface.py +84 -0
- union/_internal/__init__.py +3 -0
- union/_internal/controllers/__init__.py +77 -0
- union/_internal/controllers/_local_controller.py +77 -0
- union/_internal/controllers/pbhash.py +39 -0
- union/_internal/controllers/remote/__init__.py +40 -0
- union/_internal/controllers/remote/_action.py +131 -0
- union/_internal/controllers/remote/_client.py +43 -0
- union/_internal/controllers/remote/_controller.py +169 -0
- union/_internal/controllers/remote/_core.py +341 -0
- union/_internal/controllers/remote/_informer.py +260 -0
- union/_internal/controllers/remote/_service_protocol.py +44 -0
- union/_internal/imagebuild/__init__.py +11 -0
- union/_internal/imagebuild/docker_builder.py +416 -0
- union/_internal/imagebuild/image_builder.py +243 -0
- union/_internal/imagebuild/remote_builder.py +0 -0
- union/_internal/resolvers/__init__.py +0 -0
- union/_internal/resolvers/_task_module.py +31 -0
- union/_internal/resolvers/common.py +24 -0
- union/_internal/resolvers/default.py +27 -0
- union/_internal/runtime/__init__.py +0 -0
- union/_internal/runtime/convert.py +163 -0
- union/_internal/runtime/entrypoints.py +121 -0
- union/_internal/runtime/io.py +136 -0
- union/_internal/runtime/resources_serde.py +134 -0
- union/_internal/runtime/task_serde.py +202 -0
- union/_internal/runtime/taskrunner.py +179 -0
- union/_internal/runtime/types_serde.py +53 -0
- union/_logging.py +124 -0
- union/_protos/__init__.py +0 -0
- union/_protos/common/authorization_pb2.py +66 -0
- union/_protos/common/authorization_pb2.pyi +106 -0
- union/_protos/common/authorization_pb2_grpc.py +4 -0
- union/_protos/common/identifier_pb2.py +71 -0
- union/_protos/common/identifier_pb2.pyi +82 -0
- union/_protos/common/identifier_pb2_grpc.py +4 -0
- union/_protos/common/identity_pb2.py +48 -0
- union/_protos/common/identity_pb2.pyi +72 -0
- union/_protos/common/identity_pb2_grpc.py +4 -0
- union/_protos/common/list_pb2.py +36 -0
- union/_protos/common/list_pb2.pyi +69 -0
- union/_protos/common/list_pb2_grpc.py +4 -0
- union/_protos/common/policy_pb2.py +37 -0
- union/_protos/common/policy_pb2.pyi +27 -0
- union/_protos/common/policy_pb2_grpc.py +4 -0
- union/_protos/common/role_pb2.py +37 -0
- union/_protos/common/role_pb2.pyi +51 -0
- union/_protos/common/role_pb2_grpc.py +4 -0
- union/_protos/common/runtime_version_pb2.py +28 -0
- union/_protos/common/runtime_version_pb2.pyi +24 -0
- union/_protos/common/runtime_version_pb2_grpc.py +4 -0
- union/_protos/logs/dataplane/payload_pb2.py +96 -0
- union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- union/_protos/secret/definition_pb2.py +49 -0
- union/_protos/secret/definition_pb2.pyi +93 -0
- union/_protos/secret/definition_pb2_grpc.py +4 -0
- union/_protos/secret/payload_pb2.py +62 -0
- union/_protos/secret/payload_pb2.pyi +94 -0
- union/_protos/secret/payload_pb2_grpc.py +4 -0
- union/_protos/secret/secret_pb2.py +38 -0
- union/_protos/secret/secret_pb2.pyi +6 -0
- union/_protos/secret/secret_pb2_grpc.py +198 -0
- union/_protos/validate/validate/validate_pb2.py +76 -0
- union/_protos/workflow/node_execution_service_pb2.py +26 -0
- union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- union/_protos/workflow/queue_service_pb2.py +75 -0
- union/_protos/workflow/queue_service_pb2.pyi +103 -0
- union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- union/_protos/workflow/run_definition_pb2.py +100 -0
- union/_protos/workflow/run_definition_pb2.pyi +256 -0
- union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/run_logs_service_pb2.py +41 -0
- union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- union/_protos/workflow/run_service_pb2.py +133 -0
- union/_protos/workflow/run_service_pb2.pyi +173 -0
- union/_protos/workflow/run_service_pb2_grpc.py +412 -0
- union/_protos/workflow/state_service_pb2.py +58 -0
- union/_protos/workflow/state_service_pb2.pyi +69 -0
- union/_protos/workflow/state_service_pb2_grpc.py +138 -0
- union/_protos/workflow/task_definition_pb2.py +72 -0
- union/_protos/workflow/task_definition_pb2.pyi +65 -0
- union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/task_service_pb2.py +44 -0
- union/_protos/workflow/task_service_pb2.pyi +31 -0
- union/_protos/workflow/task_service_pb2_grpc.py +104 -0
- union/_resources.py +226 -0
- union/_retry.py +32 -0
- union/_reusable_environment.py +25 -0
- union/_run.py +374 -0
- union/_secret.py +61 -0
- union/_task.py +354 -0
- union/_task_environment.py +186 -0
- union/_timeout.py +47 -0
- union/_tools.py +27 -0
- union/_utils/__init__.py +11 -0
- union/_utils/asyn.py +119 -0
- union/_utils/file_handling.py +71 -0
- union/_utils/helpers.py +46 -0
- union/_utils/lazy_module.py +54 -0
- union/_utils/uv_script_parser.py +49 -0
- union/_version.py +21 -0
- union/connectors/__init__.py +0 -0
- union/errors.py +128 -0
- union/extras/__init__.py +5 -0
- union/extras/_container.py +263 -0
- union/io/__init__.py +11 -0
- union/io/_dataframe.py +0 -0
- union/io/_dir.py +425 -0
- union/io/_file.py +418 -0
- union/io/pickle/__init__.py +0 -0
- union/io/pickle/transformer.py +117 -0
- union/io/structured_dataset/__init__.py +122 -0
- union/io/structured_dataset/basic_dfs.py +219 -0
- union/io/structured_dataset/structured_dataset.py +1057 -0
- union/py.typed +0 -0
- union/remote/__init__.py +23 -0
- union/remote/_client/__init__.py +0 -0
- union/remote/_client/_protocols.py +129 -0
- union/remote/_client/auth/__init__.py +12 -0
- union/remote/_client/auth/_authenticators/__init__.py +0 -0
- union/remote/_client/auth/_authenticators/base.py +391 -0
- union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- union/remote/_client/auth/_authenticators/device_code.py +120 -0
- union/remote/_client/auth/_authenticators/external_command.py +77 -0
- union/remote/_client/auth/_authenticators/factory.py +200 -0
- union/remote/_client/auth/_authenticators/pkce.py +515 -0
- union/remote/_client/auth/_channel.py +184 -0
- union/remote/_client/auth/_client_config.py +83 -0
- union/remote/_client/auth/_default_html.py +32 -0
- union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
- union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
- union/remote/_client/auth/_keyring.py +154 -0
- union/remote/_client/auth/_token_client.py +258 -0
- union/remote/_client/auth/errors.py +16 -0
- union/remote/_client/controlplane.py +86 -0
- union/remote/_data.py +149 -0
- union/remote/_logs.py +74 -0
- union/remote/_project.py +86 -0
- union/remote/_run.py +820 -0
- union/remote/_secret.py +132 -0
- union/remote/_task.py +193 -0
- union/report/__init__.py +3 -0
- union/report/_report.py +178 -0
- union/report/_template.html +124 -0
- union/storage/__init__.py +24 -0
- union/storage/_remote_fs.py +34 -0
- union/storage/_storage.py +247 -0
- union/storage/_utils.py +5 -0
- union/types/__init__.py +11 -0
- union/types/_renderer.py +162 -0
- union/types/_string_literals.py +120 -0
- union/types/_type_engine.py +2131 -0
- union/types/_utils.py +80 -0
union/remote/_run.py
ADDED
|
@@ -0,0 +1,820 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Any, AsyncGenerator, Iterator, Literal, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import grpc
|
|
8
|
+
import rich.repr
|
|
9
|
+
from google.protobuf import timestamp
|
|
10
|
+
from rich.console import Console
|
|
11
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
|
|
12
|
+
|
|
13
|
+
from union._api_commons import syncer
|
|
14
|
+
from union._initialize import get_client, get_common_config, requires_client
|
|
15
|
+
from union._protos.common import identifier_pb2, list_pb2
|
|
16
|
+
from union._protos.workflow import run_definition_pb2, run_service_pb2
|
|
17
|
+
|
|
18
|
+
from ._logs import Logs
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _action_time_phase(action: run_definition_pb2.Action | run_definition_pb2.ActionDetails) -> rich.repr.Result:
|
|
22
|
+
"""
|
|
23
|
+
Rich representation of the action time and phase.
|
|
24
|
+
"""
|
|
25
|
+
start_time = timestamp.to_datetime(action.status.start_time)
|
|
26
|
+
yield "start_time", start_time.isoformat()
|
|
27
|
+
if action.status.phase in [
|
|
28
|
+
run_definition_pb2.PHASE_FAILED,
|
|
29
|
+
run_definition_pb2.PHASE_SUCCEEDED,
|
|
30
|
+
run_definition_pb2.PHASE_ABORTED,
|
|
31
|
+
run_definition_pb2.PHASE_TIMED_OUT,
|
|
32
|
+
]:
|
|
33
|
+
end_time = timestamp.to_datetime(action.status.end_time)
|
|
34
|
+
yield "end_time", end_time.isoformat()
|
|
35
|
+
yield "run_time", f"{(end_time - start_time).seconds} secs"
|
|
36
|
+
else:
|
|
37
|
+
yield "end_time", None
|
|
38
|
+
yield "run_time", f"{(datetime.now() - start_time).seconds} secs"
|
|
39
|
+
yield "phase", run_definition_pb2.Phase.Name(action.status.phase)
|
|
40
|
+
if isinstance(action, run_definition_pb2.ActionDetails):
|
|
41
|
+
yield "error", action.error_info if action.HasField("error_info") else "NA"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _action_rich_repr(action: run_definition_pb2.Action, root: bool = False) -> rich.repr.Result:
|
|
45
|
+
"""
|
|
46
|
+
Rich representation of the action.
|
|
47
|
+
"""
|
|
48
|
+
yield "run-name", action.id.run.name
|
|
49
|
+
yield "name", action.id.name
|
|
50
|
+
yield from _action_time_phase(action)
|
|
51
|
+
yield "task", action.metadata.task_id.name
|
|
52
|
+
if not root:
|
|
53
|
+
yield "group", action.metadata.group
|
|
54
|
+
yield "parent", action.metadata.parent
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _action_details_rich_repr(action: run_definition_pb2.ActionDetails, root: bool = False) -> rich.repr.Result:
|
|
58
|
+
"""
|
|
59
|
+
Rich representation of the action details.
|
|
60
|
+
"""
|
|
61
|
+
yield "name", action.id.run.name
|
|
62
|
+
yield from _action_time_phase(action)
|
|
63
|
+
yield "task", action.metadata.task_id.name
|
|
64
|
+
if not root:
|
|
65
|
+
yield "group", action.metadata.group
|
|
66
|
+
yield "parent", action.metadata.parent
|
|
67
|
+
# TODO attempt info
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _action_done_check(phase: run_definition_pb2.Phase) -> bool:
|
|
71
|
+
"""
|
|
72
|
+
Check if the action is done.
|
|
73
|
+
"""
|
|
74
|
+
return phase in [
|
|
75
|
+
run_definition_pb2.PHASE_FAILED,
|
|
76
|
+
run_definition_pb2.PHASE_SUCCEEDED,
|
|
77
|
+
run_definition_pb2.PHASE_ABORTED,
|
|
78
|
+
run_definition_pb2.PHASE_TIMED_OUT,
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@dataclass
|
|
83
|
+
class Run:
|
|
84
|
+
"""
|
|
85
|
+
A class representing a run of a task. It is used to manage the run of a task and its state on the remote
|
|
86
|
+
Union API.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
pb2: run_definition_pb2.Run
|
|
90
|
+
action: Action | None = field(init=False, default=None)
|
|
91
|
+
_details: RunDetails | None = None
|
|
92
|
+
|
|
93
|
+
def __post_init__(self):
|
|
94
|
+
"""
|
|
95
|
+
Initialize the Run object with the given run definition.
|
|
96
|
+
"""
|
|
97
|
+
if not self.pb2.HasField("action"):
|
|
98
|
+
raise RuntimeError("Run does not have an action")
|
|
99
|
+
self.action = Action(self.pb2.action)
|
|
100
|
+
|
|
101
|
+
@classmethod
|
|
102
|
+
@requires_client
|
|
103
|
+
@syncer.wrap
|
|
104
|
+
async def listall(
|
|
105
|
+
cls,
|
|
106
|
+
filters: str | None = None,
|
|
107
|
+
sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
|
|
108
|
+
) -> Union[Iterator[Run], AsyncGenerator[Run, None]]:
|
|
109
|
+
"""
|
|
110
|
+
Get all runs for the current project and domain.
|
|
111
|
+
|
|
112
|
+
:param filters: The filters to apply to the project list.
|
|
113
|
+
:param sort_by: The sorting criteria for the project list, in the format (field, order).
|
|
114
|
+
:return: An iterator of runs.
|
|
115
|
+
"""
|
|
116
|
+
token = None
|
|
117
|
+
sort_by = sort_by or ("created_at", "asc")
|
|
118
|
+
sort_pb2 = list_pb2.Sort(
|
|
119
|
+
key=sort_by[0], direction=list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING
|
|
120
|
+
)
|
|
121
|
+
cfg = get_common_config()
|
|
122
|
+
while True:
|
|
123
|
+
req = list_pb2.ListRequest(
|
|
124
|
+
limit=100,
|
|
125
|
+
token=token,
|
|
126
|
+
sort_by=sort_pb2,
|
|
127
|
+
)
|
|
128
|
+
resp = await get_client().run_service.ListRuns(
|
|
129
|
+
run_service_pb2.ListRunsRequest(
|
|
130
|
+
request=req,
|
|
131
|
+
org=cfg.org,
|
|
132
|
+
project_id=identifier_pb2.ProjectIdentifier(
|
|
133
|
+
organization=cfg.org,
|
|
134
|
+
domain=cfg.domain,
|
|
135
|
+
name=cfg.project,
|
|
136
|
+
),
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
token = resp.token
|
|
140
|
+
for r in resp.runs:
|
|
141
|
+
yield cls(r)
|
|
142
|
+
if not token:
|
|
143
|
+
break
|
|
144
|
+
|
|
145
|
+
@classmethod
|
|
146
|
+
@requires_client
|
|
147
|
+
@syncer.wrap
|
|
148
|
+
async def get(cls, name: str) -> Run:
|
|
149
|
+
"""
|
|
150
|
+
Get the current run.
|
|
151
|
+
|
|
152
|
+
:return: The current run.
|
|
153
|
+
"""
|
|
154
|
+
run_details: RunDetails = await RunDetails.get.aio(RunDetails, name=name)
|
|
155
|
+
run = run_definition_pb2.Run(
|
|
156
|
+
action=run_definition_pb2.Action(
|
|
157
|
+
id=run_details.action_id,
|
|
158
|
+
metadata=run_details.action_details.pb2.metadata,
|
|
159
|
+
status=run_details.action_details.pb2.status,
|
|
160
|
+
),
|
|
161
|
+
)
|
|
162
|
+
return cls(pb2=run, _details=run_details)
|
|
163
|
+
|
|
164
|
+
def __getattr__(self, item: str) -> Any:
|
|
165
|
+
"""
|
|
166
|
+
Forwards all other attributes to task, causing the task to be fetched!
|
|
167
|
+
"""
|
|
168
|
+
return getattr(self.pb2, item)
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def name(self) -> str:
|
|
172
|
+
"""
|
|
173
|
+
Get the name of the run.
|
|
174
|
+
"""
|
|
175
|
+
return self.pb2.action.id.run.name
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def phase(self) -> str:
|
|
179
|
+
"""
|
|
180
|
+
Get the phase of the run.
|
|
181
|
+
"""
|
|
182
|
+
return run_definition_pb2.Phase.Name(self.action.phase)
|
|
183
|
+
|
|
184
|
+
@syncer.wrap
|
|
185
|
+
async def wait(self, quiet: bool = False) -> None:
|
|
186
|
+
"""
|
|
187
|
+
Wait for the run to complete, displaying a rich progress panel with status transitions,
|
|
188
|
+
time elapsed, and error details in case of failure.
|
|
189
|
+
"""
|
|
190
|
+
if self.done():
|
|
191
|
+
if not quiet:
|
|
192
|
+
console = Console()
|
|
193
|
+
console.print(f"[bold green]Run '{self.name}' is already completed.[/bold green]")
|
|
194
|
+
return
|
|
195
|
+
|
|
196
|
+
console = Console()
|
|
197
|
+
with Progress(
|
|
198
|
+
SpinnerColumn(),
|
|
199
|
+
TextColumn("[progress.description]{task.description}"),
|
|
200
|
+
TimeElapsedColumn(),
|
|
201
|
+
console=console,
|
|
202
|
+
transient=True,
|
|
203
|
+
disable=quiet,
|
|
204
|
+
) as progress:
|
|
205
|
+
task_id = progress.add_task(f"Waiting for run '{self.name}'...", start=False)
|
|
206
|
+
|
|
207
|
+
async for ad in self.watch(cache_data_on_done=True):
|
|
208
|
+
if ad is None:
|
|
209
|
+
break
|
|
210
|
+
|
|
211
|
+
# Update progress description with the current phase
|
|
212
|
+
progress.update(task_id, description=f"Phase: {ad.phase}")
|
|
213
|
+
progress.start_task(task_id)
|
|
214
|
+
|
|
215
|
+
# If the action is done, handle the final state
|
|
216
|
+
if ad.done():
|
|
217
|
+
progress.stop_task(task_id)
|
|
218
|
+
if ad.pb2.status.phase == run_definition_pb2.PHASE_SUCCEEDED:
|
|
219
|
+
console.print(f"[bold green]Run '{self.name}' completed successfully.[/bold green]")
|
|
220
|
+
else:
|
|
221
|
+
console.print(
|
|
222
|
+
f"[bold red]Run '{self.name}' exited unsuccessfully in state {ad.phase}"
|
|
223
|
+
f"with error: {ad.error}[/bold red]"
|
|
224
|
+
)
|
|
225
|
+
break
|
|
226
|
+
|
|
227
|
+
async def watch(self, cache_data_on_done: bool = False) -> AsyncGenerator[ActionDetails, None]:
|
|
228
|
+
"""
|
|
229
|
+
Get the details of the run. This is a placeholder for getting the run details.
|
|
230
|
+
"""
|
|
231
|
+
async for ad in self.action.watch_details(cache_data_on_done=cache_data_on_done):
|
|
232
|
+
if ad is None:
|
|
233
|
+
return
|
|
234
|
+
yield ad
|
|
235
|
+
|
|
236
|
+
async def show_logs(self, attempt: int = 1, max_lines: int = 10):
|
|
237
|
+
return await Logs.create_viewer(action_id=self.action.action_id, attempt=attempt, max_lines=max_lines)
|
|
238
|
+
|
|
239
|
+
async def details(self) -> RunDetails:
|
|
240
|
+
"""
|
|
241
|
+
Get the details of the run. This is a placeholder for getting the run details.
|
|
242
|
+
"""
|
|
243
|
+
if self._details is None:
|
|
244
|
+
self._details = await RunDetails.get_details(RunDetails, self.pb2.action.id.run)
|
|
245
|
+
return self._details
|
|
246
|
+
|
|
247
|
+
@syncer.wrap
|
|
248
|
+
async def cancel(self) -> None:
|
|
249
|
+
"""
|
|
250
|
+
Cancel the run.
|
|
251
|
+
"""
|
|
252
|
+
await get_client().run_service.AbortRun(
|
|
253
|
+
run_service_pb2.AbortRunRequest(
|
|
254
|
+
run_id=self.pb2.action.id.run,
|
|
255
|
+
)
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
def done(self) -> bool:
|
|
259
|
+
"""
|
|
260
|
+
Check if the run is done.
|
|
261
|
+
"""
|
|
262
|
+
return self.action.done()
|
|
263
|
+
|
|
264
|
+
def sync(self) -> Run:
|
|
265
|
+
"""
|
|
266
|
+
Sync the run with the remote server. This is a placeholder for syncing the run.
|
|
267
|
+
"""
|
|
268
|
+
return self
|
|
269
|
+
|
|
270
|
+
# TODO add add_done_callback, maybe implement sync apis etc
|
|
271
|
+
|
|
272
|
+
def __rich_repr__(self) -> rich.repr.Result:
|
|
273
|
+
"""
|
|
274
|
+
Rich representation of the Run object.
|
|
275
|
+
"""
|
|
276
|
+
yield from _action_rich_repr(self.pb2.action, root=True)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
@dataclass
|
|
280
|
+
class RunDetails:
|
|
281
|
+
"""
|
|
282
|
+
A class representing a run of a task. It is used to manage the run of a task and its state on the remote
|
|
283
|
+
Union API.
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
pb2: run_definition_pb2.RunDetails
|
|
287
|
+
action_details: ActionDetails | None = field(init=False, default=None)
|
|
288
|
+
|
|
289
|
+
def __post_init__(self):
|
|
290
|
+
"""
|
|
291
|
+
Initialize the RunDetails object with the given run definition.
|
|
292
|
+
"""
|
|
293
|
+
self.action_details = ActionDetails(self.pb2.action)
|
|
294
|
+
|
|
295
|
+
@classmethod
|
|
296
|
+
@requires_client
|
|
297
|
+
@syncer.wrap
|
|
298
|
+
async def get_details(cls, run_id: run_definition_pb2.RunIdentifier) -> RunDetails:
|
|
299
|
+
"""
|
|
300
|
+
Get the details of the run. This is a placeholder for getting the run details.
|
|
301
|
+
"""
|
|
302
|
+
resp = await get_client().run_service.GetRunDetails(
|
|
303
|
+
run_service_pb2.GetRunDetailsRequest(
|
|
304
|
+
run_id=run_id,
|
|
305
|
+
)
|
|
306
|
+
)
|
|
307
|
+
return cls(resp.details)
|
|
308
|
+
|
|
309
|
+
@classmethod
|
|
310
|
+
@requires_client
|
|
311
|
+
@syncer.wrap
|
|
312
|
+
async def get(cls, uri: str | None = None, /, name: str | None = None) -> RunDetails:
|
|
313
|
+
"""
|
|
314
|
+
Get a run by its ID or name. If both are provided, the ID will take precedence.
|
|
315
|
+
|
|
316
|
+
:param uri: The URI of the run.
|
|
317
|
+
:param name: The name of the run.
|
|
318
|
+
"""
|
|
319
|
+
cfg = get_common_config()
|
|
320
|
+
return await RunDetails.get_details.aio(
|
|
321
|
+
cls,
|
|
322
|
+
run_id=run_definition_pb2.RunIdentifier(
|
|
323
|
+
org=cfg.org,
|
|
324
|
+
project=cfg.project,
|
|
325
|
+
domain=cfg.domain,
|
|
326
|
+
name=name,
|
|
327
|
+
),
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
@property
|
|
331
|
+
def name(self) -> str:
|
|
332
|
+
"""
|
|
333
|
+
Get the name of the action.
|
|
334
|
+
"""
|
|
335
|
+
return self.action_details.run_name
|
|
336
|
+
|
|
337
|
+
@property
|
|
338
|
+
def task_name(self) -> str | None:
|
|
339
|
+
"""
|
|
340
|
+
Get the name of the task.
|
|
341
|
+
"""
|
|
342
|
+
return self.action_details.task_name
|
|
343
|
+
|
|
344
|
+
@property
|
|
345
|
+
def action_id(self) -> run_definition_pb2.ActionIdentifier:
|
|
346
|
+
"""
|
|
347
|
+
Get the action ID.
|
|
348
|
+
"""
|
|
349
|
+
return self.action_details.action_id
|
|
350
|
+
|
|
351
|
+
def done(self) -> bool:
|
|
352
|
+
"""
|
|
353
|
+
Check if the run is in a terminal state (completed or failed). This is a placeholder for checking the
|
|
354
|
+
run state.
|
|
355
|
+
"""
|
|
356
|
+
return self.action_details.done()
|
|
357
|
+
|
|
358
|
+
async def inputs(self) -> ActionInputs:
|
|
359
|
+
"""
|
|
360
|
+
Placeholder for inputs. This can be extended to handle inputs from the run context.
|
|
361
|
+
"""
|
|
362
|
+
return await self.action_details.inputs()
|
|
363
|
+
|
|
364
|
+
async def outputs(self) -> ActionOutputs:
|
|
365
|
+
"""
|
|
366
|
+
Placeholder for outputs. This can be extended to handle outputs from the run context.
|
|
367
|
+
"""
|
|
368
|
+
return await self.action_details.outputs()
|
|
369
|
+
|
|
370
|
+
def __getattr__(self, item: str) -> Any:
|
|
371
|
+
"""
|
|
372
|
+
Forwards all other attributes to task, causing the task to be fetched!
|
|
373
|
+
"""
|
|
374
|
+
return getattr(self.pb2, item)
|
|
375
|
+
|
|
376
|
+
def __rich_repr__(self) -> rich.repr.Result:
|
|
377
|
+
"""
|
|
378
|
+
Rich representation of the Run object.
|
|
379
|
+
"""
|
|
380
|
+
yield from _action_details_rich_repr(self.pb2.action, root=True)
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
@dataclass
|
|
384
|
+
class Action:
|
|
385
|
+
"""
|
|
386
|
+
A class representing an action. It is used to manage the run of a task and its state on the remote Union API.
|
|
387
|
+
"""
|
|
388
|
+
|
|
389
|
+
pb2: run_definition_pb2.Action
|
|
390
|
+
_details: ActionDetails | None = None
|
|
391
|
+
|
|
392
|
+
@classmethod
|
|
393
|
+
@requires_client
|
|
394
|
+
@syncer.wrap
|
|
395
|
+
async def listall(
|
|
396
|
+
cls,
|
|
397
|
+
for_run_name: str,
|
|
398
|
+
filters: str | None = None,
|
|
399
|
+
sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
|
|
400
|
+
) -> Union[Iterator[Action], AsyncGenerator[Action, None]]:
|
|
401
|
+
"""
|
|
402
|
+
Get all actions for a given run.
|
|
403
|
+
|
|
404
|
+
:param for_run_name: The name of the run.
|
|
405
|
+
:param filters: The filters to apply to the project list.
|
|
406
|
+
:param sort_by: The sorting criteria for the project list, in the format (field, order).
|
|
407
|
+
:return: An iterator of projects.
|
|
408
|
+
"""
|
|
409
|
+
token = None
|
|
410
|
+
sort_by = sort_by or ("created_at", "asc")
|
|
411
|
+
sort_pb2 = list_pb2.Sort(
|
|
412
|
+
key=sort_by[0], direction=list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING
|
|
413
|
+
)
|
|
414
|
+
cfg = get_common_config()
|
|
415
|
+
while True:
|
|
416
|
+
req = list_pb2.ListRequest(
|
|
417
|
+
limit=100,
|
|
418
|
+
token=token,
|
|
419
|
+
sort_by=sort_pb2,
|
|
420
|
+
)
|
|
421
|
+
resp = await get_client().run_service.ListActions(
|
|
422
|
+
run_service_pb2.ListActionsRequest(
|
|
423
|
+
request=req,
|
|
424
|
+
run_id=run_definition_pb2.RunIdentifier(
|
|
425
|
+
org=cfg.org,
|
|
426
|
+
project=cfg.project,
|
|
427
|
+
domain=cfg.domain,
|
|
428
|
+
name=for_run_name,
|
|
429
|
+
),
|
|
430
|
+
)
|
|
431
|
+
)
|
|
432
|
+
token = resp.token
|
|
433
|
+
for r in resp.actions:
|
|
434
|
+
yield cls(r)
|
|
435
|
+
if not token:
|
|
436
|
+
break
|
|
437
|
+
|
|
438
|
+
@classmethod
|
|
439
|
+
@requires_client
|
|
440
|
+
@syncer.wrap
|
|
441
|
+
async def get(cls, uri: str | None = None, /, run_name: str | None = None, name: str | None = None) -> Action:
|
|
442
|
+
"""
|
|
443
|
+
Get a run by its ID or name. If both are provided, the ID will take precedence.
|
|
444
|
+
|
|
445
|
+
:param uri: The URI of the action.
|
|
446
|
+
:param run_name: The name of the action.
|
|
447
|
+
:param name: The name of the action.
|
|
448
|
+
"""
|
|
449
|
+
cfg = get_common_config()
|
|
450
|
+
details: ActionDetails = await ActionDetails.get_details.aio(
|
|
451
|
+
cls,
|
|
452
|
+
run_definition_pb2.ActionIdentifier(
|
|
453
|
+
run=run_definition_pb2.RunIdentifier(
|
|
454
|
+
org=cfg.org,
|
|
455
|
+
project=cfg.project,
|
|
456
|
+
domain=cfg.domain,
|
|
457
|
+
name=run_name,
|
|
458
|
+
),
|
|
459
|
+
name=name,
|
|
460
|
+
),
|
|
461
|
+
)
|
|
462
|
+
return cls(
|
|
463
|
+
pb2=run_definition_pb2.Action(
|
|
464
|
+
id=details.action_id,
|
|
465
|
+
metadata=details.pb2.metadata,
|
|
466
|
+
status=details.pb2.status,
|
|
467
|
+
),
|
|
468
|
+
_details=details,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
@property
|
|
472
|
+
def phase(self) -> str:
|
|
473
|
+
"""
|
|
474
|
+
Get the phase of the action.
|
|
475
|
+
"""
|
|
476
|
+
return run_definition_pb2.Phase.Name(self.pb2.status.phase)
|
|
477
|
+
|
|
478
|
+
@property
|
|
479
|
+
def name(self) -> str:
|
|
480
|
+
"""
|
|
481
|
+
Get the name of the action.
|
|
482
|
+
"""
|
|
483
|
+
return self.action_id.name
|
|
484
|
+
|
|
485
|
+
@property
|
|
486
|
+
def run_name(self) -> str:
|
|
487
|
+
"""
|
|
488
|
+
Get the name of the run.
|
|
489
|
+
"""
|
|
490
|
+
return self.action_id.run.name
|
|
491
|
+
|
|
492
|
+
@property
|
|
493
|
+
def task_name(self) -> str | None:
|
|
494
|
+
"""
|
|
495
|
+
Get the name of the task.
|
|
496
|
+
"""
|
|
497
|
+
if self.pb2.metadata.HasField("task_id"):
|
|
498
|
+
return self.pb2.metadata.task_id.name
|
|
499
|
+
return None
|
|
500
|
+
|
|
501
|
+
@property
|
|
502
|
+
def action_id(self) -> run_definition_pb2.ActionIdentifier:
|
|
503
|
+
"""
|
|
504
|
+
Get the action ID.
|
|
505
|
+
"""
|
|
506
|
+
return self.pb2.id
|
|
507
|
+
|
|
508
|
+
async def show_logs(self, attempt: int = 1, max_lines: int = 10):
|
|
509
|
+
return await Logs.create_viewer(action_id=self.action_id, attempt=attempt, max_lines=max_lines)
|
|
510
|
+
|
|
511
|
+
async def details(self) -> ActionDetails:
|
|
512
|
+
"""
|
|
513
|
+
Get the details of the action. This is a placeholder for getting the action details.
|
|
514
|
+
"""
|
|
515
|
+
if not self._details:
|
|
516
|
+
self._details = await ActionDetails.get_details.aio(ActionDetails, self.action_id)
|
|
517
|
+
return self._details
|
|
518
|
+
|
|
519
|
+
async def watch_details(self, cache_data_on_done: bool = False) -> AsyncGenerator[ActionDetails, None]:
|
|
520
|
+
"""
|
|
521
|
+
Watch the action for updates. This is a placeholder for watching the action.
|
|
522
|
+
"""
|
|
523
|
+
ad = None
|
|
524
|
+
async for ad in ActionDetails.watch.aio(ActionDetails, self.action_id):
|
|
525
|
+
if ad is None:
|
|
526
|
+
return
|
|
527
|
+
self._details = ad
|
|
528
|
+
yield ad
|
|
529
|
+
if cache_data_on_done and ad and ad.done():
|
|
530
|
+
await self._details.outputs()
|
|
531
|
+
|
|
532
|
+
def done(self) -> bool:
|
|
533
|
+
"""
|
|
534
|
+
Check if the action is done.
|
|
535
|
+
"""
|
|
536
|
+
return _action_done_check(self.pb2.status.phase)
|
|
537
|
+
|
|
538
|
+
async def sync(self) -> Action:
|
|
539
|
+
"""
|
|
540
|
+
Sync the action with the remote server. This is a placeholder for syncing the action.
|
|
541
|
+
"""
|
|
542
|
+
return self
|
|
543
|
+
|
|
544
|
+
def __getattr__(self, item: str) -> Any:
|
|
545
|
+
"""
|
|
546
|
+
Forwards all other attributes to action.
|
|
547
|
+
"""
|
|
548
|
+
return getattr(self.pb2, item)
|
|
549
|
+
|
|
550
|
+
def __rich_repr__(self) -> rich.repr.Result:
|
|
551
|
+
"""
|
|
552
|
+
Rich representation of the Action object.
|
|
553
|
+
"""
|
|
554
|
+
yield from _action_rich_repr(self.pb2, root=True)
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
@dataclass
|
|
558
|
+
class ActionDetails:
|
|
559
|
+
"""
|
|
560
|
+
A class representing an action. It is used to manage the run of a task and its state on the remote Union API.
|
|
561
|
+
"""
|
|
562
|
+
|
|
563
|
+
pb2: run_definition_pb2.ActionDetails
|
|
564
|
+
_inputs: ActionInputs | None = None
|
|
565
|
+
_outputs: ActionOutputs | None = None
|
|
566
|
+
|
|
567
|
+
@classmethod
|
|
568
|
+
@requires_client
|
|
569
|
+
@syncer.wrap
|
|
570
|
+
async def get_details(cls, action_id: run_definition_pb2.ActionIdentifier) -> ActionDetails:
|
|
571
|
+
"""
|
|
572
|
+
Get the details of the action. This is a placeholder for getting the action details.
|
|
573
|
+
"""
|
|
574
|
+
resp = await get_client().run_service.GetActionDetails(
|
|
575
|
+
run_service_pb2.GetActionDetailsRequest(
|
|
576
|
+
action_id=action_id,
|
|
577
|
+
)
|
|
578
|
+
)
|
|
579
|
+
return ActionDetails(resp.details)
|
|
580
|
+
|
|
581
|
+
@classmethod
|
|
582
|
+
@requires_client
|
|
583
|
+
@syncer.wrap
|
|
584
|
+
async def get(
|
|
585
|
+
cls, uri: str | None = None, /, run_name: str | None = None, name: str | None = None
|
|
586
|
+
) -> ActionDetails:
|
|
587
|
+
"""
|
|
588
|
+
Get a run by its ID or name. If both are provided, the ID will take precedence.
|
|
589
|
+
|
|
590
|
+
:param uri: The URI of the action.
|
|
591
|
+
:param name: The name of the action.
|
|
592
|
+
:param run_name: The name of the run.
|
|
593
|
+
"""
|
|
594
|
+
if not uri:
|
|
595
|
+
assert name is not None and run_name is not None, "Either uri or name and run_name must be provided"
|
|
596
|
+
cfg = get_common_config()
|
|
597
|
+
return await cls.get_details.aio(
|
|
598
|
+
cls,
|
|
599
|
+
run_definition_pb2.ActionIdentifier(
|
|
600
|
+
run=run_definition_pb2.RunIdentifier(
|
|
601
|
+
org=cfg.org,
|
|
602
|
+
project=cfg.project,
|
|
603
|
+
domain=cfg.domain,
|
|
604
|
+
name=run_name,
|
|
605
|
+
),
|
|
606
|
+
name=name,
|
|
607
|
+
),
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
@classmethod
|
|
611
|
+
@requires_client
|
|
612
|
+
@syncer.wrap
|
|
613
|
+
async def watch(cls, action_id: run_definition_pb2.ActionIdentifier) -> AsyncGenerator[ActionDetails, None]:
|
|
614
|
+
"""
|
|
615
|
+
Watch the action for updates. This is a placeholder for watching the action.
|
|
616
|
+
"""
|
|
617
|
+
if not action_id:
|
|
618
|
+
raise ValueError("Action ID is required")
|
|
619
|
+
|
|
620
|
+
call = get_client().run_service.WatchActionDetails(
|
|
621
|
+
request=run_service_pb2.WatchActionDetailsRequest(
|
|
622
|
+
action_id=action_id,
|
|
623
|
+
)
|
|
624
|
+
)
|
|
625
|
+
try:
|
|
626
|
+
v = None
|
|
627
|
+
async for resp in call:
|
|
628
|
+
v = cls(resp.details, _inputs=v.inputs if v else None, _outputs=v.outputs if v else None)
|
|
629
|
+
yield v
|
|
630
|
+
if v.done():
|
|
631
|
+
return
|
|
632
|
+
except grpc.aio.AioRpcError as e:
|
|
633
|
+
if e.code() == grpc.StatusCode.CANCELLED:
|
|
634
|
+
pass
|
|
635
|
+
else:
|
|
636
|
+
raise e
|
|
637
|
+
finally:
|
|
638
|
+
if call is not None:
|
|
639
|
+
if not call.cancelled():
|
|
640
|
+
await call.cancel()
|
|
641
|
+
|
|
642
|
+
async def watch_updates(self, cache_data_on_done: bool = False) -> AsyncGenerator[ActionDetails, None]:
|
|
643
|
+
async for d in self.watch.aio(action_id=self.pb2.id):
|
|
644
|
+
yield d
|
|
645
|
+
if d.done():
|
|
646
|
+
self.pb2 = d.pb2
|
|
647
|
+
break
|
|
648
|
+
|
|
649
|
+
if cache_data_on_done and self.done():
|
|
650
|
+
await self._cache_data.aio(self)
|
|
651
|
+
|
|
652
|
+
@property
|
|
653
|
+
def phase(self) -> str:
|
|
654
|
+
"""
|
|
655
|
+
Get the phase of the action.
|
|
656
|
+
"""
|
|
657
|
+
return run_definition_pb2.Phase.Name(self.status.phase)
|
|
658
|
+
|
|
659
|
+
@property
|
|
660
|
+
def name(self) -> str:
|
|
661
|
+
"""
|
|
662
|
+
Get the name of the action.
|
|
663
|
+
"""
|
|
664
|
+
return self.action_id.name
|
|
665
|
+
|
|
666
|
+
@property
|
|
667
|
+
def run_name(self) -> str:
|
|
668
|
+
"""
|
|
669
|
+
Get the name of the run.
|
|
670
|
+
"""
|
|
671
|
+
return self.action_id.run.name
|
|
672
|
+
|
|
673
|
+
@property
|
|
674
|
+
def task_name(self) -> str | None:
|
|
675
|
+
"""
|
|
676
|
+
Get the name of the task.
|
|
677
|
+
"""
|
|
678
|
+
if self.pb2.metadata.HasField("task_id"):
|
|
679
|
+
return self.pb2.metadata.task_id.name
|
|
680
|
+
return None
|
|
681
|
+
|
|
682
|
+
@property
|
|
683
|
+
def action_id(self) -> run_definition_pb2.ActionIdentifier:
|
|
684
|
+
"""
|
|
685
|
+
Get the action ID.
|
|
686
|
+
"""
|
|
687
|
+
return self.pb2.id
|
|
688
|
+
|
|
689
|
+
@property
|
|
690
|
+
def metadata(self) -> run_definition_pb2.ActionMetadata:
|
|
691
|
+
return self.pb2.metadata
|
|
692
|
+
|
|
693
|
+
@property
|
|
694
|
+
def status(self) -> run_definition_pb2.ActionStatus:
|
|
695
|
+
return self.pb2.status
|
|
696
|
+
|
|
697
|
+
@property
|
|
698
|
+
def error_info(self) -> run_definition_pb2.ErrorInfo | None:
|
|
699
|
+
if self.pb2.HasField("error_info"):
|
|
700
|
+
return self.pb2.error_info
|
|
701
|
+
return None
|
|
702
|
+
|
|
703
|
+
@property
|
|
704
|
+
def abort_info(self) -> run_definition_pb2.AbortInfo | None:
|
|
705
|
+
if self.pb2.HasField("abort_info"):
|
|
706
|
+
return self.pb2.abort_info
|
|
707
|
+
return None
|
|
708
|
+
|
|
709
|
+
@syncer.wrap
|
|
710
|
+
async def _cache_data(self) -> bool:
|
|
711
|
+
"""
|
|
712
|
+
Cache the inputs and outputs of the action.
|
|
713
|
+
:return: Returns True if Action is terminal and all data is cached else False.
|
|
714
|
+
"""
|
|
715
|
+
if self._inputs and self._outputs:
|
|
716
|
+
return True
|
|
717
|
+
if self._inputs and not self.done():
|
|
718
|
+
return False
|
|
719
|
+
resp = await get_client().run_service.GetActionData(
|
|
720
|
+
request=run_service_pb2.GetActionDataRequest(
|
|
721
|
+
action_id=self.pb2.id,
|
|
722
|
+
)
|
|
723
|
+
)
|
|
724
|
+
self._inputs = ActionInputs(resp.inputs)
|
|
725
|
+
self._outputs = ActionOutputs(resp.outputs) if resp.HasField("outputs") else None
|
|
726
|
+
return self._outputs is not None
|
|
727
|
+
|
|
728
|
+
async def inputs(self) -> ActionInputs:
|
|
729
|
+
"""
|
|
730
|
+
Placeholder for inputs. This can be extended to handle inputs from the run context.
|
|
731
|
+
"""
|
|
732
|
+
if not self._inputs:
|
|
733
|
+
await self._cache_data.aio(self)
|
|
734
|
+
return self._inputs
|
|
735
|
+
|
|
736
|
+
async def outputs(self) -> ActionOutputs:
|
|
737
|
+
"""
|
|
738
|
+
Placeholder for outputs. This can be extended to handle outputs from the run context.
|
|
739
|
+
"""
|
|
740
|
+
if not self._outputs:
|
|
741
|
+
if not await self._cache_data.aio(self):
|
|
742
|
+
raise RuntimeError(
|
|
743
|
+
"Action is not in a terminal state, outputs are not available. "
|
|
744
|
+
"Please wait for the action to complete."
|
|
745
|
+
)
|
|
746
|
+
return self._outputs
|
|
747
|
+
|
|
748
|
+
def done(self) -> bool:
|
|
749
|
+
"""
|
|
750
|
+
Check if the action is in a terminal state (completed or failed). This is a placeholder for checking the
|
|
751
|
+
action state.
|
|
752
|
+
"""
|
|
753
|
+
return _action_done_check(self.pb2.status.phase)
|
|
754
|
+
|
|
755
|
+
def __getattr__(self, item: str) -> Any:
|
|
756
|
+
"""
|
|
757
|
+
Forwards all other attributes to action.
|
|
758
|
+
"""
|
|
759
|
+
return getattr(self.pb2, item)
|
|
760
|
+
|
|
761
|
+
def __rich_repr__(self) -> rich.repr.Result:
|
|
762
|
+
"""
|
|
763
|
+
Rich representation of the Action object.
|
|
764
|
+
"""
|
|
765
|
+
yield from _action_details_rich_repr(self.pb2, root=True)
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
@dataclass
|
|
769
|
+
class ActionInputs:
|
|
770
|
+
"""
|
|
771
|
+
A class representing the inputs of an action. It is used to manage the inputs of a task and its state on the
|
|
772
|
+
remote Union API.
|
|
773
|
+
"""
|
|
774
|
+
|
|
775
|
+
pb2: run_definition_pb2.Inputs
|
|
776
|
+
|
|
777
|
+
def __repr__(self):
|
|
778
|
+
import rich.pretty
|
|
779
|
+
|
|
780
|
+
import union.types as types
|
|
781
|
+
|
|
782
|
+
return rich.pretty.pretty_repr(types.literal_string_repr(self.pb2))
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
@dataclass
|
|
786
|
+
class ActionOutputs:
|
|
787
|
+
"""
|
|
788
|
+
A class representing the outputs of an action. It is used to manage the outputs of a task and its state on the
|
|
789
|
+
remote Union API.
|
|
790
|
+
"""
|
|
791
|
+
|
|
792
|
+
pb2: run_definition_pb2.Outputs
|
|
793
|
+
|
|
794
|
+
def __repr__(self):
|
|
795
|
+
import rich.pretty
|
|
796
|
+
|
|
797
|
+
import union.types as types
|
|
798
|
+
|
|
799
|
+
return rich.pretty.pretty_repr(types.literal_string_repr(self.pb2))
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
async def main():
|
|
803
|
+
"""
|
|
804
|
+
Main function to test the Run and RunDetails classes.
|
|
805
|
+
"""
|
|
806
|
+
r = await Run.get.aio(cls=Run, name="random-run-5783fbc8")
|
|
807
|
+
await r.wait()
|
|
808
|
+
|
|
809
|
+
|
|
810
|
+
if __name__ == "__main__":
|
|
811
|
+
import asyncio
|
|
812
|
+
|
|
813
|
+
import rich.console
|
|
814
|
+
|
|
815
|
+
import union
|
|
816
|
+
|
|
817
|
+
union.init(
|
|
818
|
+
endpoint="dns:///localhost:8090", insecure=True, org="testorg", project="testproject", domain="development"
|
|
819
|
+
)
|
|
820
|
+
asyncio.run(main())
|