flyte 0.2.0b1__py3-none-any.whl → 2.0.0b46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. flyte/__init__.py +83 -30
  2. flyte/_bin/connect.py +61 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +87 -19
  5. flyte/_bin/serve.py +351 -0
  6. flyte/_build.py +3 -2
  7. flyte/_cache/cache.py +6 -5
  8. flyte/_cache/local_cache.py +216 -0
  9. flyte/_code_bundle/_ignore.py +31 -5
  10. flyte/_code_bundle/_packaging.py +42 -11
  11. flyte/_code_bundle/_utils.py +57 -34
  12. flyte/_code_bundle/bundle.py +130 -27
  13. flyte/_constants.py +1 -0
  14. flyte/_context.py +21 -5
  15. flyte/_custom_context.py +73 -0
  16. flyte/_debug/constants.py +37 -0
  17. flyte/_debug/utils.py +17 -0
  18. flyte/_debug/vscode.py +315 -0
  19. flyte/_deploy.py +396 -75
  20. flyte/_deployer.py +109 -0
  21. flyte/_environment.py +94 -11
  22. flyte/_excepthook.py +37 -0
  23. flyte/_group.py +2 -1
  24. flyte/_hash.py +1 -16
  25. flyte/_image.py +544 -231
  26. flyte/_initialize.py +456 -316
  27. flyte/_interface.py +40 -5
  28. flyte/_internal/controllers/__init__.py +22 -8
  29. flyte/_internal/controllers/_local_controller.py +159 -35
  30. flyte/_internal/controllers/_trace.py +18 -10
  31. flyte/_internal/controllers/remote/__init__.py +38 -9
  32. flyte/_internal/controllers/remote/_action.py +82 -12
  33. flyte/_internal/controllers/remote/_client.py +6 -2
  34. flyte/_internal/controllers/remote/_controller.py +290 -64
  35. flyte/_internal/controllers/remote/_core.py +155 -95
  36. flyte/_internal/controllers/remote/_informer.py +40 -20
  37. flyte/_internal/controllers/remote/_service_protocol.py +2 -2
  38. flyte/_internal/imagebuild/__init__.py +2 -10
  39. flyte/_internal/imagebuild/docker_builder.py +391 -84
  40. flyte/_internal/imagebuild/image_builder.py +111 -55
  41. flyte/_internal/imagebuild/remote_builder.py +409 -0
  42. flyte/_internal/imagebuild/utils.py +79 -0
  43. flyte/_internal/resolvers/_app_env_module.py +92 -0
  44. flyte/_internal/resolvers/_task_module.py +5 -38
  45. flyte/_internal/resolvers/app_env.py +26 -0
  46. flyte/_internal/resolvers/common.py +8 -1
  47. flyte/_internal/resolvers/default.py +2 -2
  48. flyte/_internal/runtime/convert.py +319 -36
  49. flyte/_internal/runtime/entrypoints.py +106 -18
  50. flyte/_internal/runtime/io.py +71 -23
  51. flyte/_internal/runtime/resources_serde.py +21 -7
  52. flyte/_internal/runtime/reuse.py +125 -0
  53. flyte/_internal/runtime/rusty.py +196 -0
  54. flyte/_internal/runtime/task_serde.py +239 -66
  55. flyte/_internal/runtime/taskrunner.py +48 -8
  56. flyte/_internal/runtime/trigger_serde.py +162 -0
  57. flyte/_internal/runtime/types_serde.py +7 -16
  58. flyte/_keyring/file.py +115 -0
  59. flyte/_link.py +30 -0
  60. flyte/_logging.py +241 -42
  61. flyte/_map.py +312 -0
  62. flyte/_metrics.py +59 -0
  63. flyte/_module.py +74 -0
  64. flyte/_pod.py +30 -0
  65. flyte/_resources.py +296 -33
  66. flyte/_retry.py +1 -7
  67. flyte/_reusable_environment.py +72 -7
  68. flyte/_run.py +462 -132
  69. flyte/_secret.py +47 -11
  70. flyte/_serve.py +333 -0
  71. flyte/_task.py +245 -56
  72. flyte/_task_environment.py +219 -97
  73. flyte/_task_plugins.py +47 -0
  74. flyte/_tools.py +8 -8
  75. flyte/_trace.py +15 -24
  76. flyte/_trigger.py +1027 -0
  77. flyte/_utils/__init__.py +12 -1
  78. flyte/_utils/asyn.py +3 -1
  79. flyte/_utils/async_cache.py +139 -0
  80. flyte/_utils/coro_management.py +5 -4
  81. flyte/_utils/description_parser.py +19 -0
  82. flyte/_utils/docker_credentials.py +173 -0
  83. flyte/_utils/helpers.py +45 -19
  84. flyte/_utils/module_loader.py +123 -0
  85. flyte/_utils/org_discovery.py +57 -0
  86. flyte/_utils/uv_script_parser.py +8 -1
  87. flyte/_version.py +16 -3
  88. flyte/app/__init__.py +27 -0
  89. flyte/app/_app_environment.py +362 -0
  90. flyte/app/_connector_environment.py +40 -0
  91. flyte/app/_deploy.py +130 -0
  92. flyte/app/_parameter.py +343 -0
  93. flyte/app/_runtime/__init__.py +3 -0
  94. flyte/app/_runtime/app_serde.py +383 -0
  95. flyte/app/_types.py +113 -0
  96. flyte/app/extras/__init__.py +9 -0
  97. flyte/app/extras/_auth_middleware.py +217 -0
  98. flyte/app/extras/_fastapi.py +93 -0
  99. flyte/app/extras/_model_loader/__init__.py +3 -0
  100. flyte/app/extras/_model_loader/config.py +7 -0
  101. flyte/app/extras/_model_loader/loader.py +288 -0
  102. flyte/cli/__init__.py +12 -0
  103. flyte/cli/_abort.py +28 -0
  104. flyte/cli/_build.py +114 -0
  105. flyte/cli/_common.py +493 -0
  106. flyte/cli/_create.py +371 -0
  107. flyte/cli/_delete.py +45 -0
  108. flyte/cli/_deploy.py +401 -0
  109. flyte/cli/_gen.py +316 -0
  110. flyte/cli/_get.py +446 -0
  111. flyte/cli/_option.py +33 -0
  112. flyte/{_cli → cli}/_params.py +57 -17
  113. flyte/cli/_plugins.py +209 -0
  114. flyte/cli/_prefetch.py +292 -0
  115. flyte/cli/_run.py +690 -0
  116. flyte/cli/_serve.py +338 -0
  117. flyte/cli/_update.py +86 -0
  118. flyte/cli/_user.py +20 -0
  119. flyte/cli/main.py +246 -0
  120. flyte/config/__init__.py +2 -167
  121. flyte/config/_config.py +215 -163
  122. flyte/config/_internal.py +10 -1
  123. flyte/config/_reader.py +225 -0
  124. flyte/connectors/__init__.py +11 -0
  125. flyte/connectors/_connector.py +330 -0
  126. flyte/connectors/_server.py +194 -0
  127. flyte/connectors/utils.py +159 -0
  128. flyte/errors.py +134 -2
  129. flyte/extend.py +24 -0
  130. flyte/extras/_container.py +69 -56
  131. flyte/git/__init__.py +3 -0
  132. flyte/git/_config.py +279 -0
  133. flyte/io/__init__.py +8 -1
  134. flyte/io/{structured_dataset → _dataframe}/__init__.py +32 -30
  135. flyte/io/{structured_dataset → _dataframe}/basic_dfs.py +75 -68
  136. flyte/io/{structured_dataset/structured_dataset.py → _dataframe/dataframe.py} +207 -242
  137. flyte/io/_dir.py +575 -113
  138. flyte/io/_file.py +587 -141
  139. flyte/io/_hashing_io.py +342 -0
  140. flyte/io/extend.py +7 -0
  141. flyte/models.py +635 -0
  142. flyte/prefetch/__init__.py +22 -0
  143. flyte/prefetch/_hf_model.py +563 -0
  144. flyte/remote/__init__.py +14 -3
  145. flyte/remote/_action.py +879 -0
  146. flyte/remote/_app.py +346 -0
  147. flyte/remote/_auth_metadata.py +42 -0
  148. flyte/remote/_client/_protocols.py +62 -4
  149. flyte/remote/_client/auth/_auth_utils.py +19 -0
  150. flyte/remote/_client/auth/_authenticators/base.py +8 -2
  151. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  152. flyte/remote/_client/auth/_authenticators/factory.py +4 -0
  153. flyte/remote/_client/auth/_authenticators/passthrough.py +79 -0
  154. flyte/remote/_client/auth/_authenticators/pkce.py +17 -18
  155. flyte/remote/_client/auth/_channel.py +47 -18
  156. flyte/remote/_client/auth/_client_config.py +5 -3
  157. flyte/remote/_client/auth/_keyring.py +15 -2
  158. flyte/remote/_client/auth/_token_client.py +3 -3
  159. flyte/remote/_client/controlplane.py +206 -18
  160. flyte/remote/_common.py +66 -0
  161. flyte/remote/_data.py +107 -22
  162. flyte/remote/_logs.py +116 -33
  163. flyte/remote/_project.py +21 -19
  164. flyte/remote/_run.py +164 -631
  165. flyte/remote/_secret.py +72 -29
  166. flyte/remote/_task.py +387 -46
  167. flyte/remote/_trigger.py +368 -0
  168. flyte/remote/_user.py +43 -0
  169. flyte/report/_report.py +10 -6
  170. flyte/storage/__init__.py +13 -1
  171. flyte/storage/_config.py +237 -0
  172. flyte/storage/_parallel_reader.py +289 -0
  173. flyte/storage/_storage.py +268 -59
  174. flyte/syncify/__init__.py +56 -0
  175. flyte/syncify/_api.py +414 -0
  176. flyte/types/__init__.py +39 -0
  177. flyte/types/_interface.py +22 -7
  178. flyte/{io/pickle/transformer.py → types/_pickle.py} +37 -9
  179. flyte/types/_string_literals.py +8 -9
  180. flyte/types/_type_engine.py +226 -126
  181. flyte/types/_utils.py +1 -1
  182. flyte-2.0.0b46.data/scripts/debug.py +38 -0
  183. flyte-2.0.0b46.data/scripts/runtime.py +194 -0
  184. flyte-2.0.0b46.dist-info/METADATA +352 -0
  185. flyte-2.0.0b46.dist-info/RECORD +221 -0
  186. flyte-2.0.0b46.dist-info/entry_points.txt +8 -0
  187. flyte-2.0.0b46.dist-info/licenses/LICENSE +201 -0
  188. flyte/_api_commons.py +0 -3
  189. flyte/_cli/_common.py +0 -299
  190. flyte/_cli/_create.py +0 -42
  191. flyte/_cli/_delete.py +0 -23
  192. flyte/_cli/_deploy.py +0 -140
  193. flyte/_cli/_get.py +0 -235
  194. flyte/_cli/_run.py +0 -174
  195. flyte/_cli/main.py +0 -98
  196. flyte/_datastructures.py +0 -342
  197. flyte/_internal/controllers/pbhash.py +0 -39
  198. flyte/_protos/common/authorization_pb2.py +0 -66
  199. flyte/_protos/common/authorization_pb2.pyi +0 -108
  200. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  201. flyte/_protos/common/identifier_pb2.py +0 -71
  202. flyte/_protos/common/identifier_pb2.pyi +0 -82
  203. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  204. flyte/_protos/common/identity_pb2.py +0 -48
  205. flyte/_protos/common/identity_pb2.pyi +0 -72
  206. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  207. flyte/_protos/common/list_pb2.py +0 -36
  208. flyte/_protos/common/list_pb2.pyi +0 -69
  209. flyte/_protos/common/list_pb2_grpc.py +0 -4
  210. flyte/_protos/common/policy_pb2.py +0 -37
  211. flyte/_protos/common/policy_pb2.pyi +0 -27
  212. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  213. flyte/_protos/common/role_pb2.py +0 -37
  214. flyte/_protos/common/role_pb2.pyi +0 -53
  215. flyte/_protos/common/role_pb2_grpc.py +0 -4
  216. flyte/_protos/common/runtime_version_pb2.py +0 -28
  217. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  218. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  219. flyte/_protos/logs/dataplane/payload_pb2.py +0 -96
  220. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -168
  221. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  222. flyte/_protos/secret/definition_pb2.py +0 -49
  223. flyte/_protos/secret/definition_pb2.pyi +0 -93
  224. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  225. flyte/_protos/secret/payload_pb2.py +0 -62
  226. flyte/_protos/secret/payload_pb2.pyi +0 -94
  227. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  228. flyte/_protos/secret/secret_pb2.py +0 -38
  229. flyte/_protos/secret/secret_pb2.pyi +0 -6
  230. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  231. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  232. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  233. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  234. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  235. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  236. flyte/_protos/workflow/queue_service_pb2.py +0 -106
  237. flyte/_protos/workflow/queue_service_pb2.pyi +0 -141
  238. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  239. flyte/_protos/workflow/run_definition_pb2.py +0 -128
  240. flyte/_protos/workflow/run_definition_pb2.pyi +0 -310
  241. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  242. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  243. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  244. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  245. flyte/_protos/workflow/run_service_pb2.py +0 -133
  246. flyte/_protos/workflow/run_service_pb2.pyi +0 -175
  247. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -412
  248. flyte/_protos/workflow/state_service_pb2.py +0 -58
  249. flyte/_protos/workflow/state_service_pb2.pyi +0 -71
  250. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  251. flyte/_protos/workflow/task_definition_pb2.py +0 -72
  252. flyte/_protos/workflow/task_definition_pb2.pyi +0 -65
  253. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  254. flyte/_protos/workflow/task_service_pb2.py +0 -44
  255. flyte/_protos/workflow/task_service_pb2.pyi +0 -31
  256. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -104
  257. flyte/io/_dataframe.py +0 -0
  258. flyte/io/pickle/__init__.py +0 -0
  259. flyte/remote/_console.py +0 -18
  260. flyte-0.2.0b1.dist-info/METADATA +0 -179
  261. flyte-0.2.0b1.dist-info/RECORD +0 -204
  262. flyte-0.2.0b1.dist-info/entry_points.txt +0 -3
  263. /flyte/{_cli → _debug}/__init__.py +0 -0
  264. /flyte/{_protos → _keyring}/__init__.py +0 -0
  265. {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/WHEEL +0 -0
  266. {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  import importlib
2
2
  from pathlib import Path
3
- from typing import List, Optional
3
+ from typing import List
4
4
 
5
5
  from flyte._internal.resolvers._task_module import extract_task_module
6
6
  from flyte._internal.resolvers.common import Resolver
@@ -23,6 +23,6 @@ class DefaultTaskResolver(Resolver):
23
23
  task_def = getattr(task_module, task_name)
24
24
  return task_def
25
25
 
26
- def loader_args(self, task: TaskTemplate, root_dir: Optional[Path] = None) -> List[str]: # type:ignore
26
+ def loader_args(self, task: TaskTemplate, root_dir: Path) -> List[str]: # type:ignore
27
27
  t, m = extract_task_module(task, root_dir)
28
28
  return ["mod", m, "instance", t]
@@ -1,30 +1,40 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
4
+ import base64
5
+ import hashlib
6
+ import inspect
3
7
  from dataclasses import dataclass
4
- from typing import Any, Dict, Tuple, Union
8
+ from types import NoneType
9
+ from typing import Any, Dict, List, Tuple, Union, get_args
5
10
 
6
- from flyteidl.core import execution_pb2, literals_pb2
11
+ from flyteidl2.core import execution_pb2, interface_pb2, literals_pb2
12
+ from flyteidl2.task import common_pb2, task_definition_pb2
7
13
 
8
14
  import flyte.errors
9
15
  import flyte.storage as storage
10
- from flyte._datastructures import ActionID, NativeInterface, TaskContext
11
- from flyte._internal.controllers import pbhash
12
- from flyte._protos.workflow import run_definition_pb2
13
- from flyte.types import TypeEngine
16
+ from flyte._context import ctx
17
+ from flyte.models import ActionID, NativeInterface, TaskContext
18
+ from flyte.types import TypeEngine, TypeTransformerFailedError
14
19
 
15
20
 
16
21
  @dataclass(frozen=True)
17
22
  class Inputs:
18
- proto_inputs: run_definition_pb2.Inputs
23
+ proto_inputs: common_pb2.Inputs
19
24
 
20
25
  @classmethod
21
26
  def empty(cls) -> "Inputs":
22
- return cls(proto_inputs=run_definition_pb2.Inputs())
27
+ return cls(proto_inputs=common_pb2.Inputs())
28
+
29
+ @property
30
+ def context(self) -> Dict[str, str]:
31
+ """Get the context as a dictionary."""
32
+ return {kv.key: kv.value for kv in self.proto_inputs.context}
23
33
 
24
34
 
25
35
  @dataclass(frozen=True)
26
36
  class Outputs:
27
- proto_outputs: run_definition_pb2.Outputs
37
+ proto_outputs: common_pb2.Outputs
28
38
 
29
39
 
30
40
  @dataclass
@@ -56,43 +66,160 @@ async def convert_inputs_to_native(inputs: Inputs, python_interface: NativeInter
56
66
  return native_vals
57
67
 
58
68
 
59
- async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwargs) -> Inputs:
69
+ async def convert_upload_default_inputs(interface: NativeInterface) -> List[common_pb2.NamedParameter]:
70
+ """
71
+ Converts the default inputs of a NativeInterface to a list of NamedParameters for upload.
72
+ This is used to upload default inputs to the Flyte backend.
73
+ """
74
+ if not interface.inputs:
75
+ return []
76
+
77
+ vars = []
78
+ literal_coros = []
79
+ for input_name, (input_type, default_value) in interface.inputs.items():
80
+ if default_value and default_value is not inspect.Parameter.empty:
81
+ lt = TypeEngine.to_literal_type(input_type)
82
+ literal_coros.append(TypeEngine.to_literal(default_value, input_type, lt))
83
+ vars.append((input_name, lt))
84
+
85
+ literals: List[literals_pb2.Literal] = await asyncio.gather(*literal_coros, return_exceptions=True)
86
+ named_params = []
87
+ for (name, lt), literal in zip(vars, literals):
88
+ if isinstance(literal, Exception):
89
+ raise RuntimeError(f"Failed to convert default value for parameter '{name}'") from literal
90
+ param = interface_pb2.Parameter(
91
+ var=interface_pb2.Variable(
92
+ type=lt,
93
+ ),
94
+ default=literal,
95
+ )
96
+ named_params.append(
97
+ common_pb2.NamedParameter(
98
+ name=name,
99
+ parameter=param,
100
+ ),
101
+ )
102
+ return named_params
103
+
104
+
105
+ def is_optional_type(tp) -> bool:
106
+ """
107
+ True if the *annotation* `tp` is equivalent to Optional[…].
108
+ Works for Optional[T], Union[T, None], and T | None.
109
+ """
110
+ return NoneType in get_args(tp) # fastest check
111
+
112
+
113
+ async def convert_from_native_to_inputs(
114
+ interface: NativeInterface, *args, custom_context: Dict[str, str] | None = None, **kwargs
115
+ ) -> Inputs:
60
116
  kwargs = interface.convert_to_kwargs(*args, **kwargs)
61
- if len(kwargs) == 0:
62
- return Inputs.empty()
117
+
118
+ missing = [key for key in interface.required_inputs() if key not in kwargs]
119
+ if missing:
120
+ raise ValueError(f"Missing required inputs: {', '.join(missing)}")
121
+
122
+ # Read custom_context from TaskContext if available (inside task execution)
123
+ # Otherwise use the passed parameter (for remote run initiation)
124
+ context_kvs = None
125
+ tctx = ctx()
126
+ if tctx and tctx.custom_context:
127
+ # Inside a task - read from TaskContext
128
+ context_to_use = tctx.custom_context
129
+ context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in context_to_use.items()]
130
+ elif custom_context:
131
+ # Remote run initiation
132
+ context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in custom_context.items()]
133
+
134
+ if len(interface.inputs) == 0:
135
+ # Handle context even for empty inputs
136
+ return Inputs(proto_inputs=common_pb2.Inputs(context=context_kvs))
137
+
63
138
  # fill in defaults if missing
