flyte 0.1.0__py3-none-any.whl → 0.2.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +78 -2
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +152 -0
- flyte/_build.py +26 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +145 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +323 -0
- flyte/_code_bundle/bundle.py +209 -0
- flyte/_context.py +152 -0
- flyte/_deploy.py +243 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +84 -0
- flyte/_excepthook.py +37 -0
- flyte/_group.py +32 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +762 -0
- flyte/_initialize.py +492 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +128 -0
- flyte/_internal/controllers/_local_controller.py +193 -0
- flyte/_internal/controllers/_trace.py +41 -0
- flyte/_internal/controllers/remote/__init__.py +60 -0
- flyte/_internal/controllers/remote/_action.py +146 -0
- flyte/_internal/controllers/remote/_client.py +47 -0
- flyte/_internal/controllers/remote/_controller.py +494 -0
- flyte/_internal/controllers/remote/_core.py +410 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +427 -0
- flyte/_internal/imagebuild/image_builder.py +246 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +342 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +330 -0
- flyte/_internal/runtime/taskrunner.py +191 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +135 -0
- flyte/_map.py +215 -0
- flyte/_pod.py +19 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +71 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +100 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +177 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/common_pb2.py +27 -0
- flyte/_protos/workflow/common_pb2.pyi +14 -0
- flyte/_protos/workflow/common_pb2_grpc.py +4 -0
- flyte/_protos/workflow/environment_pb2.py +29 -0
- flyte/_protos/workflow/environment_pb2.pyi +12 -0
- flyte/_protos/workflow/environment_pb2_grpc.py +4 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +105 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +146 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +314 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +129 -0
- flyte/_protos/workflow/run_service_pb2.pyi +171 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +66 -0
- flyte/_protos/workflow/state_service_pb2.pyi +75 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +79 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +81 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +60 -0
- flyte/_protos/workflow/task_service_pb2.pyi +59 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +138 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +482 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +449 -0
- flyte/_task_environment.py +183 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +120 -0
- flyte/_utils/__init__.py +26 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/async_cache.py +139 -0
- flyte/_utils/coro_management.py +23 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +134 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/org_discovery.py +57 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/cli/__init__.py +3 -0
- flyte/cli/_abort.py +28 -0
- flyte/cli/_common.py +337 -0
- flyte/cli/_create.py +145 -0
- flyte/cli/_delete.py +23 -0
- flyte/cli/_deploy.py +152 -0
- flyte/cli/_gen.py +163 -0
- flyte/cli/_get.py +310 -0
- flyte/cli/_params.py +538 -0
- flyte/cli/_run.py +231 -0
- flyte/cli/main.py +166 -0
- flyte/config/__init__.py +3 -0
- flyte/config/_config.py +216 -0
- flyte/config/_internal.py +64 -0
- flyte/config/_reader.py +207 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +172 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +263 -0
- flyte/io/__init__.py +27 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +467 -0
- flyte/io/_structured_dataset/__init__.py +129 -0
- flyte/io/_structured_dataset/basic_dfs.py +219 -0
- flyte/io/_structured_dataset/structured_dataset.py +1061 -0
- flyte/models.py +391 -0
- flyte/remote/__init__.py +26 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +133 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_auth_utils.py +14 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +215 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +159 -0
- flyte/remote/_logs.py +176 -0
- flyte/remote/_project.py +85 -0
- flyte/remote/_run.py +970 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +391 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +29 -0
- flyte/storage/_config.py +233 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +271 -0
- flyte/storage/_utils.py +5 -0
- flyte/syncify/__init__.py +56 -0
- flyte/syncify/_api.py +371 -0
- flyte/types/__init__.py +36 -0
- flyte/types/_interface.py +40 -0
- flyte/types/_pickle.py +118 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2287 -0
- flyte/types/_utils.py +80 -0
- flyte-0.2.0a0.dist-info/METADATA +249 -0
- flyte-0.2.0a0.dist-info/RECORD +218 -0
- {flyte-0.1.0.dist-info → flyte-0.2.0a0.dist-info}/WHEEL +2 -1
- flyte-0.2.0a0.dist-info/entry_points.txt +3 -0
- flyte-0.2.0a0.dist-info/top_level.txt +1 -0
- flyte-0.1.0.dist-info/METADATA +0 -6
- flyte-0.1.0.dist-info/RECORD +0 -5
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import os
|
|
3
|
+
import pathlib
|
|
4
|
+
import sys
|
|
5
|
+
from typing import Tuple
|
|
6
|
+
|
|
7
|
+
from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def extract_task_module(task: TaskTemplate, /, source_dir: pathlib.Path | None = None) -> Tuple[str, str]:
|
|
11
|
+
"""
|
|
12
|
+
Extract the task module from the task template.
|
|
13
|
+
|
|
14
|
+
:param task: The task template to extract the module from.
|
|
15
|
+
:param source_dir: The source directory to use for relative paths.
|
|
16
|
+
:return: A tuple containing the entity name, module
|
|
17
|
+
"""
|
|
18
|
+
entity_name = task.name
|
|
19
|
+
if isinstance(task, AsyncFunctionTaskTemplate):
|
|
20
|
+
entity_module = inspect.getmodule(task.func)
|
|
21
|
+
if entity_module is None:
|
|
22
|
+
raise ValueError(f"Task {entity_name} has no module.")
|
|
23
|
+
|
|
24
|
+
fp = entity_module.__file__
|
|
25
|
+
if fp is None:
|
|
26
|
+
raise ValueError(f"Task {entity_name} has no module.")
|
|
27
|
+
|
|
28
|
+
file_path = pathlib.Path(fp)
|
|
29
|
+
# Get the relative path to the current directory
|
|
30
|
+
# Will raise ValueError if the file is not in the source directory
|
|
31
|
+
relative_path = file_path.relative_to(str(source_dir))
|
|
32
|
+
|
|
33
|
+
if relative_path == pathlib.Path("."):
|
|
34
|
+
entity_module_name = entity_module.__name__
|
|
35
|
+
else:
|
|
36
|
+
# Replace file separators with dots and remove the '.py' extension
|
|
37
|
+
dotted_path = os.path.splitext(str(relative_path))[0].replace(os.sep, ".")
|
|
38
|
+
entity_module_name = dotted_path
|
|
39
|
+
|
|
40
|
+
entity_name = task.func.__name__
|
|
41
|
+
else:
|
|
42
|
+
raise NotImplementedError(f"Task module {entity_name} not implemented.")
|
|
43
|
+
|
|
44
|
+
if entity_module_name == "__main__":
|
|
45
|
+
"""
|
|
46
|
+
This case is for the case in which the task is run from the main module.
|
|
47
|
+
"""
|
|
48
|
+
fp = sys.modules["__main__"].__file__
|
|
49
|
+
if fp is None:
|
|
50
|
+
raise ValueError(f"Task {entity_name} has no module.")
|
|
51
|
+
main_path = pathlib.Path(fp)
|
|
52
|
+
entity_module_name = main_path.stem
|
|
53
|
+
|
|
54
|
+
return entity_name, entity_module_name
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from asyncio import Protocol
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
from flyte._task import TaskTemplate
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Resolver(Protocol):
|
|
9
|
+
"""
|
|
10
|
+
Resolver interface for loading tasks. This interface should be implemented by Resolvers.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def import_path(self) -> str:
|
|
15
|
+
"""
|
|
16
|
+
The import path of the resolver. This should be a valid python import path.
|
|
17
|
+
"""
|
|
18
|
+
return ""
|
|
19
|
+
|
|
20
|
+
def load_task(self, loader_args: List[str]) -> TaskTemplate:
|
|
21
|
+
"""
|
|
22
|
+
Given the set of identifier keys, should return one TaskTemplate or raise an error if not found
|
|
23
|
+
"""
|
|
24
|
+
raise NotImplementedError
|
|
25
|
+
|
|
26
|
+
def loader_args(self, t: TaskTemplate, root_dir: Optional[Path]) -> List[str]:
|
|
27
|
+
"""
|
|
28
|
+
Return a list of strings that can help identify the parameter TaskTemplate. Each string should not have
|
|
29
|
+
spaces or special characters. This is used to identify the task in the resolver.
|
|
30
|
+
"""
|
|
31
|
+
return []
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
from flyte._internal.resolvers._task_module import extract_task_module
|
|
6
|
+
from flyte._internal.resolvers.common import Resolver
|
|
7
|
+
from flyte._task import TaskTemplate
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DefaultTaskResolver(Resolver):
|
|
11
|
+
"""
|
|
12
|
+
Please see the notes in the TaskResolverMixin as it describes this default behavior.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def import_path(self) -> str:
|
|
17
|
+
return "flyte._internal.resolvers.default.DefaultTaskResolver"
|
|
18
|
+
|
|
19
|
+
def load_task(self, loader_args: List[str]) -> TaskTemplate:
|
|
20
|
+
_, task_module, _, task_name, *_ = loader_args
|
|
21
|
+
|
|
22
|
+
task_module = importlib.import_module(name=task_module) # type: ignore
|
|
23
|
+
task_def = getattr(task_module, task_name)
|
|
24
|
+
return task_def
|
|
25
|
+
|
|
26
|
+
def loader_args(self, task: TaskTemplate, root_dir: Optional[Path] = None) -> List[str]: # type:ignore
|
|
27
|
+
t, m = extract_task_module(task, root_dir)
|
|
28
|
+
return ["mod", m, "instance", t]
|
|
File without changes
|
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import base64
|
|
5
|
+
import hashlib
|
|
6
|
+
import inspect
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from types import NoneType
|
|
9
|
+
from typing import Any, Dict, List, Tuple, Union, get_args
|
|
10
|
+
|
|
11
|
+
from flyteidl.core import execution_pb2, interface_pb2, literals_pb2
|
|
12
|
+
|
|
13
|
+
import flyte.errors
|
|
14
|
+
import flyte.storage as storage
|
|
15
|
+
from flyte._protos.workflow import common_pb2, run_definition_pb2, task_definition_pb2
|
|
16
|
+
from flyte.models import ActionID, NativeInterface, TaskContext
|
|
17
|
+
from flyte.types import TypeEngine, TypeTransformerFailedError
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class Inputs:
|
|
22
|
+
proto_inputs: run_definition_pb2.Inputs
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def empty(cls) -> "Inputs":
|
|
26
|
+
return cls(proto_inputs=run_definition_pb2.Inputs())
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(frozen=True)
|
|
30
|
+
class Outputs:
|
|
31
|
+
proto_outputs: run_definition_pb2.Outputs
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class Error:
|
|
36
|
+
err: execution_pb2.ExecutionError
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# ------------------------------- CONVERT Methods ------------------------------- #
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _clean_error_code(code: str) -> Tuple[str, str | None]:
|
|
43
|
+
"""
|
|
44
|
+
The error code may have a server injected code and is of the form `RetriesExhausedError|<code>` or `<code>`.
|
|
45
|
+
|
|
46
|
+
:param code:
|
|
47
|
+
:return: "user code", optional server code
|
|
48
|
+
"""
|
|
49
|
+
if "|" in code:
|
|
50
|
+
server_code, user_code = code.split("|", 1)
|
|
51
|
+
return user_code.strip(), server_code.strip()
|
|
52
|
+
return code.strip(), None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
async def convert_inputs_to_native(inputs: Inputs, python_interface: NativeInterface) -> Dict[str, Any]:
|
|
56
|
+
literals = {named_literal.name: named_literal.value for named_literal in inputs.proto_inputs.literals}
|
|
57
|
+
native_vals = await TypeEngine.literal_map_to_kwargs(
|
|
58
|
+
literals_pb2.LiteralMap(literals=literals), python_interface.get_input_types()
|
|
59
|
+
)
|
|
60
|
+
return native_vals
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
async def convert_upload_default_inputs(interface: NativeInterface) -> List[common_pb2.NamedParameter]:
|
|
64
|
+
"""
|
|
65
|
+
Converts the default inputs of a NativeInterface to a list of NamedParameters for upload.
|
|
66
|
+
This is used to upload default inputs to the Flyte backend.
|
|
67
|
+
"""
|
|
68
|
+
if not interface.inputs:
|
|
69
|
+
return []
|
|
70
|
+
|
|
71
|
+
vars = []
|
|
72
|
+
literal_coros = []
|
|
73
|
+
for input_name, (input_type, default_value) in interface.inputs.items():
|
|
74
|
+
if default_value is not inspect.Parameter.empty:
|
|
75
|
+
lt = TypeEngine.to_literal_type(input_type)
|
|
76
|
+
literal_coros.append(TypeEngine.to_literal(default_value, input_type, lt))
|
|
77
|
+
vars.append((input_name, lt))
|
|
78
|
+
|
|
79
|
+
literals: List[literals_pb2.Literal] = await asyncio.gather(*literal_coros)
|
|
80
|
+
named_params = []
|
|
81
|
+
for (name, lt), literal in zip(vars, literals):
|
|
82
|
+
param = interface_pb2.Parameter(
|
|
83
|
+
var=interface_pb2.Variable(
|
|
84
|
+
type=lt,
|
|
85
|
+
),
|
|
86
|
+
default=literal,
|
|
87
|
+
)
|
|
88
|
+
named_params.append(
|
|
89
|
+
common_pb2.NamedParameter(
|
|
90
|
+
name=name,
|
|
91
|
+
parameter=param,
|
|
92
|
+
),
|
|
93
|
+
)
|
|
94
|
+
return named_params
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def is_optional_type(tp) -> bool:
|
|
98
|
+
"""
|
|
99
|
+
True if the *annotation* `tp` is equivalent to Optional[…].
|
|
100
|
+
Works for Optional[T], Union[T, None], and T | None.
|
|
101
|
+
"""
|
|
102
|
+
return NoneType in get_args(tp) # fastest check
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwargs) -> Inputs:
|
|
106
|
+
kwargs = interface.convert_to_kwargs(*args, **kwargs)
|
|
107
|
+
|
|
108
|
+
if len(kwargs) < interface.num_required_inputs():
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"Received {len(kwargs)} inputs but interface has {interface.num_required_inputs()} required inputs. "
|
|
111
|
+
f"Please provide all required inputs. Inputs received: {kwargs}, interface: {interface}"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if len(interface.inputs) == 0:
|
|
115
|
+
return Inputs.empty()
|
|
116
|
+
|
|
117
|
+
# fill in defaults if missing
|
|
118
|
+
type_hints: Dict[str, type] = {}
|
|
119
|
+
already_converted_kwargs: Dict[str, literals_pb2.Literal] = {}
|
|
120
|
+
for input_name, (input_type, default_value) in interface.inputs.items():
|
|
121
|
+
if input_name in kwargs:
|
|
122
|
+
type_hints[input_name] = input_type
|
|
123
|
+
elif (default_value is not None and default_value is not inspect.Signature.empty) or (
|
|
124
|
+
default_value is None and is_optional_type(input_type)
|
|
125
|
+
):
|
|
126
|
+
if default_value == NativeInterface.has_default:
|
|
127
|
+
if interface._remote_defaults is None or input_name not in interface._remote_defaults:
|
|
128
|
+
raise ValueError(f"Input '{input_name}' has a default value but it is not set in the interface.")
|
|
129
|
+
already_converted_kwargs[input_name] = interface._remote_defaults[input_name]
|
|
130
|
+
else:
|
|
131
|
+
kwargs[input_name] = default_value
|
|
132
|
+
type_hints[input_name] = input_type
|
|
133
|
+
|
|
134
|
+
literal_map = await TypeEngine.dict_to_literal_map(kwargs, type_hints)
|
|
135
|
+
if len(already_converted_kwargs) > 0:
|
|
136
|
+
copied_literals: Dict[str, literals_pb2.Literal] = {}
|
|
137
|
+
for k, v in literal_map.literals.items():
|
|
138
|
+
copied_literals[k] = v
|
|
139
|
+
# Add the already converted kwargs to the literal map
|
|
140
|
+
for k, v in already_converted_kwargs.items():
|
|
141
|
+
copied_literals[k] = v
|
|
142
|
+
literal_map = literals_pb2.LiteralMap(literals=copied_literals)
|
|
143
|
+
# Make sure we the interface, not literal_map or kwargs, because those may have a different order
|
|
144
|
+
return Inputs(
|
|
145
|
+
proto_inputs=run_definition_pb2.Inputs(
|
|
146
|
+
literals=[
|
|
147
|
+
run_definition_pb2.NamedLiteral(name=k, value=literal_map.literals[k]) for k in interface.inputs.keys()
|
|
148
|
+
]
|
|
149
|
+
)
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
async def convert_from_native_to_outputs(o: Any, interface: NativeInterface, task_name: str = "") -> Outputs:
|
|
154
|
+
# Always make it a tuple even if it's just one item to simplify logic below
|
|
155
|
+
if not isinstance(o, tuple):
|
|
156
|
+
o = (o,)
|
|
157
|
+
|
|
158
|
+
assert len(interface.outputs) == len(interface.outputs), (
|
|
159
|
+
f"Received {len(o)} outputs but interface has {len(interface.outputs)}"
|
|
160
|
+
)
|
|
161
|
+
named = []
|
|
162
|
+
for (output_name, python_type), v in zip(interface.outputs.items(), o):
|
|
163
|
+
try:
|
|
164
|
+
lit = await TypeEngine.to_literal(v, python_type, TypeEngine.to_literal_type(python_type))
|
|
165
|
+
named.append(run_definition_pb2.NamedLiteral(name=output_name, value=lit))
|
|
166
|
+
except TypeTransformerFailedError as e:
|
|
167
|
+
raise flyte.errors.RuntimeDataValidationError(output_name, e, task_name)
|
|
168
|
+
|
|
169
|
+
return Outputs(proto_outputs=run_definition_pb2.Outputs(literals=named))
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
async def convert_outputs_to_native(interface: NativeInterface, outputs: Outputs) -> Union[Any, Tuple[Any, ...]]:
|
|
173
|
+
lm = literals_pb2.LiteralMap(
|
|
174
|
+
literals={named_literal.name: named_literal.value for named_literal in outputs.proto_outputs.literals}
|
|
175
|
+
)
|
|
176
|
+
kwargs = await TypeEngine.literal_map_to_kwargs(lm, interface.outputs)
|
|
177
|
+
if len(kwargs) == 0:
|
|
178
|
+
return None
|
|
179
|
+
elif len(kwargs) == 1:
|
|
180
|
+
return next(iter(kwargs.values()))
|
|
181
|
+
else:
|
|
182
|
+
# Return as tuple if multiple outputs, make sure to order correctly as it seems proto maps can change ordering
|
|
183
|
+
return tuple(kwargs[k] for k in interface.outputs.keys())
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def convert_error_to_native(err: execution_pb2.ExecutionError | Exception | Error) -> Exception | None:
|
|
187
|
+
if not err:
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
if isinstance(err, Exception):
|
|
191
|
+
return err
|
|
192
|
+
|
|
193
|
+
if isinstance(err, Error):
|
|
194
|
+
err = err.err
|
|
195
|
+
|
|
196
|
+
user_code, server_code = _clean_error_code(err.code)
|
|
197
|
+
match err.kind:
|
|
198
|
+
case execution_pb2.ExecutionError.UNKNOWN:
|
|
199
|
+
return flyte.errors.RuntimeUnknownError(code=user_code, message=err.message, worker=err.worker)
|
|
200
|
+
case execution_pb2.ExecutionError.USER:
|
|
201
|
+
if "OOM" in err.code.upper():
|
|
202
|
+
return flyte.errors.OOMError(code=user_code, message=err.message, worker=err.worker)
|
|
203
|
+
elif "Interrupted" in err.code:
|
|
204
|
+
return flyte.errors.TaskInterruptedError(code=user_code, message=err.message, worker=err.worker)
|
|
205
|
+
elif "PrimaryContainerNotFound" in err.code:
|
|
206
|
+
return flyte.errors.PrimaryContainerNotFoundError(
|
|
207
|
+
code=user_code, message=err.message, worker=err.worker
|
|
208
|
+
)
|
|
209
|
+
elif "RetriesExhausted" in err.code:
|
|
210
|
+
return flyte.errors.RetriesExhaustedError(code=user_code, message=err.message, worker=err.worker)
|
|
211
|
+
elif "Unknown" in err.code:
|
|
212
|
+
return flyte.errors.RuntimeUnknownError(code=user_code, message=err.message, worker=err.worker)
|
|
213
|
+
elif "InvalidImageName" in err.code:
|
|
214
|
+
return flyte.errors.InvalidImageNameError(code=user_code, message=err.message, worker=err.worker)
|
|
215
|
+
elif "ImagePullBackOff" in err.code:
|
|
216
|
+
return flyte.errors.ImagePullBackOffError(code=user_code, message=err.message, worker=err.worker)
|
|
217
|
+
return flyte.errors.RuntimeUserError(code=user_code, message=err.message, worker=err.worker)
|
|
218
|
+
case execution_pb2.ExecutionError.SYSTEM:
|
|
219
|
+
return flyte.errors.RuntimeSystemError(code=user_code, message=err.message, worker=err.worker)
|
|
220
|
+
return None
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def convert_from_native_to_error(err: BaseException) -> Error:
|
|
224
|
+
if isinstance(err, flyte.errors.RuntimeUnknownError):
|
|
225
|
+
return Error(
|
|
226
|
+
err=execution_pb2.ExecutionError(
|
|
227
|
+
kind=execution_pb2.ExecutionError.UNKNOWN,
|
|
228
|
+
code=err.code,
|
|
229
|
+
message=str(err),
|
|
230
|
+
worker=err.worker,
|
|
231
|
+
)
|
|
232
|
+
)
|
|
233
|
+
elif isinstance(err, flyte.errors.RuntimeUserError):
|
|
234
|
+
return Error(
|
|
235
|
+
err=execution_pb2.ExecutionError(
|
|
236
|
+
kind=execution_pb2.ExecutionError.USER,
|
|
237
|
+
code=err.code,
|
|
238
|
+
message=str(err),
|
|
239
|
+
worker=err.worker,
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
elif isinstance(err, flyte.errors.RuntimeSystemError):
|
|
243
|
+
return Error(
|
|
244
|
+
err=execution_pb2.ExecutionError(
|
|
245
|
+
kind=execution_pb2.ExecutionError.SYSTEM,
|
|
246
|
+
code=err.code,
|
|
247
|
+
message=str(err),
|
|
248
|
+
worker=err.worker,
|
|
249
|
+
)
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
return Error(
|
|
253
|
+
err=execution_pb2.ExecutionError(
|
|
254
|
+
kind=execution_pb2.ExecutionError.UNKNOWN,
|
|
255
|
+
code=type(err).__name__,
|
|
256
|
+
message=str(err),
|
|
257
|
+
worker="UNKNOWN",
|
|
258
|
+
)
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def hash_data(data: Union[str, bytes]) -> str:
|
|
263
|
+
"""
|
|
264
|
+
Generate a hash for the given data. If the data is a string, it will be encoded to bytes before hashing.
|
|
265
|
+
:param data: The data to hash, can be a string or bytes.
|
|
266
|
+
:return: A hexadecimal string representation of the hash.
|
|
267
|
+
"""
|
|
268
|
+
if isinstance(data, str):
|
|
269
|
+
data = data.encode("utf-8")
|
|
270
|
+
digest = hashlib.sha256(data).digest()
|
|
271
|
+
return base64.b64encode(digest).decode("utf-8")
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def generate_inputs_hash(serialized_inputs: str | bytes) -> str:
|
|
275
|
+
"""
|
|
276
|
+
Generate a hash for the inputs. This is used to uniquely identify the inputs for a task.
|
|
277
|
+
:return: A hexadecimal string representation of the hash.
|
|
278
|
+
"""
|
|
279
|
+
return hash_data(serialized_inputs)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def generate_cache_key_hash(
|
|
283
|
+
task_name: str,
|
|
284
|
+
inputs_hash: str,
|
|
285
|
+
task_interface: interface_pb2.TypedInterface,
|
|
286
|
+
cache_version: str,
|
|
287
|
+
ignored_input_vars: List[str],
|
|
288
|
+
proto_inputs: run_definition_pb2.Inputs,
|
|
289
|
+
) -> str:
|
|
290
|
+
"""
|
|
291
|
+
Generate a cache key hash based on the inputs hash, task name, task interface, and cache version.
|
|
292
|
+
This is used to uniquely identify the cache key for a task.
|
|
293
|
+
|
|
294
|
+
:param task_name: The name of the task.
|
|
295
|
+
:param inputs_hash: The hash of the inputs.
|
|
296
|
+
:param task_interface: The interface of the task.
|
|
297
|
+
:param cache_version: The version of the cache.
|
|
298
|
+
:param ignored_input_vars: A list of input variable names to ignore when generating the cache key.
|
|
299
|
+
:param proto_inputs: The proto inputs for the task, only used if there are ignored inputs.
|
|
300
|
+
:return: A hexadecimal string representation of the cache key hash.
|
|
301
|
+
"""
|
|
302
|
+
if ignored_input_vars:
|
|
303
|
+
filtered = [named_lit for named_lit in proto_inputs.literals if named_lit.name not in ignored_input_vars]
|
|
304
|
+
final = run_definition_pb2.Inputs(literals=filtered)
|
|
305
|
+
final_inputs = final.SerializeToString(deterministic=True)
|
|
306
|
+
else:
|
|
307
|
+
final_inputs = inputs_hash
|
|
308
|
+
data = f"{final_inputs}{task_name}{task_interface.SerializeToString(deterministic=True)}{cache_version}"
|
|
309
|
+
return hash_data(data)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def generate_sub_action_id_and_output_path(
|
|
313
|
+
tctx: TaskContext,
|
|
314
|
+
task_spec_or_name: task_definition_pb2.TaskSpec | str,
|
|
315
|
+
inputs_hash: str,
|
|
316
|
+
invoke_seq: int,
|
|
317
|
+
) -> Tuple[ActionID, str]:
|
|
318
|
+
"""
|
|
319
|
+
Generate a sub-action ID and output path based on the current task context, task name, and inputs.
|
|
320
|
+
|
|
321
|
+
action name = current action name + task name + input hash + group name (if available)
|
|
322
|
+
:param tctx:
|
|
323
|
+
:param task_spec_or_name: task specification or task name. Task name is only used in case of trace actions.
|
|
324
|
+
:param inputs_hash: Consistent hash string of the inputs
|
|
325
|
+
:param invoke_seq: The sequence number of the invocation, used to differentiate between multiple invocations.
|
|
326
|
+
:return:
|
|
327
|
+
"""
|
|
328
|
+
current_action_id = tctx.action
|
|
329
|
+
current_output_path = tctx.run_base_dir
|
|
330
|
+
if isinstance(task_spec_or_name, task_definition_pb2.TaskSpec):
|
|
331
|
+
task_spec_or_name.task_template.interface
|
|
332
|
+
task_hash = hash_data(task_spec_or_name.SerializeToString(deterministic=True))
|
|
333
|
+
else:
|
|
334
|
+
task_hash = task_spec_or_name
|
|
335
|
+
sub_action_id = current_action_id.new_sub_action_from(
|
|
336
|
+
task_hash=task_hash,
|
|
337
|
+
input_hash=inputs_hash,
|
|
338
|
+
group=tctx.group_data.name if tctx.group_data else None,
|
|
339
|
+
task_call_seq=invoke_seq,
|
|
340
|
+
)
|
|
341
|
+
sub_run_output_path = storage.join(current_output_path, sub_action_id.name)
|
|
342
|
+
return sub_action_id, sub_run_output_path
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import flyte.errors
|
|
4
|
+
from flyte._code_bundle import download_bundle
|
|
5
|
+
from flyte._context import contextual_run
|
|
6
|
+
from flyte._internal import Controller
|
|
7
|
+
from flyte._internal.imagebuild.image_builder import ImageCache
|
|
8
|
+
from flyte._logging import log, logger
|
|
9
|
+
from flyte._task import TaskTemplate
|
|
10
|
+
from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
|
|
11
|
+
|
|
12
|
+
from .convert import Error, Inputs, Outputs
|
|
13
|
+
from .task_serde import load_task
|
|
14
|
+
from .taskrunner import (
|
|
15
|
+
convert_and_run,
|
|
16
|
+
extract_download_run_upload,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
async def direct_dispatch(
|
|
21
|
+
task: TaskTemplate,
|
|
22
|
+
*,
|
|
23
|
+
action: ActionID,
|
|
24
|
+
raw_data_path: RawDataPath,
|
|
25
|
+
controller: Controller,
|
|
26
|
+
version: str,
|
|
27
|
+
output_path: str,
|
|
28
|
+
run_base_dir: str,
|
|
29
|
+
checkpoints: Checkpoints | None = None,
|
|
30
|
+
code_bundle: CodeBundle | None = None,
|
|
31
|
+
inputs: Inputs | None = None,
|
|
32
|
+
) -> Tuple[Optional[Outputs], Optional[Error]]:
|
|
33
|
+
"""
|
|
34
|
+
This method is used today by the local_controller and is positioned to be used by a rust core in the future.
|
|
35
|
+
The caller, loads the task and invokes this method. This method is used to convert the inputs to native types,
|
|
36
|
+
The reason for this is that the rust entrypoint will not have access to the python context, and
|
|
37
|
+
will not be able to run the tasks in the context tree.
|
|
38
|
+
"""
|
|
39
|
+
return await contextual_run(
|
|
40
|
+
convert_and_run,
|
|
41
|
+
task=task,
|
|
42
|
+
inputs=inputs or Inputs.empty(),
|
|
43
|
+
action=action,
|
|
44
|
+
raw_data_path=raw_data_path,
|
|
45
|
+
checkpoints=checkpoints,
|
|
46
|
+
code_bundle=code_bundle,
|
|
47
|
+
controller=controller,
|
|
48
|
+
version=version,
|
|
49
|
+
output_path=output_path,
|
|
50
|
+
run_base_dir=run_base_dir,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
async def _download_and_load_task(
|
|
55
|
+
code_bundle: CodeBundle | None, resolver: str | None = None, resolver_args: List[str] | None = None
|
|
56
|
+
) -> TaskTemplate:
|
|
57
|
+
if code_bundle and (code_bundle.tgz or code_bundle.pkl):
|
|
58
|
+
logger.debug(f"Downloading {code_bundle}")
|
|
59
|
+
downloaded_path = await download_bundle(code_bundle)
|
|
60
|
+
code_bundle = code_bundle.with_downloaded_path(downloaded_path)
|
|
61
|
+
if code_bundle.pkl:
|
|
62
|
+
try:
|
|
63
|
+
logger.debug(f"Loading task from pkl: {code_bundle.downloaded_path}")
|
|
64
|
+
import gzip
|
|
65
|
+
|
|
66
|
+
import cloudpickle
|
|
67
|
+
|
|
68
|
+
with gzip.open(str(code_bundle.downloaded_path), "rb") as f:
|
|
69
|
+
return cloudpickle.load(f)
|
|
70
|
+
except Exception as e:
|
|
71
|
+
logger.exception(f"Failed to load pickled task from {code_bundle.downloaded_path}. Reason: {e!s}")
|
|
72
|
+
raise
|
|
73
|
+
|
|
74
|
+
if not resolver or not resolver_args:
|
|
75
|
+
raise flyte.errors.RuntimeSystemError(
|
|
76
|
+
"MalformedCommand", "Resolver and resolver args are required. for task"
|
|
77
|
+
)
|
|
78
|
+
logger.debug(
|
|
79
|
+
f"Loading task from tgz: {code_bundle.downloaded_path}, resolver: {resolver}, args: {resolver_args}"
|
|
80
|
+
)
|
|
81
|
+
return load_task(resolver, *resolver_args)
|
|
82
|
+
if not resolver or not resolver_args:
|
|
83
|
+
raise flyte.errors.RuntimeSystemError("MalformedCommand", "Resolver and resolver args are required. for task")
|
|
84
|
+
logger.debug(f"No code bundle provided, loading task from resolver: {resolver}, args: {resolver_args}")
|
|
85
|
+
return load_task(resolver, *resolver_args)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@log
|
|
89
|
+
async def load_and_run_task(
|
|
90
|
+
action: ActionID,
|
|
91
|
+
raw_data_path: RawDataPath,
|
|
92
|
+
output_path: str,
|
|
93
|
+
run_base_dir: str,
|
|
94
|
+
version: str,
|
|
95
|
+
controller: Controller,
|
|
96
|
+
resolver: str,
|
|
97
|
+
resolver_args: List[str],
|
|
98
|
+
checkpoints: Checkpoints | None = None,
|
|
99
|
+
code_bundle: CodeBundle | None = None,
|
|
100
|
+
input_path: str | None = None,
|
|
101
|
+
image_cache: ImageCache | None = None,
|
|
102
|
+
):
|
|
103
|
+
"""
|
|
104
|
+
This method is invoked from the runtime/CLI and is used to run a task. This creates the context tree,
|
|
105
|
+
for the tasks to run in. It also handles the loading of the task.
|
|
106
|
+
|
|
107
|
+
:param controller: Controller to use for the task.
|
|
108
|
+
:param resolver: The resolver to use to load the task.
|
|
109
|
+
:param resolver_args: The arguments to pass to the resolver.
|
|
110
|
+
:param action: The ActionID to use for the task.
|
|
111
|
+
:param raw_data_path: The raw data path to use for the task.
|
|
112
|
+
:param output_path: The output path to use for the task.
|
|
113
|
+
:param run_base_dir: Base output directory to pass down to child tasks.
|
|
114
|
+
:param version: The version of the task to run.
|
|
115
|
+
:param checkpoints: The checkpoints to use for the task.
|
|
116
|
+
:param code_bundle: The code bundle to use for the task.
|
|
117
|
+
:param input_path: The input path to use for the task.
|
|
118
|
+
:param image_cache: Mappings of Image identifiers to image URIs.
|
|
119
|
+
"""
|
|
120
|
+
task = await _download_and_load_task(code_bundle, resolver, resolver_args)
|
|
121
|
+
|
|
122
|
+
await contextual_run(
|
|
123
|
+
extract_download_run_upload,
|
|
124
|
+
task,
|
|
125
|
+
action=action,
|
|
126
|
+
version=version,
|
|
127
|
+
controller=controller,
|
|
128
|
+
raw_data_path=raw_data_path,
|
|
129
|
+
output_path=output_path,
|
|
130
|
+
run_base_dir=run_base_dir,
|
|
131
|
+
checkpoints=checkpoints,
|
|
132
|
+
code_bundle=code_bundle,
|
|
133
|
+
input_path=input_path,
|
|
134
|
+
image_cache=image_cache,
|
|
135
|
+
)
|