flyte 0.0.1b3__py3-none-any.whl → 0.2.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (319) hide show
  1. flyte/__init__.py +20 -4
  2. flyte/_bin/runtime.py +33 -7
  3. flyte/_build.py +3 -2
  4. flyte/_cache/cache.py +1 -2
  5. flyte/_code_bundle/_packaging.py +1 -1
  6. flyte/_code_bundle/_utils.py +0 -16
  7. flyte/_code_bundle/bundle.py +43 -12
  8. flyte/_context.py +8 -2
  9. flyte/_deploy.py +56 -15
  10. flyte/_environment.py +45 -4
  11. flyte/_excepthook.py +37 -0
  12. flyte/_group.py +2 -1
  13. flyte/_image.py +8 -4
  14. flyte/_initialize.py +112 -254
  15. flyte/_interface.py +3 -3
  16. flyte/_internal/controllers/__init__.py +19 -6
  17. flyte/_internal/controllers/_local_controller.py +83 -8
  18. flyte/_internal/controllers/_trace.py +2 -1
  19. flyte/_internal/controllers/remote/__init__.py +27 -7
  20. flyte/_internal/controllers/remote/_action.py +7 -2
  21. flyte/_internal/controllers/remote/_client.py +5 -1
  22. flyte/_internal/controllers/remote/_controller.py +159 -26
  23. flyte/_internal/controllers/remote/_core.py +13 -5
  24. flyte/_internal/controllers/remote/_informer.py +4 -4
  25. flyte/_internal/controllers/remote/_service_protocol.py +6 -6
  26. flyte/_internal/imagebuild/docker_builder.py +12 -1
  27. flyte/_internal/imagebuild/image_builder.py +16 -11
  28. flyte/_internal/runtime/convert.py +164 -21
  29. flyte/_internal/runtime/entrypoints.py +1 -1
  30. flyte/_internal/runtime/io.py +3 -3
  31. flyte/_internal/runtime/task_serde.py +140 -20
  32. flyte/_internal/runtime/taskrunner.py +4 -3
  33. flyte/_internal/runtime/types_serde.py +1 -1
  34. flyte/_logging.py +12 -1
  35. flyte/_map.py +215 -0
  36. flyte/_pod.py +19 -0
  37. flyte/_protos/common/list_pb2.py +3 -3
  38. flyte/_protos/common/list_pb2.pyi +2 -0
  39. flyte/_protos/logs/dataplane/payload_pb2.py +28 -24
  40. flyte/_protos/logs/dataplane/payload_pb2.pyi +11 -2
  41. flyte/_protos/workflow/common_pb2.py +27 -0
  42. flyte/_protos/workflow/common_pb2.pyi +14 -0
  43. flyte/_protos/workflow/environment_pb2.py +29 -0
  44. flyte/_protos/workflow/environment_pb2.pyi +12 -0
  45. flyte/_protos/workflow/queue_service_pb2.py +40 -41
  46. flyte/_protos/workflow/queue_service_pb2.pyi +35 -30
  47. flyte/_protos/workflow/queue_service_pb2_grpc.py +15 -15
  48. flyte/_protos/workflow/run_definition_pb2.py +61 -61
  49. flyte/_protos/workflow/run_definition_pb2.pyi +8 -4
  50. flyte/_protos/workflow/run_service_pb2.py +20 -24
  51. flyte/_protos/workflow/run_service_pb2.pyi +2 -6
  52. flyte/_protos/workflow/state_service_pb2.py +36 -28
  53. flyte/_protos/workflow/state_service_pb2.pyi +19 -15
  54. flyte/_protos/workflow/state_service_pb2_grpc.py +28 -28
  55. flyte/_protos/workflow/task_definition_pb2.py +29 -22
  56. flyte/_protos/workflow/task_definition_pb2.pyi +21 -5
  57. flyte/_protos/workflow/task_service_pb2.py +27 -11
  58. flyte/_protos/workflow/task_service_pb2.pyi +29 -1
  59. flyte/_protos/workflow/task_service_pb2_grpc.py +34 -0
  60. flyte/_run.py +166 -95
  61. flyte/_task.py +110 -28
  62. flyte/_task_environment.py +55 -72
  63. flyte/_trace.py +6 -14
  64. flyte/_utils/__init__.py +6 -0
  65. flyte/_utils/async_cache.py +139 -0
  66. flyte/_utils/coro_management.py +0 -2
  67. flyte/_utils/helpers.py +45 -19
  68. flyte/_utils/org_discovery.py +57 -0
  69. flyte/_version.py +2 -2
  70. flyte/cli/__init__.py +3 -0
  71. flyte/cli/_abort.py +28 -0
  72. flyte/{_cli → cli}/_common.py +73 -23
  73. flyte/cli/_create.py +145 -0
  74. flyte/{_cli → cli}/_delete.py +4 -4
  75. flyte/{_cli → cli}/_deploy.py +26 -14
  76. flyte/cli/_gen.py +163 -0
  77. flyte/{_cli → cli}/_get.py +98 -23
  78. {union/_cli → flyte/cli}/_params.py +106 -147
  79. flyte/{_cli → cli}/_run.py +99 -20
  80. flyte/cli/main.py +166 -0
  81. flyte/config/__init__.py +3 -0
  82. flyte/config/_config.py +216 -0
  83. flyte/config/_internal.py +64 -0
  84. flyte/config/_reader.py +207 -0
  85. flyte/errors.py +29 -0
  86. flyte/extras/_container.py +33 -43
  87. flyte/io/__init__.py +17 -1
  88. flyte/io/_dir.py +2 -2
  89. flyte/io/_file.py +3 -4
  90. flyte/io/{structured_dataset → _structured_dataset}/basic_dfs.py +1 -1
  91. flyte/io/{structured_dataset → _structured_dataset}/structured_dataset.py +1 -1
  92. flyte/{_datastructures.py → models.py} +56 -7
  93. flyte/remote/__init__.py +2 -1
  94. flyte/remote/_client/_protocols.py +2 -0
  95. flyte/remote/_client/auth/_auth_utils.py +14 -0
  96. flyte/remote/_client/auth/_channel.py +34 -3
  97. flyte/remote/_client/auth/_token_client.py +3 -3
  98. flyte/remote/_client/controlplane.py +13 -13
  99. flyte/remote/_console.py +1 -1
  100. flyte/remote/_data.py +10 -6
  101. flyte/remote/_logs.py +89 -29
  102. flyte/remote/_project.py +8 -9
  103. flyte/remote/_run.py +228 -131
  104. flyte/remote/_secret.py +12 -12
  105. flyte/remote/_task.py +179 -15
  106. flyte/report/_report.py +4 -4
  107. flyte/storage/__init__.py +5 -0
  108. flyte/storage/_config.py +233 -0
  109. flyte/storage/_storage.py +23 -3
  110. flyte/syncify/__init__.py +56 -0
  111. flyte/syncify/_api.py +371 -0
  112. flyte/types/__init__.py +23 -0
  113. flyte/types/_interface.py +22 -7
  114. flyte/{io/pickle/transformer.py → types/_pickle.py} +2 -1
  115. flyte/types/_type_engine.py +95 -18
  116. flyte-0.2.0a0.dist-info/METADATA +249 -0
  117. flyte-0.2.0a0.dist-info/RECORD +218 -0
  118. {flyte-0.0.1b3.dist-info → flyte-0.2.0a0.dist-info}/entry_points.txt +1 -1
  119. flyte/_api_commons.py +0 -3
  120. flyte/_cli/__init__.py +0 -0
  121. flyte/_cli/_create.py +0 -42
  122. flyte/_cli/main.py +0 -72
  123. flyte/_internal/controllers/pbhash.py +0 -39
  124. flyte/io/_dataframe.py +0 -0
  125. flyte/io/pickle/__init__.py +0 -0
  126. flyte-0.0.1b3.dist-info/METADATA +0 -179
  127. flyte-0.0.1b3.dist-info/RECORD +0 -390
  128. union/__init__.py +0 -54
  129. union/_api_commons.py +0 -3
  130. union/_bin/__init__.py +0 -0
  131. union/_bin/runtime.py +0 -113
  132. union/_build.py +0 -25
  133. union/_cache/__init__.py +0 -12
  134. union/_cache/cache.py +0 -141
  135. union/_cache/defaults.py +0 -9
  136. union/_cache/policy_function_body.py +0 -42
  137. union/_cli/__init__.py +0 -0
  138. union/_cli/_common.py +0 -263
  139. union/_cli/_create.py +0 -40
  140. union/_cli/_delete.py +0 -23
  141. union/_cli/_deploy.py +0 -120
  142. union/_cli/_get.py +0 -162
  143. union/_cli/_run.py +0 -150
  144. union/_cli/main.py +0 -72
  145. union/_code_bundle/__init__.py +0 -8
  146. union/_code_bundle/_ignore.py +0 -113
  147. union/_code_bundle/_packaging.py +0 -187
  148. union/_code_bundle/_utils.py +0 -342
  149. union/_code_bundle/bundle.py +0 -176
  150. union/_context.py +0 -146
  151. union/_datastructures.py +0 -295
  152. union/_deploy.py +0 -185
  153. union/_doc.py +0 -29
  154. union/_docstring.py +0 -26
  155. union/_environment.py +0 -43
  156. union/_group.py +0 -31
  157. union/_hash.py +0 -23
  158. union/_image.py +0 -760
  159. union/_initialize.py +0 -585
  160. union/_interface.py +0 -84
  161. union/_internal/__init__.py +0 -3
  162. union/_internal/controllers/__init__.py +0 -77
  163. union/_internal/controllers/_local_controller.py +0 -77
  164. union/_internal/controllers/pbhash.py +0 -39
  165. union/_internal/controllers/remote/__init__.py +0 -40
  166. union/_internal/controllers/remote/_action.py +0 -131
  167. union/_internal/controllers/remote/_client.py +0 -43
  168. union/_internal/controllers/remote/_controller.py +0 -169
  169. union/_internal/controllers/remote/_core.py +0 -341
  170. union/_internal/controllers/remote/_informer.py +0 -260
  171. union/_internal/controllers/remote/_service_protocol.py +0 -44
  172. union/_internal/imagebuild/__init__.py +0 -11
  173. union/_internal/imagebuild/docker_builder.py +0 -416
  174. union/_internal/imagebuild/image_builder.py +0 -243
  175. union/_internal/imagebuild/remote_builder.py +0 -0
  176. union/_internal/resolvers/__init__.py +0 -0
  177. union/_internal/resolvers/_task_module.py +0 -31
  178. union/_internal/resolvers/common.py +0 -24
  179. union/_internal/resolvers/default.py +0 -27
  180. union/_internal/runtime/__init__.py +0 -0
  181. union/_internal/runtime/convert.py +0 -163
  182. union/_internal/runtime/entrypoints.py +0 -121
  183. union/_internal/runtime/io.py +0 -136
  184. union/_internal/runtime/resources_serde.py +0 -134
  185. union/_internal/runtime/task_serde.py +0 -202
  186. union/_internal/runtime/taskrunner.py +0 -179
  187. union/_internal/runtime/types_serde.py +0 -53
  188. union/_logging.py +0 -124
  189. union/_protos/__init__.py +0 -0
  190. union/_protos/common/authorization_pb2.py +0 -66
  191. union/_protos/common/authorization_pb2.pyi +0 -106
  192. union/_protos/common/identifier_pb2.py +0 -71
  193. union/_protos/common/identifier_pb2.pyi +0 -82
  194. union/_protos/common/identity_pb2.py +0 -48
  195. union/_protos/common/identity_pb2.pyi +0 -72
  196. union/_protos/common/identity_pb2_grpc.py +0 -4
  197. union/_protos/common/list_pb2.py +0 -36
  198. union/_protos/common/list_pb2.pyi +0 -69
  199. union/_protos/common/list_pb2_grpc.py +0 -4
  200. union/_protos/common/policy_pb2.py +0 -37
  201. union/_protos/common/policy_pb2.pyi +0 -27
  202. union/_protos/common/policy_pb2_grpc.py +0 -4
  203. union/_protos/common/role_pb2.py +0 -37
  204. union/_protos/common/role_pb2.pyi +0 -51
  205. union/_protos/common/role_pb2_grpc.py +0 -4
  206. union/_protos/common/runtime_version_pb2.py +0 -28
  207. union/_protos/common/runtime_version_pb2.pyi +0 -24
  208. union/_protos/common/runtime_version_pb2_grpc.py +0 -4
  209. union/_protos/logs/dataplane/payload_pb2.py +0 -96
  210. union/_protos/logs/dataplane/payload_pb2.pyi +0 -168
  211. union/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  212. union/_protos/secret/definition_pb2.py +0 -49
  213. union/_protos/secret/definition_pb2.pyi +0 -93
  214. union/_protos/secret/definition_pb2_grpc.py +0 -4
  215. union/_protos/secret/payload_pb2.py +0 -62
  216. union/_protos/secret/payload_pb2.pyi +0 -94
  217. union/_protos/secret/payload_pb2_grpc.py +0 -4
  218. union/_protos/secret/secret_pb2.py +0 -38
  219. union/_protos/secret/secret_pb2.pyi +0 -6
  220. union/_protos/secret/secret_pb2_grpc.py +0 -198
  221. union/_protos/validate/validate/validate_pb2.py +0 -76
  222. union/_protos/workflow/node_execution_service_pb2.py +0 -26
  223. union/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  224. union/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  225. union/_protos/workflow/queue_service_pb2.py +0 -75
  226. union/_protos/workflow/queue_service_pb2.pyi +0 -103
  227. union/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  228. union/_protos/workflow/run_definition_pb2.py +0 -100
  229. union/_protos/workflow/run_definition_pb2.pyi +0 -256
  230. union/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  231. union/_protos/workflow/run_logs_service_pb2.py +0 -41
  232. union/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  233. union/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  234. union/_protos/workflow/run_service_pb2.py +0 -133
  235. union/_protos/workflow/run_service_pb2.pyi +0 -173
  236. union/_protos/workflow/run_service_pb2_grpc.py +0 -412
  237. union/_protos/workflow/state_service_pb2.py +0 -58
  238. union/_protos/workflow/state_service_pb2.pyi +0 -69
  239. union/_protos/workflow/state_service_pb2_grpc.py +0 -138
  240. union/_protos/workflow/task_definition_pb2.py +0 -72
  241. union/_protos/workflow/task_definition_pb2.pyi +0 -65
  242. union/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  243. union/_protos/workflow/task_service_pb2.py +0 -44
  244. union/_protos/workflow/task_service_pb2.pyi +0 -31
  245. union/_protos/workflow/task_service_pb2_grpc.py +0 -104
  246. union/_resources.py +0 -226
  247. union/_retry.py +0 -32
  248. union/_reusable_environment.py +0 -25
  249. union/_run.py +0 -374
  250. union/_secret.py +0 -61
  251. union/_task.py +0 -354
  252. union/_task_environment.py +0 -186
  253. union/_timeout.py +0 -47
  254. union/_tools.py +0 -27
  255. union/_utils/__init__.py +0 -11
  256. union/_utils/asyn.py +0 -119
  257. union/_utils/file_handling.py +0 -71
  258. union/_utils/helpers.py +0 -46
  259. union/_utils/lazy_module.py +0 -54
  260. union/_utils/uv_script_parser.py +0 -49
  261. union/_version.py +0 -21
  262. union/connectors/__init__.py +0 -0
  263. union/errors.py +0 -128
  264. union/extras/__init__.py +0 -5
  265. union/extras/_container.py +0 -263
  266. union/io/__init__.py +0 -11
  267. union/io/_dataframe.py +0 -0
  268. union/io/_dir.py +0 -425
  269. union/io/_file.py +0 -418
  270. union/io/pickle/__init__.py +0 -0
  271. union/io/pickle/transformer.py +0 -117
  272. union/io/structured_dataset/__init__.py +0 -122
  273. union/io/structured_dataset/basic_dfs.py +0 -219
  274. union/io/structured_dataset/structured_dataset.py +0 -1057
  275. union/py.typed +0 -0
  276. union/remote/__init__.py +0 -23
  277. union/remote/_client/__init__.py +0 -0
  278. union/remote/_client/_protocols.py +0 -129
  279. union/remote/_client/auth/__init__.py +0 -12
  280. union/remote/_client/auth/_authenticators/__init__.py +0 -0
  281. union/remote/_client/auth/_authenticators/base.py +0 -391
  282. union/remote/_client/auth/_authenticators/client_credentials.py +0 -73
  283. union/remote/_client/auth/_authenticators/device_code.py +0 -120
  284. union/remote/_client/auth/_authenticators/external_command.py +0 -77
  285. union/remote/_client/auth/_authenticators/factory.py +0 -200
  286. union/remote/_client/auth/_authenticators/pkce.py +0 -515
  287. union/remote/_client/auth/_channel.py +0 -184
  288. union/remote/_client/auth/_client_config.py +0 -83
  289. union/remote/_client/auth/_default_html.py +0 -32
  290. union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  291. union/remote/_client/auth/_grpc_utils/auth_interceptor.py +0 -204
  292. union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +0 -144
  293. union/remote/_client/auth/_keyring.py +0 -154
  294. union/remote/_client/auth/_token_client.py +0 -258
  295. union/remote/_client/auth/errors.py +0 -16
  296. union/remote/_client/controlplane.py +0 -86
  297. union/remote/_data.py +0 -149
  298. union/remote/_logs.py +0 -74
  299. union/remote/_project.py +0 -86
  300. union/remote/_run.py +0 -820
  301. union/remote/_secret.py +0 -132
  302. union/remote/_task.py +0 -193
  303. union/report/__init__.py +0 -3
  304. union/report/_report.py +0 -178
  305. union/report/_template.html +0 -124
  306. union/storage/__init__.py +0 -24
  307. union/storage/_remote_fs.py +0 -34
  308. union/storage/_storage.py +0 -247
  309. union/storage/_utils.py +0 -5
  310. union/types/__init__.py +0 -11
  311. union/types/_renderer.py +0 -162
  312. union/types/_string_literals.py +0 -120
  313. union/types/_type_engine.py +0 -2131
  314. union/types/_utils.py +0 -80
  315. /union/_protos/common/authorization_pb2_grpc.py → /flyte/_protos/workflow/common_pb2_grpc.py +0 -0
  316. /union/_protos/common/identifier_pb2_grpc.py → /flyte/_protos/workflow/environment_pb2_grpc.py +0 -0
  317. /flyte/io/{structured_dataset → _structured_dataset}/__init__.py +0 -0
  318. {flyte-0.0.1b3.dist-info → flyte-0.2.0a0.dist-info}/WHEEL +0 -0
  319. {flyte-0.0.1b3.dist-info → flyte-0.2.0a0.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,79 @@
1
- from typing import Any, Tuple, TypeVar
1
+ import asyncio
2
+ import atexit
3
+ import concurrent.futures
4
+ import os
5
+ import threading
6
+ from typing import Any, Callable, Tuple, TypeVar
2
7
 
3
8
  import flyte.errors
4
9
  from flyte._context import internal_ctx
5
- from flyte._datastructures import ActionID, NativeInterface, RawDataPath
6
10
  from flyte._internal.controllers import TraceInfo
7
11
  from flyte._internal.runtime import convert
8
12
  from flyte._internal.runtime.entrypoints import direct_dispatch
9
13
  from flyte._logging import log, logger
10
14
  from flyte._protos.workflow import task_definition_pb2
11
15
  from flyte._task import TaskTemplate
16
+ from flyte._utils.helpers import _selector_policy
17
+ from flyte.models import ActionID, NativeInterface
12
18
 
13
19
  R = TypeVar("R")
14
20
 
15
21
 
22
+ class _TaskRunner:
23
+ """A task runner that runs an asyncio event loop on a background thread."""
24
+
25
+ def __init__(self) -> None:
26
+ self.__loop: asyncio.AbstractEventLoop | None = None
27
+ self.__runner_thread: threading.Thread | None = None
28
+ self.__lock = threading.Lock()
29
+ atexit.register(self._close)
30
+
31
+ def _close(self) -> None:
32
+ if self.__loop:
33
+ self.__loop.stop()
34
+
35
+ def _execute(self) -> None:
36
+ loop = self.__loop
37
+ assert loop is not None
38
+ try:
39
+ loop.run_forever()
40
+ finally:
41
+ loop.close()
42
+
43
+ def get_exc_handler(self):
44
+ def exc_handler(loop, context):
45
+ logger.error(
46
+ f"Taskrunner for {self.__runner_thread.name if self.__runner_thread else 'no thread'} caught"
47
+ f" exception in {loop}: {context}"
48
+ )
49
+
50
+ return exc_handler
51
+
52
+ def get_run_future(self, coro: Any) -> concurrent.futures.Future:
53
+ """Synchronously run a coroutine on a background thread."""
54
+ name = f"{threading.current_thread().name} : loop-runner"
55
+ with self.__lock:
56
+ if self.__loop is None:
57
+ with _selector_policy():
58
+ self.__loop = asyncio.new_event_loop()
59
+
60
+ exc_handler = self.get_exc_handler()
61
+ self.__loop.set_exception_handler(exc_handler)
62
+ self.__runner_thread = threading.Thread(target=self._execute, daemon=True, name=name)
63
+ self.__runner_thread.start()
64
+ fut = asyncio.run_coroutine_threadsafe(coro, self.__loop)
65
+ return fut
66
+
67
+
16
68
  class LocalController:
17
69
  def __init__(self):
18
70
  logger.debug("LocalController init")
71
+ self._runner_map: dict[str, _TaskRunner] = {}
19
72
 
20
73
  @log
21
74
  async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
22
75
  """
23
- Submit a node to the controller
76
+ Main entrypoint for submitting a task to the local controller.
24
77
  """
25
78
  ctx = internal_ctx()
26
79
  tctx = ctx.data.task_context
@@ -28,8 +81,12 @@ class LocalController:
28
81
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
29
82
 
30
83
  inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
31
- sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, _task.name, inputs)
32
- sub_action_raw_data_path = RawDataPath(path=sub_action_output_path)
84
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
85
+
86
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
87
+ tctx, _task.name, serialized_inputs, 0
88
+ )
89
+ sub_action_raw_data_path = tctx.raw_data_path
33
90
 
