flyte 2.0.0b8__py3-none-any.whl → 2.0.0b13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flyte/__init__.py CHANGED
@@ -2,37 +2,7 @@
2
2
  Flyte SDK for authoring compound AI applications, services and workflows.
3
3
  """
4
4
 
5
- __all__ = [
6
- "GPU",
7
- "TPU",
8
- "Cache",
9
- "CachePolicy",
10
- "CacheRequest",
11
- "Device",
12
- "Environment",
13
- "Image",
14
- "PodTemplate",
15
- "Resources",
16
- "RetryStrategy",
17
- "ReusePolicy",
18
- "Secret",
19
- "SecretRequest",
20
- "TaskEnvironment",
21
- "Timeout",
22
- "TimeoutType",
23
- "__version__",
24
- "build",
25
- "build_images",
26
- "ctx",
27
- "deploy",
28
- "group",
29
- "init",
30
- "init_from_config",
31
- "map",
32
- "run",
33
- "trace",
34
- "with_runcontext",
35
- ]
5
+ from __future__ import annotations
36
6
 
37
7
  import sys
38
8
 
@@ -60,8 +30,62 @@ from ._version import __version__
60
30
  sys.excepthook = custom_excepthook
61
31
 
62
32
 
33
+ def _silence_grpc_warnings():
34
+ """
35
+ Silences gRPC warnings that can clutter the output.
36
+ """
37
+ import os
38
+
39
+ # Set environment variables for gRPC, this reduces log spew and avoids unnecessary warnings
40
+ # before importing grpc
41
+ if "GRPC_VERBOSITY" not in os.environ:
42
+ os.environ["GRPC_VERBOSITY"] = "ERROR"
43
+ os.environ["GRPC_CPP_MIN_LOG_LEVEL"] = "ERROR"
44
+ # Disable fork support (stops "skipping fork() handlers")
45
+ os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "0"
46
+ # Reduce absl/glog verbosity
47
+ os.environ["GLOG_minloglevel"] = "2"
48
+ os.environ["ABSL_LOG"] = "0"
49
+
50
+
51
+ _silence_grpc_warnings()
52
+
53
+
63
54
  def version() -> str:
64
55
  """
65
56
  Returns the version of the Flyte SDK.
66
57
  """
67
58
  return __version__
59
+
60
+
61
+ __all__ = [
62
+ "GPU",
63
+ "TPU",
64
+ "Cache",
65
+ "CachePolicy",
66
+ "CacheRequest",
67
+ "Device",
68
+ "Environment",
69
+ "Image",
70
+ "PodTemplate",
71
+ "Resources",
72
+ "RetryStrategy",
73
+ "ReusePolicy",
74
+ "Secret",
75
+ "SecretRequest",
76
+ "TaskEnvironment",
77
+ "Timeout",
78
+ "TimeoutType",
79
+ "__version__",
80
+ "build",
81
+ "build_images",
82
+ "ctx",
83
+ "deploy",
84
+ "group",
85
+ "init",
86
+ "init_from_config",
87
+ "map",
88
+ "run",
89
+ "trace",
90
+ "with_runcontext",
91
+ ]
flyte/_context.py CHANGED
@@ -97,7 +97,7 @@ class Context:
97
97
  def is_task_context(self) -> bool:
98
98
  """
99
99
  Returns true if the context is a task context
