flyte 0.1.0__py3-none-any.whl → 0.2.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (219) hide show
  1. flyte/__init__.py +78 -2
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/runtime.py +152 -0
  4. flyte/_build.py +26 -0
  5. flyte/_cache/__init__.py +12 -0
  6. flyte/_cache/cache.py +145 -0
  7. flyte/_cache/defaults.py +9 -0
  8. flyte/_cache/policy_function_body.py +42 -0
  9. flyte/_code_bundle/__init__.py +8 -0
  10. flyte/_code_bundle/_ignore.py +113 -0
  11. flyte/_code_bundle/_packaging.py +187 -0
  12. flyte/_code_bundle/_utils.py +323 -0
  13. flyte/_code_bundle/bundle.py +209 -0
  14. flyte/_context.py +152 -0
  15. flyte/_deploy.py +243 -0
  16. flyte/_doc.py +29 -0
  17. flyte/_docstring.py +32 -0
  18. flyte/_environment.py +84 -0
  19. flyte/_excepthook.py +37 -0
  20. flyte/_group.py +32 -0
  21. flyte/_hash.py +23 -0
  22. flyte/_image.py +762 -0
  23. flyte/_initialize.py +492 -0
  24. flyte/_interface.py +84 -0
  25. flyte/_internal/__init__.py +3 -0
  26. flyte/_internal/controllers/__init__.py +128 -0
  27. flyte/_internal/controllers/_local_controller.py +193 -0
  28. flyte/_internal/controllers/_trace.py +41 -0
  29. flyte/_internal/controllers/remote/__init__.py +60 -0
  30. flyte/_internal/controllers/remote/_action.py +146 -0
  31. flyte/_internal/controllers/remote/_client.py +47 -0
  32. flyte/_internal/controllers/remote/_controller.py +494 -0
  33. flyte/_internal/controllers/remote/_core.py +410 -0
  34. flyte/_internal/controllers/remote/_informer.py +361 -0
  35. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  36. flyte/_internal/imagebuild/__init__.py +11 -0
  37. flyte/_internal/imagebuild/docker_builder.py +427 -0
  38. flyte/_internal/imagebuild/image_builder.py +246 -0
  39. flyte/_internal/imagebuild/remote_builder.py +0 -0
  40. flyte/_internal/resolvers/__init__.py +0 -0
  41. flyte/_internal/resolvers/_task_module.py +54 -0
  42. flyte/_internal/resolvers/common.py +31 -0
  43. flyte/_internal/resolvers/default.py +28 -0
  44. flyte/_internal/runtime/__init__.py +0 -0
  45. flyte/_internal/runtime/convert.py +342 -0
  46. flyte/_internal/runtime/entrypoints.py +135 -0
  47. flyte/_internal/runtime/io.py +136 -0
  48. flyte/_internal/runtime/resources_serde.py +138 -0
  49. flyte/_internal/runtime/task_serde.py +330 -0
  50. flyte/_internal/runtime/taskrunner.py +191 -0
  51. flyte/_internal/runtime/types_serde.py +54 -0
  52. flyte/_logging.py +135 -0
  53. flyte/_map.py +215 -0
  54. flyte/_pod.py +19 -0
  55. flyte/_protos/__init__.py +0 -0
  56. flyte/_protos/common/authorization_pb2.py +66 -0
  57. flyte/_protos/common/authorization_pb2.pyi +108 -0
  58. flyte/_protos/common/authorization_pb2_grpc.py +4 -0
  59. flyte/_protos/common/identifier_pb2.py +71 -0
  60. flyte/_protos/common/identifier_pb2.pyi +82 -0
  61. flyte/_protos/common/identifier_pb2_grpc.py +4 -0
  62. flyte/_protos/common/identity_pb2.py +48 -0
  63. flyte/_protos/common/identity_pb2.pyi +72 -0
  64. flyte/_protos/common/identity_pb2_grpc.py +4 -0
  65. flyte/_protos/common/list_pb2.py +36 -0
  66. flyte/_protos/common/list_pb2.pyi +71 -0
  67. flyte/_protos/common/list_pb2_grpc.py +4 -0
  68. flyte/_protos/common/policy_pb2.py +37 -0
  69. flyte/_protos/common/policy_pb2.pyi +27 -0
  70. flyte/_protos/common/policy_pb2_grpc.py +4 -0
  71. flyte/_protos/common/role_pb2.py +37 -0
  72. flyte/_protos/common/role_pb2.pyi +53 -0
  73. flyte/_protos/common/role_pb2_grpc.py +4 -0
  74. flyte/_protos/common/runtime_version_pb2.py +28 -0
  75. flyte/_protos/common/runtime_version_pb2.pyi +24 -0
  76. flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
  77. flyte/_protos/logs/dataplane/payload_pb2.py +100 -0
  78. flyte/_protos/logs/dataplane/payload_pb2.pyi +177 -0
  79. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  80. flyte/_protos/secret/definition_pb2.py +49 -0
  81. flyte/_protos/secret/definition_pb2.pyi +93 -0
  82. flyte/_protos/secret/definition_pb2_grpc.py +4 -0
  83. flyte/_protos/secret/payload_pb2.py +62 -0
  84. flyte/_protos/secret/payload_pb2.pyi +94 -0
  85. flyte/_protos/secret/payload_pb2_grpc.py +4 -0
  86. flyte/_protos/secret/secret_pb2.py +38 -0
  87. flyte/_protos/secret/secret_pb2.pyi +6 -0
  88. flyte/_protos/secret/secret_pb2_grpc.py +198 -0
  89. flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
  90. flyte/_protos/validate/validate/validate_pb2.py +76 -0
  91. flyte/_protos/workflow/common_pb2.py +27 -0
  92. flyte/_protos/workflow/common_pb2.pyi +14 -0
  93. flyte/_protos/workflow/common_pb2_grpc.py +4 -0
  94. flyte/_protos/workflow/environment_pb2.py +29 -0
  95. flyte/_protos/workflow/environment_pb2.pyi +12 -0
  96. flyte/_protos/workflow/environment_pb2_grpc.py +4 -0
  97. flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
  98. flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  99. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  100. flyte/_protos/workflow/queue_service_pb2.py +105 -0
  101. flyte/_protos/workflow/queue_service_pb2.pyi +146 -0
  102. flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  103. flyte/_protos/workflow/run_definition_pb2.py +128 -0
  104. flyte/_protos/workflow/run_definition_pb2.pyi +314 -0
  105. flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  106. flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
  107. flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  108. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  109. flyte/_protos/workflow/run_service_pb2.py +129 -0
  110. flyte/_protos/workflow/run_service_pb2.pyi +171 -0
  111. flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
  112. flyte/_protos/workflow/state_service_pb2.py +66 -0
  113. flyte/_protos/workflow/state_service_pb2.pyi +75 -0
  114. flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
  115. flyte/_protos/workflow/task_definition_pb2.py +79 -0
  116. flyte/_protos/workflow/task_definition_pb2.pyi +81 -0
  117. flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  118. flyte/_protos/workflow/task_service_pb2.py +60 -0
  119. flyte/_protos/workflow/task_service_pb2.pyi +59 -0
  120. flyte/_protos/workflow/task_service_pb2_grpc.py +138 -0
  121. flyte/_resources.py +226 -0
  122. flyte/_retry.py +32 -0
  123. flyte/_reusable_environment.py +25 -0
  124. flyte/_run.py +482 -0
  125. flyte/_secret.py +61 -0
  126. flyte/_task.py +449 -0
  127. flyte/_task_environment.py +183 -0
  128. flyte/_timeout.py +47 -0
  129. flyte/_tools.py +27 -0
  130. flyte/_trace.py +120 -0
  131. flyte/_utils/__init__.py +26 -0
  132. flyte/_utils/asyn.py +119 -0
  133. flyte/_utils/async_cache.py +139 -0
  134. flyte/_utils/coro_management.py +23 -0
  135. flyte/_utils/file_handling.py +72 -0
  136. flyte/_utils/helpers.py +134 -0
  137. flyte/_utils/lazy_module.py +54 -0
  138. flyte/_utils/org_discovery.py +57 -0
  139. flyte/_utils/uv_script_parser.py +49 -0
  140. flyte/_version.py +21 -0
  141. flyte/cli/__init__.py +3 -0
  142. flyte/cli/_abort.py +28 -0
  143. flyte/cli/_common.py +337 -0
  144. flyte/cli/_create.py +145 -0
  145. flyte/cli/_delete.py +23 -0
  146. flyte/cli/_deploy.py +152 -0
  147. flyte/cli/_gen.py +163 -0
  148. flyte/cli/_get.py +310 -0
  149. flyte/cli/_params.py +538 -0
  150. flyte/cli/_run.py +231 -0
  151. flyte/cli/main.py +166 -0
  152. flyte/config/__init__.py +3 -0
  153. flyte/config/_config.py +216 -0
  154. flyte/config/_internal.py +64 -0
  155. flyte/config/_reader.py +207 -0
  156. flyte/connectors/__init__.py +0 -0
  157. flyte/errors.py +172 -0
  158. flyte/extras/__init__.py +5 -0
  159. flyte/extras/_container.py +263 -0
  160. flyte/io/__init__.py +27 -0
  161. flyte/io/_dir.py +448 -0
  162. flyte/io/_file.py +467 -0
  163. flyte/io/_structured_dataset/__init__.py +129 -0
  164. flyte/io/_structured_dataset/basic_dfs.py +219 -0
  165. flyte/io/_structured_dataset/structured_dataset.py +1061 -0
  166. flyte/models.py +391 -0
  167. flyte/remote/__init__.py +26 -0
  168. flyte/remote/_client/__init__.py +0 -0
  169. flyte/remote/_client/_protocols.py +133 -0
  170. flyte/remote/_client/auth/__init__.py +12 -0
  171. flyte/remote/_client/auth/_auth_utils.py +14 -0
  172. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  173. flyte/remote/_client/auth/_authenticators/base.py +397 -0
  174. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  175. flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
  176. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  177. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  178. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  179. flyte/remote/_client/auth/_channel.py +215 -0
  180. flyte/remote/_client/auth/_client_config.py +83 -0
  181. flyte/remote/_client/auth/_default_html.py +32 -0
  182. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  183. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  184. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  185. flyte/remote/_client/auth/_keyring.py +143 -0
  186. flyte/remote/_client/auth/_token_client.py +260 -0
  187. flyte/remote/_client/auth/errors.py +16 -0
  188. flyte/remote/_client/controlplane.py +95 -0
  189. flyte/remote/_console.py +18 -0
  190. flyte/remote/_data.py +159 -0
  191. flyte/remote/_logs.py +176 -0
  192. flyte/remote/_project.py +85 -0
  193. flyte/remote/_run.py +970 -0
  194. flyte/remote/_secret.py +132 -0
  195. flyte/remote/_task.py +391 -0
  196. flyte/report/__init__.py +3 -0
  197. flyte/report/_report.py +178 -0
  198. flyte/report/_template.html +124 -0
  199. flyte/storage/__init__.py +29 -0
  200. flyte/storage/_config.py +233 -0
  201. flyte/storage/_remote_fs.py +34 -0
  202. flyte/storage/_storage.py +271 -0
  203. flyte/storage/_utils.py +5 -0
  204. flyte/syncify/__init__.py +56 -0
  205. flyte/syncify/_api.py +371 -0
  206. flyte/types/__init__.py +36 -0
  207. flyte/types/_interface.py +40 -0
  208. flyte/types/_pickle.py +118 -0
  209. flyte/types/_renderer.py +162 -0
  210. flyte/types/_string_literals.py +120 -0
  211. flyte/types/_type_engine.py +2287 -0
  212. flyte/types/_utils.py +80 -0
  213. flyte-0.2.0a0.dist-info/METADATA +249 -0
  214. flyte-0.2.0a0.dist-info/RECORD +218 -0
  215. {flyte-0.1.0.dist-info → flyte-0.2.0a0.dist-info}/WHEEL +2 -1
  216. flyte-0.2.0a0.dist-info/entry_points.txt +3 -0
  217. flyte-0.2.0a0.dist-info/top_level.txt +1 -0
  218. flyte-0.1.0.dist-info/METADATA +0 -6
  219. flyte-0.1.0.dist-info/RECORD +0 -5
