flyte 2.0.0b32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (204) hide show
  1. flyte/__init__.py +108 -0
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +195 -0
  5. flyte/_bin/serve.py +178 -0
  6. flyte/_build.py +26 -0
  7. flyte/_cache/__init__.py +12 -0
  8. flyte/_cache/cache.py +147 -0
  9. flyte/_cache/defaults.py +9 -0
  10. flyte/_cache/local_cache.py +216 -0
  11. flyte/_cache/policy_function_body.py +42 -0
  12. flyte/_code_bundle/__init__.py +8 -0
  13. flyte/_code_bundle/_ignore.py +121 -0
  14. flyte/_code_bundle/_packaging.py +218 -0
  15. flyte/_code_bundle/_utils.py +347 -0
  16. flyte/_code_bundle/bundle.py +266 -0
  17. flyte/_constants.py +1 -0
  18. flyte/_context.py +155 -0
  19. flyte/_custom_context.py +73 -0
  20. flyte/_debug/__init__.py +0 -0
  21. flyte/_debug/constants.py +38 -0
  22. flyte/_debug/utils.py +17 -0
  23. flyte/_debug/vscode.py +307 -0
  24. flyte/_deploy.py +408 -0
  25. flyte/_deployer.py +109 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +122 -0
  29. flyte/_excepthook.py +37 -0
  30. flyte/_group.py +32 -0
  31. flyte/_hash.py +8 -0
  32. flyte/_image.py +1055 -0
  33. flyte/_initialize.py +628 -0
  34. flyte/_interface.py +119 -0
  35. flyte/_internal/__init__.py +3 -0
  36. flyte/_internal/controllers/__init__.py +129 -0
  37. flyte/_internal/controllers/_local_controller.py +239 -0
  38. flyte/_internal/controllers/_trace.py +48 -0
  39. flyte/_internal/controllers/remote/__init__.py +58 -0
  40. flyte/_internal/controllers/remote/_action.py +211 -0
  41. flyte/_internal/controllers/remote/_client.py +47 -0
  42. flyte/_internal/controllers/remote/_controller.py +583 -0
  43. flyte/_internal/controllers/remote/_core.py +465 -0
  44. flyte/_internal/controllers/remote/_informer.py +381 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +3 -0
  47. flyte/_internal/imagebuild/docker_builder.py +706 -0
  48. flyte/_internal/imagebuild/image_builder.py +277 -0
  49. flyte/_internal/imagebuild/remote_builder.py +386 -0
  50. flyte/_internal/imagebuild/utils.py +78 -0
  51. flyte/_internal/resolvers/__init__.py +0 -0
  52. flyte/_internal/resolvers/_task_module.py +21 -0
  53. flyte/_internal/resolvers/common.py +31 -0
  54. flyte/_internal/resolvers/default.py +28 -0
  55. flyte/_internal/runtime/__init__.py +0 -0
  56. flyte/_internal/runtime/convert.py +486 -0
  57. flyte/_internal/runtime/entrypoints.py +204 -0
  58. flyte/_internal/runtime/io.py +188 -0
  59. flyte/_internal/runtime/resources_serde.py +152 -0
  60. flyte/_internal/runtime/reuse.py +125 -0
  61. flyte/_internal/runtime/rusty.py +193 -0
  62. flyte/_internal/runtime/task_serde.py +362 -0
  63. flyte/_internal/runtime/taskrunner.py +209 -0
  64. flyte/_internal/runtime/trigger_serde.py +160 -0
  65. flyte/_internal/runtime/types_serde.py +54 -0
  66. flyte/_keyring/__init__.py +0 -0
  67. flyte/_keyring/file.py +115 -0
  68. flyte/_logging.py +300 -0
  69. flyte/_map.py +312 -0
  70. flyte/_module.py +72 -0
  71. flyte/_pod.py +30 -0
  72. flyte/_resources.py +473 -0
  73. flyte/_retry.py +32 -0
  74. flyte/_reusable_environment.py +102 -0
  75. flyte/_run.py +724 -0
  76. flyte/_secret.py +96 -0
  77. flyte/_task.py +550 -0
  78. flyte/_task_environment.py +316 -0
  79. flyte/_task_plugins.py +47 -0
  80. flyte/_timeout.py +47 -0
  81. flyte/_tools.py +27 -0
  82. flyte/_trace.py +119 -0
  83. flyte/_trigger.py +1000 -0
  84. flyte/_utils/__init__.py +30 -0
  85. flyte/_utils/asyn.py +121 -0
  86. flyte/_utils/async_cache.py +139 -0
  87. flyte/_utils/coro_management.py +27 -0
  88. flyte/_utils/docker_credentials.py +173 -0
  89. flyte/_utils/file_handling.py +72 -0
  90. flyte/_utils/helpers.py +134 -0
  91. flyte/_utils/lazy_module.py +54 -0
  92. flyte/_utils/module_loader.py +104 -0
  93. flyte/_utils/org_discovery.py +57 -0
  94. flyte/_utils/uv_script_parser.py +49 -0
  95. flyte/_version.py +34 -0
  96. flyte/app/__init__.py +22 -0
  97. flyte/app/_app_environment.py +157 -0
  98. flyte/app/_deploy.py +125 -0
  99. flyte/app/_input.py +160 -0
  100. flyte/app/_runtime/__init__.py +3 -0
  101. flyte/app/_runtime/app_serde.py +347 -0
  102. flyte/app/_types.py +101 -0
  103. flyte/app/extras/__init__.py +3 -0
  104. flyte/app/extras/_fastapi.py +151 -0
  105. flyte/cli/__init__.py +12 -0
  106. flyte/cli/_abort.py +28 -0
  107. flyte/cli/_build.py +114 -0
  108. flyte/cli/_common.py +468 -0
  109. flyte/cli/_create.py +371 -0
  110. flyte/cli/_delete.py +45 -0
  111. flyte/cli/_deploy.py +293 -0
  112. flyte/cli/_gen.py +176 -0
  113. flyte/cli/_get.py +370 -0
  114. flyte/cli/_option.py +33 -0
  115. flyte/cli/_params.py +554 -0
  116. flyte/cli/_plugins.py +209 -0
  117. flyte/cli/_run.py +597 -0
  118. flyte/cli/_serve.py +64 -0
  119. flyte/cli/_update.py +37 -0
  120. flyte/cli/_user.py +17 -0
  121. flyte/cli/main.py +221 -0
  122. flyte/config/__init__.py +3 -0
  123. flyte/config/_config.py +248 -0
  124. flyte/config/_internal.py +73 -0
  125. flyte/config/_reader.py +225 -0
  126. flyte/connectors/__init__.py +11 -0
  127. flyte/connectors/_connector.py +270 -0
  128. flyte/connectors/_server.py +197 -0
  129. flyte/connectors/utils.py +135 -0
  130. flyte/errors.py +243 -0
  131. flyte/extend.py +19 -0
  132. flyte/extras/__init__.py +5 -0
  133. flyte/extras/_container.py +286 -0
  134. flyte/git/__init__.py +3 -0
  135. flyte/git/_config.py +21 -0
  136. flyte/io/__init__.py +29 -0
  137. flyte/io/_dataframe/__init__.py +131 -0
  138. flyte/io/_dataframe/basic_dfs.py +223 -0
  139. flyte/io/_dataframe/dataframe.py +1026 -0
  140. flyte/io/_dir.py +910 -0
  141. flyte/io/_file.py +914 -0
  142. flyte/io/_hashing_io.py +342 -0
  143. flyte/models.py +479 -0
  144. flyte/py.typed +0 -0
  145. flyte/remote/__init__.py +35 -0
  146. flyte/remote/_action.py +738 -0
  147. flyte/remote/_app.py +57 -0
  148. flyte/remote/_client/__init__.py +0 -0
  149. flyte/remote/_client/_protocols.py +189 -0
  150. flyte/remote/_client/auth/__init__.py +12 -0
  151. flyte/remote/_client/auth/_auth_utils.py +14 -0
  152. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  153. flyte/remote/_client/auth/_authenticators/base.py +403 -0
  154. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  155. flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
  156. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  157. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  158. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  159. flyte/remote/_client/auth/_channel.py +213 -0
  160. flyte/remote/_client/auth/_client_config.py +85 -0
  161. flyte/remote/_client/auth/_default_html.py +32 -0
  162. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  163. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  164. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  165. flyte/remote/_client/auth/_keyring.py +152 -0
  166. flyte/remote/_client/auth/_token_client.py +260 -0
  167. flyte/remote/_client/auth/errors.py +16 -0
  168. flyte/remote/_client/controlplane.py +128 -0
  169. flyte/remote/_common.py +30 -0
  170. flyte/remote/_console.py +19 -0
  171. flyte/remote/_data.py +161 -0
  172. flyte/remote/_logs.py +185 -0
  173. flyte/remote/_project.py +88 -0
  174. flyte/remote/_run.py +386 -0
  175. flyte/remote/_secret.py +142 -0
  176. flyte/remote/_task.py +527 -0
  177. flyte/remote/_trigger.py +306 -0
  178. flyte/remote/_user.py +33 -0
  179. flyte/report/__init__.py +3 -0
  180. flyte/report/_report.py +182 -0
  181. flyte/report/_template.html +124 -0
  182. flyte/storage/__init__.py +36 -0
  183. flyte/storage/_config.py +237 -0
  184. flyte/storage/_parallel_reader.py +274 -0
  185. flyte/storage/_remote_fs.py +34 -0
  186. flyte/storage/_storage.py +456 -0
  187. flyte/storage/_utils.py +5 -0
  188. flyte/syncify/__init__.py +56 -0
  189. flyte/syncify/_api.py +375 -0
  190. flyte/types/__init__.py +52 -0
  191. flyte/types/_interface.py +40 -0
  192. flyte/types/_pickle.py +145 -0
  193. flyte/types/_renderer.py +162 -0
  194. flyte/types/_string_literals.py +119 -0
  195. flyte/types/_type_engine.py +2254 -0
  196. flyte/types/_utils.py +80 -0
  197. flyte-2.0.0b32.data/scripts/debug.py +38 -0
  198. flyte-2.0.0b32.data/scripts/runtime.py +195 -0
  199. flyte-2.0.0b32.dist-info/METADATA +351 -0
  200. flyte-2.0.0b32.dist-info/RECORD +204 -0
  201. flyte-2.0.0b32.dist-info/WHEEL +5 -0
  202. flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
  203. flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
  204. flyte-2.0.0b32.dist-info/top_level.txt +1 -0
