flyte 2.0.0b32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +108 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/debug.py +38 -0
- flyte/_bin/runtime.py +195 -0
- flyte/_bin/serve.py +178 -0
- flyte/_build.py +26 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +147 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/local_cache.py +216 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +121 -0
- flyte/_code_bundle/_packaging.py +218 -0
- flyte/_code_bundle/_utils.py +347 -0
- flyte/_code_bundle/bundle.py +266 -0
- flyte/_constants.py +1 -0
- flyte/_context.py +155 -0
- flyte/_custom_context.py +73 -0
- flyte/_debug/__init__.py +0 -0
- flyte/_debug/constants.py +38 -0
- flyte/_debug/utils.py +17 -0
- flyte/_debug/vscode.py +307 -0
- flyte/_deploy.py +408 -0
- flyte/_deployer.py +109 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +122 -0
- flyte/_excepthook.py +37 -0
- flyte/_group.py +32 -0
- flyte/_hash.py +8 -0
- flyte/_image.py +1055 -0
- flyte/_initialize.py +628 -0
- flyte/_interface.py +119 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +129 -0
- flyte/_internal/controllers/_local_controller.py +239 -0
- flyte/_internal/controllers/_trace.py +48 -0
- flyte/_internal/controllers/remote/__init__.py +58 -0
- flyte/_internal/controllers/remote/_action.py +211 -0
- flyte/_internal/controllers/remote/_client.py +47 -0
- flyte/_internal/controllers/remote/_controller.py +583 -0
- flyte/_internal/controllers/remote/_core.py +465 -0
- flyte/_internal/controllers/remote/_informer.py +381 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +3 -0
- flyte/_internal/imagebuild/docker_builder.py +706 -0
- flyte/_internal/imagebuild/image_builder.py +277 -0
- flyte/_internal/imagebuild/remote_builder.py +386 -0
- flyte/_internal/imagebuild/utils.py +78 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +21 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +486 -0
- flyte/_internal/runtime/entrypoints.py +204 -0
- flyte/_internal/runtime/io.py +188 -0
- flyte/_internal/runtime/resources_serde.py +152 -0
- flyte/_internal/runtime/reuse.py +125 -0
- flyte/_internal/runtime/rusty.py +193 -0
- flyte/_internal/runtime/task_serde.py +362 -0
- flyte/_internal/runtime/taskrunner.py +209 -0
- flyte/_internal/runtime/trigger_serde.py +160 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_keyring/__init__.py +0 -0
- flyte/_keyring/file.py +115 -0
- flyte/_logging.py +300 -0
- flyte/_map.py +312 -0
- flyte/_module.py +72 -0
- flyte/_pod.py +30 -0
- flyte/_resources.py +473 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +102 -0
- flyte/_run.py +724 -0
- flyte/_secret.py +96 -0
- flyte/_task.py +550 -0
- flyte/_task_environment.py +316 -0
- flyte/_task_plugins.py +47 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +119 -0
- flyte/_trigger.py +1000 -0
- flyte/_utils/__init__.py +30 -0
- flyte/_utils/asyn.py +121 -0
- flyte/_utils/async_cache.py +139 -0
- flyte/_utils/coro_management.py +27 -0
- flyte/_utils/docker_credentials.py +173 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +134 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/module_loader.py +104 -0
- flyte/_utils/org_discovery.py +57 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +34 -0
- flyte/app/__init__.py +22 -0
- flyte/app/_app_environment.py +157 -0
- flyte/app/_deploy.py +125 -0
- flyte/app/_input.py +160 -0
- flyte/app/_runtime/__init__.py +3 -0
- flyte/app/_runtime/app_serde.py +347 -0
- flyte/app/_types.py +101 -0
- flyte/app/extras/__init__.py +3 -0
- flyte/app/extras/_fastapi.py +151 -0
- flyte/cli/__init__.py +12 -0
- flyte/cli/_abort.py +28 -0
- flyte/cli/_build.py +114 -0
- flyte/cli/_common.py +468 -0
- flyte/cli/_create.py +371 -0
- flyte/cli/_delete.py +45 -0
- flyte/cli/_deploy.py +293 -0
- flyte/cli/_gen.py +176 -0
- flyte/cli/_get.py +370 -0
- flyte/cli/_option.py +33 -0
- flyte/cli/_params.py +554 -0
- flyte/cli/_plugins.py +209 -0
- flyte/cli/_run.py +597 -0
- flyte/cli/_serve.py +64 -0
- flyte/cli/_update.py +37 -0
- flyte/cli/_user.py +17 -0
- flyte/cli/main.py +221 -0
- flyte/config/__init__.py +3 -0
- flyte/config/_config.py +248 -0
- flyte/config/_internal.py +73 -0
- flyte/config/_reader.py +225 -0
- flyte/connectors/__init__.py +11 -0
- flyte/connectors/_connector.py +270 -0
- flyte/connectors/_server.py +197 -0
- flyte/connectors/utils.py +135 -0
- flyte/errors.py +243 -0
- flyte/extend.py +19 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +286 -0
- flyte/git/__init__.py +3 -0
- flyte/git/_config.py +21 -0
- flyte/io/__init__.py +29 -0
- flyte/io/_dataframe/__init__.py +131 -0
- flyte/io/_dataframe/basic_dfs.py +223 -0
- flyte/io/_dataframe/dataframe.py +1026 -0
- flyte/io/_dir.py +910 -0
- flyte/io/_file.py +914 -0
- flyte/io/_hashing_io.py +342 -0
- flyte/models.py +479 -0
- flyte/py.typed +0 -0
- flyte/remote/__init__.py +35 -0
- flyte/remote/_action.py +738 -0
- flyte/remote/_app.py +57 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +189 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_auth_utils.py +14 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +403 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +213 -0
- flyte/remote/_client/auth/_client_config.py +85 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +152 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +128 -0
- flyte/remote/_common.py +30 -0
- flyte/remote/_console.py +19 -0
- flyte/remote/_data.py +161 -0
- flyte/remote/_logs.py +185 -0
- flyte/remote/_project.py +88 -0
- flyte/remote/_run.py +386 -0
- flyte/remote/_secret.py +142 -0
- flyte/remote/_task.py +527 -0
- flyte/remote/_trigger.py +306 -0
- flyte/remote/_user.py +33 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +182 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +36 -0
- flyte/storage/_config.py +237 -0
- flyte/storage/_parallel_reader.py +274 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +456 -0
- flyte/storage/_utils.py +5 -0
- flyte/syncify/__init__.py +56 -0
- flyte/syncify/_api.py +375 -0
- flyte/types/__init__.py +52 -0
- flyte/types/_interface.py +40 -0
- flyte/types/_pickle.py +145 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +119 -0
- flyte/types/_type_engine.py +2254 -0
- flyte/types/_utils.py +80 -0
- flyte-2.0.0b32.data/scripts/debug.py +38 -0
- flyte-2.0.0b32.data/scripts/runtime.py +195 -0
- flyte-2.0.0b32.dist-info/METADATA +351 -0
- flyte-2.0.0b32.dist-info/RECORD +204 -0
- flyte-2.0.0b32.dist-info/WHEEL +5 -0
- flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
- flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
- flyte-2.0.0b32.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import time
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
from flyte._context import contextual_run
|
|
6
|
+
from flyte._internal.controllers import Controller
|
|
7
|
+
from flyte._internal.controllers import create_controller as _create_controller
|
|
8
|
+
from flyte._internal.imagebuild.image_builder import ImageCache
|
|
9
|
+
from flyte._internal.runtime.entrypoints import download_code_bundle, load_pkl_task, load_task
|
|
10
|
+
from flyte._internal.runtime.taskrunner import extract_download_run_upload
|
|
11
|
+
from flyte._logging import logger
|
|
12
|
+
from flyte._task import TaskTemplate
|
|
13
|
+
from flyte._utils import adjust_sys_path
|
|
14
|
+
from flyte.models import ActionID, Checkpoints, CodeBundle, PathRewrite, RawDataPath
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
async def download_tgz(destination: str, version: str, tgz: str) -> CodeBundle:
|
|
18
|
+
"""
|
|
19
|
+
Downloads and loads the task from the code bundle or resolver.
|
|
20
|
+
:param tgz: The path to the task template in a tar.gz format.
|
|
21
|
+
:param destination: The path to save the downloaded task template.
|
|
22
|
+
:param version: The version of the task to load.
|
|
23
|
+
:return: The CodeBundle object.
|
|
24
|
+
"""
|
|
25
|
+
logger.info(f"[rusty] Downloading tgz code bundle from {tgz} to {destination} with version {version}")
|
|
26
|
+
adjust_sys_path()
|
|
27
|
+
|
|
28
|
+
code_bundle = CodeBundle(
|
|
29
|
+
tgz=tgz,
|
|
30
|
+
destination=destination,
|
|
31
|
+
computed_version=version,
|
|
32
|
+
)
|
|
33
|
+
return await download_code_bundle(code_bundle)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
async def download_load_pkl(destination: str, version: str, pkl: str) -> Tuple[CodeBundle, TaskTemplate]:
|
|
37
|
+
"""
|
|
38
|
+
Downloads and loads the task from the code bundle or resolver.
|
|
39
|
+
:param pkl: The path to the task template in a pickle format.
|
|
40
|
+
:param destination: The path to save the downloaded task template.
|
|
41
|
+
:param version: The version of the task to load.
|
|
42
|
+
:return: The CodeBundle object.
|
|
43
|
+
"""
|
|
44
|
+
logger.info(f"[rusty] Downloading pkl code bundle from {pkl} to {destination} with version {version}")
|
|
45
|
+
adjust_sys_path()
|
|
46
|
+
|
|
47
|
+
code_bundle = CodeBundle(
|
|
48
|
+
pkl=pkl,
|
|
49
|
+
destination=destination,
|
|
50
|
+
computed_version=version,
|
|
51
|
+
)
|
|
52
|
+
code_bundle = await download_code_bundle(code_bundle)
|
|
53
|
+
return code_bundle, load_pkl_task(code_bundle)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def load_task_from_code_bundle(resolver: str, resolver_args: List[str]) -> TaskTemplate:
|
|
57
|
+
"""
|
|
58
|
+
Loads the task from the code bundle or resolver.
|
|
59
|
+
:param resolver: The resolver to use to load the task.
|
|
60
|
+
:param resolver_args: The arguments to pass to the resolver.
|
|
61
|
+
:return: The loaded task template.
|
|
62
|
+
"""
|
|
63
|
+
logger.debug(f"[rusty] Loading task from code bundle {resolver} with args: {resolver_args}")
|
|
64
|
+
return load_task(resolver, *resolver_args)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
async def create_controller(
|
|
68
|
+
endpoint: str = "host.docker.internal:8090",
|
|
69
|
+
insecure: bool = False,
|
|
70
|
+
api_key: str | None = None,
|
|
71
|
+
) -> Controller:
|
|
72
|
+
"""
|
|
73
|
+
Creates a controller instance for remote operations.
|
|
74
|
+
:param endpoint:
|
|
75
|
+
:param insecure:
|
|
76
|
+
:param api_key:
|
|
77
|
+
:return:
|
|
78
|
+
"""
|
|
79
|
+
logger.info(f"[rusty] Creating controller with endpoint {endpoint}")
|
|
80
|
+
import flyte.errors
|
|
81
|
+
from flyte._initialize import init_in_cluster
|
|
82
|
+
|
|
83
|
+
loop = asyncio.get_event_loop()
|
|
84
|
+
loop.set_exception_handler(flyte.errors.silence_grpc_polling_error)
|
|
85
|
+
|
|
86
|
+
# TODO Currently reference tasks are not supported in Rusty.
|
|
87
|
+
controller_kwargs = await init_in_cluster.aio(api_key=api_key, endpoint=endpoint, insecure=insecure)
|
|
88
|
+
return _create_controller(ct="remote", **controller_kwargs)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
async def run_task(
|
|
92
|
+
task: TaskTemplate,
|
|
93
|
+
controller: Controller,
|
|
94
|
+
org: str,
|
|
95
|
+
project: str,
|
|
96
|
+
domain: str,
|
|
97
|
+
run_name: str,
|
|
98
|
+
name: str,
|
|
99
|
+
raw_data_path: str,
|
|
100
|
+
output_path: str,
|
|
101
|
+
run_base_dir: str,
|
|
102
|
+
version: str,
|
|
103
|
+
image_cache: str | None = None,
|
|
104
|
+
checkpoint_path: str | None = None,
|
|
105
|
+
prev_checkpoint: str | None = None,
|
|
106
|
+
code_bundle: CodeBundle | None = None,
|
|
107
|
+
input_path: str | None = None,
|
|
108
|
+
path_rewrite_cfg: str | None = None,
|
|
109
|
+
):
|
|
110
|
+
"""
|
|
111
|
+
Runs the task with the provided parameters.
|
|
112
|
+
:param prev_checkpoint: Previous checkpoint path to resume from.
|
|
113
|
+
:param checkpoint_path: Checkpoint path to save the current state.
|
|
114
|
+
:param image_cache: Image cache to use for the task.
|
|
115
|
+
:param name: Action name to run.
|
|
116
|
+
:param run_name: Parent run name to use for the task.
|
|
117
|
+
:param domain: domain to run the task in.
|
|
118
|
+
:param project: project to run the task in.
|
|
119
|
+
:param org: organization to run the task in.
|
|
120
|
+
:param task: The task template to run.
|
|
121
|
+
:param raw_data_path: The path to the raw data.
|
|
122
|
+
:param output_path: The path to save the output.
|
|
123
|
+
:param run_base_dir: The base directory for the run.
|
|
124
|
+
:param version: The version of the task to run.
|
|
125
|
+
:param controller: The controller to use for the task.
|
|
126
|
+
:param code_bundle: Optional code bundle for the task.
|
|
127
|
+
:param input_path: Optional input path for the task.
|
|
128
|
+
:param path_rewrite_cfg: Optional path rewrite configuration.
|
|
129
|
+
:return: The loaded task template.
|
|
130
|
+
"""
|
|
131
|
+
start_time = time.time()
|
|
132
|
+
action_id = f"{org}/{project}/{domain}/{run_name}/{name}"
|
|
133
|
+
|
|
134
|
+
logger.info(
|
|
135
|
+
f"[rusty] Running task '{task.name}' (action: {action_id})"
|
|
136
|
+
f" at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}"
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
path_rewrite = PathRewrite.from_str(path_rewrite_cfg) if path_rewrite_cfg else None
|
|
140
|
+
if path_rewrite:
|
|
141
|
+
import flyte.storage as storage
|
|
142
|
+
|
|
143
|
+
if not await storage.exists(path_rewrite.new_prefix):
|
|
144
|
+
logger.error(
|
|
145
|
+
f"[rusty] Path rewrite failed for path {path_rewrite.new_prefix}, "
|
|
146
|
+
f"not found, reverting to original path {path_rewrite.old_prefix}"
|
|
147
|
+
)
|
|
148
|
+
path_rewrite = None
|
|
149
|
+
else:
|
|
150
|
+
logger.info(f"[rusty] Using path rewrite: {path_rewrite}")
|
|
151
|
+
|
|
152
|
+
try:
|
|
153
|
+
await contextual_run(
|
|
154
|
+
extract_download_run_upload,
|
|
155
|
+
task,
|
|
156
|
+
action=ActionID(name=name, org=org, project=project, domain=domain, run_name=run_name),
|
|
157
|
+
version=version,
|
|
158
|
+
controller=controller,
|
|
159
|
+
raw_data_path=RawDataPath(path=raw_data_path, path_rewrite=path_rewrite),
|
|
160
|
+
output_path=output_path,
|
|
161
|
+
run_base_dir=run_base_dir,
|
|
162
|
+
checkpoints=Checkpoints(prev_checkpoint_path=prev_checkpoint, checkpoint_path=checkpoint_path),
|
|
163
|
+
code_bundle=code_bundle,
|
|
164
|
+
input_path=input_path,
|
|
165
|
+
image_cache=ImageCache.from_transport(image_cache) if image_cache else None,
|
|
166
|
+
)
|
|
167
|
+
except Exception as e:
|
|
168
|
+
logger.error(f"[rusty] Task failed: {e!s}")
|
|
169
|
+
raise
|
|
170
|
+
finally:
|
|
171
|
+
end_time = time.time()
|
|
172
|
+
duration = end_time - start_time
|
|
173
|
+
logger.info(
|
|
174
|
+
f"[rusty] TASK_EXECUTION_END: Task '{task.name}' (action: {action_id})"
|
|
175
|
+
f" done after {duration:.2f}s at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))}"
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
async def ping(name: str) -> str:
|
|
180
|
+
"""
|
|
181
|
+
A simple hello world function to test the Rusty entrypoint.
|
|
182
|
+
"""
|
|
183
|
+
print(f"Received ping request from {name} in Rusty!")
|
|
184
|
+
return f"pong from Rusty to {name}!"
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
async def hello(name: str):
|
|
188
|
+
"""
|
|
189
|
+
A simple hello world function to test the Rusty entrypoint.
|
|
190
|
+
:param name: The name to greet.
|
|
191
|
+
:return: A greeting message.
|
|
192
|
+
"""
|
|
193
|
+
print(f"Received hello request in Rusty with name: {name}!")
|
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides functionality to serialize and deserialize tasks to and from the wire format.
|
|
3
|
+
It includes a Resolver interface for loading tasks, and functions to load classes and tasks.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import copy
|
|
7
|
+
import typing
|
|
8
|
+
from datetime import timedelta
|
|
9
|
+
from typing import Optional, cast
|
|
10
|
+
|
|
11
|
+
from flyteidl2.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
|
|
12
|
+
from flyteidl2.task import common_pb2, environment_pb2, task_definition_pb2
|
|
13
|
+
from google.protobuf import duration_pb2, wrappers_pb2
|
|
14
|
+
|
|
15
|
+
import flyte.errors
|
|
16
|
+
from flyte._cache.cache import VersionParameters, cache_from_request
|
|
17
|
+
from flyte._logging import logger
|
|
18
|
+
from flyte._pod import _PRIMARY_CONTAINER_NAME_FIELD, PodTemplate
|
|
19
|
+
from flyte._secret import SecretRequest, secrets_from_request
|
|
20
|
+
from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
|
|
21
|
+
from flyte.models import CodeBundle, SerializationContext
|
|
22
|
+
|
|
23
|
+
from ... import ReusePolicy
|
|
24
|
+
from ..._retry import RetryStrategy
|
|
25
|
+
from ..._timeout import TimeoutType, timeout_from_request
|
|
26
|
+
from .resources_serde import get_proto_extended_resources, get_proto_resources
|
|
27
|
+
from .reuse import add_reusable
|
|
28
|
+
from .types_serde import transform_native_to_typed_interface
|
|
29
|
+
|
|
30
|
+
_MAX_ENV_NAME_LENGTH = 63 # Maximum length for environment names
|
|
31
|
+
_MAX_TASK_SHORT_NAME_LENGTH = 63 # Maximum length for task short names
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def translate_task_to_wire(
|
|
35
|
+
task: TaskTemplate,
|
|
36
|
+
serialization_context: SerializationContext,
|
|
37
|
+
default_inputs: Optional[typing.List[common_pb2.NamedParameter]] = None,
|
|
38
|
+
) -> task_definition_pb2.TaskSpec:
|
|
39
|
+
"""
|
|
40
|
+
Translate a task to a wire format. This is a placeholder function.
|
|
41
|
+
|
|
42
|
+
:param task: The task to translate.
|
|
43
|
+
:param serialization_context: The serialization context to use for the translation.
|
|
44
|
+
:param default_inputs: Optional list of default inputs for the task.
|
|
45
|
+
|
|
46
|
+
:return: The translated task.
|
|
47
|
+
"""
|
|
48
|
+
tt = get_proto_task(task, serialization_context)
|
|
49
|
+
env: environment_pb2.Environment | None = None
|
|
50
|
+
if task.parent_env and task.parent_env():
|
|
51
|
+
_env = task.parent_env()
|
|
52
|
+
if _env:
|
|
53
|
+
env = environment_pb2.Environment(name=_env.name[:_MAX_ENV_NAME_LENGTH])
|
|
54
|
+
return task_definition_pb2.TaskSpec(
|
|
55
|
+
task_template=tt,
|
|
56
|
+
default_inputs=default_inputs,
|
|
57
|
+
short_name=task.short_name[:_MAX_TASK_SHORT_NAME_LENGTH],
|
|
58
|
+
environment=env,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_security_context(
|
|
63
|
+
secrets: Optional[SecretRequest],
|
|
64
|
+
) -> Optional[security_pb2.SecurityContext]:
|
|
65
|
+
"""
|
|
66
|
+
Get the security context from a list of secrets. This is a placeholder function.
|
|
67
|
+
|
|
68
|
+
:param secrets: The list of secrets to use for the security context.
|
|
69
|
+
|
|
70
|
+
:return: The security context.
|
|
71
|
+
"""
|
|
72
|
+
if secrets is None:
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
secret_list = secrets_from_request(secrets)
|
|
76
|
+
return security_pb2.SecurityContext(
|
|
77
|
+
secrets=[
|
|
78
|
+
security_pb2.Secret(
|
|
79
|
+
group=secret.group,
|
|
80
|
+
key=secret.key,
|
|
81
|
+
mount_requirement=(
|
|
82
|
+
security_pb2.Secret.MountType.ENV_VAR if secret.as_env_var else security_pb2.Secret.MountType.FILE
|
|
83
|
+
),
|
|
84
|
+
env_var=secret.as_env_var,
|
|
85
|
+
)
|
|
86
|
+
for secret in secret_list
|
|
87
|
+
]
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_proto_retry_strategy(
|
|
92
|
+
retries: RetryStrategy | int | None,
|
|
93
|
+
) -> Optional[literals_pb2.RetryStrategy]:
|
|
94
|
+
if retries is None:
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
if isinstance(retries, int):
|
|
98
|
+
raise AssertionError("Retries should be an instance of RetryStrategy, not int")
|
|
99
|
+
|
|
100
|
+
return literals_pb2.RetryStrategy(retries=retries.count)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_proto_timeout(timeout: TimeoutType | None) -> Optional[duration_pb2.Duration]:
|
|
104
|
+
if timeout is None:
|
|
105
|
+
return None
|
|
106
|
+
max_runtime_timeout = timeout_from_request(timeout).max_runtime
|
|
107
|
+
if isinstance(max_runtime_timeout, int):
|
|
108
|
+
max_runtime_timeout = timedelta(seconds=max_runtime_timeout)
|
|
109
|
+
return duration_pb2.Duration(seconds=max_runtime_timeout.seconds)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext) -> tasks_pb2.TaskTemplate:
|
|
113
|
+
task_id = identifier_pb2.Identifier(
|
|
114
|
+
resource_type=identifier_pb2.ResourceType.TASK,
|
|
115
|
+
project=serialize_context.project,
|
|
116
|
+
domain=serialize_context.domain,
|
|
117
|
+
org=serialize_context.org,
|
|
118
|
+
name=task.name,
|
|
119
|
+
version=serialize_context.version,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# TODO Add support for extra_config, custom
|
|
123
|
+
extra_config: typing.Dict[str, str] = {}
|
|
124
|
+
|
|
125
|
+
if task.pod_template and not isinstance(task.pod_template, str):
|
|
126
|
+
pod = _get_k8s_pod(_get_urun_container(serialize_context, task), task.pod_template)
|
|
127
|
+
extra_config[_PRIMARY_CONTAINER_NAME_FIELD] = task.pod_template.primary_container_name
|
|
128
|
+
container = None
|
|
129
|
+
else:
|
|
130
|
+
container = _get_urun_container(serialize_context, task)
|
|
131
|
+
pod = None
|
|
132
|
+
|
|
133
|
+
custom = task.custom_config(serialize_context)
|
|
134
|
+
|
|
135
|
+
sql = task.sql(serialize_context)
|
|
136
|
+
|
|
137
|
+
# -------------- CACHE HANDLING ----------------------
|
|
138
|
+
task_cache = cache_from_request(task.cache)
|
|
139
|
+
cache_enabled = task_cache.is_enabled()
|
|
140
|
+
cache_version = None
|
|
141
|
+
|
|
142
|
+
if task_cache.is_enabled():
|
|
143
|
+
logger.debug(f"Cache enabled for task {task.name}")
|
|
144
|
+
if serialize_context.code_bundle and serialize_context.code_bundle.pkl:
|
|
145
|
+
logger.debug(f"Detected pkl bundle for task {task.name}, using computed version as cache version")
|
|
146
|
+
cache_version = serialize_context.code_bundle.computed_version
|
|
147
|
+
else:
|
|
148
|
+
if isinstance(task, AsyncFunctionTaskTemplate):
|
|
149
|
+
version_parameters = VersionParameters(func=task.func, image=task.image)
|
|
150
|
+
else:
|
|
151
|
+
version_parameters = VersionParameters(func=None, image=task.image)
|
|
152
|
+
cache_version = task_cache.get_version(version_parameters)
|
|
153
|
+
logger.debug(f"Cache version for task {task.name} is {cache_version}")
|
|
154
|
+
else:
|
|
155
|
+
logger.debug(f"Cache disabled for task {task.name}")
|
|
156
|
+
|
|
157
|
+
task_template = tasks_pb2.TaskTemplate(
|
|
158
|
+
id=task_id,
|
|
159
|
+
type=task.task_type,
|
|
160
|
+
metadata=tasks_pb2.TaskMetadata(
|
|
161
|
+
discoverable=cache_enabled,
|
|
162
|
+
discovery_version=cache_version,
|
|
163
|
+
cache_serializable=task_cache.serialize,
|
|
164
|
+
cache_ignore_input_vars=(task_cache.get_ignored_inputs() if cache_enabled else None),
|
|
165
|
+
runtime=tasks_pb2.RuntimeMetadata(
|
|
166
|
+
version=flyte.version(),
|
|
167
|
+
type=tasks_pb2.RuntimeMetadata.RuntimeType.FLYTE_SDK,
|
|
168
|
+
flavor="python",
|
|
169
|
+
),
|
|
170
|
+
retries=get_proto_retry_strategy(task.retries),
|
|
171
|
+
timeout=get_proto_timeout(task.timeout),
|
|
172
|
+
pod_template_name=(task.pod_template if task.pod_template and isinstance(task.pod_template, str) else None),
|
|
173
|
+
interruptible=task.interruptible,
|
|
174
|
+
generates_deck=wrappers_pb2.BoolValue(value=task.report),
|
|
175
|
+
debuggable=task.debuggable,
|
|
176
|
+
),
|
|
177
|
+
interface=transform_native_to_typed_interface(task.native_interface),
|
|
178
|
+
custom=custom if len(custom) > 0 else None,
|
|
179
|
+
container=container,
|
|
180
|
+
task_type_version=task.task_type_version,
|
|
181
|
+
security_context=get_security_context(task.secrets),
|
|
182
|
+
config=extra_config,
|
|
183
|
+
k8s_pod=pod,
|
|
184
|
+
sql=sql,
|
|
185
|
+
extended_resources=get_proto_extended_resources(task.resources),
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
if task.reusable is not None:
|
|
189
|
+
if not isinstance(task.reusable, ReusePolicy):
|
|
190
|
+
raise flyte.errors.RuntimeUserError(
|
|
191
|
+
"BadConfig", f"Expected ReusePolicy, got {type(task.reusable)} for task {task.name}"
|
|
192
|
+
)
|
|
193
|
+
env_name = None
|
|
194
|
+
if task.parent_env is not None:
|
|
195
|
+
env = task.parent_env()
|
|
196
|
+
if env is not None:
|
|
197
|
+
env_name = env.name
|
|
198
|
+
return add_reusable(task_template, task.reusable, serialize_context.code_bundle, env_name)
|
|
199
|
+
|
|
200
|
+
return task_template
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def lookup_image_in_cache(serialize_context: SerializationContext, env_name: str, image: flyte.Image) -> str:
|
|
204
|
+
if not serialize_context.image_cache:
|
|
205
|
+
# This computes the image uri, computing hashes as necessary so can fail if done remotely.
|
|
206
|
+
return image.uri
|
|
207
|
+
elif serialize_context.image_cache and env_name not in serialize_context.image_cache.image_lookup:
|
|
208
|
+
raise flyte.errors.RuntimeUserError(
|
|
209
|
+
"MissingEnvironment",
|
|
210
|
+
f"Environment '{env_name}' not found in image cache.\n\n"
|
|
211
|
+
"💡 To fix this:\n"
|
|
212
|
+
" 1. If your parent environment calls a task in another environment,"
|
|
213
|
+
" declare that dependency using 'depends_on=[...]'.\n"
|
|
214
|
+
" Example:\n"
|
|
215
|
+
" env1 = flyte.TaskEnvironment(\n"
|
|
216
|
+
" name='outer',\n"
|
|
217
|
+
" image=flyte.Image.from_debian_base().with_pip_packages('requests'),\n"
|
|
218
|
+
" depends_on=[env2, env3],\n"
|
|
219
|
+
" )\n"
|
|
220
|
+
" 2. If you're using os.getenv() to set the environment name,"
|
|
221
|
+
" make sure the runtime environment has the same environment variable defined.\n"
|
|
222
|
+
" Example:\n"
|
|
223
|
+
" env = flyte.TaskEnvironment(\n"
|
|
224
|
+
' name=os.getenv("my-name"),\n'
|
|
225
|
+
' env_vars={"my-name": os.getenv("my-name")},\n'
|
|
226
|
+
" )\n",
|
|
227
|
+
)
|
|
228
|
+
return serialize_context.image_cache.image_lookup[env_name]
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _get_urun_container(
|
|
232
|
+
serialize_context: SerializationContext, task_template: TaskTemplate
|
|
233
|
+
) -> Optional[tasks_pb2.Container]:
|
|
234
|
+
env = (
|
|
235
|
+
[literals_pb2.KeyValuePair(key=k, value=v) for k, v in task_template.env_vars.items()]
|
|
236
|
+
if task_template.env_vars
|
|
237
|
+
else None
|
|
238
|
+
)
|
|
239
|
+
resources = get_proto_resources(task_template.resources)
|
|
240
|
+
|
|
241
|
+
img = task_template.image
|
|
242
|
+
if isinstance(img, str):
|
|
243
|
+
raise flyte.errors.RuntimeSystemError("BadConfig", "Image is not a valid image")
|
|
244
|
+
|
|
245
|
+
env_name = task_template.parent_env_name
|
|
246
|
+
if env_name is None:
|
|
247
|
+
raise flyte.errors.RuntimeSystemError("BadConfig", f"Task {task_template.name} has no parent environment name")
|
|
248
|
+
|
|
249
|
+
img_uri = lookup_image_in_cache(serialize_context, env_name, img)
|
|
250
|
+
|
|
251
|
+
return tasks_pb2.Container(
|
|
252
|
+
image=img_uri,
|
|
253
|
+
command=[],
|
|
254
|
+
args=task_template.container_args(serialize_context),
|
|
255
|
+
resources=resources,
|
|
256
|
+
env=env,
|
|
257
|
+
data_config=task_template.data_loading_config(serialize_context),
|
|
258
|
+
config=task_template.config(serialize_context),
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def _sanitize_resource_name(resource: tasks_pb2.Resources.ResourceEntry) -> str:
|
|
263
|
+
return tasks_pb2.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _get_k8s_pod(primary_container: tasks_pb2.Container, pod_template: PodTemplate) -> Optional[tasks_pb2.K8sPod]:
|
|
267
|
+
"""
|
|
268
|
+
Get the K8sPod representation of the task template.
|
|
269
|
+
:param task: The task to convert.
|
|
270
|
+
:return: The K8sPod representation of the task template.
|
|
271
|
+
"""
|
|
272
|
+
from kubernetes.client import ApiClient, V1PodSpec
|
|
273
|
+
from kubernetes.client.models import V1EnvVar, V1ResourceRequirements
|
|
274
|
+
|
|
275
|
+
pod_template = copy.deepcopy(pod_template)
|
|
276
|
+
containers = cast(V1PodSpec, pod_template.pod_spec).containers
|
|
277
|
+
primary_exists = False
|
|
278
|
+
|
|
279
|
+
for container in containers:
|
|
280
|
+
if container.name == pod_template.primary_container_name:
|
|
281
|
+
primary_exists = True
|
|
282
|
+
break
|
|
283
|
+
|
|
284
|
+
if not primary_exists:
|
|
285
|
+
raise ValueError(
|
|
286
|
+
"No primary container defined in the pod spec."
|
|
287
|
+
f" You must define a primary container with the name '{pod_template.primary_container_name}'."
|
|
288
|
+
)
|
|
289
|
+
final_containers = []
|
|
290
|
+
|
|
291
|
+
for container in containers:
|
|
292
|
+
# We overwrite the primary container attributes with the values given to ContainerTask.
|
|
293
|
+
# The attributes include: image, command, args, resource, and env (env is unioned)
|
|
294
|
+
|
|
295
|
+
if container.name == pod_template.primary_container_name:
|
|
296
|
+
if container.image is None:
|
|
297
|
+
# Copy the image from primary_container only if the image is not specified in the pod spec.
|
|
298
|
+
container.image = primary_container.image
|
|
299
|
+
|
|
300
|
+
container.command = list(primary_container.command)
|
|
301
|
+
container.args = list(primary_container.args)
|
|
302
|
+
|
|
303
|
+
limits, requests = {}, {}
|
|
304
|
+
for resource in primary_container.resources.limits:
|
|
305
|
+
limits[_sanitize_resource_name(resource)] = resource.value
|
|
306
|
+
for resource in primary_container.resources.requests:
|
|
307
|
+
requests[_sanitize_resource_name(resource)] = resource.value
|
|
308
|
+
|
|
309
|
+
resource_requirements = V1ResourceRequirements(limits=limits, requests=requests)
|
|
310
|
+
if len(limits) > 0 or len(requests) > 0:
|
|
311
|
+
# Important! Only copy over resource requirements if they are non-empty.
|
|
312
|
+
container.resources = resource_requirements
|
|
313
|
+
|
|
314
|
+
if primary_container.env is not None:
|
|
315
|
+
container.env = [V1EnvVar(name=e.key, value=e.value) for e in primary_container.env] + (
|
|
316
|
+
container.env or []
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
final_containers.append(container)
|
|
320
|
+
|
|
321
|
+
cast(V1PodSpec, pod_template.pod_spec).containers = final_containers
|
|
322
|
+
pod_spec = ApiClient().sanitize_for_serialization(pod_template.pod_spec)
|
|
323
|
+
|
|
324
|
+
metadata = tasks_pb2.K8sObjectMetadata(labels=pod_template.labels, annotations=pod_template.annotations)
|
|
325
|
+
return tasks_pb2.K8sPod(pod_spec=pod_spec, metadata=metadata)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def extract_code_bundle(
|
|
329
|
+
task_spec: task_definition_pb2.TaskSpec,
|
|
330
|
+
) -> Optional[CodeBundle]:
|
|
331
|
+
"""
|
|
332
|
+
Extract the code bundle from the task spec.
|
|
333
|
+
:param task_spec: The task spec to extract the code bundle from.
|
|
334
|
+
:return: The extracted code bundle or None if not present.
|
|
335
|
+
"""
|
|
336
|
+
container = task_spec.task_template.container
|
|
337
|
+
if container and container.args:
|
|
338
|
+
pkl_path = None
|
|
339
|
+
tgz_path = None
|
|
340
|
+
dest_path: str = "."
|
|
341
|
+
version = ""
|
|
342
|
+
for i, v in enumerate(container.args):
|
|
343
|
+
if v == "--pkl":
|
|
344
|
+
# Extract the code bundle path from the argument
|
|
345
|
+
pkl_path = container.args[i + 1] if i + 1 < len(container.args) else None
|
|
346
|
+
elif v == "--tgz":
|
|
347
|
+
# Extract the code bundle path from the argument
|
|
348
|
+
tgz_path = container.args[i + 1] if i + 1 < len(container.args) else None
|
|
349
|
+
elif v == "--dest":
|
|
350
|
+
# Extract the destination path from the argument
|
|
351
|
+
dest_path = container.args[i + 1] if i + 1 < len(container.args) else "."
|
|
352
|
+
elif v == "--version":
|
|
353
|
+
# Extract the version from the argument
|
|
354
|
+
version = container.args[i + 1] if i + 1 < len(container.args) else ""
|
|
355
|
+
if pkl_path or tgz_path:
|
|
356
|
+
return CodeBundle(
|
|
357
|
+
destination=dest_path,
|
|
358
|
+
tgz=tgz_path,
|
|
359
|
+
pkl=pkl_path,
|
|
360
|
+
computed_version=version,
|
|
361
|
+
)
|
|
362
|
+
return None
|