flyte 0.2.0b23__py3-none-any.whl → 0.2.0b25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +1 -15
- flyte/_bin/runtime.py +10 -1
- flyte/_environment.py +7 -0
- flyte/_image.py +37 -22
- flyte/_internal/controllers/remote/_action.py +29 -0
- flyte/_internal/controllers/remote/_controller.py +33 -5
- flyte/_internal/controllers/remote/_core.py +28 -14
- flyte/_internal/imagebuild/__init__.py +1 -13
- flyte/_internal/imagebuild/docker_builder.py +5 -1
- flyte/_internal/imagebuild/remote_builder.py +11 -3
- flyte/_internal/runtime/convert.py +40 -5
- flyte/_internal/runtime/entrypoints.py +58 -15
- flyte/_internal/runtime/reuse.py +121 -0
- flyte/_internal/runtime/rusty.py +165 -0
- flyte/_internal/runtime/task_serde.py +26 -35
- flyte/_pod.py +11 -1
- flyte/_resources.py +67 -2
- flyte/_reusable_environment.py +57 -2
- flyte/_run.py +35 -16
- flyte/_secret.py +30 -0
- flyte/_task.py +1 -5
- flyte/_task_environment.py +34 -2
- flyte/_task_plugins.py +45 -0
- flyte/_version.py +2 -2
- flyte/errors.py +2 -2
- flyte/extend.py +12 -0
- flyte/models.py +10 -1
- flyte/types/_type_engine.py +16 -2
- flyte-0.2.0b25.data/scripts/runtime.py +169 -0
- {flyte-0.2.0b23.dist-info → flyte-0.2.0b25.dist-info}/METADATA +1 -1
- {flyte-0.2.0b23.dist-info → flyte-0.2.0b25.dist-info}/RECORD +34 -29
- {flyte-0.2.0b23.dist-info → flyte-0.2.0b25.dist-info}/WHEEL +0 -0
- {flyte-0.2.0b23.dist-info → flyte-0.2.0b25.dist-info}/entry_points.txt +0 -0
- {flyte-0.2.0b23.dist-info → flyte-0.2.0b25.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import typing
|
|
3
|
+
from venv import logger
|
|
4
|
+
|
|
5
|
+
from flyteidl.core import tasks_pb2
|
|
6
|
+
|
|
7
|
+
import flyte.errors
|
|
8
|
+
from flyte import ReusePolicy
|
|
9
|
+
from flyte._pod import _PRIMARY_CONTAINER_DEFAULT_NAME, _PRIMARY_CONTAINER_NAME_FIELD
|
|
10
|
+
from flyte.models import CodeBundle
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def extract_unique_id_and_image(
|
|
14
|
+
env_name: str,
|
|
15
|
+
code_bundle: CodeBundle | None,
|
|
16
|
+
task: tasks_pb2.TaskTemplate,
|
|
17
|
+
reuse_policy: ReusePolicy,
|
|
18
|
+
) -> typing.Tuple[str, str]:
|
|
19
|
+
"""
|
|
20
|
+
Compute a unique ID for the task based on its name, version, image URI, and code bundle.
|
|
21
|
+
:param env_name: Name of the reusable environment.
|
|
22
|
+
:param reuse_policy: The reuse policy for the task.
|
|
23
|
+
:param task: The task template.
|
|
24
|
+
:param code_bundle: The code bundle associated with the task.
|
|
25
|
+
:return: A unique ID string and the image URI.
|
|
26
|
+
"""
|
|
27
|
+
image = ""
|
|
28
|
+
container_ser = ""
|
|
29
|
+
if task.HasField("container"):
|
|
30
|
+
copied_container = tasks_pb2.Container()
|
|
31
|
+
copied_container.CopyFrom(task.container)
|
|
32
|
+
copied_container.args.clear() # Clear args to ensure deterministic serialization
|
|
33
|
+
container_ser = copied_container.SerializeToString(deterministic=True)
|
|
34
|
+
image = copied_container.image
|
|
35
|
+
|
|
36
|
+
if task.HasField("k8s_pod"):
|
|
37
|
+
# Clear args to ensure deterministic serialization
|
|
38
|
+
copied_k8s_pod = tasks_pb2.K8sPod()
|
|
39
|
+
copied_k8s_pod.CopyFrom(task.k8s_pod)
|
|
40
|
+
if task.config is not None:
|
|
41
|
+
primary_container_name = task.config[_PRIMARY_CONTAINER_NAME_FIELD]
|
|
42
|
+
else:
|
|
43
|
+
primary_container_name = _PRIMARY_CONTAINER_DEFAULT_NAME
|
|
44
|
+
for container in copied_k8s_pod.pod_spec["containers"]:
|
|
45
|
+
if "name" in container and container["name"] == primary_container_name:
|
|
46
|
+
image = container["image"]
|
|
47
|
+
del container["args"]
|
|
48
|
+
container_ser = copied_k8s_pod.SerializeToString(deterministic=True)
|
|
49
|
+
|
|
50
|
+
components = f"{env_name}:{container_ser}"
|
|
51
|
+
if isinstance(reuse_policy.replicas, tuple):
|
|
52
|
+
components += f":{reuse_policy.replicas[0]}:{reuse_policy.replicas[1]}"
|
|
53
|
+
else:
|
|
54
|
+
components += f":{reuse_policy.replicas}"
|
|
55
|
+
if reuse_policy.ttl is not None:
|
|
56
|
+
components += f":{reuse_policy.ttl.total_seconds()}"
|
|
57
|
+
if reuse_policy.reuse_salt is None and code_bundle is not None:
|
|
58
|
+
components += f":{code_bundle.computed_version}"
|
|
59
|
+
else:
|
|
60
|
+
components += f":{reuse_policy.reuse_salt}"
|
|
61
|
+
if task.security_context is not None:
|
|
62
|
+
security_ctx_str = task.security_context.SerializeToString(deterministic=True)
|
|
63
|
+
components += f":{security_ctx_str}"
|
|
64
|
+
if task.metadata.interruptible is not None:
|
|
65
|
+
components += f":{task.metadata.interruptible}"
|
|
66
|
+
if task.metadata.pod_template_name is not None:
|
|
67
|
+
components += f":{task.metadata.pod_template_name}"
|
|
68
|
+
sha256 = hashlib.sha256()
|
|
69
|
+
sha256.update(components.encode("utf-8"))
|
|
70
|
+
return sha256.hexdigest(), image
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def add_reusable(
|
|
74
|
+
task: tasks_pb2.TaskTemplate,
|
|
75
|
+
reuse_policy: ReusePolicy,
|
|
76
|
+
code_bundle: CodeBundle | None,
|
|
77
|
+
parent_env_name: str | None = None,
|
|
78
|
+
) -> tasks_pb2.TaskTemplate:
|
|
79
|
+
"""
|
|
80
|
+
Convert a ReusePolicy to a custom configuration dictionary.
|
|
81
|
+
|
|
82
|
+
:param task: The task to which the reusable policy will be added.
|
|
83
|
+
:param reuse_policy: The reuse policy to apply.
|
|
84
|
+
:param code_bundle: The code bundle associated with the task.
|
|
85
|
+
:param parent_env_name: The name of the parent environment, if any.
|
|
86
|
+
:return: The modified task with the reusable policy added.
|
|
87
|
+
"""
|
|
88
|
+
if reuse_policy is None:
|
|
89
|
+
return task
|
|
90
|
+
|
|
91
|
+
if task.HasField("custom"):
|
|
92
|
+
raise flyte.errors.RuntimeUserError(
|
|
93
|
+
"BadConfiguration", "Plugins do not support reusable policy. Only container tasks and pods."
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
logger.debug(f"Adding reusable policy for task: {task.id.name}")
|
|
97
|
+
name = parent_env_name if parent_env_name else ""
|
|
98
|
+
if parent_env_name is None:
|
|
99
|
+
name = task.id.name.split(".")[0]
|
|
100
|
+
|
|
101
|
+
version, image_uri = extract_unique_id_and_image(
|
|
102
|
+
env_name=name, code_bundle=code_bundle, task=task, reuse_policy=reuse_policy
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
task.custom = {
|
|
106
|
+
"name": name,
|
|
107
|
+
"version": version[:15], # Use only the first 15 characters for the version
|
|
108
|
+
"type": "actor",
|
|
109
|
+
"spec": {
|
|
110
|
+
"container_image": image_uri,
|
|
111
|
+
"backlog_length": None,
|
|
112
|
+
"parallelism": reuse_policy.concurrency,
|
|
113
|
+
"replica_count": reuse_policy.max_replicas,
|
|
114
|
+
"ttl_seconds": reuse_policy.ttl.total_seconds() if reuse_policy.ttl else None,
|
|
115
|
+
},
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
task.type = "actor"
|
|
119
|
+
logger.info(f"Reusable task {task.id.name} with config {task.custom}")
|
|
120
|
+
|
|
121
|
+
return task
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from typing import Any, List, Tuple
|
|
3
|
+
|
|
4
|
+
from flyte._context import contextual_run
|
|
5
|
+
from flyte._internal.controllers import Controller
|
|
6
|
+
from flyte._internal.controllers import create_controller as _create_controller
|
|
7
|
+
from flyte._internal.imagebuild.image_builder import ImageCache
|
|
8
|
+
from flyte._internal.runtime.entrypoints import download_code_bundle, load_pkl_task, load_task
|
|
9
|
+
from flyte._internal.runtime.taskrunner import extract_download_run_upload
|
|
10
|
+
from flyte._logging import logger
|
|
11
|
+
from flyte._task import TaskTemplate
|
|
12
|
+
from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
async def download_tgz(destination: str, version: str, tgz: str) -> CodeBundle:
|
|
16
|
+
"""
|
|
17
|
+
Downloads and loads the task from the code bundle or resolver.
|
|
18
|
+
:param tgz: The path to the task template in a tar.gz format.
|
|
19
|
+
:param destination: The path to save the downloaded task template.
|
|
20
|
+
:param version: The version of the task to load.
|
|
21
|
+
:return: The CodeBundle object.
|
|
22
|
+
"""
|
|
23
|
+
logger.info(f"[rusty] Downloading tgz code bundle from {tgz} to {destination} with version {version}")
|
|
24
|
+
sys.path.insert(0, ".")
|
|
25
|
+
|
|
26
|
+
code_bundle = CodeBundle(
|
|
27
|
+
tgz=tgz,
|
|
28
|
+
destination=destination,
|
|
29
|
+
computed_version=version,
|
|
30
|
+
)
|
|
31
|
+
return await download_code_bundle(code_bundle)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
async def download_load_pkl(destination: str, version: str, pkl: str) -> Tuple[CodeBundle, TaskTemplate]:
|
|
35
|
+
"""
|
|
36
|
+
Downloads and loads the task from the code bundle or resolver.
|
|
37
|
+
:param pkl: The path to the task template in a pickle format.
|
|
38
|
+
:param destination: The path to save the downloaded task template.
|
|
39
|
+
:param version: The version of the task to load.
|
|
40
|
+
:return: The CodeBundle object.
|
|
41
|
+
"""
|
|
42
|
+
logger.info(f"[rusty] Downloading pkl code bundle from {pkl} to {destination} with version {version}")
|
|
43
|
+
sys.path.insert(0, ".")
|
|
44
|
+
|
|
45
|
+
code_bundle = CodeBundle(
|
|
46
|
+
pkl=pkl,
|
|
47
|
+
destination=destination,
|
|
48
|
+
computed_version=version,
|
|
49
|
+
)
|
|
50
|
+
code_bundle = await download_code_bundle(code_bundle)
|
|
51
|
+
return code_bundle, load_pkl_task(code_bundle)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def load_task_from_code_bundle(resolver: str, resolver_args: List[str]) -> TaskTemplate:
|
|
55
|
+
"""
|
|
56
|
+
Loads the task from the code bundle or resolver.
|
|
57
|
+
:param resolver: The resolver to use to load the task.
|
|
58
|
+
:param resolver_args: The arguments to pass to the resolver.
|
|
59
|
+
:return: The loaded task template.
|
|
60
|
+
"""
|
|
61
|
+
logger.debug(f"[rusty] Loading task from code bundle {resolver} with args: {resolver_args}")
|
|
62
|
+
return load_task(resolver, *resolver_args)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
async def create_controller(
|
|
66
|
+
endpoint: str = "host.docker.internal:8090",
|
|
67
|
+
insecure: bool = False,
|
|
68
|
+
api_key: str | None = None,
|
|
69
|
+
) -> Controller:
|
|
70
|
+
"""
|
|
71
|
+
Creates a controller instance for remote operations.
|
|
72
|
+
:param endpoint:
|
|
73
|
+
:param insecure:
|
|
74
|
+
:param api_key:
|
|
75
|
+
:return:
|
|
76
|
+
"""
|
|
77
|
+
logger.info(f"[rusty] Creating controller with endpoint {endpoint}")
|
|
78
|
+
from flyte._initialize import init
|
|
79
|
+
|
|
80
|
+
# TODO Currently refrence tasks are not supported in Rusty.
|
|
81
|
+
await init.aio()
|
|
82
|
+
controller_kwargs: dict[str, Any] = {"insecure": insecure}
|
|
83
|
+
if api_key:
|
|
84
|
+
logger.info("Using api key from environment")
|
|
85
|
+
controller_kwargs["api_key"] = api_key
|
|
86
|
+
else:
|
|
87
|
+
controller_kwargs["endpoint"] = endpoint
|
|
88
|
+
if "localhost" in endpoint or "docker" in endpoint:
|
|
89
|
+
controller_kwargs["insecure"] = True
|
|
90
|
+
logger.debug(f"Using controller endpoint: {endpoint} with kwargs: {controller_kwargs}")
|
|
91
|
+
|
|
92
|
+
return _create_controller(ct="remote", **controller_kwargs)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
async def run_task(
|
|
96
|
+
task: TaskTemplate,
|
|
97
|
+
controller: Controller,
|
|
98
|
+
org: str,
|
|
99
|
+
project: str,
|
|
100
|
+
domain: str,
|
|
101
|
+
run_name: str,
|
|
102
|
+
name: str,
|
|
103
|
+
raw_data_path: str,
|
|
104
|
+
output_path: str,
|
|
105
|
+
run_base_dir: str,
|
|
106
|
+
version: str,
|
|
107
|
+
image_cache: str | None = None,
|
|
108
|
+
checkpoint_path: str | None = None,
|
|
109
|
+
prev_checkpoint: str | None = None,
|
|
110
|
+
code_bundle: CodeBundle | None = None,
|
|
111
|
+
input_path: str | None = None,
|
|
112
|
+
):
|
|
113
|
+
"""
|
|
114
|
+
Runs the task with the provided parameters.
|
|
115
|
+
:param prev_checkpoint: Previous checkpoint path to resume from.
|
|
116
|
+
:param checkpoint_path: Checkpoint path to save the current state.
|
|
117
|
+
:param image_cache: Image cache to use for the task.
|
|
118
|
+
:param name: Action name to run.
|
|
119
|
+
:param run_name: Parent run name to use for the task.
|
|
120
|
+
:param domain: domain to run the task in.
|
|
121
|
+
:param project: project to run the task in.
|
|
122
|
+
:param org: organization to run the task in.
|
|
123
|
+
:param task: The task template to run.
|
|
124
|
+
:param raw_data_path: The path to the raw data.
|
|
125
|
+
:param output_path: The path to save the output.
|
|
126
|
+
:param run_base_dir: The base directory for the run.
|
|
127
|
+
:param version: The version of the task to run.
|
|
128
|
+
:param controller: The controller to use for the task.
|
|
129
|
+
:param code_bundle: Optional code bundle for the task.
|
|
130
|
+
:param input_path: Optional input path for the task.
|
|
131
|
+
:return: The loaded task template.
|
|
132
|
+
"""
|
|
133
|
+
logger.info(f"[rusty] Running task {task.name}")
|
|
134
|
+
await contextual_run(
|
|
135
|
+
extract_download_run_upload,
|
|
136
|
+
task,
|
|
137
|
+
action=ActionID(name=name, org=org, project=project, domain=domain, run_name=run_name),
|
|
138
|
+
version=version,
|
|
139
|
+
controller=controller,
|
|
140
|
+
raw_data_path=RawDataPath(path=raw_data_path),
|
|
141
|
+
output_path=output_path,
|
|
142
|
+
run_base_dir=run_base_dir,
|
|
143
|
+
checkpoints=Checkpoints(prev_checkpoint_path=prev_checkpoint, checkpoint_path=checkpoint_path),
|
|
144
|
+
code_bundle=code_bundle,
|
|
145
|
+
input_path=input_path,
|
|
146
|
+
image_cache=ImageCache.from_transport(image_cache) if image_cache else None,
|
|
147
|
+
)
|
|
148
|
+
logger.info(f"[rusty] Finished task {task.name}")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
async def ping(name: str) -> str:
|
|
152
|
+
"""
|
|
153
|
+
A simple hello world function to test the Rusty entrypoint.
|
|
154
|
+
"""
|
|
155
|
+
print(f"Received ping request from {name} in Rusty!")
|
|
156
|
+
return f"pong from Rusty to {name}!"
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
async def hello(name: str):
|
|
160
|
+
"""
|
|
161
|
+
A simple hello world function to test the Rusty entrypoint.
|
|
162
|
+
:param name: The name to greet.
|
|
163
|
+
:return: A greeting message.
|
|
164
|
+
"""
|
|
165
|
+
print(f"Received hello request in Rusty with name: {name}!")
|
|
@@ -4,10 +4,9 @@ It includes a Resolver interface for loading tasks, and functions to load classe
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
import copy
|
|
7
|
-
import importlib
|
|
8
7
|
import typing
|
|
9
8
|
from datetime import timedelta
|
|
10
|
-
from typing import Optional,
|
|
9
|
+
from typing import Optional, cast
|
|
11
10
|
|
|
12
11
|
from flyteidl.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
|
|
13
12
|
from google.protobuf import duration_pb2, wrappers_pb2
|
|
@@ -21,39 +20,17 @@ from flyte._secret import SecretRequest, secrets_from_request
|
|
|
21
20
|
from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
|
|
22
21
|
from flyte.models import CodeBundle, SerializationContext
|
|
23
22
|
|
|
23
|
+
from ... import ReusePolicy
|
|
24
24
|
from ..._retry import RetryStrategy
|
|
25
25
|
from ..._timeout import TimeoutType, timeout_from_request
|
|
26
26
|
from .resources_serde import get_proto_extended_resources, get_proto_resources
|
|
27
|
+
from .reuse import add_reusable
|
|
27
28
|
from .types_serde import transform_native_to_typed_interface
|
|
28
29
|
|
|
29
30
|
_MAX_ENV_NAME_LENGTH = 63 # Maximum length for environment names
|
|
30
31
|
_MAX_TASK_SHORT_NAME_LENGTH = 63 # Maximum length for task short names
|
|
31
32
|
|
|
32
33
|
|
|
33
|
-
def load_class(qualified_name) -> Type:
|
|
34
|
-
"""
|
|
35
|
-
Load a class from a qualified name. The qualified name should be in the format 'module.ClassName'.
|
|
36
|
-
:param qualified_name: The qualified name of the class to load.
|
|
37
|
-
:return: The class object.
|
|
38
|
-
"""
|
|
39
|
-
module_name, class_name = qualified_name.rsplit(".", 1) # Split module and class
|
|
40
|
-
module = importlib.import_module(module_name) # Import the module
|
|
41
|
-
return getattr(module, class_name) # Retrieve the class
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def load_task(resolver: str, *resolver_args: str) -> TaskTemplate:
|
|
45
|
-
"""
|
|
46
|
-
Load a task from a resolver. This is a placeholder function.
|
|
47
|
-
|
|
48
|
-
:param resolver: The resolver to use to load the task.
|
|
49
|
-
:param resolver_args: Arguments to pass to the resolver.
|
|
50
|
-
:return: The loaded task.
|
|
51
|
-
"""
|
|
52
|
-
resolver_class = load_class(resolver)
|
|
53
|
-
resolver_instance = resolver_class()
|
|
54
|
-
return resolver_instance.load_task(resolver_args)
|
|
55
|
-
|
|
56
|
-
|
|
57
34
|
def translate_task_to_wire(
|
|
58
35
|
task: TaskTemplate,
|
|
59
36
|
serialization_context: SerializationContext,
|
|
@@ -137,23 +114,22 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
|
|
|
137
114
|
name=task.name,
|
|
138
115
|
version=serialize_context.version,
|
|
139
116
|
)
|
|
140
|
-
# TODO, there will be tasks that do not have images, handle that case
|
|
141
|
-
# if task.parent_env is None:
|
|
142
|
-
# raise ValueError(f"Task {task.name} must have a parent environment")
|
|
143
117
|
|
|
144
118
|
# TODO Add support for SQL, extra_config, custom
|
|
145
119
|
extra_config: typing.Dict[str, str] = {}
|
|
146
|
-
custom = {} # type: ignore
|
|
147
120
|
|
|
148
|
-
sql = None
|
|
149
121
|
if task.pod_template and not isinstance(task.pod_template, str):
|
|
150
|
-
container = None
|
|
151
122
|
pod = _get_k8s_pod(_get_urun_container(serialize_context, task), task.pod_template)
|
|
152
123
|
extra_config[_PRIMARY_CONTAINER_NAME_FIELD] = task.pod_template.primary_container_name
|
|
124
|
+
container = None
|
|
153
125
|
else:
|
|
154
126
|
container = _get_urun_container(serialize_context, task)
|
|
155
127
|
pod = None
|
|
156
128
|
|
|
129
|
+
custom = task.custom_config(serialize_context)
|
|
130
|
+
|
|
131
|
+
sql = None
|
|
132
|
+
|
|
157
133
|
# -------------- CACHE HANDLING ----------------------
|
|
158
134
|
task_cache = cache_from_request(task.cache)
|
|
159
135
|
cache_enabled = task_cache.is_enabled()
|
|
@@ -175,7 +151,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
|
|
|
175
151
|
else:
|
|
176
152
|
logger.debug(f"Cache disabled for task {task.name}")
|
|
177
153
|
|
|
178
|
-
|
|
154
|
+
task_template = tasks_pb2.TaskTemplate(
|
|
179
155
|
id=task_id,
|
|
180
156
|
type=task.task_type,
|
|
181
157
|
metadata=tasks_pb2.TaskMetadata(
|
|
@@ -195,7 +171,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
|
|
|
195
171
|
generates_deck=wrappers_pb2.BoolValue(value=task.report),
|
|
196
172
|
),
|
|
197
173
|
interface=transform_native_to_typed_interface(task.native_interface),
|
|
198
|
-
custom=custom,
|
|
174
|
+
custom=custom if len(custom) > 0 else None,
|
|
199
175
|
container=container,
|
|
200
176
|
task_type_version=task.task_type_version,
|
|
201
177
|
security_context=get_security_context(task.secrets),
|
|
@@ -205,6 +181,20 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
|
|
|
205
181
|
extended_resources=get_proto_extended_resources(task.resources),
|
|
206
182
|
)
|
|
207
183
|
|
|
184
|
+
if task.reusable is not None:
|
|
185
|
+
if not isinstance(task.reusable, ReusePolicy):
|
|
186
|
+
raise flyte.errors.RuntimeUserError(
|
|
187
|
+
"BadConfig", f"Expected ReusePolicy, got {type(task.reusable)} for task {task.name}"
|
|
188
|
+
)
|
|
189
|
+
env_name = None
|
|
190
|
+
if task.parent_env is not None:
|
|
191
|
+
env = task.parent_env()
|
|
192
|
+
if env is not None:
|
|
193
|
+
env_name = env.name
|
|
194
|
+
return add_reusable(task_template, task.reusable, serialize_context.code_bundle, env_name)
|
|
195
|
+
|
|
196
|
+
return task_template
|
|
197
|
+
|
|
208
198
|
|
|
209
199
|
def _get_urun_container(
|
|
210
200
|
serialize_context: SerializationContext, task_template: TaskTemplate
|
|
@@ -220,9 +210,10 @@ def _get_urun_container(
|
|
|
220
210
|
if not serialize_context.image_cache:
|
|
221
211
|
# This computes the image uri, computing hashes as necessary so can fail if done remotely.
|
|
222
212
|
img_uri = task_template.image.uri
|
|
223
|
-
elif image_id not in serialize_context.image_cache.image_lookup:
|
|
213
|
+
elif serialize_context.image_cache and image_id not in serialize_context.image_cache.image_lookup:
|
|
224
214
|
img_uri = task_template.image.uri
|
|
225
215
|
from flyte._version import __version__
|
|
216
|
+
|
|
226
217
|
logger.warning(
|
|
227
218
|
f"Image {task_template.image} not found in the image cache: {serialize_context.image_cache.image_lookup}.\n"
|
|
228
219
|
f"This typically occurs when the Flyte SDK version (`{__version__}`) used in the task environment "
|
flyte/_pod.py
CHANGED
|
@@ -2,7 +2,8 @@ from dataclasses import dataclass, field
|
|
|
2
2
|
from typing import TYPE_CHECKING, Dict, Optional
|
|
3
3
|
|
|
4
4
|
if TYPE_CHECKING:
|
|
5
|
-
from
|
|
5
|
+
from flyteidl.core.tasks_pb2 import K8sPod
|
|
6
|
+
from kubernetes.client import ApiClient, V1PodSpec
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
|
|
@@ -17,3 +18,12 @@ class PodTemplate(object):
|
|
|
17
18
|
primary_container_name: str = _PRIMARY_CONTAINER_DEFAULT_NAME
|
|
18
19
|
labels: Optional[Dict[str, str]] = None
|
|
19
20
|
annotations: Optional[Dict[str, str]] = None
|
|
21
|
+
|
|
22
|
+
def to_k8s_pod(self) -> "K8sPod":
|
|
23
|
+
from flyteidl.core.tasks_pb2 import K8sObjectMetadata, K8sPod
|
|
24
|
+
|
|
25
|
+
return K8sPod(
|
|
26
|
+
metadata=K8sObjectMetadata(labels=self.labels, annotations=self.annotations),
|
|
27
|
+
pod_spec=ApiClient().sanitize_for_serialization(self.pod_spec),
|
|
28
|
+
primary_container_name=self.primary_container_name,
|
|
29
|
+
)
|
flyte/_resources.py
CHANGED
|
@@ -1,8 +1,15 @@
|
|
|
1
|
-
from dataclasses import dataclass
|
|
2
|
-
from typing import Literal, Optional, Tuple, Union, get_args
|
|
1
|
+
from dataclasses import dataclass, fields
|
|
2
|
+
from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union, get_args
|
|
3
3
|
|
|
4
4
|
import rich.repr
|
|
5
5
|
|
|
6
|
+
from flyte._pod import _PRIMARY_CONTAINER_DEFAULT_NAME
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from kubernetes.client import V1PodSpec
|
|
10
|
+
|
|
11
|
+
PRIMARY_CONTAINER_DEFAULT_NAME = "primary"
|
|
12
|
+
|
|
6
13
|
GPUType = Literal["T4", "A100", "A100 80G", "H100", "L4", "L40s"]
|
|
7
14
|
GPUQuantity = Literal[1, 2, 3, 4, 5, 6, 7, 8]
|
|
8
15
|
A100Parts = Literal["1g.5gb", "2g.10gb", "3g.20gb", "4g.20gb", "7g.40gb"]
|
|
@@ -224,3 +231,61 @@ class Resources:
|
|
|
224
231
|
if self.shm == "auto":
|
|
225
232
|
return ""
|
|
226
233
|
return self.shm
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _check_resource_is_singular(resource: Resources):
|
|
237
|
+
"""
|
|
238
|
+
Raise a value error if the resource has a tuple.
|
|
239
|
+
"""
|
|
240
|
+
for field in fields(resource):
|
|
241
|
+
value = getattr(resource, field.name)
|
|
242
|
+
if isinstance(value, (tuple, list)):
|
|
243
|
+
raise ValueError(f"{value} can not be a list or tuple")
|
|
244
|
+
return resource
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def pod_spec_from_resources(
|
|
248
|
+
primary_container_name: str = _PRIMARY_CONTAINER_DEFAULT_NAME,
|
|
249
|
+
requests: Optional[Resources] = None,
|
|
250
|
+
limits: Optional[Resources] = None,
|
|
251
|
+
k8s_gpu_resource_key: str = "nvidia.com/gpu",
|
|
252
|
+
) -> "V1PodSpec":
|
|
253
|
+
from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements
|
|
254
|
+
|
|
255
|
+
def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resource_key: str):
|
|
256
|
+
if resources is None:
|
|
257
|
+
return None
|
|
258
|
+
|
|
259
|
+
resources_map = {
|
|
260
|
+
"cpu": "cpu",
|
|
261
|
+
"memory": "memory",
|
|
262
|
+
"gpu": k8s_gpu_resource_key,
|
|
263
|
+
"ephemeral_storage": "ephemeral-storage",
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
k8s_pod_resources = {}
|
|
267
|
+
|
|
268
|
+
_check_resource_is_singular(resources)
|
|
269
|
+
for resource in fields(resources):
|
|
270
|
+
resource_value = getattr(resources, resource.name)
|
|
271
|
+
if resource_value is not None:
|
|
272
|
+
k8s_pod_resources[resources_map[resource.name]] = resource_value
|
|
273
|
+
|
|
274
|
+
return k8s_pod_resources
|
|
275
|
+
|
|
276
|
+
requests = _construct_k8s_pods_resources(resources=requests, k8s_gpu_resource_key=k8s_gpu_resource_key)
|
|
277
|
+
limits = _construct_k8s_pods_resources(resources=limits, k8s_gpu_resource_key=k8s_gpu_resource_key)
|
|
278
|
+
requests = requests or limits
|
|
279
|
+
limits = limits or requests
|
|
280
|
+
|
|
281
|
+
return V1PodSpec(
|
|
282
|
+
containers=[
|
|
283
|
+
V1Container(
|
|
284
|
+
name=primary_container_name,
|
|
285
|
+
resources=V1ResourceRequirements(
|
|
286
|
+
requests=requests,
|
|
287
|
+
limits=limits,
|
|
288
|
+
),
|
|
289
|
+
)
|
|
290
|
+
]
|
|
291
|
+
)
|
flyte/_reusable_environment.py
CHANGED
|
@@ -2,6 +2,8 @@ from dataclasses import dataclass
|
|
|
2
2
|
from datetime import timedelta
|
|
3
3
|
from typing import Optional, Tuple, Union
|
|
4
4
|
|
|
5
|
+
from flyte._logging import logger
|
|
6
|
+
|
|
5
7
|
|
|
6
8
|
@dataclass
|
|
7
9
|
class ReusePolicy:
|
|
@@ -14,12 +16,65 @@ class ReusePolicy:
|
|
|
14
16
|
Caution: It is important to note that the environment is shared, so managing memory and resources is important.
|
|
15
17
|
|
|
16
18
|
:param replicas: Either a single int representing number of replicas or a tuple of two ints representing
|
|
17
|
-
the min and max
|
|
19
|
+
the min and max.
|
|
18
20
|
:param idle_ttl: The maximum idle duration for an environment replica, specified as either seconds (int) or a
|
|
19
21
|
timedelta. If not set, the environment's global default will be used.
|
|
20
22
|
When a replica remains idle — meaning no tasks are running — for this duration, it will be automatically
|
|
21
23
|
terminated.
|
|
24
|
+
:param concurrency: The maximum number of tasks that can run concurrently in one instance of the environment.
|
|
25
|
+
Concurrency of greater than 1 is only supported only for `async` tasks.
|
|
26
|
+
:param reuse_salt: Optional string used to control environment reuse.
|
|
27
|
+
If set, the environment will be reused even if the code bundle changes.
|
|
28
|
+
To force a new environment, either set this to `None` or change its value.
|
|
29
|
+
|
|
30
|
+
Example:
|
|
31
|
+
reuse_salt = "v1" # Environment is reused
|
|
32
|
+
reuse_salt = "v2" # Forces environment recreation
|
|
22
33
|
"""
|
|
23
34
|
|
|
24
|
-
replicas: Union[int, Tuple[int, int]] =
|
|
35
|
+
replicas: Union[int, Tuple[int, int]] = 2
|
|
25
36
|
idle_ttl: Optional[Union[int, timedelta]] = None
|
|
37
|
+
reuse_salt: str | None = None
|
|
38
|
+
concurrency: int = 1
|
|
39
|
+
|
|
40
|
+
def __post_init__(self):
|
|
41
|
+
if self.replicas is None:
|
|
42
|
+
raise ValueError("replicas cannot be None")
|
|
43
|
+
if isinstance(self.replicas, int):
|
|
44
|
+
self.replicas = (self.replicas, self.replicas)
|
|
45
|
+
elif not isinstance(self.replicas, tuple):
|
|
46
|
+
raise ValueError("replicas must be an int or a tuple of two ints")
|
|
47
|
+
elif len(self.replicas) != 2:
|
|
48
|
+
raise ValueError("replicas must be an int or a tuple of two ints")
|
|
49
|
+
|
|
50
|
+
if self.idle_ttl:
|
|
51
|
+
if isinstance(self.idle_ttl, int):
|
|
52
|
+
self.idle_ttl = timedelta(seconds=int(self.idle_ttl))
|
|
53
|
+
elif not isinstance(self.idle_ttl, timedelta):
|
|
54
|
+
raise ValueError("idle_ttl must be an int (seconds) or a timedelta")
|
|
55
|
+
|
|
56
|
+
if self.replicas[1] == 1 and self.concurrency == 1:
|
|
57
|
+
logger.warning(
|
|
58
|
+
"It is recommended to use a minimum of 2 replicas, to avoid starvation. "
|
|
59
|
+
"Starvation can occur if a task is running and no other replicas are available to handle new tasks."
|
|
60
|
+
"Options, increase concurrency, increase replicas or turn-off reuse for the parent task, "
|
|
61
|
+
"that runs child tasks."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def ttl(self) -> timedelta | None:
|
|
66
|
+
"""
|
|
67
|
+
Returns the idle TTL as a timedelta. If idle_ttl is not set, returns the global default.
|
|
68
|
+
"""
|
|
69
|
+
if self.idle_ttl is None:
|
|
70
|
+
return None
|
|
71
|
+
if isinstance(self.idle_ttl, timedelta):
|
|
72
|
+
return self.idle_ttl
|
|
73
|
+
return timedelta(seconds=self.idle_ttl)
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def max_replicas(self) -> int:
|
|
77
|
+
"""
|
|
78
|
+
Returns the maximum number of replicas.
|
|
79
|
+
"""
|
|
80
|
+
return self.replicas[1] if isinstance(self.replicas, tuple) else self.replicas
|