flyte 2.0.0b22__py3-none-any.whl → 2.0.0b23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (88) hide show
  1. flyte/__init__.py +5 -0
  2. flyte/_bin/runtime.py +35 -5
  3. flyte/_cache/cache.py +4 -2
  4. flyte/_cache/local_cache.py +215 -0
  5. flyte/_code_bundle/bundle.py +1 -0
  6. flyte/_debug/constants.py +0 -1
  7. flyte/_debug/vscode.py +6 -1
  8. flyte/_deploy.py +193 -52
  9. flyte/_environment.py +5 -0
  10. flyte/_excepthook.py +1 -1
  11. flyte/_image.py +101 -72
  12. flyte/_initialize.py +23 -0
  13. flyte/_internal/controllers/_local_controller.py +64 -24
  14. flyte/_internal/controllers/remote/_action.py +4 -1
  15. flyte/_internal/controllers/remote/_controller.py +5 -2
  16. flyte/_internal/controllers/remote/_core.py +6 -3
  17. flyte/_internal/controllers/remote/_informer.py +1 -1
  18. flyte/_internal/imagebuild/docker_builder.py +92 -28
  19. flyte/_internal/imagebuild/image_builder.py +7 -13
  20. flyte/_internal/imagebuild/remote_builder.py +6 -1
  21. flyte/_internal/runtime/io.py +13 -1
  22. flyte/_internal/runtime/rusty.py +17 -2
  23. flyte/_internal/runtime/task_serde.py +14 -20
  24. flyte/_internal/runtime/taskrunner.py +1 -1
  25. flyte/_internal/runtime/trigger_serde.py +153 -0
  26. flyte/_logging.py +1 -1
  27. flyte/_protos/common/identifier_pb2.py +19 -1
  28. flyte/_protos/common/identifier_pb2.pyi +22 -0
  29. flyte/_protos/workflow/common_pb2.py +14 -3
  30. flyte/_protos/workflow/common_pb2.pyi +49 -0
  31. flyte/_protos/workflow/queue_service_pb2.py +41 -35
  32. flyte/_protos/workflow/queue_service_pb2.pyi +26 -12
  33. flyte/_protos/workflow/queue_service_pb2_grpc.py +34 -0
  34. flyte/_protos/workflow/run_definition_pb2.py +38 -38
  35. flyte/_protos/workflow/run_definition_pb2.pyi +4 -2
  36. flyte/_protos/workflow/run_service_pb2.py +60 -50
  37. flyte/_protos/workflow/run_service_pb2.pyi +24 -6
  38. flyte/_protos/workflow/run_service_pb2_grpc.py +34 -0
  39. flyte/_protos/workflow/task_definition_pb2.py +15 -11
  40. flyte/_protos/workflow/task_definition_pb2.pyi +19 -2
  41. flyte/_protos/workflow/task_service_pb2.py +18 -17
  42. flyte/_protos/workflow/task_service_pb2.pyi +5 -2
  43. flyte/_protos/workflow/trigger_definition_pb2.py +66 -0
  44. flyte/_protos/workflow/trigger_definition_pb2.pyi +117 -0
  45. flyte/_protos/workflow/trigger_definition_pb2_grpc.py +4 -0
  46. flyte/_protos/workflow/trigger_service_pb2.py +96 -0
  47. flyte/_protos/workflow/trigger_service_pb2.pyi +110 -0
  48. flyte/_protos/workflow/trigger_service_pb2_grpc.py +281 -0
  49. flyte/_run.py +42 -15
  50. flyte/_task.py +35 -4
  51. flyte/_task_environment.py +60 -15
  52. flyte/_trigger.py +382 -0
  53. flyte/_version.py +3 -3
  54. flyte/cli/_abort.py +3 -3
  55. flyte/cli/_build.py +1 -3
  56. flyte/cli/_common.py +15 -2
  57. flyte/cli/_create.py +74 -0
  58. flyte/cli/_delete.py +23 -1
  59. flyte/cli/_deploy.py +5 -9
  60. flyte/cli/_get.py +75 -34
  61. flyte/cli/_params.py +4 -2
  62. flyte/cli/_run.py +12 -3
  63. flyte/cli/_update.py +36 -0
  64. flyte/cli/_user.py +17 -0
  65. flyte/cli/main.py +9 -1
  66. flyte/errors.py +9 -0
  67. flyte/io/_dir.py +513 -115
  68. flyte/io/_file.py +495 -135
  69. flyte/models.py +32 -0
  70. flyte/remote/__init__.py +6 -1
  71. flyte/remote/_client/_protocols.py +36 -2
  72. flyte/remote/_client/controlplane.py +19 -3
  73. flyte/remote/_run.py +42 -2
  74. flyte/remote/_task.py +14 -1
  75. flyte/remote/_trigger.py +308 -0
  76. flyte/remote/_user.py +33 -0
  77. flyte/storage/__init__.py +6 -1
  78. flyte/storage/_storage.py +119 -101
  79. flyte/types/_pickle.py +16 -3
  80. {flyte-2.0.0b22.data → flyte-2.0.0b23.data}/scripts/runtime.py +35 -5
  81. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/METADATA +3 -1
  82. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/RECORD +87 -75
  83. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  84. {flyte-2.0.0b22.data → flyte-2.0.0b23.data}/scripts/debug.py +0 -0
  85. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/WHEEL +0 -0
  86. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/entry_points.txt +0 -0
  87. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/licenses/LICENSE +0 -0
  88. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/top_level.txt +0 -0