100
- :return:
100
+ :return: bool
101
101
  """
102
102
  return self.data.task_context is not None
103
103
 
flyte/_environment.py CHANGED
@@ -36,7 +36,7 @@ class Environment:
36
36
  :param name: Name of the environment
37
37
  :param image: Docker image to use for the environment. If set to "auto", will use the default image.
38
38
  :param resources: Resources to allocate for the environment.
39
- :param env: Environment variables to set for the environment.
39
+ :param env_vars: Environment variables to set for the environment.
40
40
  :param secrets: Secrets to inject into the environment.
41
41
  :param depends_on: Environment dependencies to hint, so when you deploy the environment, the dependencies are
42
42
  also deployed. This is useful when you have a set of environments that depend on each other.
@@ -47,7 +47,7 @@ class Environment:
47
47
  pod_template: Optional[Union[str, "V1PodTemplate"]] = None
48
48
  description: Optional[str] = None
49
49
  secrets: Optional[SecretRequest] = None
50
- env: Optional[Dict[str, str]] = None
50
+ env_vars: Optional[Dict[str, str]] = None
51
51
  resources: Optional[Resources] = None
52
52
  image: Union[str, Image, Literal["auto"]] = "auto"
53
53
 
@@ -75,7 +75,7 @@ class Environment:
75
75
  name: str,
76
76
  image: Optional[Union[str, Image, Literal["auto"]]] = None,
77
77
  resources: Optional[Resources] = None,
78
- env: Optional[Dict[str, str]] = None,
78
+ env_vars: Optional[Dict[str, str]] = None,
79
79
  secrets: Optional[SecretRequest] = None,
80
80
  depends_on: Optional[List[Environment]] = None,
81
81
  **kwargs: Any,
@@ -94,8 +94,8 @@ class Environment:
94
94
  kwargs["resources"] = self.resources
95
95
  if self.secrets is not None:
96
96
  kwargs["secrets"] = self.secrets
97
- if self.env is not None:
98
- kwargs["env"] = self.env
97
+ if self.env_vars is not None:
98
+ kwargs["env_vars"] = self.env_vars
99
99
  if self.pod_template is not None:
100
100
  kwargs["pod_template"] = self.pod_template
101
101
  if self.description is not None:
flyte/_image.py CHANGED
@@ -279,7 +279,7 @@ class DockerIgnore(Layer):
279
279
  @dataclass(frozen=True, repr=True)
280
280
  class CopyConfig(Layer):
281
281
  path_type: CopyConfigType = field(metadata={"identifier": True})
282
- src: Path = field(metadata={"identifier": True})
282
+ src: Path = field(metadata={"identifier": False})
283
283
  dst: str
284
284
  src_name: str = field(init=False)
285
285
 
@@ -451,11 +451,7 @@ class Image:
451
451
  # this default image definition may need to be updated once there is a released pypi version
452
452
  from flyte._version import __version__
453
453
 
454
- dev_mode = (
455
- (cls._is_editable_install() or (__version__ and "dev" in __version__))
456
- and not flyte_version
457
- and install_flyte
458
- )
454
+ dev_mode = (__version__ and "dev" in __version__) and not flyte_version and install_flyte
459
455
  if install_flyte is False:
460
456
  preset_tag = f"py{python_version[0]}.{python_version[1]}"
461
457
  else:
@@ -507,13 +503,6 @@ class Image:
507
503
 
508
504
  return image
509
505
 
510
- @staticmethod
511
- def _is_editable_install():
512
- """Internal hacky function to see if the current install is editable or not."""
513
- curr = Path(__file__)
514
- pyproject = curr.parent.parent.parent / "pyproject.toml"
515
- return pyproject.exists()
516
-
517
506
  @classmethod
518
507
  def from_debian_base(
519
508
  cls,
flyte/_initialize.py CHANGED
@@ -11,7 +11,6 @@ from flyte.errors import InitializationError
11
11
  from flyte.syncify import syncify
12
12
 
13
13
  from ._logging import initialize_logger, logger
14
- from ._tools import ipython_check
15
14
 
16
15
  if TYPE_CHECKING:
17
16
  from flyte._internal.imagebuild import ImageBuildEngine
@@ -173,6 +172,7 @@ async def init(
173
172
 
174
173
  :return: None
175
174
  """
175
+ from flyte._tools import ipython_check
176
176
  from flyte._utils import get_cwd_editable_install, org_from_endpoint, sanitize_endpoint
177
177
 
178
178
  interactive_mode = ipython_check()
@@ -61,7 +61,7 @@ ENV PATH="/root/.venv/bin:$$PATH" \
61
61
 
62
62
  UV_PACKAGE_INSTALL_COMMAND_TEMPLATE = Template("""\
63
63
  RUN --mount=type=cache,sharing=locked,mode=0777,target=/root/.cache/uv,id=uv \
64
- --mount=type=bind,target=requirements_uv.txt,src=requirements_uv.txt \
64
+ $REQUIREMENTS_MOUNT \
65
65
  $SECRET_MOUNT \
66
66
  uv pip install --python $$UV_PYTHON $PIP_INSTALL_ARGS
67
67
  """)
@@ -129,6 +129,8 @@ class Handler(Protocol):
129
129
  class PipAndRequirementsHandler:
130
130
  @staticmethod
131
131
  async def handle(layer: PipPackages, context_path: Path, dockerfile: str) -> str:
