flyte 2.0.0b23__py3-none-any.whl → 2.0.0b25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (162) hide show
  1. flyte/__init__.py +11 -2
  2. flyte/_cache/local_cache.py +4 -3
  3. flyte/_code_bundle/_utils.py +3 -3
  4. flyte/_code_bundle/bundle.py +12 -5
  5. flyte/_context.py +4 -1
  6. flyte/_custom_context.py +73 -0
  7. flyte/_deploy.py +31 -7
  8. flyte/_image.py +48 -16
  9. flyte/_initialize.py +69 -26
  10. flyte/_internal/controllers/_local_controller.py +1 -0
  11. flyte/_internal/controllers/_trace.py +1 -1
  12. flyte/_internal/controllers/remote/_action.py +9 -10
  13. flyte/_internal/controllers/remote/_client.py +1 -1
  14. flyte/_internal/controllers/remote/_controller.py +4 -2
  15. flyte/_internal/controllers/remote/_core.py +10 -13
  16. flyte/_internal/controllers/remote/_informer.py +3 -3
  17. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  18. flyte/_internal/imagebuild/docker_builder.py +45 -59
  19. flyte/_internal/imagebuild/remote_builder.py +51 -11
  20. flyte/_internal/imagebuild/utils.py +51 -3
  21. flyte/_internal/runtime/convert.py +39 -18
  22. flyte/_internal/runtime/io.py +8 -7
  23. flyte/_internal/runtime/resources_serde.py +20 -6
  24. flyte/_internal/runtime/reuse.py +1 -1
  25. flyte/_internal/runtime/task_serde.py +7 -10
  26. flyte/_internal/runtime/taskrunner.py +10 -1
  27. flyte/_internal/runtime/trigger_serde.py +13 -13
  28. flyte/_internal/runtime/types_serde.py +1 -1
  29. flyte/_keyring/file.py +2 -2
  30. flyte/_map.py +65 -13
  31. flyte/_pod.py +2 -2
  32. flyte/_resources.py +175 -31
  33. flyte/_run.py +37 -21
  34. flyte/_task.py +27 -6
  35. flyte/_task_environment.py +37 -10
  36. flyte/_utils/module_loader.py +2 -2
  37. flyte/_version.py +3 -3
  38. flyte/cli/_common.py +47 -5
  39. flyte/cli/_create.py +4 -0
  40. flyte/cli/_deploy.py +8 -0
  41. flyte/cli/_get.py +4 -0
  42. flyte/cli/_params.py +4 -4
  43. flyte/cli/_run.py +50 -7
  44. flyte/cli/_update.py +4 -3
  45. flyte/config/_config.py +2 -0
  46. flyte/config/_internal.py +1 -0
  47. flyte/config/_reader.py +3 -3
  48. flyte/errors.py +1 -1
  49. flyte/extend.py +4 -0
  50. flyte/extras/_container.py +6 -1
  51. flyte/git/_config.py +11 -9
  52. flyte/io/_dataframe/basic_dfs.py +1 -1
  53. flyte/io/_dataframe/dataframe.py +12 -8
  54. flyte/io/_dir.py +48 -15
  55. flyte/io/_file.py +48 -11
  56. flyte/models.py +12 -8
  57. flyte/remote/_action.py +18 -16
  58. flyte/remote/_client/_protocols.py +4 -3
  59. flyte/remote/_client/auth/_channel.py +1 -1
  60. flyte/remote/_client/controlplane.py +4 -8
  61. flyte/remote/_data.py +4 -3
  62. flyte/remote/_logs.py +3 -3
  63. flyte/remote/_run.py +5 -5
  64. flyte/remote/_secret.py +20 -13
  65. flyte/remote/_task.py +7 -8
  66. flyte/remote/_trigger.py +25 -27
  67. flyte/storage/_parallel_reader.py +274 -0
  68. flyte/storage/_storage.py +66 -2
  69. flyte/types/_interface.py +2 -2
  70. flyte/types/_pickle.py +1 -1
  71. flyte/types/_string_literals.py +8 -9
  72. flyte/types/_type_engine.py +25 -17
  73. flyte/types/_utils.py +1 -1
  74. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/METADATA +2 -1
  75. flyte-2.0.0b25.dist-info/RECORD +184 -0
  76. flyte/_protos/__init__.py +0 -0
  77. flyte/_protos/common/authorization_pb2.py +0 -66
  78. flyte/_protos/common/authorization_pb2.pyi +0 -108
  79. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  80. flyte/_protos/common/identifier_pb2.py +0 -117
  81. flyte/_protos/common/identifier_pb2.pyi +0 -142
  82. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  83. flyte/_protos/common/identity_pb2.py +0 -48
  84. flyte/_protos/common/identity_pb2.pyi +0 -72
  85. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  86. flyte/_protos/common/list_pb2.py +0 -36
  87. flyte/_protos/common/list_pb2.pyi +0 -71
  88. flyte/_protos/common/list_pb2_grpc.py +0 -4
  89. flyte/_protos/common/policy_pb2.py +0 -37
  90. flyte/_protos/common/policy_pb2.pyi +0 -27
  91. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  92. flyte/_protos/common/role_pb2.py +0 -37
  93. flyte/_protos/common/role_pb2.pyi +0 -53
  94. flyte/_protos/common/role_pb2_grpc.py +0 -4
  95. flyte/_protos/common/runtime_version_pb2.py +0 -28
  96. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  97. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  98. flyte/_protos/imagebuilder/definition_pb2.py +0 -60
  99. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -153
  100. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  101. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  102. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  103. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  104. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  105. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  106. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  107. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  108. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  109. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  110. flyte/_protos/secret/definition_pb2.py +0 -49
  111. flyte/_protos/secret/definition_pb2.pyi +0 -93
  112. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  113. flyte/_protos/secret/payload_pb2.py +0 -62
  114. flyte/_protos/secret/payload_pb2.pyi +0 -94
  115. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  116. flyte/_protos/secret/secret_pb2.py +0 -38
  117. flyte/_protos/secret/secret_pb2.pyi +0 -6
  118. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  119. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  120. flyte/_protos/workflow/common_pb2.py +0 -38
  121. flyte/_protos/workflow/common_pb2.pyi +0 -63
  122. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  123. flyte/_protos/workflow/environment_pb2.py +0 -29
  124. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  125. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  126. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  127. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  128. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  129. flyte/_protos/workflow/queue_service_pb2.py +0 -117
  130. flyte/_protos/workflow/queue_service_pb2.pyi +0 -182
  131. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -206
  132. flyte/_protos/workflow/run_definition_pb2.py +0 -123
  133. flyte/_protos/workflow/run_definition_pb2.pyi +0 -354
  134. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  135. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  136. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  137. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  138. flyte/_protos/workflow/run_service_pb2.py +0 -147
  139. flyte/_protos/workflow/run_service_pb2.pyi +0 -203
  140. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -480
  141. flyte/_protos/workflow/state_service_pb2.py +0 -67
  142. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  143. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  144. flyte/_protos/workflow/task_definition_pb2.py +0 -86
  145. flyte/_protos/workflow/task_definition_pb2.pyi +0 -105
  146. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  147. flyte/_protos/workflow/task_service_pb2.py +0 -61
  148. flyte/_protos/workflow/task_service_pb2.pyi +0 -62
  149. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  150. flyte/_protos/workflow/trigger_definition_pb2.py +0 -66
  151. flyte/_protos/workflow/trigger_definition_pb2.pyi +0 -117
  152. flyte/_protos/workflow/trigger_definition_pb2_grpc.py +0 -4
  153. flyte/_protos/workflow/trigger_service_pb2.py +0 -96
  154. flyte/_protos/workflow/trigger_service_pb2.pyi +0 -110
  155. flyte/_protos/workflow/trigger_service_pb2_grpc.py +0 -281
  156. flyte-2.0.0b23.dist-info/RECORD +0 -262
  157. {flyte-2.0.0b23.data → flyte-2.0.0b25.data}/scripts/debug.py +0 -0
  158. {flyte-2.0.0b23.data → flyte-2.0.0b25.data}/scripts/runtime.py +0 -0
  159. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/WHEEL +0 -0
  160. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/entry_points.txt +0 -0
  161. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/licenses/LICENSE +0 -0
  162. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ import flyte.errors
