flyte 0.1.0__py3-none-any.whl → 0.2.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (219) hide show
  1. flyte/__init__.py +78 -2
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/runtime.py +152 -0
  4. flyte/_build.py +26 -0
  5. flyte/_cache/__init__.py +12 -0
  6. flyte/_cache/cache.py +145 -0
  7. flyte/_cache/defaults.py +9 -0
  8. flyte/_cache/policy_function_body.py +42 -0
  9. flyte/_code_bundle/__init__.py +8 -0
  10. flyte/_code_bundle/_ignore.py +113 -0
  11. flyte/_code_bundle/_packaging.py +187 -0
  12. flyte/_code_bundle/_utils.py +323 -0
  13. flyte/_code_bundle/bundle.py +209 -0
  14. flyte/_context.py +152 -0
  15. flyte/_deploy.py +243 -0
  16. flyte/_doc.py +29 -0
  17. flyte/_docstring.py +32 -0
  18. flyte/_environment.py +84 -0
  19. flyte/_excepthook.py +37 -0
  20. flyte/_group.py +32 -0
  21. flyte/_hash.py +23 -0
  22. flyte/_image.py +762 -0
  23. flyte/_initialize.py +492 -0
  24. flyte/_interface.py +84 -0
  25. flyte/_internal/__init__.py +3 -0
  26. flyte/_internal/controllers/__init__.py +128 -0
  27. flyte/_internal/controllers/_local_controller.py +193 -0
  28. flyte/_internal/controllers/_trace.py +41 -0
  29. flyte/_internal/controllers/remote/__init__.py +60 -0
  30. flyte/_internal/controllers/remote/_action.py +146 -0
  31. flyte/_internal/controllers/remote/_client.py +47 -0
  32. flyte/_internal/controllers/remote/_controller.py +494 -0
  33. flyte/_internal/controllers/remote/_core.py +410 -0
  34. flyte/_internal/controllers/remote/_informer.py +361 -0
  35. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  36. flyte/_internal/imagebuild/__init__.py +11 -0
  37. flyte/_internal/imagebuild/docker_builder.py +427 -0
  38. flyte/_internal/imagebuild/image_builder.py +246 -0
  39. flyte/_internal/imagebuild/remote_builder.py +0 -0
  40. flyte/_internal/resolvers/__init__.py +0 -0
  41. flyte/_internal/resolvers/_task_module.py +54 -0
  42. flyte/_internal/resolvers/common.py +31 -0
  43. flyte/_internal/resolvers/default.py +28 -0
  44. flyte/_internal/runtime/__init__.py +0 -0
  45. flyte/_internal/runtime/convert.py +342 -0
  46. flyte/_internal/runtime/entrypoints.py +135 -0
  47. flyte/_internal/runtime/io.py +136 -0
  48. flyte/_internal/runtime/resources_serde.py +138 -0
  49. flyte/_internal/runtime/task_serde.py +330 -0
  50. flyte/_internal/runtime/taskrunner.py +191 -0
  51. flyte/_internal/runtime/types_serde.py +54 -0
  52. flyte/_logging.py +135 -0
  53. flyte/_map.py +215 -0
  54. flyte/_pod.py +19 -0
  55. flyte/_protos/__init__.py +0 -0
  56. flyte/_protos/common/authorization_pb2.py +66 -0
  57. flyte/_protos/common/authorization_pb2.pyi +108 -0
  58. flyte/_protos/common/authorization_pb2_grpc.py +4 -0
  59. flyte/_protos/common/identifier_pb2.py +71 -0
  60. flyte/_protos/common/identifier_pb2.pyi +82 -0
  61. flyte/_protos/common/identifier_pb2_grpc.py +4 -0
  62. flyte/_protos/common/identity_pb2.py +48 -0
  63. flyte/_protos/common/identity_pb2.pyi +72 -0
  64. flyte/_protos/common/identity_pb2_grpc.py +4 -0
  65. flyte/_protos/common/list_pb2.py +36 -0
  66. flyte/_protos/common/list_pb2.pyi +71 -0
  67. flyte/_protos/common/list_pb2_grpc.py +4 -0
  68. flyte/_protos/common/policy_pb2.py +37 -0
  69. flyte/_protos/common/policy_pb2.pyi +27 -0
  70. flyte/_protos/common/policy_pb2_grpc.py +4 -0
  71. flyte/_protos/common/role_pb2.py +37 -0
  72. flyte/_protos/common/role_pb2.pyi +53 -0
  73. flyte/_protos/common/role_pb2_grpc.py +4 -0
  74. flyte/_protos/common/runtime_version_pb2.py +28 -0
  75. flyte/_protos/common/runtime_version_pb2.pyi +24 -0
  76. flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
  77. flyte/_protos/logs/dataplane/payload_pb2.py +100 -0
  78. flyte/_protos/logs/dataplane/payload_pb2.pyi +177 -0
  79. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  80. flyte/_protos/secret/definition_pb2.py +49 -0
  81. flyte/_protos/secret/definition_pb2.pyi +93 -0
  82. flyte/_protos/secret/definition_pb2_grpc.py +4 -0
  83. flyte/_protos/secret/payload_pb2.py +62 -0
  84. flyte/_protos/secret/payload_pb2.pyi +94 -0
  85. flyte/_protos/secret/payload_pb2_grpc.py +4 -0
  86. flyte/_protos/secret/secret_pb2.py +38 -0
  87. flyte/_protos/secret/secret_pb2.pyi +6 -0
  88. flyte/_protos/secret/secret_pb2_grpc.py +198 -0
  89. flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
  90. flyte/_protos/validate/validate/validate_pb2.py +76 -0
  91. flyte/_protos/workflow/common_pb2.py +27 -0
  92. flyte/_protos/workflow/common_pb2.pyi +14 -0
  93. flyte/_protos/workflow/common_pb2_grpc.py +4 -0
  94. flyte/_protos/workflow/environment_pb2.py +29 -0
  95. flyte/_protos/workflow/environment_pb2.pyi +12 -0
  96. flyte/_protos/workflow/environment_pb2_grpc.py +4 -0
  97. flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
  98. flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  99. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  100. flyte/_protos/workflow/queue_service_pb2.py +105 -0
  101. flyte/_protos/workflow/queue_service_pb2.pyi +146 -0
  102. flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  103. flyte/_protos/workflow/run_definition_pb2.py +128 -0
  104. flyte/_protos/workflow/run_definition_pb2.pyi +314 -0
  105. flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  106. flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
  107. flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  108. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  109. flyte/_protos/workflow/run_service_pb2.py +129 -0
  110. flyte/_protos/workflow/run_service_pb2.pyi +171 -0
  111. flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
  112. flyte/_protos/workflow/state_service_pb2.py +66 -0
  113. flyte/_protos/workflow/state_service_pb2.pyi +75 -0
  114. flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
  115. flyte/_protos/workflow/task_definition_pb2.py +79 -0
  116. flyte/_protos/workflow/task_definition_pb2.pyi +81 -0
  117. flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  118. flyte/_protos/workflow/task_service_pb2.py +60 -0
  119. flyte/_protos/workflow/task_service_pb2.pyi +59 -0
  120. flyte/_protos/workflow/task_service_pb2_grpc.py +138 -0
  121. flyte/_resources.py +226 -0
  122. flyte/_retry.py +32 -0
  123. flyte/_reusable_environment.py +25 -0
  124. flyte/_run.py +482 -0
  125. flyte/_secret.py +61 -0
  126. flyte/_task.py +449 -0
  127. flyte/_task_environment.py +183 -0
  128. flyte/_timeout.py +47 -0
  129. flyte/_tools.py +27 -0
  130. flyte/_trace.py +120 -0
  131. flyte/_utils/__init__.py +26 -0
  132. flyte/_utils/asyn.py +119 -0
  133. flyte/_utils/async_cache.py +139 -0
  134. flyte/_utils/coro_management.py +23 -0
  135. flyte/_utils/file_handling.py +72 -0
  136. flyte/_utils/helpers.py +134 -0
  137. flyte/_utils/lazy_module.py +54 -0
  138. flyte/_utils/org_discovery.py +57 -0
  139. flyte/_utils/uv_script_parser.py +49 -0
  140. flyte/_version.py +21 -0
  141. flyte/cli/__init__.py +3 -0
  142. flyte/cli/_abort.py +28 -0
  143. flyte/cli/_common.py +337 -0
  144. flyte/cli/_create.py +145 -0
  145. flyte/cli/_delete.py +23 -0
  146. flyte/cli/_deploy.py +152 -0
  147. flyte/cli/_gen.py +163 -0
  148. flyte/cli/_get.py +310 -0
  149. flyte/cli/_params.py +538 -0
  150. flyte/cli/_run.py +231 -0
  151. flyte/cli/main.py +166 -0
  152. flyte/config/__init__.py +3 -0
  153. flyte/config/_config.py +216 -0
  154. flyte/config/_internal.py +64 -0
  155. flyte/config/_reader.py +207 -0
  156. flyte/connectors/__init__.py +0 -0
  157. flyte/errors.py +172 -0
  158. flyte/extras/__init__.py +5 -0
  159. flyte/extras/_container.py +263 -0
  160. flyte/io/__init__.py +27 -0
  161. flyte/io/_dir.py +448 -0
  162. flyte/io/_file.py +467 -0
  163. flyte/io/_structured_dataset/__init__.py +129 -0
  164. flyte/io/_structured_dataset/basic_dfs.py +219 -0
  165. flyte/io/_structured_dataset/structured_dataset.py +1061 -0
  166. flyte/models.py +391 -0
  167. flyte/remote/__init__.py +26 -0
  168. flyte/remote/_client/__init__.py +0 -0
  169. flyte/remote/_client/_protocols.py +133 -0
  170. flyte/remote/_client/auth/__init__.py +12 -0
  171. flyte/remote/_client/auth/_auth_utils.py +14 -0
  172. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  173. flyte/remote/_client/auth/_authenticators/base.py +397 -0
  174. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  175. flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
  176. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  177. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  178. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  179. flyte/remote/_client/auth/_channel.py +215 -0
  180. flyte/remote/_client/auth/_client_config.py +83 -0
  181. flyte/remote/_client/auth/_default_html.py +32 -0
  182. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  183. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  184. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  185. flyte/remote/_client/auth/_keyring.py +143 -0
  186. flyte/remote/_client/auth/_token_client.py +260 -0
  187. flyte/remote/_client/auth/errors.py +16 -0
  188. flyte/remote/_client/controlplane.py +95 -0
  189. flyte/remote/_console.py +18 -0
  190. flyte/remote/_data.py +159 -0
  191. flyte/remote/_logs.py +176 -0
  192. flyte/remote/_project.py +85 -0
  193. flyte/remote/_run.py +970 -0
  194. flyte/remote/_secret.py +132 -0
  195. flyte/remote/_task.py +391 -0
  196. flyte/report/__init__.py +3 -0
  197. flyte/report/_report.py +178 -0
  198. flyte/report/_template.html +124 -0
  199. flyte/storage/__init__.py +29 -0
  200. flyte/storage/_config.py +233 -0
  201. flyte/storage/_remote_fs.py +34 -0
  202. flyte/storage/_storage.py +271 -0
  203. flyte/storage/_utils.py +5 -0
  204. flyte/syncify/__init__.py +56 -0
  205. flyte/syncify/_api.py +371 -0
  206. flyte/types/__init__.py +36 -0
  207. flyte/types/_interface.py +40 -0
  208. flyte/types/_pickle.py +118 -0
  209. flyte/types/_renderer.py +162 -0
  210. flyte/types/_string_literals.py +120 -0
  211. flyte/types/_type_engine.py +2287 -0
  212. flyte/types/_utils.py +80 -0
  213. flyte-0.2.0a0.dist-info/METADATA +249 -0
  214. flyte-0.2.0a0.dist-info/RECORD +218 -0
  215. {flyte-0.1.0.dist-info → flyte-0.2.0a0.dist-info}/WHEEL +2 -1
  216. flyte-0.2.0a0.dist-info/entry_points.txt +3 -0
  217. flyte-0.2.0a0.dist-info/top_level.txt +1 -0
  218. flyte-0.1.0.dist-info/METADATA +0 -6
  219. flyte-0.1.0.dist-info/RECORD +0 -5