132
+ secret_mounts = _get_secret_mounts_layer(layer.secret_mounts)
133
+
132
134
  # Set pip_install_args based on the layer type - either a requirements file or a list of packages
133
135
  if isinstance(layer, Requirements):
134
136
  if not layer.file.exists():
@@ -138,17 +140,20 @@ class PipAndRequirementsHandler:
138
140
 
139
141
  # Copy the requirements file to the context path
140
142
  requirements_path = copy_files_to_context(layer.file, context_path)
143
+ rel_path = str(requirements_path.relative_to(context_path))
141
144
  pip_install_args = layer.get_pip_install_args()
142
- pip_install_args.extend(["--requirement", str(requirements_path)])
145
+ pip_install_args.extend(["--requirement", "requirements.txt"])
146
+ mount = f"--mount=type=bind,target=requirements.txt,src={rel_path}"
143
147
  else:
148
+ mount = ""
144
149
  requirements = list(layer.packages) if layer.packages else []
145
150
  reqs = " ".join(requirements)
146
151
  pip_install_args = layer.get_pip_install_args()
147
152
  pip_install_args.append(reqs)
148
153
 
149
- secret_mounts = _get_secret_mounts_layer(layer.secret_mounts)
150
154
  delta = UV_PACKAGE_INSTALL_COMMAND_TEMPLATE.substitute(
151
155
  SECRET_MOUNT=secret_mounts,
156
+ REQUIREMENTS_MOUNT=mount,
152
157
  PIP_INSTALL_ARGS=" ".join(pip_install_args),
153
158
  )
154
159
 
@@ -24,6 +24,7 @@ from flyte._image import (
24
24
  Requirements,
25
25
  UVProject,
26
26
  UVScript,
27
+ WorkDir,
27
28
  )
28
29
  from flyte._internal.imagebuild.image_builder import ImageBuilder, ImageChecker
29
30
  from flyte._internal.imagebuild.utils import copy_files_to_context
@@ -34,6 +35,7 @@ if TYPE_CHECKING:
34
35
  from flyte._protos.imagebuilder import definition_pb2 as image_definition_pb2
35
36
 
36
37
  IMAGE_TASK_NAME = "build-image"
38
+ OPTIMIZE_TASK_NAME = "optimize_task"
37
39
  IMAGE_TASK_PROJECT = "system"
38
40
  IMAGE_TASK_DOMAIN = "production"
39
41
 
@@ -117,6 +119,19 @@ class RemoteImageBuilder(ImageBuilder):
117
119
 
118
120
  if run_details.action_details.raw_phase == run_definition_pb2.PHASE_SUCCEEDED:
119
121
  logger.warning(click.style(f"✅ Build completed in {elapsed}!", bold=True, fg="green"))
122
+ try:
123
+ entity = remote.Task.get(
124
+ name=OPTIMIZE_TASK_NAME,
125
+ project=IMAGE_TASK_PROJECT,
126
+ domain=IMAGE_TASK_DOMAIN,
127
+ auto_version="latest",
128
+ )
129
+ await flyte.with_runcontext(project=IMAGE_TASK_PROJECT, domain=IMAGE_TASK_DOMAIN).run.aio(
130
+ entity, spec=spec, context=context, target_image=image_name
131
+ )
132
+ except Exception as e:
133
+ # Ignore the error if optimize is not enabled in the backend.
134
+ logger.warning(f"Failed to run optimize task with error: {e}")
120
135
  else:
121
136
  raise flyte.errors.ImageBuildError(f"❌ Build failed in {elapsed} at {click.style(run.url, fg='cyan')}")
122
137
 
@@ -257,6 +272,11 @@ def _get_layers_proto(image: Image, context_path: Path) -> "image_definition_pb2
257
272
  )
258
273
  )
259
274
  layers.append(env_layer)
275
+ elif isinstance(layer, WorkDir):
276
+ workdir_layer = image_definition_pb2.Layer(
277
+ workdir=image_definition_pb2.WorkDir(workdir=layer.workdir),
278
+ )
279
+ layers.append(workdir_layer)
260
280
 
