flyte 0.1.0__py3-none-any.whl → 0.2.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +78 -2
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +152 -0
- flyte/_build.py +26 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +145 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +323 -0
- flyte/_code_bundle/bundle.py +209 -0
- flyte/_context.py +152 -0
- flyte/_deploy.py +243 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +84 -0
- flyte/_excepthook.py +37 -0
- flyte/_group.py +32 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +762 -0
- flyte/_initialize.py +492 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +128 -0
- flyte/_internal/controllers/_local_controller.py +193 -0
- flyte/_internal/controllers/_trace.py +41 -0
- flyte/_internal/controllers/remote/__init__.py +60 -0
- flyte/_internal/controllers/remote/_action.py +146 -0
- flyte/_internal/controllers/remote/_client.py +47 -0
- flyte/_internal/controllers/remote/_controller.py +494 -0
- flyte/_internal/controllers/remote/_core.py +410 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +427 -0
- flyte/_internal/imagebuild/image_builder.py +246 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +342 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +330 -0
- flyte/_internal/runtime/taskrunner.py +191 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +135 -0
- flyte/_map.py +215 -0
- flyte/_pod.py +19 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +71 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +100 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +177 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/common_pb2.py +27 -0
- flyte/_protos/workflow/common_pb2.pyi +14 -0
- flyte/_protos/workflow/common_pb2_grpc.py +4 -0
- flyte/_protos/workflow/environment_pb2.py +29 -0
- flyte/_protos/workflow/environment_pb2.pyi +12 -0
- flyte/_protos/workflow/environment_pb2_grpc.py +4 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +105 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +146 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +314 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +129 -0
- flyte/_protos/workflow/run_service_pb2.pyi +171 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +66 -0
- flyte/_protos/workflow/state_service_pb2.pyi +75 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +79 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +81 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +60 -0
- flyte/_protos/workflow/task_service_pb2.pyi +59 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +138 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +482 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +449 -0
- flyte/_task_environment.py +183 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +120 -0
- flyte/_utils/__init__.py +26 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/async_cache.py +139 -0
- flyte/_utils/coro_management.py +23 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +134 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/org_discovery.py +57 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/cli/__init__.py +3 -0
- flyte/cli/_abort.py +28 -0
- flyte/cli/_common.py +337 -0
- flyte/cli/_create.py +145 -0
- flyte/cli/_delete.py +23 -0
- flyte/cli/_deploy.py +152 -0
- flyte/cli/_gen.py +163 -0
- flyte/cli/_get.py +310 -0
- flyte/cli/_params.py +538 -0
- flyte/cli/_run.py +231 -0
- flyte/cli/main.py +166 -0
- flyte/config/__init__.py +3 -0
- flyte/config/_config.py +216 -0
- flyte/config/_internal.py +64 -0
- flyte/config/_reader.py +207 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +172 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +263 -0
- flyte/io/__init__.py +27 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +467 -0
- flyte/io/_structured_dataset/__init__.py +129 -0
- flyte/io/_structured_dataset/basic_dfs.py +219 -0
- flyte/io/_structured_dataset/structured_dataset.py +1061 -0
- flyte/models.py +391 -0
- flyte/remote/__init__.py +26 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +133 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_auth_utils.py +14 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +215 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +159 -0
- flyte/remote/_logs.py +176 -0
- flyte/remote/_project.py +85 -0
- flyte/remote/_run.py +970 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +391 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +29 -0
- flyte/storage/_config.py +233 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +271 -0
- flyte/storage/_utils.py +5 -0
- flyte/syncify/__init__.py +56 -0
- flyte/syncify/_api.py +371 -0
- flyte/types/__init__.py +36 -0
- flyte/types/_interface.py +40 -0
- flyte/types/_pickle.py +118 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2287 -0
- flyte/types/_utils.py +80 -0
- flyte-0.2.0a0.dist-info/METADATA +249 -0
- flyte-0.2.0a0.dist-info/RECORD +218 -0
- {flyte-0.1.0.dist-info → flyte-0.2.0a0.dist-info}/WHEEL +2 -1
- flyte-0.2.0a0.dist-info/entry_points.txt +3 -0
- flyte-0.2.0a0.dist-info/top_level.txt +1 -0
- flyte-0.1.0.dist-info/METADATA +0 -6
- flyte-0.1.0.dist-info/RECORD +0 -5
|
@@ -0,0 +1,494 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import concurrent.futures
|
|
5
|
+
import os
|
|
6
|
+
import threading
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, AsyncIterable, Awaitable, DefaultDict, Tuple, TypeVar
|
|
11
|
+
|
|
12
|
+
import flyte
|
|
13
|
+
import flyte.errors
|
|
14
|
+
import flyte.storage as storage
|
|
15
|
+
import flyte.types as types
|
|
16
|
+
from flyte._code_bundle import build_pkl_bundle
|
|
17
|
+
from flyte._context import internal_ctx
|
|
18
|
+
from flyte._internal.controllers import TraceInfo
|
|
19
|
+
from flyte._internal.controllers.remote._action import Action
|
|
20
|
+
from flyte._internal.controllers.remote._core import Controller
|
|
21
|
+
from flyte._internal.controllers.remote._service_protocol import ClientSet
|
|
22
|
+
from flyte._internal.runtime import convert, io
|
|
23
|
+
from flyte._internal.runtime.task_serde import translate_task_to_wire
|
|
24
|
+
from flyte._logging import logger
|
|
25
|
+
from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
|
|
26
|
+
from flyte._task import TaskTemplate
|
|
27
|
+
from flyte._utils.helpers import _selector_policy
|
|
28
|
+
from flyte.models import ActionID, NativeInterface, SerializationContext
|
|
29
|
+
|
|
30
|
+
R = TypeVar("R")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
async def upload_inputs_with_retry(serialized_inputs: AsyncIterable[bytes] | bytes, inputs_uri: str) -> None:
|
|
34
|
+
"""
|
|
35
|
+
Upload inputs to the specified URI with error handling.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
serialized_inputs: The serialized inputs to upload
|
|
39
|
+
inputs_uri: The destination URI
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
RuntimeSystemError: If the upload fails
|
|
43
|
+
"""
|
|
44
|
+
try:
|
|
45
|
+
# TODO Add retry decorator to this
|
|
46
|
+
await storage.put_stream(serialized_inputs, to_path=inputs_uri)
|
|
47
|
+
except Exception as e:
|
|
48
|
+
logger.exception("Failed to upload inputs")
|
|
49
|
+
raise flyte.errors.RuntimeSystemError(type(e).__name__, str(e)) from e
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
async def handle_action_failure(action: Action, task_name: str) -> Exception:
|
|
53
|
+
"""
|
|
54
|
+
Handle action failure by loading error details or raising a RuntimeSystemError.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
action: The updated action
|
|
58
|
+
task_name: The name of the task
|
|
59
|
+
|
|
60
|
+
Raises:
|
|
61
|
+
Exception: The converted native exception or RuntimeSystemError
|
|
62
|
+
"""
|
|
63
|
+
err = action.err or action.client_err
|
|
64
|
+
if not err and action.phase == run_definition_pb2.PHASE_FAILED:
|
|
65
|
+
logger.error(f"Server reported failure for action {action.name}, checking error file.")
|
|
66
|
+
try:
|
|
67
|
+
error_path = io.error_path(f"{action.run_output_base}/{action.action_id.name}/1")
|
|
68
|
+
err = await io.load_error(error_path)
|
|
69
|
+
except Exception as e:
|
|
70
|
+
logger.exception("Failed to load error file", e)
|
|
71
|
+
err = flyte.errors.RuntimeSystemError(type(e).__name__, f"Failed to load error file: {e}")
|
|
72
|
+
else:
|
|
73
|
+
logger.error(f"Server reported failure for action {action.action_id.name}, error: {err}")
|
|
74
|
+
|
|
75
|
+
exc = convert.convert_error_to_native(err)
|
|
76
|
+
if not exc:
|
|
77
|
+
return flyte.errors.RuntimeSystemError("UnableToConvertError", f"Error in task {task_name}: {err}")
|
|
78
|
+
return exc
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
async def load_and_convert_outputs(iface: NativeInterface, realized_outputs_uri: str) -> Any:
|
|
82
|
+
"""
|
|
83
|
+
Load outputs from the given URI and convert them to native format.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
iface: The Native interface
|
|
87
|
+
realized_outputs_uri: The URI where outputs are stored
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
The converted native outputs
|
|
91
|
+
"""
|
|
92
|
+
outputs_file_path = io.outputs_path(realized_outputs_uri)
|
|
93
|
+
outputs = await io.load_outputs(outputs_file_path)
|
|
94
|
+
return await convert.convert_outputs_to_native(iface, outputs)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def unique_action_name(action_id: ActionID) -> str:
|
|
98
|
+
return f"{action_id.name}_{action_id.run_name}"
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class RemoteController(Controller):
|
|
102
|
+
"""
|
|
103
|
+
This a specialized controller that wraps the core controller and performs IO, serialization and deserialization
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
client_coro: Awaitable[ClientSet],
|
|
109
|
+
workers: int,
|
|
110
|
+
max_system_retries: int,
|
|
111
|
+
default_parent_concurrency: int = 1000,
|
|
112
|
+
):
|
|
113
|
+
""" """
|
|
114
|
+
super().__init__(
|
|
115
|
+
client_coro=client_coro,
|
|
116
|
+
workers=workers,
|
|
117
|
+
max_system_retries=max_system_retries,
|
|
118
|
+
)
|
|
119
|
+
self._default_parent_concurrency = default_parent_concurrency
|
|
120
|
+
self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
|
|
121
|
+
lambda: asyncio.Semaphore(default_parent_concurrency)
|
|
122
|
+
)
|
|
123
|
+
self._parent_action_task_call_sequence: DefaultDict[str, DefaultDict[int, int]] = defaultdict(
|
|
124
|
+
lambda: defaultdict(int)
|
|
125
|
+
)
|
|
126
|
+
self._submit_loop: asyncio.AbstractEventLoop | None = None
|
|
127
|
+
self._submit_thread: threading.Thread | None = None
|
|
128
|
+
|
|
129
|
+
def generate_task_call_sequence(self, task_obj: object, action_id: ActionID) -> int:
|
|
130
|
+
"""
|
|
131
|
+
Generate a task call sequence for the given task object and action ID.
|
|
132
|
+
This is used to track the number of times a task is called within an action.
|
|
133
|
+
"""
|
|
134
|
+
current_action_sequencer = self._parent_action_task_call_sequence[unique_action_name(action_id)]
|
|
135
|
+
current_task_id = id(task_obj)
|
|
136
|
+
v = current_action_sequencer[current_task_id]
|
|
137
|
+
new_seq = v + 1
|
|
138
|
+
current_action_sequencer[current_task_id] = new_seq
|
|
139
|
+
return new_seq
|
|
140
|
+
|
|
141
|
+
async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
142
|
+
ctx = internal_ctx()
|
|
143
|
+
tctx = ctx.data.task_context
|
|
144
|
+
if tctx is None:
|
|
145
|
+
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
146
|
+
current_action_id = tctx.action
|
|
147
|
+
|
|
148
|
+
# In the case of a regular code bundle, we will just pass it down as it is to the downstream tasks
|
|
149
|
+
# It is not allowed to change the code bundle (for regular code bundles) in the middle of a run.
|
|
150
|
+
code_bundle = tctx.code_bundle
|
|
151
|
+
|
|
152
|
+
if code_bundle and code_bundle.pkl:
|
|
153
|
+
logger.debug(f"Building new pkl bundle for task {_task.name}")
|
|
154
|
+
code_bundle = await build_pkl_bundle(
|
|
155
|
+
_task,
|
|
156
|
+
upload_to_controlplane=False,
|
|
157
|
+
upload_from_dataplane_base_path=tctx.run_base_dir,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
|
|
161
|
+
|
|
162
|
+
root_dir = Path(code_bundle.destination).absolute() if code_bundle else Path.cwd()
|
|
163
|
+
# Don't set output path in sec context because node executor will set it
|
|
164
|
+
new_serialization_context = SerializationContext(
|
|
165
|
+
project=current_action_id.project,
|
|
166
|
+
domain=current_action_id.domain,
|
|
167
|
+
org=current_action_id.org,
|
|
168
|
+
code_bundle=code_bundle,
|
|
169
|
+
version=tctx.version,
|
|
170
|
+
# supplied version.
|
|
171
|
+
# input_path=inputs_uri,
|
|
172
|
+
image_cache=tctx.compiled_image_cache,
|
|
173
|
+
root_dir=root_dir,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
task_spec = translate_task_to_wire(_task, new_serialization_context)
|
|
177
|
+
serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
|
|
178
|
+
|
|
179
|
+
inputs_hash = convert.generate_inputs_hash(serialized_inputs)
|
|
180
|
+
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
181
|
+
tctx, task_spec, inputs_hash, _task_call_seq
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
inputs_uri = io.inputs_path(sub_action_output_path)
|
|
185
|
+
await upload_inputs_with_retry(serialized_inputs, inputs_uri)
|
|
186
|
+
|
|
187
|
+
md = task_spec.task_template.metadata
|
|
188
|
+
ignored_input_vars = []
|
|
189
|
+
if len(md.cache_ignore_input_vars) > 0:
|
|
190
|
+
ignored_input_vars = list(md.cache_ignore_input_vars)
|
|
191
|
+
cache_key = None
|
|
192
|
+
if task_spec.task_template.metadata and task_spec.task_template.metadata.discoverable:
|
|
193
|
+
discovery_version = task_spec.task_template.metadata.discovery_version
|
|
194
|
+
cache_key = convert.generate_cache_key_hash(
|
|
195
|
+
_task.name,
|
|
196
|
+
inputs_hash,
|
|
197
|
+
task_spec.task_template.interface,
|
|
198
|
+
discovery_version,
|
|
199
|
+
ignored_input_vars,
|
|
200
|
+
inputs.proto_inputs,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Clear to free memory
|
|
204
|
+
serialized_inputs = None # type: ignore
|
|
205
|
+
inputs_hash = None # type: ignore
|
|
206
|
+
|
|
207
|
+
action = Action.from_task(
|
|
208
|
+
sub_action_id=run_definition_pb2.ActionIdentifier(
|
|
209
|
+
name=sub_action_id.name,
|
|
210
|
+
run=run_definition_pb2.RunIdentifier(
|
|
211
|
+
name=current_action_id.run_name,
|
|
212
|
+
project=current_action_id.project,
|
|
213
|
+
domain=current_action_id.domain,
|
|
214
|
+
org=current_action_id.org,
|
|
215
|
+
),
|
|
216
|
+
),
|
|
217
|
+
parent_action_name=current_action_id.name,
|
|
218
|
+
group_data=tctx.group_data,
|
|
219
|
+
task_spec=task_spec,
|
|
220
|
+
inputs_uri=inputs_uri,
|
|
221
|
+
run_output_base=tctx.run_base_dir,
|
|
222
|
+
cache_key=cache_key,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
try:
|
|
226
|
+
logger.info(
|
|
227
|
+
f"Submitting action Run:[{action.run_name}, Parent:[{action.parent_action_name}], "
|
|
228
|
+
f"task:[{_task.name}], action:[{action.name}]"
|
|
229
|
+
)
|
|
230
|
+
n = await self.submit_action(action)
|
|
231
|
+
logger.info(f"Action for task [{_task.name}] action id: {action.name}, completed!")
|
|
232
|
+
except asyncio.CancelledError:
|
|
233
|
+
# If the action is cancelled, we need to cancel the action on the server as well
|
|
234
|
+
logger.info(f"Action {action.action_id.name} cancelled, cancelling on server")
|
|
235
|
+
await self.cancel_action(action)
|
|
236
|
+
raise
|
|
237
|
+
|
|
238
|
+
if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED:
|
|
239
|
+
exc = await handle_action_failure(action, _task.name)
|
|
240
|
+
raise exc
|
|
241
|
+
|
|
242
|
+
if _task.native_interface.outputs:
|
|
243
|
+
if not n.realized_outputs_uri:
|
|
244
|
+
raise flyte.errors.RuntimeSystemError(
|
|
245
|
+
"RuntimeError",
|
|
246
|
+
f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.",
|
|
247
|
+
)
|
|
248
|
+
return await load_and_convert_outputs(_task.native_interface, n.realized_outputs_uri)
|
|
249
|
+
return None
|
|
250
|
+
|
|
251
|
+
async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
252
|
+
"""
|
|
253
|
+
Submit a task to the remote controller.This creates a new action on the queue service.
|
|
254
|
+
"""
|
|
255
|
+
ctx = internal_ctx()
|
|
256
|
+
tctx = ctx.data.task_context
|
|
257
|
+
if tctx is None:
|
|
258
|
+
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
259
|
+
current_action_id = tctx.action
|
|
260
|
+
task_call_seq = self.generate_task_call_sequence(_task, current_action_id)
|
|
261
|
+
async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
|
|
262
|
+
return await self._submit(task_call_seq, _task, *args, **kwargs)
|
|
263
|
+
|
|
264
|
+
def _sync_thread_loop_runner(self) -> None:
|
|
265
|
+
"""This method runs the event loop and should be invoked in a separate thread."""
|
|
266
|
+
|
|
267
|
+
loop = self._submit_loop
|
|
268
|
+
assert loop is not None
|
|
269
|
+
try:
|
|
270
|
+
loop.run_forever()
|
|
271
|
+
finally:
|
|
272
|
+
loop.close()
|
|
273
|
+
|
|
274
|
+
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
|
|
275
|
+
"""
|
|
276
|
+
This function creates a cached thread and loop for the purpose of calling the submit method synchronously,
|
|
277
|
+
returning a concurrent Future that can be awaited. There's no need for a lock because this function itself is
|
|
278
|
+
single threaded and non-async. This pattern here is basically the trivial/degenerate case of the thread pool
|
|
279
|
+
in the LocalController.
|
|
280
|
+
Please see additional comments in protocol.
|
|
281
|
+
|
|
282
|
+
:param _task:
|
|
283
|
+
:param args:
|
|
284
|
+
:param kwargs:
|
|
285
|
+
:return:
|
|
286
|
+
"""
|
|
287
|
+
if self._submit_thread is None:
|
|
288
|
+
# Please see LocalController for the general implementation of this pattern.
|
|
289
|
+
def exc_handler(loop, context):
|
|
290
|
+
logger.error(f"Remote controller submit sync loop caught exception in {loop}: {context}")
|
|
291
|
+
|
|
292
|
+
with _selector_policy():
|
|
293
|
+
self._submit_loop = asyncio.new_event_loop()
|
|
294
|
+
self._submit_loop.set_exception_handler(exc_handler)
|
|
295
|
+
|
|
296
|
+
self._submit_thread = threading.Thread(
|
|
297
|
+
name=f"remote-controller-{os.getpid()}-submitter", daemon=True, target=self._sync_thread_loop_runner
|
|
298
|
+
)
|
|
299
|
+
self._submit_thread.start()
|
|
300
|
+
|
|
301
|
+
coro = self.submit(_task, *args, **kwargs)
|
|
302
|
+
assert self._submit_loop is not None, "Submit loop should always have been initialized by now"
|
|
303
|
+
fut = asyncio.run_coroutine_threadsafe(coro, self._submit_loop)
|
|
304
|
+
return fut
|
|
305
|
+
|
|
306
|
+
async def finalize_parent_action(self, action_id: ActionID):
|
|
307
|
+
"""
|
|
308
|
+
This method is invoked when the parent action is finished. It will finalize the run and upload the outputs
|
|
309
|
+
to the control plane.
|
|
310
|
+
"""
|
|
311
|
+
run_id = run_definition_pb2.RunIdentifier(
|
|
312
|
+
name=action_id.run_name,
|
|
313
|
+
project=action_id.project,
|
|
314
|
+
domain=action_id.domain,
|
|
315
|
+
org=action_id.org,
|
|
316
|
+
)
|
|
317
|
+
await super()._finalize_parent_action(run_id=run_id, parent_action_name=action_id.name)
|
|
318
|
+
self._parent_action_semaphore.pop(unique_action_name(action_id), None)
|
|
319
|
+
self._parent_action_task_call_sequence.pop(unique_action_name(action_id), None)
|
|
320
|
+
|
|
321
|
+
async def get_action_outputs(
|
|
322
|
+
self, _interface: NativeInterface, _func: Callable, *args, **kwargs
|
|
323
|
+
) -> Tuple[TraceInfo, bool]:
|
|
324
|
+
"""
|
|
325
|
+
This method returns the outputs of the action, if it is available.
|
|
326
|
+
If not available it raises a NotFoundError.
|
|
327
|
+
:param _interface: NativeInterface
|
|
328
|
+
:param _func: Function name
|
|
329
|
+
:param args: Arguments
|
|
330
|
+
:param kwargs: Keyword arguments
|
|
331
|
+
:return:
|
|
332
|
+
"""
|
|
333
|
+
ctx = internal_ctx()
|
|
334
|
+
tctx = ctx.data.task_context
|
|
335
|
+
if tctx is None:
|
|
336
|
+
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
337
|
+
current_action_id = tctx.action
|
|
338
|
+
|
|
339
|
+
func_name = _func.__name__
|
|
340
|
+
invoke_seq_num = self.generate_task_call_sequence(_func, current_action_id)
|
|
341
|
+
inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
|
|
342
|
+
serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
|
|
343
|
+
|
|
344
|
+
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
345
|
+
tctx, func_name, serialized_inputs, invoke_seq_num
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
inputs_uri = io.inputs_path(sub_action_output_path)
|
|
349
|
+
await upload_inputs_with_retry(serialized_inputs, inputs_uri)
|
|
350
|
+
# Clear to free memory
|
|
351
|
+
serialized_inputs = None # type: ignore
|
|
352
|
+
|
|
353
|
+
prev_action = await self.get_action(
|
|
354
|
+
run_definition_pb2.ActionIdentifier(
|
|
355
|
+
name=sub_action_id.name,
|
|
356
|
+
run=run_definition_pb2.RunIdentifier(
|
|
357
|
+
name=current_action_id.run_name,
|
|
358
|
+
project=current_action_id.project,
|
|
359
|
+
domain=current_action_id.domain,
|
|
360
|
+
org=current_action_id.org,
|
|
361
|
+
),
|
|
362
|
+
),
|
|
363
|
+
current_action_id.name,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
if prev_action is None:
|
|
367
|
+
return TraceInfo(sub_action_id, _interface, inputs_uri), False
|
|
368
|
+
|
|
369
|
+
if prev_action.phase == run_definition_pb2.PHASE_FAILED:
|
|
370
|
+
if prev_action.has_error():
|
|
371
|
+
exc = convert.convert_error_to_native(prev_action.err)
|
|
372
|
+
return TraceInfo(sub_action_id, _interface, inputs_uri, error=exc), True
|
|
373
|
+
else:
|
|
374
|
+
logger.warning(f"Action {prev_action.action_id.name} failed, but no error was found, re-running trace!")
|
|
375
|
+
elif prev_action.realized_outputs_uri is not None:
|
|
376
|
+
outputs_file_path = io.outputs_path(prev_action.realized_outputs_uri)
|
|
377
|
+
o = await io.load_outputs(outputs_file_path)
|
|
378
|
+
outputs = await convert.convert_outputs_to_native(_interface, o)
|
|
379
|
+
return TraceInfo(sub_action_id, _interface, inputs_uri, output=outputs), True
|
|
380
|
+
|
|
381
|
+
return TraceInfo(sub_action_id, _interface, inputs_uri), False
|
|
382
|
+
|
|
383
|
+
async def record_trace(self, info: TraceInfo):
|
|
384
|
+
"""
|
|
385
|
+
Record a trace action. This is used to record the trace of the action and should be called when the action
|
|
386
|
+
:param info:
|
|
387
|
+
:return:
|
|
388
|
+
"""
|
|
389
|
+
ctx = internal_ctx()
|
|
390
|
+
tctx = ctx.data.task_context
|
|
391
|
+
if tctx is None:
|
|
392
|
+
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
393
|
+
|
|
394
|
+
current_output_path = tctx.output_path
|
|
395
|
+
sub_run_output_path = storage.join(current_output_path, info.action.name)
|
|
396
|
+
|
|
397
|
+
if info.interface.has_outputs():
|
|
398
|
+
if info.output:
|
|
399
|
+
outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
|
|
400
|
+
outputs_file_path = io.outputs_path(sub_run_output_path)
|
|
401
|
+
await io.upload_outputs(outputs, outputs_file_path)
|
|
402
|
+
elif info.error:
|
|
403
|
+
err = convert.convert_from_native_to_error(info.error)
|
|
404
|
+
error_path = io.error_path(sub_run_output_path)
|
|
405
|
+
await io.upload_error(err.err, error_path)
|
|
406
|
+
else:
|
|
407
|
+
raise flyte.errors.RuntimeSystemError("BadTraceInfo", "Trace info does not have output or error")
|
|
408
|
+
|
|
409
|
+
async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
|
|
410
|
+
ctx = internal_ctx()
|
|
411
|
+
tctx = ctx.data.task_context
|
|
412
|
+
if tctx is None:
|
|
413
|
+
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
414
|
+
current_action_id = tctx.action
|
|
415
|
+
task_name = _task.spec.task_template.id.name
|
|
416
|
+
|
|
417
|
+
invoke_seq_num = self.generate_task_call_sequence(_task, current_action_id)
|
|
418
|
+
|
|
419
|
+
native_interface = types.guess_interface(
|
|
420
|
+
_task.spec.task_template.interface, default_inputs=_task.spec.default_inputs
|
|
421
|
+
)
|
|
422
|
+
inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs)
|
|
423
|
+
serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
|
|
424
|
+
inputs_hash = convert.generate_inputs_hash(serialized_inputs)
|
|
425
|
+
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
426
|
+
tctx, task_name, inputs_hash, invoke_seq_num
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
inputs_uri = io.inputs_path(sub_action_output_path)
|
|
430
|
+
await upload_inputs_with_retry(serialized_inputs, inputs_uri)
|
|
431
|
+
# cache key - task name, task signature, inputs, cache version
|
|
432
|
+
cache_key = None
|
|
433
|
+
md = _task.spec.task_template.metadata
|
|
434
|
+
ignored_input_vars = []
|
|
435
|
+
if len(md.cache_ignore_input_vars) > 0:
|
|
436
|
+
ignored_input_vars = list(md.cache_ignore_input_vars)
|
|
437
|
+
if _task.spec.task_template.metadata and _task.spec.task_template.metadata.discoverable:
|
|
438
|
+
discovery_version = _task.spec.task_template.metadata.discovery_version
|
|
439
|
+
cache_key = convert.generate_cache_key_hash(
|
|
440
|
+
task_name,
|
|
441
|
+
inputs_hash,
|
|
442
|
+
_task.spec.task_template.interface,
|
|
443
|
+
discovery_version,
|
|
444
|
+
ignored_input_vars,
|
|
445
|
+
inputs.proto_inputs,
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
# Clear to free memory
|
|
449
|
+
serialized_inputs = None # type: ignore
|
|
450
|
+
inputs_hash = None # type: ignore
|
|
451
|
+
|
|
452
|
+
action = Action.from_task(
|
|
453
|
+
sub_action_id=run_definition_pb2.ActionIdentifier(
|
|
454
|
+
name=sub_action_id.name,
|
|
455
|
+
run=run_definition_pb2.RunIdentifier(
|
|
456
|
+
name=current_action_id.run_name,
|
|
457
|
+
project=current_action_id.project,
|
|
458
|
+
domain=current_action_id.domain,
|
|
459
|
+
org=current_action_id.org,
|
|
460
|
+
),
|
|
461
|
+
),
|
|
462
|
+
parent_action_name=current_action_id.name,
|
|
463
|
+
group_data=tctx.group_data,
|
|
464
|
+
task_spec=_task.spec,
|
|
465
|
+
inputs_uri=inputs_uri,
|
|
466
|
+
run_output_base=tctx.run_base_dir,
|
|
467
|
+
cache_key=cache_key,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
try:
|
|
471
|
+
logger.info(
|
|
472
|
+
f"Submitting action Run:[{action.run_name}, Parent:[{action.parent_action_name}], "
|
|
473
|
+
f"task:[{task_name}], action:[{action.name}]"
|
|
474
|
+
)
|
|
475
|
+
n = await self.submit_action(action)
|
|
476
|
+
logger.info(f"Action for task [{task_name}] action id: {action.name}, completed!")
|
|
477
|
+
except asyncio.CancelledError:
|
|
478
|
+
# If the action is cancelled, we need to cancel the action on the server as well
|
|
479
|
+
logger.info(f"Action {action.action_id.name} cancelled, cancelling on server")
|
|
480
|
+
await self.cancel_action(action)
|
|
481
|
+
raise
|
|
482
|
+
|
|
483
|
+
if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED:
|
|
484
|
+
exc = await handle_action_failure(action, task_name)
|
|
485
|
+
raise exc
|
|
486
|
+
|
|
487
|
+
if native_interface.outputs:
|
|
488
|
+
if not n.realized_outputs_uri:
|
|
489
|
+
raise flyte.errors.RuntimeSystemError(
|
|
490
|
+
"RuntimeError",
|
|
491
|
+
f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.",
|
|
492
|
+
)
|
|
493
|
+
return await load_and_convert_outputs(native_interface, n.realized_outputs_uri)
|
|
494
|
+
return None
|