flyte 2.0.0b22__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/runtime.py +43 -5
  3. flyte/_cache/cache.py +4 -2
  4. flyte/_cache/local_cache.py +216 -0
  5. flyte/_code_bundle/_ignore.py +1 -1
  6. flyte/_code_bundle/_packaging.py +4 -4
  7. flyte/_code_bundle/_utils.py +14 -8
  8. flyte/_code_bundle/bundle.py +13 -5
  9. flyte/_constants.py +1 -0
  10. flyte/_context.py +4 -1
  11. flyte/_custom_context.py +73 -0
  12. flyte/_debug/constants.py +0 -1
  13. flyte/_debug/vscode.py +6 -1
  14. flyte/_deploy.py +223 -59
  15. flyte/_environment.py +5 -0
  16. flyte/_excepthook.py +1 -1
  17. flyte/_image.py +144 -82
  18. flyte/_initialize.py +95 -12
  19. flyte/_interface.py +2 -0
  20. flyte/_internal/controllers/_local_controller.py +65 -24
  21. flyte/_internal/controllers/_trace.py +1 -1
  22. flyte/_internal/controllers/remote/_action.py +13 -11
  23. flyte/_internal/controllers/remote/_client.py +1 -1
  24. flyte/_internal/controllers/remote/_controller.py +9 -4
  25. flyte/_internal/controllers/remote/_core.py +16 -16
  26. flyte/_internal/controllers/remote/_informer.py +4 -4
  27. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  28. flyte/_internal/imagebuild/docker_builder.py +139 -84
  29. flyte/_internal/imagebuild/image_builder.py +7 -13
  30. flyte/_internal/imagebuild/remote_builder.py +65 -13
  31. flyte/_internal/imagebuild/utils.py +51 -3
  32. flyte/_internal/resolvers/_task_module.py +5 -38
  33. flyte/_internal/resolvers/default.py +2 -2
  34. flyte/_internal/runtime/convert.py +42 -20
  35. flyte/_internal/runtime/entrypoints.py +24 -1
  36. flyte/_internal/runtime/io.py +21 -8
  37. flyte/_internal/runtime/resources_serde.py +20 -6
  38. flyte/_internal/runtime/reuse.py +1 -1
  39. flyte/_internal/runtime/rusty.py +20 -5
  40. flyte/_internal/runtime/task_serde.py +33 -27
  41. flyte/_internal/runtime/taskrunner.py +10 -1
  42. flyte/_internal/runtime/trigger_serde.py +160 -0
  43. flyte/_internal/runtime/types_serde.py +1 -1
  44. flyte/_keyring/file.py +39 -9
  45. flyte/_logging.py +79 -12
  46. flyte/_map.py +31 -12
  47. flyte/_module.py +70 -0
  48. flyte/_pod.py +2 -2
  49. flyte/_resources.py +213 -31
  50. flyte/_run.py +107 -41
  51. flyte/_task.py +66 -10
  52. flyte/_task_environment.py +96 -24
  53. flyte/_task_plugins.py +4 -2
  54. flyte/_trigger.py +1000 -0
  55. flyte/_utils/__init__.py +2 -1
  56. flyte/_utils/asyn.py +3 -1
  57. flyte/_utils/docker_credentials.py +173 -0
  58. flyte/_utils/module_loader.py +17 -2
  59. flyte/_version.py +3 -3
  60. flyte/cli/_abort.py +3 -3
  61. flyte/cli/_build.py +1 -3
  62. flyte/cli/_common.py +78 -7
  63. flyte/cli/_create.py +178 -3
  64. flyte/cli/_delete.py +23 -1
  65. flyte/cli/_deploy.py +49 -11
  66. flyte/cli/_get.py +79 -34
  67. flyte/cli/_params.py +8 -6
  68. flyte/cli/_plugins.py +209 -0
  69. flyte/cli/_run.py +127 -11
  70. flyte/cli/_serve.py +64 -0
  71. flyte/cli/_update.py +37 -0
  72. flyte/cli/_user.py +17 -0
  73. flyte/cli/main.py +30 -4
  74. flyte/config/_config.py +2 -0
  75. flyte/config/_internal.py +1 -0
  76. flyte/config/_reader.py +3 -3
  77. flyte/connectors/__init__.py +11 -0
  78. flyte/connectors/_connector.py +270 -0
  79. flyte/connectors/_server.py +197 -0
  80. flyte/connectors/utils.py +135 -0
  81. flyte/errors.py +10 -1
  82. flyte/extend.py +8 -1
  83. flyte/extras/_container.py +6 -1
  84. flyte/git/_config.py +11 -9
  85. flyte/io/__init__.py +2 -0
  86. flyte/io/_dataframe/__init__.py +2 -0
  87. flyte/io/_dataframe/basic_dfs.py +1 -1
  88. flyte/io/_dataframe/dataframe.py +12 -8
  89. flyte/io/_dir.py +551 -120
  90. flyte/io/_file.py +538 -141
  91. flyte/models.py +57 -12
  92. flyte/remote/__init__.py +6 -1
  93. flyte/remote/_action.py +18 -16
  94. flyte/remote/_client/_protocols.py +39 -4
  95. flyte/remote/_client/auth/_channel.py +10 -6
  96. flyte/remote/_client/controlplane.py +17 -5
  97. flyte/remote/_console.py +3 -2
  98. flyte/remote/_data.py +4 -3
  99. flyte/remote/_logs.py +3 -3
  100. flyte/remote/_run.py +47 -7
  101. flyte/remote/_secret.py +26 -17
  102. flyte/remote/_task.py +21 -9
  103. flyte/remote/_trigger.py +306 -0
  104. flyte/remote/_user.py +33 -0
  105. flyte/storage/__init__.py +6 -1
  106. flyte/storage/_parallel_reader.py +274 -0
  107. flyte/storage/_storage.py +185 -103
  108. flyte/types/__init__.py +16 -0
  109. flyte/types/_interface.py +2 -2
  110. flyte/types/_pickle.py +17 -4
  111. flyte/types/_string_literals.py +8 -9
  112. flyte/types/_type_engine.py +26 -19
  113. flyte/types/_utils.py +1 -1
  114. {flyte-2.0.0b22.data → flyte-2.0.0b30.data}/scripts/runtime.py +43 -5
  115. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/METADATA +8 -1
  116. flyte-2.0.0b30.dist-info/RECORD +192 -0
  117. flyte/_protos/__init__.py +0 -0
  118. flyte/_protos/common/authorization_pb2.py +0 -66
  119. flyte/_protos/common/authorization_pb2.pyi +0 -108
  120. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  121. flyte/_protos/common/identifier_pb2.py +0 -99
  122. flyte/_protos/common/identifier_pb2.pyi +0 -120
  123. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  124. flyte/_protos/common/identity_pb2.py +0 -48
  125. flyte/_protos/common/identity_pb2.pyi +0 -72
  126. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  127. flyte/_protos/common/list_pb2.py +0 -36
  128. flyte/_protos/common/list_pb2.pyi +0 -71
  129. flyte/_protos/common/list_pb2_grpc.py +0 -4
  130. flyte/_protos/common/policy_pb2.py +0 -37
  131. flyte/_protos/common/policy_pb2.pyi +0 -27
  132. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  133. flyte/_protos/common/role_pb2.py +0 -37
  134. flyte/_protos/common/role_pb2.pyi +0 -53
  135. flyte/_protos/common/role_pb2_grpc.py +0 -4
  136. flyte/_protos/common/runtime_version_pb2.py +0 -28
  137. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  138. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  139. flyte/_protos/imagebuilder/definition_pb2.py +0 -60
  140. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -153
  141. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  142. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  143. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  144. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  145. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  146. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  147. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  148. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  149. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  150. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  151. flyte/_protos/secret/definition_pb2.py +0 -49
  152. flyte/_protos/secret/definition_pb2.pyi +0 -93
  153. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  154. flyte/_protos/secret/payload_pb2.py +0 -62
  155. flyte/_protos/secret/payload_pb2.pyi +0 -94
  156. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  157. flyte/_protos/secret/secret_pb2.py +0 -38
  158. flyte/_protos/secret/secret_pb2.pyi +0 -6
  159. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  160. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  161. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  162. flyte/_protos/workflow/common_pb2.py +0 -27
  163. flyte/_protos/workflow/common_pb2.pyi +0 -14
  164. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  165. flyte/_protos/workflow/environment_pb2.py +0 -29
  166. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  167. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  168. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  169. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  170. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  171. flyte/_protos/workflow/queue_service_pb2.py +0 -111
  172. flyte/_protos/workflow/queue_service_pb2.pyi +0 -168
  173. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  174. flyte/_protos/workflow/run_definition_pb2.py +0 -123
  175. flyte/_protos/workflow/run_definition_pb2.pyi +0 -352
  176. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  177. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  178. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  179. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  180. flyte/_protos/workflow/run_service_pb2.py +0 -137
  181. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  182. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  183. flyte/_protos/workflow/state_service_pb2.py +0 -67
  184. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  185. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  186. flyte/_protos/workflow/task_definition_pb2.py +0 -82
  187. flyte/_protos/workflow/task_definition_pb2.pyi +0 -88
  188. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  189. flyte/_protos/workflow/task_service_pb2.py +0 -60
  190. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  191. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  192. flyte-2.0.0b22.dist-info/RECORD +0 -250
  193. {flyte-2.0.0b22.data → flyte-2.0.0b30.data}/scripts/debug.py +0 -0
  194. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  195. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +0 -0
  196. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  197. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
