flyte 0.1.0__py3-none-any.whl → 0.2.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +62 -2
- flyte/_api_commons.py +3 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +126 -0
- flyte/_build.py +25 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +146 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_common.py +299 -0
- flyte/_cli/_create.py +42 -0
- flyte/_cli/_delete.py +23 -0
- flyte/_cli/_deploy.py +140 -0
- flyte/_cli/_get.py +235 -0
- flyte/_cli/_params.py +538 -0
- flyte/_cli/_run.py +174 -0
- flyte/_cli/main.py +98 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +339 -0
- flyte/_code_bundle/bundle.py +178 -0
- flyte/_context.py +146 -0
- flyte/_datastructures.py +342 -0
- flyte/_deploy.py +202 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +43 -0
- flyte/_group.py +31 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +757 -0
- flyte/_initialize.py +643 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +115 -0
- flyte/_internal/controllers/_local_controller.py +118 -0
- flyte/_internal/controllers/_trace.py +40 -0
- flyte/_internal/controllers/pbhash.py +39 -0
- flyte/_internal/controllers/remote/__init__.py +40 -0
- flyte/_internal/controllers/remote/_action.py +141 -0
- flyte/_internal/controllers/remote/_client.py +43 -0
- flyte/_internal/controllers/remote/_controller.py +361 -0
- flyte/_internal/controllers/remote/_core.py +402 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +416 -0
- flyte/_internal/imagebuild/image_builder.py +241 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +205 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +210 -0
- flyte/_internal/runtime/taskrunner.py +190 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +124 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +69 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +106 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +133 -0
- flyte/_protos/workflow/run_service_pb2.pyi +175 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +58 -0
- flyte/_protos/workflow/state_service_pb2.pyi +71 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +72 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +44 -0
- flyte/_protos/workflow/task_service_pb2.pyi +31 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +410 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +367 -0
- flyte/_task_environment.py +200 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +128 -0
- flyte/_utils/__init__.py +20 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/coro_management.py +25 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +108 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/config/__init__.py +168 -0
- flyte/config/_config.py +196 -0
- flyte/config/_internal.py +64 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +143 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +273 -0
- flyte/io/__init__.py +11 -0
- flyte/io/_dataframe.py +0 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +468 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/io/pickle/transformer.py +117 -0
- flyte/io/structured_dataset/__init__.py +129 -0
- flyte/io/structured_dataset/basic_dfs.py +219 -0
- flyte/io/structured_dataset/structured_dataset.py +1061 -0
- flyte/remote/__init__.py +25 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +131 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +184 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +155 -0
- flyte/remote/_logs.py +116 -0
- flyte/remote/_project.py +86 -0
- flyte/remote/_run.py +873 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +227 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +24 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +251 -0
- flyte/storage/_utils.py +5 -0
- flyte/types/__init__.py +13 -0
- flyte/types/_interface.py +25 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2211 -0
- flyte/types/_utils.py +80 -0
- flyte-0.2.0b0.dist-info/METADATA +179 -0
- flyte-0.2.0b0.dist-info/RECORD +204 -0
- {flyte-0.1.0.dist-info → flyte-0.2.0b0.dist-info}/WHEEL +2 -1
- flyte-0.2.0b0.dist-info/entry_points.txt +3 -0
- flyte-0.2.0b0.dist-info/top_level.txt +1 -0
- flyte-0.1.0.dist-info/METADATA +0 -6
- flyte-0.1.0.dist-info/RECORD +0 -5
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Awaitable, DefaultDict, Tuple, TypeVar
|
|
7
|
+
|
|
8
|
+
import flyte
|
|
9
|
+
import flyte.errors
|
|
10
|
+
import flyte.storage as storage
|
|
11
|
+
import flyte.types as types
|
|
12
|
+
from flyte._code_bundle import build_pkl_bundle
|
|
13
|
+
from flyte._context import internal_ctx
|
|
14
|
+
from flyte._datastructures import ActionID, NativeInterface, SerializationContext
|
|
15
|
+
from flyte._internal.controllers import TraceInfo
|
|
16
|
+
from flyte._internal.controllers.remote._action import Action
|
|
17
|
+
from flyte._internal.controllers.remote._core import Controller
|
|
18
|
+
from flyte._internal.controllers.remote._service_protocol import ClientSet
|
|
19
|
+
from flyte._internal.runtime import convert, io
|
|
20
|
+
from flyte._internal.runtime.task_serde import translate_task_to_wire
|
|
21
|
+
from flyte._logging import logger
|
|
22
|
+
from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
|
|
23
|
+
from flyte._task import TaskTemplate
|
|
24
|
+
|
|
25
|
+
R = TypeVar("R")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
async def upload_inputs_with_retry(inputs: convert.Inputs, inputs_uri: str) -> None:
|
|
29
|
+
"""
|
|
30
|
+
Upload inputs to the specified URI with error handling.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
inputs: The inputs to upload
|
|
34
|
+
inputs_uri: The destination URI
|
|
35
|
+
|
|
36
|
+
Raises:
|
|
37
|
+
RuntimeSystemError: If the upload fails
|
|
38
|
+
"""
|
|
39
|
+
try:
|
|
40
|
+
# TODO Add retry decorator to this
|
|
41
|
+
await io.upload_inputs(inputs, inputs_uri)
|
|
42
|
+
except Exception as e:
|
|
43
|
+
logger.exception("Failed to upload inputs", e)
|
|
44
|
+
raise flyte.errors.RuntimeSystemError(type(e).__name__, str(e)) from e
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
async def handle_action_failure(action: Action, task_name: str) -> Exception:
|
|
48
|
+
"""
|
|
49
|
+
Handle action failure by loading error details or raising a RuntimeSystemError.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
action: The updated action
|
|
53
|
+
task_name: The name of the task
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
Exception: The converted native exception or RuntimeSystemError
|
|
57
|
+
"""
|
|
58
|
+
err = action.err or action.client_err
|
|
59
|
+
if not err and action.phase == run_definition_pb2.PHASE_FAILED:
|
|
60
|
+
logger.error(f"Server reported failure for action {action.name}, checking error file.")
|
|
61
|
+
try:
|
|
62
|
+
error_path = io.error_path(f"{action.run_output_base}/{action.action_id.name}/1")
|
|
63
|
+
err = await io.load_error(error_path)
|
|
64
|
+
except Exception as e:
|
|
65
|
+
logger.exception("Failed to load error file", e)
|
|
66
|
+
err = flyte.errors.RuntimeSystemError(type(e).__name__, f"Failed to load error file: {e}")
|
|
67
|
+
else:
|
|
68
|
+
logger.error(f"Server reported failure for action {action.action_id.name}, error: {err}")
|
|
69
|
+
|
|
70
|
+
exc = convert.convert_error_to_native(err)
|
|
71
|
+
if not exc:
|
|
72
|
+
return flyte.errors.RuntimeSystemError("UnableToConvertError", f"Error in task {task_name}: {err}")
|
|
73
|
+
return exc
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
async def load_and_convert_outputs(iface: NativeInterface, realized_outputs_uri: str) -> Any:
|
|
77
|
+
"""
|
|
78
|
+
Load outputs from the given URI and convert them to native format.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
iface: The Native interface
|
|
82
|
+
realized_outputs_uri: The URI where outputs are stored
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
The converted native outputs
|
|
86
|
+
"""
|
|
87
|
+
outputs_file_path = io.outputs_path(realized_outputs_uri)
|
|
88
|
+
outputs = await io.load_outputs(outputs_file_path)
|
|
89
|
+
return await convert.convert_outputs_to_native(iface, outputs)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class RemoteController(Controller):
|
|
93
|
+
"""
|
|
94
|
+
This a specialized controller that wraps the core controller and performs IO, serialization and deserialization
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
client_coro: Awaitable[ClientSet],
|
|
100
|
+
workers: int,
|
|
101
|
+
max_system_retries: int,
|
|
102
|
+
default_parent_concurrency: int = 100,
|
|
103
|
+
):
|
|
104
|
+
""" """
|
|
105
|
+
super().__init__(
|
|
106
|
+
client_coro=client_coro,
|
|
107
|
+
workers=workers,
|
|
108
|
+
max_system_retries=max_system_retries,
|
|
109
|
+
)
|
|
110
|
+
self._default_parent_concurrency = default_parent_concurrency
|
|
111
|
+
self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
|
|
112
|
+
lambda: asyncio.Semaphore(default_parent_concurrency)
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
async def _submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
116
|
+
ctx = internal_ctx()
|
|
117
|
+
tctx = ctx.data.task_context
|
|
118
|
+
if tctx is None:
|
|
119
|
+
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
120
|
+
current_action_id = tctx.action
|
|
121
|
+
|
|
122
|
+
inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
|
|
123
|
+
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, _task.name, inputs)
|
|
124
|
+
|
|
125
|
+
# In the case of a regular code bundle, we will just pass it down as it is to the downstream tasks
|
|
126
|
+
# It is not allowed to change the code bundle (for regular code bundles) in the middle of a run.
|
|
127
|
+
code_bundle = tctx.code_bundle
|
|
128
|
+
|
|
129
|
+
if code_bundle and code_bundle.pkl:
|
|
130
|
+
logger.debug(f"Building new pkl bundle for task {sub_action_id.name}")
|
|
131
|
+
code_bundle = await build_pkl_bundle(
|
|
132
|
+
_task,
|
|
133
|
+
upload_to_controlplane=False,
|
|
134
|
+
upload_from_dataplane_path=io.pkl_path(sub_action_output_path),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
inputs_uri = io.inputs_path(sub_action_output_path)
|
|
138
|
+
await upload_inputs_with_retry(inputs, inputs_uri)
|
|
139
|
+
|
|
140
|
+
root_dir = Path(code_bundle.destination).absolute() if code_bundle else Path.cwd()
|
|
141
|
+
# Don't set output path in sec context because node executor will set it
|
|
142
|
+
new_serialization_context = SerializationContext(
|
|
143
|
+
project=current_action_id.project,
|
|
144
|
+
domain=current_action_id.domain,
|
|
145
|
+
org=current_action_id.org,
|
|
146
|
+
code_bundle=code_bundle,
|
|
147
|
+
version=tctx.version,
|
|
148
|
+
# supplied version.
|
|
149
|
+
input_path=inputs_uri,
|
|
150
|
+
image_cache=tctx.compiled_image_cache,
|
|
151
|
+
root_dir=root_dir,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
task_spec = translate_task_to_wire(_task, new_serialization_context)
|
|
155
|
+
|
|
156
|
+
action = Action.from_task(
|
|
157
|
+
sub_action_id=run_definition_pb2.ActionIdentifier(
|
|
158
|
+
name=sub_action_id.name,
|
|
159
|
+
run=run_definition_pb2.RunIdentifier(
|
|
160
|
+
name=current_action_id.run_name,
|
|
161
|
+
project=current_action_id.project,
|
|
162
|
+
domain=current_action_id.domain,
|
|
163
|
+
org=current_action_id.org,
|
|
164
|
+
),
|
|
165
|
+
),
|
|
166
|
+
parent_action_name=current_action_id.name,
|
|
167
|
+
group_data=tctx.group_data,
|
|
168
|
+
task_spec=task_spec,
|
|
169
|
+
inputs_uri=inputs_uri,
|
|
170
|
+
run_output_base=tctx.run_base_dir,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
try:
|
|
174
|
+
logger.info(
|
|
175
|
+
f"Submitting action Run:[{action.run_name}, Parent:[{action.parent_action_name}], "
|
|
176
|
+
f"task:[{_task.name}], action:[{action.name}]"
|
|
177
|
+
)
|
|
178
|
+
n = await self.submit_action(action)
|
|
179
|
+
logger.info(f"Action for task [{_task.name}] action id: {action.name}, completed!")
|
|
180
|
+
except asyncio.CancelledError:
|
|
181
|
+
# If the action is cancelled, we need to cancel the action on the server as well
|
|
182
|
+
logger.info(f"Action {action.action_id.name} cancelled, cancelling on server")
|
|
183
|
+
await self.cancel_action(action)
|
|
184
|
+
raise
|
|
185
|
+
|
|
186
|
+
if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED:
|
|
187
|
+
exc = await handle_action_failure(action, _task.name)
|
|
188
|
+
raise exc
|
|
189
|
+
|
|
190
|
+
if _task.native_interface.outputs:
|
|
191
|
+
if not n.realized_outputs_uri:
|
|
192
|
+
raise flyte.errors.RuntimeSystemError(
|
|
193
|
+
"RuntimeError",
|
|
194
|
+
f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.",
|
|
195
|
+
)
|
|
196
|
+
return await load_and_convert_outputs(_task.native_interface, n.realized_outputs_uri)
|
|
197
|
+
return None
|
|
198
|
+
|
|
199
|
+
async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
200
|
+
"""
|
|
201
|
+
Submit a task to the remote controller.This creates a new action on the queue service.
|
|
202
|
+
"""
|
|
203
|
+
ctx = internal_ctx()
|
|
204
|
+
tctx = ctx.data.task_context
|
|
205
|
+
if tctx is None:
|
|
206
|
+
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
207
|
+
current_action_id = tctx.action
|
|
208
|
+
async with self._parent_action_semaphore[current_action_id.name]:
|
|
209
|
+
return await self._submit(_task, *args, **kwargs)
|
|
210
|
+
|
|
211
|
+
async def finalize_parent_action(self, action_id: ActionID):
|
|
212
|
+
"""
|
|
213
|
+
This method is invoked when the parent action is finished. It will finalize the run and upload the outputs
|
|
214
|
+
to the control plane.
|
|
215
|
+
"""
|
|
216
|
+
run_id = run_definition_pb2.RunIdentifier(
|
|
217
|
+
name=action_id.run_name,
|
|
218
|
+
project=action_id.project,
|
|
219
|
+
domain=action_id.domain,
|
|
220
|
+
org=action_id.org,
|
|
221
|
+
)
|
|
222
|
+
await super()._finalize_parent_action(run_id=run_id, parent_action_name=action_id.name)
|
|
223
|
+
self._parent_action_semaphore.pop(action_id.name, None)
|
|
224
|
+
|
|
225
|
+
async def get_action_outputs(
|
|
226
|
+
self, _interface: NativeInterface, _func_name: str, *args, **kwargs
|
|
227
|
+
) -> Tuple[TraceInfo, bool]:
|
|
228
|
+
"""
|
|
229
|
+
This method returns the outputs of the action, if it is available.
|
|
230
|
+
If not available it raises a NotFoundError.
|
|
231
|
+
:param _interface: NativeInterface
|
|
232
|
+
:param _func_name: Function name
|
|
233
|
+
:param args: Arguments
|
|
234
|
+
:param kwargs: Keyword arguments
|
|
235
|
+
:return:
|
|
236
|
+
"""
|
|
237
|
+
ctx = internal_ctx()
|
|
238
|
+
tctx = ctx.data.task_context
|
|
239
|
+
if tctx is None:
|
|
240
|
+
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
241
|
+
current_action_id = tctx.action
|
|
242
|
+
|
|
243
|
+
inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
|
|
244
|
+
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, _func_name, inputs)
|
|
245
|
+
|
|
246
|
+
inputs_uri = io.inputs_path(sub_action_output_path)
|
|
247
|
+
await upload_inputs_with_retry(inputs, inputs_uri)
|
|
248
|
+
|
|
249
|
+
prev_action = await self.get_action(
|
|
250
|
+
run_definition_pb2.ActionIdentifier(
|
|
251
|
+
name=sub_action_id.name,
|
|
252
|
+
run=run_definition_pb2.RunIdentifier(
|
|
253
|
+
name=current_action_id.run_name,
|
|
254
|
+
project=current_action_id.project,
|
|
255
|
+
domain=current_action_id.domain,
|
|
256
|
+
org=current_action_id.org,
|
|
257
|
+
),
|
|
258
|
+
),
|
|
259
|
+
current_action_id.name,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
if prev_action is None:
|
|
263
|
+
return TraceInfo(sub_action_id, _interface, inputs_uri), False
|
|
264
|
+
|
|
265
|
+
if prev_action.phase == run_definition_pb2.PHASE_FAILED:
|
|
266
|
+
if prev_action.has_error():
|
|
267
|
+
exc = convert.convert_error_to_native(prev_action.err)
|
|
268
|
+
return TraceInfo(sub_action_id, _interface, inputs_uri, error=exc), True
|
|
269
|
+
else:
|
|
270
|
+
logger.warning(f"Action {prev_action.action_id.name} failed, but no error was found, re-running trace!")
|
|
271
|
+
elif prev_action.realized_outputs_uri is not None:
|
|
272
|
+
outputs_file_path = io.outputs_path(prev_action.realized_outputs_uri)
|
|
273
|
+
o = await io.load_outputs(outputs_file_path)
|
|
274
|
+
outputs = await convert.convert_outputs_to_native(_interface, o)
|
|
275
|
+
return TraceInfo(sub_action_id, _interface, inputs_uri, output=outputs), True
|
|
276
|
+
|
|
277
|
+
return TraceInfo(sub_action_id, _interface, inputs_uri), False
|
|
278
|
+
|
|
279
|
+
async def record_trace(self, info: TraceInfo):
|
|
280
|
+
"""
|
|
281
|
+
Record a trace action. This is used to record the trace of the action and should be called when the action
|
|
282
|
+
:param info:
|
|
283
|
+
:return:
|
|
284
|
+
"""
|
|
285
|
+
ctx = internal_ctx()
|
|
286
|
+
tctx = ctx.data.task_context
|
|
287
|
+
if tctx is None:
|
|
288
|
+
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
289
|
+
|
|
290
|
+
current_output_path = tctx.output_path
|
|
291
|
+
sub_run_output_path = storage.join(current_output_path, info.action.name)
|
|
292
|
+
|
|
293
|
+
if info.interface.has_outputs():
|
|
294
|
+
if info.output:
|
|
295
|
+
outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
|
|
296
|
+
outputs_file_path = io.outputs_path(sub_run_output_path)
|
|
297
|
+
await io.upload_outputs(outputs, outputs_file_path)
|
|
298
|
+
elif info.error:
|
|
299
|
+
err = convert.convert_from_native_to_error(info.error)
|
|
300
|
+
error_path = io.error_path(sub_run_output_path)
|
|
301
|
+
await io.upload_error(err.err, error_path)
|
|
302
|
+
else:
|
|
303
|
+
raise flyte.errors.RuntimeSystemError("BadTraceInfo", "Trace info does not have output or error")
|
|
304
|
+
|
|
305
|
+
async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
|
|
306
|
+
ctx = internal_ctx()
|
|
307
|
+
tctx = ctx.data.task_context
|
|
308
|
+
if tctx is None:
|
|
309
|
+
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
310
|
+
current_action_id = tctx.action
|
|
311
|
+
task_name = _task.spec.task_template.id.name
|
|
312
|
+
|
|
313
|
+
native_interface = types.guess_interface(_task.spec.task_template.interface)
|
|
314
|
+
inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs)
|
|
315
|
+
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, task_name, inputs)
|
|
316
|
+
|
|
317
|
+
inputs_uri = io.inputs_path(sub_action_output_path)
|
|
318
|
+
await upload_inputs_with_retry(inputs, inputs_uri)
|
|
319
|
+
|
|
320
|
+
action = Action.from_task(
|
|
321
|
+
sub_action_id=run_definition_pb2.ActionIdentifier(
|
|
322
|
+
name=sub_action_id.name,
|
|
323
|
+
run=run_definition_pb2.RunIdentifier(
|
|
324
|
+
name=current_action_id.run_name,
|
|
325
|
+
project=current_action_id.project,
|
|
326
|
+
domain=current_action_id.domain,
|
|
327
|
+
org=current_action_id.org,
|
|
328
|
+
),
|
|
329
|
+
),
|
|
330
|
+
parent_action_name=current_action_id.name,
|
|
331
|
+
group_data=tctx.group_data,
|
|
332
|
+
task_spec=_task.spec,
|
|
333
|
+
inputs_uri=inputs_uri,
|
|
334
|
+
run_output_base=tctx.run_base_dir,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
try:
|
|
338
|
+
logger.info(
|
|
339
|
+
f"Submitting action Run:[{action.run_name}, Parent:[{action.parent_action_name}], "
|
|
340
|
+
f"task:[{task_name}], action:[{action.name}]"
|
|
341
|
+
)
|
|
342
|
+
n = await self.submit_action(action)
|
|
343
|
+
logger.info(f"Action for task [{task_name}] action id: {action.name}, completed!")
|
|
344
|
+
except asyncio.CancelledError:
|
|
345
|
+
# If the action is cancelled, we need to cancel the action on the server as well
|
|
346
|
+
logger.info(f"Action {action.action_id.name} cancelled, cancelling on server")
|
|
347
|
+
await self.cancel_action(action)
|
|
348
|
+
raise
|
|
349
|
+
|
|
350
|
+
if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED:
|
|
351
|
+
exc = await handle_action_failure(action, task_name)
|
|
352
|
+
raise exc
|
|
353
|
+
|
|
354
|
+
if native_interface.outputs:
|
|
355
|
+
if not n.realized_outputs_uri:
|
|
356
|
+
raise flyte.errors.RuntimeSystemError(
|
|
357
|
+
"RuntimeError",
|
|
358
|
+
f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.",
|
|
359
|
+
)
|
|
360
|
+
return await load_and_convert_outputs(native_interface, n.realized_outputs_uri)
|
|
361
|
+
return None
|