flyte 2.0.0b22__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/runtime.py +43 -5
  3. flyte/_cache/cache.py +4 -2
  4. flyte/_cache/local_cache.py +216 -0
  5. flyte/_code_bundle/_ignore.py +1 -1
  6. flyte/_code_bundle/_packaging.py +4 -4
  7. flyte/_code_bundle/_utils.py +14 -8
  8. flyte/_code_bundle/bundle.py +13 -5
  9. flyte/_constants.py +1 -0
  10. flyte/_context.py +4 -1
  11. flyte/_custom_context.py +73 -0
  12. flyte/_debug/constants.py +0 -1
  13. flyte/_debug/vscode.py +6 -1
  14. flyte/_deploy.py +223 -59
  15. flyte/_environment.py +5 -0
  16. flyte/_excepthook.py +1 -1
  17. flyte/_image.py +144 -82
  18. flyte/_initialize.py +95 -12
  19. flyte/_interface.py +2 -0
  20. flyte/_internal/controllers/_local_controller.py +65 -24
  21. flyte/_internal/controllers/_trace.py +1 -1
  22. flyte/_internal/controllers/remote/_action.py +13 -11
  23. flyte/_internal/controllers/remote/_client.py +1 -1
  24. flyte/_internal/controllers/remote/_controller.py +9 -4
  25. flyte/_internal/controllers/remote/_core.py +16 -16
  26. flyte/_internal/controllers/remote/_informer.py +4 -4
  27. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  28. flyte/_internal/imagebuild/docker_builder.py +139 -84
  29. flyte/_internal/imagebuild/image_builder.py +7 -13
  30. flyte/_internal/imagebuild/remote_builder.py +65 -13
  31. flyte/_internal/imagebuild/utils.py +51 -3
  32. flyte/_internal/resolvers/_task_module.py +5 -38
  33. flyte/_internal/resolvers/default.py +2 -2
  34. flyte/_internal/runtime/convert.py +42 -20
  35. flyte/_internal/runtime/entrypoints.py +24 -1
  36. flyte/_internal/runtime/io.py +21 -8
  37. flyte/_internal/runtime/resources_serde.py +20 -6
  38. flyte/_internal/runtime/reuse.py +1 -1
  39. flyte/_internal/runtime/rusty.py +20 -5
  40. flyte/_internal/runtime/task_serde.py +33 -27
  41. flyte/_internal/runtime/taskrunner.py +10 -1
  42. flyte/_internal/runtime/trigger_serde.py +160 -0
  43. flyte/_internal/runtime/types_serde.py +1 -1
  44. flyte/_keyring/file.py +39 -9
  45. flyte/_logging.py +79 -12
  46. flyte/_map.py +31 -12
  47. flyte/_module.py +70 -0
  48. flyte/_pod.py +2 -2
  49. flyte/_resources.py +213 -31
  50. flyte/_run.py +107 -41
  51. flyte/_task.py +66 -10
  52. flyte/_task_environment.py +96 -24
  53. flyte/_task_plugins.py +4 -2
  54. flyte/_trigger.py +1000 -0
  55. flyte/_utils/__init__.py +2 -1
  56. flyte/_utils/asyn.py +3 -1
  57. flyte/_utils/docker_credentials.py +173 -0
  58. flyte/_utils/module_loader.py +17 -2
  59. flyte/_version.py +3 -3
  60. flyte/cli/_abort.py +3 -3
  61. flyte/cli/_build.py +1 -3
  62. flyte/cli/_common.py +78 -7
  63. flyte/cli/_create.py +178 -3
  64. flyte/cli/_delete.py +23 -1
  65. flyte/cli/_deploy.py +49 -11
  66. flyte/cli/_get.py +79 -34
  67. flyte/cli/_params.py +8 -6
  68. flyte/cli/_plugins.py +209 -0
  69. flyte/cli/_run.py +127 -11
  70. flyte/cli/_serve.py +64 -0
  71. flyte/cli/_update.py +37 -0
  72. flyte/cli/_user.py +17 -0
  73. flyte/cli/main.py +30 -4
  74. flyte/config/_config.py +2 -0
  75. flyte/config/_internal.py +1 -0
  76. flyte/config/_reader.py +3 -3
  77. flyte/connectors/__init__.py +11 -0
  78. flyte/connectors/_connector.py +270 -0
  79. flyte/connectors/_server.py +197 -0
  80. flyte/connectors/utils.py +135 -0
  81. flyte/errors.py +10 -1
  82. flyte/extend.py +8 -1
  83. flyte/extras/_container.py +6 -1
  84. flyte/git/_config.py +11 -9
  85. flyte/io/__init__.py +2 -0
  86. flyte/io/_dataframe/__init__.py +2 -0
  87. flyte/io/_dataframe/basic_dfs.py +1 -1
  88. flyte/io/_dataframe/dataframe.py +12 -8
  89. flyte/io/_dir.py +551 -120
  90. flyte/io/_file.py +538 -141
  91. flyte/models.py +57 -12
  92. flyte/remote/__init__.py +6 -1
  93. flyte/remote/_action.py +18 -16
  94. flyte/remote/_client/_protocols.py +39 -4
  95. flyte/remote/_client/auth/_channel.py +10 -6
  96. flyte/remote/_client/controlplane.py +17 -5
  97. flyte/remote/_console.py +3 -2
  98. flyte/remote/_data.py +4 -3
  99. flyte/remote/_logs.py +3 -3
  100. flyte/remote/_run.py +47 -7
  101. flyte/remote/_secret.py +26 -17
  102. flyte/remote/_task.py +21 -9
  103. flyte/remote/_trigger.py +306 -0
  104. flyte/remote/_user.py +33 -0
  105. flyte/storage/__init__.py +6 -1
  106. flyte/storage/_parallel_reader.py +274 -0
  107. flyte/storage/_storage.py +185 -103
  108. flyte/types/__init__.py +16 -0
  109. flyte/types/_interface.py +2 -2
  110. flyte/types/_pickle.py +17 -4
  111. flyte/types/_string_literals.py +8 -9
  112. flyte/types/_type_engine.py +26 -19
  113. flyte/types/_utils.py +1 -1
  114. {flyte-2.0.0b22.data → flyte-2.0.0b30.data}/scripts/runtime.py +43 -5
  115. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/METADATA +8 -1
  116. flyte-2.0.0b30.dist-info/RECORD +192 -0
  117. flyte/_protos/__init__.py +0 -0
  118. flyte/_protos/common/authorization_pb2.py +0 -66
  119. flyte/_protos/common/authorization_pb2.pyi +0 -108
  120. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  121. flyte/_protos/common/identifier_pb2.py +0 -99
  122. flyte/_protos/common/identifier_pb2.pyi +0 -120
  123. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  124. flyte/_protos/common/identity_pb2.py +0 -48
  125. flyte/_protos/common/identity_pb2.pyi +0 -72
  126. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  127. flyte/_protos/common/list_pb2.py +0 -36
  128. flyte/_protos/common/list_pb2.pyi +0 -71
  129. flyte/_protos/common/list_pb2_grpc.py +0 -4
  130. flyte/_protos/common/policy_pb2.py +0 -37
  131. flyte/_protos/common/policy_pb2.pyi +0 -27
  132. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  133. flyte/_protos/common/role_pb2.py +0 -37
  134. flyte/_protos/common/role_pb2.pyi +0 -53
  135. flyte/_protos/common/role_pb2_grpc.py +0 -4
  136. flyte/_protos/common/runtime_version_pb2.py +0 -28
  137. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  138. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  139. flyte/_protos/imagebuilder/definition_pb2.py +0 -60
  140. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -153
  141. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  142. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  143. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  144. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  145. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  146. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  147. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  148. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  149. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  150. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  151. flyte/_protos/secret/definition_pb2.py +0 -49
  152. flyte/_protos/secret/definition_pb2.pyi +0 -93
  153. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  154. flyte/_protos/secret/payload_pb2.py +0 -62
  155. flyte/_protos/secret/payload_pb2.pyi +0 -94
  156. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  157. flyte/_protos/secret/secret_pb2.py +0 -38
  158. flyte/_protos/secret/secret_pb2.pyi +0 -6
  159. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  160. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  161. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  162. flyte/_protos/workflow/common_pb2.py +0 -27
  163. flyte/_protos/workflow/common_pb2.pyi +0 -14
  164. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  165. flyte/_protos/workflow/environment_pb2.py +0 -29
  166. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  167. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  168. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  169. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  170. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  171. flyte/_protos/workflow/queue_service_pb2.py +0 -111
  172. flyte/_protos/workflow/queue_service_pb2.pyi +0 -168
  173. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  174. flyte/_protos/workflow/run_definition_pb2.py +0 -123
  175. flyte/_protos/workflow/run_definition_pb2.pyi +0 -352
  176. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  177. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  178. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  179. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  180. flyte/_protos/workflow/run_service_pb2.py +0 -137
  181. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  182. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  183. flyte/_protos/workflow/state_service_pb2.py +0 -67
  184. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  185. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  186. flyte/_protos/workflow/task_definition_pb2.py +0 -82
  187. flyte/_protos/workflow/task_definition_pb2.pyi +0 -88
  188. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  189. flyte/_protos/workflow/task_service_pb2.py +0 -60
  190. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  191. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  192. flyte-2.0.0b22.dist-info/RECORD +0 -250
  193. {flyte-2.0.0b22.data → flyte-2.0.0b30.data}/scripts/debug.py +0 -0
  194. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  195. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +0 -0
  196. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  197. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,11 @@
