flyte 0.2.0b24__py3-none-any.whl → 0.2.0b26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

@@ -456,6 +456,10 @@ class DockerImageBuilder(ImageBuilder):
456
456
  else:
457
457
  click.secho(f"Run command: {concat_command} ", fg="blue")
458
458
 
459
- await asyncio.to_thread(subprocess.run, command, check=True)
459
+ try:
460
+ await asyncio.to_thread(subprocess.run, command, check=True)
461
+ except subprocess.CalledProcessError as e:
462
+ logger.error(f"Failed to build image: {e}")
463
+ raise RuntimeError(f"Failed to build image: {e}")
460
464
 
461
465
  return image.uri
@@ -138,9 +138,17 @@ async def _validate_configuration(image: Image) -> Tuple[str, Optional[str]]:
138
138
 
139
139
  if any(context_path.iterdir()):
140
140
  # If there are files in the context directory, upload it
141
- _, context_url = await remote.upload_file.aio(
142
- Path(shutil.make_archive(str(tmp_path / "context"), "xztar", context_path))
143
- )
141
+ archive = Path(shutil.make_archive(str(tmp_path / "context"), "xztar", context_path))
142
+ st = archive.stat()
143
+ if st.st_size > 5 * 1024 * 1024:
144
+ logger.warning(
145
+ click.style(
146
+ f"Context size is {st.st_size / (1024 * 1024):.2f} MB, which is larger than 5 MB. "
147
+ "Upload and build speed will be impacted.",
148
+ fg="yellow",
149
+ )
150
+ )
151
+ _, context_url = await remote.upload_file.aio(archive)
144
152
  else:
145
153
  context_url = ""
146
154
 
@@ -150,6 +150,22 @@ async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwa
150
150
  )
151
151
 
152
152
 
153
+ async def convert_from_inputs_to_native(native_interface: NativeInterface, inputs: Inputs) -> Dict[str, Any]:
154
+ """
155
+ Converts the inputs from a run definition proto to a native Python dictionary.
156
+ :param native_interface: The native interface of the task.
157
+ :param inputs: The run definition inputs proto.
158
+ :return: A dictionary of input names to their native Python values.
159
+ """
160
+ if not inputs or not inputs.proto_inputs or not inputs.proto_inputs.literals:
161
+ return {}
162
+
163
+ literals = {named_literal.name: named_literal.value for named_literal in inputs.proto_inputs.literals}
164
+ return await TypeEngine.literal_map_to_kwargs(
165
+ literals_pb2.LiteralMap(literals=literals), native_interface.get_input_types()
166
+ )
167
+
168
+
153
169
  async def convert_from_native_to_outputs(o: Any, interface: NativeInterface, task_name: str = "") -> Outputs:
154
170
  # Always make it a tuple even if it's just one item to simplify logic below
155
171
  if not isinstance(o, tuple):
@@ -1,4 +1,5 @@
1
- from typing import List, Optional, Tuple
1
+ import importlib
2
+ from typing import List, Optional, Tuple, Type
2
3
 
3
4
  import flyte.errors
4
5
  from flyte._code_bundle import download_bundle
@@ -10,7 +11,6 @@ from flyte._task import TaskTemplate
10
11
  from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
11
12
 
12
13
  from .convert import Error, Inputs, Outputs