flyte/remote/_run.py ADDED
@@ -0,0 +1,970 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from dataclasses import dataclass, field
5
+ from datetime import datetime, timedelta, timezone
6
+ from typing import AsyncGenerator, AsyncIterator, Iterator, List, Literal, Tuple, Union, cast
7
+
8
+ import grpc
9
+ import rich.repr
10
+ from google.protobuf import timestamp
11
+ from rich.console import Console
12
+ from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
13
+
14
+ from flyte._initialize import ensure_client, get_client, get_common_config
15
+ from flyte._protos.common import identifier_pb2, list_pb2
16
+ from flyte._protos.workflow import run_definition_pb2, run_service_pb2
17
+ from flyte.syncify import syncify
18
+
19
+ from .._protos.workflow.run_service_pb2 import WatchActionDetailsResponse
20
+ from ._console import get_run_url
21
+ from ._logs import Logs
22
+
23
+ WaitFor = Literal["terminal", "running", "logs-ready"]
24
+
25
+
26
+ def _action_time_phase(action: run_definition_pb2.Action | run_definition_pb2.ActionDetails) -> rich.repr.Result:
27
+ """
28
+ Rich representation of the action time and phase.
29
+ """
30
+ start_time = timestamp.to_datetime(action.status.start_time, timezone.utc)
31
+ yield "start_time", start_time.isoformat()
32
+ if action.status.phase in [
33
+ run_definition_pb2.PHASE_FAILED,
34
+ run_definition_pb2.PHASE_SUCCEEDED,
35
+ run_definition_pb2.PHASE_ABORTED,
36
+ run_definition_pb2.PHASE_TIMED_OUT,
37
+ ]:
38
+ end_time = timestamp.to_datetime(action.status.end_time, timezone.utc)
39
+ yield "end_time", end_time.isoformat()
40
+ yield "run_time", f"{(end_time - start_time).seconds} secs"
41
+ else:
42
+ yield "end_time", None
43
+ yield "run_time", f"{(datetime.now(timezone.utc) - start_time).seconds} secs"
44
+ yield "phase", run_definition_pb2.Phase.Name(action.status.phase)
45
+ if isinstance(action, run_definition_pb2.ActionDetails):
46
+ yield (
47
+ "error",
48
+ f"{action.error_info.kind}: {action.error_info.message}" if action.HasField("error_info") else "NA",
49
+ )
50
+
51
+
52
+ def _action_rich_repr(action: run_definition_pb2.Action) -> rich.repr.Result:
53
+ """
54
+ Rich representation of the action.
55
+ """
56
+ yield "run", action.id.run.name
57
+ if action.metadata.HasField("task"):
58
+ yield "task", action.metadata.task.id.name
59
+ yield "type", "task"
60
+ yield "name", action.id.name
61
+ yield from _action_time_phase(action)
62
+ yield "group", action.metadata.group
63
+ yield "parent", action.metadata.parent
64
+ yield "attempts", action.status.attempts
65
+
66
+
67
+ def _attempt_rich_repr(action: List[run_definition_pb2.ActionAttempt]) -> rich.repr.Result:
68
+ for attempt in action:
69
+ yield "attempt", attempt.attempt
70
+ yield "phase", run_definition_pb2.Phase.Name(attempt.phase)
71
+ yield "logs_available", attempt.logs_available
72
+
73
+
74
+ def _action_details_rich_repr(action: run_definition_pb2.ActionDetails) -> rich.repr.Result:
75
+ """
76
+ Rich representation of the action details.
77
+ """
78
+ yield "name", action.id.run.name
79
+ yield from _action_time_phase(action)
80
+ yield "task", action.resolved_task_spec.task_template.id.name
81
+ yield "task_type", action.resolved_task_spec.task_template.type
82
+ yield "task_version", action.resolved_task_spec.task_template.id.version
83
+ yield "attempts", action.attempts
84
+ yield "error", f"{action.error_info.kind}: {action.error_info.message}" if action.HasField("error_info") else "NA"
85
+ yield "phase", run_definition_pb2.Phase.Name(action.status.phase)
86
+ yield "group", action.metadata.group
87
+ yield "parent", action.metadata.parent
88
+
89
+
90
+ def _action_done_check(phase: run_definition_pb2.Phase) -> bool:
91
+ """
92
+ Check if the action is done.
93
+ """
94
+ return phase in [
95
+ run_definition_pb2.PHASE_FAILED,
96
+ run_definition_pb2.PHASE_SUCCEEDED,
97
+ run_definition_pb2.PHASE_ABORTED,
98
+ run_definition_pb2.PHASE_TIMED_OUT,
99
+ ]
100
+
101
+
102
+ @dataclass
103
+ class Run:
104
+ """
105
+ A class representing a run of a task. It is used to manage the run of a task and its state on the remote
106
+ Union API.
107
+ """
108
+
109
+ pb2: run_definition_pb2.Run
110
+ action: Action = field(init=False)
111
+ _details: RunDetails | None = None
112
+
113
+ def __post_init__(self):
114
+ """
115
+ Initialize the Run object with the given run definition.
116
+ """
117
+ if not self.pb2.HasField("action"):
118
+ raise RuntimeError("Run does not have an action")
119
+ self.action = Action(self.pb2.action)
120
+
121
+ @syncify
122
+ @classmethod
123
+ async def listall(
124
+ cls,
125
+ filters: str | None = None,
126
+ sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
127
+ ) -> AsyncIterator[Run]:
128
+ """
129
+ Get all runs for the current project and domain.
130
+
131
+ :param filters: The filters to apply to the project list.
132
+ :param sort_by: The sorting criteria for the project list, in the format (field, order).
133
+ :return: An iterator of runs.
134
+ """
135
+ ensure_client()
136
+ token = None
137
+ sort_by = sort_by or ("created_at", "asc")
138
+ sort_pb2 = list_pb2.Sort(
139
+ key=sort_by[0], direction=list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING
140
+ )
141
+ cfg = get_common_config()
142
+ while True:
143
+ req = list_pb2.ListRequest(
144
+ limit=100,
145
+ token=token,
146
+ sort_by=sort_pb2,
147
+ )
148
+ resp = await get_client().run_service.ListRuns(
149
+ run_service_pb2.ListRunsRequest(
150
+ request=req,
151
+ org=cfg.org,
152
+ project_id=identifier_pb2.ProjectIdentifier(
153
+ organization=cfg.org,
154
+ domain=cfg.domain,
155
+ name=cfg.project,
156
+ ),
157
+ )
158
+ )
159
+ token = resp.token
160
+ for r in resp.runs:
161
+ yield cls(r)
162
+ if not token:
163
+ break
164
+
165
+ @syncify
166
+ @classmethod
167
+ async def get(cls, name: str) -> Run:
168
+ """
169
+ Get the current run.
170
+
171
+ :return: The current run.
172
+ """
173
+ ensure_client()
174
+ run_details: RunDetails = await RunDetails.get.aio(name=name)
175
+ run = run_definition_pb2.Run(
176
+ action=run_definition_pb2.Action(
177
+ id=run_details.action_id,
178
+ metadata=run_details.action_details.pb2.metadata,
179
+ status=run_details.action_details.pb2.status,
180
+ ),
181
+ )
182
+ return cls(pb2=run, _details=run_details)
183
+
184
+ @property
185
+ def name(self) -> str:
186
+ """
187
+ Get the name of the run.
188
+ """
189
+ return self.pb2.action.id.run.name
190
+
191
+ @property
192
+ def phase(self) -> str:
193
+ """
194
+ Get the phase of the run.
195
+ """
196
+ return self.action.phase
197
+
198
+ @property
199
+ def raw_phase(self) -> run_definition_pb2.Phase:
200
+ """
201
+ Get the raw phase of the run.
202
+ """
203
+ return self.action.raw_phase
204
+
205
+ @syncify
206
+ async def wait(self, quiet: bool = False, wait_for: Literal["terminal", "running"] = "terminal") -> None:
207
+ """
208
+ Wait for the run to complete, displaying a rich progress panel with status transitions,
209
+ time elapsed, and error details in case of failure.
210
+ """
211
+ return await self.action.wait(quiet=quiet, wait_for=wait_for)
212
+
213
+ async def watch(self, cache_data_on_done: bool = False) -> AsyncGenerator[ActionDetails, None]:
214
+ """
215
+ Get the details of the run. This is a placeholder for getting the run details.
216
+ """
217
+ return self.action.watch(cache_data_on_done=cache_data_on_done)
218
+
219
+ async def show_logs(
220
+ self,
221
+ attempt: int | None = None,
222
+ max_lines: int = 100,
223
+ show_ts: bool = False,
224
+ raw: bool = False,
225
+ filter_system: bool = False,
226
+ ):
227
+ await self.action.show_logs(attempt, max_lines, show_ts, raw, filter_system=filter_system)
228
+
229
+ async def details(self) -> RunDetails:
230
+ """
231
+ Get the details of the run. This is a placeholder for getting the run details.
232
+ """
233
+ if self._details is None:
234
+ self._details = await RunDetails.get_details(RunDetails, self.pb2.action.id.run)
235
+ return self._details
236
+
237
+ @property
238
+ def url(self) -> str:
239
+ """
240
+ Get the URL of the run.
241
+ """
242
+ client = get_client()
243
+ return get_run_url(
244
+ client.endpoint,
245
+ insecure=client.insecure,
246
+ project=self.pb2.action.id.run.project,
247
+ domain=self.pb2.action.id.run.domain,
248
+ run_name=self.name,
249
+ )
250
+
251
+ @syncify
252
+ async def abort(self):
253
+ """
254
+ Aborts / Terminates the run.
255
+ """
256
+ try:
257
+ await get_client().run_service.AbortRun(
258
+ run_service_pb2.AbortRunRequest(
259
+ run_id=self.pb2.action.id.run,
260
+ )
261
+ )
262
+ except grpc.aio.AioRpcError as e:
263
+ if e.code() == grpc.StatusCode.NOT_FOUND:
264
+ return
265
+ raise
266
+
267
+ def done(self) -> bool:
268
+ """
269
+ Check if the run is done.
270
+ """
271
+ return self.action.done()
272
+
273
+ def sync(self) -> Run:
274
+ """
275
+ Sync the run with the remote server. This is a placeholder for syncing the run.
276
+ """
277
+ return self
278
+
279
+ # TODO add add_done_callback, maybe implement sync apis etc
280
+
281
+ def __rich_repr__(self) -> rich.repr.Result:
282
+ """
283
+ Rich representation of the Run object.
284
+ """
285
+ yield from _action_rich_repr(self.pb2.action)
286
+
287
+ def __repr__(self) -> str:
288
+ """
289
+ String representation of the Action object.
290
+ """
291
+ import rich.pretty
292
+
293
+ return rich.pretty.pretty_repr(self)
294
+
295
+
296
+ @dataclass
297
+ class RunDetails:
298
+ """
299
+ A class representing a run of a task. It is used to manage the run of a task and its state on the remote
300
+ Union API.
301
+ """
302
+
303
+ pb2: run_definition_pb2.RunDetails
304
+ action_details: ActionDetails = field(init=False)
305
+
306
+ def __post_init__(self):
307
+ """
308
+ Initialize the RunDetails object with the given run definition.
309
+ """
310
+ self.action_details = ActionDetails(self.pb2.action)
311
+
312
+ @syncify
313
+ @classmethod
314
+ async def get_details(cls, run_id: run_definition_pb2.RunIdentifier) -> RunDetails:
315
+ """
316
+ Get the details of the run. This is a placeholder for getting the run details.
317
+ """
318
+ ensure_client()
319
+ resp = await get_client().run_service.GetRunDetails(
320
+ run_service_pb2.GetRunDetailsRequest(
321
+ run_id=run_id,
322
+ )
323
+ )
324
+ return cls(resp.details)
325
+
326
+ @syncify
327
+ @classmethod
328
+ async def get(cls, name: str | None = None) -> RunDetails:
329
+ """
330
+ Get a run by its ID or name. If both are provided, the ID will take precedence.
331
+
332
+ :param uri: The URI of the run.
333
+ :param name: The name of the run.
334
+ """
335
+ ensure_client()
336
+ cfg = get_common_config()
337
+ return await RunDetails.get_details.aio(
338
+ run_id=run_definition_pb2.RunIdentifier(
339
+ org=cfg.org,
340
+ project=cfg.project,
341
+ domain=cfg.domain,
342
+ name=name,
343
+ ),
344
+ )
345
+
346
+ @property
347
+ def name(self) -> str:
348
+ """
349
+ Get the name of the action.
350
+ """
351
+ return self.action_details.run_name
352
+
353
+ @property
354
+ def task_name(self) -> str | None:
355
+ """
356
+ Get the name of the task.
357
+ """
358
+ return self.action_details.task_name
359
+
360
+ @property
361
+ def action_id(self) -> run_definition_pb2.ActionIdentifier:
362
+ """
363
+ Get the action ID.
364
+ """
365
+ return self.action_details.action_id
366
+
367
+ def done(self) -> bool:
368
+ """
369
+ Check if the run is in a terminal state (completed or failed). This is a placeholder for checking the
370
+ run state.
371
+ """
372
+ return self.action_details.done()
373
+
374
+ async def inputs(self) -> ActionInputs:
375
+ """
376
+ Placeholder for inputs. This can be extended to handle inputs from the run context.
377
+ """
378
+ return await self.action_details.inputs()
379
+
380
+ async def outputs(self) -> ActionOutputs:
381
+ """
382
+ Placeholder for outputs. This can be extended to handle outputs from the run context.
383
+ """
384
+ return await self.action_details.outputs()
385
+
386
+ def __rich_repr__(self) -> rich.repr.Result:
387
+ """
388
+ Rich representation of the Run object.
389
+ """
390
+ yield from _action_details_rich_repr(self.pb2.action)
391
+
392
+ def __repr__(self) -> str:
393
+ """
394
+ String representation of the Action object.
395
+ """
396
+ import rich.pretty
397
+
398
+ return rich.pretty.pretty_repr(self)
399
+
400
+
401
+ @dataclass
402
+ class Action:
403
+ """
404
+ A class representing an action. It is used to manage the run of a task and its state on the remote Union API.
405
+ """
406
+
407
+ pb2: run_definition_pb2.Action
408
+ _details: ActionDetails | None = None
409
+
410
+ @syncify
411
+ @classmethod
412
+ async def listall(
413
+ cls,
414
+ for_run_name: str,
415
+ filters: str | None = None,
416
+ sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
417
+ ) -> Union[Iterator[Action], AsyncIterator[Action]]:
418
+ """
419
+ Get all actions for a given run.
420
+
421
+ :param for_run_name: The name of the run.
422
+ :param filters: The filters to apply to the project list.
423
+ :param sort_by: The sorting criteria for the project list, in the format (field, order).
424
+ :return: An iterator of projects.
425
+ """
426
+ ensure_client()
427
+ token = None
428
+ sort_by = sort_by or ("created_at", "asc")
429
+ sort_pb2 = list_pb2.Sort(
430
+ key=sort_by[0], direction=list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING
431
+ )
432
+ cfg = get_common_config()
433
+ while True:
434
+ req = list_pb2.ListRequest(
435
+ limit=100,
436
+ token=token,
437
+ sort_by=sort_pb2,
438
+ )
439
+ resp = await get_client().run_service.ListActions(
440
+ run_service_pb2.ListActionsRequest(
441
+ request=req,
442
+ run_id=run_definition_pb2.RunIdentifier(
443
+ org=cfg.org,
444
+ project=cfg.project,
445
+ domain=cfg.domain,
446
+ name=for_run_name,
447
+ ),
448
+ )
449
+ )
450
+ token = resp.token
451
+ for r in resp.actions:
452
+ yield cls(r)
453
+ if not token:
454
+ break
455
+
456
+ @syncify
457
+ @classmethod
458
+ async def get(cls, uri: str | None = None, /, run_name: str | None = None, name: str | None = None) -> Action:
459
+ """
460
+ Get a run by its ID or name. If both are provided, the ID will take precedence.
461
+
462
+ :param uri: The URI of the action.
463
+ :param run_name: The name of the action.
464
+ :param name: The name of the action.
465
+ """
466
+ ensure_client()
467
+ cfg = get_common_config()
468
+ details: ActionDetails = await ActionDetails.get_details.aio(
469
+ run_definition_pb2.ActionIdentifier(
470
+ run=run_definition_pb2.RunIdentifier(
471
+ org=cfg.org,
472
+ project=cfg.project,
473
+ domain=cfg.domain,
474
+ name=run_name,
475
+ ),
476
+ name=name,
477
+ ),
478
+ )
479
+ return cls(
480
+ pb2=run_definition_pb2.Action(
481
+ id=details.action_id,
482
+ metadata=details.pb2.metadata,
483
+ status=details.pb2.status,
484
+ ),
485
+ _details=details,
486
+ )
487
+
488
+ @property
489
+ def phase(self) -> str:
490
+ """
491
+ Get the phase of the action.
492
+ """
493
+ return run_definition_pb2.Phase.Name(self.pb2.status.phase)
494
+
495
+ @property
496
+ def raw_phase(self) -> run_definition_pb2.Phase:
497
+ """
498
+ Get the raw phase of the action.
499
+ """
500
+ return self.pb2.status.phase
501
+
502
+ @property
503
+ def name(self) -> str:
504
+ """
505
+ Get the name of the action.
506
+ """
507
+ return self.action_id.name
508
+
509
+ @property
510
+ def run_name(self) -> str:
511
+ """
512
+ Get the name of the run.
513
+ """
514
+ return self.action_id.run.name
515
+
516
+ @property
517
+ def task_name(self) -> str | None:
518
+ """
519
+ Get the name of the task.
520
+ """
521
+ if self.pb2.metadata.HasField("task") and self.pb2.metadata.task.HasField("id"):
522
+ return self.pb2.metadata.task.id.name
523
+ return None
524
+
525
+ @property
526
+ def action_id(self) -> run_definition_pb2.ActionIdentifier:
527
+ """
528
+ Get the action ID.
529
+ """
530
+ return self.pb2.id
531
+
532
+ async def show_logs(
533
+ self,
534
+ attempt: int | None = None,
535
+ max_lines: int = 30,
536
+ show_ts: bool = False,
537
+ raw: bool = False,
538
+ filter_system: bool = False,
539
+ ):
540
+ details = await self.details()
541
+ if not details.is_running and not details.done():
542
+ # TODO we can short circuit here if the attempt is not the last one and it is done!
543
+ await self.wait(wait_for="logs-ready")
544
+ details = await self.details()
545
+ if not attempt:
546
+ attempt = details.attempts
547
+ return await Logs.create_viewer(
548
+ action_id=self.action_id,
549
+ attempt=attempt,
550
+ max_lines=max_lines,
551
+ show_ts=show_ts,
552
+ raw=raw,
553
+ filter_system=filter_system,
554
+ )
555
+
556
+ async def details(self) -> ActionDetails:
557
+ """
558
+ Get the details of the action. This is a placeholder for getting the action details.
559
+ """
560
+ if not self._details:
561
+ self._details = await ActionDetails.get_details.aio(self.action_id)
562
+ return cast(ActionDetails, self._details)
563
+
564
+ async def watch(
565
+ self, cache_data_on_done: bool = False, wait_for: WaitFor = "terminal"
566
+ ) -> AsyncGenerator[ActionDetails, None]:
567
+ """
568
+ Watch the action for updates. This is a placeholder for watching the action.
569
+ """
570
+ ad = None
571
+ async for ad in ActionDetails.watch.aio(self.action_id):
572
+ if ad is None:
573
+ return
574
+ self._details = ad
575
+ yield ad
576
+ if wait_for == "running" and ad.is_running:
577
+ break
578
+ elif wait_for == "logs-ready" and ad.logs_available():
579
+ break
580
+ if ad.done():
581
+ break
582
+ if cache_data_on_done and ad and ad.done():
583
+ await cast(ActionDetails, self._details).outputs()
584
+
585
+ async def wait(self, quiet: bool = False, wait_for: WaitFor = "terminal") -> None:
586
+ """
587
+ Wait for the run to complete, displaying a rich progress panel with status transitions,
588
+ time elapsed, and error details in case of failure.
589
+ """
590
+ console = Console()
591
+ if self.done():
592
+ if not quiet:
593
+ if self.pb2.status.phase == run_definition_pb2.PHASE_SUCCEEDED:
594
+ console.print(
595
+ f"[bold green]Action '{self.name}' in Run '{self.run_name}'"
596
+ f" completed successfully.[/bold green]"
597
+ )
598
+ else:
599
+ details = await self.details()
600
+ console.print(
601
+ f"[bold red]Action '{self.name}' in Run '{self.run_name}'"
602
+ f" exited unsuccessfully in state {self.phase} with error: {details.error_info}[/bold red]"
603
+ )
604
+ return
605
+
606
+ try:
607
+ with Progress(
608
+ SpinnerColumn(),
609
+ TextColumn("[progress.description]{task.description}"),
610
+ TimeElapsedColumn(),
611
+ console=console,
612
+ transient=True,
613
+ disable=quiet,
614
+ ) as progress:
615
+ task_id = progress.add_task(f"Waiting for run '{self.name}'...", start=False)
616
+ progress.start_task(task_id)
617
+
618
+ async for ad in self.watch(cache_data_on_done=True, wait_for=wait_for):
619
+ if ad is None:
620
+ progress.stop_task(task_id)
621
+ break
622
+
623
+ if ad.is_running and wait_for == "running":
624
+ progress.start_task(task_id)
625
+ break
626
+
627
+ if ad.logs_available() and wait_for == "logs-ready":
628
+ progress.start_task(task_id)
629
+ break
630
+
631
+ # Update progress description with the current phase
632
+ progress.update(
633
+ task_id,
634
+ description=f"Run: {self.name} in {ad.phase}, Runtime: {ad.runtime} secs "
635
+ f"Attempts[{ad.attempts}]",
636
+ )
637
+
638
+ # If the action is done, handle the final state
639
+ if ad.done():
640
+ progress.stop_task(task_id)
641
+ if ad.pb2.status.phase == run_definition_pb2.PHASE_SUCCEEDED:
642
+ console.print(f"[bold green]Run '{self.name}' completed successfully.[/bold green]")
643
+ else:
644
+ console.print(
645
+ f"[bold red]Run '{self.name}' exited unsuccessfully in state {ad.phase}"
646
+ f"with error: {ad.error_info}[/bold red]"
647
+ )
648
+ break
649
+ except asyncio.CancelledError:
650
+ # Handle cancellation gracefully
651
+ pass
652
+ except KeyboardInterrupt:
653
+ # Handle keyboard interrupt gracefully
654
+ pass
655
+
656
+ def done(self) -> bool:
657
+ """
658
+ Check if the action is done.
659
+ """
660
+ return _action_done_check(self.raw_phase)
661
+
662
+ async def sync(self) -> Action:
663
+ """
664
+ Sync the action with the remote server. This is a placeholder for syncing the action.
665
+ """
666
+ return self
667
+
668
+ def __rich_repr__(self) -> rich.repr.Result:
669
+ """
670
+ Rich representation of the Action object.
671
+ """
672
+ yield from _action_rich_repr(self.pb2)
673
+ if self._details:
674
+ yield from self._details.__rich_repr__()
675
+
676
+ def __repr__(self) -> str:
677
+ """
678
+ String representation of the Action object.
679
+ """
680
+ import rich.pretty
681
+
682
+ return rich.pretty.pretty_repr(self)
683
+
684
+
685
+ @dataclass
686
+ class ActionDetails:
687
+ """
688
+ A class representing an action. It is used to manage the run of a task and its state on the remote Union API.
689
+ """
690
+
691
+ pb2: run_definition_pb2.ActionDetails
692
+ _inputs: ActionInputs | None = None
693
+ _outputs: ActionOutputs | None = None
694
+
695
+ @syncify
696
+ @classmethod
697
+ async def get_details(cls, action_id: run_definition_pb2.ActionIdentifier) -> ActionDetails:
698
+ """
699
+ Get the details of the action. This is a placeholder for getting the action details.
700
+ """
701
+ ensure_client()
702
+ resp = await get_client().run_service.GetActionDetails(
703
+ run_service_pb2.GetActionDetailsRequest(
704
+ action_id=action_id,
705
+ )
706
+ )
707
+ return ActionDetails(resp.details)
708
+
709
+ @syncify
710
+ @classmethod
711
+ async def get(
712
+ cls, uri: str | None = None, /, run_name: str | None = None, name: str | None = None
713
+ ) -> ActionDetails:
714
+ """
715
+ Get a run by its ID or name. If both are provided, the ID will take precedence.
716
+
717
+ :param uri: The URI of the action.
718
+ :param name: The name of the action.
719
+ :param run_name: The name of the run.
720
+ """
721
+ ensure_client()
722
+ if not uri:
723
+ assert name is not None and run_name is not None, "Either uri or name and run_name must be provided"
724
+ cfg = get_common_config()
725
+ return await cls.get_details.aio(
726
+ run_definition_pb2.ActionIdentifier(
727
+ run=run_definition_pb2.RunIdentifier(
728
+ org=cfg.org,
729
+ project=cfg.project,
730
+ domain=cfg.domain,
731
+ name=run_name,
732
+ ),
733
+ name=name,
734
+ ),
735
+ )
736
+
737
+ @syncify
738
+ @classmethod
739
+ async def watch(cls, action_id: run_definition_pb2.ActionIdentifier) -> AsyncIterator[ActionDetails]:
740
+ """
741
+ Watch the action for updates. This is a placeholder for watching the action.
742
+ """
743
+ ensure_client()
744
+ if not action_id:
745
+ raise ValueError("Action ID is required")
746
+
747
+ call = cast(
748
+ AsyncIterator[WatchActionDetailsResponse],
749
+ get_client().run_service.WatchActionDetails(
750
+ request=run_service_pb2.WatchActionDetailsRequest(
751
+ action_id=action_id,
752
+ )
753
+ ),
754
+ )
755
+ try:
756
+ async for resp in call:
757
+ v = cls(resp.details)
758
+ yield v
759
+ if v.done():
760
+ return
761
+ except grpc.aio.AioRpcError as e:
762
+ if e.code() == grpc.StatusCode.CANCELLED:
763
+ pass
764
+ else:
765
+ raise e
766
+
767
+ async def watch_updates(self, cache_data_on_done: bool = False) -> AsyncGenerator[ActionDetails, None]:
768
+ async for d in self.watch.aio(action_id=self.pb2.id):
769
+ yield d
770
+ if d.done():
771
+ self.pb2 = d.pb2
772
+ break
773
+
774
+ if cache_data_on_done and self.done():
775
+ await self._cache_data.aio()
776
+
777
+ @property
778
+ def phase(self) -> str:
779
+ """
780
+ Get the phase of the action.
781
+ """
782
+ return run_definition_pb2.Phase.Name(self.status.phase)
783
+
784
+ @property
785
+ def raw_phase(self) -> run_definition_pb2.Phase:
786
+ """
787
+ Get the raw phase of the action.
788
+ """
789
+ return self.status.phase
790
+
791
+ @property
792
+ def is_running(self) -> bool:
793
+ """
794
+ Check if the action is currently running.
795
+ """
796
+ return self.status.phase == run_definition_pb2.PHASE_RUNNING
797
+
798
+ @property
799
+ def name(self) -> str:
800
+ """
801
+ Get the name of the action.
802
+ """
803
+ return self.action_id.name
804
+
805
+ @property
806
+ def run_name(self) -> str:
807
+ """
808
+ Get the name of the run.
809
+ """
810
+ return self.action_id.run.name
811
+
812
+ @property
813
+ def task_name(self) -> str | None:
814
+ """
815
+ Get the name of the task.
816
+ """
817
+ if self.pb2.metadata.HasField("task") and self.pb2.metadata.task.HasField("id"):
818
+ return self.pb2.metadata.task.id.name
819
+ return None
820
+
821
+ @property
822
+ def action_id(self) -> run_definition_pb2.ActionIdentifier:
823
+ """
824
+ Get the action ID.
825
+ """
826
+ return self.pb2.id
827
+
828
+ @property
829
+ def metadata(self) -> run_definition_pb2.ActionMetadata:
830
+ return self.pb2.metadata
831
+
832
+ @property
833
+ def status(self) -> run_definition_pb2.ActionStatus:
834
+ return self.pb2.status
835
+
836
+ @property
837
+ def error_info(self) -> run_definition_pb2.ErrorInfo | None:
838
+ if self.pb2.HasField("error_info"):
839
+ return self.pb2.error_info
840
+ return None
841
+
842
+ @property
843
+ def abort_info(self) -> run_definition_pb2.AbortInfo | None:
844
+ if self.pb2.HasField("abort_info"):
845
+ return self.pb2.abort_info
846
+ return None
847
+
848
+ @property
849
+ def runtime(self) -> timedelta:
850
+ """
851
+ Get the runtime of the action.
852
+ """
853
+ start_time = timestamp.to_datetime(self.pb2.status.start_time, timezone.utc)
854
+ if self.pb2.status.HasField("end_time"):
855
+ end_time = timestamp.to_datetime(self.pb2.status.end_time, timezone.utc)
856
+ return end_time - start_time
857
+ return datetime.now(timezone.utc) - start_time
858
+
859
+ @property
860
+ def attempts(self) -> int:
861
+ """
862
+ Get the number of attempts of the action.
863
+ """
864
+ return self.pb2.status.attempts
865
+
866
+ def logs_available(self, attempt: int | None = None) -> bool:
867
+ """
868
+ Check if logs are available for the action, optionally for a specific attempt.
869
+ If attempt is None, it checks for the latest attempt.
870
+ """
871
+ if attempt is None:
872
+ attempt = self.pb2.status.attempts
873
+ attempts = self.pb2.attempts
874
+ if attempts and len(attempts) >= attempt:
875
+ return attempts[attempt - 1].logs_available
876
+ return False
877
+
878
+ @syncify
879
+ async def _cache_data(self) -> bool:
880
+ """
881
+ Cache the inputs and outputs of the action.
882
+ :return: Returns True if Action is terminal and all data is cached else False.
883
+ """
884
+ if self._inputs and self._outputs:
885
+ return True
886
+ if self._inputs and not self.done():
887
+ return False
888
+ resp = await get_client().run_service.GetActionData(
889
+ request=run_service_pb2.GetActionDataRequest(
890
+ action_id=self.pb2.id,
891
+ )
892
+ )
893
+ self._inputs = ActionInputs(resp.inputs)
894
+ self._outputs = ActionOutputs(resp.outputs) if resp.HasField("outputs") else None
895
+ return self._outputs is not None
896
+
897
+ async def inputs(self) -> ActionInputs:
898
+ """
899
+ Placeholder for inputs. This can be extended to handle inputs from the run context.
900
+ """
901
+ if not self._inputs:
902
+ await self._cache_data.aio()
903
+ return cast(ActionInputs, self._inputs)
904
+
905
+ async def outputs(self) -> ActionOutputs:
906
+ """
907
+ Placeholder for outputs. This can be extended to handle outputs from the run context.
908
+ """
909
+ if not self._outputs:
910
+ if not await self._cache_data.aio():
911
+ raise RuntimeError(
912
+ "Action is not in a terminal state, outputs are not available. "
913
+ "Please wait for the action to complete."
914
+ )
915
+ return cast(ActionOutputs, self._outputs)
916
+
917
+ def done(self) -> bool:
918
+ """
919
+ Check if the action is in a terminal state (completed or failed). This is a placeholder for checking the
920
+ action state.
921
+ """
922
+ return _action_done_check(self.raw_phase)
923
+
924
+ def __rich_repr__(self) -> rich.repr.Result:
925
+ """
926
+ Rich representation of the Action object.
927
+ """
928
+ yield from _action_details_rich_repr(self.pb2)
929
+
930
+ def __repr__(self) -> str:
931
+ """
932
+ String representation of the Action object.
933
+ """
934
+ import rich.pretty
935
+
936
+ return rich.pretty.pretty_repr(self)
937
+
938
+
939
+ @dataclass
940
+ class ActionInputs:
941
+ """
942
+ A class representing the inputs of an action. It is used to manage the inputs of a task and its state on the
943
+ remote Union API.
944
+ """
945
+
946
+ pb2: run_definition_pb2.Inputs
947
+
948
+ def __repr__(self):
949
+ import rich.pretty
950
+
951
+ import flyte.types as types
952
+
953
+ return rich.pretty.pretty_repr(types.literal_string_repr(self.pb2))
954
+
955
+
956
+ @dataclass
957
+ class ActionOutputs:
958
+ """
959
+ A class representing the outputs of an action. It is used to manage the outputs of a task and its state on the
960
+ remote Union API.
961
+ """
962
+
963
+ pb2: run_definition_pb2.Outputs
964
+
965
+ def __repr__(self):
966
+ import rich.pretty
967
+
968
+ import flyte.types as types
969
+
970
+ return rich.pretty.pretty_repr(types.literal_string_repr(self.pb2))