flyte/models.py CHANGED
@@ -14,7 +14,7 @@ from flyte._interface import extract_return_annotation, literal_to_enum
14
14
  from flyte._logging import logger
15
15
 
16
16
  if TYPE_CHECKING:
17
- from flyteidl.core import literals_pb2
17
+ from flyteidl2.core import literals_pb2
18
18
 
19
19
  from flyte._internal.imagebuild.image_builder import ImageCache
20
20
  from flyte.report import Report
@@ -77,6 +77,37 @@ class ActionID:
77
77
  return self.new_sub_action(new_name)
78
78
 
79
79
 
80
+ @rich.repr.auto
81
+ @dataclass
82
+ class PathRewrite:
83
+ """
84
+ Configuration for rewriting paths during input loading.
85
+ """
86
+
87
+ # If set, rewrites any path starting with this prefix to the new prefix.
88
+ old_prefix: str
89
+ new_prefix: str
90
+
91
+ def __post_init__(self):
92
+ if not self.old_prefix or not self.new_prefix:
93
+ raise ValueError("Both old_prefix and new_prefix must be non-empty strings.")
94
+ if self.old_prefix == self.new_prefix:
95
+ raise ValueError("old_prefix and new_prefix must be different.")
96
+
97
+ @classmethod
98
+ def from_str(cls, pattern: str) -> PathRewrite:
99
+ """
100
+ Create a PathRewrite from a string pattern of the form `old_prefix->new_prefix`.
101
+ """
102
+ parts = pattern.split("->")
103
+ if len(parts) != 2:
104
+ raise ValueError(f"Invalid path rewrite pattern: {pattern}. Expected format 'old_prefix->new_prefix'.")
105
+ return cls(old_prefix=parts[0], new_prefix=parts[1])
106
+
107
+ def __repr__(self) -> str:
108
+ return f"{self.old_prefix}->{self.new_prefix}"
109
+
110
+
80
111
  @rich.repr.auto