261
281
  return image_definition_pb2.ImageSpec(
262
282
  base_image=image.base_image,
@@ -54,10 +54,10 @@ def extract_unique_id_and_image(
54
54
  components += f":{reuse_policy.replicas}"
55
55
  if reuse_policy.ttl is not None:
56
56
  components += f":{reuse_policy.ttl.total_seconds()}"
57
- if reuse_policy.reuse_salt is None and code_bundle is not None:
57
+ if reuse_policy.get_scaledown_ttl() is not None:
58
+ components += f":{reuse_policy.get_scaledown_ttl()}"
59
+ if code_bundle is not None:
58
60
  components += f":{code_bundle.computed_version}"
59
- else:
60
- components += f":{reuse_policy.reuse_salt}"
61
61
  if task.security_context is not None:
62
62
  security_ctx_str = task.security_context.SerializeToString(deterministic=True)
63
63
  components += f":{security_ctx_str}"
@@ -102,6 +102,8 @@ def add_reusable(
102
102
  env_name=name, code_bundle=code_bundle, task=task, reuse_policy=reuse_policy
103
103
  )
104
104
 
105
+ scaledown_ttl = reuse_policy.get_scaledown_ttl()
106
+
105
107
  task.custom = {
106
108
  "name": name,
107
109
  "version": version[:15], # Use only the first 15 characters for the version
@@ -110,8 +112,10 @@ def add_reusable(
110
112
  "container_image": image_uri,
111
113
  "backlog_length": None,
112
114
  "parallelism": reuse_policy.concurrency,
115
+ "min_replica_count": reuse_policy.min_replicas,
113
116
  "replica_count": reuse_policy.max_replicas,
114
117
  "ttl_seconds": reuse_policy.ttl.total_seconds() if reuse_policy.ttl else None,
118
+ "scaledown_ttl_seconds": scaledown_ttl.total_seconds() if scaledown_ttl else None,
115
119
  },
116
120
  }
117
121
 
@@ -204,7 +204,9 @@ def _get_urun_container(
204
204
  serialize_context: SerializationContext, task_template: TaskTemplate
205
205
  ) -> Optional[tasks_pb2.Container]:
206
206
  env = (
207
- [literals_pb2.KeyValuePair(key=k, value=v) for k, v in task_template.env.items()] if task_template.env else None
207
+ [literals_pb2.KeyValuePair(key=k, value=v) for k, v in task_template.env_vars.items()]
208
+ if task_template.env_vars
209
+ else None
208
210
  )
209
211
  resources = get_proto_resources(task_template.resources)
210
212
  # pr: under what conditions should this return None?
flyte/_logging.py CHANGED
@@ -4,7 +4,9 @@ import logging
4
4
  import os
5
5
  from typing import Optional
6
6
 
7
- from ._tools import ipython_check, is_in_cluster
7
+ import flyte
8
+
9
+ from ._tools import ipython_check
8
10
 
9
11
  DEFAULT_LOG_LEVEL = logging.WARNING
10
12
 
@@ -42,7 +44,8 @@ def get_rich_handler(log_level: int) -> Optional[logging.Handler]:
42
44
  """
43
45
  Upgrades the global loggers to use Rich logging.
44
46
  """
45
- if is_in_cluster():
47
+ ctx = flyte.ctx()
48
+ if ctx and ctx.is_in_cluster():
46
49
  return None
47
50
  if not ipython_check() and is_rich_logging_disabled():
48
51
  return None
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
  from datetime import timedelta
3
- from typing import Optional, Tuple, Union
3
+ from typing import Tuple, Union
4
4
 
5
5
  from flyte._logging import logger
6
6
 
@@ -17,25 +17,22 @@ class ReusePolicy:
17
17
 
18
18
  :param replicas: Either a single int representing number of replicas or a tuple of two ints representing
19
19
  the min and max.
20
- :param idle_ttl: The maximum idle duration for an environment replica, specified as either seconds (int) or a
21
- timedelta. If not set, the environment's global default will be used.
20
+ :param idle_ttl: The maximum idle duration for an environment, specified as either seconds (int) or a
21
+ timedelta, after which all replicas in the environment are shutdown.
22
+ If not set, the default is configured in the backend (can be as low as 90s).
22
23
  When a replica remains idle — meaning no tasks are running — for this duration, it will be automatically
23
- terminated.
24
+ terminated, also referred to as environment idle timeout.
24
25
  :param concurrency: The maximum number of tasks that can run concurrently in one instance of the environment.
25
- Concurrency of greater than 1 is only supported only for `async` tasks.
26
- :param reuse_salt: Optional string used to control environment reuse.
27
- If set, the environment will be reused even if the code bundle changes.
28
- To force a new environment, either set this to `None` or change its value.
29
-
30
- Example:
31
- reuse_salt = "v1" # Environment is reused
32
- reuse_salt = "v2" # Forces environment recreation
26
+ Concurrency of greater than 1 is only supported for `async` tasks.
27
+ :param scaledown_ttl: The minimum time to wait before scaling down each replica, specified as either seconds (int)
28
+ or a timedelta. This is useful to prevent rapid scaling down of replicas when tasks are running
29
+ frequently. If not set, the default is configured in the backend.
33
30
  """
34
31
 
35
32
  replicas: Union[int, Tuple[int, int]] = 2
36
- idle_ttl: Optional[Union[int, timedelta]] = None
37
- reuse_salt: str | None = None
33
+ idle_ttl: Union[int, timedelta] = 30 # seconds
38
34
  concurrency: int = 1
35
+ scaledown_ttl: Union[int, timedelta] = 30 # seconds
39
36
 
40
37
  def __post_init__(self):
41
38
  if self.replicas is None:
@@ -47,11 +44,12 @@ class ReusePolicy:
47
44
  elif len(self.replicas) != 2:
48
45
  raise ValueError("replicas must be an int or a tuple of two ints")
49
46
 
50
- if self.idle_ttl:
51
- if isinstance(self.idle_ttl, int):
52
- self.idle_ttl = timedelta(seconds=int(self.idle_ttl))
53
- elif not isinstance(self.idle_ttl, timedelta):
54
- raise ValueError("idle_ttl must be an int (seconds) or a timedelta")
47
+ if isinstance(self.idle_ttl, int):
48
+ self.idle_ttl = timedelta(seconds=int(self.idle_ttl))
49
+ elif not isinstance(self.idle_ttl, timedelta):
50
+ raise ValueError("idle_ttl must be an int (seconds) or a timedelta")
51
+ if self.idle_ttl.total_seconds() < 30:
52
+ raise ValueError("idle_ttl must be at least 30 seconds")
55
53
 
56
54
  if self.replicas[1] == 1 and self.concurrency == 1:
57
55
  logger.warning(
@@ -61,6 +59,13 @@ class ReusePolicy:
61
59
  "that runs child tasks."
62
60
  )
63
61
 
62
+ if isinstance(self.scaledown_ttl, int):
63
+ self.scaledown_ttl = timedelta(seconds=int(self.scaledown_ttl))
64
+ elif not isinstance(self.scaledown_ttl, timedelta):
65
+ raise ValueError("scaledown_ttl must be an int (seconds) or a timedelta")
66
+ if self.scaledown_ttl.total_seconds() < 30:
67
+ raise ValueError("scaledown_ttl must be at least 30 seconds")
68
+
64
69
  @property
65
70
  def ttl(self) -> timedelta | None:
66
71
  """
@@ -72,6 +77,23 @@ class ReusePolicy:
72
77
  return self.idle_ttl
73
78
  return timedelta(seconds=self.idle_ttl)
74
79
 
80
+ @property
81
+ def min_replicas(self) -> int:
82
+ """
83
+ Returns the minimum number of replicas.
84
+ """
85
+ return self.replicas[0] if isinstance(self.replicas, tuple) else self.replicas
86
+
87
+ def get_scaledown_ttl(self) -> timedelta | None:
88
+ """
89
+ Returns the scaledown TTL as a timedelta. If scaledown_ttl is not set, returns None.
90
+ """
91
+ if self.scaledown_ttl is None:
92
+ return None
93
+ if isinstance(self.scaledown_ttl, timedelta):
94
+ return self.scaledown_ttl
95
+ return timedelta(seconds=int(self.scaledown_ttl))
96
+
75
97
  @property
76
98
  def max_replicas(self) -> int:
77
99
  """
flyte/_run.py CHANGED
@@ -19,8 +19,6 @@ from flyte._initialize import (
19
19
  )
20
20
  from flyte._logging import logger
21
21
  from flyte._task import P, R, TaskTemplate
22
- from flyte._tools import ipython_check
23
- from flyte.errors import InitializationError
24
22
  from flyte.models import (
25
23
  ActionID,
26
24
  Checkpoints,
@@ -89,13 +87,15 @@ class _Runner:
89
87
  overwrite_cache: bool = False,
90
88
  project: str | None = None,
91
89
  domain: str | None = None,
92
- env: Dict[str, str] | None = None,
90
+ env_vars: Dict[str, str] | None = None,
93
91
  labels: Dict[str, str] | None = None,
94
92
  annotations: Dict[str, str] | None = None,
95
93
  interruptible: bool = False,
96
94
  log_level: int | None = None,
97
95
  disable_run_cache: bool = False,
98
96
  ):
97
+ from flyte._tools import ipython_check
98
+
99
99
  init_config = _get_init_config()
100
100
  client = init_config.client if init_config else None
101
101
  if not force_mode and client is not None:
@@ -116,7 +116,7 @@ class _Runner:
116
116
  self._overwrite_cache = overwrite_cache
117
117
  self._project = project
118
118
  self._domain = domain
119
- self._env = env
119
+ self._env_vars = env_vars
120
120
  self._labels = labels
121
121
  self._annotations = annotations
122
122
  self._interruptible = interruptible
@@ -198,7 +198,7 @@ class _Runner:
198
198
  task_spec = translate_task_to_wire(obj, s_ctx)
199
199
  inputs = await convert_from_native_to_inputs(obj.native_interface, *args, **kwargs)
200
200
 
201
- env = self._env or {}
201
+ env = self._env_vars or {}
202
202
  if self._log_level:
203
203
  env["LOG_LEVEL"] = str(self._log_level)
204
204
  else:
@@ -207,7 +207,7 @@ class _Runner:
207
207
  if not self._dry_run:
208
208
  if get_client() is None:
209
209
  # This can only happen, if the user forces flyte.run(mode="remote") without initializing the client
210
- raise InitializationError(
210
+ raise flyte.errors.InitializationError(
211
211
  "ClientNotInitializedError",
212
212
  "user",
213
213
  "flyte.run requires client to be initialized. "
@@ -542,7 +542,7 @@ def with_runcontext(
542
542
  overwrite_cache: bool = False,
543
543
  project: str | None = None,
544
544
  domain: str | None = None,
545
- env: Dict[str, str] | None = None,
545
+ env_vars: Dict[str, str] | None = None,
546
546
  labels: Dict[str, str] | None = None,
547
547
  annotations: Dict[str, str] | None = None,
548
548
  interruptible: bool = False,
@@ -582,7 +582,7 @@ def with_runcontext(
582
582
  :param overwrite_cache: Optional If true, the cache will be overwritten for the run
583
583
  :param project: Optional The project to use for the run
584
584
  :param domain: Optional The domain to use for the run
585
- :param env: Optional Environment variables to set for the run
585
+ :param env_vars: Optional Environment variables to set for the run
586
586
  :param labels: Optional Labels to set for the run
587
587
  :param annotations: Optional Annotations to set for the run
588
588
  :param interruptible: Optional If true, the run can be interrupted by the user.
@@ -606,7 +606,7 @@ def with_runcontext(
606
606
  raw_data_path=raw_data_path,
607
607
  run_base_dir=run_base_dir,
608
608
  overwrite_cache=overwrite_cache,
609
- env=env,
609
+ env_vars=env_vars,
610
610
  labels=labels,
611
611
  annotations=annotations,
612
612
  interruptible=interruptible,
flyte/_secret.py CHANGED
@@ -17,14 +17,14 @@ class Secret:
17
17
 
18
18
  Example:
19
19
  ```python
20
- @task(secrets="MY_SECRET")
20
+ @task(secrets="my-secret")
21
21
  async def my_task():
22
- os.environ["MY_SECRET"] # This will be set to the value of the secret
22
+ # This will be set to the value of the secret. Note: The env var is always uppercase, and - is replaced with _.
23
+ os.environ["MY_SECRET"]
23
24
 
24
- @task(secrets=Secret("MY_SECRET", mount="/path/to/secret"))
25
+ @task(secrets=Secret("my-openai-api-key", as_env_var="OPENAI_API_KEY"))
25
26
  async def my_task2():
26
- async with open("/path/to/secret") as f:
27
- secret_value = f.read()
27
+ os.environ["OPENAI_API_KEY"]
28
28
  ```
29
29
 
30
30
  TODO: Add support for secret versioning (some stores) and secret groups (some stores) and mounting as files.
@@ -32,6 +32,7 @@ class Secret:
32
32
  :param key: The name of the secret in the secret store.
33
33
  :param group: The group of the secret in the secret store.
34
34
  :param mount: Use this to specify the path where the secret should be mounted.
35
+ TODO: support arbitrary mount paths. Today only "/etc/flyte/secrets" is supported
35
36
  :param as_env_var: The name of the environment variable that the secret should be mounted as.
36
37
  """
37
38
 
@@ -41,6 +42,9 @@ class Secret:
41
42
  as_env_var: Optional[str] = None
42
43
 
43
44
  def __post_init__(self):
45
+ if not self.mount and not self.as_env_var:
46
+ self.as_env_var = f"{self.group}_{self.key}" if self.group else self.key
47
+ self.as_env_var = self.as_env_var.replace("-", "_").upper()
44
48
  if self.as_env_var is not None:
45
49
  pattern = r"^[A-Z_][A-Z0-9_]*$"
46
50
  if not re.match(pattern, self.as_env_var):
flyte/_task.py CHANGED
@@ -76,7 +76,7 @@ class TaskTemplate(Generic[P, R]):
76
76
  :param reusable: Optional The reusability policy for the task, defaults to None, which means the task environment
77
77
  will not be reused across task invocations.
78
78
  :param docs: Optional The documentation for the task, if not provided the function docstring will be used.
79
- :param env: Optional The environment variables to set for the task.
79
+ :param env_vars: Optional The environment variables to set for the task.
80
80
  :param secrets: Optional The secrets that will be injected into the task at runtime.
81
81
  :param timeout: Optional The timeout for the task.
82
82
  :param max_inline_io_bytes: Maximum allowed size (in bytes) for all inputs and outputs passed directly to the task
@@ -95,7 +95,7 @@ class TaskTemplate(Generic[P, R]):
95
95
  retries: Union[int, RetryStrategy] = 0
96
96
  reusable: Union[ReusePolicy, None] = None
97
97
  docs: Optional[Documentation] = None
98
- env: Optional[Dict[str, str]] = None
98
+ env_vars: Optional[Dict[str, str]] = None
99
99
  secrets: Optional[SecretRequest] = None
100
100
  timeout: Optional[TimeoutType] = None
101
101
  pod_template: Optional[Union[str, PodTemplate]] = None
@@ -314,14 +314,16 @@ class TaskTemplate(Generic[P, R]):
314
314
  def override(
315
315
  self,
316
316
  *,
317
+ friendly_name: Optional[str] = None,
317
318
  resources: Optional[Resources] = None,
318
- cache: CacheRequest = "auto",
319
+ cache: Optional[CacheRequest] = None,
319
320
  retries: Union[int, RetryStrategy] = 0,
320
321
  timeout: Optional[TimeoutType] = None,
321
322
  reusable: Union[ReusePolicy, Literal["off"], None] = None,
322
- env: Optional[Dict[str, str]] = None,
323
+ env_vars: Optional[Dict[str, str]] = None,
323
324
  secrets: Optional[SecretRequest] = None,
324
325
  max_inline_io_bytes: int | None = None,
326
+ pod_template: Optional[Union[str, PodTemplate]] = None,
325
327
  **kwargs: Any,
326
328
  ) -> TaskTemplate:
327
329
  """
@@ -344,7 +346,7 @@ class TaskTemplate(Generic[P, R]):
344
346
  " Reusable tasks will use the parent env's resources. You can disable reusability and"
345
347
  " override resources if needed. (set reusable='off')"
346
348
  )
347
- if env is not None:
349
+ if env_vars is not None:
348
350
  raise ValueError(
349
351
  "Cannot override env when reusable is set."
350
352
  " Reusable tasks will use the parent env's env. You can disable reusability and "
@@ -358,7 +360,7 @@ class TaskTemplate(Generic[P, R]):
358
360
  )
359
361
 
360
362
  resources = resources or self.resources
361
- env = env or self.env
363
+ env_vars = env_vars or self.env_vars
362
364
  secrets = secrets or self.secrets
363
365
 
364
366
  for k, v in kwargs.items():
@@ -373,14 +375,17 @@ class TaskTemplate(Generic[P, R]):
373
375
 
374
376
  return replace(
375
377
  self,
378
+ friendly_name=friendly_name or self.friendly_name,
376
379
  resources=resources,
377
380
  cache=cache,
378
381
  retries=retries,
379
382
  timeout=timeout,
380
383
  reusable=cast(Optional[ReusePolicy], reusable),
381
- env=env,
384
+ env_vars=env_vars,
382
385
  secrets=secrets,
383
386
  max_inline_io_bytes=max_inline_io_bytes,
387
+ pod_template=pod_template,
388
+ **kwargs,
384
389
  )
385
390
 
386
391