flyte 2.0.0b13__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/debug.py +38 -0
  3. flyte/_bin/runtime.py +62 -8
  4. flyte/_cache/cache.py +4 -2
  5. flyte/_cache/local_cache.py +216 -0
  6. flyte/_code_bundle/_ignore.py +12 -4
  7. flyte/_code_bundle/_packaging.py +13 -9
  8. flyte/_code_bundle/_utils.py +18 -10
  9. flyte/_code_bundle/bundle.py +17 -9
  10. flyte/_constants.py +1 -0
  11. flyte/_context.py +4 -1
  12. flyte/_custom_context.py +73 -0
  13. flyte/_debug/constants.py +38 -0
  14. flyte/_debug/utils.py +17 -0
  15. flyte/_debug/vscode.py +307 -0
  16. flyte/_deploy.py +235 -61
  17. flyte/_environment.py +20 -6
  18. flyte/_excepthook.py +1 -1
  19. flyte/_hash.py +1 -16
  20. flyte/_image.py +178 -81
  21. flyte/_initialize.py +132 -51
  22. flyte/_interface.py +39 -2
  23. flyte/_internal/controllers/__init__.py +4 -5
  24. flyte/_internal/controllers/_local_controller.py +70 -29
  25. flyte/_internal/controllers/_trace.py +1 -1
  26. flyte/_internal/controllers/remote/__init__.py +0 -2
  27. flyte/_internal/controllers/remote/_action.py +14 -16
  28. flyte/_internal/controllers/remote/_client.py +1 -1
  29. flyte/_internal/controllers/remote/_controller.py +68 -70
  30. flyte/_internal/controllers/remote/_core.py +127 -99
  31. flyte/_internal/controllers/remote/_informer.py +19 -10
  32. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  33. flyte/_internal/imagebuild/docker_builder.py +181 -69
  34. flyte/_internal/imagebuild/image_builder.py +0 -5
  35. flyte/_internal/imagebuild/remote_builder.py +155 -64
  36. flyte/_internal/imagebuild/utils.py +51 -2
  37. flyte/_internal/resolvers/_task_module.py +5 -38
  38. flyte/_internal/resolvers/default.py +2 -2
  39. flyte/_internal/runtime/convert.py +110 -21
  40. flyte/_internal/runtime/entrypoints.py +27 -1
  41. flyte/_internal/runtime/io.py +21 -8
  42. flyte/_internal/runtime/resources_serde.py +20 -6
  43. flyte/_internal/runtime/reuse.py +1 -1
  44. flyte/_internal/runtime/rusty.py +20 -5
  45. flyte/_internal/runtime/task_serde.py +34 -19
  46. flyte/_internal/runtime/taskrunner.py +22 -4
  47. flyte/_internal/runtime/trigger_serde.py +160 -0
  48. flyte/_internal/runtime/types_serde.py +1 -1
  49. flyte/_keyring/__init__.py +0 -0
  50. flyte/_keyring/file.py +115 -0
  51. flyte/_logging.py +201 -39
  52. flyte/_map.py +111 -14
  53. flyte/_module.py +70 -0
  54. flyte/_pod.py +4 -3
  55. flyte/_resources.py +213 -31
  56. flyte/_run.py +110 -39
  57. flyte/_task.py +75 -16
  58. flyte/_task_environment.py +105 -29
  59. flyte/_task_plugins.py +4 -2
  60. flyte/_trace.py +5 -0
  61. flyte/_trigger.py +1000 -0
  62. flyte/_utils/__init__.py +2 -1
  63. flyte/_utils/asyn.py +3 -1
  64. flyte/_utils/coro_management.py +2 -1
  65. flyte/_utils/docker_credentials.py +173 -0
  66. flyte/_utils/module_loader.py +17 -2
  67. flyte/_version.py +3 -3
  68. flyte/cli/_abort.py +3 -3
  69. flyte/cli/_build.py +3 -6
  70. flyte/cli/_common.py +78 -7
  71. flyte/cli/_create.py +182 -4
  72. flyte/cli/_delete.py +23 -1
  73. flyte/cli/_deploy.py +63 -16
  74. flyte/cli/_get.py +79 -34
  75. flyte/cli/_params.py +26 -10
  76. flyte/cli/_plugins.py +209 -0
  77. flyte/cli/_run.py +151 -26
  78. flyte/cli/_serve.py +64 -0
  79. flyte/cli/_update.py +37 -0
  80. flyte/cli/_user.py +17 -0
  81. flyte/cli/main.py +30 -4
  82. flyte/config/_config.py +10 -6
  83. flyte/config/_internal.py +1 -0
  84. flyte/config/_reader.py +29 -8
  85. flyte/connectors/__init__.py +11 -0
  86. flyte/connectors/_connector.py +270 -0
  87. flyte/connectors/_server.py +197 -0
  88. flyte/connectors/utils.py +135 -0
  89. flyte/errors.py +22 -2
  90. flyte/extend.py +8 -1
  91. flyte/extras/_container.py +6 -1
  92. flyte/git/__init__.py +3 -0
  93. flyte/git/_config.py +21 -0
  94. flyte/io/__init__.py +2 -0
  95. flyte/io/_dataframe/__init__.py +2 -0
  96. flyte/io/_dataframe/basic_dfs.py +17 -8
  97. flyte/io/_dataframe/dataframe.py +98 -132
  98. flyte/io/_dir.py +575 -113
  99. flyte/io/_file.py +582 -139
  100. flyte/io/_hashing_io.py +342 -0
  101. flyte/models.py +74 -15
  102. flyte/remote/__init__.py +6 -1
  103. flyte/remote/_action.py +34 -26
  104. flyte/remote/_client/_protocols.py +39 -4
  105. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  106. flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
  107. flyte/remote/_client/auth/_channel.py +10 -6
  108. flyte/remote/_client/controlplane.py +17 -5
  109. flyte/remote/_console.py +3 -2
  110. flyte/remote/_data.py +6 -6
  111. flyte/remote/_logs.py +3 -3
  112. flyte/remote/_run.py +64 -8
  113. flyte/remote/_secret.py +26 -17
  114. flyte/remote/_task.py +75 -33
  115. flyte/remote/_trigger.py +306 -0
  116. flyte/remote/_user.py +33 -0
  117. flyte/report/_report.py +1 -1
  118. flyte/storage/__init__.py +6 -1
  119. flyte/storage/_config.py +5 -1
  120. flyte/storage/_parallel_reader.py +274 -0
  121. flyte/storage/_storage.py +200 -103
  122. flyte/types/__init__.py +16 -0
  123. flyte/types/_interface.py +2 -2
  124. flyte/types/_pickle.py +35 -8
  125. flyte/types/_string_literals.py +8 -9
  126. flyte/types/_type_engine.py +40 -70
  127. flyte/types/_utils.py +1 -1
  128. flyte-2.0.0b30.data/scripts/debug.py +38 -0
  129. {flyte-2.0.0b13.data → flyte-2.0.0b30.data}/scripts/runtime.py +62 -8
  130. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/METADATA +11 -3
  131. flyte-2.0.0b30.dist-info/RECORD +192 -0
  132. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +3 -0
  133. flyte/_protos/common/authorization_pb2.py +0 -66
  134. flyte/_protos/common/authorization_pb2.pyi +0 -108
  135. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  136. flyte/_protos/common/identifier_pb2.py +0 -93
  137. flyte/_protos/common/identifier_pb2.pyi +0 -110
  138. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  139. flyte/_protos/common/identity_pb2.py +0 -48
  140. flyte/_protos/common/identity_pb2.pyi +0 -72
  141. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  142. flyte/_protos/common/list_pb2.py +0 -36
  143. flyte/_protos/common/list_pb2.pyi +0 -71
  144. flyte/_protos/common/list_pb2_grpc.py +0 -4
  145. flyte/_protos/common/policy_pb2.py +0 -37
  146. flyte/_protos/common/policy_pb2.pyi +0 -27
  147. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  148. flyte/_protos/common/role_pb2.py +0 -37
  149. flyte/_protos/common/role_pb2.pyi +0 -53
  150. flyte/_protos/common/role_pb2_grpc.py +0 -4
  151. flyte/_protos/common/runtime_version_pb2.py +0 -28
  152. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  153. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  154. flyte/_protos/imagebuilder/definition_pb2.py +0 -59
  155. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -140
  156. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  157. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  158. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  159. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  160. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  161. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  162. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  163. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  164. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  165. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  166. flyte/_protos/secret/definition_pb2.py +0 -49
  167. flyte/_protos/secret/definition_pb2.pyi +0 -93
  168. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  169. flyte/_protos/secret/payload_pb2.py +0 -62
  170. flyte/_protos/secret/payload_pb2.pyi +0 -94
  171. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  172. flyte/_protos/secret/secret_pb2.py +0 -38
  173. flyte/_protos/secret/secret_pb2.pyi +0 -6
  174. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  175. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  176. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  177. flyte/_protos/workflow/common_pb2.py +0 -27
  178. flyte/_protos/workflow/common_pb2.pyi +0 -14
  179. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  180. flyte/_protos/workflow/environment_pb2.py +0 -29
  181. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  182. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  183. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  184. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  185. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  186. flyte/_protos/workflow/queue_service_pb2.py +0 -109
  187. flyte/_protos/workflow/queue_service_pb2.pyi +0 -166
  188. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  189. flyte/_protos/workflow/run_definition_pb2.py +0 -121
  190. flyte/_protos/workflow/run_definition_pb2.pyi +0 -327
  191. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  192. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  193. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  194. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  195. flyte/_protos/workflow/run_service_pb2.py +0 -137
  196. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  197. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  198. flyte/_protos/workflow/state_service_pb2.py +0 -67
  199. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  200. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  201. flyte/_protos/workflow/task_definition_pb2.py +0 -79
  202. flyte/_protos/workflow/task_definition_pb2.pyi +0 -81
  203. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  204. flyte/_protos/workflow/task_service_pb2.py +0 -60
  205. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  206. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  207. flyte-2.0.0b13.dist-info/RECORD +0 -239
  208. /flyte/{_protos → _debug}/__init__.py +0 -0
  209. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  210. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  211. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