1
- import inspect
2
- import os
3
1
  import pathlib
4
- import sys
5
2
  from typing import Tuple
6
3
 
4
+ from flyte._module import extract_obj_module
7
5
  from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
8
6
 
9
7
 
10
- def extract_task_module(task: TaskTemplate, /, source_dir: pathlib.Path | None = None) -> Tuple[str, str]:
8
+ def extract_task_module(task: TaskTemplate, /, source_dir: pathlib.Path) -> Tuple[str, str]:
11
9
  """
12
10
  Extract the task module from the task template.
13
11
 
@@ -15,40 +13,9 @@ def extract_task_module(task: TaskTemplate, /, source_dir: pathlib.Path | None =
15
13
  :param source_dir: The source directory to use for relative paths.
16
14
  :return: A tuple containing the entity name, module
17
15
  """
18
- entity_name = task.name
19
16
  if isinstance(task, AsyncFunctionTaskTemplate):
20
- entity_module = inspect.getmodule(task.func)
21
- if entity_module is None:
22
- raise ValueError(f"Task {entity_name} has no module.")
23
-
24
- fp = entity_module.__file__
25
- if fp is None:
26
- raise ValueError(f"Task {entity_name} has no module.")
27
-
28
- file_path = pathlib.Path(fp)
29
- # Get the relative path to the current directory
30
- # Will raise ValueError if the file is not in the source directory
31
- relative_path = file_path.relative_to(str(source_dir))
32
-
33
- if relative_path == pathlib.Path("."):
34
- entity_module_name = entity_module.__name__
35
- else:
36
- # Replace file separators with dots and remove the '.py' extension
37
- dotted_path = os.path.splitext(str(relative_path))[0].replace(os.sep, ".")
38
- entity_module_name = dotted_path
39
-
40
17
  entity_name = task.func.__name__