34
91
  out, err = await direct_dispatch(
35
92
  _task,
@@ -54,6 +111,18 @@ class LocalController:
54
111
  return result
55
112
  return out
56
113
 
114
+ def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
115
+ name = threading.current_thread().name + f"PID:{os.getpid()}"
116
+ coro = self.submit(_task, *args, **kwargs)
117
+ if name not in self._runner_map:
118
+ if len(self._runner_map) > 100:
119
+ logger.warning(
120
+ "More than 100 event loop runners created!!! This could be a case of runaway recursion..."
121
+ )
122
+ self._runner_map[name] = _TaskRunner()
123
+
124
+ return self._runner_map[name].get_run_future(coro)
125
+
57
126
  async def finalize_parent_action(self, action: ActionID):
58
127
  pass
59
128
 
@@ -64,7 +133,7 @@ class LocalController:
64
133
  pass
65
134
 
66
135
  async def get_action_outputs(
67
- self, _interface: NativeInterface, _func_name: str, *args, **kwargs
136
+ self, _interface: NativeInterface, _func: Callable, *args, **kwargs
68
137
  ) -> Tuple[TraceInfo, bool]:
69
138
  """
70
139
  This method returns the outputs of the action, if it is available.
@@ -79,8 +148,13 @@ class LocalController:
79
148
  if _interface.inputs:
80
149
  converted_inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
81
150
  assert converted_inputs
151
+
152
+ serialized_inputs = converted_inputs.proto_inputs.SerializeToString(deterministic=True)
82
153
  action_id, action_output_path = convert.generate_sub_action_id_and_output_path(
83
- tctx, _func_name, converted_inputs
154
+ tctx,
155
+ _func.__name__,
156
+ serialized_inputs,
157
+ 0,
84
158
  )
85
159
  assert action_output_path
86
160
  return (
@@ -88,6 +162,7 @@ class LocalController:
88
162
  action=action_id,
89
163
  interface=_interface,
90
164
  inputs_path=action_output_path,
165
+ name=_func.__name__,
91
166
  ),
92
167
  True,
93
168
  )
@@ -105,7 +180,7 @@ class LocalController:
105
180
 
106
181
  if info.interface.outputs and info.output:
107
182
  # If the result is not an AsyncGenerator, convert it directly
108
- converted_outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
183
+ converted_outputs = await convert.convert_from_native_to_outputs(info.output, info.interface, info.name)
109
184
  assert converted_outputs
110
185
  elif info.error:
111
186
  # If there is an error, convert it to a native error
@@ -2,7 +2,7 @@ from dataclasses import dataclass
2
2
  from datetime import timedelta
3
3
  from typing import Any, Optional
4
4
 
5
- from flyte._datastructures import ActionID, NativeInterface
5
+ from flyte.models import ActionID, NativeInterface
6
6
 
7
7
 
8
8
  @dataclass
@@ -18,6 +18,7 @@ class TraceInfo:
18
18
  duration: Optional[timedelta] = None
19
19
  output: Optional[Any] = None
20
20
  error: Optional[Exception] = None
21
+ name: str = ""
21
22
 
22
23
  def add_outputs(self, output: Any, duration: timedelta):
23
24
  """
@@ -10,13 +10,13 @@ __all__ = ["RemoteController", "create_remote_controller"]
10
10
  def create_remote_controller(
11
11
  *,
12
12
  api_key: str | None = None,
13
- auth_type: AuthType = "Pkce",
14
- endpoint: str,
15
- client_config: ClientConfig | None = None,
16
- headless: bool = False,
13
+ endpoint: str | None = None,
17
14
  insecure: bool = False,
18
15
  insecure_skip_verify: bool = False,
19
16
  ca_cert_file_path: str | None = None,
17
+ client_config: ClientConfig | None = None,
18
+ auth_type: AuthType = "Pkce",
19
+ headless: bool = False,
20
20
  command: List[str] | None = None,
21
21
  proxy_command: List[str] | None = None,
22
22
  client_id: str | None = None,
@@ -27,13 +27,33 @@ def create_remote_controller(
27
27
  """
28
28
  Create a new instance of the remote controller.
29
29
  """
30
+ assert endpoint or api_key, "Either endpoint or api_key must be provided when initializing remote controller"
30
31
  from ._client import ControllerClient
31
32
  from ._controller import RemoteController
32
33
 
34
+ if endpoint:
35
+ client_coro = ControllerClient.for_endpoint(
36
+ endpoint,
37
+ insecure=insecure,
38
+ insecure_skip_verify=insecure_skip_verify,
39
+ ca_cert_file_path=ca_cert_file_path,
40
+ client_id=client_id,
41
+ client_credentials_secret=client_credentials_secret,
42
+ auth_type=auth_type,
43
+ )
44
+ elif api_key:
45
+ client_coro = ControllerClient.for_api_key(
46
+ api_key,
47
+ insecure=insecure,
48
+ insecure_skip_verify=insecure_skip_verify,
49
+ ca_cert_file_path=ca_cert_file_path,
50
+ client_id=client_id,
51
+ client_credentials_secret=client_credentials_secret,
52
+ auth_type=auth_type,
53
+ )
54
+
33
55
  controller = RemoteController(
34
- client_coro=ControllerClient.for_endpoint(
35
- endpoint=endpoint, insecure=insecure, insecure_skip_verify=insecure_skip_verify
36
- ),
56
+ client_coro=client_coro,
37
57
  workers=10,
38
58
  max_system_retries=5,
39
59
  )
@@ -4,8 +4,8 @@ from dataclasses import dataclass
4
4
 
5
5
  from flyteidl.core import execution_pb2
6
6
 
7
- from flyte._datastructures import GroupData
8
7
  from flyte._protos.workflow import run_definition_pb2, state_service_pb2, task_definition_pb2
8
+ from flyte.models import GroupData
9
9
 
10
10
 
11
11
  @dataclass
@@ -28,6 +28,7 @@ class Action:
28
28
  started: bool = False
29
29
  retries: int = 0
30
30
  client_err: Exception | None = None # This error is set when something goes wrong in the controller.
31
+ cache_key: str | None = None # None means no caching, otherwise it is the version of the cache.
31
32
 
32
33
  @property
33
34
  def name(self) -> str:
@@ -91,6 +92,8 @@ class Action:
91
92
  if not self.started:
92
93
  self.task = action.task
93
94
 
95
+ self.cache_key = action.cache_key
96
+
94
97
  def set_client_error(self, exc: Exception):
95
98
  self.client_err = exc
96
99
 
@@ -106,6 +109,7 @@ class Action:
106
109
  task_spec: task_definition_pb2.TaskSpec,
107
110
  inputs_uri: str,
108
111
  run_output_base: str,
112
+ cache_key: str | None = None,
109
113
  ) -> Action:
110
114
  return cls(
111
115
  action_id=sub_action_id,
@@ -115,6 +119,7 @@ class Action:
115
119
  task=task_spec,
116
120
  inputs_uri=inputs_uri,
117
121
  run_output_base=run_output_base,
122
+ cache_key=cache_key,
118
123
  )
119
124
 
120
125
  @classmethod
@@ -130,7 +135,7 @@ class Action:
130
135
  """
131
136
  from flyte._logging import logger
132
137
 
133
- logger.info(f"In Action from_state {obj.action_id} {obj.phase} {obj.output_uri}")
138
+ logger.debug(f"In Action from_state {obj.action_id} {obj.phase} {obj.output_uri}")
134
139
  return cls(
135
140
  action_id=obj.action_id,
136
141
  parent_action_name=parent_action_name,
@@ -20,7 +20,11 @@ class ControllerClient:
20
20
 
21
21
  @classmethod
22
22
  async def for_endpoint(cls, endpoint: str, insecure: bool = False, **kwargs) -> ControllerClient:
23
- return cls(await create_channel(endpoint, insecure=insecure, **kwargs))
23
+ return cls(await create_channel(endpoint, None, insecure=insecure, **kwargs))
24
+
25
+ @classmethod
26
+ async def for_api_key(cls, api_key: str, insecure: bool = False, **kwargs) -> ControllerClient:
27
+ return cls(await create_channel(None, api_key, insecure=insecure, **kwargs))
24
28
 
25
29
  @property
26
30
  def state_service(self) -> StateService:
@@ -1,9 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
+ import concurrent.futures
5
+ import os
6
+ import threading
4
7
  from collections import defaultdict
8
+ from collections.abc import Callable
5
9
  from pathlib import Path
6
- from typing import Any, Awaitable, DefaultDict, Tuple, TypeVar
10
+ from typing import Any, AsyncIterable, Awaitable, DefaultDict, Tuple, TypeVar
7
11
 
8
12
  import flyte
9
13
  import flyte.errors
@@ -11,7 +15,6 @@ import flyte.storage as storage
11
15
  import flyte.types as types
12
16
  from flyte._code_bundle import build_pkl_bundle
13
17
  from flyte._context import internal_ctx
14
- from flyte._datastructures import ActionID, NativeInterface, SerializationContext
15
18
  from flyte._internal.controllers import TraceInfo
16
19
  from flyte._internal.controllers.remote._action import Action
17
20
  from flyte._internal.controllers.remote._core import Controller
@@ -21,16 +24,18 @@ from flyte._internal.runtime.task_serde import translate_task_to_wire
21
24
  from flyte._logging import logger
22
25
  from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
23
26
  from flyte._task import TaskTemplate
27
+ from flyte._utils.helpers import _selector_policy
28
+ from flyte.models import ActionID, NativeInterface, SerializationContext
24
29
 
25
30
  R = TypeVar("R")
26
31
 
27
32
 
28
- async def upload_inputs_with_retry(inputs: convert.Inputs, inputs_uri: str) -> None:
33
+ async def upload_inputs_with_retry(serialized_inputs: AsyncIterable[bytes] | bytes, inputs_uri: str) -> None:
29
34
  """
30
35
  Upload inputs to the specified URI with error handling.
31
36
 
32
37
  Args:
33
- inputs: The inputs to upload
38
+ serialized_inputs: The serialized inputs to upload
34
39
  inputs_uri: The destination URI
35
40
 
36
41
  Raises:
@@ -38,9 +43,9 @@ async def upload_inputs_with_retry(inputs: convert.Inputs, inputs_uri: str) -> N
38
43
  """
39
44
  try:
40
45
  # TODO Add retry decorator to this
41
- await io.upload_inputs(inputs, inputs_uri)
46
+ await storage.put_stream(serialized_inputs, to_path=inputs_uri)
42
47
  except Exception as e:
43
- logger.exception("Failed to upload inputs", e)
48
+ logger.exception("Failed to upload inputs")
44
49
  raise flyte.errors.RuntimeSystemError(type(e).__name__, str(e)) from e
45
50
 
46
51
 
@@ -89,6 +94,10 @@ async def load_and_convert_outputs(iface: NativeInterface, realized_outputs_uri:
89
94
  return await convert.convert_outputs_to_native(iface, outputs)
90
95
 
91
96
 
97
+ def unique_action_name(action_id: ActionID) -> str:
98
+ return f"{action_id.name}_{action_id.run_name}"
99
+
100
+
92
101
  class RemoteController(Controller):
93
102
  """
94
103
  This a specialized controller that wraps the core controller and performs IO, serialization and deserialization
@@ -99,7 +108,7 @@ class RemoteController(Controller):
99
108
  client_coro: Awaitable[ClientSet],
100
109
  workers: int,
101
110
  max_system_retries: int,
102
- default_parent_concurrency: int = 100,
111
+ default_parent_concurrency: int = 1000,
103
112
  ):
104
113
  """ """
105
114
  super().__init__(
@@ -111,31 +120,44 @@ class RemoteController(Controller):
111
120
  self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
112
121
  lambda: asyncio.Semaphore(default_parent_concurrency)
113
122
  )
123
+ self._parent_action_task_call_sequence: DefaultDict[str, DefaultDict[int, int]] = defaultdict(
124
+ lambda: defaultdict(int)
125
+ )
126
+ self._submit_loop: asyncio.AbstractEventLoop | None = None
127
+ self._submit_thread: threading.Thread | None = None
114
128
 
115
- async def _submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
129
+ def generate_task_call_sequence(self, task_obj: object, action_id: ActionID) -> int:
130
+ """
131
+ Generate a task call sequence for the given task object and action ID.
132
+ This is used to track the number of times a task is called within an action.
133
+ """
134
+ current_action_sequencer = self._parent_action_task_call_sequence[unique_action_name(action_id)]
135
+ current_task_id = id(task_obj)
136
+ v = current_action_sequencer[current_task_id]
137
+ new_seq = v + 1
138
+ current_action_sequencer[current_task_id] = new_seq
139
+ return new_seq
140
+
141
+ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwargs) -> Any:
116
142
  ctx = internal_ctx()
117
143
  tctx = ctx.data.task_context
118
144
  if tctx is None:
119
145
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
120
146
  current_action_id = tctx.action
121
147
 
122
- inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
123
- sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, _task.name, inputs)
124
-
125
148
  # In the case of a regular code bundle, we will just pass it down as it is to the downstream tasks
126
149
  # It is not allowed to change the code bundle (for regular code bundles) in the middle of a run.
127
150
  code_bundle = tctx.code_bundle
128
151
 
129
152
  if code_bundle and code_bundle.pkl:
130
- logger.debug(f"Building new pkl bundle for task {sub_action_id.name}")
153
+ logger.debug(f"Building new pkl bundle for task {_task.name}")
131
154
  code_bundle = await build_pkl_bundle(
132
155
  _task,
133
156
  upload_to_controlplane=False,
134
- upload_from_dataplane_path=io.pkl_path(sub_action_output_path),
157
+ upload_from_dataplane_base_path=tctx.run_base_dir,
135
158
  )
136
159
 
137
- inputs_uri = io.inputs_path(sub_action_output_path)
138
- await upload_inputs_with_retry(inputs, inputs_uri)
160
+ inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
139
161
 
140
162
  root_dir = Path(code_bundle.destination).absolute() if code_bundle else Path.cwd()
141
163
  # Don't set output path in sec context because node executor will set it
@@ -146,12 +168,41 @@ class RemoteController(Controller):
146
168
  code_bundle=code_bundle,
147
169
  version=tctx.version,
148
170
  # supplied version.
149
- input_path=inputs_uri,
171
+ # input_path=inputs_uri,
150
172
  image_cache=tctx.compiled_image_cache,
151
173
  root_dir=root_dir,
152
174
  )
153
175
 
154
176
  task_spec = translate_task_to_wire(_task, new_serialization_context)
177
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
178
+
179
+ inputs_hash = convert.generate_inputs_hash(serialized_inputs)
180
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
181
+ tctx, task_spec, inputs_hash, _task_call_seq
182
+ )
183
+
184
+ inputs_uri = io.inputs_path(sub_action_output_path)
185
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri)
186
+
187
+ md = task_spec.task_template.metadata
188
+ ignored_input_vars = []
189
+ if len(md.cache_ignore_input_vars) > 0:
190
+ ignored_input_vars = list(md.cache_ignore_input_vars)
191
+ cache_key = None
192
+ if task_spec.task_template.metadata and task_spec.task_template.metadata.discoverable:
193
+ discovery_version = task_spec.task_template.metadata.discovery_version
194
+ cache_key = convert.generate_cache_key_hash(
195
+ _task.name,
196
+ inputs_hash,
197
+ task_spec.task_template.interface,
198
+ discovery_version,
199
+ ignored_input_vars,
200
+ inputs.proto_inputs,
201
+ )
202
+
203
+ # Clear to free memory
204
+ serialized_inputs = None # type: ignore
205
+ inputs_hash = None # type: ignore
155
206
 
156
207
  action = Action.from_task(
157
208
  sub_action_id=run_definition_pb2.ActionIdentifier(
@@ -168,6 +219,7 @@ class RemoteController(Controller):
168
219
  task_spec=task_spec,
169
220
  inputs_uri=inputs_uri,
170
221
  run_output_base=tctx.run_base_dir,
222
+ cache_key=cache_key,
171
223
  )
172
224
 
173
225
  try:
@@ -205,8 +257,51 @@ class RemoteController(Controller):
205
257
  if tctx is None:
206
258
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
207
259
  current_action_id = tctx.action
208
- async with self._parent_action_semaphore[current_action_id.name]:
209
- return await self._submit(_task, *args, **kwargs)
260
+ task_call_seq = self.generate_task_call_sequence(_task, current_action_id)
261
+ async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
262
+ return await self._submit(task_call_seq, _task, *args, **kwargs)
263
+
264
+ def _sync_thread_loop_runner(self) -> None:
265
+ """This method runs the event loop and should be invoked in a separate thread."""
266
+
267
+ loop = self._submit_loop
268
+ assert loop is not None
269
+ try:
270
+ loop.run_forever()
271
+ finally:
272
+ loop.close()
273
+
274
+ def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
275
+ """
276
+ This function creates a cached thread and loop for the purpose of calling the submit method synchronously,
277
+ returning a concurrent Future that can be awaited. There's no need for a lock because this function itself is
278
+ single threaded and non-async. This pattern here is basically the trivial/degenerate case of the thread pool
279
+ in the LocalController.
280
+ Please see additional comments in protocol.
281
+
282
+ :param _task:
283
+ :param args:
284
+ :param kwargs:
285
+ :return:
286
+ """
287
+ if self._submit_thread is None:
288
+ # Please see LocalController for the general implementation of this pattern.
289
+ def exc_handler(loop, context):
290
+ logger.error(f"Remote controller submit sync loop caught exception in {loop}: {context}")
291
+
292
+ with _selector_policy():
293
+ self._submit_loop = asyncio.new_event_loop()
294
+ self._submit_loop.set_exception_handler(exc_handler)
295
+
296
+ self._submit_thread = threading.Thread(
297
+ name=f"remote-controller-{os.getpid()}-submitter", daemon=True, target=self._sync_thread_loop_runner
298
+ )
299
+ self._submit_thread.start()
300
+
301
+ coro = self.submit(_task, *args, **kwargs)
302
+ assert self._submit_loop is not None, "Submit loop should always have been initialized by now"
303
+ fut = asyncio.run_coroutine_threadsafe(coro, self._submit_loop)
304
+ return fut
210
305
 
211
306
  async def finalize_parent_action(self, action_id: ActionID):
212
307
  """
@@ -220,16 +315,17 @@ class RemoteController(Controller):
220
315
  org=action_id.org,
221
316
  )
222
317
  await super()._finalize_parent_action(run_id=run_id, parent_action_name=action_id.name)
223
- self._parent_action_semaphore.pop(action_id.name, None)
318
+ self._parent_action_semaphore.pop(unique_action_name(action_id), None)
319
+ self._parent_action_task_call_sequence.pop(unique_action_name(action_id), None)
224
320
 
225
321
  async def get_action_outputs(
226
- self, _interface: NativeInterface, _func_name: str, *args, **kwargs
322
+ self, _interface: NativeInterface, _func: Callable, *args, **kwargs
227
323
  ) -> Tuple[TraceInfo, bool]:
228
324
  """
229
325
  This method returns the outputs of the action, if it is available.
230
326
  If not available it raises a NotFoundError.
231
327
  :param _interface: NativeInterface
232
- :param _func_name: Function name
328
+ :param _func: Function name
233
329
  :param args: Arguments
234
330
  :param kwargs: Keyword arguments
235
331
  :return:
@@ -240,11 +336,19 @@ class RemoteController(Controller):
240
336
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
241
337
  current_action_id = tctx.action
242
338
 
339
+ func_name = _func.__name__
340
+ invoke_seq_num = self.generate_task_call_sequence(_func, current_action_id)
243
341
  inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
244
- sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, _func_name, inputs)
342
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
343
+
344
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
345
+ tctx, func_name, serialized_inputs, invoke_seq_num
346
+ )
245
347
 