flyte/models.py CHANGED
@@ -77,6 +77,37 @@ class ActionID:
77
77
  return self.new_sub_action(new_name)
78
78
 
79
79
 
80
+ @rich.repr.auto
81
+ @dataclass
82
+ class PathRewrite:
83
+ """
84
+ Configuration for rewriting paths during input loading.
85
+ """
86
+
87
+ # If set, rewrites any path starting with this prefix to the new prefix.
88
+ old_prefix: str
89
+ new_prefix: str
90
+
91
+ def __post_init__(self):
92
+ if not self.old_prefix or not self.new_prefix:
93
+ raise ValueError("Both old_prefix and new_prefix must be non-empty strings.")
94
+ if self.old_prefix == self.new_prefix:
95
+ raise ValueError("old_prefix and new_prefix must be different.")
96
+
97
+ @classmethod
98
+ def from_str(cls, pattern: str) -> PathRewrite:
99
+ """
100
+ Create a PathRewrite from a string pattern of the form `old_prefix->new_prefix`.
101
+ """
102
+ parts = pattern.split("->")
103
+ if len(parts) != 2:
104
+ raise ValueError(f"Invalid path rewrite pattern: {pattern}. Expected format 'old_prefix->new_prefix'.")
105
+ return cls(old_prefix=parts[0], new_prefix=parts[1])
106
+
107
+ def __repr__(self) -> str:
108
+ return f"{self.old_prefix}->{self.new_prefix}"
109
+
110
+
80
111
  @rich.repr.auto
81
112
  @dataclass(frozen=True, kw_only=True)
82
113
  class RawDataPath:
@@ -86,6 +117,7 @@ class RawDataPath:
86
117
  """
87
118
 
88
119
  path: str
120
+ path_rewrite: Optional[PathRewrite] = None
89
121
 
90
122
  @classmethod
91
123
  def from_local_folder(cls, local_folder: str | pathlib.Path | None = None) -> RawDataPath:
flyte/remote/__init__.py CHANGED
@@ -7,12 +7,15 @@ __all__ = [
7
7
  "ActionDetails",
8
8
  "ActionInputs",
9
9
  "ActionOutputs",
10
+ "Phase",
10
11
  "Project",
11
12
  "Run",
12
13
  "RunDetails",
13
14
  "Secret",
14
15
  "SecretTypes",
15
16
  "Task",
17
+ "Trigger",
18
+ "User",
16
19
  "create_channel",
17
20
  "upload_dir",
18
21
  "upload_file",
@@ -22,6 +25,8 @@ from ._action import Action, ActionDetails, ActionInputs, ActionOutputs
22
25
  from ._client.auth import create_channel
23
26
  from ._data import upload_dir, upload_file
24
27
  from ._project import Project
25
- from ._run import Run, RunDetails
28
+ from ._run import Phase, Run, RunDetails
26
29
  from ._secret import Secret, SecretTypes
27
30
  from ._task import Task
31
+ from ._trigger import Trigger
32
+ from ._user import User
@@ -1,12 +1,12 @@
1
1
  from typing import AsyncIterator, Protocol
2
2
 
3
3
  from flyteidl.admin import project_attributes_pb2, project_pb2, version_pb2
4
- from flyteidl.service import dataproxy_pb2
4
+ from flyteidl.service import dataproxy_pb2, identity_pb2
5
5
  from grpc.aio import UnaryStreamCall
6
6
  from grpc.aio._typing import RequestType
7
7
 
8
8
  from flyte._protos.secret import payload_pb2
9
- from flyte._protos.workflow import run_logs_service_pb2, run_service_pb2, task_service_pb2
9
+ from flyte._protos.workflow import run_logs_service_pb2, run_service_pb2, task_service_pb2, trigger_service_pb2
10
10
 
11
11
 
12
12
  class MetadataServiceProtocol(Protocol):
@@ -131,3 +131,37 @@ class SecretService(Protocol):
131
131
  async def ListSecrets(self, request: payload_pb2.ListSecretsRequest) -> payload_pb2.ListSecretsResponse: ...
132
132
 
133
133
  async def DeleteSecret(self, request: payload_pb2.DeleteSecretRequest) -> payload_pb2.DeleteSecretResponse: ...
134
+
135
+
136
+ class IdentityService(Protocol):
137
+ async def UserInfo(self, request: identity_pb2.UserInfoRequest) -> identity_pb2.UserInfoResponse: ...
138
+
139
+
140
+ class TriggerService(Protocol):
141
+ async def DeployTrigger(
142
+ self, request: trigger_service_pb2.DeployTriggerRequest
143
+ ) -> trigger_service_pb2.DeployTriggerResponse: ...
144
+
145
+ async def GetTriggerDetails(
146
+ self, request: trigger_service_pb2.GetTriggerDetailsRequest
147
+ ) -> trigger_service_pb2.GetTriggerDetailsResponse: ...
148
+
149
+ async def GetTriggerRevisionDetails(
150
+ self, request: trigger_service_pb2.GetTriggerRevisionDetailsRequest
151
+ ) -> trigger_service_pb2.GetTriggerRevisionDetailsResponse: ...
152
+
153
+ async def ListTriggers(
154
+ self, request: trigger_service_pb2.ListTriggersRequest
155
+ ) -> trigger_service_pb2.ListTriggersResponse: ...
156
+
157
+ async def GetTriggerRevisionHistory(
158
+ self, request: trigger_service_pb2.GetTriggerRevisionHistoryRequest
159
+ ) -> trigger_service_pb2.GetTriggerRevisionHistoryResponse: ...
160
+
161
+ async def UpdateTriggers(
162
+ self, request: trigger_service_pb2.UpdateTriggersRequest
163
+ ) -> trigger_service_pb2.UpdateTriggersResponse: ...
164
+
165
+ async def DeleteTriggers(
166
+ self, request: trigger_service_pb2.DeleteTriggersRequest
167
+ ) -> trigger_service_pb2.DeleteTriggersResponse: ...
@@ -15,19 +15,26 @@ if "GRPC_VERBOSITY" not in os.environ:
15
15
  #### Has to be before grpc
16
16
 
17
17
  import grpc
18
- from flyteidl.service import admin_pb2_grpc, dataproxy_pb2_grpc
18
+ from flyteidl.service import admin_pb2_grpc, dataproxy_pb2_grpc, identity_pb2_grpc
19
19
 
20
20
  from flyte._protos.secret import secret_pb2_grpc
21
- from flyte._protos.workflow import run_logs_service_pb2_grpc, run_service_pb2_grpc, task_service_pb2_grpc
21
+ from flyte._protos.workflow import (
22
+ run_logs_service_pb2_grpc,
23
+ run_service_pb2_grpc,
24
+ task_service_pb2_grpc,
25
+ trigger_service_pb2_grpc,
26
+ )
22
27
 
23
28
  from ._protocols import (
24
29
  DataProxyService,
30
+ IdentityService,
25
31
  MetadataServiceProtocol,
26
32
  ProjectDomainService,
27
33
  RunLogsService,
28
34
  RunService,
29
35
  SecretService,
30
36
  TaskService,
37
+ TriggerService,
31
38
  )
32
39
  from .auth import create_channel
33
40
 
@@ -38,7 +45,6 @@ class ClientSet:
38
45
  channel: grpc.aio.Channel,
39
46
  endpoint: str,
40
47
  insecure: bool = False,
41
- data_proxy_channel: grpc.aio.Channel | None = None,
42
48
  **kwargs,
43
49
  ):
44
50
  self.endpoint = endpoint
@@ -50,6 +56,8 @@ class ClientSet:
50
56
  self._dataproxy = dataproxy_pb2_grpc.DataProxyServiceStub(channel=channel)
51
57
  self._log_service = run_logs_service_pb2_grpc.RunLogsServiceStub(channel=channel)
52
58
  self._secrets_service = secret_pb2_grpc.SecretServiceStub(channel=channel)
59
+ self._identity_service = identity_pb2_grpc.IdentityServiceStub(channel=channel)
60
+ self._trigger_service = trigger_service_pb2_grpc.TriggerServiceStub(channel=channel)
53
61
 
54
62
  @classmethod
55
63
  async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet:
@@ -105,5 +113,13 @@ class ClientSet:
105
113
  def secrets_service(self) -> SecretService:
106
114
  return self._secrets_service
107
115
 
116
+ @property
117
+ def identity_service(self) -> IdentityService:
118
+ return self._identity_service
119
+
120
+ @property
121
+ def trigger_service(self) -> TriggerService:
122
+ return self._trigger_service
123
+
108
124
  async def close(self, grace: float | None = None):
109
125
  return await self._channel.close(grace=grace)
flyte/remote/_run.py CHANGED
@@ -7,6 +7,7 @@ import grpc
7
7
  import rich.repr
8
8
 
9
9
  from flyte._initialize import ensure_client, get_client, get_common_config
10
+ from flyte._logging import logger
10
11
  from flyte._protos.common import identifier_pb2, list_pb2
11
12
  from flyte._protos.workflow import run_definition_pb2, run_service_pb2
12
13
  from flyte.syncify import syncify
@@ -16,6 +17,11 @@ from ._action import _action_details_rich_repr, _action_rich_repr
16
17
  from ._common import ToJSONMixin
17
18
  from ._console import get_run_url
18
19
 
20
+ # @kumare3 is sadpanda, because we have to create a mirror of phase types here, because protobuf phases are ghastly
21
+ Phase = Literal[
22
+ "queued", "waiting_for_resources", "initializing", "running", "succeeded", "failed", "aborted", "timed_out"
23
+ ]
24
+
19
25
 
20
26
  @dataclass
21
27
  class Run(ToJSONMixin):
@@ -40,14 +46,16 @@ class Run(ToJSONMixin):
40
46
  @classmethod
41
47
  async def listall(
42
48
  cls,
43
- filters: str | None = None,
49
+ in_phase: Tuple[Phase] | None = None,
50
+ created_by_subject: str | None = None,
44
51
  sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
45
52
  limit: int = 100,
46
53
  ) -> AsyncIterator[Run]:
47
54
  """
