flyte 0.1.0__py3-none-any.whl → 0.2.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +78 -2
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +152 -0
- flyte/_build.py +26 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +145 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +323 -0
- flyte/_code_bundle/bundle.py +209 -0
- flyte/_context.py +152 -0
- flyte/_deploy.py +243 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +84 -0
- flyte/_excepthook.py +37 -0
- flyte/_group.py +32 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +762 -0
- flyte/_initialize.py +492 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +128 -0
- flyte/_internal/controllers/_local_controller.py +193 -0
- flyte/_internal/controllers/_trace.py +41 -0
- flyte/_internal/controllers/remote/__init__.py +60 -0
- flyte/_internal/controllers/remote/_action.py +146 -0
- flyte/_internal/controllers/remote/_client.py +47 -0
- flyte/_internal/controllers/remote/_controller.py +494 -0
- flyte/_internal/controllers/remote/_core.py +410 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +427 -0
- flyte/_internal/imagebuild/image_builder.py +246 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +342 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +330 -0
- flyte/_internal/runtime/taskrunner.py +191 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +135 -0
- flyte/_map.py +215 -0
- flyte/_pod.py +19 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +71 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +100 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +177 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/common_pb2.py +27 -0
- flyte/_protos/workflow/common_pb2.pyi +14 -0
- flyte/_protos/workflow/common_pb2_grpc.py +4 -0
- flyte/_protos/workflow/environment_pb2.py +29 -0
- flyte/_protos/workflow/environment_pb2.pyi +12 -0
- flyte/_protos/workflow/environment_pb2_grpc.py +4 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +105 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +146 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +314 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +129 -0
- flyte/_protos/workflow/run_service_pb2.pyi +171 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +66 -0
- flyte/_protos/workflow/state_service_pb2.pyi +75 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +79 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +81 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +60 -0
- flyte/_protos/workflow/task_service_pb2.pyi +59 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +138 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +482 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +449 -0
- flyte/_task_environment.py +183 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +120 -0
- flyte/_utils/__init__.py +26 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/async_cache.py +139 -0
- flyte/_utils/coro_management.py +23 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +134 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/org_discovery.py +57 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/cli/__init__.py +3 -0
- flyte/cli/_abort.py +28 -0
- flyte/cli/_common.py +337 -0
- flyte/cli/_create.py +145 -0
- flyte/cli/_delete.py +23 -0
- flyte/cli/_deploy.py +152 -0
- flyte/cli/_gen.py +163 -0
- flyte/cli/_get.py +310 -0
- flyte/cli/_params.py +538 -0
- flyte/cli/_run.py +231 -0
- flyte/cli/main.py +166 -0
- flyte/config/__init__.py +3 -0
- flyte/config/_config.py +216 -0
- flyte/config/_internal.py +64 -0
- flyte/config/_reader.py +207 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +172 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +263 -0
- flyte/io/__init__.py +27 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +467 -0
- flyte/io/_structured_dataset/__init__.py +129 -0
- flyte/io/_structured_dataset/basic_dfs.py +219 -0
- flyte/io/_structured_dataset/structured_dataset.py +1061 -0
- flyte/models.py +391 -0
- flyte/remote/__init__.py +26 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +133 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_auth_utils.py +14 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +215 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +159 -0
- flyte/remote/_logs.py +176 -0
- flyte/remote/_project.py +85 -0
- flyte/remote/_run.py +970 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +391 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +29 -0
- flyte/storage/_config.py +233 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +271 -0
- flyte/storage/_utils.py +5 -0
- flyte/syncify/__init__.py +56 -0
- flyte/syncify/_api.py +371 -0
- flyte/types/__init__.py +36 -0
- flyte/types/_interface.py +40 -0
- flyte/types/_pickle.py +118 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2287 -0
- flyte/types/_utils.py +80 -0
- flyte-0.2.0a0.dist-info/METADATA +249 -0
- flyte-0.2.0a0.dist-info/RECORD +218 -0
- {flyte-0.1.0.dist-info → flyte-0.2.0a0.dist-info}/WHEEL +2 -1
- flyte-0.2.0a0.dist-info/entry_points.txt +3 -0
- flyte-0.2.0a0.dist-info/top_level.txt +1 -0
- flyte-0.1.0.dist-info/METADATA +0 -6
- flyte-0.1.0.dist-info/RECORD +0 -5
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the methods for uploading and downloading inputs and outputs.
|
|
3
|
+
It uses the storage module to handle the actual uploading and downloading of files.
|
|
4
|
+
|
|
5
|
+
TODO: Convert to use streaming apis
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
|
|
10
|
+
from flyteidl.core import errors_pb2, execution_pb2
|
|
11
|
+
|
|
12
|
+
import flyte.storage as storage
|
|
13
|
+
from flyte._protos.workflow import run_definition_pb2
|
|
14
|
+
|
|
15
|
+
from ..._logging import log
|
|
16
|
+
from .convert import Inputs, Outputs, _clean_error_code
|
|
17
|
+
|
|
18
|
+
# ------------------------------- CONSTANTS ------------------------------- #
|
|
19
|
+
_INPUTS_FILE_NAME = "inputs.pb"
|
|
20
|
+
_OUTPUTS_FILE_NAME = "outputs.pb"
|
|
21
|
+
_CHECKPOINT_FILE_NAME = "_flytecheckpoints"
|
|
22
|
+
_ERROR_FILE_NAME = "error.pb"
|
|
23
|
+
_REPORT_FILE_NAME = "report.html"
|
|
24
|
+
_PKL_EXT = ".pkl.gz"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def pkl_path(base_path: str, pkl_name: str) -> str:
|
|
28
|
+
return storage.join(base_path, f"{pkl_name}{_PKL_EXT}")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def inputs_path(base_path: str) -> str:
|
|
32
|
+
return storage.join(base_path, _INPUTS_FILE_NAME)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def outputs_path(base_path: str) -> str:
|
|
36
|
+
return storage.join(base_path, _OUTPUTS_FILE_NAME)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def error_path(base_path: str) -> str:
|
|
40
|
+
return storage.join(base_path, _ERROR_FILE_NAME)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def report_path(base_path: str) -> str:
|
|
44
|
+
return storage.join(base_path, _REPORT_FILE_NAME)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# ------------------------------- UPLOAD Methods ------------------------------- #
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
async def upload_inputs(inputs: Inputs, input_path: str):
|
|
51
|
+
"""
|
|
52
|
+
:param Inputs inputs: Inputs
|
|
53
|
+
:param str input_path: The path to upload the input file.
|
|
54
|
+
"""
|
|
55
|
+
await storage.put_stream(data_iterable=inputs.proto_inputs.SerializeToString(), to_path=input_path)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
async def upload_outputs(outputs: Outputs, output_path: str):
|
|
59
|
+
"""
|
|
60
|
+
:param outputs: Outputs
|
|
61
|
+
:param output_path: The path to upload the output file.
|
|
62
|
+
"""
|
|
63
|
+
output_uri = outputs_path(output_path)
|
|
64
|
+
await storage.put_stream(data_iterable=outputs.proto_outputs.SerializeToString(), to_path=output_uri)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str):
|
|
68
|
+
"""
|
|
69
|
+
:param err: execution_pb2.ExecutionError
|
|
70
|
+
:param output_prefix: The output prefix of the remote uri.
|
|
71
|
+
"""
|
|
72
|
+
# TODO - clean this up + conditionally set kind
|
|
73
|
+
error_document = errors_pb2.ErrorDocument(
|
|
74
|
+
error=errors_pb2.ContainerError(
|
|
75
|
+
code=err.code,
|
|
76
|
+
message=err.message,
|
|
77
|
+
kind=errors_pb2.ContainerError.RECOVERABLE,
|
|
78
|
+
origin=err.kind,
|
|
79
|
+
timestamp=err.timestamp,
|
|
80
|
+
worker=err.worker,
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
error_uri = error_path(output_prefix)
|
|
84
|
+
await storage.put_stream(data_iterable=error_document.SerializeToString(), to_path=error_uri)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# ------------------------------- DOWNLOAD Methods ------------------------------- #
|
|
88
|
+
@log(level=logging.INFO)
|
|
89
|
+
async def load_inputs(path: str) -> Inputs:
|
|
90
|
+
"""
|
|
91
|
+
:param path: Input file to be downloaded
|
|
92
|
+
:return: Inputs object
|
|
93
|
+
"""
|
|
94
|
+
lm = run_definition_pb2.Inputs()
|
|
95
|
+
proto_str = b"".join([c async for c in storage.get_stream(path=path)])
|
|
96
|
+
lm.ParseFromString(proto_str)
|
|
97
|
+
return Inputs(proto_inputs=lm)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
async def load_outputs(path: str) -> Outputs:
|
|
101
|
+
"""
|
|
102
|
+
:param path: output file to be loaded
|
|
103
|
+
:return: Outputs object
|
|
104
|
+
"""
|
|
105
|
+
lm = run_definition_pb2.Outputs()
|
|
106
|
+
proto_str = b"".join([c async for c in storage.get_stream(path=path)])
|
|
107
|
+
lm.ParseFromString(proto_str)
|
|
108
|
+
return Outputs(proto_outputs=lm)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
async def load_error(path: str) -> execution_pb2.ExecutionError:
|
|
112
|
+
"""
|
|
113
|
+
:param path: error file to be downloaded
|
|
114
|
+
:return: execution_pb2.ExecutionError
|
|
115
|
+
"""
|
|
116
|
+
err = errors_pb2.ErrorDocument()
|
|
117
|
+
proto_str = b"".join([c async for c in storage.get_stream(path=path)])
|
|
118
|
+
err.ParseFromString(proto_str)
|
|
119
|
+
|
|
120
|
+
if err.error is not None:
|
|
121
|
+
user_code, server_code = _clean_error_code(err.error.code)
|
|
122
|
+
return execution_pb2.ExecutionError(
|
|
123
|
+
code=user_code,
|
|
124
|
+
message=err.error.message,
|
|
125
|
+
kind=err.error.origin,
|
|
126
|
+
error_uri=path,
|
|
127
|
+
timestamp=err.error.timestamp,
|
|
128
|
+
worker=err.error.worker,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return execution_pb2.ExecutionError(
|
|
132
|
+
code="Unknown",
|
|
133
|
+
message=f"Received unloadable error from path {path}",
|
|
134
|
+
kind=execution_pb2.ExecutionError.SYSTEM,
|
|
135
|
+
error_uri=path,
|
|
136
|
+
)
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
from flyteidl.core import tasks_pb2
|
|
4
|
+
|
|
5
|
+
from flyte._resources import CPUBaseType, Resources
|
|
6
|
+
|
|
7
|
+
ACCELERATOR_DEVICE_MAP = {
|
|
8
|
+
"A100": "nvidia-tesla-a100",
|
|
9
|
+
"A100 80G": "nvidia-a100-80gb",
|
|
10
|
+
"A10": "nvidia-a10",
|
|
11
|
+
"A10G": "nvidia-a10g",
|
|
12
|
+
"A100G": "nvidia-a100g",
|
|
13
|
+
"L4": "nvidia-l4",
|
|
14
|
+
"L40s": "nvidia-l40",
|
|
15
|
+
"L4_VWS": "nvidia-l4-vws",
|
|
16
|
+
"K80": "nvidia-tesla-k80",
|
|
17
|
+
"M60": "nvidia-tesla-m60",
|
|
18
|
+
"P4": "nvidia-tesla-p4",
|
|
19
|
+
"P100": "nvidia-tesla-p100",
|
|
20
|
+
"T4": "nvidia-tesla-t4",
|
|
21
|
+
"V100": "nvidia-tesla-v100",
|
|
22
|
+
"V5E": "tpu-v5-lite-podslice",
|
|
23
|
+
"V5P": "tpu-v5p-slice",
|
|
24
|
+
"V6E": "tpu-v6e-slice",
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _get_cpu_resource_entry(cpu: CPUBaseType) -> tasks_pb2.Resources.ResourceEntry:
|
|
29
|
+
return tasks_pb2.Resources.ResourceEntry(
|
|
30
|
+
name=tasks_pb2.Resources.ResourceName.CPU,
|
|
31
|
+
value=str(cpu),
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _get_memory_resource_entry(memory: str) -> tasks_pb2.Resources.ResourceEntry:
|
|
36
|
+
return tasks_pb2.Resources.ResourceEntry(
|
|
37
|
+
name=tasks_pb2.Resources.ResourceName.MEMORY,
|
|
38
|
+
value=memory,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _get_gpu_resource_entry(gpu: int) -> tasks_pb2.Resources.ResourceEntry:
|
|
43
|
+
return tasks_pb2.Resources.ResourceEntry(
|
|
44
|
+
name=tasks_pb2.Resources.ResourceName.GPU,
|
|
45
|
+
value=str(gpu),
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _get_gpu_extended_resource_entry(resources: Resources) -> Optional[tasks_pb2.GPUAccelerator]:
|
|
50
|
+
if resources is None:
|
|
51
|
+
return None
|
|
52
|
+
if resources.gpu is None or isinstance(resources.gpu, int):
|
|
53
|
+
return None
|
|
54
|
+
device = resources.get_device()
|
|
55
|
+
if device is None:
|
|
56
|
+
return None
|
|
57
|
+
if device.device not in ACCELERATOR_DEVICE_MAP:
|
|
58
|
+
raise ValueError(f"GPU of type {device.device} unknown, cannot map to device name")
|
|
59
|
+
return tasks_pb2.GPUAccelerator(
|
|
60
|
+
device=ACCELERATOR_DEVICE_MAP[device.device],
|
|
61
|
+
partition_size=device.partition if device.partition else None,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _get_disk_resource_entry(disk: str) -> tasks_pb2.Resources.ResourceEntry:
|
|
66
|
+
return tasks_pb2.Resources.ResourceEntry(
|
|
67
|
+
name=tasks_pb2.Resources.ResourceName.EPHEMERAL_STORAGE,
|
|
68
|
+
value=disk,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_proto_extended_resources(resources: Resources | None) -> Optional[tasks_pb2.ExtendedResources]:
|
|
73
|
+
"""
|
|
74
|
+
TODO Implement partitioning logic string handling for GPU
|
|
75
|
+
:param resources:
|
|
76
|
+
"""
|
|
77
|
+
if resources is None:
|
|
78
|
+
return None
|
|
79
|
+
acc = _get_gpu_extended_resource_entry(resources)
|
|
80
|
+
shm = resources.get_shared_memory()
|
|
81
|
+
if acc is None and shm is None:
|
|
82
|
+
return None
|
|
83
|
+
proto_shm = None
|
|
84
|
+
if shm is not None:
|
|
85
|
+
proto_shm = tasks_pb2.SharedMemory(
|
|
86
|
+
mount_path="/dev/shm",
|
|
87
|
+
mount_name="flyte-shm",
|
|
88
|
+
size_limit=shm,
|
|
89
|
+
)
|
|
90
|
+
return tasks_pb2.ExtendedResources(gpu_accelerator=acc, shared_memory=proto_shm)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _convert_resources_to_resource_entries(
|
|
94
|
+
resources: Resources | None,
|
|
95
|
+
) -> Tuple[List[tasks_pb2.Resources.ResourceEntry], List[tasks_pb2.Resources.ResourceEntry]]:
|
|
96
|
+
request_entries: List[tasks_pb2.Resources.ResourceEntry] = []
|
|
97
|
+
limit_entries: List[tasks_pb2.Resources.ResourceEntry] = []
|
|
98
|
+
if resources is None:
|
|
99
|
+
return request_entries, limit_entries
|
|
100
|
+
if resources.cpu is not None:
|
|
101
|
+
if isinstance(resources.cpu, tuple):
|
|
102
|
+
request_entries.append(_get_cpu_resource_entry(resources.cpu[0]))
|
|
103
|
+
limit_entries.append(_get_cpu_resource_entry(resources.cpu[1]))
|
|
104
|
+
else:
|
|
105
|
+
request_entries.append(_get_cpu_resource_entry(resources.cpu))
|
|
106
|
+
|
|
107
|
+
if resources.memory is not None:
|
|
108
|
+
if isinstance(resources.memory, tuple):
|
|
109
|
+
request_entries.append(_get_memory_resource_entry(resources.memory[0]))
|
|
110
|
+
limit_entries.append(_get_memory_resource_entry(resources.memory[1]))
|
|
111
|
+
else:
|
|
112
|
+
request_entries.append(_get_memory_resource_entry(resources.memory))
|
|
113
|
+
|
|
114
|
+
if resources.gpu is not None:
|
|
115
|
+
device = resources.get_device()
|
|
116
|
+
if device is not None:
|
|
117
|
+
request_entries.append(_get_gpu_resource_entry(device.quantity))
|
|
118
|
+
|
|
119
|
+
if resources.disk is not None:
|
|
120
|
+
request_entries.append(_get_disk_resource_entry(resources.disk))
|
|
121
|
+
|
|
122
|
+
return request_entries, limit_entries
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def get_proto_resources(resources: Resources | None) -> Optional[tasks_pb2.Resources]:
|
|
126
|
+
"""
|
|
127
|
+
Get main resources IDL representation from the resources object
|
|
128
|
+
|
|
129
|
+
:param resources: User facing Resources object containing potentially both requests and limits
|
|
130
|
+
:return: The given resources as requests and limits
|
|
131
|
+
"""
|
|
132
|
+
if resources is None:
|
|
133
|
+
return None
|
|
134
|
+
request_entries, limit_entries = _convert_resources_to_resource_entries(resources)
|
|
135
|
+
if not request_entries and not limit_entries:
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
return tasks_pb2.Resources(requests=request_entries, limits=limit_entries)
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides functionality to serialize and deserialize tasks to and from the wire format.
|
|
3
|
+
It includes a Resolver interface for loading tasks, and functions to load classes and tasks.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import copy
|
|
7
|
+
import importlib
|
|
8
|
+
import typing
|
|
9
|
+
from datetime import timedelta
|
|
10
|
+
from typing import Optional, Type, cast
|
|
11
|
+
|
|
12
|
+
from flyteidl.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
|
|
13
|
+
from google.protobuf import duration_pb2, wrappers_pb2
|
|
14
|
+
|
|
15
|
+
import flyte.errors
|
|
16
|
+
from flyte._cache.cache import VersionParameters, cache_from_request
|
|
17
|
+
from flyte._logging import logger
|
|
18
|
+
from flyte._pod import _PRIMARY_CONTAINER_NAME_FIELD, PodTemplate
|
|
19
|
+
from flyte._protos.workflow import common_pb2, environment_pb2, task_definition_pb2
|
|
20
|
+
from flyte._secret import SecretRequest, secrets_from_request
|
|
21
|
+
from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
|
|
22
|
+
from flyte.models import CodeBundle, SerializationContext
|
|
23
|
+
|
|
24
|
+
from ..._retry import RetryStrategy
|
|
25
|
+
from ..._timeout import TimeoutType, timeout_from_request
|
|
26
|
+
from .resources_serde import get_proto_extended_resources, get_proto_resources
|
|
27
|
+
from .types_serde import transform_native_to_typed_interface
|
|
28
|
+
|
|
29
|
+
_MAX_ENV_NAME_LENGTH = 63 # Maximum length for environment names
|
|
30
|
+
_MAX_TASK_SHORT_NAME_LENGTH = 63 # Maximum length for task short names
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def load_class(qualified_name) -> Type:
|
|
34
|
+
"""
|
|
35
|
+
Load a class from a qualified name. The qualified name should be in the format 'module.ClassName'.
|
|
36
|
+
:param qualified_name: The qualified name of the class to load.
|
|
37
|
+
:return: The class object.
|
|
38
|
+
"""
|
|
39
|
+
module_name, class_name = qualified_name.rsplit(".", 1) # Split module and class
|
|
40
|
+
module = importlib.import_module(module_name) # Import the module
|
|
41
|
+
return getattr(module, class_name) # Retrieve the class
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def load_task(resolver: str, *resolver_args: str) -> TaskTemplate:
|
|
45
|
+
"""
|
|
46
|
+
Load a task from a resolver. This is a placeholder function.
|
|
47
|
+
|
|
48
|
+
:param resolver: The resolver to use to load the task.
|
|
49
|
+
:param resolver_args: Arguments to pass to the resolver.
|
|
50
|
+
:return: The loaded task.
|
|
51
|
+
"""
|
|
52
|
+
resolver_class = load_class(resolver)
|
|
53
|
+
resolver_instance = resolver_class()
|
|
54
|
+
return resolver_instance.load_task(resolver_args)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def translate_task_to_wire(
|
|
58
|
+
task: TaskTemplate,
|
|
59
|
+
serialization_context: SerializationContext,
|
|
60
|
+
default_inputs: Optional[typing.List[common_pb2.NamedParameter]] = None,
|
|
61
|
+
) -> task_definition_pb2.TaskSpec:
|
|
62
|
+
"""
|
|
63
|
+
Translate a task to a wire format. This is a placeholder function.
|
|
64
|
+
|
|
65
|
+
:param task: The task to translate.
|
|
66
|
+
:param serialization_context: The serialization context to use for the translation.
|
|
67
|
+
:param default_inputs: Optional list of default inputs for the task.
|
|
68
|
+
|
|
69
|
+
:return: The translated task.
|
|
70
|
+
"""
|
|
71
|
+
tt = get_proto_task(task, serialization_context)
|
|
72
|
+
env: environment_pb2.Environment | None = None
|
|
73
|
+
if task.parent_env and task.parent_env():
|
|
74
|
+
_env = task.parent_env()
|
|
75
|
+
if _env:
|
|
76
|
+
env = environment_pb2.Environment(name=_env.name[:_MAX_ENV_NAME_LENGTH])
|
|
77
|
+
return task_definition_pb2.TaskSpec(
|
|
78
|
+
task_template=tt,
|
|
79
|
+
default_inputs=default_inputs,
|
|
80
|
+
short_name=task.friendly_name[:_MAX_TASK_SHORT_NAME_LENGTH],
|
|
81
|
+
environment=env,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def get_security_context(secrets: Optional[SecretRequest]) -> Optional[security_pb2.SecurityContext]:
|
|
86
|
+
"""
|
|
87
|
+
Get the security context from a list of secrets. This is a placeholder function.
|
|
88
|
+
|
|
89
|
+
:param secrets: The list of secrets to use for the security context.
|
|
90
|
+
|
|
91
|
+
:return: The security context.
|
|
92
|
+
"""
|
|
93
|
+
if secrets is None:
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
secret_list = secrets_from_request(secrets)
|
|
97
|
+
return security_pb2.SecurityContext(
|
|
98
|
+
secrets=[
|
|
99
|
+
security_pb2.Secret(
|
|
100
|
+
group=secret.group,
|
|
101
|
+
key=secret.key,
|
|
102
|
+
mount_requirement=(
|
|
103
|
+
security_pb2.Secret.MountType.ENV_VAR if secret.as_env_var else security_pb2.Secret.MountType.FILE
|
|
104
|
+
),
|
|
105
|
+
env_var=secret.as_env_var,
|
|
106
|
+
)
|
|
107
|
+
for secret in secret_list
|
|
108
|
+
]
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def get_proto_retry_strategy(retries: RetryStrategy | int | None) -> Optional[literals_pb2.RetryStrategy]:
|
|
113
|
+
if retries is None:
|
|
114
|
+
return None
|
|
115
|
+
|
|
116
|
+
if isinstance(retries, int):
|
|
117
|
+
raise AssertionError("Retries should be an instance of RetryStrategy, not int")
|
|
118
|
+
|
|
119
|
+
return literals_pb2.RetryStrategy(retries=retries.count)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def get_proto_timeout(timeout: TimeoutType | None) -> Optional[duration_pb2.Duration]:
|
|
123
|
+
if timeout is None:
|
|
124
|
+
return None
|
|
125
|
+
max_runtime_timeout = timeout_from_request(timeout).max_runtime
|
|
126
|
+
if isinstance(max_runtime_timeout, int):
|
|
127
|
+
max_runtime_timeout = timedelta(seconds=max_runtime_timeout)
|
|
128
|
+
return duration_pb2.Duration(seconds=max_runtime_timeout.seconds)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext) -> tasks_pb2.TaskTemplate:
|
|
132
|
+
task_id = identifier_pb2.Identifier(
|
|
133
|
+
resource_type=identifier_pb2.ResourceType.TASK,
|
|
134
|
+
project=serialize_context.project,
|
|
135
|
+
domain=serialize_context.domain,
|
|
136
|
+
org=serialize_context.org,
|
|
137
|
+
name=task.name,
|
|
138
|
+
version=serialize_context.version,
|
|
139
|
+
)
|
|
140
|
+
# TODO, there will be tasks that do not have images, handle that case
|
|
141
|
+
# if task.parent_env is None:
|
|
142
|
+
# raise ValueError(f"Task {task.name} must have a parent environment")
|
|
143
|
+
|
|
144
|
+
# TODO Add support for SQL, extra_config, custom
|
|
145
|
+
extra_config: typing.Dict[str, str] = {}
|
|
146
|
+
custom = {} # type: ignore
|
|
147
|
+
|
|
148
|
+
sql = None
|
|
149
|
+
if task.pod_template and not isinstance(task.pod_template, str):
|
|
150
|
+
container = None
|
|
151
|
+
pod = _get_k8s_pod(_get_urun_container(serialize_context, task), task.pod_template)
|
|
152
|
+
extra_config[_PRIMARY_CONTAINER_NAME_FIELD] = task.pod_template.primary_container_name
|
|
153
|
+
else:
|
|
154
|
+
container = _get_urun_container(serialize_context, task)
|
|
155
|
+
pod = None
|
|
156
|
+
|
|
157
|
+
# -------------- CACHE HANDLING ----------------------
|
|
158
|
+
task_cache = cache_from_request(task.cache)
|
|
159
|
+
cache_enabled = task_cache.is_enabled()
|
|
160
|
+
cache_version = None
|
|
161
|
+
|
|
162
|
+
if task_cache.is_enabled():
|
|
163
|
+
logger.debug(f"Cache enabled for task {task.name}")
|
|
164
|
+
if serialize_context.code_bundle and serialize_context.code_bundle.pkl:
|
|
165
|
+
logger.debug(f"Detected pkl bundle for task {task.name}, using computed version as cache version")
|
|
166
|
+
cache_version = serialize_context.code_bundle.computed_version
|
|
167
|
+
else:
|
|
168
|
+
version_parameters = None
|
|
169
|
+
if isinstance(task, AsyncFunctionTaskTemplate):
|
|
170
|
+
version_parameters = VersionParameters(func=task.func, image=task.image)
|
|
171
|
+
else:
|
|
172
|
+
version_parameters = VersionParameters(func=None, image=task.image)
|
|
173
|
+
cache_version = task_cache.get_version(version_parameters)
|
|
174
|
+
logger.debug(f"Cache version for task {task.name} is {cache_version}")
|
|
175
|
+
else:
|
|
176
|
+
logger.debug(f"Cache disabled for task {task.name}")
|
|
177
|
+
|
|
178
|
+
return tasks_pb2.TaskTemplate(
|
|
179
|
+
id=task_id,
|
|
180
|
+
type=task.task_type,
|
|
181
|
+
metadata=tasks_pb2.TaskMetadata(
|
|
182
|
+
discoverable=cache_enabled,
|
|
183
|
+
discovery_version=cache_version,
|
|
184
|
+
cache_serializable=task_cache.serialize,
|
|
185
|
+
cache_ignore_input_vars=task_cache.get_ignored_inputs() if cache_enabled else None,
|
|
186
|
+
runtime=tasks_pb2.RuntimeMetadata(
|
|
187
|
+
version=flyte.version(),
|
|
188
|
+
type=tasks_pb2.RuntimeMetadata.RuntimeType.FLYTE_SDK,
|
|
189
|
+
flavor="python",
|
|
190
|
+
),
|
|
191
|
+
retries=get_proto_retry_strategy(task.retries),
|
|
192
|
+
timeout=get_proto_timeout(task.timeout),
|
|
193
|
+
pod_template_name=task.pod_template if task.pod_template and isinstance(task.pod_template, str) else None,
|
|
194
|
+
interruptible=task.interruptable,
|
|
195
|
+
generates_deck=wrappers_pb2.BoolValue(value=task.report),
|
|
196
|
+
),
|
|
197
|
+
interface=transform_native_to_typed_interface(task.native_interface),
|
|
198
|
+
custom=custom,
|
|
199
|
+
container=container,
|
|
200
|
+
task_type_version=task.task_type_version,
|
|
201
|
+
security_context=get_security_context(task.secrets),
|
|
202
|
+
config=extra_config,
|
|
203
|
+
k8s_pod=pod,
|
|
204
|
+
sql=sql,
|
|
205
|
+
extended_resources=get_proto_extended_resources(task.resources),
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _get_urun_container(
|
|
210
|
+
serialize_context: SerializationContext, task_template: TaskTemplate
|
|
211
|
+
) -> Optional[tasks_pb2.Container]:
|
|
212
|
+
env = (
|
|
213
|
+
[literals_pb2.KeyValuePair(key=k, value=v) for k, v in task_template.env.items()] if task_template.env else None
|
|
214
|
+
)
|
|
215
|
+
resources = get_proto_resources(task_template.resources)
|
|
216
|
+
# pr: under what conditions should this return None?
|
|
217
|
+
if isinstance(task_template.image, str):
|
|
218
|
+
raise flyte.errors.RuntimeSystemError("BadConfig", "Image is not a valid image")
|
|
219
|
+
image_id = task_template.image.identifier
|
|
220
|
+
if not serialize_context.image_cache or image_id not in serialize_context.image_cache.image_lookup:
|
|
221
|
+
# This computes the image uri, computing hashes as necessary so can fail if done remotely.
|
|
222
|
+
img_uri = task_template.image.uri
|
|
223
|
+
else:
|
|
224
|
+
img_uri = serialize_context.image_cache.image_lookup[image_id]
|
|
225
|
+
|
|
226
|
+
return tasks_pb2.Container(
|
|
227
|
+
image=img_uri,
|
|
228
|
+
command=[],
|
|
229
|
+
args=task_template.container_args(serialize_context),
|
|
230
|
+
resources=resources,
|
|
231
|
+
env=env,
|
|
232
|
+
data_config=task_template.data_loading_config(serialize_context),
|
|
233
|
+
config=task_template.config(serialize_context),
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _sanitize_resource_name(resource: tasks_pb2.Resources.ResourceEntry) -> str:
|
|
238
|
+
return tasks_pb2.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _get_k8s_pod(primary_container: tasks_pb2.Container, pod_template: PodTemplate) -> Optional[tasks_pb2.K8sPod]:
|
|
242
|
+
"""
|
|
243
|
+
Get the K8sPod representation of the task template.
|
|
244
|
+
:param task: The task to convert.
|
|
245
|
+
:return: The K8sPod representation of the task template.
|
|
246
|
+
"""
|
|
247
|
+
from kubernetes.client import ApiClient, V1PodSpec
|
|
248
|
+
from kubernetes.client.models import V1EnvVar, V1ResourceRequirements
|
|
249
|
+
|
|
250
|
+
pod_template = copy.deepcopy(pod_template)
|
|
251
|
+
containers = cast(V1PodSpec, pod_template.pod_spec).containers
|
|
252
|
+
primary_exists = False
|
|
253
|
+
|
|
254
|
+
for container in containers:
|
|
255
|
+
if container.name == pod_template.primary_container_name:
|
|
256
|
+
primary_exists = True
|
|
257
|
+
break
|
|
258
|
+
|
|
259
|
+
if not primary_exists:
|
|
260
|
+
raise ValueError(
|
|
261
|
+
"No primary container defined in the pod spec."
|
|
262
|
+
f" You must define a primary container with the name '{pod_template.primary_container_name}'."
|
|
263
|
+
)
|
|
264
|
+
final_containers = []
|
|
265
|
+
|
|
266
|
+
for container in containers:
|
|
267
|
+
# We overwrite the primary container attributes with the values given to ContainerTask.
|
|
268
|
+
# The attributes include: image, command, args, resource, and env (env is unioned)
|
|
269
|
+
|
|
270
|
+
if container.name == pod_template.primary_container_name:
|
|
271
|
+
if container.image is None:
|
|
272
|
+
# Copy the image from primary_container only if the image is not specified in the pod spec.
|
|
273
|
+
container.image = primary_container.image
|
|
274
|
+
|
|
275
|
+
container.command = list(primary_container.command)
|
|
276
|
+
container.args = list(primary_container.args)
|
|
277
|
+
|
|
278
|
+
limits, requests = {}, {}
|
|
279
|
+
for resource in primary_container.resources.limits:
|
|
280
|
+
limits[_sanitize_resource_name(resource)] = resource.value
|
|
281
|
+
for resource in primary_container.resources.requests:
|
|
282
|
+
requests[_sanitize_resource_name(resource)] = resource.value
|
|
283
|
+
|
|
284
|
+
resource_requirements = V1ResourceRequirements(limits=limits, requests=requests)
|
|
285
|
+
if len(limits) > 0 or len(requests) > 0:
|
|
286
|
+
# Important! Only copy over resource requirements if they are non-empty.
|
|
287
|
+
container.resources = resource_requirements
|
|
288
|
+
|
|
289
|
+
if primary_container.env is not None:
|
|
290
|
+
container.env = [V1EnvVar(name=e.key, value=e.value) for e in primary_container.env] + (
|
|
291
|
+
container.env or []
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
final_containers.append(container)
|
|
295
|
+
|
|
296
|
+
cast(V1PodSpec, pod_template.pod_spec).containers = final_containers
|
|
297
|
+
pod_spec = ApiClient().sanitize_for_serialization(pod_template.pod_spec)
|
|
298
|
+
|
|
299
|
+
metadata = tasks_pb2.K8sObjectMetadata(labels=pod_template.labels, annotations=pod_template.annotations)
|
|
300
|
+
return tasks_pb2.K8sPod(pod_spec=pod_spec, metadata=metadata)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def extract_code_bundle(task_spec: task_definition_pb2.TaskSpec) -> Optional[CodeBundle]:
|
|
304
|
+
"""
|
|
305
|
+
Extract the code bundle from the task spec.
|
|
306
|
+
:param task_spec: The task spec to extract the code bundle from.
|
|
307
|
+
:return: The extracted code bundle or None if not present.
|
|
308
|
+
"""
|
|
309
|
+
container = task_spec.task_template.container
|
|
310
|
+
if container and container.args:
|
|
311
|
+
pkl_path = None
|
|
312
|
+
tgz_path = None
|
|
313
|
+
dest_path: str = "."
|
|
314
|
+
version = ""
|
|
315
|
+
for i, v in enumerate(container.args):
|
|
316
|
+
if v == "--pkl":
|
|
317
|
+
# Extract the code bundle path from the argument
|
|
318
|
+
pkl_path = container.args[i + 1] if i + 1 < len(container.args) else None
|
|
319
|
+
elif v == "--tgz":
|
|
320
|
+
# Extract the code bundle path from the argument
|
|
321
|
+
tgz_path = container.args[i + 1] if i + 1 < len(container.args) else None
|
|
322
|
+
elif v == "--dest":
|
|
323
|
+
# Extract the destination path from the argument
|
|
324
|
+
dest_path = container.args[i + 1] if i + 1 < len(container.args) else "."
|
|
325
|
+
elif v == "--version":
|
|
326
|
+
# Extract the version from the argument
|
|
327
|
+
version = container.args[i + 1] if i + 1 < len(container.args) else ""
|
|
328
|
+
if pkl_path or tgz_path:
|
|
329
|
+
return CodeBundle(destination=dest_path, tgz=tgz_path, pkl=pkl_path, computed_version=version)
|
|
330
|
+
return None
|