18
+ entity_module_name = extract_obj_module(task.func, source_dir)
19
+ return entity_name, entity_module_name
41
20
  else:
42
- raise NotImplementedError(f"Task module {entity_name} not implemented.")
43
-
44
- if entity_module_name == "__main__":
45
- """
46
- This case is for the case in which the task is run from the main module.
47
- """
48
- fp = sys.modules["__main__"].__file__
49
- if fp is None:
50
- raise ValueError(f"Task {entity_name} has no module.")
51
- main_path = pathlib.Path(fp)
52
- entity_module_name = main_path.stem
53
-
54
- return entity_name, entity_module_name
21
+ raise NotImplementedError(f"Task module {task.name} not implemented.")
@@ -1,6 +1,6 @@
1
1
  import importlib
2
2
  from pathlib import Path
3
- from typing import List, Optional
3
+ from typing import List
4
4
 
5
5
  from flyte._internal.resolvers._task_module import extract_task_module
6
6
  from flyte._internal.resolvers.common import Resolver
@@ -23,6 +23,6 @@ class DefaultTaskResolver(Resolver):
23
23
  task_def = getattr(task_module, task_name)
24
24
  return task_def
25
25
 
26
- def loader_args(self, task: TaskTemplate, root_dir: Optional[Path] = None) -> List[str]: # type:ignore
26
+ def loader_args(self, task: TaskTemplate, root_dir: Path) -> List[str]: # type:ignore
27
27
  t, m = extract_task_module(task, root_dir)
28
28
  return ["mod", m, "instance", t]
@@ -8,27 +8,33 @@ from dataclasses import dataclass
8
8
  from types import NoneType
9
9
  from typing import Any, Dict, List, Tuple, Union, get_args
10
10
 
11
- from flyteidl.core import execution_pb2, interface_pb2, literals_pb2
11
+ from flyteidl2.core import execution_pb2, interface_pb2, literals_pb2
12
+ from flyteidl2.task import common_pb2, task_definition_pb2
12
13
 
13
14
  import flyte.errors
14
15
  import flyte.storage as storage
15
- from flyte._protos.workflow import common_pb2, run_definition_pb2, task_definition_pb2
16
+ from flyte._context import ctx
16
17
  from flyte.models import ActionID, NativeInterface, TaskContext
17
18
  from flyte.types import TypeEngine, TypeTransformerFailedError
18
19
 
19
20
 
20
21
  @dataclass(frozen=True)
21
22
  class Inputs:
22
- proto_inputs: run_definition_pb2.Inputs
23
+ proto_inputs: common_pb2.Inputs
23
24
 
24
25
  @classmethod
25
26
  def empty(cls) -> "Inputs":
26
- return cls(proto_inputs=run_definition_pb2.Inputs())
27
+ return cls(proto_inputs=common_pb2.Inputs())
28
+
29
+ @property
30
+ def context(self) -> Dict[str, str]:
31
+ """Get the context as a dictionary."""
32
+ return {kv.key: kv.value for kv in self.proto_inputs.context}
27
33
 
28
34
 
29
35
  @dataclass(frozen=True)
30
36
  class Outputs:
31
- proto_outputs: run_definition_pb2.Outputs
37
+ proto_outputs: common_pb2.Outputs
32
38
 
33
39
 
34
40
  @dataclass
@@ -102,15 +108,30 @@ def is_optional_type(tp) -> bool:
102
108
  return NoneType in get_args(tp) # fastest check
103
109
 
104
110
 
105
- async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwargs) -> Inputs:
111
+ async def convert_from_native_to_inputs(
112
+ interface: NativeInterface, *args, custom_context: Dict[str, str] | None = None, **kwargs
113
+ ) -> Inputs:
106
114
  kwargs = interface.convert_to_kwargs(*args, **kwargs)
107
115
 
108
116
  missing = [key for key in interface.required_inputs() if key not in kwargs]
109
117
  if missing:
110
118
  raise ValueError(f"Missing required inputs: {', '.join(missing)}")
111
119
 
120
+ # Read custom_context from TaskContext if available (inside task execution)
121
+ # Otherwise use the passed parameter (for remote run initiation)
122
+ context_kvs = None
123
+ tctx = ctx()
124
+ if tctx and tctx.custom_context:
125
+ # Inside a task - read from TaskContext
126
+ context_to_use = tctx.custom_context
127
+ context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in context_to_use.items()]
128
+ elif custom_context:
129
+ # Remote run initiation
130
+ context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in custom_context.items()]
131
+
112
132
  if len(interface.inputs) == 0:
113
- return Inputs.empty()
133
+ # Handle context even for empty inputs
134
+ return Inputs(proto_inputs=common_pb2.Inputs(context=context_kvs))
114
135
 
115
136
  # fill in defaults if missing
116
137
  type_hints: Dict[str, type] = {}
@@ -122,13 +143,14 @@ async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwa
122
143
  (default_value is not None and default_value is not inspect.Signature.empty)
123
144
  or (default_value is None and is_optional_type(input_type))
124
145
  or input_type is None
146
+ or input_type is type(None)
125
147
  ):
126
148
  if default_value == NativeInterface.has_default:
127
149
  if interface._remote_defaults is None or input_name not in interface._remote_defaults:
128
150
  raise ValueError(f"Input '{input_name}' has a default value but it is not set in the interface.")
129
151
  already_converted_kwargs[input_name] = interface._remote_defaults[input_name]
130
- elif input_type is None:
131
- # If the type is None, we assume it's a placeholder for no type
152
+ elif input_type is None or input_type is type(None):
153
+ # If the type is 'None' or 'class<None>', we assume it's a placeholder for no type
132
154
  kwargs[input_name] = None
133
155
  type_hints[input_name] = NoneType
134
156
  else:
@@ -144,12 +166,12 @@ async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwa
144
166
  for k, v in already_converted_kwargs.items():
145
167
  copied_literals[k] = v
146
168
  literal_map = literals_pb2.LiteralMap(literals=copied_literals)
169
+
147
170
  # Make sure we the interface, not literal_map or kwargs, because those may have a different order