139
+ type_hints: Dict[str, type] = {}
140
+ already_converted_kwargs: Dict[str, literals_pb2.Literal] = {}
64
141
  for input_name, (input_type, default_value) in interface.inputs.items():
65
- if input_name not in kwargs:
66
- if default_value is not None:
142
+ if input_name in kwargs:
143
+ type_hints[input_name] = input_type
144
+ elif (
145
+ (default_value is not None and default_value is not inspect.Signature.empty)
146
+ or (default_value is None and is_optional_type(input_type))
147
+ or input_type is None
148
+ or input_type is type(None)
149
+ ):
150
+ if default_value == NativeInterface.has_default:
151
+ if interface._remote_defaults is None or input_name not in interface._remote_defaults:
152
+ raise ValueError(f"Input '{input_name}' has a default value but it is not set in the interface.")
153
+ already_converted_kwargs[input_name] = interface._remote_defaults[input_name]
154
+ elif input_type is None or input_type is type(None):
155
+ # If the type is 'None' or 'class<None>', we assume it's a placeholder for no type
156
+ kwargs[input_name] = None
157
+ type_hints[input_name] = NoneType
158
+ else:
67
159
  kwargs[input_name] = default_value
68
- # todo: fill in Nones for optional inputs
69
- if len(kwargs) < len(interface.inputs):
70
- raise ValueError(
71
- f"Received {len(kwargs)} inputs but interface has {len(interface.inputs)}. "
72
- f"Please provide all required inputs."
73
- )
74
- literal_map = await TypeEngine.dict_to_literal_map(kwargs, interface.get_input_types())
160
+ type_hints[input_name] = input_type
161
+
162
+ literal_map = await TypeEngine.dict_to_literal_map(kwargs, type_hints)
163
+ if len(already_converted_kwargs) > 0:
164
+ copied_literals: Dict[str, literals_pb2.Literal] = {}
165
+ for k, v in literal_map.literals.items():
166
+ copied_literals[k] = v
167
+ # Add the already converted kwargs to the literal map
168
+ for k, v in already_converted_kwargs.items():
169
+ copied_literals[k] = v
170
+ literal_map = literals_pb2.LiteralMap(literals=copied_literals)
171
+
172
+ # Make sure we the interface, not literal_map or kwargs, because those may have a different order
75
173
  return Inputs(
76
- proto_inputs=run_definition_pb2.Inputs(
77
- literals=[run_definition_pb2.NamedLiteral(name=k, value=v) for k, v in literal_map.literals.items()]
174
+ proto_inputs=common_pb2.Inputs(
175
+ literals=[common_pb2.NamedLiteral(name=k, value=literal_map.literals[k]) for k in interface.inputs.keys()],
176
+ context=context_kvs,
78
177
  )
79
178
  )
80
179
 
81
180
 
82
- async def convert_from_native_to_outputs(o: Any, interface: NativeInterface) -> Outputs:
181
+ async def convert_from_inputs_to_native(native_interface: NativeInterface, inputs: Inputs) -> Dict[str, Any]:
182
+ """
183
+ Converts the inputs from a run definition proto to a native Python dictionary.
184
+ :param native_interface: The native interface of the task.
185
+ :param inputs: The run definition inputs proto.
186
+ :return: A dictionary of input names to their native Python values.
187
+ """
188
+ if not inputs or not inputs.proto_inputs or not inputs.proto_inputs.literals:
189
+ return {}
190
+
191
+ literals = {named_literal.name: named_literal.value for named_literal in inputs.proto_inputs.literals}
192
+ return await TypeEngine.literal_map_to_kwargs(
193
+ literals_pb2.LiteralMap(literals=literals), native_interface.get_input_types()
194
+ )
195
+
196
+
197
+ async def convert_from_native_to_outputs(o: Any, interface: NativeInterface, task_name: str = "") -> Outputs:
83
198
  # Always make it a tuple even if it's just one item to simplify logic below
84
199
  if not isinstance(o, tuple):
85
200
  o = (o,)
86
201
 
87
- assert len(interface.outputs) == len(interface.outputs), (
88
- f"Received {len(o)} outputs but interface has {len(interface.outputs)}"
89
- )
202
+ if len(interface.outputs) == 0:
203
+ if len(o) != 0:
204
+ if len(o) == 1 and o[0] is not None:
205
+ raise flyte.errors.RuntimeDataValidationError(
206
+ "o0",
207
+ f"Expected no outputs but got {o},did you miss a return type annotation?",
208
+ task_name,
209
+ )
210
+ else:
211
+ assert len(o) == len(interface.outputs), (
212
+ f"Received {len(o)} outputs but return annotation has {len(interface.outputs)} outputs specified. "
213
+ )
90
214
  named = []
91
215
  for (output_name, python_type), v in zip(interface.outputs.items(), o):
92
- lit = await TypeEngine.to_literal(v, python_type, TypeEngine.to_literal_type(python_type))
93
- named.append(run_definition_pb2.NamedLiteral(name=output_name, value=lit))
216
+ try:
217
+ lit = await TypeEngine.to_literal(v, python_type, TypeEngine.to_literal_type(python_type))
218
+ named.append(common_pb2.NamedLiteral(name=output_name, value=lit))
219
+ except TypeTransformerFailedError as e:
220
+ raise flyte.errors.RuntimeDataValidationError(output_name, e, task_name)
94
221
 
95
- return Outputs(proto_outputs=run_definition_pb2.Outputs(literals=named))
222
+ return Outputs(proto_outputs=common_pb2.Outputs(literals=named))
96
223
 
97
224
 
98
225
  async def convert_outputs_to_native(interface: NativeInterface, outputs: Outputs) -> Union[Any, Tuple[Any, ...]]:
@@ -119,7 +246,7 @@ def convert_error_to_native(err: execution_pb2.ExecutionError | Exception | Erro
119
246
  if isinstance(err, Error):
120
247
  err = err.err
121
248
 
122
- user_code, server_code = _clean_error_code(err.code)
249
+ user_code, _server_code = _clean_error_code(err.code)
123
250
  match err.kind:
124
251
  case execution_pb2.ExecutionError.UNKNOWN:
125
252
  return flyte.errors.RuntimeUnknownError(code=user_code, message=err.message, worker=err.worker)
@@ -185,21 +312,177 @@ def convert_from_native_to_error(err: BaseException) -> Error:
185
312
  )
186
313
 
187
314
 
188
- def generate_sub_action_id_and_output_path(tctx: TaskContext, task_name: str, inputs: Inputs) -> Tuple[ActionID, str]:
315
+ def hash_data(data: Union[str, bytes]) -> str:
316
+ """
317
+ Generate a hash for the given data. If the data is a string, it will be encoded to bytes before hashing.
318
+ :param data: The data to hash, can be a string or bytes.
319
+ :return: A hexadecimal string representation of the hash.
320
+ """
321
+ if isinstance(data, str):
322
+ data = data.encode("utf-8")
323
+ digest = hashlib.sha256(data).digest()
324
+ return base64.b64encode(digest).decode("utf-8")
325
+
326
+
327
+ def generate_inputs_hash(serialized_inputs: str | bytes) -> str:
328
+ """
329
+ Generate a hash for the inputs. This is used to uniquely identify the inputs for a task.
330
+ :return: A hexadecimal string representation of the hash.
331
+ """
332
+ return hash_data(serialized_inputs)
333
+
334
+
335
+ def generate_inputs_repr_for_literal(literal: literals_pb2.Literal) -> bytes:
336
+ """
337
+ Generate a byte representation for a single literal that is meant to be hashed as part of the cache key
338
+ computation for an Action. This function should just serialize the literal deterministically, but will
339
+ use an existing hash value if present in the Literal. This is trivial, except we need to handle nested literals
340
+ (inside collections and maps), that may have the hash property set.
341
+
342
+ :param literal: The literal to get a hashable representation for.
343
+ :return: byte representation of the literal that can be fed into a hash function.
344
+ """
345
+ # If the literal has a hash value, use that instead of serializing the full literal
346
+ if literal.hash:
347
+ return literal.hash.encode("utf-8")
348
+
349
+ if literal.HasField("collection"):
350
+ buf = bytearray()
351
+ for nested_literal in literal.collection.literals:
352
+ if nested_literal.hash:
353
+ buf += nested_literal.hash.encode("utf-8")
354
+ else:
355
+ buf += generate_inputs_repr_for_literal(nested_literal)
356
+
357
+ b = bytes(buf)
358
+ return b
359
+
360
+ elif literal.HasField("map"):
361
+ buf = bytearray()
362
+ # Sort keys to ensure deterministic ordering
363
+ for key in sorted(literal.map.literals.keys()):
364
+ nested_literal = literal.map.literals[key]
365
+ buf += key.encode("utf-8")
366
+ if nested_literal.hash:
367
+ buf += nested_literal.hash.encode("utf-8")
368
+ else:
369
+ buf += generate_inputs_repr_for_literal(nested_literal)
370
+
371
+ b = bytes(buf)
372
+ return b
373
+
374
+ # For all other cases (scalars, etc.), just serialize the literal normally
375
+ return literal.SerializeToString(deterministic=True)
376
+
377
+
378
+ def generate_inputs_hash_for_named_literals(inputs: list[common_pb2.NamedLiteral]) -> str:
379
+ """
380
+ Generate a hash for the inputs using the new literal representation approach that respects
381
+ hash values already present in literals. This is used to uniquely identify the inputs for a task
382
+ when some literals may have precomputed hash values.
383
+
384
+ :param inputs: List of NamedLiteral inputs to hash.
385
+ :return: A base64-encoded string representation of the hash.
386
+ """
387
+ if not inputs:
388
+ return ""
389
+
390
+ # Build the byte representation by concatenating each literal's representation
391
+ combined_bytes = b""
392
+ for named_literal in inputs:
393
+ # Add the name to ensure order matters
394
+ name_bytes = named_literal.name.encode("utf-8")
395
+ literal_bytes = generate_inputs_repr_for_literal(named_literal.value)
396
+ # Combine name and literal bytes with a separator to avoid collisions
397
+ combined_bytes += name_bytes + b":" + literal_bytes + b";"
398
+
399
+ return hash_data(combined_bytes)
400
+
401
+
402
+ def generate_inputs_hash_from_proto(inputs: common_pb2.Inputs) -> str:
403
+ """
404
+ Generate a hash for the inputs. This is used to uniquely identify the inputs for a task.
405
+ :param inputs: The inputs to hash.
406
+ :return: A hexadecimal string representation of the hash.
407
+ """
408
+ if not inputs or not inputs.literals:
409
+ return ""
410
+ return generate_inputs_hash_for_named_literals(list(inputs.literals))
411
+
412
+
413
+ def generate_interface_hash(task_interface: interface_pb2.TypedInterface) -> str:
414
+ """
415
+ Generate a hash for the task interface. This is used to uniquely identify the task interface.
416
+ :param task_interface: The interface of the task.
417
+ :return: A hexadecimal string representation of the hash.
418
+ """
419
+ if not task_interface:
420
+ return ""
421
+ serialized_interface = task_interface.SerializeToString(deterministic=True)
422
+ return hash_data(serialized_interface)
423
+
424
+
425
+ def generate_cache_key_hash(
426
+ task_name: str,
427
+ inputs_hash: str,
428
+ task_interface: interface_pb2.TypedInterface,
429
+ cache_version: str,
430
+ ignored_input_vars: List[str],
431
+ proto_inputs: common_pb2.Inputs,
432
+ ) -> str:
433
+ """
434
+ Generate a cache key hash based on the inputs hash, task name, task interface, and cache version.
435
+ This is used to uniquely identify the cache key for a task.
436
+
437
+ :param task_name: The name of the task.
438
+ :param inputs_hash: The hash of the inputs.
439
+ :param task_interface: The interface of the task.
440
+ :param cache_version: The version of the cache.
441
+ :param ignored_input_vars: A list of input variable names to ignore when generating the cache key.
442
+ :param proto_inputs: The proto inputs for the task, only used if there are ignored inputs.
443
+ :return: A hexadecimal string representation of the cache key hash.
444
+ """
445
+ if ignored_input_vars:
446
+ filtered = [named_lit for named_lit in proto_inputs.literals if named_lit.name not in ignored_input_vars]
447
+ final = common_pb2.Inputs(literals=filtered)
448
+ final_inputs = generate_inputs_hash_from_proto(final)
449
+ else:
450
+ final_inputs = inputs_hash
451
+
452
+ interface_hash = generate_interface_hash(task_interface)
453
+
454
+ data = f"{final_inputs}{task_name}{interface_hash}{cache_version}"
455
+ return hash_data(data)
456
+
457
+
458
+ def generate_sub_action_id_and_output_path(
459
+ tctx: TaskContext,
460
+ task_spec_or_name: task_definition_pb2.TaskSpec | str,
461
+ inputs_hash: str,
462
+ invoke_seq: int,
463
+ ) -> Tuple[ActionID, str]:
189
464
  """
190
465
  Generate a sub-action ID and output path based on the current task context, task name, and inputs.
466
+
467
+ action name = current action name + task name + input hash + group name (if available)
191
468
  :param tctx:
192
- :param task_name:
193
- :param inputs:
469
+ :param task_spec_or_name: task specification or task name. Task name is only used in case of trace actions.
470
+ :param inputs_hash: Consistent hash string of the inputs
471
+ :param invoke_seq: The sequence number of the invocation, used to differentiate between multiple invocations.
194
472
  :return:
195
473
  """
196
474
  current_action_id = tctx.action
197
475
  current_output_path = tctx.run_base_dir
198
- inputs_hash = pbhash.compute_hash_string(inputs.proto_inputs)
476
+ if isinstance(task_spec_or_name, task_definition_pb2.TaskSpec):
477
+ task_spec_or_name.task_template.interface
478
+ task_hash = hash_data(task_spec_or_name.SerializeToString(deterministic=True))
479
+ else:
480
+ task_hash = task_spec_or_name
199
481
  sub_action_id = current_action_id.new_sub_action_from(
200
- task_name=task_name,
482
+ task_hash=task_hash,
201
483
  input_hash=inputs_hash,
202
484
  group=tctx.group_data.name if tctx.group_data else None,
485
+ task_call_seq=invoke_seq,
203
486
  )
204
487
  sub_run_output_path = storage.join(current_output_path, sub_action_id.name)
205
488
  return sub_action_id, sub_run_output_path
@@ -1,16 +1,20 @@
1
- from typing import List, Optional, Tuple
1
+ import importlib
2
+ import os
3
+ import traceback
4
+ from typing import List, Optional, Tuple, Type
2
5
 
3
6
  import flyte.errors
4
7
  from flyte._code_bundle import download_bundle
5
8
  from flyte._context import contextual_run
6
- from flyte._datastructures import ActionID, Checkpoints, CodeBundle, RawDataPath
7
9
  from flyte._internal import Controller
8
10
  from flyte._internal.imagebuild.image_builder import ImageCache
9
11
  from flyte._logging import log, logger
12
+ from flyte._metrics import Stopwatch
10
13
  from flyte._task import TaskTemplate
14
+ from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
11
15
 
16
+ from ..._utils import adjust_sys_path
12
17
  from .convert import Error, Inputs, Outputs
13
- from .task_serde import load_task
14
18
  from .taskrunner import (
15
19
  convert_and_run,
16
20
  extract_download_run_upload,
@@ -51,25 +55,95 @@ async def direct_dispatch(
51
55
  )
52
56
 
53
57
 
58
+ def load_class(qualified_name) -> Type:
59
+ """
60
+ Load a class from a qualified name. The qualified name should be in the format 'module.ClassName'.
61
+ :param qualified_name: The qualified name of the class to load.
62
+ :return: The class object.
63
+ """
64
+ module_name, class_name = qualified_name.rsplit(".", 1) # Split module and class
65
+ module = importlib.import_module(module_name) # Import the module
66
+ return getattr(module, class_name) # Retrieve the class
67
+
68
+
69
+ def load_task(resolver: str, *resolver_args: str) -> TaskTemplate:
70
+ """
71
+ Load a task from a resolver. This is a placeholder function.
72
+
73
+ :param resolver: The resolver to use to load the task.
74
+ :param resolver_args: Arguments to pass to the resolver.
75
+ :return: The loaded task.
76
+ """
77
+ resolver_class = load_class(resolver)
78
+ resolver_instance = resolver_class()
79
+ try:
80
+ return resolver_instance.load_task(resolver_args)
81
+ except ModuleNotFoundError as e:
82
+ cwd = os.getcwd()
83
+ files = []
84
+ try:
85
+ for root, dirs, filenames in os.walk(cwd):
86
+ for name in dirs + filenames:
87
+ rel_path = os.path.relpath(os.path.join(root, name), cwd)
88
+ files.append(rel_path)
89
+ except Exception as list_err:
90
+ files = [f"(Failed to list directory: {list_err})"]
91
+
92
+ msg = (
93
+ "\n\nFull traceback:\n" + "".join(traceback.format_exc()) + f"\n[ImportError Diagnostics]\n"
94
+ f"Module '{e.name}' not found in either the Python virtual environment or the current working directory.\n"
95
+ f"Current working directory: {cwd}\n"
96
+ f"Files found under current directory:\n" + "\n".join(f" - {f}" for f in files)
97
+ )
98
+ raise ModuleNotFoundError(msg) from e
99
+
100
+
101
+ def load_pkl_task(code_bundle: CodeBundle) -> TaskTemplate:
102
+ """
103
+ Loads a task from a pickled code bundle.
104
+ :param code_bundle: The code bundle to load the task from.
105
+ :return: The loaded task template.
106
+ """
107
+ logger.debug(f"Loading task from pkl: {code_bundle.downloaded_path}")
108
+ try:
109
+ import gzip
110
+
111
+ import cloudpickle
112
+
113
+ with gzip.open(str(code_bundle.downloaded_path), "rb") as f:
114
+ return cloudpickle.load(f)
115
+ except Exception as e:
116
+ logger.exception(f"Failed to load pickled task from {code_bundle.downloaded_path}. Reason: {e!s}")
117
+ raise
118
+
119
+
120
+ async def download_code_bundle(code_bundle: CodeBundle) -> CodeBundle:
121
+ """
122
+ Downloads the code bundle if it is not already downloaded.
123
+ :param code_bundle: The code bundle to download.
124
+ :return: The code bundle with the downloaded path.
125
+ """
126
+ adjust_sys_path([str(code_bundle.destination)])
127
+ logger.debug(f"Downloading {code_bundle}")
128
+ sw = Stopwatch("download_code_bundle")
129
+ sw.start()
130
+ downloaded_path = await download_bundle(code_bundle)
131
+ sw.stop()
132
+ return code_bundle.with_downloaded_path(downloaded_path)
133
+
134
+
54
135
  async def _download_and_load_task(
55
136
  code_bundle: CodeBundle | None, resolver: str | None = None, resolver_args: List[str] | None = None
56
137
  ) -> TaskTemplate:
57
138
  if code_bundle and (code_bundle.tgz or code_bundle.pkl):
58
139
  logger.debug(f"Downloading {code_bundle}")
59
- downloaded_path = await download_bundle(code_bundle)
60
- code_bundle = code_bundle.with_downloaded_path(downloaded_path)
140
+ code_bundle = await download_code_bundle(code_bundle)
61
141
  if code_bundle.pkl:
62
- try:
63
- logger.debug(f"Loading task from pkl: {code_bundle.downloaded_path}")
64
- import gzip
65
-
66
- import cloudpickle
67
-
68
- with gzip.open(str(code_bundle.downloaded_path), "rb") as f:
69
- return cloudpickle.load(f)
70
- except Exception as e:
71
- logger.exception(f"Failed to load pickled task from {code_bundle.downloaded_path}. Reason: {e!s}")
72
- raise
142
+ sw = Stopwatch("load_pkl_task")
143
+ sw.start()
144
+ result = load_pkl_task(code_bundle)
145
+ sw.stop()
146
+ return result
73
147
 
74
148
  if not resolver or not resolver_args:
75
149
  raise flyte.errors.RuntimeSystemError(
@@ -78,11 +152,19 @@ async def _download_and_load_task(
78
152
  logger.debug(
79
153
  f"Loading task from tgz: {code_bundle.downloaded_path}, resolver: {resolver}, args: {resolver_args}"
80
154
  )
81
- return load_task(resolver, *resolver_args)
155
+ sw = Stopwatch("load_task_from_tgz")
156
+ sw.start()
157
+ result = load_task(resolver, *resolver_args)
158
+ sw.stop()
159
+ return result
82
160
  if not resolver or not resolver_args:
83
161
  raise flyte.errors.RuntimeSystemError("MalformedCommand", "Resolver and resolver args are required. for task")
84
162
  logger.debug(f"No code bundle provided, loading task from resolver: {resolver}, args: {resolver_args}")
85
- return load_task(resolver, *resolver_args)
163
+ sw = Stopwatch("load_task_from_resolver")
164
+ sw.start()
165
+ result = load_task(resolver, *resolver_args)
166
+ sw.stop()
167
+ return result
86
168
 
87
169
 
88
170
  @log
@@ -99,6 +181,7 @@ async def load_and_run_task(
99
181
  code_bundle: CodeBundle | None = None,
100
182
  input_path: str | None = None,
101
183
  image_cache: ImageCache | None = None,
184
+ interactive_mode: bool = False,
102
185
  ):
103
186
  """
104
187
  This method is invoked from the runtime/CLI and is used to run a task. This creates the context tree,
@@ -116,7 +199,10 @@ async def load_and_run_task(
116
199
  :param code_bundle: The code bundle to use for the task.
117
200
  :param input_path: The input path to use for the task.
118
201
  :param image_cache: Mappings of Image identifiers to image URIs.
202
+ :param interactive_mode: Whether to run the task in interactive mode.
119
203
  """
204
+ sw = Stopwatch("load_and_run_task_total")
205
+ sw.start()
120
206
  task = await _download_and_load_task(code_bundle, resolver, resolver_args)
121
207
 
122
208
  await contextual_run(
@@ -132,4 +218,6 @@ async def load_and_run_task(
132
218
  code_bundle=code_bundle,
133
219
  input_path=input_path,
134
220
  image_cache=image_cache,
221
+ interactive_mode=interactive_mode,
135
222
  )
223
+ sw.stop()