flyte 2.0.0b13__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/debug.py +38 -0
  3. flyte/_bin/runtime.py +62 -8
  4. flyte/_cache/cache.py +4 -2
  5. flyte/_cache/local_cache.py +216 -0
  6. flyte/_code_bundle/_ignore.py +12 -4
  7. flyte/_code_bundle/_packaging.py +13 -9
  8. flyte/_code_bundle/_utils.py +18 -10
  9. flyte/_code_bundle/bundle.py +17 -9
  10. flyte/_constants.py +1 -0
  11. flyte/_context.py +4 -1
  12. flyte/_custom_context.py +73 -0
  13. flyte/_debug/constants.py +38 -0
  14. flyte/_debug/utils.py +17 -0
  15. flyte/_debug/vscode.py +307 -0
  16. flyte/_deploy.py +235 -61
  17. flyte/_environment.py +20 -6
  18. flyte/_excepthook.py +1 -1
  19. flyte/_hash.py +1 -16
  20. flyte/_image.py +178 -81
  21. flyte/_initialize.py +132 -51
  22. flyte/_interface.py +39 -2
  23. flyte/_internal/controllers/__init__.py +4 -5
  24. flyte/_internal/controllers/_local_controller.py +70 -29
  25. flyte/_internal/controllers/_trace.py +1 -1
  26. flyte/_internal/controllers/remote/__init__.py +0 -2
  27. flyte/_internal/controllers/remote/_action.py +14 -16
  28. flyte/_internal/controllers/remote/_client.py +1 -1
  29. flyte/_internal/controllers/remote/_controller.py +68 -70
  30. flyte/_internal/controllers/remote/_core.py +127 -99
  31. flyte/_internal/controllers/remote/_informer.py +19 -10
  32. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  33. flyte/_internal/imagebuild/docker_builder.py +181 -69
  34. flyte/_internal/imagebuild/image_builder.py +0 -5
  35. flyte/_internal/imagebuild/remote_builder.py +155 -64
  36. flyte/_internal/imagebuild/utils.py +51 -2
  37. flyte/_internal/resolvers/_task_module.py +5 -38
  38. flyte/_internal/resolvers/default.py +2 -2
  39. flyte/_internal/runtime/convert.py +110 -21
  40. flyte/_internal/runtime/entrypoints.py +27 -1
  41. flyte/_internal/runtime/io.py +21 -8
  42. flyte/_internal/runtime/resources_serde.py +20 -6
  43. flyte/_internal/runtime/reuse.py +1 -1
  44. flyte/_internal/runtime/rusty.py +20 -5
  45. flyte/_internal/runtime/task_serde.py +34 -19
  46. flyte/_internal/runtime/taskrunner.py +22 -4
  47. flyte/_internal/runtime/trigger_serde.py +160 -0
  48. flyte/_internal/runtime/types_serde.py +1 -1
  49. flyte/_keyring/__init__.py +0 -0
  50. flyte/_keyring/file.py +115 -0
  51. flyte/_logging.py +201 -39
  52. flyte/_map.py +111 -14
  53. flyte/_module.py +70 -0
  54. flyte/_pod.py +4 -3
  55. flyte/_resources.py +213 -31
  56. flyte/_run.py +110 -39
  57. flyte/_task.py +75 -16
  58. flyte/_task_environment.py +105 -29
  59. flyte/_task_plugins.py +4 -2
  60. flyte/_trace.py +5 -0
  61. flyte/_trigger.py +1000 -0
  62. flyte/_utils/__init__.py +2 -1
  63. flyte/_utils/asyn.py +3 -1
  64. flyte/_utils/coro_management.py +2 -1
  65. flyte/_utils/docker_credentials.py +173 -0
  66. flyte/_utils/module_loader.py +17 -2
  67. flyte/_version.py +3 -3
  68. flyte/cli/_abort.py +3 -3
  69. flyte/cli/_build.py +3 -6
  70. flyte/cli/_common.py +78 -7
  71. flyte/cli/_create.py +182 -4
  72. flyte/cli/_delete.py +23 -1
  73. flyte/cli/_deploy.py +63 -16
  74. flyte/cli/_get.py +79 -34
  75. flyte/cli/_params.py +26 -10
  76. flyte/cli/_plugins.py +209 -0
  77. flyte/cli/_run.py +151 -26
  78. flyte/cli/_serve.py +64 -0
  79. flyte/cli/_update.py +37 -0
  80. flyte/cli/_user.py +17 -0
  81. flyte/cli/main.py +30 -4
  82. flyte/config/_config.py +10 -6
  83. flyte/config/_internal.py +1 -0
  84. flyte/config/_reader.py +29 -8
  85. flyte/connectors/__init__.py +11 -0
  86. flyte/connectors/_connector.py +270 -0
  87. flyte/connectors/_server.py +197 -0
  88. flyte/connectors/utils.py +135 -0
  89. flyte/errors.py +22 -2
  90. flyte/extend.py +8 -1
  91. flyte/extras/_container.py +6 -1
  92. flyte/git/__init__.py +3 -0
  93. flyte/git/_config.py +21 -0
  94. flyte/io/__init__.py +2 -0
  95. flyte/io/_dataframe/__init__.py +2 -0
  96. flyte/io/_dataframe/basic_dfs.py +17 -8
  97. flyte/io/_dataframe/dataframe.py +98 -132
  98. flyte/io/_dir.py +575 -113
  99. flyte/io/_file.py +582 -139
  100. flyte/io/_hashing_io.py +342 -0
  101. flyte/models.py +74 -15
  102. flyte/remote/__init__.py +6 -1
  103. flyte/remote/_action.py +34 -26
  104. flyte/remote/_client/_protocols.py +39 -4
  105. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  106. flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
  107. flyte/remote/_client/auth/_channel.py +10 -6
  108. flyte/remote/_client/controlplane.py +17 -5
  109. flyte/remote/_console.py +3 -2
  110. flyte/remote/_data.py +6 -6
  111. flyte/remote/_logs.py +3 -3
  112. flyte/remote/_run.py +64 -8
  113. flyte/remote/_secret.py +26 -17
  114. flyte/remote/_task.py +75 -33
  115. flyte/remote/_trigger.py +306 -0
  116. flyte/remote/_user.py +33 -0
  117. flyte/report/_report.py +1 -1
  118. flyte/storage/__init__.py +6 -1
  119. flyte/storage/_config.py +5 -1
  120. flyte/storage/_parallel_reader.py +274 -0
  121. flyte/storage/_storage.py +200 -103
  122. flyte/types/__init__.py +16 -0
  123. flyte/types/_interface.py +2 -2
  124. flyte/types/_pickle.py +35 -8
  125. flyte/types/_string_literals.py +8 -9
  126. flyte/types/_type_engine.py +40 -70
  127. flyte/types/_utils.py +1 -1
  128. flyte-2.0.0b30.data/scripts/debug.py +38 -0
  129. {flyte-2.0.0b13.data → flyte-2.0.0b30.data}/scripts/runtime.py +62 -8
  130. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/METADATA +11 -3
  131. flyte-2.0.0b30.dist-info/RECORD +192 -0
  132. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +3 -0
  133. flyte/_protos/common/authorization_pb2.py +0 -66
  134. flyte/_protos/common/authorization_pb2.pyi +0 -108
  135. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  136. flyte/_protos/common/identifier_pb2.py +0 -93
  137. flyte/_protos/common/identifier_pb2.pyi +0 -110
  138. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  139. flyte/_protos/common/identity_pb2.py +0 -48
  140. flyte/_protos/common/identity_pb2.pyi +0 -72
  141. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  142. flyte/_protos/common/list_pb2.py +0 -36
  143. flyte/_protos/common/list_pb2.pyi +0 -71
  144. flyte/_protos/common/list_pb2_grpc.py +0 -4
  145. flyte/_protos/common/policy_pb2.py +0 -37
  146. flyte/_protos/common/policy_pb2.pyi +0 -27
  147. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  148. flyte/_protos/common/role_pb2.py +0 -37
  149. flyte/_protos/common/role_pb2.pyi +0 -53
  150. flyte/_protos/common/role_pb2_grpc.py +0 -4
  151. flyte/_protos/common/runtime_version_pb2.py +0 -28
  152. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  153. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  154. flyte/_protos/imagebuilder/definition_pb2.py +0 -59
  155. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -140
  156. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  157. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  158. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  159. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  160. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  161. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  162. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  163. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  164. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  165. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  166. flyte/_protos/secret/definition_pb2.py +0 -49
  167. flyte/_protos/secret/definition_pb2.pyi +0 -93
  168. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  169. flyte/_protos/secret/payload_pb2.py +0 -62
  170. flyte/_protos/secret/payload_pb2.pyi +0 -94
  171. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  172. flyte/_protos/secret/secret_pb2.py +0 -38
  173. flyte/_protos/secret/secret_pb2.pyi +0 -6
  174. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  175. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  176. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  177. flyte/_protos/workflow/common_pb2.py +0 -27
  178. flyte/_protos/workflow/common_pb2.pyi +0 -14
  179. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  180. flyte/_protos/workflow/environment_pb2.py +0 -29
  181. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  182. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  183. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  184. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  185. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  186. flyte/_protos/workflow/queue_service_pb2.py +0 -109
  187. flyte/_protos/workflow/queue_service_pb2.pyi +0 -166
  188. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  189. flyte/_protos/workflow/run_definition_pb2.py +0 -121
  190. flyte/_protos/workflow/run_definition_pb2.pyi +0 -327
  191. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  192. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  193. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  194. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  195. flyte/_protos/workflow/run_service_pb2.py +0 -137
  196. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  197. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  198. flyte/_protos/workflow/state_service_pb2.py +0 -67
  199. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  200. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  201. flyte/_protos/workflow/task_definition_pb2.py +0 -79
  202. flyte/_protos/workflow/task_definition_pb2.pyi +0 -81
  203. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  204. flyte/_protos/workflow/task_service_pb2.py +0 -60
  205. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  206. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  207. flyte-2.0.0b13.dist-info/RECORD +0 -239
  208. /flyte/{_protos → _debug}/__init__.py +0 -0
  209. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  210. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  211. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