148
171
  return Inputs(
149
- proto_inputs=run_definition_pb2.Inputs(
150
- literals=[
151
- run_definition_pb2.NamedLiteral(name=k, value=literal_map.literals[k]) for k in interface.inputs.keys()
152
- ]
172
+ proto_inputs=common_pb2.Inputs(
173
+ literals=[common_pb2.NamedLiteral(name=k, value=literal_map.literals[k]) for k in interface.inputs.keys()],
174
+ context=context_kvs,
153
175
  )
154
176
  )
155
177
 
@@ -191,11 +213,11 @@ async def convert_from_native_to_outputs(o: Any, interface: NativeInterface, tas
191
213
  for (output_name, python_type), v in zip(interface.outputs.items(), o):
192
214
  try:
193
215
  lit = await TypeEngine.to_literal(v, python_type, TypeEngine.to_literal_type(python_type))
194
- named.append(run_definition_pb2.NamedLiteral(name=output_name, value=lit))
216
+ named.append(common_pb2.NamedLiteral(name=output_name, value=lit))
195
217
  except TypeTransformerFailedError as e:
196
218
  raise flyte.errors.RuntimeDataValidationError(output_name, e, task_name)
197
219
 
198
- return Outputs(proto_outputs=run_definition_pb2.Outputs(literals=named))
220
+ return Outputs(proto_outputs=common_pb2.Outputs(literals=named))
199
221
 
200
222
 
201
223
  async def convert_outputs_to_native(interface: NativeInterface, outputs: Outputs) -> Union[Any, Tuple[Any, ...]]:
@@ -222,7 +244,7 @@ def convert_error_to_native(err: execution_pb2.ExecutionError | Exception | Erro
222
244
  if isinstance(err, Error):
223
245
  err = err.err
224
246
 
225
- user_code, server_code = _clean_error_code(err.code)
247
+ user_code, _server_code = _clean_error_code(err.code)
226
248
  match err.kind:
227
249
  case execution_pb2.ExecutionError.UNKNOWN:
228
250
  return flyte.errors.RuntimeUnknownError(code=user_code, message=err.message, worker=err.worker)
@@ -351,7 +373,7 @@ def generate_inputs_repr_for_literal(literal: literals_pb2.Literal) -> bytes:
351
373
  return literal.SerializeToString(deterministic=True)
352
374
 
353
375
 
354
- def generate_inputs_hash_for_named_literals(inputs: list[run_definition_pb2.NamedLiteral]) -> str:
376
+ def generate_inputs_hash_for_named_literals(inputs: list[common_pb2.NamedLiteral]) -> str:
355
377
  """
356
378
  Generate a hash for the inputs using the new literal representation approach that respects
357
379
  hash values already present in literals. This is used to uniquely identify the inputs for a task
@@ -375,7 +397,7 @@ def generate_inputs_hash_for_named_literals(inputs: list[run_definition_pb2.Name
375
397
  return hash_data(combined_bytes)
376
398
 
377
399
 
378
- def generate_inputs_hash_from_proto(inputs: run_definition_pb2.Inputs) -> str:
400
+ def generate_inputs_hash_from_proto(inputs: common_pb2.Inputs) -> str:
379
401
  """
380
402
  Generate a hash for the inputs. This is used to uniquely identify the inputs for a task.
381
403
  :param inputs: The inputs to hash.
@@ -404,7 +426,7 @@ def generate_cache_key_hash(
404
426
  task_interface: interface_pb2.TypedInterface,
405
427
  cache_version: str,
406
428
  ignored_input_vars: List[str],
407
- proto_inputs: run_definition_pb2.Inputs,
429
+ proto_inputs: common_pb2.Inputs,
408
430
  ) -> str:
409
431
  """
410
432
  Generate a cache key hash based on the inputs hash, task name, task interface, and cache version.
@@ -420,7 +442,7 @@ def generate_cache_key_hash(
420
442
  """
421
443
  if ignored_input_vars:
422
444
  filtered = [named_lit for named_lit in proto_inputs.literals if named_lit.name not in ignored_input_vars]
423
- final = run_definition_pb2.Inputs(literals=filtered)
445
+ final = common_pb2.Inputs(literals=filtered)
424
446
  final_inputs = generate_inputs_hash_from_proto(final)
425
447
  else:
426
448
  final_inputs = inputs_hash
@@ -1,4 +1,6 @@
1
1
  import importlib
2
+ import os
3
+ import traceback
2
4
  from typing import List, Optional, Tuple, Type
3
5
 
4
6
  import flyte.errors
@@ -10,6 +12,7 @@ from flyte._logging import log, logger
10
12
  from flyte._task import TaskTemplate
11
13
  from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
12
14
 
15
+ from ..._utils import adjust_sys_path
13
16
  from .convert import Error, Inputs, Outputs
14
17
  from .taskrunner import (
15
18
  convert_and_run,
@@ -72,7 +75,26 @@ def load_task(resolver: str, *resolver_args: str) -> TaskTemplate:
72
75
  """
73
76
  resolver_class = load_class(resolver)
74
77
  resolver_instance = resolver_class()
75
- return resolver_instance.load_task(resolver_args)
78
+ try:
79
+ return resolver_instance.load_task(resolver_args)
80
+ except ModuleNotFoundError as e:
81
+ cwd = os.getcwd()
82
+ files = []
83
+ try:
84
+ for root, dirs, filenames in os.walk(cwd):
85
+ for name in dirs + filenames:
86
+ rel_path = os.path.relpath(os.path.join(root, name), cwd)
87
+ files.append(rel_path)
88
+ except Exception as list_err:
89
+ files = [f"(Failed to list directory: {list_err})"]
90
+
91
+ msg = (
92
+ "\n\nFull traceback:\n" + "".join(traceback.format_exc()) + f"\n[ImportError Diagnostics]\n"
93
+ f"Module '{e.name}' not found in either the Python virtual environment or the current working directory.\n"
94
+ f"Current working directory: {cwd}\n"
95
+ f"Files found under current directory:\n" + "\n".join(f" - {f}" for f in files)
96
+ )
97
+ raise ModuleNotFoundError(msg) from e
76
98
 
77
99
 
78
100
  def load_pkl_task(code_bundle: CodeBundle) -> TaskTemplate:
@@ -100,6 +122,7 @@ async def download_code_bundle(code_bundle: CodeBundle) -> CodeBundle:
100
122
  :param code_bundle: The code bundle to download.
101
123
  :return: The code bundle with the downloaded path.
102
124
  """
125
+ adjust_sys_path()
103
126
  logger.debug(f"Downloading {code_bundle}")
104
127
  downloaded_path = await download_bundle(code_bundle)
105
128
  return code_bundle.with_downloaded_path(downloaded_path)
@@ -5,10 +5,12 @@ It uses the storage module to handle the actual uploading and downloading of fil
5
5
  TODO: Convert to use streaming apis
6
6
  """
7
7
 
8
- from flyteidl.core import errors_pb2, execution_pb2
8
+ from flyteidl.core import errors_pb2
9
+ from flyteidl2.core import execution_pb2
10
+ from flyteidl2.task import common_pb2
9
11
 
10
12
  import flyte.storage as storage
11
- from flyte._protos.workflow import run_definition_pb2
13
+ from flyte.models import PathRewrite
12
14
 
13
15
  from .convert import Inputs, Outputs, _clean_error_code
14
16
 
@@ -69,7 +71,7 @@ async def upload_outputs(outputs: Outputs, output_path: str, max_bytes: int = -1
69
71
  await storage.put_stream(data_iterable=outputs.proto_outputs.SerializeToString(), to_path=output_uri)
70
72
 
71
73
 
72
- async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str):
74
+ async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str) -> str:
73
75
  """
74
76
  :param err: execution_pb2.ExecutionError
75
77
  :param output_prefix: The output prefix of the remote uri.
@@ -86,17 +88,18 @@ async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str):
86
88
  )
87
89
  )
88
90
  error_uri = error_path(output_prefix)
89
- await storage.put_stream(data_iterable=error_document.SerializeToString(), to_path=error_uri)
91
+ return await storage.put_stream(data_iterable=error_document.SerializeToString(), to_path=error_uri)
90
92
 
91
93
 
92
94
  # ------------------------------- DOWNLOAD Methods ------------------------------- #
93
- async def load_inputs(path: str, max_bytes: int = -1) -> Inputs:
95
+ async def load_inputs(path: str, max_bytes: int = -1, path_rewrite_config: PathRewrite | None = None) -> Inputs:
94
96
  """
95
97
  :param path: Input file to be downloaded
96
98
  :param max_bytes: Maximum number of bytes to read from the input file. Default is -1, which means no limit.
99
+ :param path_rewrite_config: If provided, rewrites paths in the input blobs according to the configuration.
97
100
  :return: Inputs object
98
101
  """
99
- lm = run_definition_pb2.Inputs()
102
+ lm = common_pb2.Inputs()
100
103
 
101
104
  if max_bytes == -1:
102
105
  proto_str = b"".join([c async for c in storage.get_stream(path=path)])
@@ -115,6 +118,16 @@ async def load_inputs(path: str, max_bytes: int = -1) -> Inputs:
115
118
  proto_str = b"".join(proto_bytes)
116
119
 
117
120
  lm.ParseFromString(proto_str)
121
+
122
+ if path_rewrite_config is not None:
123
+ for inp in lm.literals:
124
+ if inp.value.HasField("scalar") and inp.value.scalar.HasField("blob"):
125
+ scalar_blob = inp.value.scalar.blob
126
+ if scalar_blob.uri.startswith(path_rewrite_config.old_prefix):
127
+ scalar_blob.uri = scalar_blob.uri.replace(
128
+ path_rewrite_config.old_prefix, path_rewrite_config.new_prefix, 1
129
+ )
130
+
118
131
  return Inputs(proto_inputs=lm)
119
132
 
120
133
 
@@ -125,7 +138,7 @@ async def load_outputs(path: str, max_bytes: int = -1) -> Outputs:
125
138
  If -1, reads the entire file.
126
139
  :return: Outputs object
127
140
  """
128
- lm = run_definition_pb2.Outputs()
141
+ lm = common_pb2.Outputs()
129
142
 
130
143
  if max_bytes == -1:
131
144
  proto_str = b"".join([c async for c in storage.get_stream(path=path)])
@@ -157,7 +170,7 @@ async def load_error(path: str) -> execution_pb2.ExecutionError:
157
170
  err.ParseFromString(proto_str)
158
171
 
159
172
  if err.error is not None:
160
- user_code, server_code = _clean_error_code(err.error.code)
173
+ user_code, _server_code = _clean_error_code(err.error.code)
161
174
  return execution_pb2.ExecutionError(
162
175
  code=user_code,
163
176
  message=err.error.message,
@@ -1,8 +1,8 @@
1
- from typing import List, Optional, Tuple
1
+ from typing import Dict, List, Optional, Tuple
2
2
 
3
- from flyteidl.core import tasks_pb2
3
+ from flyteidl2.core import tasks_pb2
4
4
 
5
- from flyte._resources import CPUBaseType, Resources
5
+ from flyte._resources import CPUBaseType, DeviceClass, Resources
6
6
 
7
7
  ACCELERATOR_DEVICE_MAP = {
8
8
  "A100": "nvidia-tesla-a100",
@@ -24,6 +24,14 @@ ACCELERATOR_DEVICE_MAP = {
24
24
  "V6E": "tpu-v6e-slice",
25
25
  }
26
26
 
27
+ _DeviceClassToProto: Dict[DeviceClass, "tasks_pb2.GPUAccelerator.DeviceClass"] = {
28
+ "GPU": tasks_pb2.GPUAccelerator.NVIDIA_GPU,
29
+ "TPU": tasks_pb2.GPUAccelerator.GOOGLE_TPU,
30
+ "NEURON": tasks_pb2.GPUAccelerator.AMAZON_NEURON,
31
+ "AMD_GPU": tasks_pb2.GPUAccelerator.AMD_GPU,
32
+ "HABANA_GAUDI": tasks_pb2.GPUAccelerator.HABANA_GAUDI,
33
+ }
34
+
27
35
 
28
36
  def _get_cpu_resource_entry(cpu: CPUBaseType) -> tasks_pb2.Resources.ResourceEntry:
29
37
  return tasks_pb2.Resources.ResourceEntry(
@@ -54,11 +62,17 @@ def _get_gpu_extended_resource_entry(resources: Resources) -> Optional[tasks_pb2
54
62
  device = resources.get_device()
55
63
  if device is None:
56
64
  return None
57
- if device.device not in ACCELERATOR_DEVICE_MAP:
58
- raise ValueError(f"GPU of type {device.device} unknown, cannot map to device name")
65
+
66
+ device_class = _DeviceClassToProto.get(device.device_class, tasks_pb2.GPUAccelerator.NVIDIA_GPU)
67
+ if device.device is None:
68
+ raise RuntimeError("Device type must be specified for GPU string.")
69
+ else:
70
+ device_type = device.device
71
+ device_type = ACCELERATOR_DEVICE_MAP.get(device_type, device_type)
59
72
  return tasks_pb2.GPUAccelerator(
60
- device=ACCELERATOR_DEVICE_MAP[device.device],
73
+ device=device_type,
61
74
  partition_size=device.partition if device.partition else None,
75
+ device_class=device_class,
62
76
  )
63
77
 
64
78
 
@@ -2,7 +2,7 @@ import hashlib
2
2
  import typing
3
3
  from venv import logger
4
4
 
5
- from flyteidl.core import tasks_pb2
5
+ from flyteidl2.core import tasks_pb2
6
6
 
7
7
  import flyte.errors
8
8
  from flyte import ReusePolicy
@@ -1,5 +1,4 @@
1
1
  import asyncio
2
- import sys
3
2
  import time
4
3
  from typing import Any, List, Tuple
5
4
 
@@ -11,7 +10,8 @@ from flyte._internal.runtime.entrypoints import download_code_bundle, load_pkl_t
11
10
  from flyte._internal.runtime.taskrunner import extract_download_run_upload
12
11
  from flyte._logging import logger
13
12
  from flyte._task import TaskTemplate
14
- from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
13
+ from flyte._utils import adjust_sys_path
14
+ from flyte.models import ActionID, Checkpoints, CodeBundle, PathRewrite, RawDataPath
15
15
 
16
16
 
17
17
  async def download_tgz(destination: str, version: str, tgz: str) -> CodeBundle:
@@ -23,7 +23,7 @@ async def download_tgz(destination: str, version: str, tgz: str) -> CodeBundle:
23
23
  :return: The CodeBundle object.
24
24
  """
25
25
  logger.info(f"[rusty] Downloading tgz code bundle from {tgz} to {destination} with version {version}")
26
- sys.path.insert(0, ".")
26
+ adjust_sys_path()
27
27
 
28
28
  code_bundle = CodeBundle(
29
29
  tgz=tgz,
@@ -42,7 +42,7 @@ async def download_load_pkl(destination: str, version: str, pkl: str) -> Tuple[C
42
42
  :return: The CodeBundle object.
43
43
  """
44
44
  logger.info(f"[rusty] Downloading pkl code bundle from {pkl} to {destination} with version {version}")
45
- sys.path.insert(0, ".")
45
+ adjust_sys_path()
46
46
 
47
47
  code_bundle = CodeBundle(
48
48
  pkl=pkl,
@@ -115,6 +115,7 @@ async def run_task(
115
115
  prev_checkpoint: str | None = None,
116
116
  code_bundle: CodeBundle | None = None,
117
117
  input_path: str | None = None,
118
+ path_rewrite_cfg: str | None = None,
118
119
  ):
119
120
  """
120
121
  Runs the task with the provided parameters.
@@ -134,6 +135,7 @@ async def run_task(
134
135
  :param controller: The controller to use for the task.
135
136
  :param code_bundle: Optional code bundle for the task.
136
137
  :param input_path: Optional input path for the task.
138
+ :param path_rewrite_cfg: Optional path rewrite configuration.
137
139
  :return: The loaded task template.
138
140
  """
139
141
  start_time = time.time()
@@ -144,6 +146,19 @@ async def run_task(
144
146
  f" at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}"
145
147
  )
146
148
 
149
+ path_rewrite = PathRewrite.from_str(path_rewrite_cfg) if path_rewrite_cfg else None
150
+ if path_rewrite:
151
+ import flyte.storage as storage
152
+
153
+ if not await storage.exists(path_rewrite.new_prefix):
154
+ logger.error(
155
+ f"[rusty] Path rewrite failed for path {path_rewrite.new_prefix}, "
156
+ f"not found, reverting to original path {path_rewrite.old_prefix}"
157
+ )
158
+ path_rewrite = None
159
+ else:
160
+ logger.info(f"[rusty] Using path rewrite: {path_rewrite}")
161
+
147
162
  try:
148
163
  await contextual_run(
149
164
  extract_download_run_upload,
@@ -151,7 +166,7 @@ async def run_task(
151
166
  action=ActionID(name=name, org=org, project=project, domain=domain, run_name=run_name),
152
167
  version=version,
153
168
  controller=controller,
154
- raw_data_path=RawDataPath(path=raw_data_path),
169
+ raw_data_path=RawDataPath(path=raw_data_path, path_rewrite=path_rewrite),
155
170
  output_path=output_path,
156
171
  run_base_dir=run_base_dir,
157
172
  checkpoints=Checkpoints(prev_checkpoint_path=prev_checkpoint, checkpoint_path=checkpoint_path),
@@ -4,19 +4,18 @@ It includes a Resolver interface for loading tasks, and functions to load classe
4
4
  """
5
5
 
6
6
  import copy
7
- import sys
8
7
  import typing
9
8
  from datetime import timedelta
10
9
  from typing import Optional, cast
11
10
 
12
- from flyteidl.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
11
+ from flyteidl2.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
12
+ from flyteidl2.task import common_pb2, environment_pb2, task_definition_pb2
13
13
  from google.protobuf import duration_pb2, wrappers_pb2
14
14
 
15
15
  import flyte.errors
16
16
  from flyte._cache.cache import VersionParameters, cache_from_request
17
17
  from flyte._logging import logger
18
18
  from flyte._pod import _PRIMARY_CONTAINER_NAME_FIELD, PodTemplate
19
- from flyte._protos.workflow import common_pb2, environment_pb2, task_definition_pb2
20
19
  from flyte._secret import SecretRequest, secrets_from_request
21
20
  from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
22
21
  from flyte.models import CodeBundle, SerializationContext
@@ -120,7 +119,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
120
119
  version=serialize_context.version,
121
120
  )
122
121
 
123
- # TODO Add support for SQL, extra_config, custom
122
+ # TODO Add support for extra_config, custom
124
123
  extra_config: typing.Dict[str, str] = {}
125
124
 
126
125
  if task.pod_template and not isinstance(task.pod_template, str):
@@ -133,7 +132,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
133
132
 
134
133
  custom = task.custom_config(serialize_context)
135
134
 
136
- sql = None
135
+ sql = task.sql(serialize_context)
137
136
 
138
137
  # -------------- CACHE HANDLING ----------------------
139
138
  task_cache = cache_from_request(task.cache)
@@ -171,8 +170,9 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
171
170
  retries=get_proto_retry_strategy(task.retries),
172
171
  timeout=get_proto_timeout(task.timeout),
173
172
  pod_template_name=(task.pod_template if task.pod_template and isinstance(task.pod_template, str) else None),
174
- interruptible=task.interruptable,
173
+ interruptible=task.interruptible,
175
174
  generates_deck=wrappers_pb2.BoolValue(value=task.report),
175
+ debuggable=task.debuggable,
176
176
  ),
177
177
  interface=transform_native_to_typed_interface(task.native_interface),
178
178
  custom=custom if len(custom) > 0 else None,
@@ -209,34 +209,40 @@ def _get_urun_container(
209
209
  else None
210
210
  )
211
211
  resources = get_proto_resources(task_template.resources)
212
- # pr: under what conditions should this return None?
212
+
213
213
  if isinstance(task_template.image, str):
214
214
  raise flyte.errors.RuntimeSystemError("BadConfig", "Image is not a valid image")
215
- image_id = task_template.image.identifier
215
+
216
+ env_name = task_template.parent_env_name
217
+ if env_name is None:
218
+ raise flyte.errors.RuntimeSystemError("BadConfig", f"Task {task_template.name} has no parent environment name")
219
+
216
220
  if not serialize_context.image_cache:
217
221
  # This computes the image uri, computing hashes as necessary so can fail if done remotely.
218
222
  img_uri = task_template.image.uri
219
- elif serialize_context.image_cache and image_id not in serialize_context.image_cache.image_lookup:
220
- img_uri = task_template.image.uri
221
-
222
- logger.warning(
223
- f"Image {task_template.image} not found in the image cache: {serialize_context.image_cache.image_lookup}."
223
+ elif serialize_context.image_cache and env_name not in serialize_context.image_cache.image_lookup:
224
+ raise flyte.errors.RuntimeUserError(
225
+ "MissingEnvironment",
226
+ f"Environment '{env_name}' not found in image cache.\n\n"
227
+ "💡 To fix this:\n"
228
+ " 1. If your parent environment calls a task in another environment,"
229
+ " declare that dependency using 'depends_on=[...]'.\n"
230
+ " Example:\n"
231
+ " env1 = flyte.TaskEnvironment(\n"
232
+ " name='outer',\n"
233
+ " image=flyte.Image.from_debian_base().with_pip_packages('requests'),\n"
234
+ " depends_on=[env2, env3],\n"
235
+ " )\n"
236
+ " 2. If you're using os.getenv() to set the environment name,"
237
+ " make sure the runtime environment has the same environment variable defined.\n"
238
+ " Example:\n"
239
+ " env = flyte.TaskEnvironment(\n"
240
+ ' name=os.getenv("my-name"),\n'
241
+ ' env_vars={"my-name": os.getenv("my-name")},\n'
242
+ " )\n",
224
243
  )
225
244
  else:
226
- python_version_str = "{}.{}".format(sys.version_info.major, sys.version_info.minor)
227
- version_lookup = serialize_context.image_cache.image_lookup[image_id]
228
- if python_version_str in version_lookup:
229
- img_uri = version_lookup[python_version_str]
230
- elif version_lookup:
231
- # Fallback: try to get any available version
232
- fallback_py_version, img_uri = next(iter(version_lookup.items()))
233
- logger.warning(
234
- f"Image {task_template.image} for python version {python_version_str} "
235
- f"not found in the image cache: {serialize_context.image_cache.image_lookup}.\n"
236
- f"Fall back using image {img_uri} for python version {fallback_py_version} ."
237
- )
238
- else:
239
- img_uri = task_template.image.uri
245
+ img_uri = serialize_context.image_cache.image_lookup[env_name]
240
246
 
241
247
  return tasks_pb2.Container(
242
248
  image=img_uri,
@@ -129,6 +129,14 @@ async def convert_and_run(
129
129
  in a context tree.
130
130
  """
131
131
  ctx = internal_ctx()
132
+
133
+ # Load inputs first to get context
134
+ if input_path:
135
+ inputs = await load_inputs(input_path, path_rewrite_config=raw_data_path.path_rewrite)
136
+
137
+ # Extract context from inputs
138
+ custom_context = inputs.context if inputs else {}
139
+
132
140
  tctx = TaskContext(
133
141
  action=action,
134
142
  checkpoints=checkpoints,
@@ -142,9 +150,10 @@ async def convert_and_run(
142
150
  report=flyte.report.Report(name=action.name),
143
151
  mode="remote" if not ctx.data.task_context else ctx.data.task_context.mode,
144
152
  interactive_mode=interactive_mode,
153
+ custom_context=custom_context,
145
154
  )
155
+
146
156
  with ctx.replace_task_context(tctx):
147
- inputs = await load_inputs(input_path) if input_path else inputs
148
157
  inputs_kwargs = await convert_inputs_to_native(inputs, task.native_interface)
149
158
  out, err = await run_task(tctx=tctx, controller=controller, task=task, inputs=inputs_kwargs)
150
159
  if err is not None: