flyte 0.1.0__py3-none-any.whl → 0.2.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (205) hide show
  1. flyte/__init__.py +62 -2
  2. flyte/_api_commons.py +3 -0
  3. flyte/_bin/__init__.py +0 -0
  4. flyte/_bin/runtime.py +126 -0
  5. flyte/_build.py +25 -0
  6. flyte/_cache/__init__.py +12 -0
  7. flyte/_cache/cache.py +146 -0
  8. flyte/_cache/defaults.py +9 -0
  9. flyte/_cache/policy_function_body.py +42 -0
  10. flyte/_cli/__init__.py +0 -0
  11. flyte/_cli/_common.py +299 -0
  12. flyte/_cli/_create.py +42 -0
  13. flyte/_cli/_delete.py +23 -0
  14. flyte/_cli/_deploy.py +140 -0
  15. flyte/_cli/_get.py +235 -0
  16. flyte/_cli/_params.py +538 -0
  17. flyte/_cli/_run.py +174 -0
  18. flyte/_cli/main.py +98 -0
  19. flyte/_code_bundle/__init__.py +8 -0
  20. flyte/_code_bundle/_ignore.py +113 -0
  21. flyte/_code_bundle/_packaging.py +187 -0
  22. flyte/_code_bundle/_utils.py +339 -0
  23. flyte/_code_bundle/bundle.py +178 -0
  24. flyte/_context.py +146 -0
  25. flyte/_datastructures.py +342 -0
  26. flyte/_deploy.py +202 -0
  27. flyte/_doc.py +29 -0
  28. flyte/_docstring.py +32 -0
  29. flyte/_environment.py +43 -0
  30. flyte/_group.py +31 -0
  31. flyte/_hash.py +23 -0
  32. flyte/_image.py +757 -0
  33. flyte/_initialize.py +643 -0
  34. flyte/_interface.py +84 -0
  35. flyte/_internal/__init__.py +3 -0
  36. flyte/_internal/controllers/__init__.py +115 -0
  37. flyte/_internal/controllers/_local_controller.py +118 -0
  38. flyte/_internal/controllers/_trace.py +40 -0
  39. flyte/_internal/controllers/pbhash.py +39 -0
  40. flyte/_internal/controllers/remote/__init__.py +40 -0
  41. flyte/_internal/controllers/remote/_action.py +141 -0
  42. flyte/_internal/controllers/remote/_client.py +43 -0
  43. flyte/_internal/controllers/remote/_controller.py +361 -0
  44. flyte/_internal/controllers/remote/_core.py +402 -0
  45. flyte/_internal/controllers/remote/_informer.py +361 -0
  46. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  47. flyte/_internal/imagebuild/__init__.py +11 -0
  48. flyte/_internal/imagebuild/docker_builder.py +416 -0
  49. flyte/_internal/imagebuild/image_builder.py +241 -0
  50. flyte/_internal/imagebuild/remote_builder.py +0 -0
  51. flyte/_internal/resolvers/__init__.py +0 -0
  52. flyte/_internal/resolvers/_task_module.py +54 -0
  53. flyte/_internal/resolvers/common.py +31 -0
  54. flyte/_internal/resolvers/default.py +28 -0
  55. flyte/_internal/runtime/__init__.py +0 -0
  56. flyte/_internal/runtime/convert.py +205 -0
  57. flyte/_internal/runtime/entrypoints.py +135 -0
  58. flyte/_internal/runtime/io.py +136 -0
  59. flyte/_internal/runtime/resources_serde.py +138 -0
  60. flyte/_internal/runtime/task_serde.py +210 -0
  61. flyte/_internal/runtime/taskrunner.py +190 -0
  62. flyte/_internal/runtime/types_serde.py +54 -0
  63. flyte/_logging.py +124 -0
  64. flyte/_protos/__init__.py +0 -0
  65. flyte/_protos/common/authorization_pb2.py +66 -0
  66. flyte/_protos/common/authorization_pb2.pyi +108 -0
  67. flyte/_protos/common/authorization_pb2_grpc.py +4 -0
  68. flyte/_protos/common/identifier_pb2.py +71 -0
  69. flyte/_protos/common/identifier_pb2.pyi +82 -0
  70. flyte/_protos/common/identifier_pb2_grpc.py +4 -0
  71. flyte/_protos/common/identity_pb2.py +48 -0
  72. flyte/_protos/common/identity_pb2.pyi +72 -0
  73. flyte/_protos/common/identity_pb2_grpc.py +4 -0
  74. flyte/_protos/common/list_pb2.py +36 -0
  75. flyte/_protos/common/list_pb2.pyi +69 -0
  76. flyte/_protos/common/list_pb2_grpc.py +4 -0
  77. flyte/_protos/common/policy_pb2.py +37 -0
  78. flyte/_protos/common/policy_pb2.pyi +27 -0
  79. flyte/_protos/common/policy_pb2_grpc.py +4 -0
  80. flyte/_protos/common/role_pb2.py +37 -0
  81. flyte/_protos/common/role_pb2.pyi +53 -0
  82. flyte/_protos/common/role_pb2_grpc.py +4 -0
  83. flyte/_protos/common/runtime_version_pb2.py +28 -0
  84. flyte/_protos/common/runtime_version_pb2.pyi +24 -0
  85. flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
  86. flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
  87. flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
  88. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  89. flyte/_protos/secret/definition_pb2.py +49 -0
  90. flyte/_protos/secret/definition_pb2.pyi +93 -0
  91. flyte/_protos/secret/definition_pb2_grpc.py +4 -0
  92. flyte/_protos/secret/payload_pb2.py +62 -0
  93. flyte/_protos/secret/payload_pb2.pyi +94 -0
  94. flyte/_protos/secret/payload_pb2_grpc.py +4 -0
  95. flyte/_protos/secret/secret_pb2.py +38 -0
  96. flyte/_protos/secret/secret_pb2.pyi +6 -0
  97. flyte/_protos/secret/secret_pb2_grpc.py +198 -0
  98. flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
  99. flyte/_protos/validate/validate/validate_pb2.py +76 -0
  100. flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
  101. flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  102. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  103. flyte/_protos/workflow/queue_service_pb2.py +106 -0
  104. flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
  105. flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  106. flyte/_protos/workflow/run_definition_pb2.py +128 -0
  107. flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
  108. flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  109. flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
  110. flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  111. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  112. flyte/_protos/workflow/run_service_pb2.py +133 -0
  113. flyte/_protos/workflow/run_service_pb2.pyi +175 -0
  114. flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
  115. flyte/_protos/workflow/state_service_pb2.py +58 -0
  116. flyte/_protos/workflow/state_service_pb2.pyi +71 -0
  117. flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
  118. flyte/_protos/workflow/task_definition_pb2.py +72 -0
  119. flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
  120. flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  121. flyte/_protos/workflow/task_service_pb2.py +44 -0
  122. flyte/_protos/workflow/task_service_pb2.pyi +31 -0
  123. flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
  124. flyte/_resources.py +226 -0
  125. flyte/_retry.py +32 -0
  126. flyte/_reusable_environment.py +25 -0
  127. flyte/_run.py +410 -0
  128. flyte/_secret.py +61 -0
  129. flyte/_task.py +367 -0
  130. flyte/_task_environment.py +200 -0
  131. flyte/_timeout.py +47 -0
  132. flyte/_tools.py +27 -0
  133. flyte/_trace.py +128 -0
  134. flyte/_utils/__init__.py +20 -0
  135. flyte/_utils/asyn.py +119 -0
  136. flyte/_utils/coro_management.py +25 -0
  137. flyte/_utils/file_handling.py +72 -0
  138. flyte/_utils/helpers.py +108 -0
  139. flyte/_utils/lazy_module.py +54 -0
  140. flyte/_utils/uv_script_parser.py +49 -0
  141. flyte/_version.py +21 -0
  142. flyte/config/__init__.py +168 -0
  143. flyte/config/_config.py +196 -0
  144. flyte/config/_internal.py +64 -0
  145. flyte/connectors/__init__.py +0 -0
  146. flyte/errors.py +143 -0
  147. flyte/extras/__init__.py +5 -0
  148. flyte/extras/_container.py +273 -0
  149. flyte/io/__init__.py +11 -0
  150. flyte/io/_dataframe.py +0 -0
  151. flyte/io/_dir.py +448 -0
  152. flyte/io/_file.py +468 -0
  153. flyte/io/pickle/__init__.py +0 -0
  154. flyte/io/pickle/transformer.py +117 -0
  155. flyte/io/structured_dataset/__init__.py +129 -0
  156. flyte/io/structured_dataset/basic_dfs.py +219 -0
  157. flyte/io/structured_dataset/structured_dataset.py +1061 -0
  158. flyte/remote/__init__.py +25 -0
  159. flyte/remote/_client/__init__.py +0 -0
  160. flyte/remote/_client/_protocols.py +131 -0
  161. flyte/remote/_client/auth/__init__.py +12 -0
  162. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  163. flyte/remote/_client/auth/_authenticators/base.py +397 -0
  164. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  165. flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
  166. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  167. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  168. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  169. flyte/remote/_client/auth/_channel.py +184 -0
  170. flyte/remote/_client/auth/_client_config.py +83 -0
  171. flyte/remote/_client/auth/_default_html.py +32 -0
  172. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  173. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  174. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  175. flyte/remote/_client/auth/_keyring.py +143 -0
  176. flyte/remote/_client/auth/_token_client.py +260 -0
  177. flyte/remote/_client/auth/errors.py +16 -0
  178. flyte/remote/_client/controlplane.py +95 -0
  179. flyte/remote/_console.py +18 -0
  180. flyte/remote/_data.py +155 -0
  181. flyte/remote/_logs.py +116 -0
  182. flyte/remote/_project.py +86 -0
  183. flyte/remote/_run.py +873 -0
  184. flyte/remote/_secret.py +132 -0
  185. flyte/remote/_task.py +227 -0
  186. flyte/report/__init__.py +3 -0
  187. flyte/report/_report.py +178 -0
  188. flyte/report/_template.html +124 -0
  189. flyte/storage/__init__.py +24 -0
  190. flyte/storage/_remote_fs.py +34 -0
  191. flyte/storage/_storage.py +251 -0
  192. flyte/storage/_utils.py +5 -0
  193. flyte/types/__init__.py +13 -0
  194. flyte/types/_interface.py +25 -0
  195. flyte/types/_renderer.py +162 -0
  196. flyte/types/_string_literals.py +120 -0
  197. flyte/types/_type_engine.py +2211 -0
  198. flyte/types/_utils.py +80 -0
  199. flyte-0.2.0b0.dist-info/METADATA +179 -0
  200. flyte-0.2.0b0.dist-info/RECORD +204 -0
  201. {flyte-0.1.0.dist-info → flyte-0.2.0b0.dist-info}/WHEEL +2 -1
  202. flyte-0.2.0b0.dist-info/entry_points.txt +3 -0
  203. flyte-0.2.0b0.dist-info/top_level.txt +1 -0
  204. flyte-0.1.0.dist-info/METADATA +0 -6
  205. flyte-0.1.0.dist-info/RECORD +0 -5