@@ -8,27 +8,33 @@ from dataclasses import dataclass
8
8
  from types import NoneType
9
9
  from typing import Any, Dict, List, Tuple, Union, get_args
10
10
 
11
- from flyteidl.core import execution_pb2, interface_pb2, literals_pb2
11
+ from flyteidl2.core import execution_pb2, interface_pb2, literals_pb2
12
+ from flyteidl2.task import common_pb2, task_definition_pb2
12
13
 
13
14
  import flyte.errors
14
15
  import flyte.storage as storage
15
- from flyte._protos.workflow import common_pb2, run_definition_pb2, task_definition_pb2
16
+ from flyte._context import ctx
16
17
  from flyte.models import ActionID, NativeInterface, TaskContext
17
18
  from flyte.types import TypeEngine, TypeTransformerFailedError
18
19
 
19
20
 
20
21
  @dataclass(frozen=True)
21
22
  class Inputs:
22
- proto_inputs: run_definition_pb2.Inputs
23
+ proto_inputs: common_pb2.Inputs
23
24
 
24
25
  @classmethod
25
26
  def empty(cls) -> "Inputs":
26
- return cls(proto_inputs=run_definition_pb2.Inputs())
27
+ return cls(proto_inputs=common_pb2.Inputs())
28
+
29
+ @property
30
+ def context(self) -> Dict[str, str]:
31
+ """Get the context as a dictionary."""
32
+ return {kv.key: kv.value for kv in self.proto_inputs.context}
27
33
 
28
34
 
29
35
  @dataclass(frozen=True)
30
36
  class Outputs:
31
- proto_outputs: run_definition_pb2.Outputs
37
+ proto_outputs: common_pb2.Outputs
32
38
 
33
39
 
34
40
  @dataclass
@@ -102,15 +108,30 @@ def is_optional_type(tp) -> bool:
102
108
  return NoneType in get_args(tp) # fastest check
103
109
 
104
110
 
105
- async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwargs) -> Inputs:
111
+ async def convert_from_native_to_inputs(
112
+ interface: NativeInterface, *args, custom_context: Dict[str, str] | None = None, **kwargs
113
+ ) -> Inputs:
106
114
  kwargs = interface.convert_to_kwargs(*args, **kwargs)
107
115
 
108
116
  missing = [key for key in interface.required_inputs() if key not in kwargs]
109
117
  if missing:
110
118
  raise ValueError(f"Missing required inputs: {', '.join(missing)}")
111
119
 
120
+ # Read custom_context from TaskContext if available (inside task execution)
121
+ # Otherwise use the passed parameter (for remote run initiation)
122
+ context_kvs = None
123
+ tctx = ctx()
124
+ if tctx and tctx.custom_context:
125
+ # Inside a task - read from TaskContext
126
+ context_to_use = tctx.custom_context
127
+ context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in context_to_use.items()]
128
+ elif custom_context:
129
+ # Remote run initiation
130
+ context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in custom_context.items()]
131
+
112
132
  if len(interface.inputs) == 0:
113
- return Inputs.empty()
133
+ # Handle context even for empty inputs
134
+ return Inputs(proto_inputs=common_pb2.Inputs(context=context_kvs))
114
135
 
115
136
  # fill in defaults if missing
116
137
  type_hints: Dict[str, type] = {}
@@ -122,13 +143,14 @@ async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwa
122
143
  (default_value is not None and default_value is not inspect.Signature.empty)
123
144
  or (default_value is None and is_optional_type(input_type))
124
145
  or input_type is None
146
+ or input_type is type(None)
125
147
  ):
126
148
  if default_value == NativeInterface.has_default:
127
149
  if interface._remote_defaults is None or input_name not in interface._remote_defaults:
128
150
  raise ValueError(f"Input '{input_name}' has a default value but it is not set in the interface.")
129
151
  already_converted_kwargs[input_name] = interface._remote_defaults[input_name]
130
- elif input_type is None:
131
- # If the type is None, we assume it's a placeholder for no type
152
+ elif input_type is None or input_type is type(None):
153
+ # If the type is 'None' or 'class<None>', we assume it's a placeholder for no type
132
154
  kwargs[input_name] = None
133
155
  type_hints[input_name] = NoneType
134
156
  else:
@@ -144,12 +166,12 @@ async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwa
144
166
  for k, v in already_converted_kwargs.items():
145
167
  copied_literals[k] = v
146
168
  literal_map = literals_pb2.LiteralMap(literals=copied_literals)
169
+
147
170
  # Make sure we the interface, not literal_map or kwargs, because those may have a different order
148
171
  return Inputs(
149
- proto_inputs=run_definition_pb2.Inputs(
150
- literals=[
151
- run_definition_pb2.NamedLiteral(name=k, value=literal_map.literals[k]) for k in interface.inputs.keys()
152
- ]
172
+ proto_inputs=common_pb2.Inputs(
173
+ literals=[common_pb2.NamedLiteral(name=k, value=literal_map.literals[k]) for k in interface.inputs.keys()],
174
+ context=context_kvs,
153
175
  )
154
176
  )
155
177
 
@@ -191,11 +213,11 @@ async def convert_from_native_to_outputs(o: Any, interface: NativeInterface, tas
191
213
  for (output_name, python_type), v in zip(interface.outputs.items(), o):
192
214
  try:
193
215
  lit = await TypeEngine.to_literal(v, python_type, TypeEngine.to_literal_type(python_type))
194
- named.append(run_definition_pb2.NamedLiteral(name=output_name, value=lit))
216
+ named.append(common_pb2.NamedLiteral(name=output_name, value=lit))
195
217
  except TypeTransformerFailedError as e:
196
218
  raise flyte.errors.RuntimeDataValidationError(output_name, e, task_name)
197
219
 
198
- return Outputs(proto_outputs=run_definition_pb2.Outputs(literals=named))
220
+ return Outputs(proto_outputs=common_pb2.Outputs(literals=named))
199
221
 
200
222
 
201
223
  async def convert_outputs_to_native(interface: NativeInterface, outputs: Outputs) -> Union[Any, Tuple[Any, ...]]:
@@ -222,7 +244,7 @@ def convert_error_to_native(err: execution_pb2.ExecutionError | Exception | Erro
222
244
  if isinstance(err, Error):
223
245
  err = err.err
224
246
 
225
- user_code, server_code = _clean_error_code(err.code)
247
+ user_code, _server_code = _clean_error_code(err.code)
226
248
  match err.kind:
227
249
  case execution_pb2.ExecutionError.UNKNOWN:
228
250
  return flyte.errors.RuntimeUnknownError(code=user_code, message=err.message, worker=err.worker)
@@ -308,15 +330,82 @@ def generate_inputs_hash(serialized_inputs: str | bytes) -> str:
308
330
  return hash_data(serialized_inputs)
309
331
 
310
332
 
311
- def generate_inputs_hash_from_proto(inputs: run_definition_pb2.Inputs) -> str:
333
+ def generate_inputs_repr_for_literal(literal: literals_pb2.Literal) -> bytes:
334
+ """
335
+ Generate a byte representation for a single literal that is meant to be hashed as part of the cache key
336
+ computation for an Action. This function should just serialize the literal deterministically, but will
337
+ use an existing hash value if present in the Literal. This is trivial, except we need to handle nested literals
338
+ (inside collections and maps), that may have the hash property set.
339
+
340
+ :param literal: The literal to get a hashable representation for.
341
+ :return: byte representation of the literal that can be fed into a hash function.
342
+ """
343
+ # If the literal has a hash value, use that instead of serializing the full literal
344
+ if literal.hash:
345
+ return literal.hash.encode("utf-8")
346
+
347
+ if literal.HasField("collection"):
348
+ buf = bytearray()
349
+ for nested_literal in literal.collection.literals:
350
+ if nested_literal.hash:
351
+ buf += nested_literal.hash.encode("utf-8")
352
+ else:
353
+ buf += generate_inputs_repr_for_literal(nested_literal)
354
+
355
+ b = bytes(buf)
356
+ return b
357
+
358
+ elif literal.HasField("map"):
359
+ buf = bytearray()
360
+ # Sort keys to ensure deterministic ordering
361
+ for key in sorted(literal.map.literals.keys()):
362
+ nested_literal = literal.map.literals[key]
363
+ buf += key.encode("utf-8")
364
+ if nested_literal.hash:
365
+ buf += nested_literal.hash.encode("utf-8")
366
+ else:
367
+ buf += generate_inputs_repr_for_literal(nested_literal)
368
+
369
+ b = bytes(buf)
370
+ return b
371
+
372
+ # For all other cases (scalars, etc.), just serialize the literal normally
373
+ return literal.SerializeToString(deterministic=True)
374
+
375
+
376
+ def generate_inputs_hash_for_named_literals(inputs: list[common_pb2.NamedLiteral]) -> str:
377
+ """
378
+ Generate a hash for the inputs using the new literal representation approach that respects
379
+ hash values already present in literals. This is used to uniquely identify the inputs for a task
380
+ when some literals may have precomputed hash values.
381
+
382
+ :param inputs: List of NamedLiteral inputs to hash.
383
+ :return: A base64-encoded string representation of the hash.
384
+ """
385
+ if not inputs:
386
+ return ""
387
+
388
+ # Build the byte representation by concatenating each literal's representation
389
+ combined_bytes = b""
390
+ for named_literal in inputs:
391
+ # Add the name to ensure order matters
392
+ name_bytes = named_literal.name.encode("utf-8")
393
+ literal_bytes = generate_inputs_repr_for_literal(named_literal.value)
394
+ # Combine name and literal bytes with a separator to avoid collisions
395
+ combined_bytes += name_bytes + b":" + literal_bytes + b";"
396
+
397
+ return hash_data(combined_bytes)
398
+
399
+
400
+ def generate_inputs_hash_from_proto(inputs: common_pb2.Inputs) -> str:
312
401
  """
313
402
  Generate a hash for the inputs. This is used to uniquely identify the inputs for a task.
314
403
  :param inputs: The inputs to hash.
315
404
  :return: A hexadecimal string representation of the hash.
316
405
  """
317
- if not inputs:
406
+ if not inputs or not inputs.literals:
318
407
  return ""
319
- return generate_inputs_hash(inputs.SerializeToString(deterministic=True))
408
+ return generate_inputs_hash_for_named_literals(list(inputs.literals))
320
409
 
321
410
 
322
411
  def generate_interface_hash(task_interface: interface_pb2.TypedInterface) -> str:
@@ -337,7 +426,7 @@ def generate_cache_key_hash(
337
426
  task_interface: interface_pb2.TypedInterface,
338
427
  cache_version: str,
339
428
  ignored_input_vars: List[str],
340
- proto_inputs: run_definition_pb2.Inputs,
429
+ proto_inputs: common_pb2.Inputs,
341
430
  ) -> str:
342
431
  """
343
432
  Generate a cache key hash based on the inputs hash, task name, task interface, and cache version.
@@ -353,7 +442,7 @@ def generate_cache_key_hash(
353
442
  """
354
443
  if ignored_input_vars:
355
444
  filtered = [named_lit for named_lit in proto_inputs.literals if named_lit.name not in ignored_input_vars]
356
- final = run_definition_pb2.Inputs(literals=filtered)
445
+ final = common_pb2.Inputs(literals=filtered)
357
446
  final_inputs = generate_inputs_hash_from_proto(final)
358
447
  else:
359
448
  final_inputs = inputs_hash
@@ -1,4 +1,6 @@
1
1
  import importlib
2
+ import os
3
+ import traceback
2
4
  from typing import List, Optional, Tuple, Type
3
5
 
4
6
  import flyte.errors
@@ -10,6 +12,7 @@ from flyte._logging import log, logger
10
12
  from flyte._task import TaskTemplate
11
13
  from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
12
14
 
15
+ from ..._utils import adjust_sys_path
13
16
  from .convert import Error, Inputs, Outputs
14
17
  from .taskrunner import (
15
18
  convert_and_run,
@@ -72,7 +75,26 @@ def load_task(resolver: str, *resolver_args: str) -> TaskTemplate:
72
75
  """
73
76
  resolver_class = load_class(resolver)
74
77
  resolver_instance = resolver_class()
75
- return resolver_instance.load_task(resolver_args)
78
+ try:
79
+ return resolver_instance.load_task(resolver_args)
80
+ except ModuleNotFoundError as e:
81
+ cwd = os.getcwd()
82
+ files = []
83
+ try:
84
+ for root, dirs, filenames in os.walk(cwd):
85
+ for name in dirs + filenames:
86
+ rel_path = os.path.relpath(os.path.join(root, name), cwd)
87
+ files.append(rel_path)
88
+ except Exception as list_err:
89
+ files = [f"(Failed to list directory: {list_err})"]
90
+
91
+ msg = (
92
+ "\n\nFull traceback:\n" + "".join(traceback.format_exc()) + f"\n[ImportError Diagnostics]\n"
93
+ f"Module '{e.name}' not found in either the Python virtual environment or the current working directory.\n"
94
+ f"Current working directory: {cwd}\n"
95
+ f"Files found under current directory:\n" + "\n".join(f" - {f}" for f in files)
96
+ )
97
+ raise ModuleNotFoundError(msg) from e
76
98
 
77
99
 
78
100
  def load_pkl_task(code_bundle: CodeBundle) -> TaskTemplate:
@@ -100,6 +122,7 @@ async def download_code_bundle(code_bundle: CodeBundle) -> CodeBundle:
100
122
  :param code_bundle: The code bundle to download.
101
123
  :return: The code bundle with the downloaded path.
102
124
  """
125
+ adjust_sys_path()
103
126
  logger.debug(f"Downloading {code_bundle}")
104
127
  downloaded_path = await download_bundle(code_bundle)
105
128
  return code_bundle.with_downloaded_path(downloaded_path)
@@ -142,6 +165,7 @@ async def load_and_run_task(
142
165
  code_bundle: CodeBundle | None = None,
143
166
  input_path: str | None = None,
144
167
  image_cache: ImageCache | None = None,
168
+ interactive_mode: bool = False,
145
169
  ):
146
170
  """
147
171
  This method is invoked from the runtime/CLI and is used to run a task. This creates the context tree,
@@ -159,6 +183,7 @@ async def load_and_run_task(
159
183
  :param code_bundle: The code bundle to use for the task.
160
184
  :param input_path: The input path to use for the task.
161
185
  :param image_cache: Mappings of Image identifiers to image URIs.
186
+ :param interactive_mode: Whether to run the task in interactive mode.
162
187
  """
163
188
  task = await _download_and_load_task(code_bundle, resolver, resolver_args)
164
189
 
@@ -175,4 +200,5 @@ async def load_and_run_task(
175
200
  code_bundle=code_bundle,
176
201
  input_path=input_path,
177
202
  image_cache=image_cache,
203
+ interactive_mode=interactive_mode,
178
204
  )
@@ -5,10 +5,12 @@ It uses the storage module to handle the actual uploading and downloading of fil
5
5
  TODO: Convert to use streaming apis
6
6
  """
7
7
 
8
- from flyteidl.core import errors_pb2, execution_pb2
8
+ from flyteidl.core import errors_pb2
9
+ from flyteidl2.core import execution_pb2
10
+ from flyteidl2.task import common_pb2
9
11
 
10
12
  import flyte.storage as storage
11
- from flyte._protos.workflow import run_definition_pb2
13
+ from flyte.models import PathRewrite
12
14
 
13
15
  from .convert import Inputs, Outputs, _clean_error_code
14
16
 
@@ -69,7 +71,7 @@ async def upload_outputs(outputs: Outputs, output_path: str, max_bytes: int = -1
69
71
  await storage.put_stream(data_iterable=outputs.proto_outputs.SerializeToString(), to_path=output_uri)
70
72
 
71
73
 
72
- async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str):
74
+ async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str) -> str:
73
75
  """
74
76
  :param err: execution_pb2.ExecutionError
75
77
  :param output_prefix: The output prefix of the remote uri.
@@ -86,17 +88,18 @@ async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str):
86
88
  )
87
89
  )
88
90
  error_uri = error_path(output_prefix)
89
- await storage.put_stream(data_iterable=error_document.SerializeToString(), to_path=error_uri)
91
+ return await storage.put_stream(data_iterable=error_document.SerializeToString(), to_path=error_uri)
90
92
 
91
93
 
92
94
  # ------------------------------- DOWNLOAD Methods ------------------------------- #
93
- async def load_inputs(path: str, max_bytes: int = -1) -> Inputs:
95
+ async def load_inputs(path: str, max_bytes: int = -1, path_rewrite_config: PathRewrite | None = None) -> Inputs:
94
96
  """
95
97
  :param path: Input file to be downloaded
96
98
  :param max_bytes: Maximum number of bytes to read from the input file. Default is -1, which means no limit.
99
+ :param path_rewrite_config: If provided, rewrites paths in the input blobs according to the configuration.
97
100
  :return: Inputs object
98
101
  """
99
- lm = run_definition_pb2.Inputs()
102
+ lm = common_pb2.Inputs()
100
103
 
101
104
  if max_bytes == -1:
102
105
  proto_str = b"".join([c async for c in storage.get_stream(path=path)])
@@ -115,6 +118,16 @@ async def load_inputs(path: str, max_bytes: int = -1) -> Inputs:
115
118
  proto_str = b"".join(proto_bytes)
116
119
 
117
120
  lm.ParseFromString(proto_str)
121
+
122
+ if path_rewrite_config is not None:
123
+ for inp in lm.literals:
124
+ if inp.value.HasField("scalar") and inp.value.scalar.HasField("blob"):
125
+ scalar_blob = inp.value.scalar.blob
126
+ if scalar_blob.uri.startswith(path_rewrite_config.old_prefix):
127
+ scalar_blob.uri = scalar_blob.uri.replace(
128
+ path_rewrite_config.old_prefix, path_rewrite_config.new_prefix, 1
129
+ )
130
+
118
131
  return Inputs(proto_inputs=lm)
119
132
 
120
133
 
@@ -125,7 +138,7 @@ async def load_outputs(path: str, max_bytes: int = -1) -> Outputs:
125
138
  If -1, reads the entire file.
126
139
  :return: Outputs object
127
140
  """
128
- lm = run_definition_pb2.Outputs()
141
+ lm = common_pb2.Outputs()
129
142
 
130
143
  if max_bytes == -1:
131
144
  proto_str = b"".join([c async for c in storage.get_stream(path=path)])
@@ -157,7 +170,7 @@ async def load_error(path: str) -> execution_pb2.ExecutionError:
157
170
  err.ParseFromString(proto_str)
158
171
 
159
172
  if err.error is not None:
160
- user_code, server_code = _clean_error_code(err.error.code)
173
+ user_code, _server_code = _clean_error_code(err.error.code)
161
174
  return execution_pb2.ExecutionError(
162
175
  code=user_code,
163
176
  message=err.error.message,
@@ -1,8 +1,8 @@
1
- from typing import List, Optional, Tuple
1
+ from typing import Dict, List, Optional, Tuple
2
2
 
3
- from flyteidl.core import tasks_pb2
3
+ from flyteidl2.core import tasks_pb2
4
4
 
5
- from flyte._resources import CPUBaseType, Resources
5
+ from flyte._resources import CPUBaseType, DeviceClass, Resources
6
6
 
7
7
  ACCELERATOR_DEVICE_MAP = {
8
8
  "A100": "nvidia-tesla-a100",
@@ -24,6 +24,14 @@ ACCELERATOR_DEVICE_MAP = {
24
24
  "V6E": "tpu-v6e-slice",
25
25
  }
26
26
 
27
+ _DeviceClassToProto: Dict[DeviceClass, "tasks_pb2.GPUAccelerator.DeviceClass"] = {
28
+ "GPU": tasks_pb2.GPUAccelerator.NVIDIA_GPU,
29
+ "TPU": tasks_pb2.GPUAccelerator.GOOGLE_TPU,
30
+ "NEURON": tasks_pb2.GPUAccelerator.AMAZON_NEURON,
31
+ "AMD_GPU": tasks_pb2.GPUAccelerator.AMD_GPU,
32
+ "HABANA_GAUDI": tasks_pb2.GPUAccelerator.HABANA_GAUDI,
33
+ }
34
+
27
35
 
28
36
  def _get_cpu_resource_entry(cpu: CPUBaseType) -> tasks_pb2.Resources.ResourceEntry:
29
37
  return tasks_pb2.Resources.ResourceEntry(
@@ -54,11 +62,17 @@ def _get_gpu_extended_resource_entry(resources: Resources) -> Optional[tasks_pb2
54
62
  device = resources.get_device()
55
63
  if device is None:
56
64
  return None
57
- if device.device not in ACCELERATOR_DEVICE_MAP:
58
- raise ValueError(f"GPU of type {device.device} unknown, cannot map to device name")
65
+
66
+ device_class = _DeviceClassToProto.get(device.device_class, tasks_pb2.GPUAccelerator.NVIDIA_GPU)
67
+ if device.device is None:
68
+ raise RuntimeError("Device type must be specified for GPU string.")
69
+ else:
70
+ device_type = device.device
71
+ device_type = ACCELERATOR_DEVICE_MAP.get(device_type, device_type)
59
72
  return tasks_pb2.GPUAccelerator(
60
- device=ACCELERATOR_DEVICE_MAP[device.device],
73
+ device=device_type,
61
74
  partition_size=device.partition if device.partition else None,
75
+ device_class=device_class,
62
76
  )
63
77
 
64
78
 
@@ -2,7 +2,7 @@ import hashlib
2
2
  import typing
3
3
  from venv import logger
4
4
 
5
- from flyteidl.core import tasks_pb2
5
+ from flyteidl2.core import tasks_pb2
6
6
 
7
7
  import flyte.errors
8
8
  from flyte import ReusePolicy
@@ -1,5 +1,4 @@
1
1
  import asyncio
2
- import sys
3
2
  import time
4
3
  from typing import Any, List, Tuple
5
4
 
@@ -11,7 +10,8 @@ from flyte._internal.runtime.entrypoints import download_code_bundle, load_pkl_t
11
10
  from flyte._internal.runtime.taskrunner import extract_download_run_upload
12
11
  from flyte._logging import logger
13
12
  from flyte._task import TaskTemplate
14
- from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
13
+ from flyte._utils import adjust_sys_path
14
+ from flyte.models import ActionID, Checkpoints, CodeBundle, PathRewrite, RawDataPath
15
15
 
16
16
 
17
17
  async def download_tgz(destination: str, version: str, tgz: str) -> CodeBundle:
@@ -23,7 +23,7 @@ async def download_tgz(destination: str, version: str, tgz: str) -> CodeBundle:
23
23
  :return: The CodeBundle object.
24
24
  """
25
25
  logger.info(f"[rusty] Downloading tgz code bundle from {tgz} to {destination} with version {version}")
26
- sys.path.insert(0, ".")
26
+ adjust_sys_path()
27
27
 
28
28
  code_bundle = CodeBundle(
29
29
  tgz=tgz,
@@ -42,7 +42,7 @@ async def download_load_pkl(destination: str, version: str, pkl: str) -> Tuple[C
42
42
  :return: The CodeBundle object.
43
43
  """
44
44
  logger.info(f"[rusty] Downloading pkl code bundle from {pkl} to {destination} with version {version}")
45
- sys.path.insert(0, ".")
45
+ adjust_sys_path()
46
46
 
47
47
  code_bundle = CodeBundle(
48
48
  pkl=pkl,
@@ -115,6 +115,7 @@ async def run_task(
115
115
  prev_checkpoint: str | None = None,
116
116
  code_bundle: CodeBundle | None = None,
117
117
  input_path: str | None = None,
118
+ path_rewrite_cfg: str | None = None,
118
119
  ):
119
120
  """
120
121
  Runs the task with the provided parameters.
@@ -134,6 +135,7 @@ async def run_task(
134
135
  :param controller: The controller to use for the task.
135
136
  :param code_bundle: Optional code bundle for the task.
136
137
  :param input_path: Optional input path for the task.
138
+ :param path_rewrite_cfg: Optional path rewrite configuration.
137
139
  :return: The loaded task template.
138
140
  """
139
141
  start_time = time.time()
@@ -144,6 +146,19 @@ async def run_task(
144
146
  f" at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}"
145
147
  )
146
148
 
149
+ path_rewrite = PathRewrite.from_str(path_rewrite_cfg) if path_rewrite_cfg else None
150
+ if path_rewrite:
151
+ import flyte.storage as storage
152
+
153
+ if not await storage.exists(path_rewrite.new_prefix):
154
+ logger.error(
155
+ f"[rusty] Path rewrite failed for path {path_rewrite.new_prefix}, "
156
+ f"not found, reverting to original path {path_rewrite.old_prefix}"
157
+ )
158
+ path_rewrite = None
159
+ else:
160
+ logger.info(f"[rusty] Using path rewrite: {path_rewrite}")
161
+
147
162
  try:
148
163
  await contextual_run(
149
164
  extract_download_run_upload,
@@ -151,7 +166,7 @@ async def run_task(
151
166
  action=ActionID(name=name, org=org, project=project, domain=domain, run_name=run_name),
152
167
  version=version,
153
168
  controller=controller,
154
- raw_data_path=RawDataPath(path=raw_data_path),
169
+ raw_data_path=RawDataPath(path=raw_data_path, path_rewrite=path_rewrite),
155
170
  output_path=output_path,
156
171
  run_base_dir=run_base_dir,
157
172
  checkpoints=Checkpoints(prev_checkpoint_path=prev_checkpoint, checkpoint_path=checkpoint_path),
@@ -8,14 +8,14 @@ import typing
8
8
  from datetime import timedelta
9
9
  from typing import Optional, cast
10
10
 
11
- from flyteidl.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
11
+ from flyteidl2.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
12
+ from flyteidl2.task import common_pb2, environment_pb2, task_definition_pb2
12
13
  from google.protobuf import duration_pb2, wrappers_pb2
13
14
 
14
15
  import flyte.errors
15
16
  from flyte._cache.cache import VersionParameters, cache_from_request
16
17
  from flyte._logging import logger
17
18
  from flyte._pod import _PRIMARY_CONTAINER_NAME_FIELD, PodTemplate
18
- from flyte._protos.workflow import common_pb2, environment_pb2, task_definition_pb2
19
19
  from flyte._secret import SecretRequest, secrets_from_request
20
20
  from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
21
21
  from flyte.models import CodeBundle, SerializationContext
@@ -54,7 +54,7 @@ def translate_task_to_wire(
54
54
  return task_definition_pb2.TaskSpec(
55
55
  task_template=tt,
56
56
  default_inputs=default_inputs,
57
- short_name=task.friendly_name[:_MAX_TASK_SHORT_NAME_LENGTH],
57
+ short_name=task.short_name[:_MAX_TASK_SHORT_NAME_LENGTH],
58
58
  environment=env,
59
59
  )
60
60
 
@@ -119,7 +119,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
119
119
  version=serialize_context.version,
120
120
  )
121
121
 
122
- # TODO Add support for SQL, extra_config, custom
122
+ # TODO Add support for extra_config, custom
123
123
  extra_config: typing.Dict[str, str] = {}
124
124
 
125
125
  if task.pod_template and not isinstance(task.pod_template, str):
@@ -132,7 +132,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
132
132
 
133
133
  custom = task.custom_config(serialize_context)
134
134
 
135
- sql = None
135
+ sql = task.sql(serialize_context)
136
136
 
137
137
  # -------------- CACHE HANDLING ----------------------
138
138
  task_cache = cache_from_request(task.cache)
@@ -145,7 +145,6 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
145
145
  logger.debug(f"Detected pkl bundle for task {task.name}, using computed version as cache version")
146
146
  cache_version = serialize_context.code_bundle.computed_version
147
147
  else:
148
- version_parameters = None
149
148
  if isinstance(task, AsyncFunctionTaskTemplate):
150
149
  version_parameters = VersionParameters(func=task.func, image=task.image)
151
150
  else:
@@ -171,8 +170,9 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
171
170
  retries=get_proto_retry_strategy(task.retries),
172
171
  timeout=get_proto_timeout(task.timeout),
173
172
  pod_template_name=(task.pod_template if task.pod_template and isinstance(task.pod_template, str) else None),
174
- interruptible=task.interruptable,
173
+ interruptible=task.interruptible,
175
174
  generates_deck=wrappers_pb2.BoolValue(value=task.report),
175
+ debuggable=task.debuggable,
176
176
  ),
177
177
  interface=transform_native_to_typed_interface(task.native_interface),
178
178
  custom=custom if len(custom) > 0 else None,
@@ -209,25 +209,40 @@ def _get_urun_container(
209
209
  else None
210
210
  )
211
211
  resources = get_proto_resources(task_template.resources)
212
- # pr: under what conditions should this return None?
212
+
213
213
  if isinstance(task_template.image, str):
214
214
  raise flyte.errors.RuntimeSystemError("BadConfig", "Image is not a valid image")
215
- image_id = task_template.image.identifier
215
+
216
+ env_name = task_template.parent_env_name
217
+ if env_name is None:
218
+ raise flyte.errors.RuntimeSystemError("BadConfig", f"Task {task_template.name} has no parent environment name")
219
+
216
220
  if not serialize_context.image_cache:
217
221
  # This computes the image uri, computing hashes as necessary so can fail if done remotely.
218
222
  img_uri = task_template.image.uri
219
- elif serialize_context.image_cache and image_id not in serialize_context.image_cache.image_lookup:
220
- img_uri = task_template.image.uri
221
- from flyte._version import __version__
222
-
223
- logger.warning(
224
- f"Image {task_template.image} not found in the image cache: {serialize_context.image_cache.image_lookup}.\n"
225
- f"This typically occurs when the Flyte SDK version (`{__version__}`) used in the task environment "
226
- f"differs from the version used to compile or deploy it.\n"
227
- f"Ensure both environments use the same Flyte SDK version to avoid inconsistencies in image resolution."
223
+ elif serialize_context.image_cache and env_name not in serialize_context.image_cache.image_lookup:
224
+ raise flyte.errors.RuntimeUserError(
225
+ "MissingEnvironment",
226
+ f"Environment '{env_name}' not found in image cache.\n\n"
227
+ "💡 To fix this:\n"
228
+ " 1. If your parent environment calls a task in another environment,"
229
+ " declare that dependency using 'depends_on=[...]'.\n"
230
+ " Example:\n"
231
+ " env1 = flyte.TaskEnvironment(\n"
232
+ " name='outer',\n"
233
+ " image=flyte.Image.from_debian_base().with_pip_packages('requests'),\n"
234
+ " depends_on=[env2, env3],\n"
235
+ " )\n"
236
+ " 2. If you're using os.getenv() to set the environment name,"
237
+ " make sure the runtime environment has the same environment variable defined.\n"
238
+ " Example:\n"
239
+ " env = flyte.TaskEnvironment(\n"
240
+ ' name=os.getenv("my-name"),\n'
241
+ ' env_vars={"my-name": os.getenv("my-name")},\n'
242
+ " )\n",
228
243
  )
229
244
  else:
230
- img_uri = serialize_context.image_cache.image_lookup[image_id]
245
+ img_uri = serialize_context.image_cache.image_lookup[env_name]
231
246
 
232
247
  return tasks_pb2.Container(
233
248
  image=img_uri,