246
348
  inputs_uri = io.inputs_path(sub_action_output_path)
247
- await upload_inputs_with_retry(inputs, inputs_uri)
349
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri)
350
+ # Clear to free memory
351
+ serialized_inputs = None # type: ignore
248
352
 
249
353
  prev_action = await self.get_action(
250
354
  run_definition_pb2.ActionIdentifier(
@@ -310,12 +414,40 @@ class RemoteController(Controller):
310
414
  current_action_id = tctx.action
311
415
  task_name = _task.spec.task_template.id.name
312
416
 
313
- native_interface = types.guess_interface(_task.spec.task_template.interface)
417
+ invoke_seq_num = self.generate_task_call_sequence(_task, current_action_id)
418
+
419
+ native_interface = types.guess_interface(
420
+ _task.spec.task_template.interface, default_inputs=_task.spec.default_inputs
421
+ )
314
422
  inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs)
315
- sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, task_name, inputs)
423
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
424
+ inputs_hash = convert.generate_inputs_hash(serialized_inputs)
425
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
426
+ tctx, task_name, inputs_hash, invoke_seq_num
427
+ )
316
428
 
317
429
  inputs_uri = io.inputs_path(sub_action_output_path)
318
- await upload_inputs_with_retry(inputs, inputs_uri)
430
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri)
431
+ # cache key - task name, task signature, inputs, cache version
432
+ cache_key = None
433
+ md = _task.spec.task_template.metadata
434
+ ignored_input_vars = []
435
+ if len(md.cache_ignore_input_vars) > 0:
436
+ ignored_input_vars = list(md.cache_ignore_input_vars)
437
+ if _task.spec.task_template.metadata and _task.spec.task_template.metadata.discoverable:
438
+ discovery_version = _task.spec.task_template.metadata.discovery_version
439
+ cache_key = convert.generate_cache_key_hash(
440
+ task_name,
441
+ inputs_hash,
442
+ _task.spec.task_template.interface,
443
+ discovery_version,
444
+ ignored_input_vars,
445
+ inputs.proto_inputs,
446
+ )
447
+
448
+ # Clear to free memory
449
+ serialized_inputs = None # type: ignore
450
+ inputs_hash = None # type: ignore
319
451
 
