flyte 2.0.0b32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (204) hide show
  1. flyte/__init__.py +108 -0
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +195 -0
  5. flyte/_bin/serve.py +178 -0
  6. flyte/_build.py +26 -0
  7. flyte/_cache/__init__.py +12 -0
  8. flyte/_cache/cache.py +147 -0
  9. flyte/_cache/defaults.py +9 -0
  10. flyte/_cache/local_cache.py +216 -0
  11. flyte/_cache/policy_function_body.py +42 -0
  12. flyte/_code_bundle/__init__.py +8 -0
  13. flyte/_code_bundle/_ignore.py +121 -0
  14. flyte/_code_bundle/_packaging.py +218 -0
  15. flyte/_code_bundle/_utils.py +347 -0
  16. flyte/_code_bundle/bundle.py +266 -0
  17. flyte/_constants.py +1 -0
  18. flyte/_context.py +155 -0
  19. flyte/_custom_context.py +73 -0
  20. flyte/_debug/__init__.py +0 -0
  21. flyte/_debug/constants.py +38 -0
  22. flyte/_debug/utils.py +17 -0
  23. flyte/_debug/vscode.py +307 -0
  24. flyte/_deploy.py +408 -0
  25. flyte/_deployer.py +109 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +122 -0
  29. flyte/_excepthook.py +37 -0
  30. flyte/_group.py +32 -0
  31. flyte/_hash.py +8 -0
  32. flyte/_image.py +1055 -0
  33. flyte/_initialize.py +628 -0
  34. flyte/_interface.py +119 -0
  35. flyte/_internal/__init__.py +3 -0
  36. flyte/_internal/controllers/__init__.py +129 -0
  37. flyte/_internal/controllers/_local_controller.py +239 -0
  38. flyte/_internal/controllers/_trace.py +48 -0
  39. flyte/_internal/controllers/remote/__init__.py +58 -0
  40. flyte/_internal/controllers/remote/_action.py +211 -0
  41. flyte/_internal/controllers/remote/_client.py +47 -0
  42. flyte/_internal/controllers/remote/_controller.py +583 -0
  43. flyte/_internal/controllers/remote/_core.py +465 -0
  44. flyte/_internal/controllers/remote/_informer.py +381 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +3 -0
  47. flyte/_internal/imagebuild/docker_builder.py +706 -0
  48. flyte/_internal/imagebuild/image_builder.py +277 -0
  49. flyte/_internal/imagebuild/remote_builder.py +386 -0
  50. flyte/_internal/imagebuild/utils.py +78 -0
  51. flyte/_internal/resolvers/__init__.py +0 -0
  52. flyte/_internal/resolvers/_task_module.py +21 -0
  53. flyte/_internal/resolvers/common.py +31 -0
  54. flyte/_internal/resolvers/default.py +28 -0
  55. flyte/_internal/runtime/__init__.py +0 -0
  56. flyte/_internal/runtime/convert.py +486 -0
  57. flyte/_internal/runtime/entrypoints.py +204 -0
  58. flyte/_internal/runtime/io.py +188 -0
  59. flyte/_internal/runtime/resources_serde.py +152 -0
  60. flyte/_internal/runtime/reuse.py +125 -0
  61. flyte/_internal/runtime/rusty.py +193 -0
  62. flyte/_internal/runtime/task_serde.py +362 -0
  63. flyte/_internal/runtime/taskrunner.py +209 -0
  64. flyte/_internal/runtime/trigger_serde.py +160 -0
  65. flyte/_internal/runtime/types_serde.py +54 -0
  66. flyte/_keyring/__init__.py +0 -0
  67. flyte/_keyring/file.py +115 -0
  68. flyte/_logging.py +300 -0
  69. flyte/_map.py +312 -0
  70. flyte/_module.py +72 -0
  71. flyte/_pod.py +30 -0
  72. flyte/_resources.py +473 -0
  73. flyte/_retry.py +32 -0
  74. flyte/_reusable_environment.py +102 -0
  75. flyte/_run.py +724 -0
  76. flyte/_secret.py +96 -0
  77. flyte/_task.py +550 -0
  78. flyte/_task_environment.py +316 -0
  79. flyte/_task_plugins.py +47 -0
  80. flyte/_timeout.py +47 -0
  81. flyte/_tools.py +27 -0
  82. flyte/_trace.py +119 -0
  83. flyte/_trigger.py +1000 -0
  84. flyte/_utils/__init__.py +30 -0
  85. flyte/_utils/asyn.py +121 -0
  86. flyte/_utils/async_cache.py +139 -0
  87. flyte/_utils/coro_management.py +27 -0
  88. flyte/_utils/docker_credentials.py +173 -0
  89. flyte/_utils/file_handling.py +72 -0
  90. flyte/_utils/helpers.py +134 -0
  91. flyte/_utils/lazy_module.py +54 -0
  92. flyte/_utils/module_loader.py +104 -0
  93. flyte/_utils/org_discovery.py +57 -0
  94. flyte/_utils/uv_script_parser.py +49 -0
  95. flyte/_version.py +34 -0
  96. flyte/app/__init__.py +22 -0
  97. flyte/app/_app_environment.py +157 -0
  98. flyte/app/_deploy.py +125 -0
  99. flyte/app/_input.py +160 -0
  100. flyte/app/_runtime/__init__.py +3 -0
  101. flyte/app/_runtime/app_serde.py +347 -0
  102. flyte/app/_types.py +101 -0
  103. flyte/app/extras/__init__.py +3 -0
  104. flyte/app/extras/_fastapi.py +151 -0
  105. flyte/cli/__init__.py +12 -0
  106. flyte/cli/_abort.py +28 -0
  107. flyte/cli/_build.py +114 -0
  108. flyte/cli/_common.py +468 -0
  109. flyte/cli/_create.py +371 -0
  110. flyte/cli/_delete.py +45 -0
  111. flyte/cli/_deploy.py +293 -0
  112. flyte/cli/_gen.py +176 -0
  113. flyte/cli/_get.py +370 -0
  114. flyte/cli/_option.py +33 -0
  115. flyte/cli/_params.py +554 -0
  116. flyte/cli/_plugins.py +209 -0
  117. flyte/cli/_run.py +597 -0
  118. flyte/cli/_serve.py +64 -0
  119. flyte/cli/_update.py +37 -0
  120. flyte/cli/_user.py +17 -0
  121. flyte/cli/main.py +221 -0
  122. flyte/config/__init__.py +3 -0
  123. flyte/config/_config.py +248 -0
  124. flyte/config/_internal.py +73 -0
  125. flyte/config/_reader.py +225 -0
  126. flyte/connectors/__init__.py +11 -0
  127. flyte/connectors/_connector.py +270 -0
  128. flyte/connectors/_server.py +197 -0
  129. flyte/connectors/utils.py +135 -0
  130. flyte/errors.py +243 -0
  131. flyte/extend.py +19 -0
  132. flyte/extras/__init__.py +5 -0
  133. flyte/extras/_container.py +286 -0
  134. flyte/git/__init__.py +3 -0
  135. flyte/git/_config.py +21 -0
  136. flyte/io/__init__.py +29 -0
  137. flyte/io/_dataframe/__init__.py +131 -0
  138. flyte/io/_dataframe/basic_dfs.py +223 -0
  139. flyte/io/_dataframe/dataframe.py +1026 -0
  140. flyte/io/_dir.py +910 -0
  141. flyte/io/_file.py +914 -0
  142. flyte/io/_hashing_io.py +342 -0
  143. flyte/models.py +479 -0
  144. flyte/py.typed +0 -0
  145. flyte/remote/__init__.py +35 -0
  146. flyte/remote/_action.py +738 -0
  147. flyte/remote/_app.py +57 -0
  148. flyte/remote/_client/__init__.py +0 -0
  149. flyte/remote/_client/_protocols.py +189 -0
  150. flyte/remote/_client/auth/__init__.py +12 -0
  151. flyte/remote/_client/auth/_auth_utils.py +14 -0
  152. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  153. flyte/remote/_client/auth/_authenticators/base.py +403 -0
  154. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  155. flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
  156. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  157. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  158. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  159. flyte/remote/_client/auth/_channel.py +213 -0
  160. flyte/remote/_client/auth/_client_config.py +85 -0
  161. flyte/remote/_client/auth/_default_html.py +32 -0
  162. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  163. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  164. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  165. flyte/remote/_client/auth/_keyring.py +152 -0
  166. flyte/remote/_client/auth/_token_client.py +260 -0
  167. flyte/remote/_client/auth/errors.py +16 -0
  168. flyte/remote/_client/controlplane.py +128 -0
  169. flyte/remote/_common.py +30 -0
  170. flyte/remote/_console.py +19 -0
  171. flyte/remote/_data.py +161 -0
  172. flyte/remote/_logs.py +185 -0
  173. flyte/remote/_project.py +88 -0
  174. flyte/remote/_run.py +386 -0
  175. flyte/remote/_secret.py +142 -0
  176. flyte/remote/_task.py +527 -0
  177. flyte/remote/_trigger.py +306 -0
  178. flyte/remote/_user.py +33 -0
  179. flyte/report/__init__.py +3 -0
  180. flyte/report/_report.py +182 -0
  181. flyte/report/_template.html +124 -0
  182. flyte/storage/__init__.py +36 -0
  183. flyte/storage/_config.py +237 -0
  184. flyte/storage/_parallel_reader.py +274 -0
  185. flyte/storage/_remote_fs.py +34 -0
  186. flyte/storage/_storage.py +456 -0
  187. flyte/storage/_utils.py +5 -0
  188. flyte/syncify/__init__.py +56 -0
  189. flyte/syncify/_api.py +375 -0
  190. flyte/types/__init__.py +52 -0
  191. flyte/types/_interface.py +40 -0
  192. flyte/types/_pickle.py +145 -0
  193. flyte/types/_renderer.py +162 -0
  194. flyte/types/_string_literals.py +119 -0
  195. flyte/types/_type_engine.py +2254 -0
  196. flyte/types/_utils.py +80 -0
  197. flyte-2.0.0b32.data/scripts/debug.py +38 -0
  198. flyte-2.0.0b32.data/scripts/runtime.py +195 -0
  199. flyte-2.0.0b32.dist-info/METADATA +351 -0
  200. flyte-2.0.0b32.dist-info/RECORD +204 -0
  201. flyte-2.0.0b32.dist-info/WHEEL +5 -0
  202. flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
  203. flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
  204. flyte-2.0.0b32.dist-info/top_level.txt +1 -0
