flyte 2.0.0b32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (204) hide show
  1. flyte/__init__.py +108 -0
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +195 -0
  5. flyte/_bin/serve.py +178 -0
  6. flyte/_build.py +26 -0
  7. flyte/_cache/__init__.py +12 -0
  8. flyte/_cache/cache.py +147 -0
  9. flyte/_cache/defaults.py +9 -0
  10. flyte/_cache/local_cache.py +216 -0
  11. flyte/_cache/policy_function_body.py +42 -0
  12. flyte/_code_bundle/__init__.py +8 -0
  13. flyte/_code_bundle/_ignore.py +121 -0
  14. flyte/_code_bundle/_packaging.py +218 -0
  15. flyte/_code_bundle/_utils.py +347 -0
  16. flyte/_code_bundle/bundle.py +266 -0
  17. flyte/_constants.py +1 -0
  18. flyte/_context.py +155 -0
  19. flyte/_custom_context.py +73 -0
  20. flyte/_debug/__init__.py +0 -0
  21. flyte/_debug/constants.py +38 -0
  22. flyte/_debug/utils.py +17 -0
  23. flyte/_debug/vscode.py +307 -0
  24. flyte/_deploy.py +408 -0
  25. flyte/_deployer.py +109 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +122 -0
  29. flyte/_excepthook.py +37 -0
  30. flyte/_group.py +32 -0
  31. flyte/_hash.py +8 -0
  32. flyte/_image.py +1055 -0
  33. flyte/_initialize.py +628 -0
  34. flyte/_interface.py +119 -0
  35. flyte/_internal/__init__.py +3 -0
  36. flyte/_internal/controllers/__init__.py +129 -0
  37. flyte/_internal/controllers/_local_controller.py +239 -0
  38. flyte/_internal/controllers/_trace.py +48 -0
  39. flyte/_internal/controllers/remote/__init__.py +58 -0
  40. flyte/_internal/controllers/remote/_action.py +211 -0
  41. flyte/_internal/controllers/remote/_client.py +47 -0
  42. flyte/_internal/controllers/remote/_controller.py +583 -0
  43. flyte/_internal/controllers/remote/_core.py +465 -0
  44. flyte/_internal/controllers/remote/_informer.py +381 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +3 -0
  47. flyte/_internal/imagebuild/docker_builder.py +706 -0
  48. flyte/_internal/imagebuild/image_builder.py +277 -0
  49. flyte/_internal/imagebuild/remote_builder.py +386 -0
  50. flyte/_internal/imagebuild/utils.py +78 -0
  51. flyte/_internal/resolvers/__init__.py +0 -0
  52. flyte/_internal/resolvers/_task_module.py +21 -0
  53. flyte/_internal/resolvers/common.py +31 -0
  54. flyte/_internal/resolvers/default.py +28 -0
  55. flyte/_internal/runtime/__init__.py +0 -0
  56. flyte/_internal/runtime/convert.py +486 -0
  57. flyte/_internal/runtime/entrypoints.py +204 -0
  58. flyte/_internal/runtime/io.py +188 -0
  59. flyte/_internal/runtime/resources_serde.py +152 -0
  60. flyte/_internal/runtime/reuse.py +125 -0
  61. flyte/_internal/runtime/rusty.py +193 -0
  62. flyte/_internal/runtime/task_serde.py +362 -0
  63. flyte/_internal/runtime/taskrunner.py +209 -0
  64. flyte/_internal/runtime/trigger_serde.py +160 -0
  65. flyte/_internal/runtime/types_serde.py +54 -0
  66. flyte/_keyring/__init__.py +0 -0
  67. flyte/_keyring/file.py +115 -0
  68. flyte/_logging.py +300 -0
  69. flyte/_map.py +312 -0
  70. flyte/_module.py +72 -0
  71. flyte/_pod.py +30 -0
  72. flyte/_resources.py +473 -0
  73. flyte/_retry.py +32 -0
  74. flyte/_reusable_environment.py +102 -0
  75. flyte/_run.py +724 -0
  76. flyte/_secret.py +96 -0
  77. flyte/_task.py +550 -0
  78. flyte/_task_environment.py +316 -0
  79. flyte/_task_plugins.py +47 -0
  80. flyte/_timeout.py +47 -0
  81. flyte/_tools.py +27 -0
  82. flyte/_trace.py +119 -0
  83. flyte/_trigger.py +1000 -0
  84. flyte/_utils/__init__.py +30 -0
  85. flyte/_utils/asyn.py +121 -0
  86. flyte/_utils/async_cache.py +139 -0
  87. flyte/_utils/coro_management.py +27 -0
  88. flyte/_utils/docker_credentials.py +173 -0
  89. flyte/_utils/file_handling.py +72 -0
  90. flyte/_utils/helpers.py +134 -0
  91. flyte/_utils/lazy_module.py +54 -0
  92. flyte/_utils/module_loader.py +104 -0
  93. flyte/_utils/org_discovery.py +57 -0
  94. flyte/_utils/uv_script_parser.py +49 -0
  95. flyte/_version.py +34 -0
  96. flyte/app/__init__.py +22 -0
  97. flyte/app/_app_environment.py +157 -0
  98. flyte/app/_deploy.py +125 -0
  99. flyte/app/_input.py +160 -0
  100. flyte/app/_runtime/__init__.py +3 -0
  101. flyte/app/_runtime/app_serde.py +347 -0
  102. flyte/app/_types.py +101 -0
  103. flyte/app/extras/__init__.py +3 -0
  104. flyte/app/extras/_fastapi.py +151 -0
  105. flyte/cli/__init__.py +12 -0
  106. flyte/cli/_abort.py +28 -0
  107. flyte/cli/_build.py +114 -0
  108. flyte/cli/_common.py +468 -0
  109. flyte/cli/_create.py +371 -0
  110. flyte/cli/_delete.py +45 -0
  111. flyte/cli/_deploy.py +293 -0
  112. flyte/cli/_gen.py +176 -0
  113. flyte/cli/_get.py +370 -0
  114. flyte/cli/_option.py +33 -0
  115. flyte/cli/_params.py +554 -0
  116. flyte/cli/_plugins.py +209 -0
  117. flyte/cli/_run.py +597 -0
  118. flyte/cli/_serve.py +64 -0
  119. flyte/cli/_update.py +37 -0
  120. flyte/cli/_user.py +17 -0
  121. flyte/cli/main.py +221 -0
  122. flyte/config/__init__.py +3 -0
  123. flyte/config/_config.py +248 -0
  124. flyte/config/_internal.py +73 -0
  125. flyte/config/_reader.py +225 -0
  126. flyte/connectors/__init__.py +11 -0
  127. flyte/connectors/_connector.py +270 -0
  128. flyte/connectors/_server.py +197 -0
  129. flyte/connectors/utils.py +135 -0
  130. flyte/errors.py +243 -0
  131. flyte/extend.py +19 -0
  132. flyte/extras/__init__.py +5 -0
  133. flyte/extras/_container.py +286 -0
  134. flyte/git/__init__.py +3 -0
  135. flyte/git/_config.py +21 -0
  136. flyte/io/__init__.py +29 -0
  137. flyte/io/_dataframe/__init__.py +131 -0
  138. flyte/io/_dataframe/basic_dfs.py +223 -0
  139. flyte/io/_dataframe/dataframe.py +1026 -0
  140. flyte/io/_dir.py +910 -0
  141. flyte/io/_file.py +914 -0
  142. flyte/io/_hashing_io.py +342 -0
  143. flyte/models.py +479 -0
  144. flyte/py.typed +0 -0
  145. flyte/remote/__init__.py +35 -0
  146. flyte/remote/_action.py +738 -0
  147. flyte/remote/_app.py +57 -0
  148. flyte/remote/_client/__init__.py +0 -0
  149. flyte/remote/_client/_protocols.py +189 -0
  150. flyte/remote/_client/auth/__init__.py +12 -0
  151. flyte/remote/_client/auth/_auth_utils.py +14 -0
  152. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  153. flyte/remote/_client/auth/_authenticators/base.py +403 -0
  154. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  155. flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
  156. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  157. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  158. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  159. flyte/remote/_client/auth/_channel.py +213 -0
  160. flyte/remote/_client/auth/_client_config.py +85 -0
  161. flyte/remote/_client/auth/_default_html.py +32 -0
  162. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  163. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  164. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  165. flyte/remote/_client/auth/_keyring.py +152 -0
  166. flyte/remote/_client/auth/_token_client.py +260 -0
  167. flyte/remote/_client/auth/errors.py +16 -0
  168. flyte/remote/_client/controlplane.py +128 -0
  169. flyte/remote/_common.py +30 -0
  170. flyte/remote/_console.py +19 -0
  171. flyte/remote/_data.py +161 -0
  172. flyte/remote/_logs.py +185 -0
  173. flyte/remote/_project.py +88 -0
  174. flyte/remote/_run.py +386 -0
  175. flyte/remote/_secret.py +142 -0
  176. flyte/remote/_task.py +527 -0
  177. flyte/remote/_trigger.py +306 -0
  178. flyte/remote/_user.py +33 -0
  179. flyte/report/__init__.py +3 -0
  180. flyte/report/_report.py +182 -0
  181. flyte/report/_template.html +124 -0
  182. flyte/storage/__init__.py +36 -0
  183. flyte/storage/_config.py +237 -0
  184. flyte/storage/_parallel_reader.py +274 -0
  185. flyte/storage/_remote_fs.py +34 -0
  186. flyte/storage/_storage.py +456 -0
  187. flyte/storage/_utils.py +5 -0
  188. flyte/syncify/__init__.py +56 -0
  189. flyte/syncify/_api.py +375 -0
  190. flyte/types/__init__.py +52 -0
  191. flyte/types/_interface.py +40 -0
  192. flyte/types/_pickle.py +145 -0
  193. flyte/types/_renderer.py +162 -0
  194. flyte/types/_string_literals.py +119 -0
  195. flyte/types/_type_engine.py +2254 -0
  196. flyte/types/_utils.py +80 -0
  197. flyte-2.0.0b32.data/scripts/debug.py +38 -0
  198. flyte-2.0.0b32.data/scripts/runtime.py +195 -0
  199. flyte-2.0.0b32.dist-info/METADATA +351 -0
  200. flyte-2.0.0b32.dist-info/RECORD +204 -0
  201. flyte-2.0.0b32.dist-info/WHEEL +5 -0
  202. flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
  203. flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
  204. flyte-2.0.0b32.dist-info/top_level.txt +1 -0