flyte/remote/_action.py CHANGED
@@ -20,14 +20,15 @@ from typing import (
20
20
  import grpc
21
21
  import rich.pretty
22
22
  import rich.repr
23
+ from flyteidl2.common import identifier_pb2, list_pb2
24
+ from flyteidl2.task import common_pb2
25
+ from flyteidl2.workflow import run_definition_pb2, run_service_pb2
26
+ from flyteidl2.workflow.run_service_pb2 import WatchActionDetailsResponse
23
27
  from rich.console import Console
24
28
  from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
25
29
 
26
30
  from flyte import types
27
- from flyte._initialize import ensure_client, get_client, get_common_config
28
- from flyte._protos.common import identifier_pb2, list_pb2
29
- from flyte._protos.workflow import run_definition_pb2, run_service_pb2
30
- from flyte._protos.workflow.run_service_pb2 import WatchActionDetailsResponse
31
+ from flyte._initialize import ensure_client, get_client, get_init_config
31
32
  from flyte.remote._common import ToJSONMixin
32
33
  from flyte.remote._logs import Logs
33
34
  from flyte.syncify import syncify
@@ -67,14 +68,14 @@ def _action_rich_repr(action: run_definition_pb2.Action) -> rich.repr.Result:
67
68
  """
68
69
  Rich representation of the action.
69
70
  """
70
- yield "run", action.id.run.name
71
+ yield "name", action.id.run.name
71
72
  if action.metadata.HasField("task"):
72
- yield "task", action.metadata.task.id.name
73
+ yield "task name", action.metadata.task.id.name
73
74
  yield "type", action.metadata.task.task_type
74
75
  elif action.metadata.HasField("trace"):
75
76
  yield "trace", action.metadata.trace.name
76
77
  yield "type", "trace"
77
- yield "name", action.id.name
78
+ yield "action name", action.id.name
78
79
  yield from _action_time_phase(action)
79
80
  yield "group", action.metadata.group
80
81
  yield "parent", action.metadata.parent
@@ -98,9 +99,10 @@ def _action_details_rich_repr(
98
99
  """
99
100
  yield "name", action.id.run.name
100
101
  yield from _action_time_phase(action)
101
- yield "task", action.resolved_task_spec.task_template.id.name
102
- yield "task_type", action.resolved_task_spec.task_template.type
103
- yield "task_version", action.resolved_task_spec.task_template.id.version
102
+ if action.HasField("task"):
103
+ yield "task", action.task.task_template.id.name
104
+ yield "task_type", action.task.task_template.type
105
+ yield "task_version", action.task.task_template.id.version
104
106
  yield "attempts", action.attempts
105
107
  yield "error", (f"{action.error_info.kind}: {action.error_info.message}" if action.HasField("error_info") else "NA")
106
108
  yield "phase", run_definition_pb2.Phase.Name(action.status.phase)
@@ -152,7 +154,7 @@ class Action(ToJSONMixin):
152
154
  key=sort_by[0],
153
155
  direction=(list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING),
154
156
  )
155
- cfg = get_common_config()
157
+ cfg = get_init_config()
156
158
  while True:
157
159
  req = list_pb2.ListRequest(
158
160
  limit=100,
@@ -193,7 +195,7 @@ class Action(ToJSONMixin):
193
195
  :param name: The name of the action.
194
196
  """
195
197
  ensure_client()
196
- cfg = get_common_config()
198
+ cfg = get_init_config()
197
199
  details: ActionDetails = await ActionDetails.get_details.aio(
198
200
  identifier_pb2.ActionIdentifier(
199
201
  run=identifier_pb2.RunIdentifier(
@@ -327,9 +329,10 @@ class Action(ToJSONMixin):
327
329
  )
328
330
  else:
329
331
  details = await self.details()
332
+ error_message = details.error_info.message if details.error_info else ""
330
333
  console.print(
331
334
  f"[bold red]Action '{self.name}' in Run '{self.run_name}'"
332
- f" exited unsuccessfully in state {self.phase} with error: {details.error_info}[/bold red]"
335
+ f" exited unsuccessfully in state {self.phase} with error: {error_message}[/bold red]"
333
336
  )
334
337
  return
335
338
 
@@ -368,13 +371,15 @@ class Action(ToJSONMixin):
368
371
  # If the action is done, handle the final state
369
372
  if ad.done():
370
373
  progress.stop_task(task_id)
371
- if ad.pb2.status.phase == run_definition_pb2.PHASE_SUCCEEDED:
372
- console.print(f"[bold green]Run '{self.run_name}' completed successfully.[/bold green]")
373
- else:
374
- console.print(
375
- f"[bold red]Run '{self.run_name}' exited unsuccessfully in state {ad.phase}"
376
- f" with error: {ad.error_info}[/bold red]"
377
- )
374
+ if not quiet:
375
+ if ad.pb2.status.phase == run_definition_pb2.PHASE_SUCCEEDED:
376
+ console.print(f"[bold green]Run '{self.run_name}' completed successfully.[/bold green]")
377
+ else:
378
+ error_message = ad.error_info.message if ad.error_info else ""
379
+ console.print(
380
+ f"[bold red]Run '{self.run_name}' exited unsuccessfully in state {ad.phase}"
381
+ f" with error: {error_message}[/bold red]"
382
+ )
378
383
  break
379
384
  except asyncio.CancelledError:
380
385
  # Handle cancellation gracefully
@@ -455,7 +460,7 @@ class ActionDetails(ToJSONMixin):
455
460
  ensure_client()
456
461
  if not uri:
457
462
  assert name is not None and run_name is not None, "Either uri or name and run_name must be provided"
458
- cfg = get_common_config()
463
+ cfg = get_init_config()
459
464
  return await cls.get_details.aio(
460
465
  identifier_pb2.ActionIdentifier(
461
466
  run=identifier_pb2.RunIdentifier(
@@ -627,8 +632,11 @@ class ActionDetails(ToJSONMixin):
627
632
  )
628
633
  )
629
634
  native_iface = None
630
- if self.pb2.resolved_task_spec:
631
- iface = self.pb2.resolved_task_spec.task_template.interface
635
+ if self.pb2.HasField("task"):
636
+ iface = self.pb2.task.task_template.interface
637
+ native_iface = types.guess_interface(iface)
638
+ elif self.pb2.HasField("trace"):
639
+ iface = self.pb2.trace.interface
632
640
  native_iface = types.guess_interface(iface)
633
641
 
634
642
  if resp.inputs:
@@ -700,7 +708,7 @@ class ActionInputs(UserDict, ToJSONMixin):
700
708
  remote Union API.
701
709
  """
702
710
 
703
- pb2: run_definition_pb2.Inputs
711
+ pb2: common_pb2.Inputs
704
712
  data: Dict[str, Any]
705
713
 
706
714
  def __repr__(self):
@@ -717,14 +725,14 @@ class ActionOutputs(tuple, ToJSONMixin):
717
725
  remote Union API.
718
726
  """
719
727
 
720
- def __new__(cls, pb2: run_definition_pb2.Outputs, data: Tuple[Any, ...]):
728
+ def __new__(cls, pb2: common_pb2.Outputs, data: Tuple[Any, ...]):
721
729
  # Create the tuple part
722
730
  obj = super().__new__(cls, data)
723
731
  # Store extra data (you can't do this here directly since it's immutable)
724
732
  obj.pb2 = pb2
725
733
  return obj
726
734
 
727
- def __init__(self, pb2: run_definition_pb2.Outputs, data: Tuple[Any, ...]):
735
+ def __init__(self, pb2: common_pb2.Outputs, data: Tuple[Any, ...]):
728
736
  # Normally you'd set instance attributes here,
729
737
  # but we've already set `pb2` in `__new__`
730
738
  self.pb2 = pb2
@@ -1,13 +1,14 @@
1
1
  from typing import AsyncIterator, Protocol
2
2
 
3
3
  from flyteidl.admin import project_attributes_pb2, project_pb2, version_pb2
4
- from flyteidl.service import dataproxy_pb2
4
+ from flyteidl.service import dataproxy_pb2, identity_pb2
5
+ from flyteidl2.secret import payload_pb2
6
+ from flyteidl2.task import task_service_pb2
7
+ from flyteidl2.trigger import trigger_service_pb2
8
+ from flyteidl2.workflow import run_logs_service_pb2, run_service_pb2
5
9
  from grpc.aio import UnaryStreamCall
6
10
  from grpc.aio._typing import RequestType
7
11
 
8
- from flyte._protos.secret import payload_pb2
9
- from flyte._protos.workflow import run_logs_service_pb2, run_service_pb2, task_service_pb2
10
-
11
12
 
12
13
  class MetadataServiceProtocol(Protocol):
13
14
  async def GetVersion(self, request: version_pb2.GetVersionRequest) -> version_pb2.GetVersionResponse: ...
@@ -131,3 +132,37 @@ class SecretService(Protocol):
131
132
  async def ListSecrets(self, request: payload_pb2.ListSecretsRequest) -> payload_pb2.ListSecretsResponse: ...
132
133
 
133
134
  async def DeleteSecret(self, request: payload_pb2.DeleteSecretRequest) -> payload_pb2.DeleteSecretResponse: ...
135
+
136
+
137
+ class IdentityService(Protocol):
138
+ async def UserInfo(self, request: identity_pb2.UserInfoRequest) -> identity_pb2.UserInfoResponse: ...
139
+
140
+
141
+ class TriggerService(Protocol):
142
+ async def DeployTrigger(
143
+ self, request: trigger_service_pb2.DeployTriggerRequest
144
+ ) -> trigger_service_pb2.DeployTriggerResponse: ...
145
+
146
+ async def GetTriggerDetails(
147
+ self, request: trigger_service_pb2.GetTriggerDetailsRequest
148
+ ) -> trigger_service_pb2.GetTriggerDetailsResponse: ...
149
+
150
+ async def GetTriggerRevisionDetails(
151
+ self, request: trigger_service_pb2.GetTriggerRevisionDetailsRequest
152
+ ) -> trigger_service_pb2.GetTriggerRevisionDetailsResponse: ...
153
+
154
+ async def ListTriggers(
155
+ self, request: trigger_service_pb2.ListTriggersRequest
156
+ ) -> trigger_service_pb2.ListTriggersResponse: ...
157
+
158
+ async def GetTriggerRevisionHistory(
159
+ self, request: trigger_service_pb2.GetTriggerRevisionHistoryRequest
160
+ ) -> trigger_service_pb2.GetTriggerRevisionHistoryResponse: ...
161
+
162
+ async def UpdateTriggers(
163
+ self, request: trigger_service_pb2.UpdateTriggersRequest
164
+ ) -> trigger_service_pb2.UpdateTriggersResponse: ...
165
+
166
+ async def DeleteTriggers(
167
+ self, request: trigger_service_pb2.DeleteTriggersRequest
168
+ ) -> trigger_service_pb2.DeleteTriggersResponse: ...
@@ -1,4 +1,4 @@
1
- import click
1
+ from rich import print as rich_print
2
2
 
3
3
  from flyte._logging import logger
4
4
  from flyte.remote._client.auth import _token_client as token_client
@@ -81,7 +81,7 @@ class DeviceCodeAuthenticator(Authenticator):
81
81
  for_endpoint=self._endpoint,
82
82
  )
83
83
  except (AuthenticationError, AuthenticationPending):
84
- logger.warning("Failed to refresh token. Kicking off a full authorization flow.")
84
+ logger.warning("Logging in...")
85
85
 
86
86
  """Fall back to device flow"""
87
87
  resp = await token_client.get_device_code(
@@ -94,10 +94,9 @@ class DeviceCodeAuthenticator(Authenticator):
94
94
 
95
95
  full_uri = f"{resp.verification_uri}?user_code={resp.user_code}"
96
96
  text = (
97
- f"To Authenticate, navigate in a browser to the following URL: "
98
- f"{click.style(full_uri, fg='blue', underline=True)}"
97
+ f"To Authenticate, navigate in a browser to the following URL: [blue link={full_uri}]{full_uri}[/blue link]"
99
98
  )
100
- click.secho(text)
99
+ rich_print(text)
101
100
  try:
102
101
  token, refresh_token, expires_in = await token_client.poll_token_endpoint(
103
102
  resp,
@@ -123,7 +123,7 @@ class PKCEAuthenticator(Authenticator):
123
123
  try:
124
124
  return await self._auth_client.refresh_access_token(self._creds)
125
125
  except AccessTokenNotFoundError:
126
- logger.warning("Failed to refresh token. Kicking off a full authorization flow.")
126
+ logger.warning("Logging in...")
127
127
 
128
128
  return await self._auth_client.get_creds_from_remote()
129
129
 
@@ -7,6 +7,7 @@ import httpx
7
7
  from grpc.experimental.aio import init_grpc_aio
8
8
 
9
9
  from flyte._logging import logger
10
+ from flyte._utils.org_discovery import hostname_from_url
10
11
 
11
12
  from ._authenticators.base import get_async_session
12
13
  from ._authenticators.factory import (
@@ -30,16 +31,19 @@ def bootstrap_ssl_from_server(endpoint: str) -> grpc.ChannelCredentials:
30
31
  :param endpoint: The endpoint URL to retrieve the SSL certificate from, may include port number
31
32
  :return: gRPC channel credentials created from the retrieved certificate
32
33
  """
34
+ hostname = hostname_from_url(endpoint)
35
+
33
36
  # Get port from endpoint or use 443
34
- endpoint_parts = endpoint.rsplit(":", 1)
37
+ endpoint_parts = hostname.rsplit(":", 1)
35
38
  if len(endpoint_parts) == 2 and endpoint_parts[1].isdigit():
36
39
  server_address = (endpoint_parts[0], int(endpoint_parts[1]))
37
40
  else:
38
- logger.warning(f"Unrecognized port in endpoint [{endpoint}], defaulting to 443.")
39
- server_address = (endpoint, 443)
41
+ logger.warning(f"Unrecognized port in endpoint [{hostname}], defaulting to 443.")
42
+ server_address = (hostname, 443)
40
43
 
41
- # Run the blocking SSL certificate retrieval in a thread pool
42
- cert = ssl.get_server_certificate(server_address)
44
+ # Run the blocking SSL certificate retrieval with a timeout
45
+ logger.debug(f"Retrieving SSL certificate from {server_address}")
46
+ cert = ssl.get_server_certificate(server_address, timeout=10)
43
47
  return grpc.ssl_channel_credentials(str.encode(cert))
44
48
 
45
49
 
@@ -112,7 +116,7 @@ async def create_channel(
112
116
  if api_key:
113
117
  from flyte.remote._client.auth._auth_utils import decode_api_key
114
118
 
115
- endpoint, client_id, client_secret, org = decode_api_key(api_key)
119
+ endpoint, client_id, client_secret, _org = decode_api_key(api_key)
116
120
  kwargs["auth_type"] = "ClientSecret"
117
121
  kwargs["client_id"] = client_id
118
122
  kwargs["client_secret"] = client_secret
@@ -15,19 +15,22 @@ if "GRPC_VERBOSITY" not in os.environ:
15
15
  #### Has to be before grpc
16
16
 
17
17
  import grpc
18
- from flyteidl.service import admin_pb2_grpc, dataproxy_pb2_grpc
19
-
20
- from flyte._protos.secret import secret_pb2_grpc
21
- from flyte._protos.workflow import run_logs_service_pb2_grpc, run_service_pb2_grpc, task_service_pb2_grpc
18
+ from flyteidl.service import admin_pb2_grpc, dataproxy_pb2_grpc, identity_pb2_grpc
19
+ from flyteidl2.secret import secret_pb2_grpc
20
+ from flyteidl2.task import task_service_pb2_grpc
21
+ from flyteidl2.trigger import trigger_service_pb2_grpc
22
+ from flyteidl2.workflow import run_logs_service_pb2_grpc, run_service_pb2_grpc
22
23
 
23
24
  from ._protocols import (
24
25
  DataProxyService,
26
+ IdentityService,
25
27
  MetadataServiceProtocol,
26
28
  ProjectDomainService,
27
29
  RunLogsService,
28
30
  RunService,
29
31
  SecretService,
30
32
  TaskService,
33
+ TriggerService,
31
34
  )
32
35
  from .auth import create_channel
33
36
 
@@ -38,7 +41,6 @@ class ClientSet:
38
41
  channel: grpc.aio.Channel,
39
42
  endpoint: str,
40
43
  insecure: bool = False,
41
- data_proxy_channel: grpc.aio.Channel | None = None,
42
44
  **kwargs,
43
45
  ):
44
46
  self.endpoint = endpoint
@@ -50,6 +52,8 @@ class ClientSet:
50
52
  self._dataproxy = dataproxy_pb2_grpc.DataProxyServiceStub(channel=channel)
51
53
  self._log_service = run_logs_service_pb2_grpc.RunLogsServiceStub(channel=channel)
52
54
  self._secrets_service = secret_pb2_grpc.SecretServiceStub(channel=channel)
55
+ self._identity_service = identity_pb2_grpc.IdentityServiceStub(channel=channel)
56
+ self._trigger_service = trigger_service_pb2_grpc.TriggerServiceStub(channel=channel)
53
57
 
54
58
  @classmethod
55
59
  async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet:
@@ -105,5 +109,13 @@ class ClientSet:
105
109
  def secrets_service(self) -> SecretService:
106
110
  return self._secrets_service
107
111
 
112
+ @property
113
+ def identity_service(self) -> IdentityService:
114
+ return self._identity_service
115
+
116
+ @property
117
+ def trigger_service(self) -> TriggerService:
118
+ return self._trigger_service
119
+
108
120
  async def close(self, grace: float | None = None):
109
121
  return await self._channel.close(grace=grace)
flyte/remote/_console.py CHANGED
@@ -9,8 +9,9 @@ def _get_http_domain(endpoint: str, insecure: bool) -> str:
9
9
  else:
10
10
  domain = parsed.netloc or parsed.path
11
11
  # TODO: make console url configurable
12
- if domain.split(":")[0] == "localhost":
13
- domain = "localhost:8080"
12
+ domain_split = domain.split(":")
13
+ if domain_split[0] == "localhost":
14
+ domain = domain if len(domain_split) > 1 else f"{domain}:8080"
14
15
  return f"{scheme}://{domain}"
15
16
 
16
17
 
flyte/remote/_data.py CHANGED
@@ -15,8 +15,7 @@ import httpx
15
15
  from flyteidl.service import dataproxy_pb2
16
16
  from google.protobuf import duration_pb2
17
17
 
18
- from flyte._initialize import CommonInit, ensure_client, get_client, get_common_config
19
- from flyte._logging import make_hyperlink
18
+ from flyte._initialize import CommonInit, ensure_client, get_client, get_init_config, require_project_and_domain
20
19
  from flyte.errors import InitializationError, RuntimeSystemError
21
20
  from flyte.syncify import syncify
22
21
 
@@ -55,6 +54,7 @@ def hash_file(file_path: typing.Union[os.PathLike, str]) -> Tuple[bytes, str, in
55
54
  return h.digest(), h.hexdigest(), size
56
55
 
57
56
 
57
+ @require_project_and_domain
58
58
  async def _upload_single_file(
59
59
  cfg: CommonInit, fp: Path, verify: bool = True, basedir: str | None = None
60
60
  ) -> Tuple[str, str]:
@@ -91,7 +91,7 @@ async def _upload_single_file(
91
91
  raise RuntimeSystemError(e.code().value, f"Failed to get signed url for {fp}: {e.details()}")
92
92
  except Exception as e:
93
93
  raise RuntimeSystemError(type(e).__name__, f"Failed to get signed url for {fp}.") from e
94
- logger.debug(f"Uploading to {make_hyperlink('signed url', resp.signed_url)} for {fp}")
94
+ logger.debug(f"Uploading to [link={resp.signed_url}]signed url[/link] for [link=file://{fp}]{fp}[/link]")
95
95
  extra_headers = get_extra_headers_for_protocol(resp.native_url)
96
96
  extra_headers.update(resp.headers)
97
97
  encoded_md5 = b64encode(md5_bytes)
@@ -101,7 +101,7 @@ async def _upload_single_file(
101
101
  extra_headers.update({"Content-Length": str(content_length), "Content-MD5": encoded_md5.decode("utf-8")})
102
102
  async with httpx.AsyncClient(verify=verify) as aclient:
103
103
  put_resp = await aclient.put(resp.signed_url, headers=extra_headers, content=file)
104
- if put_resp.status_code != 200:
104
+ if put_resp.status_code not in [200, 201, 204]:
105
105
  raise RuntimeSystemError(
106
106
  "UploadFailed",
107
107
  f"Failed to upload {fp} to {resp.signed_url}, status code: {put_resp.status_code}, "
@@ -125,7 +125,7 @@ async def upload_file(fp: Path, verify: bool = True) -> Tuple[str, str]:
125
125
  """
126
126
  # This is a placeholder implementation. Replace with actual upload logic.
127
127
  ensure_client()
128
- cfg = get_common_config()
128
+ cfg = get_init_config()
129
129
  if not fp.is_file():
130
130
  raise ValueError(f"{fp} is not a single file, upload arg must be a single file.")
131
131
  return await _upload_single_file(cfg, fp, verify=verify)
@@ -141,7 +141,7 @@ async def upload_dir(dir_path: Path, verify: bool = True) -> str:
141
141
  """
142
142
  # This is a placeholder implementation. Replace with actual upload logic.
143
143
  ensure_client()
144
- cfg = get_common_config()
144
+ cfg = get_init_config()
145
145
  if not dir_path.is_dir():
146
146
  raise ValueError(f"{dir_path} is not a directory, upload arg must be a directory.")
147
147
 
flyte/remote/_logs.py CHANGED
@@ -4,6 +4,9 @@ from dataclasses import dataclass
4
4
  from typing import AsyncGenerator, AsyncIterator
5
5
 
6
6
  import grpc
7
+ from flyteidl2.common import identifier_pb2
8
+ from flyteidl2.logs.dataplane import payload_pb2
9
+ from flyteidl2.workflow import run_logs_service_pb2
7
10
  from rich.console import Console
8
11
  from rich.live import Live
9
12
  from rich.panel import Panel
@@ -11,9 +14,6 @@ from rich.text import Text
11
14
 
12
15
  from flyte._initialize import ensure_client, get_client
13
16
  from flyte._logging import logger
14
- from flyte._protos.common import identifier_pb2
15
- from flyte._protos.logs.dataplane import payload_pb2
16
- from flyte._protos.workflow import run_logs_service_pb2
17
17
  from flyte._tools import ipython_check, ipywidgets_check
18
18
  from flyte.errors import LogsNotYetAvailableError
19
19
  from flyte.syncify import syncify
flyte/remote/_run.py CHANGED
@@ -5,10 +5,11 @@ from typing import AsyncGenerator, AsyncIterator, Literal, Tuple
5
5
 
6
6
  import grpc
7
7
  import rich.repr
8
+ from flyteidl2.common import identifier_pb2, list_pb2
9
+ from flyteidl2.workflow import run_definition_pb2, run_service_pb2
8
10
 
9
- from flyte._initialize import ensure_client, get_client, get_common_config
10
- from flyte._protos.common import identifier_pb2, list_pb2
11
- from flyte._protos.workflow import run_definition_pb2, run_service_pb2
11
+ from flyte._initialize import ensure_client, get_client, get_init_config
12
+ from flyte._logging import logger
12
13
  from flyte.syncify import syncify
13
14
 
14
15
  from . import Action, ActionDetails, ActionInputs, ActionOutputs
@@ -16,6 +17,11 @@ from ._action import _action_details_rich_repr, _action_rich_repr
16
17
  from ._common import ToJSONMixin
17
18
  from ._console import get_run_url
18
19
 
20
+ # @kumare3 is sadpanda, because we have to create a mirror of phase types here, because protobuf phases are ghastly
21
+ Phase = Literal[
22
+ "queued", "waiting_for_resources", "initializing", "running", "succeeded", "failed", "aborted", "timed_out"
23
+ ]
24
+
19
25
 
20
26
  @dataclass
21
27
  class Run(ToJSONMixin):
@@ -40,14 +46,16 @@ class Run(ToJSONMixin):
40
46
  @classmethod
41
47
  async def listall(
42
48
  cls,
43
- filters: str | None = None,
49
+ in_phase: Tuple[Phase] | None = None,
50
+ created_by_subject: str | None = None,
44
51
  sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
45
52
  limit: int = 100,
46
53
  ) -> AsyncIterator[Run]:
47
54
  """
48
55
  Get all runs for the current project and domain.
49
56
 
50
- :param filters: The filters to apply to the project list.
57
+ :param in_phase: Filter runs by one or more phases.
58
+ :param created_by_subject: Filter runs by the subject that created them. (this is not username, but the subject)
51
59
  :param sort_by: The sorting criteria for the project list, in the format (field, order).
52
60
  :param limit: The maximum number of runs to return.
53
61
  :return: An iterator of runs.
@@ -59,13 +67,44 @@ class Run(ToJSONMixin):
59
67
  key=sort_by[0],
60
68
  direction=(list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING),
61
69
  )
62
- cfg = get_common_config()
70
+ filters = []
71
+ if in_phase:
72
+ phases = [str(run_definition_pb2.Phase.Value(f"PHASE_{p.upper()}")) for p in in_phase]
73
+ logger.debug(f"Fetching run phases: {phases}")
74
+ if len(phases) > 1:
75
+ filters.append(
76
+ list_pb2.Filter(
77
+ function=list_pb2.Filter.Function.VALUE_IN,
78
+ field="phase",
79
+ values=phases,
80
+ ),
81
+ )
82
+ else:
83
+ filters.append(
84
+ list_pb2.Filter(
85
+ function=list_pb2.Filter.Function.EQUAL,
86
+ field="phase",
87
+ values=phases[0],
88
+ ),
89
+ )
90
+ if created_by_subject:
91
+ logger.debug(f"Fetching runs created by: {created_by_subject}")
92
+ filters.append(
93
+ list_pb2.Filter(
94
+ function=list_pb2.Filter.Function.EQUAL,
95
+ field="created_by",
96
+ values=[created_by_subject],
97
+ ),
98
+ )
99
+
100
+ cfg = get_init_config()
63
101
  i = 0
64
102
  while True:
65
103
  req = list_pb2.ListRequest(
66
104
  limit=min(100, limit),
67
105
  token=token,
68
106
  sort_by=sort_pb2,
107
+ filters=filters,
69
108
  )
70
109
  resp = await get_client().run_service.ListRuns(
71
110
  run_service_pb2.ListRunsRequest(
@@ -157,10 +196,26 @@ class Run(ToJSONMixin):
157
196
  """
158
197
  Get the details of the run. This is a placeholder for getting the run details.
159
198
  """
160
- if self._details is None:
199
+ if self._details is None or not self._details.done():
161
200
  self._details = await RunDetails.get_details.aio(self.pb2.action.id.run)
162
201
  return self._details
163
202
 
203
+ @syncify
204
+ async def inputs(self) -> ActionInputs:
205
+ """
206
+ Get the inputs of the run. This is a placeholder for getting the run inputs.
207
+ """
208
+ details = await self.details.aio()
209
+ return await details.inputs()
210
+
211
+ @syncify
212
+ async def outputs(self) -> ActionOutputs:
213
+ """
214
+ Get the outputs of the run. This is a placeholder for getting the run outputs.
215
+ """
216
+ details = await self.details.aio()
217
+ return await details.outputs()
218
+
164
219
  @property
165
220
  def url(self) -> str:
166
221
  """
@@ -209,6 +264,7 @@ class Run(ToJSONMixin):
209
264
  """
210
265
  Rich representation of the Run object.
211
266
  """
267
+ yield "url", f"[blue bold][link={self.url}]link[/link][/blue bold]"
212
268
  yield from _action_rich_repr(self.pb2.action)
213
269
 
214
270
  def __repr__(self) -> str:
@@ -260,7 +316,7 @@ class RunDetails(ToJSONMixin):
260
316
  :param name: The name of the run.
261
317
  """
262
318
  ensure_client()
263
- cfg = get_common_config()
319
+ cfg = get_init_config()
264
320
  return await RunDetails.get_details.aio(
265
321
  run_id=identifier_pb2.RunIdentifier(
266
322
  org=cfg.org,
flyte/remote/_secret.py CHANGED
@@ -4,9 +4,9 @@ from dataclasses import dataclass
4
4
  from typing import AsyncIterator, Literal, Union
5
5
 
6
6
  import rich.repr
7
+ from flyteidl2.secret import definition_pb2, payload_pb2
7
8
 
8
- from flyte._initialize import ensure_client, get_client, get_common_config
9
- from flyte._protos.secret import definition_pb2, payload_pb2
9
+ from flyte._initialize import ensure_client, get_client, get_init_config
10
10
  from flyte.remote._common import ToJSONMixin
11
11
  from flyte.syncify import syncify
12
12
 
@@ -21,12 +21,19 @@ class Secret(ToJSONMixin):
21
21
  @classmethod
22
22
  async def create(cls, name: str, value: Union[str, bytes], type: SecretTypes = "regular"):
23
23
  ensure_client()
24
- cfg = get_common_config()
25
- secret_type = (
26
- definition_pb2.SecretType.SECRET_TYPE_GENERIC
27
- if type == "regular"
28
- else definition_pb2.SecretType.SECRET_TYPE_IMAGE_PULL_SECRET
29
- )
24
+ cfg = get_init_config()
25
+ project = cfg.project
26
+ domain = cfg.domain
27
+
28
+ if type == "regular":
29
+ secret_type = definition_pb2.SecretType.SECRET_TYPE_GENERIC
30
+
31
+ else:
32
+ secret_type = definition_pb2.SecretType.SECRET_TYPE_IMAGE_PULL_SECRET
33
+ if project or domain:
34
+ raise ValueError(
35
+ f"Project `{project}` or domain `{domain}` should not be set when creating the image pull secret."
36
+ )
30
37
 
31
38
  if isinstance(value, str):
32
39
  secret = definition_pb2.SecretSpec(
@@ -42,8 +49,8 @@ class Secret(ToJSONMixin):
42
49
  request=payload_pb2.CreateSecretRequest(
43
50
  id=definition_pb2.SecretIdentifier(
44
51
  organization=cfg.org,
45
- project=cfg.project,
46
- domain=cfg.domain,
52
+ project=project,
53
+ domain=domain,
47
54
  name=name,
48
55
  ),
49
56
  secret_spec=secret,
@@ -54,7 +61,7 @@ class Secret(ToJSONMixin):
54
61
  @classmethod
55
62
  async def get(cls, name: str) -> Secret:
56
63
  ensure_client()
57
- cfg = get_common_config()
64
+ cfg = get_init_config()
58
65
  resp = await get_client().secrets_service.GetSecret(
59
66
  request=payload_pb2.GetSecretRequest(
60
67
  id=definition_pb2.SecretIdentifier(
@@ -71,29 +78,31 @@ class Secret(ToJSONMixin):
71
78
  @classmethod
72
79
  async def listall(cls, limit: int = 100) -> AsyncIterator[Secret]:
73
80
  ensure_client()
74
- cfg = get_common_config()
75
- token = None
81
+ cfg = get_init_config()
82
+ per_cluster_tokens = None
76
83
  while True:
77
84
  resp = await get_client().secrets_service.ListSecrets( # type: ignore
78
85
  request=payload_pb2.ListSecretsRequest(
79
86
  organization=cfg.org,
80
87
  project=cfg.project,
81
88
  domain=cfg.domain,
82
- token=token,
89
+ per_cluster_tokens=per_cluster_tokens,
83
90
  limit=limit,
84
91
  ),
85
92
  )
86
- token = resp.token
93
+ per_cluster_tokens = resp.per_cluster_tokens
94
+ round_items = [v for _, v in per_cluster_tokens.items() if v]
95
+ has_next = any(round_items)
87
96
  for r in resp.secrets:
88
97
  yield cls(r)
89
- if not token:
98
+ if not has_next:
90
99
  break
91
100
 
92
101
  @syncify
93
102
  @classmethod
94
103
  async def delete(cls, name):
95
104
  ensure_client()
96
- cfg = get_common_config()
105
+ cfg = get_init_config()
97
106
  await get_client().secrets_service.DeleteSecret( # type: ignore
98
107
  request=payload_pb2.DeleteSecretRequest(
99
108
  id=definition_pb2.SecretIdentifier(