@@ -0,0 +1,361 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from collections import defaultdict
5
+ from pathlib import Path
6
+ from typing import Any, Awaitable, DefaultDict, Tuple, TypeVar
7
+
8
+ import flyte
9
+ import flyte.errors
10
+ import flyte.storage as storage
11
+ import flyte.types as types
12
+ from flyte._code_bundle import build_pkl_bundle
13
+ from flyte._context import internal_ctx
14
+ from flyte._datastructures import ActionID, NativeInterface, SerializationContext
15
+ from flyte._internal.controllers import TraceInfo
16
+ from flyte._internal.controllers.remote._action import Action
17
+ from flyte._internal.controllers.remote._core import Controller
18
+ from flyte._internal.controllers.remote._service_protocol import ClientSet
19
+ from flyte._internal.runtime import convert, io
20
+ from flyte._internal.runtime.task_serde import translate_task_to_wire
21
+ from flyte._logging import logger
22
+ from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
23
+ from flyte._task import TaskTemplate
24
+
25
+ R = TypeVar("R")
26
+
27
+
28
+ async def upload_inputs_with_retry(inputs: convert.Inputs, inputs_uri: str) -> None:
29
+ """
30
+ Upload inputs to the specified URI with error handling.
31
+
32
+ Args:
33
+ inputs: The inputs to upload
34
+ inputs_uri: The destination URI
35
+
36
+ Raises:
37
+ RuntimeSystemError: If the upload fails
38
+ """
39
+ try:
40
+ # TODO Add retry decorator to this
41
+ await io.upload_inputs(inputs, inputs_uri)
42
+ except Exception as e:
43
+ logger.exception("Failed to upload inputs", e)
44
+ raise flyte.errors.RuntimeSystemError(type(e).__name__, str(e)) from e
45
+
46
+
47
+ async def handle_action_failure(action: Action, task_name: str) -> Exception:
48
+ """
49
+ Handle action failure by loading error details or raising a RuntimeSystemError.
50
+
51
+ Args:
52
+ action: The updated action
53
+ task_name: The name of the task
54
+
55
+ Raises:
56
+ Exception: The converted native exception or RuntimeSystemError
57
+ """
58
+ err = action.err or action.client_err
59
+ if not err and action.phase == run_definition_pb2.PHASE_FAILED:
60
+ logger.error(f"Server reported failure for action {action.name}, checking error file.")
61
+ try:
62
+ error_path = io.error_path(f"{action.run_output_base}/{action.action_id.name}/1")
63
+ err = await io.load_error(error_path)
64
+ except Exception as e:
65
+ logger.exception("Failed to load error file", e)
66
+ err = flyte.errors.RuntimeSystemError(type(e).__name__, f"Failed to load error file: {e}")
67
+ else:
68
+ logger.error(f"Server reported failure for action {action.action_id.name}, error: {err}")
69
+
70
+ exc = convert.convert_error_to_native(err)
71
+ if not exc:
72
+ return flyte.errors.RuntimeSystemError("UnableToConvertError", f"Error in task {task_name}: {err}")
73
+ return exc
74
+
75
+
76
+ async def load_and_convert_outputs(iface: NativeInterface, realized_outputs_uri: str) -> Any:
77
+ """
78
+ Load outputs from the given URI and convert them to native format.
79
+
80
+ Args:
81
+ iface: The Native interface
82
+ realized_outputs_uri: The URI where outputs are stored
83
+
84
+ Returns:
85
+ The converted native outputs
86
+ """
87
+ outputs_file_path = io.outputs_path(realized_outputs_uri)
88
+ outputs = await io.load_outputs(outputs_file_path)
89
+ return await convert.convert_outputs_to_native(iface, outputs)
90
+
91
+
92
+ class RemoteController(Controller):
93
+ """
94
+ This a specialized controller that wraps the core controller and performs IO, serialization and deserialization
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ client_coro: Awaitable[ClientSet],
100
+ workers: int,
101
+ max_system_retries: int,
102
+ default_parent_concurrency: int = 100,
103
+ ):
104
+ """ """
105
+ super().__init__(
106
+ client_coro=client_coro,
107
+ workers=workers,
108
+ max_system_retries=max_system_retries,
109
+ )
110
+ self._default_parent_concurrency = default_parent_concurrency
111
+ self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
112
+ lambda: asyncio.Semaphore(default_parent_concurrency)
113
+ )
114
+
115
+ async def _submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
116
+ ctx = internal_ctx()
117
+ tctx = ctx.data.task_context
118
+ if tctx is None:
119
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
120
+ current_action_id = tctx.action
121
+
122
+ inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
123
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, _task.name, inputs)
124
+
125
+ # In the case of a regular code bundle, we will just pass it down as it is to the downstream tasks
126
+ # It is not allowed to change the code bundle (for regular code bundles) in the middle of a run.
127
+ code_bundle = tctx.code_bundle
128
+
129
+ if code_bundle and code_bundle.pkl:
130
+ logger.debug(f"Building new pkl bundle for task {sub_action_id.name}")
131
+ code_bundle = await build_pkl_bundle(
132
+ _task,
133
+ upload_to_controlplane=False,
134
+ upload_from_dataplane_path=io.pkl_path(sub_action_output_path),
135
+ )
136
+
137
+ inputs_uri = io.inputs_path(sub_action_output_path)
138
+ await upload_inputs_with_retry(inputs, inputs_uri)
139
+
140
+ root_dir = Path(code_bundle.destination).absolute() if code_bundle else Path.cwd()
141
+ # Don't set output path in sec context because node executor will set it
142
+ new_serialization_context = SerializationContext(
143
+ project=current_action_id.project,
144
+ domain=current_action_id.domain,
145
+ org=current_action_id.org,
146
+ code_bundle=code_bundle,
147
+ version=tctx.version,
148
+ # supplied version.
149
+ input_path=inputs_uri,
150
+ image_cache=tctx.compiled_image_cache,
151
+ root_dir=root_dir,
152
+ )
153
+
154
+ task_spec = translate_task_to_wire(_task, new_serialization_context)
155
+
156
+ action = Action.from_task(
157
+ sub_action_id=run_definition_pb2.ActionIdentifier(
158
+ name=sub_action_id.name,
159
+ run=run_definition_pb2.RunIdentifier(
160
+ name=current_action_id.run_name,
161
+ project=current_action_id.project,
162
+ domain=current_action_id.domain,
163
+ org=current_action_id.org,
164
+ ),
165
+ ),
166
+ parent_action_name=current_action_id.name,
167
+ group_data=tctx.group_data,
168
+ task_spec=task_spec,
169
+ inputs_uri=inputs_uri,
170
+ run_output_base=tctx.run_base_dir,
171
+ )
172
+
173
+ try:
174
+ logger.info(
175
+ f"Submitting action Run:[{action.run_name}, Parent:[{action.parent_action_name}], "
176
+ f"task:[{_task.name}], action:[{action.name}]"
177
+ )
178
+ n = await self.submit_action(action)
179
+ logger.info(f"Action for task [{_task.name}] action id: {action.name}, completed!")
180
+ except asyncio.CancelledError:
181
+ # If the action is cancelled, we need to cancel the action on the server as well
182
+ logger.info(f"Action {action.action_id.name} cancelled, cancelling on server")
183
+ await self.cancel_action(action)
184
+ raise
185
+
186
+ if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED:
187
+ exc = await handle_action_failure(action, _task.name)
188
+ raise exc
189
+
190
+ if _task.native_interface.outputs:
191
+ if not n.realized_outputs_uri:
192
+ raise flyte.errors.RuntimeSystemError(
193
+ "RuntimeError",
194
+ f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.",
195
+ )
196
+ return await load_and_convert_outputs(_task.native_interface, n.realized_outputs_uri)
197
+ return None
198
+
199
+ async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
200
+ """
201
+ Submit a task to the remote controller.This creates a new action on the queue service.
202
+ """
203
+ ctx = internal_ctx()
204
+ tctx = ctx.data.task_context
205
+ if tctx is None:
206
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
207
+ current_action_id = tctx.action
208
+ async with self._parent_action_semaphore[current_action_id.name]:
209
+ return await self._submit(_task, *args, **kwargs)
210
+
211
+ async def finalize_parent_action(self, action_id: ActionID):
212
+ """
213
+ This method is invoked when the parent action is finished. It will finalize the run and upload the outputs
214
+ to the control plane.
215
+ """
216
+ run_id = run_definition_pb2.RunIdentifier(
217
+ name=action_id.run_name,
218
+ project=action_id.project,
219
+ domain=action_id.domain,
220
+ org=action_id.org,
221
+ )
222
+ await super()._finalize_parent_action(run_id=run_id, parent_action_name=action_id.name)
223
+ self._parent_action_semaphore.pop(action_id.name, None)
224
+
225
+ async def get_action_outputs(
226
+ self, _interface: NativeInterface, _func_name: str, *args, **kwargs
227
+ ) -> Tuple[TraceInfo, bool]:
228
+ """
229
+ This method returns the outputs of the action, if it is available.
230
+ If not available it raises a NotFoundError.
231
+ :param _interface: NativeInterface
232
+ :param _func_name: Function name
233
+ :param args: Arguments
234
+ :param kwargs: Keyword arguments
235
+ :return:
236
+ """
237
+ ctx = internal_ctx()
238
+ tctx = ctx.data.task_context
239
+ if tctx is None:
240
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
241
+ current_action_id = tctx.action
242
+
243
+ inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
244
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, _func_name, inputs)
245
+
246
+ inputs_uri = io.inputs_path(sub_action_output_path)
247
+ await upload_inputs_with_retry(inputs, inputs_uri)
248
+
249
+ prev_action = await self.get_action(
250
+ run_definition_pb2.ActionIdentifier(
251
+ name=sub_action_id.name,
252
+ run=run_definition_pb2.RunIdentifier(
253
+ name=current_action_id.run_name,
254
+ project=current_action_id.project,
255
+ domain=current_action_id.domain,
256
+ org=current_action_id.org,
257
+ ),
258
+ ),
259
+ current_action_id.name,
260
+ )
261
+
262
+ if prev_action is None:
263
+ return TraceInfo(sub_action_id, _interface, inputs_uri), False
264
+
265
+ if prev_action.phase == run_definition_pb2.PHASE_FAILED:
266
+ if prev_action.has_error():
267
+ exc = convert.convert_error_to_native(prev_action.err)
268
+ return TraceInfo(sub_action_id, _interface, inputs_uri, error=exc), True
269
+ else:
270
+ logger.warning(f"Action {prev_action.action_id.name} failed, but no error was found, re-running trace!")
271
+ elif prev_action.realized_outputs_uri is not None:
272
+ outputs_file_path = io.outputs_path(prev_action.realized_outputs_uri)
273
+ o = await io.load_outputs(outputs_file_path)
274
+ outputs = await convert.convert_outputs_to_native(_interface, o)
275
+ return TraceInfo(sub_action_id, _interface, inputs_uri, output=outputs), True
276
+
277
+ return TraceInfo(sub_action_id, _interface, inputs_uri), False
278
+
279
+ async def record_trace(self, info: TraceInfo):
280
+ """
281
+ Record a trace action. This is used to record the trace of the action and should be called when the action
282
+ :param info:
283
+ :return:
284
+ """
285
+ ctx = internal_ctx()
286
+ tctx = ctx.data.task_context
287
+ if tctx is None:
288
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
289
+
290
+ current_output_path = tctx.output_path
291
+ sub_run_output_path = storage.join(current_output_path, info.action.name)
292
+
293
+ if info.interface.has_outputs():
294
+ if info.output:
295
+ outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
296
+ outputs_file_path = io.outputs_path(sub_run_output_path)
297
+ await io.upload_outputs(outputs, outputs_file_path)
298
+ elif info.error:
299
+ err = convert.convert_from_native_to_error(info.error)
300
+ error_path = io.error_path(sub_run_output_path)
301
+ await io.upload_error(err.err, error_path)
302
+ else:
303
+ raise flyte.errors.RuntimeSystemError("BadTraceInfo", "Trace info does not have output or error")
304
+
305
+ async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
306
+ ctx = internal_ctx()
307
+ tctx = ctx.data.task_context
308
+ if tctx is None:
309
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
310
+ current_action_id = tctx.action
311
+ task_name = _task.spec.task_template.id.name
312
+
313
+ native_interface = types.guess_interface(_task.spec.task_template.interface)
314
+ inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs)
315
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, task_name, inputs)
316
+
317
+ inputs_uri = io.inputs_path(sub_action_output_path)
318
+ await upload_inputs_with_retry(inputs, inputs_uri)
319
+
320
+ action = Action.from_task(
321
+ sub_action_id=run_definition_pb2.ActionIdentifier(
322
+ name=sub_action_id.name,
323
+ run=run_definition_pb2.RunIdentifier(
324
+ name=current_action_id.run_name,
325
+ project=current_action_id.project,
326
+ domain=current_action_id.domain,
327
+ org=current_action_id.org,
328
+ ),
329
+ ),
330
+ parent_action_name=current_action_id.name,
331
+ group_data=tctx.group_data,
332
+ task_spec=_task.spec,
333
+ inputs_uri=inputs_uri,
334
+ run_output_base=tctx.run_base_dir,
335
+ )
336
+
337
+ try:
338
+ logger.info(
339
+ f"Submitting action Run:[{action.run_name}, Parent:[{action.parent_action_name}], "
340
+ f"task:[{task_name}], action:[{action.name}]"
341
+ )
342
+ n = await self.submit_action(action)
343
+ logger.info(f"Action for task [{task_name}] action id: {action.name}, completed!")
344
+ except asyncio.CancelledError:
345
+ # If the action is cancelled, we need to cancel the action on the server as well
346
+ logger.info(f"Action {action.action_id.name} cancelled, cancelling on server")
347
+ await self.cancel_action(action)
348
+ raise
349
+
350
+ if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED:
351
+ exc = await handle_action_failure(action, task_name)
352
+ raise exc
353
+
354
+ if native_interface.outputs:
355
+ if not n.realized_outputs_uri:
356
+ raise flyte.errors.RuntimeSystemError(
357
+ "RuntimeError",
358
+ f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.",
359
+ )
360
+ return await load_and_convert_outputs(native_interface, n.realized_outputs_uri)
361
+ return None