48
55
  Get all runs for the current project and domain.
49
56
 
50
- :param filters: The filters to apply to the project list.
57
+ :param in_phase: Filter runs by one or more phases.
58
+ :param created_by_subject: Filter runs by the subject that created them. (this is not username, but the subject)
51
59
  :param sort_by: The sorting criteria for the project list, in the format (field, order).
52
60
  :param limit: The maximum number of runs to return.
53
61
  :return: An iterator of runs.
@@ -59,6 +67,36 @@ class Run(ToJSONMixin):
59
67
  key=sort_by[0],
60
68
  direction=(list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING),
61
69
  )
70
+ filters = []
71
+ if in_phase:
72
+ phases = [str(run_definition_pb2.Phase.Value(f"PHASE_{p.upper()}")) for p in in_phase]
73
+ logger.debug(f"Fetching run phases: {phases}")
74
+ if len(phases) > 1:
75
+ filters.append(
76
+ list_pb2.Filter(
77
+ function=list_pb2.Filter.Function.VALUE_IN,
78
+ field="phase",
79
+ values=phases,
80
+ ),
81
+ )
82
+ else:
83
+ filters.append(
84
+ list_pb2.Filter(
85
+ function=list_pb2.Filter.Function.EQUAL,
86
+ field="phase",
87
+ values=phases[0],
88
+ ),
89
+ )
90
+ if created_by_subject:
91
+ logger.debug(f"Fetching runs created by: {created_by_subject}")
92
+ filters.append(
93
+ list_pb2.Filter(
94
+ function=list_pb2.Filter.Function.EQUAL,
95
+ field="created_by",
96
+ values=[created_by_subject],
97
+ ),
98
+ )
99
+
62
100
  cfg = get_common_config()