81
112
  @dataclass(frozen=True, kw_only=True)
82
113
  class RawDataPath:
@@ -86,6 +117,7 @@ class RawDataPath:
86
117
  """
87
118
 
88
119
  path: str
120
+ path_rewrite: Optional[PathRewrite] = None
89
121
 
90
122
  @classmethod
91
123
  def from_local_folder(cls, local_folder: str | pathlib.Path | None = None) -> RawDataPath:
@@ -112,7 +144,7 @@ class RawDataPath:
112
144
 
113
145
  def get_random_remote_path(self, file_name: Optional[str] = None) -> str:
114
146
  """
115
- Returns a random path for uploading a file/directory to.
147
+ Returns a random path for uploading a file/directory to. This file/folder will not be created, it's just a path.
116
148
 
117
149
  :param file_name: If given, will be joined after a randomly generated portion.
118
150
  :return:
@@ -128,13 +160,14 @@ class RawDataPath:
128
160
 
129
161
  protocol = get_protocol(file_prefix)
130
162
  if "file" in protocol:
131
- local_path = pathlib.Path(file_prefix) / random_string
163
+ parent_folder = pathlib.Path(file_prefix)
164
+ parent_folder.mkdir(exist_ok=True, parents=True)
132
165
  if file_name:
133
- # Only if file name is given do we create the parent, because it may be needed as a folder otherwise
134
- local_path = local_path / file_name
135
- if not local_path.exists():
136
- local_path.parent.mkdir(exist_ok=True, parents=True)
137
- local_path.touch()
166
+ random_folder = parent_folder / random_string
167
+ random_folder.mkdir()
168
+ local_path = random_folder / file_name
169
+ else:
170
+ local_path = parent_folder / random_string
138
171
  return str(local_path.absolute())
139
172
 
140
173
  fs = fsspec.filesystem(protocol)
@@ -162,6 +195,8 @@ class TaskContext:
162
195
  :param action: The action ID of the current execution. This is always set, within a run.
163
196
  :param version: The version of the executed task. This is set when the task is executed by an action and will be
164
197
  set on all sub-actions.
198
+ :param custom_context: Context metadata for the action. If an action receives context, it'll automatically pass it
199
+ to any actions it spawns. Context will not be used for cache key computation.
165
200
  """
166
201
 
167
202
  action: ActionID
@@ -178,6 +213,7 @@ class TaskContext:
178
213
  data: Dict[str, Any] = field(default_factory=dict)
179
214
  mode: Literal["local", "remote", "hybrid"] = "remote"
180
215
  interactive_mode: bool = False
216
+ custom_context: Dict[str, str] = field(default_factory=dict)
181
217
 
182
218
  def replace(self, **kwargs) -> TaskContext:
183
219
  if "data" in kwargs:
@@ -319,10 +355,18 @@ class NativeInterface:
319
355
  """
320
356
  Extract the native interface from the given function. This is used to create a native interface for the task.