@@ -0,0 +1,494 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import concurrent.futures
5
+ import os
6
+ import threading
7
+ from collections import defaultdict
8
+ from collections.abc import Callable
9
+ from pathlib import Path
10
+ from typing import Any, AsyncIterable, Awaitable, DefaultDict, Tuple, TypeVar
11
+
12
+ import flyte
13
+ import flyte.errors
14
+ import flyte.storage as storage
15
+ import flyte.types as types
16
+ from flyte._code_bundle import build_pkl_bundle
17
+ from flyte._context import internal_ctx
18
+ from flyte._internal.controllers import TraceInfo
19
+ from flyte._internal.controllers.remote._action import Action
20
+ from flyte._internal.controllers.remote._core import Controller
21
+ from flyte._internal.controllers.remote._service_protocol import ClientSet
22
+ from flyte._internal.runtime import convert, io
23
+ from flyte._internal.runtime.task_serde import translate_task_to_wire
24
+ from flyte._logging import logger
25
+ from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
26
+ from flyte._task import TaskTemplate
27
+ from flyte._utils.helpers import _selector_policy
28
+ from flyte.models import ActionID, NativeInterface, SerializationContext
29
+
30
+ R = TypeVar("R")
31
+
32
+
33
+ async def upload_inputs_with_retry(serialized_inputs: AsyncIterable[bytes] | bytes, inputs_uri: str) -> None:
34
+ """
35
+ Upload inputs to the specified URI with error handling.
36
+
37
+ Args:
38
+ serialized_inputs: The serialized inputs to upload
39
+ inputs_uri: The destination URI
40
+
41
+ Raises:
42
+ RuntimeSystemError: If the upload fails
43
+ """
44
+ try:
45
+ # TODO Add retry decorator to this
46
+ await storage.put_stream(serialized_inputs, to_path=inputs_uri)
47
+ except Exception as e:
48
+ logger.exception("Failed to upload inputs")
49
+ raise flyte.errors.RuntimeSystemError(type(e).__name__, str(e)) from e
50
+
51
+
52
+ async def handle_action_failure(action: Action, task_name: str) -> Exception:
53
+ """
54
+ Handle action failure by loading error details or raising a RuntimeSystemError.
55
+
56
+ Args:
57
+ action: The updated action
58
+ task_name: The name of the task
59
+
60
+ Raises:
61
+ Exception: The converted native exception or RuntimeSystemError
62
+ """
63
+ err = action.err or action.client_err
64
+ if not err and action.phase == run_definition_pb2.PHASE_FAILED:
65
+ logger.error(f"Server reported failure for action {action.name}, checking error file.")
66
+ try:
67
+ error_path = io.error_path(f"{action.run_output_base}/{action.action_id.name}/1")
68
+ err = await io.load_error(error_path)
69
+ except Exception as e:
70
+ logger.exception("Failed to load error file", e)
71
+ err = flyte.errors.RuntimeSystemError(type(e).__name__, f"Failed to load error file: {e}")
72
+ else:
73
+ logger.error(f"Server reported failure for action {action.action_id.name}, error: {err}")
74
+
75
+ exc = convert.convert_error_to_native(err)
76
+ if not exc:
77
+ return flyte.errors.RuntimeSystemError("UnableToConvertError", f"Error in task {task_name}: {err}")
78
+ return exc
79
+
80
+
81
+ async def load_and_convert_outputs(iface: NativeInterface, realized_outputs_uri: str) -> Any:
82
+ """
83
+ Load outputs from the given URI and convert them to native format.
84
+
85
+ Args:
86
+ iface: The Native interface
87
+ realized_outputs_uri: The URI where outputs are stored
88
+
89
+ Returns:
90
+ The converted native outputs
91
+ """
92
+ outputs_file_path = io.outputs_path(realized_outputs_uri)
93
+ outputs = await io.load_outputs(outputs_file_path)
94
+ return await convert.convert_outputs_to_native(iface, outputs)
95
+
96
+
97
+ def unique_action_name(action_id: ActionID) -> str:
98
+ return f"{action_id.name}_{action_id.run_name}"
99
+
100
+
101
+ class RemoteController(Controller):
102
+ """
103
+ This a specialized controller that wraps the core controller and performs IO, serialization and deserialization
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ client_coro: Awaitable[ClientSet],
109
+ workers: int,
110
+ max_system_retries: int,
111
+ default_parent_concurrency: int = 1000,
112
+ ):
113
+ """ """
114
+ super().__init__(
115
+ client_coro=client_coro,
116
+ workers=workers,
117
+ max_system_retries=max_system_retries,
118
+ )
119
+ self._default_parent_concurrency = default_parent_concurrency
120
+ self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
121
+ lambda: asyncio.Semaphore(default_parent_concurrency)
122
+ )
123
+ self._parent_action_task_call_sequence: DefaultDict[str, DefaultDict[int, int]] = defaultdict(
124
+ lambda: defaultdict(int)
125
+ )
126
+ self._submit_loop: asyncio.AbstractEventLoop | None = None
127
+ self._submit_thread: threading.Thread | None = None
128
+
129
+ def generate_task_call_sequence(self, task_obj: object, action_id: ActionID) -> int:
130
+ """
131
+ Generate a task call sequence for the given task object and action ID.
132
+ This is used to track the number of times a task is called within an action.
133
+ """
134
+ current_action_sequencer = self._parent_action_task_call_sequence[unique_action_name(action_id)]
135
+ current_task_id = id(task_obj)
136
+ v = current_action_sequencer[current_task_id]
137
+ new_seq = v + 1
138
+ current_action_sequencer[current_task_id] = new_seq
139
+ return new_seq
140
+
141
+ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwargs) -> Any:
142
+ ctx = internal_ctx()
143
+ tctx = ctx.data.task_context
144
+ if tctx is None:
145
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
146
+ current_action_id = tctx.action
147
+
148
+ # In the case of a regular code bundle, we will just pass it down as it is to the downstream tasks
149
+ # It is not allowed to change the code bundle (for regular code bundles) in the middle of a run.
150
+ code_bundle = tctx.code_bundle
151
+
152
+ if code_bundle and code_bundle.pkl:
153
+ logger.debug(f"Building new pkl bundle for task {_task.name}")
154
+ code_bundle = await build_pkl_bundle(
155
+ _task,
156
+ upload_to_controlplane=False,
157
+ upload_from_dataplane_base_path=tctx.run_base_dir,
158
+ )
159
+
160
+ inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
161
+
162
+ root_dir = Path(code_bundle.destination).absolute() if code_bundle else Path.cwd()
163
+ # Don't set output path in sec context because node executor will set it
164
+ new_serialization_context = SerializationContext(
165
+ project=current_action_id.project,
166
+ domain=current_action_id.domain,
167
+ org=current_action_id.org,
168
+ code_bundle=code_bundle,
169
+ version=tctx.version,
170
+ # supplied version.
171
+ # input_path=inputs_uri,
172
+ image_cache=tctx.compiled_image_cache,
173
+ root_dir=root_dir,
174
+ )
175
+
176
+ task_spec = translate_task_to_wire(_task, new_serialization_context)
177
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
178
+
179
+ inputs_hash = convert.generate_inputs_hash(serialized_inputs)
180
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
181
+ tctx, task_spec, inputs_hash, _task_call_seq
182
+ )
183
+
184
+ inputs_uri = io.inputs_path(sub_action_output_path)
185
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri)
186
+
187
+ md = task_spec.task_template.metadata
188
+ ignored_input_vars = []
189
+ if len(md.cache_ignore_input_vars) > 0:
190
+ ignored_input_vars = list(md.cache_ignore_input_vars)
191
+ cache_key = None
192
+ if task_spec.task_template.metadata and task_spec.task_template.metadata.discoverable:
193
+ discovery_version = task_spec.task_template.metadata.discovery_version
194
+ cache_key = convert.generate_cache_key_hash(
195
+ _task.name,
196
+ inputs_hash,
197
+ task_spec.task_template.interface,
198
+ discovery_version,
199
+ ignored_input_vars,
200
+ inputs.proto_inputs,
201
+ )
202
+
203
+ # Clear to free memory
204
+ serialized_inputs = None # type: ignore
205
+ inputs_hash = None # type: ignore
206
+
207
+ action = Action.from_task(
208
+ sub_action_id=run_definition_pb2.ActionIdentifier(
209
+ name=sub_action_id.name,
210
+ run=run_definition_pb2.RunIdentifier(
211
+ name=current_action_id.run_name,
212
+ project=current_action_id.project,
213
+ domain=current_action_id.domain,
214
+ org=current_action_id.org,
215
+ ),
216
+ ),
217
+ parent_action_name=current_action_id.name,
218
+ group_data=tctx.group_data,
219
+ task_spec=task_spec,
220
+ inputs_uri=inputs_uri,
221
+ run_output_base=tctx.run_base_dir,
222
+ cache_key=cache_key,
223
+ )
224
+
225
+ try:
226
+ logger.info(
227
+ f"Submitting action Run:[{action.run_name}, Parent:[{action.parent_action_name}], "
228
+ f"task:[{_task.name}], action:[{action.name}]"
229
+ )
230
+ n = await self.submit_action(action)
231
+ logger.info(f"Action for task [{_task.name}] action id: {action.name}, completed!")
232
+ except asyncio.CancelledError:
233
+ # If the action is cancelled, we need to cancel the action on the server as well
234
+ logger.info(f"Action {action.action_id.name} cancelled, cancelling on server")
235
+ await self.cancel_action(action)
236
+ raise
237
+
238
+ if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED:
239
+ exc = await handle_action_failure(action, _task.name)
240
+ raise exc
241
+
242
+ if _task.native_interface.outputs:
243
+ if not n.realized_outputs_uri:
244
+ raise flyte.errors.RuntimeSystemError(
245
+ "RuntimeError",
246
+ f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.",
247
+ )
248
+ return await load_and_convert_outputs(_task.native_interface, n.realized_outputs_uri)
249
+ return None
250
+
251
+ async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
252
+ """
253
+ Submit a task to the remote controller.This creates a new action on the queue service.
254
+ """
255
+ ctx = internal_ctx()
256
+ tctx = ctx.data.task_context
257
+ if tctx is None:
258
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
259
+ current_action_id = tctx.action
260
+ task_call_seq = self.generate_task_call_sequence(_task, current_action_id)
261
+ async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
262
+ return await self._submit(task_call_seq, _task, *args, **kwargs)
263
+
264
+ def _sync_thread_loop_runner(self) -> None:
265
+ """This method runs the event loop and should be invoked in a separate thread."""
266
+
267
+ loop = self._submit_loop
268
+ assert loop is not None
269
+ try:
270
+ loop.run_forever()
271
+ finally:
272
+ loop.close()
273
+
274
+ def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
275
+ """
276
+ This function creates a cached thread and loop for the purpose of calling the submit method synchronously,
277
+ returning a concurrent Future that can be awaited. There's no need for a lock because this function itself is
278
+ single threaded and non-async. This pattern here is basically the trivial/degenerate case of the thread pool
279
+ in the LocalController.
280
+ Please see additional comments in protocol.
281
+
282
+ :param _task:
283
+ :param args:
284
+ :param kwargs:
285
+ :return:
286
+ """
287
+ if self._submit_thread is None:
288
+ # Please see LocalController for the general implementation of this pattern.
289
+ def exc_handler(loop, context):
290
+ logger.error(f"Remote controller submit sync loop caught exception in {loop}: {context}")
291
+
292
+ with _selector_policy():
293
+ self._submit_loop = asyncio.new_event_loop()
294
+ self._submit_loop.set_exception_handler(exc_handler)
295
+
296
+ self._submit_thread = threading.Thread(
297
+ name=f"remote-controller-{os.getpid()}-submitter", daemon=True, target=self._sync_thread_loop_runner
298
+ )
299
+ self._submit_thread.start()
300
+
301
+ coro = self.submit(_task, *args, **kwargs)
302
+ assert self._submit_loop is not None, "Submit loop should always have been initialized by now"
303
+ fut = asyncio.run_coroutine_threadsafe(coro, self._submit_loop)
304
+ return fut
305
+
306
+ async def finalize_parent_action(self, action_id: ActionID):
307
+ """
308
+ This method is invoked when the parent action is finished. It will finalize the run and upload the outputs
309
+ to the control plane.
310
+ """
311
+ run_id = run_definition_pb2.RunIdentifier(
312
+ name=action_id.run_name,
313
+ project=action_id.project,
314
+ domain=action_id.domain,
315
+ org=action_id.org,
316
+ )
317
+ await super()._finalize_parent_action(run_id=run_id, parent_action_name=action_id.name)
318
+ self._parent_action_semaphore.pop(unique_action_name(action_id), None)
319
+ self._parent_action_task_call_sequence.pop(unique_action_name(action_id), None)
320
+
321
+ async def get_action_outputs(
322
+ self, _interface: NativeInterface, _func: Callable, *args, **kwargs
323
+ ) -> Tuple[TraceInfo, bool]:
324
+ """
325
+ This method returns the outputs of the action, if it is available.
326
+ If not available it raises a NotFoundError.
327
+ :param _interface: NativeInterface
328
+ :param _func: Function name
329
+ :param args: Arguments
330
+ :param kwargs: Keyword arguments
331
+ :return:
332
+ """
333
+ ctx = internal_ctx()
334
+ tctx = ctx.data.task_context
335
+ if tctx is None:
336
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
337
+ current_action_id = tctx.action
338
+
339
+ func_name = _func.__name__
340
+ invoke_seq_num = self.generate_task_call_sequence(_func, current_action_id)
341
+ inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
342
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
343
+
344
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
345
+ tctx, func_name, serialized_inputs, invoke_seq_num
346
+ )
347
+
348
+ inputs_uri = io.inputs_path(sub_action_output_path)
349
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri)
350
+ # Clear to free memory
351
+ serialized_inputs = None # type: ignore
352
+
353
+ prev_action = await self.get_action(
354
+ run_definition_pb2.ActionIdentifier(
355
+ name=sub_action_id.name,
356
+ run=run_definition_pb2.RunIdentifier(
357
+ name=current_action_id.run_name,
358
+ project=current_action_id.project,
359
+ domain=current_action_id.domain,
360
+ org=current_action_id.org,
361
+ ),
362
+ ),
363
+ current_action_id.name,
364
+ )
365
+
366
+ if prev_action is None:
367
+ return TraceInfo(sub_action_id, _interface, inputs_uri), False
368
+
369
+ if prev_action.phase == run_definition_pb2.PHASE_FAILED:
370
+ if prev_action.has_error():
371
+ exc = convert.convert_error_to_native(prev_action.err)
372
+ return TraceInfo(sub_action_id, _interface, inputs_uri, error=exc), True
373
+ else:
374
+ logger.warning(f"Action {prev_action.action_id.name} failed, but no error was found, re-running trace!")
375
+ elif prev_action.realized_outputs_uri is not None:
376
+ outputs_file_path = io.outputs_path(prev_action.realized_outputs_uri)
377
+ o = await io.load_outputs(outputs_file_path)
378
+ outputs = await convert.convert_outputs_to_native(_interface, o)
379
+ return TraceInfo(sub_action_id, _interface, inputs_uri, output=outputs), True
380
+
381
+ return TraceInfo(sub_action_id, _interface, inputs_uri), False
382
+
383
+ async def record_trace(self, info: TraceInfo):
384
+ """
385
+ Record a trace action. This is used to record the trace of the action and should be called when the action
386
+ :param info:
387
+ :return:
388
+ """
389
+ ctx = internal_ctx()
390
+ tctx = ctx.data.task_context
391
+ if tctx is None:
392
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
393
+
394
+ current_output_path = tctx.output_path
395
+ sub_run_output_path = storage.join(current_output_path, info.action.name)
396
+
397
+ if info.interface.has_outputs():
398
+ if info.output:
399
+ outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
400
+ outputs_file_path = io.outputs_path(sub_run_output_path)
401
+ await io.upload_outputs(outputs, outputs_file_path)
402
+ elif info.error:
403
+ err = convert.convert_from_native_to_error(info.error)
404
+ error_path = io.error_path(sub_run_output_path)
405
+ await io.upload_error(err.err, error_path)
406
+ else:
407
+ raise flyte.errors.RuntimeSystemError("BadTraceInfo", "Trace info does not have output or error")
408
+
409
+ async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
410
+ ctx = internal_ctx()
411
+ tctx = ctx.data.task_context
412
+ if tctx is None:
413
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
414
+ current_action_id = tctx.action
415
+ task_name = _task.spec.task_template.id.name
416
+
417
+ invoke_seq_num = self.generate_task_call_sequence(_task, current_action_id)
418
+
419
+ native_interface = types.guess_interface(
420
+ _task.spec.task_template.interface, default_inputs=_task.spec.default_inputs
421
+ )
422
+ inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs)
423
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
424
+ inputs_hash = convert.generate_inputs_hash(serialized_inputs)
425
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
426
+ tctx, task_name, inputs_hash, invoke_seq_num
427
+ )
428
+
429
+ inputs_uri = io.inputs_path(sub_action_output_path)
430
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri)
431
+ # cache key - task name, task signature, inputs, cache version
432
+ cache_key = None
433
+ md = _task.spec.task_template.metadata
434
+ ignored_input_vars = []
435
+ if len(md.cache_ignore_input_vars) > 0:
436
+ ignored_input_vars = list(md.cache_ignore_input_vars)
437
+ if _task.spec.task_template.metadata and _task.spec.task_template.metadata.discoverable:
438
+ discovery_version = _task.spec.task_template.metadata.discovery_version
439
+ cache_key = convert.generate_cache_key_hash(
440
+ task_name,
441
+ inputs_hash,
442
+ _task.spec.task_template.interface,
443
+ discovery_version,
444
+ ignored_input_vars,
445
+ inputs.proto_inputs,
446
+ )
447
+
448
+ # Clear to free memory
449
+ serialized_inputs = None # type: ignore
450
+ inputs_hash = None # type: ignore
451
+
452
+ action = Action.from_task(
453
+ sub_action_id=run_definition_pb2.ActionIdentifier(
454
+ name=sub_action_id.name,
455
+ run=run_definition_pb2.RunIdentifier(
456
+ name=current_action_id.run_name,
457
+ project=current_action_id.project,
458
+ domain=current_action_id.domain,
459
+ org=current_action_id.org,
460
+ ),
461
+ ),
462
+ parent_action_name=current_action_id.name,
463
+ group_data=tctx.group_data,
464
+ task_spec=_task.spec,
465
+ inputs_uri=inputs_uri,
466
+ run_output_base=tctx.run_base_dir,
467
+ cache_key=cache_key,
468
+ )
469
+
470
+ try:
471
+ logger.info(
472
+ f"Submitting action Run:[{action.run_name}, Parent:[{action.parent_action_name}], "
473
+ f"task:[{task_name}], action:[{action.name}]"
474
+ )
475
+ n = await self.submit_action(action)
476
+ logger.info(f"Action for task [{task_name}] action id: {action.name}, completed!")
477
+ except asyncio.CancelledError:
478
+ # If the action is cancelled, we need to cancel the action on the server as well
479
+ logger.info(f"Action {action.action_id.name} cancelled, cancelling on server")
480
+ await self.cancel_action(action)
481
+ raise
482
+
483
+ if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED:
484
+ exc = await handle_action_failure(action, task_name)
485
+ raise exc
486
+
487
+ if native_interface.outputs:
488
+ if not n.realized_outputs_uri:
489
+ raise flyte.errors.RuntimeSystemError(
490
+ "RuntimeError",
491
+ f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.",
492
+ )
493
+ return await load_and_convert_outputs(native_interface, n.realized_outputs_uri)
494
+ return None