63
101
  i = 0
64
102
  while True:
@@ -66,6 +104,7 @@ class Run(ToJSONMixin):
66
104
  limit=min(100, limit),
67
105
  token=token,
68
106
  sort_by=sort_pb2,
107
+ filters=filters,
69
108
  )
70
109
  resp = await get_client().run_service.ListRuns(
71
110
  run_service_pb2.ListRunsRequest(
@@ -225,6 +264,7 @@ class Run(ToJSONMixin):
225
264
  """
226
265
  Rich representation of the Run object.
227
266
  """
267
+ yield "url", f"[blue bold][link={self.url}]link[/link][/blue bold]"
228
268
  yield from _action_rich_repr(self.pb2.action)
229
269
 
230
270
  def __repr__(self) -> str:
flyte/remote/_task.py CHANGED
@@ -99,6 +99,7 @@ AutoVersioning = Literal["latest", "current"]
99
99
  class TaskDetails(ToJSONMixin):
100
100
  pb2: task_definition_pb2.TaskDetails
101
101
  max_inline_io_bytes: int = 10 * 1024 * 1024 # 10 MB
102
+ overriden_queue: Optional[str] = None
102
103
 
103
104
  @classmethod
104
105
  def get(
@@ -284,6 +285,13 @@ class TaskDetails(ToJSONMixin):
284
285
  f"Reference tasks [{self.name}] cannot be executed locally, only remotely."
285
286
  )
286
287
 
288
+ @property
289
+ def queue(self) -> Optional[str]:
290
+ """
291
+ The queue to use for the task.
292
+ """
293
+ return self.overriden_queue
294
+
287
295
  def override(
288
296
  self,
289
297
  *,
@@ -295,6 +303,7 @@ class TaskDetails(ToJSONMixin):
295
303
  secrets: Optional[flyte.SecretRequest] = None,
296
304
  max_inline_io_bytes: Optional[int] = None,
297
305
  cache: Optional[flyte.Cache] = None,
306
+ queue: Optional[str] = None,
298
307
  **kwargs: Any,
299
308
  ) -> TaskDetails:
300
309
  if len(kwargs) > 0:
@@ -342,7 +351,11 @@ class TaskDetails(ToJSONMixin):
342
351
  md.cache_serializable = cache.serialize
343
352
  md.cache_ignore_input_vars[:] = list(cache.ignored_inputs or ())
344
353
 
345
- return TaskDetails(pb2, max_inline_io_bytes=max_inline_io_bytes or self.max_inline_io_bytes)
354
+ return TaskDetails(
355
+ pb2,
356
+ max_inline_io_bytes=max_inline_io_bytes or self.max_inline_io_bytes,
357
+ overriden_queue=queue,
358
+ )
346
359
 
347
360
  def __rich_repr__(self) -> rich.repr.Result:
348
361
  """
@@ -0,0 +1,308 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from functools import cached_property
5
+ from typing import AsyncIterator
6
+
7
+ import grpc.aio
8
+
9
+ import flyte
10
+ from flyte._initialize import ensure_client, get_client, get_common_config
11
+ from flyte._internal.runtime import trigger_serde
12
+ from flyte._protos.common import identifier_pb2, list_pb2
13
+ from flyte._protos.workflow import common_pb2, task_definition_pb2, trigger_definition_pb2, trigger_service_pb2
14
+ from flyte.syncify import syncify
15
+
16
+ from ._common import ToJSONMixin
17
+ from ._task import Task, TaskDetails
18
+
19
+
20
+ @dataclass
21
+ class TriggerDetails(ToJSONMixin):
22
+ pb2: trigger_definition_pb2.TriggerDetails
23
+
24
+ @syncify
25
+ @classmethod
26
+ async def get(cls, *, name: str, task_name: str) -> TriggerDetails:
27
+ """
28
+ Retrieve detailed information about a specific trigger by its name.
29
+ """
30
+ ensure_client()
31
+ cfg = get_common_config()
32
+ resp = await get_client().trigger_service.GetTriggerDetails(
33
+ request=trigger_service_pb2.GetTriggerDetailsRequest(
34
+ name=identifier_pb2.TriggerName(
35
+ task_name=task_name,
36
+ name=name,
37
+ org=cfg.org,
38
+ project=cfg.project,
39
+ domain=cfg.domain,
40
+ ),
41
+ )
42
+ )
43
+ return cls(pb2=resp.trigger)
44
+
45
+ @property
46
+ def name(self) -> str:
47
+ return self.id.name.name
48
+
49
+ @property
50
+ def id(self) -> identifier_pb2.TriggerIdentifier:
51
+ return self.pb2.id
52
+
53
+ @property
54
+ def task_name(self) -> str:
55
+ return self.pb2.id.name.task_name
56
+
57
+ @property
58
+ def automation_spec(self) -> common_pb2.TriggerAutomationSpec:
59
+ return self.pb2.automation_spec
60
+
61
+ @property
62
+ def metadata(self) -> trigger_definition_pb2.TriggerMetadata:
63
+ return self.pb2.metadata
64
+
65
+ @property
66
+ def status(self) -> trigger_definition_pb2.TriggerStatus:
67
+ return self.pb2.status
68
+
69
+ @property
70
+ def is_active(self) -> bool:
71
+ return self.pb2.spec.active
72
+
73
+ @cached_property
74
+ def trigger(self) -> trigger_definition_pb2.Trigger:
75
+ return trigger_definition_pb2.Trigger(
76
+ id=self.pb2.id,
77
+ automation_spec=self.automation_spec,
78
+ metadata=self.metadata,
79
+ status=self.status,
80
+ active=self.is_active,
81
+ )
82
+
83
+
84
+ @dataclass
85
+ class Trigger(ToJSONMixin):
86
+ pb2: trigger_definition_pb2.Trigger
87
+ details: TriggerDetails | None = None
88
+
89
+ @syncify
90
+ @classmethod
91
+ async def create(
92
+ cls,
93
+ trigger: flyte.Trigger,
94
+ task_name: str,
95
+ task_version: str | None = None,
96
+ ) -> Trigger:
97
+ """
98
+ Create a new trigger in the Flyte platform.
99
+
100
+ :param trigger: The flyte.Trigger object containing the trigger definition.
101
+ :param task_name: Optional name of the task to associate with the trigger.
102
+ """
103
+ ensure_client()
104
+ cfg = get_common_config()
105
+
106
+ # Fetch the task to ensure it exists and to get its input definitions
107
+ try:
108
+ lazy = (
109
+ Task.get(name=task_name, version=task_version)
110
+ if task_version
111
+ else Task.get(name=task_name, auto_version="latest")
112
+ )
113
+ task: TaskDetails = await lazy.fetch.aio()
114
+
115
+ task_trigger = await trigger_serde.to_task_trigger(
116
+ t=trigger,
117
+ task_name=task_name,
118
+ task_inputs=task.pb2.spec.task_template.interface.inputs,
119
+ task_default_inputs=list(task.pb2.spec.default_inputs),
120
+ )
121
+
122
+ resp = await get_client().trigger_service.DeployTrigger(
123
+ request=trigger_service_pb2.DeployTriggerRequest(
124
+ id=identifier_pb2.TriggerIdentifier(
125
+ name=identifier_pb2.TriggerName(
126
+ name=trigger.name,
127
+ task_name=task_name,
128
+ org=cfg.org,
129
+ project=cfg.project,
130
+ domain=cfg.domain,
131
+ ),
132
+ revision=1,
133
+ ),
134
+ spec=trigger_definition_pb2.TriggerSpec(
135
+ active=task_trigger.spec.active,
136
+ inputs=task_trigger.spec.inputs,
137
+ run_spec=task_trigger.spec.run_spec,
138
+ task_version=task.version,
139
+ ),
140
+ automation_spec=task_trigger.automation_spec,
141
+ )
142
+ )
143
+
144
+ details = TriggerDetails(pb2=resp.trigger)
145
+
146
+ return cls(pb2=details.trigger, details=details)
147
+ except grpc.aio.AioRpcError as e:
148
+ if e.code() == grpc.StatusCode.NOT_FOUND:
149
+ raise ValueError(f"Task {task_name}:{task_version or 'latest'} not found") from e
150
+ raise
151
+
152
+ @syncify
153
+ @classmethod
154
+ async def get(cls, *, name: str, task_name: str) -> TriggerDetails:
155
+ """
156
+ Retrieve a trigger by its name and associated task name.
157
+ """
158
+ return await TriggerDetails.get(name=name, task_name=task_name)
159
+
160
+ @syncify
161
+ @classmethod
162
+ async def listall(
163
+ cls, task_name: str | None = None, task_version: str | None = None, limit: int = 100
164
+ ) -> AsyncIterator[Trigger]:
165
+ """
166
+ List all triggers associated with a specific task or all tasks if no task name is provided.
167
+ """
168
+ ensure_client()
169
+ cfg = get_common_config()
170
+ token = None
171
+ # task_name_id = None TODO: implement listing by task name only
172
+ project_id = None
173
+ task_id = None
174
+ if task_name and task_version:
175
+ task_id = task_definition_pb2.TaskIdentifier(
176
+ name=task_name,
177
+ project=cfg.project,
178
+ domain=cfg.domain,
179
+ org=cfg.org,
180
+ version=task_version,
181
+ )
182
+ # elif task_name: TODO: implement listing by task name only
183
+ # task_name_id = task_definition_pb2.TaskName(
184
+ # name=task_name,
185
+ # project=cfg.project,
186
+ # domain=cfg.domain,
187
+ # org=cfg.org,
188
+ # )
189
+ else:
190
+ project_id = identifier_pb2.ProjectIdentifier(
191
+ organization=cfg.org,
192
+ domain=cfg.domain,
193
+ name=cfg.project,
194
+ )
195
+
196
+ while True:
197
+ resp = await get_client().trigger_service.ListTriggers(
198
+ request=trigger_service_pb2.ListTriggersRequest(
199
+ project_id=project_id,
200
+ task_id=task_id,
201
+ # task_name=task_name_id,
202
+ request=list_pb2.ListRequest(
203
+ limit=limit,
204
+ token=token,
205
+ ),
206
+ )
207
+ )
208
+ token = resp.token
209
+ for r in resp.triggers:
210
+ yield cls(r)
211
+ if not token:
212
+ break
213
+
214
+ @syncify
215
+ @classmethod
216
+ async def update(cls, name: str, task_name: str, active: bool):
217
+ """
218
+ Pause a trigger by its name and associated task name.
219
+ """
220
+ ensure_client()
221
+ cfg = get_common_config()
222
+ await get_client().trigger_service.UpdateTriggers(
223
+ request=trigger_service_pb2.UpdateTriggersRequest(
224
+ names=[
225
+ identifier_pb2.TriggerName(
226
+ org=cfg.org,
227
+ project=cfg.project,
228
+ domain=cfg.domain,
229
+ name=name,
230
+ task_name=task_name,
231
+ )
232
+ ],
233
+ active=active,
234
+ )
235
+ )
236
+
237
+ @syncify
238
+ @classmethod
239
+ async def delete(cls, name: str, task_name: str):
240
+ """
241
+ Delete a trigger by its name.
242
+ """
243
+ ensure_client()
244
+ cfg = get_common_config()
245
+ await get_client().trigger_service.DeleteTriggers(
246
+ request=trigger_service_pb2.DeleteTriggersRequest(
247
+ names=[
248
+ identifier_pb2.TriggerName(
249
+ org=cfg.org,
250
+ project=cfg.project,
251
+ domain=cfg.domain,
252
+ name=name,
253
+ task_name=task_name,
254
+ )
255
+ ],
256
+ )
257
+ )
258
+
259
+ @property
260
+ def id(self) -> identifier_pb2.TriggerIdentifier:
261
+ return self.pb2.id
262
+
263
+ @property
264
+ def name(self) -> str:
265
+ return self.id.name.name
266
+
267
+ @property
268
+ def task_name(self) -> str:
269
+ return self.id.name.task_name
270
+
271
+ @property
272
+ def automation_spec(self) -> common_pb2.TriggerAutomationSpec:
273
+ return self.pb2.automation_spec
274
+
275
+ async def get_details(self) -> TriggerDetails:
276
+ """
277
+ Get detailed information about this trigger.
278
+ """
279
+ if not self.details:
280
+ details = await TriggerDetails.get.aio(name=self.pb2.id.name.name)
281
+ self.details = details
282
+ return self.details
283
+
284
+ @property
285
+ def is_active(self) -> bool:
286
+ return self.pb2.active
287
+
288
+ def _rich_automation(self, automation: common_pb2.TriggerAutomationSpec):
289
+ if automation.type == common_pb2.TriggerAutomationSpec.TYPE_NONE:
290
+ yield "none", None
291
+ elif automation.type == common_pb2.TriggerAutomationSpec.TYPE_SCHEDULE:
292
+ if automation.schedule.cron_expression is not None:
293
+ yield "cron", automation.schedule.cron_expression
294
+ elif automation.schedule.rate is not None:
295
+ r = automation.schedule.rate
296
+ yield (
297
+ "fixed_rate",
298
+ (
299
+ f"Every [{r.value}] {r.unit} starting at "
300
+ f"{r.start_time.ToDatetime() if automation.HasField('start_time') else 'now'}"
301
+ ),
302
+ )
303
+
304
+ def __rich_repr__(self):
305
+ yield "task_name", self.task_name
306
+ yield "name", self.name
307
+ yield from self._rich_automation(self.pb2.automation_spec)
308
+ yield "auto_activate", self.is_active
flyte/remote/_user.py ADDED
@@ -0,0 +1,33 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from flyteidl.service import identity_pb2
6
+ from flyteidl.service.identity_pb2 import UserInfoResponse
7
+
8
+ from .._initialize import ensure_client, get_client
9
+ from ..syncify import syncify
10
+ from ._common import ToJSONMixin
11
+
12
+
13
+ @dataclass
14
+ class User(ToJSONMixin):
15
+ pb2: UserInfoResponse
16
+
17
+ @syncify
18
+ @classmethod
19
+ async def get(cls) -> User:
20
+ """
21
+ Fetches information about the currently logged in user.
22
+ Returns: A User object containing details about the user.
23
+ """
24
+ ensure_client()
25
+
26
+ resp = await get_client().identity_service.UserInfo(identity_pb2.UserInfoRequest())
27
+ return cls(resp)
28
+
29
+ def subject(self) -> str:
30
+ return self.pb2.subject
31
+
32
+ def name(self) -> str:
33
+ return self.pb2.name
flyte/storage/__init__.py CHANGED
@@ -3,6 +3,8 @@ __all__ = [
3
3
  "GCS",
4
4
  "S3",
5
5
  "Storage",
6
+ "exists",
7
+ "exists_sync",
6
8
  "get",
7
9
  "get_configured_fsspec_kwargs",
8
10
  "get_random_local_directory",
@@ -11,13 +13,15 @@ __all__ = [
11
13
  "get_underlying_filesystem",
12
14
  "is_remote",
13
15
  "join",
16
+ "open",
14
17
  "put",
15
18
  "put_stream",
16
- "put_stream",
17
19
  ]
18
20
 
19
21
  from ._config import ABFS, GCS, S3, Storage
20
22
  from ._storage import (
23
+ exists,
24
+ exists_sync,
21
25
  get,
22
26
  get_configured_fsspec_kwargs,
23
27
  get_random_local_directory,
@@ -26,6 +30,7 @@ from ._storage import (
26
30
  get_underlying_filesystem,
27
31
  is_remote,
28
32
  join,
33
+ open,
29
34
  put,
30
35
  put_stream,
31
36
  )