321
357
  """
358
+ # Get function parameters, defaults, varargs info (POSITIONAL_ONLY, VAR_POSITIONAL, KEYWORD_ONLY, etc.).
322
359
  sig = inspect.signature(func)
323
360
 
324
361
  # Extract parameter details (name, type, default value)
325
362
  param_info = {}
363
+ try:
364
+ # Get fully evaluated, real Python types for type checking.
365
+ hints = typing.get_type_hints(func, include_extras=True)
366
+ except Exception as e:
367
+ logger.warning(f"Could not get type hints for function {func.__name__}: {e}")
368
+ raise
369
+
326
370
  for name, param in sig.parameters.items():
327
371
  if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
328
372
  raise ValueError(f"Function {func.__name__} cannot have variable positional or keyword arguments.")
@@ -330,13 +374,14 @@ class NativeInterface:
330
374
  logger.warning(
331
375
  f"Function {func.__name__} has parameter {name} without type annotation. Data will be pickled."
332
376
  )
333
- if typing.get_origin(param.annotation) is Literal:
334
- param_info[name] = (literal_to_enum(param.annotation), param.default)
377
+ arg_type = hints.get(name, param.annotation)
378
+ if typing.get_origin(arg_type) is Literal:
379
+ param_info[name] = (literal_to_enum(arg_type), param.default)
335
380
  else:
336
- param_info[name] = (param.annotation, param.default)
381
+ param_info[name] = (arg_type, param.default)
337
382
 
338
383
  # Get return type
339
- outputs = extract_return_annotation(sig.return_annotation)
384
+ outputs = extract_return_annotation(hints.get("return", sig.return_annotation))
340
385
  return cls(inputs=param_info, outputs=outputs)
341
386
 
342
387
  def convert_to_kwargs(self, *args, **kwargs) -> Dict[str, Any]:
flyte/remote/__init__.py CHANGED
@@ -7,12 +7,15 @@ __all__ = [
7
7
  "ActionDetails",
8
8
  "ActionInputs",
9
9
  "ActionOutputs",
10
+ "Phase",
10
11
  "Project",
11
12
  "Run",
12
13
  "RunDetails",
13
14
  "Secret",
14
15
  "SecretTypes",
15
16
  "Task",
17
+ "Trigger",
18
+ "User",
16
19
  "create_channel",
17
20
  "upload_dir",
18
21
  "upload_file",
@@ -22,6 +25,8 @@ from ._action import Action, ActionDetails, ActionInputs, ActionOutputs
22
25
  from ._client.auth import create_channel
23
26
  from ._data import upload_dir, upload_file
24
27
  from ._project import Project
25
- from ._run import Run, RunDetails
28
+ from ._run import Phase, Run, RunDetails
26
29
  from ._secret import Secret, SecretTypes
27
30
  from ._task import Task
31
+ from ._trigger import Trigger
32
+ from ._user import User
flyte/remote/_action.py CHANGED
@@ -20,14 +20,15 @@ from typing import (
20
20
  import grpc
21
21
  import rich.pretty
22
22
  import rich.repr
23
+ from flyteidl2.common import identifier_pb2, list_pb2
24
+ from flyteidl2.task import common_pb2
25
+ from flyteidl2.workflow import run_definition_pb2, run_service_pb2
26
+ from flyteidl2.workflow.run_service_pb2 import WatchActionDetailsResponse
23
27
  from rich.console import Console
24
28
  from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
25
29
 
26
30
  from flyte import types
27
- from flyte._initialize import ensure_client, get_client, get_common_config
28
- from flyte._protos.common import identifier_pb2, list_pb2
29
- from flyte._protos.workflow import run_definition_pb2, run_service_pb2
30
- from flyte._protos.workflow.run_service_pb2 import WatchActionDetailsResponse
31
+ from flyte._initialize import ensure_client, get_client, get_init_config
31
32
  from flyte.remote._common import ToJSONMixin
32
33
  from flyte.remote._logs import Logs
33
34
  from flyte.syncify import syncify
@@ -67,14 +68,14 @@ def _action_rich_repr(action: run_definition_pb2.Action) -> rich.repr.Result:
67
68
  """
68
69
  Rich representation of the action.
69
70
  """
70
- yield "run", action.id.run.name
71
+ yield "name", action.id.run.name
71
72
  if action.metadata.HasField("task"):
72
- yield "task", action.metadata.task.id.name
73
+ yield "task name", action.metadata.task.id.name
73
74
  yield "type", action.metadata.task.task_type
74
75
  elif action.metadata.HasField("trace"):
75
76
  yield "trace", action.metadata.trace.name
76
77
  yield "type", "trace"
77
- yield "name", action.id.name
78
+ yield "action name", action.id.name
78
79
  yield from _action_time_phase(action)
79
80
  yield "group", action.metadata.group
80
81
  yield "parent", action.metadata.parent
@@ -98,9 +99,10 @@ def _action_details_rich_repr(
98
99
  """
99
100
  yield "name", action.id.run.name
100
101
  yield from _action_time_phase(action)
101
- yield "task", action.resolved_task_spec.task_template.id.name
102
- yield "task_type", action.resolved_task_spec.task_template.type
103
- yield "task_version", action.resolved_task_spec.task_template.id.version
102
+ if action.HasField("task"):
103
+ yield "task", action.task.task_template.id.name
104
+ yield "task_type", action.task.task_template.type
105
+ yield "task_version", action.task.task_template.id.version
104
106
  yield "attempts", action.attempts