320
452
  action = Action.from_task(
321
453
  sub_action_id=run_definition_pb2.ActionIdentifier(
@@ -332,6 +464,7 @@ class RemoteController(Controller):
332
464
  task_spec=_task.spec,
333
465
  inputs_uri=inputs_uri,
334
466
  run_output_base=tctx.run_base_dir,
467
+ cache_key=cache_key,
335
468
  )
336
469
 
337
470
  try:
@@ -7,6 +7,7 @@ from asyncio import Event
7
7
  from typing import Awaitable, Coroutine, Optional
8
8
 
9
9
  import grpc.aio
10
+ from google.protobuf.wrappers_pb2 import StringValue
10
11
 
11
12
  import flyte.errors
12
13
  from flyte._logging import log, logger
@@ -32,7 +33,7 @@ class Controller:
32
33
  max_system_retries: int = 5,
33
34
  resource_log_interval_sec: float = 10.0,
34
35
  min_backoff_on_err_sec: float = 0.1,
35
- thread_wait_timeout_sec: float = 10.0,
36
+ thread_wait_timeout_sec: float = 5.0,
36
37
  enqueue_timeout_sec: float = 5.0,
37
38
  ):
38
39
  """
@@ -286,10 +287,11 @@ class Controller:
286
287
  if started:
287
288
  logger.info(f"Cancelling action: {action.name}")
288
289
  try:
289
- await self._queue_service.AbortQueuedAction(
290
- queue_service_pb2.AbortQueuedActionRequest(action_id=action.action_id),
291
- wait_for_ready=True,
292
- )
290
+ # TODO add support when the queue service supports aborting actions
291
+ # await self._queue_service.AbortQueuedAction(
292
+ # queue_service_pb2.AbortQueuedActionRequest(action_id=action.action_id),
293
+ # wait_for_ready=True,
294
+ # )
293
295
  logger.info(f"Successfully cancelled action: {action.name}")
294
296
  except grpc.aio.AioRpcError as e:
295
297
  if e.code() in [grpc.StatusCode.NOT_FOUND, grpc.StatusCode.FAILED_PRECONDITION]:
@@ -310,6 +312,11 @@ class Controller:
310
312
  if not action.is_started() and action.task is not None:
311
313
  logger.debug(f"Attempting to launch action: {action.name}")
312
314
  try:
315
+ cache_key = None
316
+ logger.info(f"Action {action.name} has cache version {action.cache_key}")
317
+ if action.cache_key:
318
+ cache_key = StringValue(value=action.cache_key)
319
+
313
320
  await self._queue_service.EnqueueAction(
314
321
  queue_service_pb2.EnqueueActionRequest(
315
322
  action_id=action.action_id,
@@ -323,6 +330,7 @@ class Controller:
323
330
  name=action.task.task_template.id.name,
324
331
  ),
325
332
  spec=action.task,
333
+ cache_key=cache_key,
326
334
  ),
327
335
  input_uri=action.inputs_uri,
328
336
  run_output_base=action.run_output_base,