13
- from .task_serde import load_task
14
14
  from .taskrunner import (
15
15
  convert_and_run,
16
16
  extract_download_run_upload,
@@ -51,25 +51,68 @@ async def direct_dispatch(
51
51
  )
52
52
 
53
53
 
54
+ def load_class(qualified_name) -> Type:
55
+ """
56
+ Load a class from a qualified name. The qualified name should be in the format 'module.ClassName'.
57
+ :param qualified_name: The qualified name of the class to load.
58
+ :return: The class object.
59
+ """
60
+ module_name, class_name = qualified_name.rsplit(".", 1) # Split module and class
61
+ module = importlib.import_module(module_name) # Import the module
62
+ return getattr(module, class_name) # Retrieve the class
63
+
64
+
65
+ def load_task(resolver: str, *resolver_args: str) -> TaskTemplate:
66
+ """
67
+ Load a task from a resolver. This is a placeholder function.
68
+
69
+ :param resolver: The resolver to use to load the task.
70
+ :param resolver_args: Arguments to pass to the resolver.
71
+ :return: The loaded task.
72
+ """
73
+ resolver_class = load_class(resolver)
74
+ resolver_instance = resolver_class()
75
+ return resolver_instance.load_task(resolver_args)
76
+
77
+
78
+ def load_pkl_task(code_bundle: CodeBundle) -> TaskTemplate:
79
+ """
80
+ Loads a task from a pickled code bundle.
81
+ :param code_bundle: The code bundle to load the task from.
82
+ :return: The loaded task template.
83
+ """
84
+ logger.debug(f"Loading task from pkl: {code_bundle.downloaded_path}")
85
+ try:
86
+ import gzip
87
+
88
+ import cloudpickle
89
+
90
+ with gzip.open(str(code_bundle.downloaded_path), "rb") as f:
91
+ return cloudpickle.load(f)
92
+ except Exception as e:
93
+ logger.exception(f"Failed to load pickled task from {code_bundle.downloaded_path}. Reason: {e!s}")
94
+ raise
95
+
96
+
97
+ async def download_code_bundle(code_bundle: CodeBundle) -> CodeBundle:
98
+ """
99
+ Downloads the code bundle if it is not already downloaded.
100
+ :param code_bundle: The code bundle to download.
101
+ :return: The code bundle with the downloaded path.
102
+ """
103
+ logger.debug(f"Downloading {code_bundle}")
104
+ downloaded_path = await download_bundle(code_bundle)
105
+ return code_bundle.with_downloaded_path(downloaded_path)
106
+
107
+
54
108
  async def _download_and_load_task(
55
109
  code_bundle: CodeBundle | None, resolver: str | None = None, resolver_args: List[str] | None = None
56
110
  ) -> TaskTemplate:
57
111
  if code_bundle and (code_bundle.tgz or code_bundle.pkl):
58
112
  logger.debug(f"Downloading {code_bundle}")
59
- downloaded_path = await download_bundle(code_bundle)
60
- code_bundle = code_bundle.with_downloaded_path(downloaded_path)
113
+ code_bundle = await download_code_bundle(code_bundle)
61
114
  if code_bundle.pkl:
62
- try:
63
- logger.debug(f"Loading task from pkl: {code_bundle.downloaded_path}")
64
- import gzip
65
-
66
- import cloudpickle
67
-
68
- with gzip.open(str(code_bundle.downloaded_path), "rb") as f:
69
- return cloudpickle.load(f)
70
- except Exception as e:
71
- logger.exception(f"Failed to load pickled task from {code_bundle.downloaded_path}. Reason: {e!s}")
72
- raise
115
+ return load_pkl_task(code_bundle)
73
116
 
74
117
  if not resolver or not resolver_args:
75
118
  raise flyte.errors.RuntimeSystemError(
@@ -0,0 +1,121 @@
1
+ import hashlib
2
+ import typing
3
+ from venv import logger
4
+
5
+ from flyteidl.core import tasks_pb2
6
+
7
+ import flyte.errors
8
+ from flyte import ReusePolicy
9
+ from flyte._pod import _PRIMARY_CONTAINER_DEFAULT_NAME, _PRIMARY_CONTAINER_NAME_FIELD
10
+ from flyte.models import CodeBundle
11
+
12
+
13
+ def extract_unique_id_and_image(
14
+ env_name: str,
15
+ code_bundle: CodeBundle | None,
16
+ task: tasks_pb2.TaskTemplate,
17
+ reuse_policy: ReusePolicy,
18
+ ) -> typing.Tuple[str, str]:
19
+ """
20
+ Compute a unique ID for the task based on its name, version, image URI, and code bundle.
21
+ :param env_name: Name of the reusable environment.
22
+ :param reuse_policy: The reuse policy for the task.
23
+ :param task: The task template.
24
+ :param code_bundle: The code bundle associated with the task.
25
+ :return: A unique ID string and the image URI.
26
+ """
27
+ image = ""
28
+ container_ser = ""
29
+ if task.HasField("container"):
30
+ copied_container = tasks_pb2.Container()
31
+ copied_container.CopyFrom(task.container)
32
+ copied_container.args.clear() # Clear args to ensure deterministic serialization
33
+ container_ser = copied_container.SerializeToString(deterministic=True)
34
+ image = copied_container.image
35
+
36
+ if task.HasField("k8s_pod"):
37
+ # Clear args to ensure deterministic serialization
38
+ copied_k8s_pod = tasks_pb2.K8sPod()
39
+ copied_k8s_pod.CopyFrom(task.k8s_pod)
40
+ if task.config is not None:
41
+ primary_container_name = task.config[_PRIMARY_CONTAINER_NAME_FIELD]
42
+ else:
43
+ primary_container_name = _PRIMARY_CONTAINER_DEFAULT_NAME
44
+ for container in copied_k8s_pod.pod_spec["containers"]:
45
+ if "name" in container and container["name"] == primary_container_name:
46
+ image = container["image"]
47
+ del container["args"]
48
+ container_ser = copied_k8s_pod.SerializeToString(deterministic=True)
49
+
50
+ components = f"{env_name}:{container_ser}"
51
+ if isinstance(reuse_policy.replicas, tuple):
52
+ components += f":{reuse_policy.replicas[0]}:{reuse_policy.replicas[1]}"
53
+ else:
54
+ components += f":{reuse_policy.replicas}"
55
+ if reuse_policy.ttl is not None:
56
+ components += f":{reuse_policy.ttl.total_seconds()}"
57
+ if reuse_policy.reuse_salt is None and code_bundle is not None:
58
+ components += f":{code_bundle.computed_version}"
59
+ else:
60
+ components += f":{reuse_policy.reuse_salt}"
61
+ if task.security_context is not None:
62
+ security_ctx_str = task.security_context.SerializeToString(deterministic=True)
63
+ components += f":{security_ctx_str}"
64
+ if task.metadata.interruptible is not None:
65
+ components += f":{task.metadata.interruptible}"
66
+ if task.metadata.pod_template_name is not None:
67
+ components += f":{task.metadata.pod_template_name}"
68
+ sha256 = hashlib.sha256()
69
+ sha256.update(components.encode("utf-8"))
70
+ return sha256.hexdigest(), image
71
+
72
+
73
+ def add_reusable(
74
+ task: tasks_pb2.TaskTemplate,
75
+ reuse_policy: ReusePolicy,
76
+ code_bundle: CodeBundle | None,
77
+ parent_env_name: str | None = None,
78
+ ) -> tasks_pb2.TaskTemplate:
79
+ """
80
+ Convert a ReusePolicy to a custom configuration dictionary.
81
+
82
+ :param task: The task to which the reusable policy will be added.
83
+ :param reuse_policy: The reuse policy to apply.
84
+ :param code_bundle: The code bundle associated with the task.
85
+ :param parent_env_name: The name of the parent environment, if any.
86
+ :return: The modified task with the reusable policy added.
87
+ """
88
+ if reuse_policy is None:
89
+ return task
90
+
91
+ if task.HasField("custom"):
92
+ raise flyte.errors.RuntimeUserError(
93
+ "BadConfiguration", "Plugins do not support reusable policy. Only container tasks and pods."
94
+ )
95
+
96
+ logger.debug(f"Adding reusable policy for task: {task.id.name}")
97
+ name = parent_env_name if parent_env_name else ""
98
+ if parent_env_name is None:
99
+ name = task.id.name.split(".")[0]
100
+
101
+ version, image_uri = extract_unique_id_and_image(
102
+ env_name=name, code_bundle=code_bundle, task=task, reuse_policy=reuse_policy
103
+ )
104
+
105
+ task.custom = {
106
+ "name": name,
107
+ "version": version[:15], # Use only the first 15 characters for the version
108
+ "type": "actor",
109
+ "spec": {
110
+ "container_image": image_uri,
111
+ "backlog_length": None,
112
+ "parallelism": reuse_policy.concurrency,
113
+ "replica_count": reuse_policy.max_replicas,
114
+ "ttl_seconds": reuse_policy.ttl.total_seconds() if reuse_policy.ttl else None,
115
+ },
116
+ }
117
+
118
+ task.type = "actor"
119
+ logger.info(f"Reusable task {task.id.name} with config {task.custom}")
120
+
121
+ return task
@@ -0,0 +1,165 @@
1
+ import sys
2
+ from typing import Any, List, Tuple
3
+
4
+ from flyte._context import contextual_run
5
+ from flyte._internal.controllers import Controller
6
+ from flyte._internal.controllers import create_controller as _create_controller
7
+ from flyte._internal.imagebuild.image_builder import ImageCache
8
+ from flyte._internal.runtime.entrypoints import download_code_bundle, load_pkl_task, load_task
9
+ from flyte._internal.runtime.taskrunner import extract_download_run_upload
10
+ from flyte._logging import logger
11
+ from flyte._task import TaskTemplate
12
+ from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
13
+
14
+
15
+ async def download_tgz(destination: str, version: str, tgz: str) -> CodeBundle:
16
+ """
17
+ Downloads and loads the task from the code bundle or resolver.
18
+ :param tgz: The path to the task template in a tar.gz format.
19
+ :param destination: The path to save the downloaded task template.
20
+ :param version: The version of the task to load.
21
+ :return: The CodeBundle object.
22
+ """
23
+ logger.info(f"[rusty] Downloading tgz code bundle from {tgz} to {destination} with version {version}")
24
+ sys.path.insert(0, ".")
25
+
26
+ code_bundle = CodeBundle(
27
+ tgz=tgz,
28
+ destination=destination,
29
+ computed_version=version,
30
+ )
31
+ return await download_code_bundle(code_bundle)
32
+
33
+
34
+ async def download_load_pkl(destination: str, version: str, pkl: str) -> Tuple[CodeBundle, TaskTemplate]:
35
+ """
36
+ Downloads and loads the task from the code bundle or resolver.
37
+ :param pkl: The path to the task template in a pickle format.
38
+ :param destination: The path to save the downloaded task template.
39
+ :param version: The version of the task to load.
40
+ :return: The CodeBundle object.
41
+ """
42
+ logger.info(f"[rusty] Downloading pkl code bundle from {pkl} to {destination} with version {version}")
43
+ sys.path.insert(0, ".")
44
+
45
+ code_bundle = CodeBundle(
46
+ pkl=pkl,
47
+ destination=destination,
48
+ computed_version=version,
49
+ )
50
+ code_bundle = await download_code_bundle(code_bundle)
51
+ return code_bundle, load_pkl_task(code_bundle)
52
+
53
+
54
+ def load_task_from_code_bundle(resolver: str, resolver_args: List[str]) -> TaskTemplate:
55
+ """
56
+ Loads the task from the code bundle or resolver.
57
+ :param resolver: The resolver to use to load the task.
58
+ :param resolver_args: The arguments to pass to the resolver.
59
+ :return: The loaded task template.
60
+ """
61
+ logger.debug(f"[rusty] Loading task from code bundle {resolver} with args: {resolver_args}")
62
+ return load_task(resolver, *resolver_args)
63
+
64
+
65
+ async def create_controller(
66
+ endpoint: str = "host.docker.internal:8090",
67
+ insecure: bool = False,
68
+ api_key: str | None = None,
69
+ ) -> Controller:
70
+ """
71
+ Creates a controller instance for remote operations.
72
+ :param endpoint:
73
+ :param insecure:
74
+ :param api_key:
75
+ :return:
76
+ """
77
+ logger.info(f"[rusty] Creating controller with endpoint {endpoint}")
78
+ from flyte._initialize import init
79
+
80
+ # TODO Currently refrence tasks are not supported in Rusty.
81
+ await init.aio()
82
+ controller_kwargs: dict[str, Any] = {"insecure": insecure}
83
+ if api_key:
84
+ logger.info("Using api key from environment")
85
+ controller_kwargs["api_key"] = api_key
86
+ else:
87
+ controller_kwargs["endpoint"] = endpoint
88
+ if "localhost" in endpoint or "docker" in endpoint:
89
+ controller_kwargs["insecure"] = True
90
+ logger.debug(f"Using controller endpoint: {endpoint} with kwargs: {controller_kwargs}")
91
+
92
+ return _create_controller(ct="remote", **controller_kwargs)
93
+
94
+
95
+ async def run_task(
96
+ task: TaskTemplate,
97
+ controller: Controller,
98
+ org: str,
99
+ project: str,
100
+ domain: str,
101
+ run_name: str,
102
+ name: str,
103
+ raw_data_path: str,
104
+ output_path: str,
105
+ run_base_dir: str,
106
+ version: str,
107
+ image_cache: str | None = None,
108
+ checkpoint_path: str | None = None,
109
+ prev_checkpoint: str | None = None,
110
+ code_bundle: CodeBundle | None = None,
111
+ input_path: str | None = None,
112
+ ):
113
+ """
114
+ Runs the task with the provided parameters.
115
+ :param prev_checkpoint: Previous checkpoint path to resume from.
116
+ :param checkpoint_path: Checkpoint path to save the current state.
117
+ :param image_cache: Image cache to use for the task.
118
+ :param name: Action name to run.
119
+ :param run_name: Parent run name to use for the task.
120
+ :param domain: domain to run the task in.
121
+ :param project: project to run the task in.
122
+ :param org: organization to run the task in.
123
+ :param task: The task template to run.
124
+ :param raw_data_path: The path to the raw data.
125
+ :param output_path: The path to save the output.
126
+ :param run_base_dir: The base directory for the run.
127
+ :param version: The version of the task to run.
128
+ :param controller: The controller to use for the task.
129
+ :param code_bundle: Optional code bundle for the task.
130
+ :param input_path: Optional input path for the task.
131
+ :return: The loaded task template.
132
+ """
133
+ logger.info(f"[rusty] Running task {task.name}")
134
+ await contextual_run(
135
+ extract_download_run_upload,
136
+ task,
137
+ action=ActionID(name=name, org=org, project=project, domain=domain, run_name=run_name),
138
+ version=version,
139
+ controller=controller,
140
+ raw_data_path=RawDataPath(path=raw_data_path),
141
+ output_path=output_path,
142
+ run_base_dir=run_base_dir,
143
+ checkpoints=Checkpoints(prev_checkpoint_path=prev_checkpoint, checkpoint_path=checkpoint_path),
144
+ code_bundle=code_bundle,
145
+ input_path=input_path,
146
+ image_cache=ImageCache.from_transport(image_cache) if image_cache else None,
147
+ )
148
+ logger.info(f"[rusty] Finished task {task.name}")
149
+
150
+
151
+ async def ping(name: str) -> str:
152
+ """
153
+ A simple hello world function to test the Rusty entrypoint.
154
+ """
155
+ print(f"Received ping request from {name} in Rusty!")
156
+ return f"pong from Rusty to {name}!"
157
+
158
+
159
+ async def hello(name: str):
160
+ """
161
+ A simple hello world function to test the Rusty entrypoint.
162
+ :param name: The name to greet.
163
+ :return: A greeting message.
164
+ """
165
+ print(f"Received hello request in Rusty with name: {name}!")
@@ -4,10 +4,9 @@ It includes a Resolver interface for loading tasks, and functions to load classe
4
4
  """
5
5
 
6
6
  import copy
7
- import importlib
8
7
  import typing
9
8
  from datetime import timedelta
10
- from typing import Optional, Type, cast
9
+ from typing import Optional, cast
11
10
 
12
11
  from flyteidl.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
13
12
  from google.protobuf import duration_pb2, wrappers_pb2
@@ -21,39 +20,17 @@ from flyte._secret import SecretRequest, secrets_from_request
21
20
  from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
22
21
  from flyte.models import CodeBundle, SerializationContext
23
22
 
23
+ from ... import ReusePolicy
24
24
  from ..._retry import RetryStrategy
25
25
  from ..._timeout import TimeoutType, timeout_from_request
26
26
  from .resources_serde import get_proto_extended_resources, get_proto_resources
27
+ from .reuse import add_reusable
27
28
  from .types_serde import transform_native_to_typed_interface
28
29
 
29
30
  _MAX_ENV_NAME_LENGTH = 63 # Maximum length for environment names
30
31
  _MAX_TASK_SHORT_NAME_LENGTH = 63 # Maximum length for task short names
31
32
 
32
33
 
33
- def load_class(qualified_name) -> Type:
34
- """
35
- Load a class from a qualified name. The qualified name should be in the format 'module.ClassName'.
36
- :param qualified_name: The qualified name of the class to load.
37
- :return: The class object.
38
- """
39
- module_name, class_name = qualified_name.rsplit(".", 1) # Split module and class
40
- module = importlib.import_module(module_name) # Import the module
41
- return getattr(module, class_name) # Retrieve the class
42
-
43
-
44
- def load_task(resolver: str, *resolver_args: str) -> TaskTemplate:
45
- """
46
- Load a task from a resolver. This is a placeholder function.
47
-
48
- :param resolver: The resolver to use to load the task.
49
- :param resolver_args: Arguments to pass to the resolver.
50
- :return: The loaded task.
51
- """
52
- resolver_class = load_class(resolver)
53
- resolver_instance = resolver_class()
54
- return resolver_instance.load_task(resolver_args)
55
-
56
-
57
34
  def translate_task_to_wire(
58
35
  task: TaskTemplate,
59
36
  serialization_context: SerializationContext,
@@ -137,23 +114,22 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
137
114
  name=task.name,
138
115
  version=serialize_context.version,
139
116
  )
140
- # TODO, there will be tasks that do not have images, handle that case
141
- # if task.parent_env is None:
142
- # raise ValueError(f"Task {task.name} must have a parent environment")
143
117
 
144
118
  # TODO Add support for SQL, extra_config, custom
145
119
  extra_config: typing.Dict[str, str] = {}
146
- custom = task.custom_config(serialize_context)
147
120
 
148
- sql = None
149
121
  if task.pod_template and not isinstance(task.pod_template, str):
150
- container = None
151
122
  pod = _get_k8s_pod(_get_urun_container(serialize_context, task), task.pod_template)
152
123
  extra_config[_PRIMARY_CONTAINER_NAME_FIELD] = task.pod_template.primary_container_name
124
+ container = None
153
125
  else:
154
126
  container = _get_urun_container(serialize_context, task)
155
127
  pod = None
156
128
 
129
+ custom = task.custom_config(serialize_context)
130
+
131
+ sql = None
132
+
157
133
  # -------------- CACHE HANDLING ----------------------
158
134
  task_cache = cache_from_request(task.cache)
159
135
  cache_enabled = task_cache.is_enabled()
@@ -175,7 +151,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
175
151
  else:
176
152
  logger.debug(f"Cache disabled for task {task.name}")
177
153
 
178
- return tasks_pb2.TaskTemplate(
154
+ task_template = tasks_pb2.TaskTemplate(
179
155
  id=task_id,
180
156
  type=task.task_type,
181
157
  metadata=tasks_pb2.TaskMetadata(
@@ -195,7 +171,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
195
171
  generates_deck=wrappers_pb2.BoolValue(value=task.report),
196
172
  ),
197
173
  interface=transform_native_to_typed_interface(task.native_interface),
198
- custom=custom,
174
+ custom=custom if len(custom) > 0 else None,
199
175
  container=container,
200
176
  task_type_version=task.task_type_version,
201
177
  security_context=get_security_context(task.secrets),
@@ -205,6 +181,20 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
205
181
  extended_resources=get_proto_extended_resources(task.resources),
206
182
  )
207
183
 
184
+ if task.reusable is not None:
185
+ if not isinstance(task.reusable, ReusePolicy):
186
+ raise flyte.errors.RuntimeUserError(
187
+ "BadConfig", f"Expected ReusePolicy, got {type(task.reusable)} for task {task.name}"
188
+ )
189
+ env_name = None
190
+ if task.parent_env is not None:
191
+ env = task.parent_env()
192
+ if env is not None:
193
+ env_name = env.name
194
+ return add_reusable(task_template, task.reusable, serialize_context.code_bundle, env_name)
195
+
196
+ return task_template
197
+
208
198
 
209
199
  def _get_urun_container(
210
200
  serialize_context: SerializationContext, task_template: TaskTemplate
@@ -2,6 +2,8 @@ from dataclasses import dataclass
2
2
  from datetime import timedelta
3
3
  from typing import Optional, Tuple, Union
4
4
 
5
+ from flyte._logging import logger
6
+
5
7
 
6
8
  @dataclass
7
9
  class ReusePolicy:
@@ -14,12 +16,65 @@ class ReusePolicy:
14
16
  Caution: It is important to note that the environment is shared, so managing memory and resources is important.
15
17
 
16
18
  :param replicas: Either a single int representing number of replicas or a tuple of two ints representing
17
- the min and max
19
+ the min and max.
18
20
  :param idle_ttl: The maximum idle duration for an environment replica, specified as either seconds (int) or a
19
21
  timedelta. If not set, the environment's global default will be used.
20
22
  When a replica remains idle — meaning no tasks are running — for this duration, it will be automatically
21
23
  terminated.
24
+ :param concurrency: The maximum number of tasks that can run concurrently in one instance of the environment.
25
+ Concurrency of greater than 1 is only supported only for `async` tasks.
26
+ :param reuse_salt: Optional string used to control environment reuse.
27
+ If set, the environment will be reused even if the code bundle changes.
28
+ To force a new environment, either set this to `None` or change its value.
29
+
30
+ Example:
31
+ reuse_salt = "v1" # Environment is reused
32
+ reuse_salt = "v2" # Forces environment recreation
22
33
  """
23
34
 
24
- replicas: Union[int, Tuple[int, int]] = 1
35
+ replicas: Union[int, Tuple[int, int]] = 2
25
36
  idle_ttl: Optional[Union[int, timedelta]] = None
37
+ reuse_salt: str | None = None
38
+ concurrency: int = 1
39
+
40
+ def __post_init__(self):
41
+ if self.replicas is None:
42
+ raise ValueError("replicas cannot be None")
43
+ if isinstance(self.replicas, int):
44
+ self.replicas = (self.replicas, self.replicas)
45
+ elif not isinstance(self.replicas, tuple):
46
+ raise ValueError("replicas must be an int or a tuple of two ints")
47
+ elif len(self.replicas) != 2:
48
+ raise ValueError("replicas must be an int or a tuple of two ints")
49
+
50
+ if self.idle_ttl:
51
+ if isinstance(self.idle_ttl, int):
52
+ self.idle_ttl = timedelta(seconds=int(self.idle_ttl))
53
+ elif not isinstance(self.idle_ttl, timedelta):
54
+ raise ValueError("idle_ttl must be an int (seconds) or a timedelta")
55
+
56
+ if self.replicas[1] == 1 and self.concurrency == 1:
57
+ logger.warning(
58
+ "It is recommended to use a minimum of 2 replicas, to avoid starvation. "
59
+ "Starvation can occur if a task is running and no other replicas are available to handle new tasks."
60
+ "Options, increase concurrency, increase replicas or turn-off reuse for the parent task, "
61
+ "that runs child tasks."
62
+ )
63
+
64
+ @property
65
+ def ttl(self) -> timedelta | None:
66
+ """
67
+ Returns the idle TTL as a timedelta. If idle_ttl is not set, returns the global default.
68
+ """
69
+ if self.idle_ttl is None:
70
+ return None
71
+ if isinstance(self.idle_ttl, timedelta):
72
+ return self.idle_ttl
73
+ return timedelta(seconds=self.idle_ttl)
74
+
75
+ @property
76
+ def max_replicas(self) -> int:
77
+ """
78
+ Returns the maximum number of replicas.
79
+ """
80
+ return self.replicas[1] if isinstance(self.replicas, tuple) else self.replicas
flyte/_secret.py CHANGED
@@ -45,6 +45,27 @@ class Secret:
45
45
  if not re.match(pattern, self.as_env_var):
46
46
  raise ValueError(f"Invalid environment variable name: {self.as_env_var}, must match {pattern}")
47
47
 
48
+ def stable_hash(self) -> str:
49
+ """
50
+ Deterministic, process-independent hash (as hex string).
51
+ """
52
+ import hashlib
53
+
54
+ data = (
55
+ self.key,
56
+ self.group or "",
57
+ str(self.mount) if self.mount else "",
58
+ self.as_env_var or "",
59
+ )
60
+ joined = "|".join(data)
61
+ return hashlib.sha256(joined.encode("utf-8")).hexdigest()
62
+
63
+ def __hash__(self) -> int:
64
+ """
65
+ Deterministic hash function for the Secret class.
66
+ """
67
+ return int(self.stable_hash()[:16], 16)
68
+
48
69
 
49
70
  SecretRequest = Union[str, Secret, List[str | Secret]]
50
71
 
@@ -59,3 +80,12 @@ def secrets_from_request(secrets: SecretRequest) -> List[Secret]:
59
80
  return [secrets]
60
81
  else:
61
82
  return [Secret(key=s) if isinstance(s, str) else s for s in secrets]
83
+
84
+
85
+ if __name__ == "__main__":
86
+ # Example usage
87
+ secret1 = Secret(key="MY_SECRET", mount=pathlib.Path("/path/to/secret"), as_env_var="MY_SECRET_ENV")
88
+ secret2 = Secret(
89
+ key="ANOTHER_SECRET",
90
+ )
91
+ print(hash(secret1), hash(secret2))