@@ -0,0 +1,583 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import concurrent.futures
5
+ import os
6
+ import threading
7
+ from collections import defaultdict
8
+ from collections.abc import Callable
9
+ from pathlib import Path
10
+ from typing import Any, Awaitable, DefaultDict, Tuple, TypeVar
11
+
12
+ from flyteidl2.common import identifier_pb2
13
+ from flyteidl2.workflow import run_definition_pb2
14
+
15
+ import flyte
16
+ import flyte.errors
17
+ import flyte.storage as storage
18
+ from flyte._code_bundle import build_pkl_bundle
19
+ from flyte._context import internal_ctx
20
+ from flyte._internal.controllers import TraceInfo
21
+ from flyte._internal.controllers.remote._action import Action
22
+ from flyte._internal.controllers.remote._core import Controller
23
+ from flyte._internal.controllers.remote._service_protocol import ClientSet
24
+ from flyte._internal.runtime import convert, io
25
+ from flyte._internal.runtime.task_serde import translate_task_to_wire
26
+ from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
27
+ from flyte._logging import logger
28
+ from flyte._task import TaskTemplate
29
+ from flyte._utils.helpers import _selector_policy
30
+ from flyte.models import MAX_INLINE_IO_BYTES, ActionID, NativeInterface, SerializationContext
31
+ from flyte.remote._task import TaskDetails
32
+
33
+ R = TypeVar("R")
34
+
35
+ MAX_TRACE_BYTES = MAX_INLINE_IO_BYTES
36
+
37
+
38
+ async def upload_inputs_with_retry(serialized_inputs: bytes, inputs_uri: str, max_bytes: int) -> None:
39
+ """
40
+ Upload inputs to the specified URI with error handling.
41
+
42
+ Args:
43
+ serialized_inputs: The serialized inputs to upload
44
+ inputs_uri: The destination URI
45
+ max_bytes: Maximum number of bytes to read from the input stream
46
+
47
+ Raises:
48
+ RuntimeSystemError: If the upload fails
49
+ """
50
+ if len(serialized_inputs) > max_bytes:
51
+ raise flyte.errors.InlineIOMaxBytesBreached(
52
+ f"Inputs exceed max_bytes limit of {max_bytes / 1024 / 1024} MB,"
53
+ f" actual size: {len(serialized_inputs) / 1024 / 1024} MB"
54
+ )
55
+ try:
56
+ # TODO Add retry decorator to this
57
+ await storage.put_stream(serialized_inputs, to_path=inputs_uri)
58
+ except Exception as e:
59
+ logger.exception("Failed to upload inputs")
60
+ raise flyte.errors.RuntimeSystemError(type(e).__name__, str(e)) from e
61
+
62
+
63
+ async def handle_action_failure(action: Action, task_name: str) -> Exception:
64
+ """
65
+ Handle action failure by loading error details or raising a RuntimeSystemError.
66
+
67
+ Args:
68
+ action: The updated action
69
+ task_name: The name of the task
70
+
71
+ Raises:
72
+ Exception: The converted native exception or RuntimeSystemError
73
+ """
74
+ err = action.err or action.client_err
75
+ if not err and action.phase == run_definition_pb2.PHASE_FAILED:
76
+ logger.error(f"Server reported failure for action {action.name}, checking error file.")
77
+ try:
78
+ error_path = io.error_path(f"{action.run_output_base}/{action.action_id.name}/1")
79
+ err = await io.load_error(error_path)
80
+ except Exception as e:
81
+ logger.exception("Failed to load error file", e)
82
+ err = flyte.errors.RuntimeSystemError(type(e).__name__, f"Failed to load error file: {e}")
83
+ else:
84
+ logger.error(f"Server reported failure for action {action.action_id.name}, error: {err}")
85
+
86
+ exc = convert.convert_error_to_native(err)
87
+ if not exc:
88
+ return flyte.errors.RuntimeSystemError("UnableToConvertError", f"Error in task {task_name}: {err}")
89
+ return exc
90
+
91
+
92
+ async def load_and_convert_outputs(iface: NativeInterface, realized_outputs_uri: str, max_bytes: int) -> Any:
93
+ """
94
+ Load outputs from the given URI and convert them to native format.
95
+
96
+ Args:
97
+ iface: The Native interface
98
+ realized_outputs_uri: The URI where outputs are stored
99
+ max_bytes: Maximum number of bytes to read from the output file
100
+
101
+ Returns:
102
+ The converted native outputs
103
+ """
104
+ outputs_file_path = io.outputs_path(realized_outputs_uri)
105
+ outputs = await io.load_outputs(outputs_file_path, max_bytes=max_bytes)
106
+ return await convert.convert_outputs_to_native(iface, outputs)
107
+
108
+
109
+ def unique_action_name(action_id: ActionID) -> str:
110
+ return f"{action_id.name}_{action_id.run_name}"
111
+
112
+
113
+ class RemoteController(Controller):
114
+ """
115
+ This a specialized controller that wraps the core controller and performs IO, serialization and deserialization
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ client_coro: Awaitable[ClientSet],
121
+ workers: int = 20,
122
+ max_system_retries: int = 10,
123
+ ):
124
+ """ """
125
+ super().__init__(
126
+ client_coro=client_coro,
127
+ workers=workers,
128
+ max_system_retries=max_system_retries,
129
+ )
130
+ default_parent_concurrency = int(os.getenv("_F_P_CNC", "1000"))
131
+ self._default_parent_concurrency = default_parent_concurrency
132
+ self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
133
+ lambda: asyncio.Semaphore(default_parent_concurrency)
134
+ )
135
+ self._parent_action_task_call_sequence: DefaultDict[str, DefaultDict[int, int]] = defaultdict(
136
+ lambda: defaultdict(int)
137
+ )
138
+ self._submit_loop: asyncio.AbstractEventLoop | None = None
139
+ self._submit_thread: threading.Thread | None = None
140
+
141
+ def generate_task_call_sequence(self, task_obj: object, action_id: ActionID) -> int:
142
+ """
143
+ Generate a task call sequence for the given task object and action ID.
144
+ This is used to track the number of times a task is called within an action.
145
+ """
146
+ uniq = unique_action_name(action_id)
147
+ current_action_sequencer = self._parent_action_task_call_sequence[uniq]
148
+ current_task_id = id(task_obj)
149
+ v = current_action_sequencer[current_task_id]
150
+ new_seq = v + 1
151
+ current_action_sequencer[current_task_id] = new_seq
152
+ name = ""
153
+ if hasattr(task_obj, "__name__"):
154
+ name = task_obj.__name__
155
+ elif hasattr(task_obj, "name"):
156
+ name = task_obj.name
157
+ logger.info(f"For action {uniq}, task {name} call sequence is {new_seq}")
158
+ return new_seq
159
+
160
+ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwargs) -> Any:
161
+ ctx = internal_ctx()
162
+ tctx = ctx.data.task_context
163
+ if tctx is None:
164
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
165
+ current_action_id = tctx.action
166
+
167
+ # In the case of a regular code bundle, we will just pass it down as it is to the downstream tasks
168
+ # It is not allowed to change the code bundle (for regular code bundles) in the middle of a run.
169
+ code_bundle = tctx.code_bundle
170
+
171
+ if tctx.interactive_mode or (code_bundle and code_bundle.pkl):
172
+ logger.debug(f"Building new pkl bundle for task {_task.name}")
173
+ code_bundle = await build_pkl_bundle(
174
+ _task,
175
+ upload_to_controlplane=False,
176
+ upload_from_dataplane_base_path=tctx.run_base_dir,
177
+ )
178
+
179
+ inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
180
+
181
+ root_dir = Path(code_bundle.destination).absolute() if code_bundle else Path.cwd()
182
+ # Don't set output path in sec context because node executor will set it
183
+ new_serialization_context = SerializationContext(
184
+ project=current_action_id.project,
185
+ domain=current_action_id.domain,
186
+ org=current_action_id.org,
187
+ code_bundle=code_bundle,
188
+ version=tctx.version,
189
+ # supplied version.
190
+ # input_path=inputs_uri,
191
+ image_cache=tctx.compiled_image_cache,
192
+ root_dir=root_dir,
193
+ )
194
+
195
+ task_spec = translate_task_to_wire(_task, new_serialization_context)
196
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
197
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
198
+ tctx, task_spec, inputs_hash, _task_call_seq
199
+ )
200
+ logger.info(f"Sub action {sub_action_id} output path {sub_action_output_path}")
201
+
202
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
203
+ inputs_uri = io.inputs_path(sub_action_output_path)
204
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri, max_bytes=_task.max_inline_io_bytes)
205
+
206
+ md = task_spec.task_template.metadata
207
+ ignored_input_vars = []
208
+ if len(md.cache_ignore_input_vars) > 0:
209
+ ignored_input_vars = list(md.cache_ignore_input_vars)
210
+ cache_key = None
211
+ if task_spec.task_template.metadata and task_spec.task_template.metadata.discoverable:
212
+ discovery_version = task_spec.task_template.metadata.discovery_version
213
+ cache_key = convert.generate_cache_key_hash(
214
+ _task.name,
215
+ inputs_hash,
216
+ task_spec.task_template.interface,
217
+ discovery_version,
218
+ ignored_input_vars,
219
+ inputs.proto_inputs,
220
+ )
221
+
222
+ # Clear to free memory
223
+ serialized_inputs = None # type: ignore
224
+ inputs_hash = None # type: ignore
225
+
226
+ action = Action.from_task(
227
+ sub_action_id=identifier_pb2.ActionIdentifier(
228
+ name=sub_action_id.name,
229
+ run=identifier_pb2.RunIdentifier(
230
+ name=current_action_id.run_name,
231
+ project=current_action_id.project,
232
+ domain=current_action_id.domain,
233
+ org=current_action_id.org,
234
+ ),
235
+ ),
236
+ parent_action_name=current_action_id.name,
237
+ group_data=tctx.group_data,
238
+ task_spec=task_spec,
239
+ inputs_uri=inputs_uri,
240
+ run_output_base=tctx.run_base_dir,
241
+ cache_key=cache_key,
242
+ queue=_task.queue,
243
+ )
244
+
245
+ try:
246
+ logger.info(
247
+ f"Submitting action Run:[{action.run_name}, Parent:[{action.parent_action_name}], "
248
+ f"task:[{_task.name}], action:[{action.name}]"
249
+ )
250
+ n = await self.submit_action(action)
251
+ logger.info(f"Action for task [{_task.name}] action id: {action.name}, completed!")
252
+ except asyncio.CancelledError:
253
+ # If the action is cancelled, we need to cancel the action on the server as well
254
+ logger.info(f"Action {action.action_id.name} cancelled, cancelling on server")
255
+ await self.cancel_action(action)
256
+ raise
257
+
258
+ # If the action is aborted, we should abort the controller as well
259
+ if n.phase == run_definition_pb2.PHASE_ABORTED:
260
+ logger.warning(f"Action {n.action_id.name} was aborted, aborting current Action{current_action_id.name}")
261
+ raise flyte.errors.RunAbortedError(
262
+ f"Action {n.action_id.name} was aborted, aborting current Action {current_action_id.name}"
263
+ )
264
+
265
+ if n.phase == run_definition_pb2.PHASE_TIMED_OUT:
266
+ logger.warning(
267
+ f"Action {n.action_id.name} timed out, raising timeout exception Action {current_action_id.name}"
268
+ )
269
+ raise flyte.errors.TaskTimeoutError(
270
+ f"Action {n.action_id.name} timed out, raising exception in current Action {current_action_id.name}"
271
+ )
272
+
273
+ if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED:
274
+ exc = await handle_action_failure(action, _task.name)
275
+ raise exc
276
+
277
+ if _task.native_interface.outputs:
278
+ if not n.realized_outputs_uri:
279
+ raise flyte.errors.RuntimeSystemError(
280
+ "RuntimeError",
281
+ f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.",
282
+ )
283
+ return await load_and_convert_outputs(
284
+ _task.native_interface, n.realized_outputs_uri, max_bytes=_task.max_inline_io_bytes
285
+ )
286
+ return None
287
+
288
+ async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
289
+ """
290
+ Submit a task to the remote controller.This creates a new action on the queue service.
291
+ """
292
+ ctx = internal_ctx()
293
+ tctx = ctx.data.task_context
294
+ if tctx is None:
295
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
296
+ current_action_id = tctx.action
297
+ task_call_seq = self.generate_task_call_sequence(_task, current_action_id)
298
+ async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
299
+ return await self._submit(task_call_seq, _task, *args, **kwargs)
300
+
301
+ def _sync_thread_loop_runner(self) -> None:
302
+ """This method runs the event loop and should be invoked in a separate thread."""
303
+
304
+ loop = self._submit_loop
305
+ assert loop is not None
306
+ try:
307
+ loop.run_forever()
308
+ finally:
309
+ loop.close()
310
+
311
+ def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
312
+ """
313
+ This function creates a cached thread and loop for the purpose of calling the submit method synchronously,
314
+ returning a concurrent Future that can be awaited. There's no need for a lock because this function itself is
315
+ single threaded and non-async. This pattern here is basically the trivial/degenerate case of the thread pool
316
+ in the LocalController.
317
+ Please see additional comments in protocol.
318
+
319
+ :param _task:
320
+ :param args:
321
+ :param kwargs:
322
+ :return:
323
+ """
324
+ if self._submit_thread is None:
325
+ # Please see LocalController for the general implementation of this pattern.
326
+ def exc_handler(loop, context):
327
+ logger.error(f"Remote controller submit sync loop caught exception in {loop}: {context}")
328
+
329
+ with _selector_policy():
330
+ self._submit_loop = asyncio.new_event_loop()
331
+ self._submit_loop.set_exception_handler(exc_handler)
332
+
333
+ self._submit_thread = threading.Thread(
334
+ name=f"remote-controller-{os.getpid()}-submitter",
335
+ daemon=True,
336
+ target=self._sync_thread_loop_runner,
337
+ )
338
+ self._submit_thread.start()
339
+
340
+ coro = self.submit(_task, *args, **kwargs)
341
+ assert self._submit_loop is not None, "Submit loop should always have been initialized by now"
342
+ fut = asyncio.run_coroutine_threadsafe(coro, self._submit_loop)
343
+ return fut
344
+
345
+ async def finalize_parent_action(self, action_id: ActionID):
346
+ """
347
+ This method is invoked when the parent action is finished. It will finalize the run and upload the outputs
348
+ to the control plane.
349
+ """
350
+ run_id = identifier_pb2.RunIdentifier(
351
+ name=action_id.run_name,
352
+ project=action_id.project,
353
+ domain=action_id.domain,
354
+ org=action_id.org,
355
+ )
356
+ await super()._finalize_parent_action(run_id=run_id, parent_action_name=action_id.name)
357
+ self._parent_action_semaphore.pop(unique_action_name(action_id), None)
358
+ self._parent_action_task_call_sequence.pop(unique_action_name(action_id), None)
359
+
360
+ async def get_action_outputs(
361
+ self, _interface: NativeInterface, _func: Callable, *args, **kwargs
362
+ ) -> Tuple[TraceInfo, bool]:
363
+ """
364
+ This method returns the outputs of the action, if it is available.
365
+ If not available it raises a NotFoundError.
366
+ :param _interface: NativeInterface
367
+ :param _func: Function name
368
+ :param args: Arguments
369
+ :param kwargs: Keyword arguments
370
+ :return:
371
+ """
372
+ ctx = internal_ctx()
373
+ tctx = ctx.data.task_context
374
+ if tctx is None:
375
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
376
+ current_action_id = tctx.action
377
+
378
+ func_name = _func.__name__
379
+ invoke_seq_num = self.generate_task_call_sequence(_func, current_action_id)
380
+
381
+ inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
382
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
383
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
384
+
385
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
386
+ tctx, func_name, inputs_hash, invoke_seq_num
387
+ )
388
+
389
+ inputs_uri = io.inputs_path(sub_action_output_path)
390
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri, max_bytes=MAX_TRACE_BYTES)
391
+ # Clear to free memory
392
+ serialized_inputs = None # type: ignore
393
+
394
+ prev_action = await self.get_action(
395
+ identifier_pb2.ActionIdentifier(
396
+ name=sub_action_id.name,
397
+ run=identifier_pb2.RunIdentifier(
398
+ name=current_action_id.run_name,
399
+ project=current_action_id.project,
400
+ domain=current_action_id.domain,
401
+ org=current_action_id.org,
402
+ ),
403
+ ),
404
+ current_action_id.name,
405
+ )
406
+
407
+ if prev_action is None:
408
+ return TraceInfo(func_name, sub_action_id, _interface, inputs_uri), False
409
+
410
+ if prev_action.phase == run_definition_pb2.PHASE_FAILED:
411
+ if prev_action.has_error():
412
+ exc = convert.convert_error_to_native(prev_action.err)
413
+ return (
414
+ TraceInfo(func_name, sub_action_id, _interface, inputs_uri, error=exc),
415
+ True,
416
+ )
417
+ else:
418
+ logger.warning(f"Action {prev_action.action_id.name} failed, but no error was found, re-running trace!")
419
+ elif prev_action.realized_outputs_uri is not None:
420
+ o = await io.load_outputs(prev_action.realized_outputs_uri, max_bytes=MAX_TRACE_BYTES)
421
+ outputs = await convert.convert_outputs_to_native(_interface, o)
422
+ return (
423
+ TraceInfo(func_name, sub_action_id, _interface, inputs_uri, output=outputs),
424
+ True,
425
+ )
426
+
427
+ return TraceInfo(func_name, sub_action_id, _interface, inputs_uri), False
428
+
429
+ async def record_trace(self, info: TraceInfo):
430
+ """
431
+ Record a trace action. This is used to record the trace of the action and should be called when the action
432
+ :param info:
433
+ :return:
434
+ """
435
+ ctx = internal_ctx()
436
+ tctx = ctx.data.task_context
437
+ if tctx is None:
438
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
439
+
440
+ current_action_id = tctx.action
441
+ sub_run_output_path = storage.join(tctx.run_base_dir, info.action.name)
442
+ outputs_file_path: str = ""
443
+
444
+ if info.interface.has_outputs():
445
+ if info.error:
446
+ err = convert.convert_from_native_to_error(info.error)
447
+ await io.upload_error(err.err, sub_run_output_path)
448
+ else:
449
+ outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
450
+ outputs_file_path = io.outputs_path(sub_run_output_path)
451
+ await io.upload_outputs(outputs, sub_run_output_path, max_bytes=MAX_TRACE_BYTES)
452
+
453
+ typed_interface = transform_native_to_typed_interface(info.interface)
454
+
455
+ trace_action = Action.from_trace(
456
+ parent_action_name=current_action_id.name,
457
+ action_id=identifier_pb2.ActionIdentifier(
458
+ name=info.action.name,
459
+ run=identifier_pb2.RunIdentifier(
460
+ name=current_action_id.run_name,
461
+ project=current_action_id.project,
462
+ domain=current_action_id.domain,
463
+ org=current_action_id.org,
464
+ ),
465
+ ),
466
+ inputs_uri=info.inputs_path,
467
+ outputs_uri=outputs_file_path,
468
+ friendly_name=info.name,
469
+ group_data=tctx.group_data,
470
+ run_output_base=tctx.run_base_dir,
471
+ start_time=info.start_time,
472
+ end_time=info.end_time,
473
+ typed_interface=typed_interface if typed_interface else None,
474
+ )
475
+
476
+ async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
477
+ try:
478
+ logger.info(
479
+ f"Submitting Trace action Run:[{trace_action.run_name},"
480
+ f" Parent:[{trace_action.parent_action_name}],"
481
+ f" Trace fn:[{info.name}], action:[{info.action.name}]"
482
+ )
483
+ await self.submit_action(trace_action)
484
+ logger.info(f"Trace Action for [{info.name}] action id: {info.action.name}, completed!")
485
+ except asyncio.CancelledError:
486
+ # If the action is cancelled, we need to cancel the action on the server as well
487
+ raise
488
+
489
+ async def _submit_task_ref(self, invoke_seq_num: int, _task: TaskDetails, *args, **kwargs) -> Any:
490
+ ctx = internal_ctx()
491
+ tctx = ctx.data.task_context
492
+ if tctx is None:
493
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
494
+ current_action_id = tctx.action
495
+ task_name = _task.name
496
+
497
+ native_interface = _task.interface
498
+ pb_interface = _task.pb2.spec.task_template.interface
499
+
500
+ inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs)
501
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
502
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
503
+ tctx, task_name, inputs_hash, invoke_seq_num
504
+ )
505
+
506
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
507
+ inputs_uri = io.inputs_path(sub_action_output_path)
508
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri, _task.max_inline_io_bytes)
509
+ # cache key - task name, task signature, inputs, cache version
510
+ cache_key = None
511
+ md = _task.pb2.spec.task_template.metadata
512
+ ignored_input_vars = []
513
+ if len(md.cache_ignore_input_vars) > 0:
514
+ ignored_input_vars = list(md.cache_ignore_input_vars)
515
+ if md and md.discoverable:
516
+ discovery_version = md.discovery_version
517
+ cache_key = convert.generate_cache_key_hash(
518
+ task_name,
519
+ inputs_hash,
520
+ pb_interface,
521
+ discovery_version,
522
+ ignored_input_vars,
523
+ inputs.proto_inputs,
524
+ )
525
+
526
+ # Clear to free memory
527
+ serialized_inputs = None # type: ignore
528
+ inputs_hash = None # type: ignore
529
+
530
+ action = Action.from_task(
531
+ sub_action_id=identifier_pb2.ActionIdentifier(
532
+ name=sub_action_id.name,
533
+ run=identifier_pb2.RunIdentifier(
534
+ name=current_action_id.run_name,
535
+ project=current_action_id.project,
536
+ domain=current_action_id.domain,
537
+ org=current_action_id.org,
538
+ ),
539
+ ),
540
+ parent_action_name=current_action_id.name,
541
+ group_data=tctx.group_data,
542
+ task_spec=_task.pb2.spec,
543
+ inputs_uri=inputs_uri,
544
+ run_output_base=tctx.run_base_dir,
545
+ cache_key=cache_key,
546
+ queue=None,
547
+ )
548
+
549
+ try:
550
+ logger.info(
551
+ f"Submitting action Run:[{action.run_name}, Parent:[{action.parent_action_name}], "
552
+ f"task:[{task_name}], action:[{action.name}]"
553
+ )
554
+ n = await self.submit_action(action)
555
+ logger.info(f"Action for task [{task_name}] action id: {action.name}, completed!")
556
+ except asyncio.CancelledError:
557
+ # If the action is cancelled, we need to cancel the action on the server as well
558
+ logger.info(f"Action {action.action_id.name} cancelled, cancelling on server")
559
+ await self.cancel_action(action)
560
+ raise
561
+
562
+ if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED:
563
+ exc = await handle_action_failure(action, task_name)
564
+ raise exc
565
+
566
+ if native_interface.outputs:
567
+ if not n.realized_outputs_uri:
568
+ raise flyte.errors.RuntimeSystemError(
569
+ "RuntimeError",
570
+ f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.",
571
+ )
572
+ return await load_and_convert_outputs(native_interface, n.realized_outputs_uri, _task.max_inline_io_bytes)
573
+ return None
574
+
575
+ async def submit_task_ref(self, _task: TaskDetails, *args, **kwargs) -> Any:
576
+ ctx = internal_ctx()
577
+ tctx = ctx.data.task_context
578
+ if tctx is None:
579
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
580
+ current_action_id = tctx.action
581
+ task_call_seq = self.generate_task_call_sequence(_task, current_action_id)
582
+ async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
583
+ return await self._submit_task_ref(task_call_seq, _task, *args, **kwargs)