16
16
  from flyte import Image, remote
17
17
  from flyte._code_bundle._utils import tar_strip_file_attributes
18
18
  from flyte._image import (
19
+ _BASE_REGISTRY,
19
20
  AptPackages,
20
21
  Architecture,
21
22
  Commands,
@@ -24,6 +25,7 @@ from flyte._image import (
24
25
  Env,
25
26
  PipOption,
26
27
  PipPackages,
28
+ PoetryProject,
27
29
  PythonWheels,
28
30
  Requirements,
29
31
  UVProject,
@@ -31,14 +33,14 @@ from flyte._image import (
31
33
  WorkDir,
32
34
  )
33
35
  from flyte._internal.imagebuild.image_builder import ImageBuilder, ImageChecker
34
- from flyte._internal.imagebuild.utils import copy_files_to_context
36
+ from flyte._internal.imagebuild.utils import copy_files_to_context, get_and_list_dockerignore
35
37
  from flyte._internal.runtime.task_serde import get_security_context
36
38
  from flyte._logging import logger
37
39
  from flyte._secret import Secret
38
40
  from flyte.remote import ActionOutputs, Run
39
41
 
40
42
  if TYPE_CHECKING:
41
- from flyte._protos.imagebuilder import definition_pb2 as image_definition_pb2
43
+ from flyteidl2.imagebuilder import definition_pb2 as image_definition_pb2
42
44
 
43
45
  IMAGE_TASK_NAME = "build-image"
44
46
  IMAGE_TASK_PROJECT = "system"
@@ -68,10 +70,11 @@ class RemoteImageChecker(ImageChecker):
68
70
  image_name = f"{repository.split('/')[-1]}:{tag}"
69
71
 
70
72
  try:
73
+ from flyteidl2.imagebuilder import definition_pb2 as image_definition__pb2
74
+ from flyteidl2.imagebuilder import payload_pb2 as image_payload__pb2
75
+ from flyteidl2.imagebuilder import service_pb2_grpc as image_service_pb2_grpc
76
+
71
77
  from flyte._initialize import _get_init_config
72
- from flyte._protos.imagebuilder import definition_pb2 as image_definition__pb2
73
- from flyte._protos.imagebuilder import payload_pb2 as image_payload__pb2
74
- from flyte._protos.imagebuilder import service_pb2_grpc as image_service_pb2_grpc
75
78
 
76
79
  cfg = _get_init_config()
77
80
  if cfg is None:
@@ -96,7 +99,7 @@ class RemoteImageBuilder(ImageBuilder):
96
99
  return [RemoteImageChecker]
97
100
 
98
101
  async def build_image(self, image: Image, dry_run: bool = False) -> str:
99
- from flyte._protos.workflow import run_definition_pb2
102
+ from flyteidl2.workflow import run_definition_pb2
100
103
 
101
104
  image_name = f"{image.name}:{image._final_tag}"
102
105
  spec, context = await _validate_configuration(image)
@@ -110,10 +113,15 @@ class RemoteImageBuilder(ImageBuilder):
110
113
  ).override.aio(secrets=_get_build_secrets_from_image(image))
111
114
 
112
115
  logger.warning("[bold blue]🐳 Submitting a new build...[/bold blue]")
116
+ if image.registry and image.registry != _BASE_REGISTRY:
117
+ target_image = f"{image.registry}/{image_name}"
118
+ else:
119
+ # Use the default system registry in the backend.
120
+ target_image = image_name
113
121
  run = cast(
114
122
  Run,
115
123
  await flyte.with_runcontext(project=IMAGE_TASK_PROJECT, domain=IMAGE_TASK_DOMAIN).run.aio(
116
- entity, spec=spec, context=context, target_image=image_name
124
+ entity, spec=spec, context=context, target_image=target_image
117
125
  ),
118
126
  )
119
127
  logger.warning(f"⏳ Waiting for build to finish at: [bold cyan link={run.url}]{run.url}[/bold cyan link]")
@@ -180,7 +188,7 @@ async def _validate_configuration(image: Image) -> Tuple[str, Optional[str]]:
180
188
 
181
189
 
182
190
  def _get_layers_proto(image: Image, context_path: Path) -> "image_definition_pb2.ImageSpec":
183
- from flyte._protos.imagebuilder import definition_pb2 as image_definition_pb2
191
+ from flyteidl2.imagebuilder import definition_pb2 as image_definition_pb2
184
192
 
185
193
  if image.dockerfile is not None:
186
194
  raise flyte.errors.ImageBuildError(
@@ -190,7 +198,7 @@ def _get_layers_proto(image: Image, context_path: Path) -> "image_definition_pb2
190
198
  layers = []
191
199
  for layer in image._layers:
192
200
  secret_mounts = None
193
- pip_options = None
201
+ pip_options = image_definition_pb2.PipOptions()
194
202
 
195
203
  if isinstance(layer, PipOption):
196
204
  pip_options = image_definition_pb2.PipOptions(
@@ -256,12 +264,18 @@ def _get_layers_proto(image: Image, context_path: Path) -> "image_definition_pb2
256
264
  if "tool.uv.index" in line:
257
265
  raise ValueError("External sources are not supported in pyproject.toml")
258
266
 
259
- if layer.extra_args and "--no-install-project" in layer.extra_args:
267
+ if layer.project_install_mode == "dependencies_only":
260
268
  # Copy pyproject itself
261
269
  pyproject_dst = copy_files_to_context(layer.pyproject, context_path)
270
+ if pip_options.extra_args:
271
+ if "--no-install-project" not in pip_options.extra_args:
272
+ pip_options.extra_args += " --no-install-project"
273
+ else:
274
+ pip_options.extra_args = " --no-install-project"
262
275
  else:
263
276
  # Copy the entire project
264
- pyproject_dst = copy_files_to_context(layer.pyproject.parent, context_path)
277
+ docker_ignore_patterns = get_and_list_dockerignore(image)
278
+ pyproject_dst = copy_files_to_context(layer.pyproject.parent, context_path, docker_ignore_patterns)
265
279
 
266
280
  uv_layer = image_definition_pb2.Layer(
267
281
  uv_project=image_definition_pb2.UVProject(
@@ -272,6 +286,27 @@ def _get_layers_proto(image: Image, context_path: Path) -> "image_definition_pb2
272
286
  )
273
287
  )
274
288
  layers.append(uv_layer)
289
+ elif isinstance(layer, PoetryProject):
290
+ for line in layer.pyproject.read_text().splitlines():
291
+ if "tool.poetry.source" in line:
292
+ raise ValueError("External sources are not supported in pyproject.toml")
293
+
294
+ if layer.extra_args and "--no-root" in layer.extra_args:
295
+ # Copy pyproject itself
296
+ pyproject_dst = copy_files_to_context(layer.pyproject, context_path)
297
+ else:
298
+ # Copy the entire project
299
+ pyproject_dst = copy_files_to_context(layer.pyproject.parent, context_path)
300
+
301
+ poetry_layer = image_definition_pb2.Layer(
302
+ poetry_project=image_definition_pb2.PoetryProject(
303
+ pyproject=str(pyproject_dst.relative_to(context_path)),
304
+ poetry_lock=str(copy_files_to_context(layer.poetry_lock, context_path).relative_to(context_path)),
305
+ extra_args=layer.extra_args,
306
+ secret_mounts=secret_mounts,
307
+ )
308
+ )
309
+ layers.append(poetry_layer)
275
310
  elif isinstance(layer, Commands):
276
311
  commands_layer = image_definition_pb2.Layer(
277
312
  commands=image_definition_pb2.Commands(
@@ -330,4 +365,9 @@ def _get_build_secrets_from_image(image: Image) -> Optional[typing.List[Secret]]
330
365
  else:
331
366
  raise ValueError(f"Unsupported secret_mount type: {type(secret_mount)}")
332
367
 
368
+ image_registry_secret = image._image_registry_secret
369
+ if image_registry_secret:
370
+ secrets.append(
371
+ Secret(key=image_registry_secret.key, group=image_registry_secret.group, mount=DEFAULT_SECRET_DIR)
372
+ )
333
373
  return secrets
@@ -1,8 +1,12 @@
1
1
  import shutil
2
2
  from pathlib import Path
3
+ from typing import List, Optional
3
4
 
5
+ from flyte._image import DockerIgnore, Image
6
+ from flyte._logging import logger
4
7
 
5
- def copy_files_to_context(src: Path, context_path: Path) -> Path:
8
+
9
+ def copy_files_to_context(src: Path, context_path: Path, ignore_patterns: list[str] = []) -> Path:
6
10
  """
7
11
  This helper function ensures that absolute paths that users specify are converted correctly to a path in the
8
12
  context directory. Doing this prevents collisions while ensuring files are available in the context.
@@ -23,8 +27,52 @@ def copy_files_to_context(src: Path, context_path: Path) -> Path:
23
27
  dst_path = context_path / src
24
28
  dst_path.parent.mkdir(parents=True, exist_ok=True)
25
29
  if src.is_dir():
26
- # TODO: Add support dockerignore
27
- shutil.copytree(src, dst_path, dirs_exist_ok=True, ignore=shutil.ignore_patterns(".idea", ".venv"))
30
+ default_ignore_patterns = [".idea", ".venv"]
31
+ ignore_patterns = list(set(ignore_patterns + default_ignore_patterns))
32
+ shutil.copytree(src, dst_path, dirs_exist_ok=True, ignore=shutil.ignore_patterns(*ignore_patterns))
28
33
  else:
29
34
  shutil.copy(src, dst_path)
30
35
  return dst_path
36
+
37
+
38
+ def get_and_list_dockerignore(image: Image) -> List[str]:
39
+ """
40
+ Get and parse dockerignore patterns from .dockerignore file.
41
+
42
+ This function first looks for a DockerIgnore layer in the image's layers. If found, it uses
43
+ the path specified in that layer. If no DockerIgnore layer is found, it falls back to looking
44
+ for a .dockerignore file in the root_path directory.
45
+
46
+ :param image: The Image object
47
+ """
48
+ from flyte._initialize import _get_init_config
49
+
50
+ # Look for DockerIgnore layer in the image layers
51
+ dockerignore_path: Optional[Path] = None
52
+ patterns: List[str] = []
53
+
54
+ for layer in image._layers:
55
+ if isinstance(layer, DockerIgnore) and layer.path.strip():
56
+ dockerignore_path = Path(layer.path)
57
+ # If DockerIgnore layer not specified, set dockerignore_path under root_path
58
+ init_config = _get_init_config()
59
+ root_path = init_config.root_dir if init_config else None
60
+ if not dockerignore_path and root_path:
61
+ dockerignore_path = Path(root_path) / ".dockerignore"
62
+ # Return empty list if no .dockerignore file found
63
+ if not dockerignore_path or not dockerignore_path.exists() or not dockerignore_path.is_file():
64
+ logger.info(f".dockerignore file not found at path: {dockerignore_path}")
65
+ return patterns
66
+
67
+ try:
68
+ with open(dockerignore_path, "r", encoding="utf-8") as f:
69
+ for line in f:
70
+ stripped_line = line.strip()
71
+ # Skip empty lines, whitespace-only lines, and comments
72
+ if not stripped_line or stripped_line.startswith("#"):
73
+ continue
74
+ patterns.append(stripped_line)
75
+ except Exception as e:
76
+ logger.error(f"Failed to read .dockerignore file at {dockerignore_path}: {e}")
77
+ return []
78
+ return patterns
@@ -8,27 +8,33 @@ from dataclasses import dataclass
8
8
  from types import NoneType
9
9
  from typing import Any, Dict, List, Tuple, Union, get_args
10
10
 
11
- from flyteidl.core import execution_pb2, interface_pb2, literals_pb2
11
+ from flyteidl2.core import execution_pb2, interface_pb2, literals_pb2
12
+ from flyteidl2.task import common_pb2, task_definition_pb2
12
13
 
13
14
  import flyte.errors
14
15
  import flyte.storage as storage
15
- from flyte._protos.workflow import common_pb2, run_definition_pb2, task_definition_pb2
16
+ from flyte._context import ctx
16
17
  from flyte.models import ActionID, NativeInterface, TaskContext
17
18
  from flyte.types import TypeEngine, TypeTransformerFailedError
18
19
 
19
20
 
20
21
  @dataclass(frozen=True)
21
22
  class Inputs:
22
- proto_inputs: run_definition_pb2.Inputs
23
+ proto_inputs: common_pb2.Inputs
23
24
 
24
25
  @classmethod
25
26
  def empty(cls) -> "Inputs":
26
- return cls(proto_inputs=run_definition_pb2.Inputs())
27
+ return cls(proto_inputs=common_pb2.Inputs())
28
+
29
+ @property
30
+ def context(self) -> Dict[str, str]:
31
+ """Get the context as a dictionary."""
32
+ return {kv.key: kv.value for kv in self.proto_inputs.context}
27
33
 
28
34
 
29
35
  @dataclass(frozen=True)
30
36
  class Outputs:
31
- proto_outputs: run_definition_pb2.Outputs
37
+ proto_outputs: common_pb2.Outputs
32
38
 
33
39
 
34
40
  @dataclass
@@ -102,15 +108,30 @@ def is_optional_type(tp) -> bool:
102
108
  return NoneType in get_args(tp) # fastest check
103
109
 
104
110
 
105
- async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwargs) -> Inputs:
111
+ async def convert_from_native_to_inputs(
112
+ interface: NativeInterface, *args, custom_context: Dict[str, str] | None = None, **kwargs
113
+ ) -> Inputs:
106
114
  kwargs = interface.convert_to_kwargs(*args, **kwargs)
107
115
 
108
116
  missing = [key for key in interface.required_inputs() if key not in kwargs]
109
117
  if missing:
110
118
  raise ValueError(f"Missing required inputs: {', '.join(missing)}")
111
119
 
120
+ # Read custom_context from TaskContext if available (inside task execution)
121
+ # Otherwise use the passed parameter (for remote run initiation)
122
+ context_kvs = None
123
+ tctx = ctx()
124
+ if tctx and tctx.custom_context:
125
+ # Inside a task - read from TaskContext
126
+ context_to_use = tctx.custom_context
127
+ context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in context_to_use.items()]
128
+ elif custom_context:
129
+ # Remote run initiation
130
+ context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in custom_context.items()]
131
+
112
132
  if len(interface.inputs) == 0:
113
- return Inputs.empty()
133
+ # Handle context even for empty inputs
134
+ return Inputs(proto_inputs=common_pb2.Inputs(context=context_kvs))
114
135
 
115
136
  # fill in defaults if missing
116
137
  type_hints: Dict[str, type] = {}
@@ -144,12 +165,12 @@ async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwa
144
165
  for k, v in already_converted_kwargs.items():
145
166
  copied_literals[k] = v
146
167
  literal_map = literals_pb2.LiteralMap(literals=copied_literals)
168
+
147
169
  # Make sure we the interface, not literal_map or kwargs, because those may have a different order
148
170
  return Inputs(
149
- proto_inputs=run_definition_pb2.Inputs(
150
- literals=[
151
- run_definition_pb2.NamedLiteral(name=k, value=literal_map.literals[k]) for k in interface.inputs.keys()
152
- ]
171
+ proto_inputs=common_pb2.Inputs(
172
+ literals=[common_pb2.NamedLiteral(name=k, value=literal_map.literals[k]) for k in interface.inputs.keys()],
173
+ context=context_kvs,
153
174
  )
154
175
  )
155
176
 
@@ -191,11 +212,11 @@ async def convert_from_native_to_outputs(o: Any, interface: NativeInterface, tas
191
212
  for (output_name, python_type), v in zip(interface.outputs.items(), o):
192
213
  try:
193
214
  lit = await TypeEngine.to_literal(v, python_type, TypeEngine.to_literal_type(python_type))
194
- named.append(run_definition_pb2.NamedLiteral(name=output_name, value=lit))
215
+ named.append(common_pb2.NamedLiteral(name=output_name, value=lit))
195
216
  except TypeTransformerFailedError as e:
196
217
  raise flyte.errors.RuntimeDataValidationError(output_name, e, task_name)
197
218
 
198
- return Outputs(proto_outputs=run_definition_pb2.Outputs(literals=named))
219
+ return Outputs(proto_outputs=common_pb2.Outputs(literals=named))
199
220
 
200
221
 
201
222
  async def convert_outputs_to_native(interface: NativeInterface, outputs: Outputs) -> Union[Any, Tuple[Any, ...]]:
@@ -222,7 +243,7 @@ def convert_error_to_native(err: execution_pb2.ExecutionError | Exception | Erro
222
243
  if isinstance(err, Error):
223
244
  err = err.err
224
245
 
225
- user_code, server_code = _clean_error_code(err.code)
246
+ user_code, _server_code = _clean_error_code(err.code)
226
247
  match err.kind:
227
248
  case execution_pb2.ExecutionError.UNKNOWN:
228
249
  return flyte.errors.RuntimeUnknownError(code=user_code, message=err.message, worker=err.worker)
@@ -351,7 +372,7 @@ def generate_inputs_repr_for_literal(literal: literals_pb2.Literal) -> bytes:
351
372
  return literal.SerializeToString(deterministic=True)
352
373
 
353
374
 
354
- def generate_inputs_hash_for_named_literals(inputs: list[run_definition_pb2.NamedLiteral]) -> str:
375
+ def generate_inputs_hash_for_named_literals(inputs: list[common_pb2.NamedLiteral]) -> str:
355
376
  """
356
377
  Generate a hash for the inputs using the new literal representation approach that respects
357
378
  hash values already present in literals. This is used to uniquely identify the inputs for a task
@@ -375,7 +396,7 @@ def generate_inputs_hash_for_named_literals(inputs: list[run_definition_pb2.Name
375
396
  return hash_data(combined_bytes)
376
397
 
377
398
 
378
- def generate_inputs_hash_from_proto(inputs: run_definition_pb2.Inputs) -> str:
399
+ def generate_inputs_hash_from_proto(inputs: common_pb2.Inputs) -> str:
379
400
  """
380
401
  Generate a hash for the inputs. This is used to uniquely identify the inputs for a task.
381
402
  :param inputs: The inputs to hash.
@@ -404,7 +425,7 @@ def generate_cache_key_hash(
404
425
  task_interface: interface_pb2.TypedInterface,
405
426
  cache_version: str,
406
427
  ignored_input_vars: List[str],
407
- proto_inputs: run_definition_pb2.Inputs,
428
+ proto_inputs: common_pb2.Inputs,
408
429
  ) -> str:
409
430
  """
410
431
  Generate a cache key hash based on the inputs hash, task name, task interface, and cache version.
@@ -420,7 +441,7 @@ def generate_cache_key_hash(
420
441
  """
421
442
  if ignored_input_vars:
422
443
  filtered = [named_lit for named_lit in proto_inputs.literals if named_lit.name not in ignored_input_vars]
423
- final = run_definition_pb2.Inputs(literals=filtered)
444
+ final = common_pb2.Inputs(literals=filtered)
424
445
  final_inputs = generate_inputs_hash_from_proto(final)
425
446
  else:
426
447
  final_inputs = inputs_hash
@@ -5,10 +5,11 @@ It uses the storage module to handle the actual uploading and downloading of fil
5
5
  TODO: Convert to use streaming apis
6
6
  """
7
7
 
8
- from flyteidl.core import errors_pb2, execution_pb2
8
+ from flyteidl.core import errors_pb2
9
+ from flyteidl2.core import execution_pb2
10
+ from flyteidl2.task import common_pb2
9
11
 
10
12
  import flyte.storage as storage
11
- from flyte._protos.workflow import run_definition_pb2
12
13
  from flyte.models import PathRewrite
13
14
 
14
15
  from .convert import Inputs, Outputs, _clean_error_code
@@ -70,7 +71,7 @@ async def upload_outputs(outputs: Outputs, output_path: str, max_bytes: int = -1
70
71
  await storage.put_stream(data_iterable=outputs.proto_outputs.SerializeToString(), to_path=output_uri)
71
72
 
72
73
 
73
- async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str):
74
+ async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str) -> str:
74
75
  """
75
76
  :param err: execution_pb2.ExecutionError
76
77
  :param output_prefix: The output prefix of the remote uri.
@@ -87,7 +88,7 @@ async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str):
87
88
  )
88
89
  )
89
90
  error_uri = error_path(output_prefix)
90
- await storage.put_stream(data_iterable=error_document.SerializeToString(), to_path=error_uri)
91
+ return await storage.put_stream(data_iterable=error_document.SerializeToString(), to_path=error_uri)
91
92
 
92
93
 
93
94
  # ------------------------------- DOWNLOAD Methods ------------------------------- #
@@ -98,7 +99,7 @@ async def load_inputs(path: str, max_bytes: int = -1, path_rewrite_config: PathR
98
99
  :param path_rewrite_config: If provided, rewrites paths in the input blobs according to the configuration.
99
100
  :return: Inputs object
100
101
  """
101
- lm = run_definition_pb2.Inputs()
102
+ lm = common_pb2.Inputs()
102
103
 
103
104
  if max_bytes == -1:
104
105
  proto_str = b"".join([c async for c in storage.get_stream(path=path)])
@@ -137,7 +138,7 @@ async def load_outputs(path: str, max_bytes: int = -1) -> Outputs:
137
138
  If -1, reads the entire file.
138
139
  :return: Outputs object
139
140
  """
140
- lm = run_definition_pb2.Outputs()
141
+ lm = common_pb2.Outputs()
141
142
 
142
143
  if max_bytes == -1:
143
144
  proto_str = b"".join([c async for c in storage.get_stream(path=path)])
@@ -169,7 +170,7 @@ async def load_error(path: str) -> execution_pb2.ExecutionError:
169
170
  err.ParseFromString(proto_str)
170
171
 
171
172
  if err.error is not None:
172
- user_code, server_code = _clean_error_code(err.error.code)
173
+ user_code, _server_code = _clean_error_code(err.error.code)
173
174
  return execution_pb2.ExecutionError(
174
175
  code=user_code,
175
176
  message=err.error.message,
@@ -1,8 +1,8 @@
1
- from typing import List, Optional, Tuple
1
+ from typing import Dict, List, Optional, Tuple
2
2
 
3
- from flyteidl.core import tasks_pb2
3
+ from flyteidl2.core import tasks_pb2
4
4
 
5
- from flyte._resources import CPUBaseType, Resources
5
+ from flyte._resources import CPUBaseType, DeviceClass, Resources
6
6
 
7
7
  ACCELERATOR_DEVICE_MAP = {
8
8
  "A100": "nvidia-tesla-a100",
@@ -24,6 +24,14 @@ ACCELERATOR_DEVICE_MAP = {
24
24
  "V6E": "tpu-v6e-slice",
25
25
  }
26
26
 
27
+ _DeviceClassToProto: Dict[DeviceClass, "tasks_pb2.GPUAccelerator.DeviceClass"] = {
28
+ "GPU": tasks_pb2.GPUAccelerator.NVIDIA_GPU,
29
+ "TPU": tasks_pb2.GPUAccelerator.GOOGLE_TPU,
30
+ "NEURON": tasks_pb2.GPUAccelerator.AMAZON_NEURON,
31
+ "AMD_GPU": tasks_pb2.GPUAccelerator.AMD_GPU,
32
+ "HABANA_GAUDI": tasks_pb2.GPUAccelerator.HABANA_GAUDI,
33
+ }
34
+
27
35
 
28
36
  def _get_cpu_resource_entry(cpu: CPUBaseType) -> tasks_pb2.Resources.ResourceEntry:
29
37
  return tasks_pb2.Resources.ResourceEntry(
@@ -54,11 +62,17 @@ def _get_gpu_extended_resource_entry(resources: Resources) -> Optional[tasks_pb2
54
62
  device = resources.get_device()
55
63
  if device is None:
56
64
  return None
57
- if device.device not in ACCELERATOR_DEVICE_MAP:
58
- raise ValueError(f"GPU of type {device.device} unknown, cannot map to device name")
65
+
66
+ device_class = _DeviceClassToProto.get(device.device_class, tasks_pb2.GPUAccelerator.NVIDIA_GPU)
67
+ if device.device is None:
68
+ raise RuntimeError("Device type must be specified for GPU string.")
69
+ else:
70
+ device_type = device.device
71
+ device_type = ACCELERATOR_DEVICE_MAP.get(device_type, device_type)
59
72
  return tasks_pb2.GPUAccelerator(
60
- device=ACCELERATOR_DEVICE_MAP[device.device],
73
+ device=device_type,
61
74
  partition_size=device.partition if device.partition else None,
75
+ device_class=device_class,
62
76
  )
63
77
 
64
78
 
@@ -2,7 +2,7 @@ import hashlib
2
2
  import typing
3
3
  from venv import logger
4
4
 
5
- from flyteidl.core import tasks_pb2
5
+ from flyteidl2.core import tasks_pb2
6
6
 
7
7
  import flyte.errors
8
8
  from flyte import ReusePolicy
@@ -8,14 +8,14 @@ import typing
8
8
  from datetime import timedelta
9
9
  from typing import Optional, cast
10
10
 
11
- from flyteidl.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
11
+ from flyteidl2.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
12
+ from flyteidl2.task import common_pb2, environment_pb2, task_definition_pb2
12
13
  from google.protobuf import duration_pb2, wrappers_pb2
13
14
 
14
15
  import flyte.errors
15
16
  from flyte._cache.cache import VersionParameters, cache_from_request
16
17
  from flyte._logging import logger
17
18
  from flyte._pod import _PRIMARY_CONTAINER_NAME_FIELD, PodTemplate
18
- from flyte._protos.workflow import common_pb2, environment_pb2, task_definition_pb2
19
19
  from flyte._secret import SecretRequest, secrets_from_request
20
20
  from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
21
21
  from flyte.models import CodeBundle, SerializationContext
@@ -172,6 +172,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
172
172
  pod_template_name=(task.pod_template if task.pod_template and isinstance(task.pod_template, str) else None),
173
173
  interruptible=task.interruptible,
174
174
  generates_deck=wrappers_pb2.BoolValue(value=task.report),
175
+ debuggable=task.debuggable,
175
176
  ),
176
177
  interface=transform_native_to_typed_interface(task.native_interface),
177
178
  custom=custom if len(custom) > 0 else None,
@@ -208,17 +209,13 @@ def _get_urun_container(
208
209
  else None
209
210
  )
210
211
  resources = get_proto_resources(task_template.resources)
211
- # pr: under what conditions should this return None?
212
+
212
213
  if isinstance(task_template.image, str):
213
214
  raise flyte.errors.RuntimeSystemError("BadConfig", "Image is not a valid image")
214
215
 
215
- env_name = ""
216
- if task_template.parent_env is not None:
217
- task_env = task_template.parent_env()
218
- if task_env is not None:
219
- env_name = task_env.name
220
- else:
221
- raise flyte.errors.RuntimeSystemError("BadConfig", "Task template has no parent environment")
216
+ env_name = task_template.parent_env_name
217
+ if env_name is None:
218
+ raise flyte.errors.RuntimeSystemError("BadConfig", f"Task {task_template.name} has no parent environment name")
222
219
 
223
220
  if not serialize_context.image_cache:
224
221
  # This computes the image uri, computing hashes as necessary so can fail if done remotely.
@@ -129,6 +129,14 @@ async def convert_and_run(
129
129
  in a context tree.
130
130
  """
131
131
  ctx = internal_ctx()
132
+
133
+ # Load inputs first to get context
134
+ if input_path:
135
+ inputs = await load_inputs(input_path, path_rewrite_config=raw_data_path.path_rewrite)
136
+
137
+ # Extract context from inputs
138
+ custom_context = inputs.context if inputs else {}
139
+
132
140
  tctx = TaskContext(
133
141
  action=action,
134
142
  checkpoints=checkpoints,
@@ -142,9 +150,10 @@ async def convert_and_run(
142
150
  report=flyte.report.Report(name=action.name),
143
151
  mode="remote" if not ctx.data.task_context else ctx.data.task_context.mode,
144
152
  interactive_mode=interactive_mode,
153
+ custom_context=custom_context,
145
154
  )
155
+
146
156
  with ctx.replace_task_context(tctx):
147
- inputs = await load_inputs(input_path, path_rewrite_config=raw_data_path.path_rewrite) if input_path else inputs
148
157
  inputs_kwargs = await convert_inputs_to_native(inputs, task.native_interface)
149
158
  out, err = await run_task(tctx=tctx, controller=controller, task=task, inputs=inputs_kwargs)
150
159
  if err is not None:
@@ -1,12 +1,12 @@
1
1
  import asyncio
2
2
  from typing import Union
3
3
 
4
- from flyteidl.core import interface_pb2, literals_pb2
4
+ from flyteidl2.core import interface_pb2, literals_pb2
5
+ from flyteidl2.task import common_pb2, run_pb2, task_definition_pb2
5
6
  from google.protobuf import timestamp_pb2, wrappers_pb2
6
7
 
7
8
  import flyte.types
8
9
  from flyte import Cron, FixedRate, Trigger, TriggerTime
9
- from flyte._protos.workflow import common_pb2, run_definition_pb2, trigger_definition_pb2
10
10
 
11
11
 
12
12
  def _to_schedule(m: Union[Cron, FixedRate], kickoff_arg_name: str | None = None) -> common_pb2.Schedule:
@@ -36,7 +36,7 @@ async def process_default_inputs(
36
36
  task_name: str,
37
37
  task_inputs: interface_pb2.VariableMap,
38
38
  task_default_inputs: list[common_pb2.NamedParameter],
39
- ) -> list[run_definition_pb2.NamedLiteral]:
39
+ ) -> list[common_pb2.NamedLiteral]:
40
40
  """
41
41
  Process default inputs and convert them to NamedLiteral objects.
42
42
 
@@ -68,10 +68,10 @@ async def process_default_inputs(
68
68
  keys.append(p.name)
69
69
  final_literals.append(p.parameter.default)
70
70
 
71
- literals: list[run_definition_pb2.NamedLiteral] = []
71
+ literals: list[common_pb2.NamedLiteral] = []
72
72
  for k, lit in zip(keys, final_literals):
73
73
  literals.append(
74
- run_definition_pb2.NamedLiteral(
74
+ common_pb2.NamedLiteral(
75
75
  name=k,
76
76
  value=lit,
77
77
  )
@@ -85,7 +85,7 @@ async def to_task_trigger(
85
85
  task_name: str,
86
86
  task_inputs: interface_pb2.VariableMap,
87
87
  task_default_inputs: list[common_pb2.NamedParameter],
88
- ) -> trigger_definition_pb2.TaskTrigger:
88
+ ) -> task_definition_pb2.TaskTrigger:
89
89
  """
90
90
  Converts a Trigger object to a TaskTrigger protobuf object.
91
91
  Args:
@@ -98,15 +98,15 @@ async def to_task_trigger(
98
98
  """
99
99
  env = None
100
100
  if t.env_vars:
101
- env = run_definition_pb2.Envs()
101
+ env = run_pb2.Envs()
102
102
  for k, v in t.env_vars.items():
103
103
  env.values.append(literals_pb2.KeyValuePair(key=k, value=v))
104
104
 
105
- labels = run_definition_pb2.Labels(values=t.labels) if t.labels else None
105
+ labels = run_pb2.Labels(values=t.labels) if t.labels else None
106
106
 
107
- annotations = run_definition_pb2.Annotations(values=t.annotations) if t.annotations else None
107
+ annotations = run_pb2.Annotations(values=t.annotations) if t.annotations else None
108
108
 
109
- run_spec = run_definition_pb2.RunSpec(
109
+ run_spec = run_pb2.RunSpec(
110
110
  overwrite_cache=t.overwrite_cache,
111
111
  envs=env,
112
112
  interruptible=wrappers_pb2.BoolValue(value=t.interruptible) if t.interruptible is not None else None,
@@ -139,12 +139,12 @@ async def to_task_trigger(
139
139
  kickoff_arg_name=kickoff_arg_name,
140
140
  )
141
141
 
142
- return trigger_definition_pb2.TaskTrigger(
142
+ return task_definition_pb2.TaskTrigger(
143
143
  name=t.name,
144
- spec=trigger_definition_pb2.TaskTriggerSpec(
144
+ spec=task_definition_pb2.TaskTriggerSpec(
145
145
  active=t.auto_activate,
146
146
  run_spec=run_spec,
147
- inputs=run_definition_pb2.Inputs(literals=literals),
147
+ inputs=common_pb2.Inputs(literals=literals),
148
148
  ),
149
149
  automation_spec=common_pb2.TriggerAutomationSpec(
150
150
  type=common_pb2.TriggerAutomationSpec.Type.TYPE_SCHEDULE,
@@ -1,6 +1,6 @@
1
1
  from typing import Dict, Optional, TypeVar
2
2
 
3
- from flyteidl.core import interface_pb2
3
+ from flyteidl2.core import interface_pb2
4
4
 
5
5
  from flyte.models import NativeInterface
6
6
  from flyte.types._type_engine import TypeEngine
flyte/_keyring/file.py CHANGED
@@ -72,9 +72,9 @@ class SimplePlainTextKeyring(KeyringBackend):
72
72
 
73
73
  @property
74
74
  def file_path(self) -> Path:
75
- from flyte._initialize import get_common_config
75
+ from flyte._initialize import get_init_config
76
76
 
77
- config_path = get_common_config().source_config_path
77
+ config_path = get_init_config().source_config_path
78
78
  if config_path and str(config_path.parent) == ".flyte":
79
79
  # if the config is in a .flyte directory, use that as the path
80
80
  return config_path.parent / "keyring.cfg"