flyte 0.0.1b3__py3-none-any.whl → 0.2.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +20 -4
- flyte/_bin/runtime.py +33 -7
- flyte/_build.py +3 -2
- flyte/_cache/cache.py +1 -2
- flyte/_code_bundle/_packaging.py +1 -1
- flyte/_code_bundle/_utils.py +0 -16
- flyte/_code_bundle/bundle.py +43 -12
- flyte/_context.py +8 -2
- flyte/_deploy.py +56 -15
- flyte/_environment.py +45 -4
- flyte/_excepthook.py +37 -0
- flyte/_group.py +2 -1
- flyte/_image.py +8 -4
- flyte/_initialize.py +112 -254
- flyte/_interface.py +3 -3
- flyte/_internal/controllers/__init__.py +19 -6
- flyte/_internal/controllers/_local_controller.py +83 -8
- flyte/_internal/controllers/_trace.py +2 -1
- flyte/_internal/controllers/remote/__init__.py +27 -7
- flyte/_internal/controllers/remote/_action.py +7 -2
- flyte/_internal/controllers/remote/_client.py +5 -1
- flyte/_internal/controllers/remote/_controller.py +159 -26
- flyte/_internal/controllers/remote/_core.py +13 -5
- flyte/_internal/controllers/remote/_informer.py +4 -4
- flyte/_internal/controllers/remote/_service_protocol.py +6 -6
- flyte/_internal/imagebuild/docker_builder.py +12 -1
- flyte/_internal/imagebuild/image_builder.py +16 -11
- flyte/_internal/runtime/convert.py +164 -21
- flyte/_internal/runtime/entrypoints.py +1 -1
- flyte/_internal/runtime/io.py +3 -3
- flyte/_internal/runtime/task_serde.py +140 -20
- flyte/_internal/runtime/taskrunner.py +4 -3
- flyte/_internal/runtime/types_serde.py +1 -1
- flyte/_logging.py +12 -1
- flyte/_map.py +215 -0
- flyte/_pod.py +19 -0
- flyte/_protos/common/list_pb2.py +3 -3
- flyte/_protos/common/list_pb2.pyi +2 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +28 -24
- flyte/_protos/logs/dataplane/payload_pb2.pyi +11 -2
- flyte/_protos/workflow/common_pb2.py +27 -0
- flyte/_protos/workflow/common_pb2.pyi +14 -0
- flyte/_protos/workflow/environment_pb2.py +29 -0
- flyte/_protos/workflow/environment_pb2.pyi +12 -0
- flyte/_protos/workflow/queue_service_pb2.py +40 -41
- flyte/_protos/workflow/queue_service_pb2.pyi +35 -30
- flyte/_protos/workflow/queue_service_pb2_grpc.py +15 -15
- flyte/_protos/workflow/run_definition_pb2.py +61 -61
- flyte/_protos/workflow/run_definition_pb2.pyi +8 -4
- flyte/_protos/workflow/run_service_pb2.py +20 -24
- flyte/_protos/workflow/run_service_pb2.pyi +2 -6
- flyte/_protos/workflow/state_service_pb2.py +36 -28
- flyte/_protos/workflow/state_service_pb2.pyi +19 -15
- flyte/_protos/workflow/state_service_pb2_grpc.py +28 -28
- flyte/_protos/workflow/task_definition_pb2.py +29 -22
- flyte/_protos/workflow/task_definition_pb2.pyi +21 -5
- flyte/_protos/workflow/task_service_pb2.py +27 -11
- flyte/_protos/workflow/task_service_pb2.pyi +29 -1
- flyte/_protos/workflow/task_service_pb2_grpc.py +34 -0
- flyte/_run.py +166 -95
- flyte/_task.py +110 -28
- flyte/_task_environment.py +55 -72
- flyte/_trace.py +6 -14
- flyte/_utils/__init__.py +6 -0
- flyte/_utils/async_cache.py +139 -0
- flyte/_utils/coro_management.py +0 -2
- flyte/_utils/helpers.py +45 -19
- flyte/_utils/org_discovery.py +57 -0
- flyte/_version.py +2 -2
- flyte/cli/__init__.py +3 -0
- flyte/cli/_abort.py +28 -0
- flyte/{_cli → cli}/_common.py +73 -23
- flyte/cli/_create.py +145 -0
- flyte/{_cli → cli}/_delete.py +4 -4
- flyte/{_cli → cli}/_deploy.py +26 -14
- flyte/cli/_gen.py +163 -0
- flyte/{_cli → cli}/_get.py +98 -23
- {union/_cli → flyte/cli}/_params.py +106 -147
- flyte/{_cli → cli}/_run.py +99 -20
- flyte/cli/main.py +166 -0
- flyte/config/__init__.py +3 -0
- flyte/config/_config.py +216 -0
- flyte/config/_internal.py +64 -0
- flyte/config/_reader.py +207 -0
- flyte/errors.py +29 -0
- flyte/extras/_container.py +33 -43
- flyte/io/__init__.py +17 -1
- flyte/io/_dir.py +2 -2
- flyte/io/_file.py +3 -4
- flyte/io/{structured_dataset → _structured_dataset}/basic_dfs.py +1 -1
- flyte/io/{structured_dataset → _structured_dataset}/structured_dataset.py +1 -1
- flyte/{_datastructures.py → models.py} +56 -7
- flyte/remote/__init__.py +2 -1
- flyte/remote/_client/_protocols.py +2 -0
- flyte/remote/_client/auth/_auth_utils.py +14 -0
- flyte/remote/_client/auth/_channel.py +34 -3
- flyte/remote/_client/auth/_token_client.py +3 -3
- flyte/remote/_client/controlplane.py +13 -13
- flyte/remote/_console.py +1 -1
- flyte/remote/_data.py +10 -6
- flyte/remote/_logs.py +89 -29
- flyte/remote/_project.py +8 -9
- flyte/remote/_run.py +228 -131
- flyte/remote/_secret.py +12 -12
- flyte/remote/_task.py +179 -15
- flyte/report/_report.py +4 -4
- flyte/storage/__init__.py +5 -0
- flyte/storage/_config.py +233 -0
- flyte/storage/_storage.py +23 -3
- flyte/syncify/__init__.py +56 -0
- flyte/syncify/_api.py +371 -0
- flyte/types/__init__.py +23 -0
- flyte/types/_interface.py +22 -7
- flyte/{io/pickle/transformer.py → types/_pickle.py} +2 -1
- flyte/types/_type_engine.py +95 -18
- flyte-0.2.0a0.dist-info/METADATA +249 -0
- flyte-0.2.0a0.dist-info/RECORD +218 -0
- {flyte-0.0.1b3.dist-info → flyte-0.2.0a0.dist-info}/entry_points.txt +1 -1
- flyte/_api_commons.py +0 -3
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_create.py +0 -42
- flyte/_cli/main.py +0 -72
- flyte/_internal/controllers/pbhash.py +0 -39
- flyte/io/_dataframe.py +0 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte-0.0.1b3.dist-info/METADATA +0 -179
- flyte-0.0.1b3.dist-info/RECORD +0 -390
- union/__init__.py +0 -54
- union/_api_commons.py +0 -3
- union/_bin/__init__.py +0 -0
- union/_bin/runtime.py +0 -113
- union/_build.py +0 -25
- union/_cache/__init__.py +0 -12
- union/_cache/cache.py +0 -141
- union/_cache/defaults.py +0 -9
- union/_cache/policy_function_body.py +0 -42
- union/_cli/__init__.py +0 -0
- union/_cli/_common.py +0 -263
- union/_cli/_create.py +0 -40
- union/_cli/_delete.py +0 -23
- union/_cli/_deploy.py +0 -120
- union/_cli/_get.py +0 -162
- union/_cli/_run.py +0 -150
- union/_cli/main.py +0 -72
- union/_code_bundle/__init__.py +0 -8
- union/_code_bundle/_ignore.py +0 -113
- union/_code_bundle/_packaging.py +0 -187
- union/_code_bundle/_utils.py +0 -342
- union/_code_bundle/bundle.py +0 -176
- union/_context.py +0 -146
- union/_datastructures.py +0 -295
- union/_deploy.py +0 -185
- union/_doc.py +0 -29
- union/_docstring.py +0 -26
- union/_environment.py +0 -43
- union/_group.py +0 -31
- union/_hash.py +0 -23
- union/_image.py +0 -760
- union/_initialize.py +0 -585
- union/_interface.py +0 -84
- union/_internal/__init__.py +0 -3
- union/_internal/controllers/__init__.py +0 -77
- union/_internal/controllers/_local_controller.py +0 -77
- union/_internal/controllers/pbhash.py +0 -39
- union/_internal/controllers/remote/__init__.py +0 -40
- union/_internal/controllers/remote/_action.py +0 -131
- union/_internal/controllers/remote/_client.py +0 -43
- union/_internal/controllers/remote/_controller.py +0 -169
- union/_internal/controllers/remote/_core.py +0 -341
- union/_internal/controllers/remote/_informer.py +0 -260
- union/_internal/controllers/remote/_service_protocol.py +0 -44
- union/_internal/imagebuild/__init__.py +0 -11
- union/_internal/imagebuild/docker_builder.py +0 -416
- union/_internal/imagebuild/image_builder.py +0 -243
- union/_internal/imagebuild/remote_builder.py +0 -0
- union/_internal/resolvers/__init__.py +0 -0
- union/_internal/resolvers/_task_module.py +0 -31
- union/_internal/resolvers/common.py +0 -24
- union/_internal/resolvers/default.py +0 -27
- union/_internal/runtime/__init__.py +0 -0
- union/_internal/runtime/convert.py +0 -163
- union/_internal/runtime/entrypoints.py +0 -121
- union/_internal/runtime/io.py +0 -136
- union/_internal/runtime/resources_serde.py +0 -134
- union/_internal/runtime/task_serde.py +0 -202
- union/_internal/runtime/taskrunner.py +0 -179
- union/_internal/runtime/types_serde.py +0 -53
- union/_logging.py +0 -124
- union/_protos/__init__.py +0 -0
- union/_protos/common/authorization_pb2.py +0 -66
- union/_protos/common/authorization_pb2.pyi +0 -106
- union/_protos/common/identifier_pb2.py +0 -71
- union/_protos/common/identifier_pb2.pyi +0 -82
- union/_protos/common/identity_pb2.py +0 -48
- union/_protos/common/identity_pb2.pyi +0 -72
- union/_protos/common/identity_pb2_grpc.py +0 -4
- union/_protos/common/list_pb2.py +0 -36
- union/_protos/common/list_pb2.pyi +0 -69
- union/_protos/common/list_pb2_grpc.py +0 -4
- union/_protos/common/policy_pb2.py +0 -37
- union/_protos/common/policy_pb2.pyi +0 -27
- union/_protos/common/policy_pb2_grpc.py +0 -4
- union/_protos/common/role_pb2.py +0 -37
- union/_protos/common/role_pb2.pyi +0 -51
- union/_protos/common/role_pb2_grpc.py +0 -4
- union/_protos/common/runtime_version_pb2.py +0 -28
- union/_protos/common/runtime_version_pb2.pyi +0 -24
- union/_protos/common/runtime_version_pb2_grpc.py +0 -4
- union/_protos/logs/dataplane/payload_pb2.py +0 -96
- union/_protos/logs/dataplane/payload_pb2.pyi +0 -168
- union/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
- union/_protos/secret/definition_pb2.py +0 -49
- union/_protos/secret/definition_pb2.pyi +0 -93
- union/_protos/secret/definition_pb2_grpc.py +0 -4
- union/_protos/secret/payload_pb2.py +0 -62
- union/_protos/secret/payload_pb2.pyi +0 -94
- union/_protos/secret/payload_pb2_grpc.py +0 -4
- union/_protos/secret/secret_pb2.py +0 -38
- union/_protos/secret/secret_pb2.pyi +0 -6
- union/_protos/secret/secret_pb2_grpc.py +0 -198
- union/_protos/validate/validate/validate_pb2.py +0 -76
- union/_protos/workflow/node_execution_service_pb2.py +0 -26
- union/_protos/workflow/node_execution_service_pb2.pyi +0 -4
- union/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
- union/_protos/workflow/queue_service_pb2.py +0 -75
- union/_protos/workflow/queue_service_pb2.pyi +0 -103
- union/_protos/workflow/queue_service_pb2_grpc.py +0 -172
- union/_protos/workflow/run_definition_pb2.py +0 -100
- union/_protos/workflow/run_definition_pb2.pyi +0 -256
- union/_protos/workflow/run_definition_pb2_grpc.py +0 -4
- union/_protos/workflow/run_logs_service_pb2.py +0 -41
- union/_protos/workflow/run_logs_service_pb2.pyi +0 -28
- union/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
- union/_protos/workflow/run_service_pb2.py +0 -133
- union/_protos/workflow/run_service_pb2.pyi +0 -173
- union/_protos/workflow/run_service_pb2_grpc.py +0 -412
- union/_protos/workflow/state_service_pb2.py +0 -58
- union/_protos/workflow/state_service_pb2.pyi +0 -69
- union/_protos/workflow/state_service_pb2_grpc.py +0 -138
- union/_protos/workflow/task_definition_pb2.py +0 -72
- union/_protos/workflow/task_definition_pb2.pyi +0 -65
- union/_protos/workflow/task_definition_pb2_grpc.py +0 -4
- union/_protos/workflow/task_service_pb2.py +0 -44
- union/_protos/workflow/task_service_pb2.pyi +0 -31
- union/_protos/workflow/task_service_pb2_grpc.py +0 -104
- union/_resources.py +0 -226
- union/_retry.py +0 -32
- union/_reusable_environment.py +0 -25
- union/_run.py +0 -374
- union/_secret.py +0 -61
- union/_task.py +0 -354
- union/_task_environment.py +0 -186
- union/_timeout.py +0 -47
- union/_tools.py +0 -27
- union/_utils/__init__.py +0 -11
- union/_utils/asyn.py +0 -119
- union/_utils/file_handling.py +0 -71
- union/_utils/helpers.py +0 -46
- union/_utils/lazy_module.py +0 -54
- union/_utils/uv_script_parser.py +0 -49
- union/_version.py +0 -21
- union/connectors/__init__.py +0 -0
- union/errors.py +0 -128
- union/extras/__init__.py +0 -5
- union/extras/_container.py +0 -263
- union/io/__init__.py +0 -11
- union/io/_dataframe.py +0 -0
- union/io/_dir.py +0 -425
- union/io/_file.py +0 -418
- union/io/pickle/__init__.py +0 -0
- union/io/pickle/transformer.py +0 -117
- union/io/structured_dataset/__init__.py +0 -122
- union/io/structured_dataset/basic_dfs.py +0 -219
- union/io/structured_dataset/structured_dataset.py +0 -1057
- union/py.typed +0 -0
- union/remote/__init__.py +0 -23
- union/remote/_client/__init__.py +0 -0
- union/remote/_client/_protocols.py +0 -129
- union/remote/_client/auth/__init__.py +0 -12
- union/remote/_client/auth/_authenticators/__init__.py +0 -0
- union/remote/_client/auth/_authenticators/base.py +0 -391
- union/remote/_client/auth/_authenticators/client_credentials.py +0 -73
- union/remote/_client/auth/_authenticators/device_code.py +0 -120
- union/remote/_client/auth/_authenticators/external_command.py +0 -77
- union/remote/_client/auth/_authenticators/factory.py +0 -200
- union/remote/_client/auth/_authenticators/pkce.py +0 -515
- union/remote/_client/auth/_channel.py +0 -184
- union/remote/_client/auth/_client_config.py +0 -83
- union/remote/_client/auth/_default_html.py +0 -32
- union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- union/remote/_client/auth/_grpc_utils/auth_interceptor.py +0 -204
- union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +0 -144
- union/remote/_client/auth/_keyring.py +0 -154
- union/remote/_client/auth/_token_client.py +0 -258
- union/remote/_client/auth/errors.py +0 -16
- union/remote/_client/controlplane.py +0 -86
- union/remote/_data.py +0 -149
- union/remote/_logs.py +0 -74
- union/remote/_project.py +0 -86
- union/remote/_run.py +0 -820
- union/remote/_secret.py +0 -132
- union/remote/_task.py +0 -193
- union/report/__init__.py +0 -3
- union/report/_report.py +0 -178
- union/report/_template.html +0 -124
- union/storage/__init__.py +0 -24
- union/storage/_remote_fs.py +0 -34
- union/storage/_storage.py +0 -247
- union/storage/_utils.py +0 -5
- union/types/__init__.py +0 -11
- union/types/_renderer.py +0 -162
- union/types/_string_literals.py +0 -120
- union/types/_type_engine.py +0 -2131
- union/types/_utils.py +0 -80
- /union/_protos/common/authorization_pb2_grpc.py → /flyte/_protos/workflow/common_pb2_grpc.py +0 -0
- /union/_protos/common/identifier_pb2_grpc.py → /flyte/_protos/workflow/environment_pb2_grpc.py +0 -0
- /flyte/io/{structured_dataset → _structured_dataset}/__init__.py +0 -0
- {flyte-0.0.1b3.dist-info → flyte-0.2.0a0.dist-info}/WHEEL +0 -0
- {flyte-0.0.1b3.dist-info → flyte-0.2.0a0.dist-info}/top_level.txt +0 -0
|
@@ -1,26 +1,79 @@
|
|
|
1
|
-
|
|
1
|
+
import asyncio
|
|
2
|
+
import atexit
|
|
3
|
+
import concurrent.futures
|
|
4
|
+
import os
|
|
5
|
+
import threading
|
|
6
|
+
from typing import Any, Callable, Tuple, TypeVar
|
|
2
7
|
|
|
3
8
|
import flyte.errors
|
|
4
9
|
from flyte._context import internal_ctx
|
|
5
|
-
from flyte._datastructures import ActionID, NativeInterface, RawDataPath
|
|
6
10
|
from flyte._internal.controllers import TraceInfo
|
|
7
11
|
from flyte._internal.runtime import convert
|
|
8
12
|
from flyte._internal.runtime.entrypoints import direct_dispatch
|
|
9
13
|
from flyte._logging import log, logger
|
|
10
14
|
from flyte._protos.workflow import task_definition_pb2
|
|
11
15
|
from flyte._task import TaskTemplate
|
|
16
|
+
from flyte._utils.helpers import _selector_policy
|
|
17
|
+
from flyte.models import ActionID, NativeInterface
|
|
12
18
|
|
|
13
19
|
R = TypeVar("R")
|
|
14
20
|
|
|
15
21
|
|
|
22
|
+
class _TaskRunner:
|
|
23
|
+
"""A task runner that runs an asyncio event loop on a background thread."""
|
|
24
|
+
|
|
25
|
+
def __init__(self) -> None:
|
|
26
|
+
self.__loop: asyncio.AbstractEventLoop | None = None
|
|
27
|
+
self.__runner_thread: threading.Thread | None = None
|
|
28
|
+
self.__lock = threading.Lock()
|
|
29
|
+
atexit.register(self._close)
|
|
30
|
+
|
|
31
|
+
def _close(self) -> None:
|
|
32
|
+
if self.__loop:
|
|
33
|
+
self.__loop.stop()
|
|
34
|
+
|
|
35
|
+
def _execute(self) -> None:
|
|
36
|
+
loop = self.__loop
|
|
37
|
+
assert loop is not None
|
|
38
|
+
try:
|
|
39
|
+
loop.run_forever()
|
|
40
|
+
finally:
|
|
41
|
+
loop.close()
|
|
42
|
+
|
|
43
|
+
def get_exc_handler(self):
|
|
44
|
+
def exc_handler(loop, context):
|
|
45
|
+
logger.error(
|
|
46
|
+
f"Taskrunner for {self.__runner_thread.name if self.__runner_thread else 'no thread'} caught"
|
|
47
|
+
f" exception in {loop}: {context}"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return exc_handler
|
|
51
|
+
|
|
52
|
+
def get_run_future(self, coro: Any) -> concurrent.futures.Future:
|
|
53
|
+
"""Synchronously run a coroutine on a background thread."""
|
|
54
|
+
name = f"{threading.current_thread().name} : loop-runner"
|
|
55
|
+
with self.__lock:
|
|
56
|
+
if self.__loop is None:
|
|
57
|
+
with _selector_policy():
|
|
58
|
+
self.__loop = asyncio.new_event_loop()
|
|
59
|
+
|
|
60
|
+
exc_handler = self.get_exc_handler()
|
|
61
|
+
self.__loop.set_exception_handler(exc_handler)
|
|
62
|
+
self.__runner_thread = threading.Thread(target=self._execute, daemon=True, name=name)
|
|
63
|
+
self.__runner_thread.start()
|
|
64
|
+
fut = asyncio.run_coroutine_threadsafe(coro, self.__loop)
|
|
65
|
+
return fut
|
|
66
|
+
|
|
67
|
+
|
|
16
68
|
class LocalController:
|
|
17
69
|
def __init__(self):
|
|
18
70
|
logger.debug("LocalController init")
|
|
71
|
+
self._runner_map: dict[str, _TaskRunner] = {}
|
|
19
72
|
|
|
20
73
|
@log
|
|
21
74
|
async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
22
75
|
"""
|
|
23
|
-
|
|
76
|
+
Main entrypoint for submitting a task to the local controller.
|
|
24
77
|
"""
|
|
25
78
|
ctx = internal_ctx()
|
|
26
79
|
tctx = ctx.data.task_context
|
|
@@ -28,8 +81,12 @@ class LocalController:
|
|
|
28
81
|
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
29
82
|
|
|
30
83
|
inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
|
|
31
|
-
|
|
32
|
-
|
|
84
|
+
serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
|
|
85
|
+
|
|
86
|
+
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
87
|
+
tctx, _task.name, serialized_inputs, 0
|
|
88
|
+
)
|
|
89
|
+
sub_action_raw_data_path = tctx.raw_data_path
|
|
33
90
|
|
|
34
91
|
out, err = await direct_dispatch(
|
|
35
92
|
_task,
|
|
@@ -54,6 +111,18 @@ class LocalController:
|
|
|
54
111
|
return result
|
|
55
112
|
return out
|
|
56
113
|
|
|
114
|
+
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
|
|
115
|
+
name = threading.current_thread().name + f"PID:{os.getpid()}"
|
|
116
|
+
coro = self.submit(_task, *args, **kwargs)
|
|
117
|
+
if name not in self._runner_map:
|
|
118
|
+
if len(self._runner_map) > 100:
|
|
119
|
+
logger.warning(
|
|
120
|
+
"More than 100 event loop runners created!!! This could be a case of runaway recursion..."
|
|
121
|
+
)
|
|
122
|
+
self._runner_map[name] = _TaskRunner()
|
|
123
|
+
|
|
124
|
+
return self._runner_map[name].get_run_future(coro)
|
|
125
|
+
|
|
57
126
|
async def finalize_parent_action(self, action: ActionID):
|
|
58
127
|
pass
|
|
59
128
|
|
|
@@ -64,7 +133,7 @@ class LocalController:
|
|
|
64
133
|
pass
|
|
65
134
|
|
|
66
135
|
async def get_action_outputs(
|
|
67
|
-
self, _interface: NativeInterface,
|
|
136
|
+
self, _interface: NativeInterface, _func: Callable, *args, **kwargs
|
|
68
137
|
) -> Tuple[TraceInfo, bool]:
|
|
69
138
|
"""
|
|
70
139
|
This method returns the outputs of the action, if it is available.
|
|
@@ -79,8 +148,13 @@ class LocalController:
|
|
|
79
148
|
if _interface.inputs:
|
|
80
149
|
converted_inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
|
|
81
150
|
assert converted_inputs
|
|
151
|
+
|
|
152
|
+
serialized_inputs = converted_inputs.proto_inputs.SerializeToString(deterministic=True)
|
|
82
153
|
action_id, action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
83
|
-
tctx,
|
|
154
|
+
tctx,
|
|
155
|
+
_func.__name__,
|
|
156
|
+
serialized_inputs,
|
|
157
|
+
0,
|
|
84
158
|
)
|
|
85
159
|
assert action_output_path
|
|
86
160
|
return (
|
|
@@ -88,6 +162,7 @@ class LocalController:
|
|
|
88
162
|
action=action_id,
|
|
89
163
|
interface=_interface,
|
|
90
164
|
inputs_path=action_output_path,
|
|
165
|
+
name=_func.__name__,
|
|
91
166
|
),
|
|
92
167
|
True,
|
|
93
168
|
)
|
|
@@ -105,7 +180,7 @@ class LocalController:
|
|
|
105
180
|
|
|
106
181
|
if info.interface.outputs and info.output:
|
|
107
182
|
# If the result is not an AsyncGenerator, convert it directly
|
|
108
|
-
converted_outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
|
|
183
|
+
converted_outputs = await convert.convert_from_native_to_outputs(info.output, info.interface, info.name)
|
|
109
184
|
assert converted_outputs
|
|
110
185
|
elif info.error:
|
|
111
186
|
# If there is an error, convert it to a native error
|
|
@@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|
|
2
2
|
from datetime import timedelta
|
|
3
3
|
from typing import Any, Optional
|
|
4
4
|
|
|
5
|
-
from flyte.
|
|
5
|
+
from flyte.models import ActionID, NativeInterface
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
@dataclass
|
|
@@ -18,6 +18,7 @@ class TraceInfo:
|
|
|
18
18
|
duration: Optional[timedelta] = None
|
|
19
19
|
output: Optional[Any] = None
|
|
20
20
|
error: Optional[Exception] = None
|
|
21
|
+
name: str = ""
|
|
21
22
|
|
|
22
23
|
def add_outputs(self, output: Any, duration: timedelta):
|
|
23
24
|
"""
|
|
@@ -10,13 +10,13 @@ __all__ = ["RemoteController", "create_remote_controller"]
|
|
|
10
10
|
def create_remote_controller(
|
|
11
11
|
*,
|
|
12
12
|
api_key: str | None = None,
|
|
13
|
-
|
|
14
|
-
endpoint: str,
|
|
15
|
-
client_config: ClientConfig | None = None,
|
|
16
|
-
headless: bool = False,
|
|
13
|
+
endpoint: str | None = None,
|
|
17
14
|
insecure: bool = False,
|
|
18
15
|
insecure_skip_verify: bool = False,
|
|
19
16
|
ca_cert_file_path: str | None = None,
|
|
17
|
+
client_config: ClientConfig | None = None,
|
|
18
|
+
auth_type: AuthType = "Pkce",
|
|
19
|
+
headless: bool = False,
|
|
20
20
|
command: List[str] | None = None,
|
|
21
21
|
proxy_command: List[str] | None = None,
|
|
22
22
|
client_id: str | None = None,
|
|
@@ -27,13 +27,33 @@ def create_remote_controller(
|
|
|
27
27
|
"""
|
|
28
28
|
Create a new instance of the remote controller.
|
|
29
29
|
"""
|
|
30
|
+
assert endpoint or api_key, "Either endpoint or api_key must be provided when initializing remote controller"
|
|
30
31
|
from ._client import ControllerClient
|
|
31
32
|
from ._controller import RemoteController
|
|
32
33
|
|
|
34
|
+
if endpoint:
|
|
35
|
+
client_coro = ControllerClient.for_endpoint(
|
|
36
|
+
endpoint,
|
|
37
|
+
insecure=insecure,
|
|
38
|
+
insecure_skip_verify=insecure_skip_verify,
|
|
39
|
+
ca_cert_file_path=ca_cert_file_path,
|
|
40
|
+
client_id=client_id,
|
|
41
|
+
client_credentials_secret=client_credentials_secret,
|
|
42
|
+
auth_type=auth_type,
|
|
43
|
+
)
|
|
44
|
+
elif api_key:
|
|
45
|
+
client_coro = ControllerClient.for_api_key(
|
|
46
|
+
api_key,
|
|
47
|
+
insecure=insecure,
|
|
48
|
+
insecure_skip_verify=insecure_skip_verify,
|
|
49
|
+
ca_cert_file_path=ca_cert_file_path,
|
|
50
|
+
client_id=client_id,
|
|
51
|
+
client_credentials_secret=client_credentials_secret,
|
|
52
|
+
auth_type=auth_type,
|
|
53
|
+
)
|
|
54
|
+
|
|
33
55
|
controller = RemoteController(
|
|
34
|
-
client_coro=
|
|
35
|
-
endpoint=endpoint, insecure=insecure, insecure_skip_verify=insecure_skip_verify
|
|
36
|
-
),
|
|
56
|
+
client_coro=client_coro,
|
|
37
57
|
workers=10,
|
|
38
58
|
max_system_retries=5,
|
|
39
59
|
)
|
|
@@ -4,8 +4,8 @@ from dataclasses import dataclass
|
|
|
4
4
|
|
|
5
5
|
from flyteidl.core import execution_pb2
|
|
6
6
|
|
|
7
|
-
from flyte._datastructures import GroupData
|
|
8
7
|
from flyte._protos.workflow import run_definition_pb2, state_service_pb2, task_definition_pb2
|
|
8
|
+
from flyte.models import GroupData
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
@dataclass
|
|
@@ -28,6 +28,7 @@ class Action:
|
|
|
28
28
|
started: bool = False
|
|
29
29
|
retries: int = 0
|
|
30
30
|
client_err: Exception | None = None # This error is set when something goes wrong in the controller.
|
|
31
|
+
cache_key: str | None = None # None means no caching, otherwise it is the version of the cache.
|
|
31
32
|
|
|
32
33
|
@property
|
|
33
34
|
def name(self) -> str:
|
|
@@ -91,6 +92,8 @@ class Action:
|
|
|
91
92
|
if not self.started:
|
|
92
93
|
self.task = action.task
|
|
93
94
|
|
|
95
|
+
self.cache_key = action.cache_key
|
|
96
|
+
|
|
94
97
|
def set_client_error(self, exc: Exception):
|
|
95
98
|
self.client_err = exc
|
|
96
99
|
|
|
@@ -106,6 +109,7 @@ class Action:
|
|
|
106
109
|
task_spec: task_definition_pb2.TaskSpec,
|
|
107
110
|
inputs_uri: str,
|
|
108
111
|
run_output_base: str,
|
|
112
|
+
cache_key: str | None = None,
|
|
109
113
|
) -> Action:
|
|
110
114
|
return cls(
|
|
111
115
|
action_id=sub_action_id,
|
|
@@ -115,6 +119,7 @@ class Action:
|
|
|
115
119
|
task=task_spec,
|
|
116
120
|
inputs_uri=inputs_uri,
|
|
117
121
|
run_output_base=run_output_base,
|
|
122
|
+
cache_key=cache_key,
|
|
118
123
|
)
|
|
119
124
|
|
|
120
125
|
@classmethod
|
|
@@ -130,7 +135,7 @@ class Action:
|
|
|
130
135
|
"""
|
|
131
136
|
from flyte._logging import logger
|
|
132
137
|
|
|
133
|
-
logger.
|
|
138
|
+
logger.debug(f"In Action from_state {obj.action_id} {obj.phase} {obj.output_uri}")
|
|
134
139
|
return cls(
|
|
135
140
|
action_id=obj.action_id,
|
|
136
141
|
parent_action_name=parent_action_name,
|
|
@@ -20,7 +20,11 @@ class ControllerClient:
|
|
|
20
20
|
|
|
21
21
|
@classmethod
|
|
22
22
|
async def for_endpoint(cls, endpoint: str, insecure: bool = False, **kwargs) -> ControllerClient:
|
|
23
|
-
return cls(await create_channel(endpoint, insecure=insecure, **kwargs))
|
|
23
|
+
return cls(await create_channel(endpoint, None, insecure=insecure, **kwargs))
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
async def for_api_key(cls, api_key: str, insecure: bool = False, **kwargs) -> ControllerClient:
|
|
27
|
+
return cls(await create_channel(None, api_key, insecure=insecure, **kwargs))
|
|
24
28
|
|
|
25
29
|
@property
|
|
26
30
|
def state_service(self) -> StateService:
|
|
@@ -1,9 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import concurrent.futures
|
|
5
|
+
import os
|
|
6
|
+
import threading
|
|
4
7
|
from collections import defaultdict
|
|
8
|
+
from collections.abc import Callable
|
|
5
9
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Awaitable, DefaultDict, Tuple, TypeVar
|
|
10
|
+
from typing import Any, AsyncIterable, Awaitable, DefaultDict, Tuple, TypeVar
|
|
7
11
|
|
|
8
12
|
import flyte
|
|
9
13
|
import flyte.errors
|
|
@@ -11,7 +15,6 @@ import flyte.storage as storage
|
|
|
11
15
|
import flyte.types as types
|
|
12
16
|
from flyte._code_bundle import build_pkl_bundle
|
|
13
17
|
from flyte._context import internal_ctx
|
|
14
|
-
from flyte._datastructures import ActionID, NativeInterface, SerializationContext
|
|
15
18
|
from flyte._internal.controllers import TraceInfo
|
|
16
19
|
from flyte._internal.controllers.remote._action import Action
|
|
17
20
|
from flyte._internal.controllers.remote._core import Controller
|
|
@@ -21,16 +24,18 @@ from flyte._internal.runtime.task_serde import translate_task_to_wire
|
|
|
21
24
|
from flyte._logging import logger
|
|
22
25
|
from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
|
|
23
26
|
from flyte._task import TaskTemplate
|
|
27
|
+
from flyte._utils.helpers import _selector_policy
|
|
28
|
+
from flyte.models import ActionID, NativeInterface, SerializationContext
|
|
24
29
|
|
|
25
30
|
R = TypeVar("R")
|
|
26
31
|
|
|
27
32
|
|
|
28
|
-
async def upload_inputs_with_retry(
|
|
33
|
+
async def upload_inputs_with_retry(serialized_inputs: AsyncIterable[bytes] | bytes, inputs_uri: str) -> None:
|
|
29
34
|
"""
|
|
30
35
|
Upload inputs to the specified URI with error handling.
|
|
31
36
|
|
|
32
37
|
Args:
|
|
33
|
-
|
|
38
|
+
serialized_inputs: The serialized inputs to upload
|
|
34
39
|
inputs_uri: The destination URI
|
|
35
40
|
|
|
36
41
|
Raises:
|
|
@@ -38,9 +43,9 @@ async def upload_inputs_with_retry(inputs: convert.Inputs, inputs_uri: str) -> N
|
|
|
38
43
|
"""
|
|
39
44
|
try:
|
|
40
45
|
# TODO Add retry decorator to this
|
|
41
|
-
await
|
|
46
|
+
await storage.put_stream(serialized_inputs, to_path=inputs_uri)
|
|
42
47
|
except Exception as e:
|
|
43
|
-
logger.exception("Failed to upload inputs"
|
|
48
|
+
logger.exception("Failed to upload inputs")
|
|
44
49
|
raise flyte.errors.RuntimeSystemError(type(e).__name__, str(e)) from e
|
|
45
50
|
|
|
46
51
|
|
|
@@ -89,6 +94,10 @@ async def load_and_convert_outputs(iface: NativeInterface, realized_outputs_uri:
|
|
|
89
94
|
return await convert.convert_outputs_to_native(iface, outputs)
|
|
90
95
|
|
|
91
96
|
|
|
97
|
+
def unique_action_name(action_id: ActionID) -> str:
|
|
98
|
+
return f"{action_id.name}_{action_id.run_name}"
|
|
99
|
+
|
|
100
|
+
|
|
92
101
|
class RemoteController(Controller):
|
|
93
102
|
"""
|
|
94
103
|
This a specialized controller that wraps the core controller and performs IO, serialization and deserialization
|
|
@@ -99,7 +108,7 @@ class RemoteController(Controller):
|
|
|
99
108
|
client_coro: Awaitable[ClientSet],
|
|
100
109
|
workers: int,
|
|
101
110
|
max_system_retries: int,
|
|
102
|
-
default_parent_concurrency: int =
|
|
111
|
+
default_parent_concurrency: int = 1000,
|
|
103
112
|
):
|
|
104
113
|
""" """
|
|
105
114
|
super().__init__(
|
|
@@ -111,31 +120,44 @@ class RemoteController(Controller):
|
|
|
111
120
|
self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
|
|
112
121
|
lambda: asyncio.Semaphore(default_parent_concurrency)
|
|
113
122
|
)
|
|
123
|
+
self._parent_action_task_call_sequence: DefaultDict[str, DefaultDict[int, int]] = defaultdict(
|
|
124
|
+
lambda: defaultdict(int)
|
|
125
|
+
)
|
|
126
|
+
self._submit_loop: asyncio.AbstractEventLoop | None = None
|
|
127
|
+
self._submit_thread: threading.Thread | None = None
|
|
114
128
|
|
|
115
|
-
|
|
129
|
+
def generate_task_call_sequence(self, task_obj: object, action_id: ActionID) -> int:
|
|
130
|
+
"""
|
|
131
|
+
Generate a task call sequence for the given task object and action ID.
|
|
132
|
+
This is used to track the number of times a task is called within an action.
|
|
133
|
+
"""
|
|
134
|
+
current_action_sequencer = self._parent_action_task_call_sequence[unique_action_name(action_id)]
|
|
135
|
+
current_task_id = id(task_obj)
|
|
136
|
+
v = current_action_sequencer[current_task_id]
|
|
137
|
+
new_seq = v + 1
|
|
138
|
+
current_action_sequencer[current_task_id] = new_seq
|
|
139
|
+
return new_seq
|
|
140
|
+
|
|
141
|
+
async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
116
142
|
ctx = internal_ctx()
|
|
117
143
|
tctx = ctx.data.task_context
|
|
118
144
|
if tctx is None:
|
|
119
145
|
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
120
146
|
current_action_id = tctx.action
|
|
121
147
|
|
|
122
|
-
inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
|
|
123
|
-
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, _task.name, inputs)
|
|
124
|
-
|
|
125
148
|
# In the case of a regular code bundle, we will just pass it down as it is to the downstream tasks
|
|
126
149
|
# It is not allowed to change the code bundle (for regular code bundles) in the middle of a run.
|
|
127
150
|
code_bundle = tctx.code_bundle
|
|
128
151
|
|
|
129
152
|
if code_bundle and code_bundle.pkl:
|
|
130
|
-
logger.debug(f"Building new pkl bundle for task {
|
|
153
|
+
logger.debug(f"Building new pkl bundle for task {_task.name}")
|
|
131
154
|
code_bundle = await build_pkl_bundle(
|
|
132
155
|
_task,
|
|
133
156
|
upload_to_controlplane=False,
|
|
134
|
-
|
|
157
|
+
upload_from_dataplane_base_path=tctx.run_base_dir,
|
|
135
158
|
)
|
|
136
159
|
|
|
137
|
-
|
|
138
|
-
await upload_inputs_with_retry(inputs, inputs_uri)
|
|
160
|
+
inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
|
|
139
161
|
|
|
140
162
|
root_dir = Path(code_bundle.destination).absolute() if code_bundle else Path.cwd()
|
|
141
163
|
# Don't set output path in sec context because node executor will set it
|
|
@@ -146,12 +168,41 @@ class RemoteController(Controller):
|
|
|
146
168
|
code_bundle=code_bundle,
|
|
147
169
|
version=tctx.version,
|
|
148
170
|
# supplied version.
|
|
149
|
-
input_path=inputs_uri,
|
|
171
|
+
# input_path=inputs_uri,
|
|
150
172
|
image_cache=tctx.compiled_image_cache,
|
|
151
173
|
root_dir=root_dir,
|
|
152
174
|
)
|
|
153
175
|
|
|
154
176
|
task_spec = translate_task_to_wire(_task, new_serialization_context)
|
|
177
|
+
serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
|
|
178
|
+
|
|
179
|
+
inputs_hash = convert.generate_inputs_hash(serialized_inputs)
|
|
180
|
+
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
181
|
+
tctx, task_spec, inputs_hash, _task_call_seq
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
inputs_uri = io.inputs_path(sub_action_output_path)
|
|
185
|
+
await upload_inputs_with_retry(serialized_inputs, inputs_uri)
|
|
186
|
+
|
|
187
|
+
md = task_spec.task_template.metadata
|
|
188
|
+
ignored_input_vars = []
|
|
189
|
+
if len(md.cache_ignore_input_vars) > 0:
|
|
190
|
+
ignored_input_vars = list(md.cache_ignore_input_vars)
|
|
191
|
+
cache_key = None
|
|
192
|
+
if task_spec.task_template.metadata and task_spec.task_template.metadata.discoverable:
|
|
193
|
+
discovery_version = task_spec.task_template.metadata.discovery_version
|
|
194
|
+
cache_key = convert.generate_cache_key_hash(
|
|
195
|
+
_task.name,
|
|
196
|
+
inputs_hash,
|
|
197
|
+
task_spec.task_template.interface,
|
|
198
|
+
discovery_version,
|
|
199
|
+
ignored_input_vars,
|
|
200
|
+
inputs.proto_inputs,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Clear to free memory
|
|
204
|
+
serialized_inputs = None # type: ignore
|
|
205
|
+
inputs_hash = None # type: ignore
|
|
155
206
|
|
|
156
207
|
action = Action.from_task(
|
|
157
208
|
sub_action_id=run_definition_pb2.ActionIdentifier(
|
|
@@ -168,6 +219,7 @@ class RemoteController(Controller):
|
|
|
168
219
|
task_spec=task_spec,
|
|
169
220
|
inputs_uri=inputs_uri,
|
|
170
221
|
run_output_base=tctx.run_base_dir,
|
|
222
|
+
cache_key=cache_key,
|
|
171
223
|
)
|
|
172
224
|
|
|
173
225
|
try:
|
|
@@ -205,8 +257,51 @@ class RemoteController(Controller):
|
|
|
205
257
|
if tctx is None:
|
|
206
258
|
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
207
259
|
current_action_id = tctx.action
|
|
208
|
-
|
|
209
|
-
|
|
260
|
+
task_call_seq = self.generate_task_call_sequence(_task, current_action_id)
|
|
261
|
+
async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
|
|
262
|
+
return await self._submit(task_call_seq, _task, *args, **kwargs)
|
|
263
|
+
|
|
264
|
+
def _sync_thread_loop_runner(self) -> None:
|
|
265
|
+
"""This method runs the event loop and should be invoked in a separate thread."""
|
|
266
|
+
|
|
267
|
+
loop = self._submit_loop
|
|
268
|
+
assert loop is not None
|
|
269
|
+
try:
|
|
270
|
+
loop.run_forever()
|
|
271
|
+
finally:
|
|
272
|
+
loop.close()
|
|
273
|
+
|
|
274
|
+
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
|
|
275
|
+
"""
|
|
276
|
+
This function creates a cached thread and loop for the purpose of calling the submit method synchronously,
|
|
277
|
+
returning a concurrent Future that can be awaited. There's no need for a lock because this function itself is
|
|
278
|
+
single threaded and non-async. This pattern here is basically the trivial/degenerate case of the thread pool
|
|
279
|
+
in the LocalController.
|
|
280
|
+
Please see additional comments in protocol.
|
|
281
|
+
|
|
282
|
+
:param _task:
|
|
283
|
+
:param args:
|
|
284
|
+
:param kwargs:
|
|
285
|
+
:return:
|
|
286
|
+
"""
|
|
287
|
+
if self._submit_thread is None:
|
|
288
|
+
# Please see LocalController for the general implementation of this pattern.
|
|
289
|
+
def exc_handler(loop, context):
|
|
290
|
+
logger.error(f"Remote controller submit sync loop caught exception in {loop}: {context}")
|
|
291
|
+
|
|
292
|
+
with _selector_policy():
|
|
293
|
+
self._submit_loop = asyncio.new_event_loop()
|
|
294
|
+
self._submit_loop.set_exception_handler(exc_handler)
|
|
295
|
+
|
|
296
|
+
self._submit_thread = threading.Thread(
|
|
297
|
+
name=f"remote-controller-{os.getpid()}-submitter", daemon=True, target=self._sync_thread_loop_runner
|
|
298
|
+
)
|
|
299
|
+
self._submit_thread.start()
|
|
300
|
+
|
|
301
|
+
coro = self.submit(_task, *args, **kwargs)
|
|
302
|
+
assert self._submit_loop is not None, "Submit loop should always have been initialized by now"
|
|
303
|
+
fut = asyncio.run_coroutine_threadsafe(coro, self._submit_loop)
|
|
304
|
+
return fut
|
|
210
305
|
|
|
211
306
|
async def finalize_parent_action(self, action_id: ActionID):
|
|
212
307
|
"""
|
|
@@ -220,16 +315,17 @@ class RemoteController(Controller):
|
|
|
220
315
|
org=action_id.org,
|
|
221
316
|
)
|
|
222
317
|
await super()._finalize_parent_action(run_id=run_id, parent_action_name=action_id.name)
|
|
223
|
-
self._parent_action_semaphore.pop(action_id
|
|
318
|
+
self._parent_action_semaphore.pop(unique_action_name(action_id), None)
|
|
319
|
+
self._parent_action_task_call_sequence.pop(unique_action_name(action_id), None)
|
|
224
320
|
|
|
225
321
|
async def get_action_outputs(
|
|
226
|
-
self, _interface: NativeInterface,
|
|
322
|
+
self, _interface: NativeInterface, _func: Callable, *args, **kwargs
|
|
227
323
|
) -> Tuple[TraceInfo, bool]:
|
|
228
324
|
"""
|
|
229
325
|
This method returns the outputs of the action, if it is available.
|
|
230
326
|
If not available it raises a NotFoundError.
|
|
231
327
|
:param _interface: NativeInterface
|
|
232
|
-
:param
|
|
328
|
+
:param _func: Function name
|
|
233
329
|
:param args: Arguments
|
|
234
330
|
:param kwargs: Keyword arguments
|
|
235
331
|
:return:
|
|
@@ -240,11 +336,19 @@ class RemoteController(Controller):
|
|
|
240
336
|
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
241
337
|
current_action_id = tctx.action
|
|
242
338
|
|
|
339
|
+
func_name = _func.__name__
|
|
340
|
+
invoke_seq_num = self.generate_task_call_sequence(_func, current_action_id)
|
|
243
341
|
inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
|
|
244
|
-
|
|
342
|
+
serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
|
|
343
|
+
|
|
344
|
+
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
345
|
+
tctx, func_name, serialized_inputs, invoke_seq_num
|
|
346
|
+
)
|
|
245
347
|
|
|
246
348
|
inputs_uri = io.inputs_path(sub_action_output_path)
|
|
247
|
-
await upload_inputs_with_retry(
|
|
349
|
+
await upload_inputs_with_retry(serialized_inputs, inputs_uri)
|
|
350
|
+
# Clear to free memory
|
|
351
|
+
serialized_inputs = None # type: ignore
|
|
248
352
|
|
|
249
353
|
prev_action = await self.get_action(
|
|
250
354
|
run_definition_pb2.ActionIdentifier(
|
|
@@ -310,12 +414,40 @@ class RemoteController(Controller):
|
|
|
310
414
|
current_action_id = tctx.action
|
|
311
415
|
task_name = _task.spec.task_template.id.name
|
|
312
416
|
|
|
313
|
-
|
|
417
|
+
invoke_seq_num = self.generate_task_call_sequence(_task, current_action_id)
|
|
418
|
+
|
|
419
|
+
native_interface = types.guess_interface(
|
|
420
|
+
_task.spec.task_template.interface, default_inputs=_task.spec.default_inputs
|
|
421
|
+
)
|
|
314
422
|
inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs)
|
|
315
|
-
|
|
423
|
+
serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
|
|
424
|
+
inputs_hash = convert.generate_inputs_hash(serialized_inputs)
|
|
425
|
+
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
426
|
+
tctx, task_name, inputs_hash, invoke_seq_num
|
|
427
|
+
)
|
|
316
428
|
|
|
317
429
|
inputs_uri = io.inputs_path(sub_action_output_path)
|
|
318
|
-
await upload_inputs_with_retry(
|
|
430
|
+
await upload_inputs_with_retry(serialized_inputs, inputs_uri)
|
|
431
|
+
# cache key - task name, task signature, inputs, cache version
|
|
432
|
+
cache_key = None
|
|
433
|
+
md = _task.spec.task_template.metadata
|
|
434
|
+
ignored_input_vars = []
|
|
435
|
+
if len(md.cache_ignore_input_vars) > 0:
|
|
436
|
+
ignored_input_vars = list(md.cache_ignore_input_vars)
|
|
437
|
+
if _task.spec.task_template.metadata and _task.spec.task_template.metadata.discoverable:
|
|
438
|
+
discovery_version = _task.spec.task_template.metadata.discovery_version
|
|
439
|
+
cache_key = convert.generate_cache_key_hash(
|
|
440
|
+
task_name,
|
|
441
|
+
inputs_hash,
|
|
442
|
+
_task.spec.task_template.interface,
|
|
443
|
+
discovery_version,
|
|
444
|
+
ignored_input_vars,
|
|
445
|
+
inputs.proto_inputs,
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
# Clear to free memory
|
|
449
|
+
serialized_inputs = None # type: ignore
|
|
450
|
+
inputs_hash = None # type: ignore
|
|
319
451
|
|
|
320
452
|
action = Action.from_task(
|
|
321
453
|
sub_action_id=run_definition_pb2.ActionIdentifier(
|
|
@@ -332,6 +464,7 @@ class RemoteController(Controller):
|
|
|
332
464
|
task_spec=_task.spec,
|
|
333
465
|
inputs_uri=inputs_uri,
|
|
334
466
|
run_output_base=tctx.run_base_dir,
|
|
467
|
+
cache_key=cache_key,
|
|
335
468
|
)
|
|
336
469
|
|
|
337
470
|
try:
|
|
@@ -7,6 +7,7 @@ from asyncio import Event
|
|
|
7
7
|
from typing import Awaitable, Coroutine, Optional
|
|
8
8
|
|
|
9
9
|
import grpc.aio
|
|
10
|
+
from google.protobuf.wrappers_pb2 import StringValue
|
|
10
11
|
|
|
11
12
|
import flyte.errors
|
|
12
13
|
from flyte._logging import log, logger
|
|
@@ -32,7 +33,7 @@ class Controller:
|
|
|
32
33
|
max_system_retries: int = 5,
|
|
33
34
|
resource_log_interval_sec: float = 10.0,
|
|
34
35
|
min_backoff_on_err_sec: float = 0.1,
|
|
35
|
-
thread_wait_timeout_sec: float =
|
|
36
|
+
thread_wait_timeout_sec: float = 5.0,
|
|
36
37
|
enqueue_timeout_sec: float = 5.0,
|
|
37
38
|
):
|
|
38
39
|
"""
|
|
@@ -286,10 +287,11 @@ class Controller:
|
|
|
286
287
|
if started:
|
|
287
288
|
logger.info(f"Cancelling action: {action.name}")
|
|
288
289
|
try:
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
290
|
+
# TODO add support when the queue service supports aborting actions
|
|
291
|
+
# await self._queue_service.AbortQueuedAction(
|
|
292
|
+
# queue_service_pb2.AbortQueuedActionRequest(action_id=action.action_id),
|
|
293
|
+
# wait_for_ready=True,
|
|
294
|
+
# )
|
|
293
295
|
logger.info(f"Successfully cancelled action: {action.name}")
|
|
294
296
|
except grpc.aio.AioRpcError as e:
|
|
295
297
|
if e.code() in [grpc.StatusCode.NOT_FOUND, grpc.StatusCode.FAILED_PRECONDITION]:
|
|
@@ -310,6 +312,11 @@ class Controller:
|
|
|
310
312
|
if not action.is_started() and action.task is not None:
|
|
311
313
|
logger.debug(f"Attempting to launch action: {action.name}")
|
|
312
314
|
try:
|
|
315
|
+
cache_key = None
|
|
316
|
+
logger.info(f"Action {action.name} has cache version {action.cache_key}")
|
|
317
|
+
if action.cache_key:
|
|
318
|
+
cache_key = StringValue(value=action.cache_key)
|
|
319
|
+
|
|
313
320
|
await self._queue_service.EnqueueAction(
|
|
314
321
|
queue_service_pb2.EnqueueActionRequest(
|
|
315
322
|
action_id=action.action_id,
|
|
@@ -323,6 +330,7 @@ class Controller:
|
|
|
323
330
|
name=action.task.task_template.id.name,
|
|
324
331
|
),
|
|
325
332
|
spec=action.task,
|
|
333
|
+
cache_key=cache_key,
|
|
326
334
|
),
|
|
327
335
|
input_uri=action.inputs_uri,
|
|
328
336
|
run_output_base=action.run_output_base,
|