@@ -0,0 +1,78 @@
1
+ import shutil
2
+ from pathlib import Path
3
+ from typing import List, Optional
4
+
5
+ from flyte._image import DockerIgnore, Image
6
+ from flyte._logging import logger
7
+
8
+
9
+ def copy_files_to_context(src: Path, context_path: Path, ignore_patterns: list[str] = []) -> Path:
10
+ """
11
+ This helper function ensures that absolute paths that users specify are converted correctly to a path in the
12
+ context directory. Doing this prevents collisions while ensuring files are available in the context.
13
+
14
+ For example, if a user has
15
+ img.with_requirements(Path("/Users/username/requirements.txt"))
16
+ .with_requirements(Path("requirements.txt"))
17
+ .with_requirements(Path("../requirements.txt"))
18
+
19
+ copying with this function ensures that the Docker context folder has all three files.
20
+
21
+ :param src: The source path to copy
22
+ :param context_path: The context path where the files should be copied to
23
+ """
24
+ if src.is_absolute() or ".." in str(src):
25
+ dst_path = context_path / str(src.absolute()).replace("/", "./_flyte_abs_context/", 1)
26
+ else:
27
+ dst_path = context_path / src
28
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
29
+ if src.is_dir():
30
+ default_ignore_patterns = [".idea", ".venv"]
31
+ ignore_patterns = list(set(ignore_patterns + default_ignore_patterns))
32
+ shutil.copytree(src, dst_path, dirs_exist_ok=True, ignore=shutil.ignore_patterns(*ignore_patterns))
33
+ else:
34
+ shutil.copy(src, dst_path)
35
+ return dst_path
36
+
37
+
38
+ def get_and_list_dockerignore(image: Image) -> List[str]:
39
+ """
40
+ Get and parse dockerignore patterns from .dockerignore file.
41
+
42
+ This function first looks for a DockerIgnore layer in the image's layers. If found, it uses
43
+ the path specified in that layer. If no DockerIgnore layer is found, it falls back to looking
44
+ for a .dockerignore file in the root_path directory.
45
+
46
+ :param image: The Image object
47
+ """
48
+ from flyte._initialize import _get_init_config
49
+
50
+ # Look for DockerIgnore layer in the image layers
51
+ dockerignore_path: Optional[Path] = None
52
+ patterns: List[str] = []
53
+
54
+ for layer in image._layers:
55
+ if isinstance(layer, DockerIgnore) and layer.path.strip():
56
+ dockerignore_path = Path(layer.path)
57
+ # If DockerIgnore layer not specified, set dockerignore_path under root_path
58
+ init_config = _get_init_config()
59
+ root_path = init_config.root_dir if init_config else None
60
+ if not dockerignore_path and root_path:
61
+ dockerignore_path = Path(root_path) / ".dockerignore"
62
+ # Return empty list if no .dockerignore file found
63
+ if not dockerignore_path or not dockerignore_path.exists() or not dockerignore_path.is_file():
64
+ logger.info(f".dockerignore file not found at path: {dockerignore_path}")
65
+ return patterns
66
+
67
+ try:
68
+ with open(dockerignore_path, "r", encoding="utf-8") as f:
69
+ for line in f:
70
+ stripped_line = line.strip()
71
+ # Skip empty lines, whitespace-only lines, and comments
72
+ if not stripped_line or stripped_line.startswith("#"):
73
+ continue
74
+ patterns.append(stripped_line)
75
+ except Exception as e:
76
+ logger.error(f"Failed to read .dockerignore file at {dockerignore_path}: {e}")
77
+ return []
78
+ return patterns
File without changes
@@ -0,0 +1,21 @@
1
+ import pathlib
2
+ from typing import Tuple
3
+
4
+ from flyte._module import extract_obj_module
5
+ from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
6
+
7
+
8
+ def extract_task_module(task: TaskTemplate, /, source_dir: pathlib.Path) -> Tuple[str, str]:
9
+ """
10
+ Extract the task module from the task template.
11
+
12
+ :param task: The task template to extract the module from.
13
+ :param source_dir: The source directory to use for relative paths.
14
+ :return: A tuple containing the entity name, module
15
+ """
16
+ if isinstance(task, AsyncFunctionTaskTemplate):
17
+ entity_name = task.func.__name__
18
+ entity_module_name, _ = extract_obj_module(task.func, source_dir)
19
+ return entity_name, entity_module_name
20
+ else:
21
+ raise NotImplementedError(f"Task module {task.name} not implemented.")
@@ -0,0 +1,31 @@
1
+ from asyncio import Protocol
2
+ from pathlib import Path
3
+ from typing import List, Optional
4
+
5
+ from flyte._task import TaskTemplate
6
+
7
+
8
+ class Resolver(Protocol):
9
+ """
10
+ Resolver interface for loading tasks. This interface should be implemented by Resolvers.
11
+ """
12
+
13
+ @property
14
+ def import_path(self) -> str:
15
+ """
16
+ The import path of the resolver. This should be a valid python import path.
17
+ """
18
+ return ""
19
+
20
+ def load_task(self, loader_args: List[str]) -> TaskTemplate:
21
+ """
22
+ Given the set of identifier keys, should return one TaskTemplate or raise an error if not found
23
+ """
24
+ raise NotImplementedError
25
+
26
+ def loader_args(self, t: TaskTemplate, root_dir: Optional[Path]) -> List[str]:
27
+ """
28
+ Return a list of strings that can help identify the parameter TaskTemplate. Each string should not have
29
+ spaces or special characters. This is used to identify the task in the resolver.
30
+ """
31
+ return []
@@ -0,0 +1,28 @@
1
+ import importlib
2
+ from pathlib import Path
3
+ from typing import List
4
+
5
+ from flyte._internal.resolvers._task_module import extract_task_module
6
+ from flyte._internal.resolvers.common import Resolver
7
+ from flyte._task import TaskTemplate
8
+
9
+
10
+ class DefaultTaskResolver(Resolver):
11
+ """
12
+ Please see the notes in the TaskResolverMixin as it describes this default behavior.
13
+ """
14
+
15
+ @property
16
+ def import_path(self) -> str:
17
+ return "flyte._internal.resolvers.default.DefaultTaskResolver"
18
+
19
+ def load_task(self, loader_args: List[str]) -> TaskTemplate:
20
+ _, task_module, _, task_name, *_ = loader_args
21
+
22
+ task_module = importlib.import_module(name=task_module) # type: ignore
23
+ task_def = getattr(task_module, task_name)
24
+ return task_def
25
+
26
+ def loader_args(self, task: TaskTemplate, root_dir: Path) -> List[str]: # type:ignore
27
+ t, m = extract_task_module(task, root_dir)
28
+ return ["mod", m, "instance", t]
File without changes
@@ -0,0 +1,486 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import base64
5
+ import hashlib
6
+ import inspect
7
+ from dataclasses import dataclass
8
+ from types import NoneType
9
+ from typing import Any, Dict, List, Tuple, Union, get_args
10
+
11
+ from flyteidl2.core import execution_pb2, interface_pb2, literals_pb2
12
+ from flyteidl2.task import common_pb2, task_definition_pb2
13
+
14
+ import flyte.errors
15
+ import flyte.storage as storage
16
+ from flyte._context import ctx
17
+ from flyte.models import ActionID, NativeInterface, TaskContext
18
+ from flyte.types import TypeEngine, TypeTransformerFailedError
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class Inputs:
23
+ proto_inputs: common_pb2.Inputs
24
+
25
+ @classmethod
26
+ def empty(cls) -> "Inputs":
27
+ return cls(proto_inputs=common_pb2.Inputs())
28
+
29
+ @property
30
+ def context(self) -> Dict[str, str]:
31
+ """Get the context as a dictionary."""
32
+ return {kv.key: kv.value for kv in self.proto_inputs.context}
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class Outputs:
37
+ proto_outputs: common_pb2.Outputs
38
+
39
+
40
+ @dataclass
41
+ class Error:
42
+ err: execution_pb2.ExecutionError
43
+
44
+
45
+ # ------------------------------- CONVERT Methods ------------------------------- #
46
+
47
+
48
+ def _clean_error_code(code: str) -> Tuple[str, str | None]:
49
+ """
50
+ The error code may have a server injected code and is of the form `RetriesExhausedError|<code>` or `<code>`.
51
+
52
+ :param code:
53
+ :return: "user code", optional server code
54
+ """
55
+ if "|" in code:
56
+ server_code, user_code = code.split("|", 1)
57
+ return user_code.strip(), server_code.strip()
58
+ return code.strip(), None
59
+
60
+
61
+ async def convert_inputs_to_native(inputs: Inputs, python_interface: NativeInterface) -> Dict[str, Any]:
62
+ literals = {named_literal.name: named_literal.value for named_literal in inputs.proto_inputs.literals}
63
+ native_vals = await TypeEngine.literal_map_to_kwargs(
64
+ literals_pb2.LiteralMap(literals=literals), python_interface.get_input_types()
65
+ )
66
+ return native_vals
67
+
68
+
69
+ async def convert_upload_default_inputs(interface: NativeInterface) -> List[common_pb2.NamedParameter]:
70
+ """
71
+ Converts the default inputs of a NativeInterface to a list of NamedParameters for upload.
72
+ This is used to upload default inputs to the Flyte backend.
73
+ """
74
+ if not interface.inputs:
75
+ return []
76
+
77
+ vars = []
78
+ literal_coros = []
79
+ for input_name, (input_type, default_value) in interface.inputs.items():
80
+ if default_value and default_value is not inspect.Parameter.empty:
81
+ lt = TypeEngine.to_literal_type(input_type)
82
+ literal_coros.append(TypeEngine.to_literal(default_value, input_type, lt))
83
+ vars.append((input_name, lt))
84
+
85
+ literals: List[literals_pb2.Literal] = await asyncio.gather(*literal_coros)
86
+ named_params = []
87
+ for (name, lt), literal in zip(vars, literals):
88
+ param = interface_pb2.Parameter(
89
+ var=interface_pb2.Variable(
90
+ type=lt,
91
+ ),
92
+ default=literal,
93
+ )
94
+ named_params.append(
95
+ common_pb2.NamedParameter(
96
+ name=name,
97
+ parameter=param,
98
+ ),
99
+ )
100
+ return named_params
101
+
102
+
103
+ def is_optional_type(tp) -> bool:
104
+ """
105
+ True if the *annotation* `tp` is equivalent to Optional[…].
106
+ Works for Optional[T], Union[T, None], and T | None.
107
+ """
108
+ return NoneType in get_args(tp) # fastest check
109
+
110
+
111
+ async def convert_from_native_to_inputs(
112
+ interface: NativeInterface, *args, custom_context: Dict[str, str] | None = None, **kwargs
113
+ ) -> Inputs:
114
+ kwargs = interface.convert_to_kwargs(*args, **kwargs)
115
+
116
+ missing = [key for key in interface.required_inputs() if key not in kwargs]
117
+ if missing:
118
+ raise ValueError(f"Missing required inputs: {', '.join(missing)}")
119
+
120
+ # Read custom_context from TaskContext if available (inside task execution)
121
+ # Otherwise use the passed parameter (for remote run initiation)
122
+ context_kvs = None
123
+ tctx = ctx()
124
+ if tctx and tctx.custom_context:
125
+ # Inside a task - read from TaskContext
126
+ context_to_use = tctx.custom_context
127
+ context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in context_to_use.items()]
128
+ elif custom_context:
129
+ # Remote run initiation
130
+ context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in custom_context.items()]
131
+
132
+ if len(interface.inputs) == 0:
133
+ # Handle context even for empty inputs
134
+ return Inputs(proto_inputs=common_pb2.Inputs(context=context_kvs))
135
+
136
+ # fill in defaults if missing
137
+ type_hints: Dict[str, type] = {}
138
+ already_converted_kwargs: Dict[str, literals_pb2.Literal] = {}
139
+ for input_name, (input_type, default_value) in interface.inputs.items():
140
+ if input_name in kwargs:
141
+ type_hints[input_name] = input_type
142
+ elif (
143
+ (default_value is not None and default_value is not inspect.Signature.empty)
144
+ or (default_value is None and is_optional_type(input_type))
145
+ or input_type is None
146
+ or input_type is type(None)
147
+ ):
148
+ if default_value == NativeInterface.has_default:
149
+ if interface._remote_defaults is None or input_name not in interface._remote_defaults:
150
+ raise ValueError(f"Input '{input_name}' has a default value but it is not set in the interface.")
151
+ already_converted_kwargs[input_name] = interface._remote_defaults[input_name]
152
+ elif input_type is None or input_type is type(None):
153
+ # If the type is 'None' or 'class<None>', we assume it's a placeholder for no type
154
+ kwargs[input_name] = None
155
+ type_hints[input_name] = NoneType
156
+ else:
157
+ kwargs[input_name] = default_value
158
+ type_hints[input_name] = input_type
159
+
160
+ literal_map = await TypeEngine.dict_to_literal_map(kwargs, type_hints)
161
+ if len(already_converted_kwargs) > 0:
162
+ copied_literals: Dict[str, literals_pb2.Literal] = {}
163
+ for k, v in literal_map.literals.items():
164
+ copied_literals[k] = v
165
+ # Add the already converted kwargs to the literal map
166
+ for k, v in already_converted_kwargs.items():
167
+ copied_literals[k] = v
168
+ literal_map = literals_pb2.LiteralMap(literals=copied_literals)
169
+
170
+ # Make sure we the interface, not literal_map or kwargs, because those may have a different order
171
+ return Inputs(
172
+ proto_inputs=common_pb2.Inputs(
173
+ literals=[common_pb2.NamedLiteral(name=k, value=literal_map.literals[k]) for k in interface.inputs.keys()],
174
+ context=context_kvs,
175
+ )
176
+ )
177
+
178
+
179
+ async def convert_from_inputs_to_native(native_interface: NativeInterface, inputs: Inputs) -> Dict[str, Any]:
180
+ """
181
+ Converts the inputs from a run definition proto to a native Python dictionary.
182
+ :param native_interface: The native interface of the task.
183
+ :param inputs: The run definition inputs proto.
184
+ :return: A dictionary of input names to their native Python values.
185
+ """
186
+ if not inputs or not inputs.proto_inputs or not inputs.proto_inputs.literals:
187
+ return {}
188
+
189
+ literals = {named_literal.name: named_literal.value for named_literal in inputs.proto_inputs.literals}
190
+ return await TypeEngine.literal_map_to_kwargs(
191
+ literals_pb2.LiteralMap(literals=literals), native_interface.get_input_types()
192
+ )
193
+
194
+
195
+ async def convert_from_native_to_outputs(o: Any, interface: NativeInterface, task_name: str = "") -> Outputs:
196
+ # Always make it a tuple even if it's just one item to simplify logic below
197
+ if not isinstance(o, tuple):
198
+ o = (o,)
199
+
200
+ if len(interface.outputs) == 0:
201
+ if len(o) != 0:
202
+ if len(o) == 1 and o[0] is not None:
203
+ raise flyte.errors.RuntimeDataValidationError(
204
+ "o0",
205
+ f"Expected no outputs but got {o},did you miss a return type annotation?",
206
+ task_name,
207
+ )
208
+ else:
209
+ assert len(o) == len(interface.outputs), (
210
+ f"Received {len(o)} outputs but return annotation has {len(interface.outputs)} outputs specified. "
211
+ )
212
+ named = []
213
+ for (output_name, python_type), v in zip(interface.outputs.items(), o):
214
+ try:
215
+ lit = await TypeEngine.to_literal(v, python_type, TypeEngine.to_literal_type(python_type))
216
+ named.append(common_pb2.NamedLiteral(name=output_name, value=lit))
217
+ except TypeTransformerFailedError as e:
218
+ raise flyte.errors.RuntimeDataValidationError(output_name, e, task_name)
219
+
220
+ return Outputs(proto_outputs=common_pb2.Outputs(literals=named))
221
+
222
+
223
+ async def convert_outputs_to_native(interface: NativeInterface, outputs: Outputs) -> Union[Any, Tuple[Any, ...]]:
224
+ lm = literals_pb2.LiteralMap(
225
+ literals={named_literal.name: named_literal.value for named_literal in outputs.proto_outputs.literals}
226
+ )
227
+ kwargs = await TypeEngine.literal_map_to_kwargs(lm, interface.outputs)
228
+ if len(kwargs) == 0:
229
+ return None
230
+ elif len(kwargs) == 1:
231
+ return next(iter(kwargs.values()))
232
+ else:
233
+ # Return as tuple if multiple outputs, make sure to order correctly as it seems proto maps can change ordering
234
+ return tuple(kwargs[k] for k in interface.outputs.keys())
235
+
236
+
237
+ def convert_error_to_native(err: execution_pb2.ExecutionError | Exception | Error) -> Exception | None:
238
+ if not err:
239
+ return None
240
+
241
+ if isinstance(err, Exception):
242
+ return err
243
+
244
+ if isinstance(err, Error):
245
+ err = err.err
246
+
247
+ user_code, _server_code = _clean_error_code(err.code)
248
+ match err.kind:
249
+ case execution_pb2.ExecutionError.UNKNOWN:
250
+ return flyte.errors.RuntimeUnknownError(code=user_code, message=err.message, worker=err.worker)
251
+ case execution_pb2.ExecutionError.USER:
252
+ if "OOM" in err.code.upper():
253
+ return flyte.errors.OOMError(code=user_code, message=err.message, worker=err.worker)
254
+ elif "Interrupted" in err.code:
255
+ return flyte.errors.TaskInterruptedError(code=user_code, message=err.message, worker=err.worker)
256
+ elif "PrimaryContainerNotFound" in err.code:
257
+ return flyte.errors.PrimaryContainerNotFoundError(
258
+ code=user_code, message=err.message, worker=err.worker
259
+ )
260
+ elif "RetriesExhausted" in err.code:
261
+ return flyte.errors.RetriesExhaustedError(code=user_code, message=err.message, worker=err.worker)
262
+ elif "Unknown" in err.code:
263
+ return flyte.errors.RuntimeUnknownError(code=user_code, message=err.message, worker=err.worker)
264
+ elif "InvalidImageName" in err.code:
265
+ return flyte.errors.InvalidImageNameError(code=user_code, message=err.message, worker=err.worker)
266
+ elif "ImagePullBackOff" in err.code:
267
+ return flyte.errors.ImagePullBackOffError(code=user_code, message=err.message, worker=err.worker)
268
+ return flyte.errors.RuntimeUserError(code=user_code, message=err.message, worker=err.worker)
269
+ case execution_pb2.ExecutionError.SYSTEM:
270
+ return flyte.errors.RuntimeSystemError(code=user_code, message=err.message, worker=err.worker)
271
+ return None
272
+
273
+
274
+ def convert_from_native_to_error(err: BaseException) -> Error:
275
+ if isinstance(err, flyte.errors.RuntimeUnknownError):
276
+ return Error(
277
+ err=execution_pb2.ExecutionError(
278
+ kind=execution_pb2.ExecutionError.UNKNOWN,
279
+ code=err.code,
280
+ message=str(err),
281
+ worker=err.worker,
282
+ )
283
+ )
284
+ elif isinstance(err, flyte.errors.RuntimeUserError):
285
+ return Error(
286
+ err=execution_pb2.ExecutionError(
287
+ kind=execution_pb2.ExecutionError.USER,
288
+ code=err.code,
289
+ message=str(err),
290
+ worker=err.worker,
291
+ )
292
+ )
293
+ elif isinstance(err, flyte.errors.RuntimeSystemError):
294
+ return Error(
295
+ err=execution_pb2.ExecutionError(
296
+ kind=execution_pb2.ExecutionError.SYSTEM,
297
+ code=err.code,
298
+ message=str(err),
299
+ worker=err.worker,
300
+ )
301
+ )
302
+ else:
303
+ return Error(
304
+ err=execution_pb2.ExecutionError(
305
+ kind=execution_pb2.ExecutionError.UNKNOWN,
306
+ code=type(err).__name__,
307
+ message=str(err),
308
+ worker="UNKNOWN",
309
+ )
310
+ )
311
+
312
+
313
+ def hash_data(data: Union[str, bytes]) -> str:
314
+ """
315
+ Generate a hash for the given data. If the data is a string, it will be encoded to bytes before hashing.
316
+ :param data: The data to hash, can be a string or bytes.
317
+ :return: A hexadecimal string representation of the hash.
318
+ """
319
+ if isinstance(data, str):
320
+ data = data.encode("utf-8")
321
+ digest = hashlib.sha256(data).digest()
322
+ return base64.b64encode(digest).decode("utf-8")
323
+
324
+
325
+ def generate_inputs_hash(serialized_inputs: str | bytes) -> str:
326
+ """
327
+ Generate a hash for the inputs. This is used to uniquely identify the inputs for a task.
328
+ :return: A hexadecimal string representation of the hash.
329
+ """
330
+ return hash_data(serialized_inputs)
331
+
332
+
333
+ def generate_inputs_repr_for_literal(literal: literals_pb2.Literal) -> bytes:
334
+ """
335
+ Generate a byte representation for a single literal that is meant to be hashed as part of the cache key
336
+ computation for an Action. This function should just serialize the literal deterministically, but will
337
+ use an existing hash value if present in the Literal. This is trivial, except we need to handle nested literals
338
+ (inside collections and maps), that may have the hash property set.
339
+
340
+ :param literal: The literal to get a hashable representation for.
341
+ :return: byte representation of the literal that can be fed into a hash function.
342
+ """
343
+ # If the literal has a hash value, use that instead of serializing the full literal
344
+ if literal.hash:
345
+ return literal.hash.encode("utf-8")
346
+
347
+ if literal.HasField("collection"):
348
+ buf = bytearray()
349
+ for nested_literal in literal.collection.literals:
350
+ if nested_literal.hash:
351
+ buf += nested_literal.hash.encode("utf-8")
352
+ else:
353
+ buf += generate_inputs_repr_for_literal(nested_literal)
354
+
355
+ b = bytes(buf)
356
+ return b
357
+
358
+ elif literal.HasField("map"):
359
+ buf = bytearray()
360
+ # Sort keys to ensure deterministic ordering
361
+ for key in sorted(literal.map.literals.keys()):
362
+ nested_literal = literal.map.literals[key]
363
+ buf += key.encode("utf-8")
364
+ if nested_literal.hash:
365
+ buf += nested_literal.hash.encode("utf-8")
366
+ else:
367
+ buf += generate_inputs_repr_for_literal(nested_literal)
368
+
369
+ b = bytes(buf)
370
+ return b
371
+
372
+ # For all other cases (scalars, etc.), just serialize the literal normally
373
+ return literal.SerializeToString(deterministic=True)
374
+
375
+
376
+ def generate_inputs_hash_for_named_literals(inputs: list[common_pb2.NamedLiteral]) -> str:
377
+ """
378
+ Generate a hash for the inputs using the new literal representation approach that respects
379
+ hash values already present in literals. This is used to uniquely identify the inputs for a task
380
+ when some literals may have precomputed hash values.
381
+
382
+ :param inputs: List of NamedLiteral inputs to hash.
383
+ :return: A base64-encoded string representation of the hash.
384
+ """
385
+ if not inputs:
386
+ return ""
387
+
388
+ # Build the byte representation by concatenating each literal's representation
389
+ combined_bytes = b""
390
+ for named_literal in inputs:
391
+ # Add the name to ensure order matters
392
+ name_bytes = named_literal.name.encode("utf-8")
393
+ literal_bytes = generate_inputs_repr_for_literal(named_literal.value)
394
+ # Combine name and literal bytes with a separator to avoid collisions
395
+ combined_bytes += name_bytes + b":" + literal_bytes + b";"
396
+
397
+ return hash_data(combined_bytes)
398
+
399
+
400
+ def generate_inputs_hash_from_proto(inputs: common_pb2.Inputs) -> str:
401
+ """
402
+ Generate a hash for the inputs. This is used to uniquely identify the inputs for a task.
403
+ :param inputs: The inputs to hash.
404
+ :return: A hexadecimal string representation of the hash.
405
+ """
406
+ if not inputs or not inputs.literals:
407
+ return ""
408
+ return generate_inputs_hash_for_named_literals(list(inputs.literals))
409
+
410
+
411
+ def generate_interface_hash(task_interface: interface_pb2.TypedInterface) -> str:
412
+ """
413
+ Generate a hash for the task interface. This is used to uniquely identify the task interface.
414
+ :param task_interface: The interface of the task.
415
+ :return: A hexadecimal string representation of the hash.
416
+ """
417
+ if not task_interface:
418
+ return ""
419
+ serialized_interface = task_interface.SerializeToString(deterministic=True)
420
+ return hash_data(serialized_interface)
421
+
422
+
423
+ def generate_cache_key_hash(
424
+ task_name: str,
425
+ inputs_hash: str,
426
+ task_interface: interface_pb2.TypedInterface,
427
+ cache_version: str,
428
+ ignored_input_vars: List[str],
429
+ proto_inputs: common_pb2.Inputs,
430
+ ) -> str:
431
+ """
432
+ Generate a cache key hash based on the inputs hash, task name, task interface, and cache version.
433
+ This is used to uniquely identify the cache key for a task.
434
+
435
+ :param task_name: The name of the task.
436
+ :param inputs_hash: The hash of the inputs.
437
+ :param task_interface: The interface of the task.
438
+ :param cache_version: The version of the cache.
439
+ :param ignored_input_vars: A list of input variable names to ignore when generating the cache key.
440
+ :param proto_inputs: The proto inputs for the task, only used if there are ignored inputs.
441
+ :return: A hexadecimal string representation of the cache key hash.
442
+ """
443
+ if ignored_input_vars:
444
+ filtered = [named_lit for named_lit in proto_inputs.literals if named_lit.name not in ignored_input_vars]
445
+ final = common_pb2.Inputs(literals=filtered)
446
+ final_inputs = generate_inputs_hash_from_proto(final)
447
+ else:
448
+ final_inputs = inputs_hash
449
+
450
+ interface_hash = generate_interface_hash(task_interface)
451
+
452
+ data = f"{final_inputs}{task_name}{interface_hash}{cache_version}"
453
+ return hash_data(data)
454
+
455
+
456
+ def generate_sub_action_id_and_output_path(
457
+ tctx: TaskContext,
458
+ task_spec_or_name: task_definition_pb2.TaskSpec | str,
459
+ inputs_hash: str,
460
+ invoke_seq: int,
461
+ ) -> Tuple[ActionID, str]:
462
+ """
463
+ Generate a sub-action ID and output path based on the current task context, task name, and inputs.
464
+
465
+ action name = current action name + task name + input hash + group name (if available)
466
+ :param tctx:
467
+ :param task_spec_or_name: task specification or task name. Task name is only used in case of trace actions.
468
+ :param inputs_hash: Consistent hash string of the inputs
469
+ :param invoke_seq: The sequence number of the invocation, used to differentiate between multiple invocations.
470
+ :return:
471
+ """
472
+ current_action_id = tctx.action
473
+ current_output_path = tctx.run_base_dir
474
+ if isinstance(task_spec_or_name, task_definition_pb2.TaskSpec):
475
+ task_spec_or_name.task_template.interface
476
+ task_hash = hash_data(task_spec_or_name.SerializeToString(deterministic=True))
477
+ else:
478
+ task_hash = task_spec_or_name
479
+ sub_action_id = current_action_id.new_sub_action_from(
480
+ task_hash=task_hash,
481
+ input_hash=inputs_hash,
482
+ group=tctx.group_data.name if tctx.group_data else None,
483
+ task_call_seq=invoke_seq,
484
+ )
485
+ sub_run_output_path = storage.join(current_output_path, sub_action_id.name)
486
+ return sub_action_id, sub_run_output_path