@@ -0,0 +1,738 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from collections import UserDict
5
+ from dataclasses import dataclass
6
+ from datetime import datetime, timedelta, timezone
7
+ from typing import (
8
+ Any,
9
+ AsyncGenerator,
10
+ AsyncIterator,
11
+ Dict,
12
+ Iterator,
13
+ List,
14
+ Literal,
15
+ Tuple,
16
+ Union,
17
+ cast,
18
+ )
19
+
20
+ import grpc
21
+ import rich.pretty
22
+ import rich.repr
23
+ from flyteidl2.common import identifier_pb2, list_pb2
24
+ from flyteidl2.task import common_pb2
25
+ from flyteidl2.workflow import run_definition_pb2, run_service_pb2
26
+ from flyteidl2.workflow.run_service_pb2 import WatchActionDetailsResponse
27
+ from rich.console import Console
28
+ from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
29
+
30
+ from flyte import types
31
+ from flyte._initialize import ensure_client, get_client, get_init_config
32
+ from flyte.remote._common import ToJSONMixin
33
+ from flyte.remote._logs import Logs
34
+ from flyte.syncify import syncify
35
+
36
+ WaitFor = Literal["terminal", "running", "logs-ready"]
37
+
38
+
39
+ def _action_time_phase(
40
+ action: run_definition_pb2.Action | run_definition_pb2.ActionDetails,
41
+ ) -> rich.repr.Result:
42
+ """
43
+ Rich representation of the action time and phase.
44
+ """
45
+ start_time = action.status.start_time.ToDatetime().replace(tzinfo=timezone.utc)
46
+ yield "start_time", start_time.isoformat()
47
+ if action.status.phase in [
48
+ run_definition_pb2.PHASE_FAILED,
49
+ run_definition_pb2.PHASE_SUCCEEDED,
50
+ run_definition_pb2.PHASE_ABORTED,
51
+ run_definition_pb2.PHASE_TIMED_OUT,
52
+ ]:
53
+ end_time = action.status.end_time.ToDatetime().replace(tzinfo=timezone.utc)
54
+ yield "end_time", end_time.isoformat()
55
+ yield "run_time", f"{(end_time - start_time).seconds} secs"
56
+ else:
57
+ yield "end_time", None
58
+ yield "run_time", f"{(datetime.now(timezone.utc) - start_time).seconds} secs"
59
+ yield "phase", run_definition_pb2.Phase.Name(action.status.phase)
60
+ if isinstance(action, run_definition_pb2.ActionDetails):
61
+ yield (
62
+ "error",
63
+ (f"{action.error_info.kind}: {action.error_info.message}" if action.HasField("error_info") else "NA"),
64
+ )
65
+
66
+
67
+ def _action_rich_repr(action: run_definition_pb2.Action) -> rich.repr.Result:
68
+ """
69
+ Rich representation of the action.
70
+ """
71
+ yield "name", action.id.run.name
72
+ if action.metadata.HasField("task"):
73
+ yield "task name", action.metadata.task.id.name
74
+ yield "type", action.metadata.task.task_type
75
+ elif action.metadata.HasField("trace"):
76
+ yield "trace", action.metadata.trace.name
77
+ yield "type", "trace"
78
+ yield "action name", action.id.name
79
+ yield from _action_time_phase(action)
80
+ yield "group", action.metadata.group
81
+ yield "parent", action.metadata.parent
82
+ yield "attempts", action.status.attempts
83
+
84
+
85
+ def _attempt_rich_repr(
86
+ action: List[run_definition_pb2.ActionAttempt],
87
+ ) -> rich.repr.Result:
88
+ for attempt in action:
89
+ yield "attempt", attempt.attempt
90
+ yield "phase", run_definition_pb2.Phase.Name(attempt.phase)
91
+ yield "logs_available", attempt.logs_available
92
+
93
+
94
+ def _action_details_rich_repr(
95
+ action: run_definition_pb2.ActionDetails,
96
+ ) -> rich.repr.Result:
97
+ """
98
+ Rich representation of the action details.
99
+ """
100
+ yield "name", action.id.run.name
101
+ yield from _action_time_phase(action)
102
+ if action.HasField("task"):
103
+ yield "task", action.task.task_template.id.name
104
+ yield "task_type", action.task.task_template.type
105
+ yield "task_version", action.task.task_template.id.version
106
+ yield "attempts", action.attempts
107
+ yield "error", (f"{action.error_info.kind}: {action.error_info.message}" if action.HasField("error_info") else "NA")
108
+ yield "phase", run_definition_pb2.Phase.Name(action.status.phase)
109
+ yield "group", action.metadata.group
110
+ yield "parent", action.metadata.parent
111
+
112
+
113
+ def _action_done_check(phase: run_definition_pb2.Phase) -> bool:
114
+ """
115
+ Check if the action is done.
116
+ """
117
+ return phase in [
118
+ run_definition_pb2.PHASE_FAILED,
119
+ run_definition_pb2.PHASE_SUCCEEDED,
120
+ run_definition_pb2.PHASE_ABORTED,
121
+ run_definition_pb2.PHASE_TIMED_OUT,
122
+ ]
123
+
124
+
125
+ @dataclass
126
+ class Action(ToJSONMixin):
127
+ """
128
+ A class representing an action. It is used to manage the run of a task and its state on the remote Union API.
129
+ """
130
+
131
+ pb2: run_definition_pb2.Action
132
+ _details: ActionDetails | None = None
133
+
134
+ @syncify
135
+ @classmethod
136
+ async def listall(
137
+ cls,
138
+ for_run_name: str,
139
+ filters: str | None = None,
140
+ sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
141
+ ) -> Union[Iterator[Action], AsyncIterator[Action]]:
142
+ """
143
+ Get all actions for a given run.
144
+
145
+ :param for_run_name: The name of the run.
146
+ :param filters: The filters to apply to the project list.
147
+ :param sort_by: The sorting criteria for the project list, in the format (field, order).
148
+ :return: An iterator of projects.
149
+ """
150
+ ensure_client()
151
+ token = None
152
+ sort_by = sort_by or ("created_at", "asc")
153
+ sort_pb2 = list_pb2.Sort(
154
+ key=sort_by[0],
155
+ direction=(list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING),
156
+ )
157
+ cfg = get_init_config()
158
+ while True:
159
+ req = list_pb2.ListRequest(
160
+ limit=100,
161
+ token=token,
162
+ sort_by=sort_pb2,
163
+ )
164
+ resp = await get_client().run_service.ListActions(
165
+ run_service_pb2.ListActionsRequest(
166
+ request=req,
167
+ run_id=identifier_pb2.RunIdentifier(
168
+ org=cfg.org,
169
+ project=cfg.project,
170
+ domain=cfg.domain,
171
+ name=for_run_name,
172
+ ),
173
+ )
174
+ )
175
+ token = resp.token
176
+ for r in resp.actions:
177
+ yield cls(r)
178
+ if not token:
179
+ break
180
+
181
+ @syncify
182
+ @classmethod
183
+ async def get(
184
+ cls,
185
+ uri: str | None = None,
186
+ /,
187
+ run_name: str | None = None,
188
+ name: str | None = None,
189
+ ) -> Action:
190
+ """
191
+ Get a run by its ID or name. If both are provided, the ID will take precedence.
192
+
193
+ :param uri: The URI of the action.
194
+ :param run_name: The name of the action.
195
+ :param name: The name of the action.
196
+ """
197
+ ensure_client()
198
+ cfg = get_init_config()
199
+ details: ActionDetails = await ActionDetails.get_details.aio(
200
+ identifier_pb2.ActionIdentifier(
201
+ run=identifier_pb2.RunIdentifier(
202
+ org=cfg.org,
203
+ project=cfg.project,
204
+ domain=cfg.domain,
205
+ name=run_name,
206
+ ),
207
+ name=name,
208
+ ),
209
+ )
210
+ return cls(
211
+ pb2=run_definition_pb2.Action(
212
+ id=details.action_id,
213
+ metadata=details.pb2.metadata,
214
+ status=details.pb2.status,
215
+ ),
216
+ _details=details,
217
+ )
218
+
219
+ @property
220
+ def phase(self) -> str:
221
+ """
222
+ Get the phase of the action.
223
+ """
224
+ return run_definition_pb2.Phase.Name(self.pb2.status.phase)
225
+
226
+ @property
227
+ def raw_phase(self) -> run_definition_pb2.Phase:
228
+ """
229
+ Get the raw phase of the action.
230
+ """
231
+ return self.pb2.status.phase
232
+
233
+ @property
234
+ def name(self) -> str:
235
+ """
236
+ Get the name of the action.
237
+ """
238
+ return self.action_id.name
239
+
240
+ @property
241
+ def run_name(self) -> str:
242
+ """
243
+ Get the name of the run.
244
+ """
245
+ return self.action_id.run.name
246
+
247
+ @property
248
+ def task_name(self) -> str | None:
249
+ """
250
+ Get the name of the task.
251
+ """
252
+ if self.pb2.metadata.HasField("task") and self.pb2.metadata.task.HasField("id"):
253
+ return self.pb2.metadata.task.id.name
254
+ return None
255
+
256
+ @property
257
+ def action_id(self) -> identifier_pb2.ActionIdentifier:
258
+ """
259
+ Get the action ID.
260
+ """
261
+ return self.pb2.id
262
+
263
+ @syncify
264
+ async def show_logs(
265
+ self,
266
+ attempt: int | None = None,
267
+ max_lines: int = 30,
268
+ show_ts: bool = False,
269
+ raw: bool = False,
270
+ filter_system: bool = False,
271
+ ):
272
+ details = await self.details()
273
+ if not details.is_running and not details.done():
274
+ # TODO we can short circuit here if the attempt is not the last one and it is done!
275
+ await self.wait(wait_for="logs-ready")
276
+ details = await self.details()
277
+ if not attempt:
278
+ attempt = details.attempts
279
+ return await Logs.create_viewer(
280
+ action_id=self.action_id,
281
+ attempt=attempt,
282
+ max_lines=max_lines,
283
+ show_ts=show_ts,
284
+ raw=raw,
285
+ filter_system=filter_system,
286
+ )
287
+
288
+ async def details(self) -> ActionDetails:
289
+ """
290
+ Get the details of the action. This is a placeholder for getting the action details.
291
+ """
292
+ if not self._details:
293
+ self._details = await ActionDetails.get_details.aio(self.action_id)
294
+ return cast(ActionDetails, self._details)
295
+
296
+ async def watch(
297
+ self, cache_data_on_done: bool = False, wait_for: WaitFor = "terminal"
298
+ ) -> AsyncGenerator[ActionDetails, None]:
299
+ """
300
+ Watch the action for updates. This is a placeholder for watching the action.
301
+ """
302
+ ad = None
303
+ async for ad in ActionDetails.watch.aio(self.action_id):
304
+ if ad is None:
305
+ return
306
+ self._details = ad
307
+ yield ad
308
+ if wait_for == "running" and ad.is_running:
309
+ break
310
+ elif wait_for == "logs-ready" and ad.logs_available():
311
+ break
312
+ if ad.done():
313
+ break
314
+ if cache_data_on_done and ad and ad.done():
315
+ await cast(ActionDetails, self._details).outputs()
316
+
317
+ async def wait(self, quiet: bool = False, wait_for: WaitFor = "terminal") -> None:
318
+ """
319
+ Wait for the run to complete, displaying a rich progress panel with status transitions,
320
+ time elapsed, and error details in case of failure.
321
+ """
322
+ console = Console()
323
+ if self.done():
324
+ if not quiet:
325
+ if self.pb2.status.phase == run_definition_pb2.PHASE_SUCCEEDED:
326
+ console.print(
327
+ f"[bold green]Action '{self.name}' in Run '{self.run_name}'"
328
+ f" completed successfully.[/bold green]"
329
+ )
330
+ else:
331
+ details = await self.details()
332
+ error_message = details.error_info.message if details.error_info else ""
333
+ console.print(
334
+ f"[bold red]Action '{self.name}' in Run '{self.run_name}'"
335
+ f" exited unsuccessfully in state {self.phase} with error: {error_message}[/bold red]"
336
+ )
337
+ return
338
+
339
+ try:
340
+ with Progress(
341
+ SpinnerColumn(),
342
+ TextColumn("[progress.description]{task.description}"),
343
+ TimeElapsedColumn(),
344
+ console=console,
345
+ transient=True,
346
+ disable=quiet,
347
+ ) as progress:
348
+ task_id = progress.add_task(f"Waiting for run '{self.name}'...", start=False)
349
+ progress.start_task(task_id)
350
+
351
+ async for ad in self.watch(cache_data_on_done=True, wait_for=wait_for):
352
+ if ad is None:
353
+ progress.stop_task(task_id)
354
+ break
355
+
356
+ if ad.is_running and wait_for == "running":
357
+ progress.start_task(task_id)
358
+ break
359
+
360
+ if ad.logs_available() and wait_for == "logs-ready":
361
+ progress.start_task(task_id)
362
+ break
363
+
364
+ # Update progress description with the current phase
365
+ progress.update(
366
+ task_id,
367
+ description=f"Run: {self.run_name} in {ad.phase}, Runtime: {ad.runtime} secs "
368
+ f"Attempts[{ad.attempts}]",
369
+ )
370
+
371
+ # If the action is done, handle the final state
372
+ if ad.done():
373
+ progress.stop_task(task_id)
374
+ if not quiet:
375
+ if ad.pb2.status.phase == run_definition_pb2.PHASE_SUCCEEDED:
376
+ console.print(f"[bold green]Run '{self.run_name}' completed successfully.[/bold green]")
377
+ else:
378
+ error_message = ad.error_info.message if ad.error_info else ""
379
+ console.print(
380
+ f"[bold red]Run '{self.run_name}' exited unsuccessfully in state {ad.phase}"
381
+ f" with error: {error_message}[/bold red]"
382
+ )
383
+ break
384
+ except asyncio.CancelledError:
385
+ # Handle cancellation gracefully
386
+ pass
387
+ except KeyboardInterrupt:
388
+ # Handle keyboard interrupt gracefully
389
+ pass
390
+
391
+ def done(self) -> bool:
392
+ """
393
+ Check if the action is done.
394
+ """
395
+ return _action_done_check(self.raw_phase)
396
+
397
+ async def sync(self) -> Action:
398
+ """
399
+ Sync the action with the remote server. This is a placeholder for syncing the action.
400
+ """
401
+ return self
402
+
403
+ def __rich_repr__(self) -> rich.repr.Result:
404
+ """
405
+ Rich representation of the Action object.
406
+ """
407
+ yield from _action_rich_repr(self.pb2)
408
+ if self._details:
409
+ yield from self._details.__rich_repr__()
410
+
411
+ def __repr__(self) -> str:
412
+ """
413
+ String representation of the Action object.
414
+ """
415
+ import rich.pretty
416
+
417
+ return rich.pretty.pretty_repr(self)
418
+
419
+
420
+ @dataclass
421
+ class ActionDetails(ToJSONMixin):
422
+ """
423
+ A class representing an action. It is used to manage the run of a task and its state on the remote Union API.
424
+ """
425
+
426
+ pb2: run_definition_pb2.ActionDetails
427
+ _inputs: ActionInputs | None = None
428
+ _outputs: ActionOutputs | None = None
429
+
430
+ @syncify
431
+ @classmethod
432
+ async def get_details(cls, action_id: identifier_pb2.ActionIdentifier) -> ActionDetails:
433
+ """
434
+ Get the details of the action. This is a placeholder for getting the action details.
435
+ """
436
+ ensure_client()
437
+ resp = await get_client().run_service.GetActionDetails(
438
+ run_service_pb2.GetActionDetailsRequest(
439
+ action_id=action_id,
440
+ )
441
+ )
442
+ return ActionDetails(resp.details)
443
+
444
+ @syncify
445
+ @classmethod
446
+ async def get(
447
+ cls,
448
+ uri: str | None = None,
449
+ /,
450
+ run_name: str | None = None,
451
+ name: str | None = None,
452
+ ) -> ActionDetails:
453
+ """
454
+ Get a run by its ID or name. If both are provided, the ID will take precedence.
455
+
456
+ :param uri: The URI of the action.
457
+ :param name: The name of the action.
458
+ :param run_name: The name of the run.
459
+ """
460
+ ensure_client()
461
+ if not uri:
462
+ assert name is not None and run_name is not None, "Either uri or name and run_name must be provided"
463
+ cfg = get_init_config()
464
+ return await cls.get_details.aio(
465
+ identifier_pb2.ActionIdentifier(
466
+ run=identifier_pb2.RunIdentifier(
467
+ org=cfg.org,
468
+ project=cfg.project,
469
+ domain=cfg.domain,
470
+ name=run_name,
471
+ ),
472
+ name=name,
473
+ ),
474
+ )
475
+
476
+ @syncify
477
+ @classmethod
478
+ async def watch(cls, action_id: identifier_pb2.ActionIdentifier) -> AsyncIterator[ActionDetails]:
479
+ """
480
+ Watch the action for updates. This is a placeholder for watching the action.
481
+ """
482
+ ensure_client()
483
+ if not action_id:
484
+ raise ValueError("Action ID is required")
485
+
486
+ call = cast(
487
+ AsyncIterator[WatchActionDetailsResponse],
488
+ get_client().run_service.WatchActionDetails(
489
+ request=run_service_pb2.WatchActionDetailsRequest(
490
+ action_id=action_id,
491
+ )
492
+ ),
493
+ )
494
+ try:
495
+ async for resp in call:
496
+ v = cls(resp.details)
497
+ yield v
498
+ if v.done():
499
+ return
500
+ except grpc.aio.AioRpcError as e:
501
+ if e.code() == grpc.StatusCode.CANCELLED:
502
+ pass
503
+ else:
504
+ raise e
505
+
506
+ async def watch_updates(self, cache_data_on_done: bool = False) -> AsyncGenerator[ActionDetails, None]:
507
+ async for d in self.watch.aio(action_id=self.pb2.id):
508
+ yield d
509
+ if d.done():
510
+ self.pb2 = d.pb2
511
+ break
512
+
513
+ if cache_data_on_done and self.done():
514
+ await self._cache_data.aio()
515
+
516
+ @property
517
+ def phase(self) -> str:
518
+ """
519
+ Get the phase of the action.
520
+ """
521
+ return run_definition_pb2.Phase.Name(self.status.phase)
522
+
523
+ @property
524
+ def raw_phase(self) -> run_definition_pb2.Phase:
525
+ """
526
+ Get the raw phase of the action.
527
+ """
528
+ return self.status.phase
529
+
530
+ @property
531
+ def is_running(self) -> bool:
532
+ """
533
+ Check if the action is currently running.
534
+ """
535
+ return self.status.phase == run_definition_pb2.PHASE_RUNNING
536
+
537
+ @property
538
+ def name(self) -> str:
539
+ """
540
+ Get the name of the action.
541
+ """
542
+ return self.action_id.name
543
+
544
+ @property
545
+ def run_name(self) -> str:
546
+ """
547
+ Get the name of the run.
548
+ """
549
+ return self.action_id.run.name
550
+
551
+ @property
552
+ def task_name(self) -> str | None:
553
+ """
554
+ Get the name of the task.
555
+ """
556
+ if self.pb2.metadata.HasField("task") and self.pb2.metadata.task.HasField("id"):
557
+ return self.pb2.metadata.task.id.name
558
+ return None
559
+
560
+ @property
561
+ def action_id(self) -> identifier_pb2.ActionIdentifier:
562
+ """
563
+ Get the action ID.
564
+ """
565
+ return self.pb2.id
566
+
567
+ @property
568
+ def metadata(self) -> run_definition_pb2.ActionMetadata:
569
+ return self.pb2.metadata
570
+
571
+ @property
572
+ def status(self) -> run_definition_pb2.ActionStatus:
573
+ return self.pb2.status
574
+
575
+ @property
576
+ def error_info(self) -> run_definition_pb2.ErrorInfo | None:
577
+ if self.pb2.HasField("error_info"):
578
+ return self.pb2.error_info
579
+ return None
580
+
581
+ @property
582
+ def abort_info(self) -> run_definition_pb2.AbortInfo | None:
583
+ if self.pb2.HasField("abort_info"):
584
+ return self.pb2.abort_info
585
+ return None
586
+
587
+ @property
588
+ def runtime(self) -> timedelta:
589
+ """
590
+ Get the runtime of the action.
591
+ """
592
+ start_time = self.pb2.status.start_time.ToDatetime().replace(tzinfo=timezone.utc)
593
+ if self.pb2.status.HasField("end_time"):
594
+ end_time = self.pb2.status.end_time.ToDatetime().replace(tzinfo=timezone.utc)
595
+ return end_time - start_time
596
+ return datetime.now(timezone.utc) - start_time
597
+
598
+ @property
599
+ def attempts(self) -> int:
600
+ """
601
+ Get the number of attempts of the action.
602
+ """
603
+ return self.pb2.status.attempts
604
+
605
+ def logs_available(self, attempt: int | None = None) -> bool:
606
+ """
607
+ Check if logs are available for the action, optionally for a specific attempt.
608
+ If attempt is None, it checks for the latest attempt.
609
+ """
610
+ if attempt is None:
611
+ attempt = self.pb2.status.attempts
612
+ attempts = self.pb2.attempts
613
+ if attempts and len(attempts) >= attempt:
614
+ return attempts[attempt - 1].logs_available
615
+ return False
616
+
617
+ @syncify
618
+ async def _cache_data(self) -> bool:
619
+ """
620
+ Cache the inputs and outputs of the action.
621
+ :return: Returns True if Action is terminal and all data is cached else False.
622
+ """
623
+ from flyte._internal.runtime import convert
624
+
625
+ if self._inputs and self._outputs:
626
+ return True
627
+ if self._inputs and not self.done():
628
+ return False
629
+ resp = await get_client().run_service.GetActionData(
630
+ request=run_service_pb2.GetActionDataRequest(
631
+ action_id=self.pb2.id,
632
+ )
633
+ )
634
+ native_iface = None
635
+ if self.pb2.HasField("task"):
636
+ iface = self.pb2.task.task_template.interface
637
+ native_iface = types.guess_interface(iface)
638
+ elif self.pb2.HasField("trace"):
639
+ iface = self.pb2.trace.interface
640
+ native_iface = types.guess_interface(iface)
641
+
642
+ if resp.inputs:
643
+ data_dict = (
644
+ await convert.convert_from_inputs_to_native(native_iface, convert.Inputs(resp.inputs))
645
+ if native_iface
646
+ else {}
647
+ )
648
+ self._inputs = ActionInputs(pb2=resp.inputs, data=data_dict)
649
+
650
+ if resp.outputs:
651
+ data_tuple = (
652
+ await convert.convert_outputs_to_native(native_iface, convert.Outputs(resp.outputs))
653
+ if native_iface
654
+ else ()
655
+ )
656
+ if not isinstance(data_tuple, tuple):
657
+ data_tuple = (data_tuple,)
658
+ self._outputs = ActionOutputs(pb2=resp.outputs, data=data_tuple)
659
+
660
+ return self._outputs is not None
661
+
662
+ async def inputs(self) -> ActionInputs:
663
+ """
664
+ Placeholder for inputs. This can be extended to handle inputs from the run context.
665
+ """
666
+ if not self._inputs:
667
+ await self._cache_data.aio()
668
+ return cast(ActionInputs, self._inputs)
669
+
670
+ async def outputs(self) -> ActionOutputs:
671
+ """
672
+ Placeholder for outputs. This can be extended to handle outputs from the run context.
673
+ """
674
+ if not self._outputs:
675
+ if not await self._cache_data.aio():
676
+ raise RuntimeError(
677
+ "Action is not in a terminal state, outputs are not available. "
678
+ "Please wait for the action to complete."
679
+ )
680
+ return cast(ActionOutputs, self._outputs)
681
+
682
+ def done(self) -> bool:
683
+ """
684
+ Check if the action is in a terminal state (completed or failed). This is a placeholder for checking the
685
+ action state.
686
+ """
687
+ return _action_done_check(self.raw_phase)
688
+
689
+ def __rich_repr__(self) -> rich.repr.Result:
690
+ """
691
+ Rich representation of the Action object.
692
+ """
693
+ yield from _action_details_rich_repr(self.pb2)
694
+
695
+ def __repr__(self) -> str:
696
+ """
697
+ String representation of the Action object.
698
+ """
699
+ import rich.pretty
700
+
701
+ return rich.pretty.pretty_repr(self)
702
+
703
+
704
+ @dataclass
705
+ class ActionInputs(UserDict, ToJSONMixin):
706
+ """
707
+ A class representing the inputs of an action. It is used to manage the inputs of a task and its state on the
708
+ remote Union API.
709
+ """
710
+
711
+ pb2: common_pb2.Inputs
712
+ data: Dict[str, Any]
713
+
714
+ def __repr__(self):
715
+ import rich.pretty
716
+
717
+ import flyte.types as types
718
+
719
+ return rich.pretty.pretty_repr(types.literal_string_repr(self.pb2))
720
+
721
+
722
+ class ActionOutputs(tuple, ToJSONMixin):
723
+ """
724
+ A class representing the outputs of an action. It is used to manage the outputs of a task and its state on the
725
+ remote Union API.
726
+ """
727
+
728
+ def __new__(cls, pb2: common_pb2.Outputs, data: Tuple[Any, ...]):
729
+ # Create the tuple part
730
+ obj = super().__new__(cls, data)
731
+ # Store extra data (you can't do this here directly since it's immutable)
732
+ obj.pb2 = pb2
733
+ return obj
734
+
735
+ def __init__(self, pb2: common_pb2.Outputs, data: Tuple[Any, ...]):
736
+ # Normally you'd set instance attributes here,
737
+ # but we've already set `pb2` in `__new__`
738
+ self.pb2 = pb2