105
107
  yield "error", (f"{action.error_info.kind}: {action.error_info.message}" if action.HasField("error_info") else "NA")
106
108
  yield "phase", run_definition_pb2.Phase.Name(action.status.phase)
@@ -152,7 +154,7 @@ class Action(ToJSONMixin):
152
154
  key=sort_by[0],
153
155
  direction=(list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING),
154
156
  )
155
- cfg = get_common_config()
157
+ cfg = get_init_config()
156
158
  while True:
157
159
  req = list_pb2.ListRequest(
158
160
  limit=100,
@@ -193,7 +195,7 @@ class Action(ToJSONMixin):
193
195
  :param name: The name of the action.
194
196
  """
195
197
  ensure_client()
196
- cfg = get_common_config()
198
+ cfg = get_init_config()
197
199
  details: ActionDetails = await ActionDetails.get_details.aio(
198
200
  identifier_pb2.ActionIdentifier(
199
201
  run=identifier_pb2.RunIdentifier(
@@ -458,7 +460,7 @@ class ActionDetails(ToJSONMixin):
458
460
  ensure_client()
459
461
  if not uri:
460
462
  assert name is not None and run_name is not None, "Either uri or name and run_name must be provided"
461
- cfg = get_common_config()
463
+ cfg = get_init_config()
462
464
  return await cls.get_details.aio(
463
465
  identifier_pb2.ActionIdentifier(
464
466
  run=identifier_pb2.RunIdentifier(
@@ -706,7 +708,7 @@ class ActionInputs(UserDict, ToJSONMixin):
706
708
  remote Union API.
707
709
  """
708
710
 
709
- pb2: run_definition_pb2.Inputs
711
+ pb2: common_pb2.Inputs
710
712
  data: Dict[str, Any]
711
713
 
712
714
  def __repr__(self):
@@ -723,14 +725,14 @@ class ActionOutputs(tuple, ToJSONMixin):
723
725
  remote Union API.
724
726
  """
725
727
 
726
- def __new__(cls, pb2: run_definition_pb2.Outputs, data: Tuple[Any, ...]):
728
+ def __new__(cls, pb2: common_pb2.Outputs, data: Tuple[Any, ...]):
727
729
  # Create the tuple part
728
730
  obj = super().__new__(cls, data)
729
731
  # Store extra data (you can't do this here directly since it's immutable)
730
732
  obj.pb2 = pb2
731
733
  return obj
732
734
 
733
- def __init__(self, pb2: run_definition_pb2.Outputs, data: Tuple[Any, ...]):
735
+ def __init__(self, pb2: common_pb2.Outputs, data: Tuple[Any, ...]):
734
736
  # Normally you'd set instance attributes here,
735
737
  # but we've already set `pb2` in `__new__`
736
738
  self.pb2 = pb2
@@ -1,13 +1,14 @@
1
1
  from typing import AsyncIterator, Protocol
2
2
 
3
3
  from flyteidl.admin import project_attributes_pb2, project_pb2, version_pb2
4
- from flyteidl.service import dataproxy_pb2
4
+ from flyteidl.service import dataproxy_pb2, identity_pb2
5
+ from flyteidl2.secret import payload_pb2
6
+ from flyteidl2.task import task_service_pb2
7
+ from flyteidl2.trigger import trigger_service_pb2
8
+ from flyteidl2.workflow import run_logs_service_pb2, run_service_pb2
5
9
  from grpc.aio import UnaryStreamCall
6
10
  from grpc.aio._typing import RequestType
7
11
 
8
- from flyte._protos.secret import payload_pb2
9
- from flyte._protos.workflow import run_logs_service_pb2, run_service_pb2, task_service_pb2
10
-
11
12
 
12
13
  class MetadataServiceProtocol(Protocol):
13
14
  async def GetVersion(self, request: version_pb2.GetVersionRequest) -> version_pb2.GetVersionResponse: ...
@@ -131,3 +132,37 @@ class SecretService(Protocol):
131
132
  async def ListSecrets(self, request: payload_pb2.ListSecretsRequest) -> payload_pb2.ListSecretsResponse: ...
132
133
 
133
134
  async def DeleteSecret(self, request: payload_pb2.DeleteSecretRequest) -> payload_pb2.DeleteSecretResponse: ...
135
+
136
+
137
+ class IdentityService(Protocol):
138
+ async def UserInfo(self, request: identity_pb2.UserInfoRequest) -> identity_pb2.UserInfoResponse: ...
139
+
140
+
141
+ class TriggerService(Protocol):
142
+ async def DeployTrigger(
143
+ self, request: trigger_service_pb2.DeployTriggerRequest
144
+ ) -> trigger_service_pb2.DeployTriggerResponse: ...
145
+
146
+ async def GetTriggerDetails(
147
+ self, request: trigger_service_pb2.GetTriggerDetailsRequest
148
+ ) -> trigger_service_pb2.GetTriggerDetailsResponse: ...
149
+
150
+ async def GetTriggerRevisionDetails(
151
+ self, request: trigger_service_pb2.GetTriggerRevisionDetailsRequest
152
+ ) -> trigger_service_pb2.GetTriggerRevisionDetailsResponse: ...
153
+
154
+ async def ListTriggers(
155
+ self, request: trigger_service_pb2.ListTriggersRequest
156
+ ) -> trigger_service_pb2.ListTriggersResponse: ...
157
+
158
+ async def GetTriggerRevisionHistory(
159
+ self, request: trigger_service_pb2.GetTriggerRevisionHistoryRequest
160
+ ) -> trigger_service_pb2.GetTriggerRevisionHistoryResponse: ...
161
+
162
+ async def UpdateTriggers(
163
+ self, request: trigger_service_pb2.UpdateTriggersRequest
164
+ ) -> trigger_service_pb2.UpdateTriggersResponse: ...
165
+
166
+ async def DeleteTriggers(
167
+ self, request: trigger_service_pb2.DeleteTriggersRequest
168
+ ) -> trigger_service_pb2.DeleteTriggersResponse: ...
@@ -7,6 +7,7 @@ import httpx
7
7
  from grpc.experimental.aio import init_grpc_aio
8
8
 
9
9
  from flyte._logging import logger
10
+ from flyte._utils.org_discovery import hostname_from_url
10
11
 
11
12
  from ._authenticators.base import get_async_session
12
13
  from ._authenticators.factory import (
@@ -30,16 +31,19 @@ def bootstrap_ssl_from_server(endpoint: str) -> grpc.ChannelCredentials:
30
31
  :param endpoint: The endpoint URL to retrieve the SSL certificate from, may include port number
31
32
  :return: gRPC channel credentials created from the retrieved certificate
32
33
  """
34
+ hostname = hostname_from_url(endpoint)
35
+
33
36
  # Get port from endpoint or use 443
34
- endpoint_parts = endpoint.rsplit(":", 1)
37
+ endpoint_parts = hostname.rsplit(":", 1)
35
38
  if len(endpoint_parts) == 2 and endpoint_parts[1].isdigit():
36
39
  server_address = (endpoint_parts[0], int(endpoint_parts[1]))
37
40
  else:
38
- logger.warning(f"Unrecognized port in endpoint [{endpoint}], defaulting to 443.")
39
- server_address = (endpoint, 443)
41
+ logger.warning(f"Unrecognized port in endpoint [{hostname}], defaulting to 443.")
42
+ server_address = (hostname, 443)
40
43
 
41
- # Run the blocking SSL certificate retrieval in a thread pool
42
- cert = ssl.get_server_certificate(server_address)
44
+ # Run the blocking SSL certificate retrieval with a timeout
45
+ logger.debug(f"Retrieving SSL certificate from {server_address}")
46
+ cert = ssl.get_server_certificate(server_address, timeout=10)
43
47
  return grpc.ssl_channel_credentials(str.encode(cert))
44
48
 
45
49
 
@@ -112,7 +116,7 @@ async def create_channel(
112
116
  if api_key:
113
117
  from flyte.remote._client.auth._auth_utils import decode_api_key
114
118
 
115
- endpoint, client_id, client_secret, org = decode_api_key(api_key)
119
+ endpoint, client_id, client_secret, _org = decode_api_key(api_key)
116
120
  kwargs["auth_type"] = "ClientSecret"
117
121
  kwargs["client_id"] = client_id
118
122
  kwargs["client_secret"] = client_secret
@@ -15,19 +15,22 @@ if "GRPC_VERBOSITY" not in os.environ:
15
15
  #### Has to be before grpc
16
16
 
17
17
  import grpc
18
- from flyteidl.service import admin_pb2_grpc, dataproxy_pb2_grpc
19
-
20
- from flyte._protos.secret import secret_pb2_grpc
21
- from flyte._protos.workflow import run_logs_service_pb2_grpc, run_service_pb2_grpc, task_service_pb2_grpc
18
+ from flyteidl.service import admin_pb2_grpc, dataproxy_pb2_grpc, identity_pb2_grpc
19
+ from flyteidl2.secret import secret_pb2_grpc
20
+ from flyteidl2.task import task_service_pb2_grpc
21
+ from flyteidl2.trigger import trigger_service_pb2_grpc
22
+ from flyteidl2.workflow import run_logs_service_pb2_grpc, run_service_pb2_grpc
22
23
 
23
24
  from ._protocols import (
24
25
  DataProxyService,
26
+ IdentityService,
25
27
  MetadataServiceProtocol,
26
28
  ProjectDomainService,
27
29
  RunLogsService,
28
30
  RunService,
29
31
  SecretService,
30
32
  TaskService,
33
+ TriggerService,
31
34
  )
32
35
  from .auth import create_channel
33
36
 
@@ -38,7 +41,6 @@ class ClientSet:
38
41
  channel: grpc.aio.Channel,
39
42
  endpoint: str,
40
43
  insecure: bool = False,
41
- data_proxy_channel: grpc.aio.Channel | None = None,
42
44
  **kwargs,
43
45
  ):
44
46
  self.endpoint = endpoint
@@ -50,6 +52,8 @@ class ClientSet:
50
52
  self._dataproxy = dataproxy_pb2_grpc.DataProxyServiceStub(channel=channel)
51
53
  self._log_service = run_logs_service_pb2_grpc.RunLogsServiceStub(channel=channel)
52
54
  self._secrets_service = secret_pb2_grpc.SecretServiceStub(channel=channel)
55
+ self._identity_service = identity_pb2_grpc.IdentityServiceStub(channel=channel)
56
+ self._trigger_service = trigger_service_pb2_grpc.TriggerServiceStub(channel=channel)
53
57
 
54
58
  @classmethod
55
59
  async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet:
@@ -105,5 +109,13 @@ class ClientSet:
105
109
  def secrets_service(self) -> SecretService:
106
110
  return self._secrets_service
107
111
 
112
+ @property
113
+ def identity_service(self) -> IdentityService:
114
+ return self._identity_service
115
+
116
+ @property
117
+ def trigger_service(self) -> TriggerService:
118
+ return self._trigger_service
119
+
108
120
  async def close(self, grace: float | None = None):
109
121
  return await self._channel.close(grace=grace)
flyte/remote/_console.py CHANGED
@@ -9,8 +9,9 @@ def _get_http_domain(endpoint: str, insecure: bool) -> str:
9
9
  else:
10
10
  domain = parsed.netloc or parsed.path
11
11
  # TODO: make console url configurable
12
- if domain.split(":")[0] == "localhost":
13
- domain = "localhost:8080"
12
+ domain_split = domain.split(":")
13
+ if domain_split[0] == "localhost":
14
+ domain = domain if len(domain_split) > 1 else f"{domain}:8080"
14
15
  return f"{scheme}://{domain}"
15
16
 
16
17
 
flyte/remote/_data.py CHANGED
@@ -15,7 +15,7 @@ import httpx
15
15
  from flyteidl.service import dataproxy_pb2
16
16
  from google.protobuf import duration_pb2
17
17
 
18
- from flyte._initialize import CommonInit, ensure_client, get_client, get_common_config
18
+ from flyte._initialize import CommonInit, ensure_client, get_client, get_init_config, require_project_and_domain
19
19
  from flyte.errors import InitializationError, RuntimeSystemError
20
20
  from flyte.syncify import syncify
21
21
 
@@ -54,6 +54,7 @@ def hash_file(file_path: typing.Union[os.PathLike, str]) -> Tuple[bytes, str, in
54
54
  return h.digest(), h.hexdigest(), size
55
55
 
56
56
 
57
+ @require_project_and_domain
57
58
  async def _upload_single_file(
58
59
  cfg: CommonInit, fp: Path, verify: bool = True, basedir: str | None = None
59
60
  ) -> Tuple[str, str]:
@@ -124,7 +125,7 @@ async def upload_file(fp: Path, verify: bool = True) -> Tuple[str, str]:
124
125
  """
125
126
  # This is a placeholder implementation. Replace with actual upload logic.
126
127
  ensure_client()
127
- cfg = get_common_config()
128
+ cfg = get_init_config()
128
129
  if not fp.is_file():
129
130
  raise ValueError(f"{fp} is not a single file, upload arg must be a single file.")
130
131
  return await _upload_single_file(cfg, fp, verify=verify)
@@ -140,7 +141,7 @@ async def upload_dir(dir_path: Path, verify: bool = True) -> str:
140
141
  """
141
142
  # This is a placeholder implementation. Replace with actual upload logic.
142
143
  ensure_client()
143
- cfg = get_common_config()
144
+ cfg = get_init_config()
144
145
  if not dir_path.is_dir():
145
146
  raise ValueError(f"{dir_path} is not a directory, upload arg must be a directory.")
146
147
 
flyte/remote/_logs.py CHANGED
@@ -4,6 +4,9 @@ from dataclasses import dataclass
4
4
  from typing import AsyncGenerator, AsyncIterator
5
5
 
6
6
  import grpc
7
+ from flyteidl2.common import identifier_pb2
8
+ from flyteidl2.logs.dataplane import payload_pb2
9
+ from flyteidl2.workflow import run_logs_service_pb2
7
10
  from rich.console import Console
8
11
  from rich.live import Live
9
12
  from rich.panel import Panel
@@ -11,9 +14,6 @@ from rich.text import Text
11
14
 
12
15
  from flyte._initialize import ensure_client, get_client
13
16
  from flyte._logging import logger
14
- from flyte._protos.common import identifier_pb2
15
- from flyte._protos.logs.dataplane import payload_pb2
16
- from flyte._protos.workflow import run_logs_service_pb2
17
17
  from flyte._tools import ipython_check, ipywidgets_check
18
18
  from flyte.errors import LogsNotYetAvailableError
19
19
  from flyte.syncify import syncify
flyte/remote/_run.py CHANGED
@@ -5,10 +5,11 @@ from typing import AsyncGenerator, AsyncIterator, Literal, Tuple
5
5
 
6
6
  import grpc
7
7
  import rich.repr
8
+ from flyteidl2.common import identifier_pb2, list_pb2
9
+ from flyteidl2.workflow import run_definition_pb2, run_service_pb2
8
10
 
9
- from flyte._initialize import ensure_client, get_client, get_common_config
10
- from flyte._protos.common import identifier_pb2, list_pb2
11
- from flyte._protos.workflow import run_definition_pb2, run_service_pb2
11
+ from flyte._initialize import ensure_client, get_client, get_init_config
12
+ from flyte._logging import logger
12
13
  from flyte.syncify import syncify
13
14
 
14
15
  from . import Action, ActionDetails, ActionInputs, ActionOutputs
@@ -16,6 +17,11 @@ from ._action import _action_details_rich_repr, _action_rich_repr
16
17
  from ._common import ToJSONMixin
17
18
  from ._console import get_run_url
18
19
 
20
+ # @kumare3 is sadpanda, because we have to create a mirror of phase types here, because protobuf phases are ghastly
21
+ Phase = Literal[
22
+ "queued", "waiting_for_resources", "initializing", "running", "succeeded", "failed", "aborted", "timed_out"
23
+ ]
24
+
19
25
 
20
26
  @dataclass
21
27
  class Run(ToJSONMixin):
@@ -40,14 +46,16 @@ class Run(ToJSONMixin):
40
46
  @classmethod
41
47
  async def listall(
42
48
  cls,
43
- filters: str | None = None,
49
+ in_phase: Tuple[Phase] | None = None,
50
+ created_by_subject: str | None = None,
44
51
  sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
45
52
  limit: int = 100,
46
53
  ) -> AsyncIterator[Run]:
47
54
  """
48
55
  Get all runs for the current project and domain.
49
56
 
50
- :param filters: The filters to apply to the project list.
57
+ :param in_phase: Filter runs by one or more phases.
58
+ :param created_by_subject: Filter runs by the subject that created them. (this is not username, but the subject)
51
59
  :param sort_by: The sorting criteria for the project list, in the format (field, order).
52
60
  :param limit: The maximum number of runs to return.
53
61
  :return: An iterator of runs.
@@ -59,13 +67,44 @@ class Run(ToJSONMixin):
59
67
  key=sort_by[0],
60
68
  direction=(list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING),
61
69
  )
62
- cfg = get_common_config()
70
+ filters = []
71
+ if in_phase:
72
+ phases = [str(run_definition_pb2.Phase.Value(f"PHASE_{p.upper()}")) for p in in_phase]
73
+ logger.debug(f"Fetching run phases: {phases}")
74
+ if len(phases) > 1:
75
+ filters.append(
76
+ list_pb2.Filter(
77
+ function=list_pb2.Filter.Function.VALUE_IN,
78
+ field="phase",
79
+ values=phases,
80
+ ),
81
+ )
82
+ else:
83
+ filters.append(
84
+ list_pb2.Filter(
85
+ function=list_pb2.Filter.Function.EQUAL,
86
+ field="phase",
87
+ values=phases[0],
88
+ ),
89
+ )
90
+ if created_by_subject:
91
+ logger.debug(f"Fetching runs created by: {created_by_subject}")
92
+ filters.append(
93
+ list_pb2.Filter(
94
+ function=list_pb2.Filter.Function.EQUAL,
95
+ field="created_by",
96
+ values=[created_by_subject],
97
+ ),
98
+ )
99
+
100
+ cfg = get_init_config()
63
101
  i = 0
64
102
  while True:
65
103
  req = list_pb2.ListRequest(
66
104
  limit=min(100, limit),
67
105
  token=token,
68
106
  sort_by=sort_pb2,
107
+ filters=filters,
69
108
  )
70
109
  resp = await get_client().run_service.ListRuns(
71
110
  run_service_pb2.ListRunsRequest(
@@ -225,6 +264,7 @@ class Run(ToJSONMixin):
225
264
  """
226
265
  Rich representation of the Run object.
227
266
  """
267
+ yield "url", f"[blue bold][link={self.url}]link[/link][/blue bold]"
228
268
  yield from _action_rich_repr(self.pb2.action)
229
269
 
230
270
  def __repr__(self) -> str:
@@ -276,7 +316,7 @@ class RunDetails(ToJSONMixin):
276
316
  :param name: The name of the run.
277
317
  """
278
318
  ensure_client()
279
- cfg = get_common_config()
319
+ cfg = get_init_config()
280
320
  return await RunDetails.get_details.aio(
281
321
  run_id=identifier_pb2.RunIdentifier(
282
322
  org=cfg.org,