flyte 2.0.0b22__py3-none-any.whl → 2.0.0b23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +5 -0
- flyte/_bin/runtime.py +35 -5
- flyte/_cache/cache.py +4 -2
- flyte/_cache/local_cache.py +215 -0
- flyte/_code_bundle/bundle.py +1 -0
- flyte/_debug/constants.py +0 -1
- flyte/_debug/vscode.py +6 -1
- flyte/_deploy.py +193 -52
- flyte/_environment.py +5 -0
- flyte/_excepthook.py +1 -1
- flyte/_image.py +101 -72
- flyte/_initialize.py +23 -0
- flyte/_internal/controllers/_local_controller.py +64 -24
- flyte/_internal/controllers/remote/_action.py +4 -1
- flyte/_internal/controllers/remote/_controller.py +5 -2
- flyte/_internal/controllers/remote/_core.py +6 -3
- flyte/_internal/controllers/remote/_informer.py +1 -1
- flyte/_internal/imagebuild/docker_builder.py +92 -28
- flyte/_internal/imagebuild/image_builder.py +7 -13
- flyte/_internal/imagebuild/remote_builder.py +6 -1
- flyte/_internal/runtime/io.py +13 -1
- flyte/_internal/runtime/rusty.py +17 -2
- flyte/_internal/runtime/task_serde.py +14 -20
- flyte/_internal/runtime/taskrunner.py +1 -1
- flyte/_internal/runtime/trigger_serde.py +153 -0
- flyte/_logging.py +1 -1
- flyte/_protos/common/identifier_pb2.py +19 -1
- flyte/_protos/common/identifier_pb2.pyi +22 -0
- flyte/_protos/workflow/common_pb2.py +14 -3
- flyte/_protos/workflow/common_pb2.pyi +49 -0
- flyte/_protos/workflow/queue_service_pb2.py +41 -35
- flyte/_protos/workflow/queue_service_pb2.pyi +26 -12
- flyte/_protos/workflow/queue_service_pb2_grpc.py +34 -0
- flyte/_protos/workflow/run_definition_pb2.py +38 -38
- flyte/_protos/workflow/run_definition_pb2.pyi +4 -2
- flyte/_protos/workflow/run_service_pb2.py +60 -50
- flyte/_protos/workflow/run_service_pb2.pyi +24 -6
- flyte/_protos/workflow/run_service_pb2_grpc.py +34 -0
- flyte/_protos/workflow/task_definition_pb2.py +15 -11
- flyte/_protos/workflow/task_definition_pb2.pyi +19 -2
- flyte/_protos/workflow/task_service_pb2.py +18 -17
- flyte/_protos/workflow/task_service_pb2.pyi +5 -2
- flyte/_protos/workflow/trigger_definition_pb2.py +66 -0
- flyte/_protos/workflow/trigger_definition_pb2.pyi +117 -0
- flyte/_protos/workflow/trigger_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/trigger_service_pb2.py +96 -0
- flyte/_protos/workflow/trigger_service_pb2.pyi +110 -0
- flyte/_protos/workflow/trigger_service_pb2_grpc.py +281 -0
- flyte/_run.py +42 -15
- flyte/_task.py +35 -4
- flyte/_task_environment.py +60 -15
- flyte/_trigger.py +382 -0
- flyte/_version.py +3 -3
- flyte/cli/_abort.py +3 -3
- flyte/cli/_build.py +1 -3
- flyte/cli/_common.py +15 -2
- flyte/cli/_create.py +74 -0
- flyte/cli/_delete.py +23 -1
- flyte/cli/_deploy.py +5 -9
- flyte/cli/_get.py +75 -34
- flyte/cli/_params.py +4 -2
- flyte/cli/_run.py +12 -3
- flyte/cli/_update.py +36 -0
- flyte/cli/_user.py +17 -0
- flyte/cli/main.py +9 -1
- flyte/errors.py +9 -0
- flyte/io/_dir.py +513 -115
- flyte/io/_file.py +495 -135
- flyte/models.py +32 -0
- flyte/remote/__init__.py +6 -1
- flyte/remote/_client/_protocols.py +36 -2
- flyte/remote/_client/controlplane.py +19 -3
- flyte/remote/_run.py +42 -2
- flyte/remote/_task.py +14 -1
- flyte/remote/_trigger.py +308 -0
- flyte/remote/_user.py +33 -0
- flyte/storage/__init__.py +6 -1
- flyte/storage/_storage.py +119 -101
- flyte/types/_pickle.py +16 -3
- {flyte-2.0.0b22.data → flyte-2.0.0b23.data}/scripts/runtime.py +35 -5
- {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/METADATA +3 -1
- {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/RECORD +87 -75
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
- {flyte-2.0.0b22.data → flyte-2.0.0b23.data}/scripts/debug.py +0 -0
- {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/WHEEL +0 -0
- {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/entry_points.txt +0 -0
- {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/licenses/LICENSE +0 -0
- {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/top_level.txt +0 -0
flyte/models.py
CHANGED
|
@@ -77,6 +77,37 @@ class ActionID:
|
|
|
77
77
|
return self.new_sub_action(new_name)
|
|
78
78
|
|
|
79
79
|
|
|
80
|
+
@rich.repr.auto
|
|
81
|
+
@dataclass
|
|
82
|
+
class PathRewrite:
|
|
83
|
+
"""
|
|
84
|
+
Configuration for rewriting paths during input loading.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
# If set, rewrites any path starting with this prefix to the new prefix.
|
|
88
|
+
old_prefix: str
|
|
89
|
+
new_prefix: str
|
|
90
|
+
|
|
91
|
+
def __post_init__(self):
|
|
92
|
+
if not self.old_prefix or not self.new_prefix:
|
|
93
|
+
raise ValueError("Both old_prefix and new_prefix must be non-empty strings.")
|
|
94
|
+
if self.old_prefix == self.new_prefix:
|
|
95
|
+
raise ValueError("old_prefix and new_prefix must be different.")
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def from_str(cls, pattern: str) -> PathRewrite:
|
|
99
|
+
"""
|
|
100
|
+
Create a PathRewrite from a string pattern of the form `old_prefix->new_prefix`.
|
|
101
|
+
"""
|
|
102
|
+
parts = pattern.split("->")
|
|
103
|
+
if len(parts) != 2:
|
|
104
|
+
raise ValueError(f"Invalid path rewrite pattern: {pattern}. Expected format 'old_prefix->new_prefix'.")
|
|
105
|
+
return cls(old_prefix=parts[0], new_prefix=parts[1])
|
|
106
|
+
|
|
107
|
+
def __repr__(self) -> str:
|
|
108
|
+
return f"{self.old_prefix}->{self.new_prefix}"
|
|
109
|
+
|
|
110
|
+
|
|
80
111
|
@rich.repr.auto
|
|
81
112
|
@dataclass(frozen=True, kw_only=True)
|
|
82
113
|
class RawDataPath:
|
|
@@ -86,6 +117,7 @@ class RawDataPath:
|
|
|
86
117
|
"""
|
|
87
118
|
|
|
88
119
|
path: str
|
|
120
|
+
path_rewrite: Optional[PathRewrite] = None
|
|
89
121
|
|
|
90
122
|
@classmethod
|
|
91
123
|
def from_local_folder(cls, local_folder: str | pathlib.Path | None = None) -> RawDataPath:
|
flyte/remote/__init__.py
CHANGED
|
@@ -7,12 +7,15 @@ __all__ = [
|
|
|
7
7
|
"ActionDetails",
|
|
8
8
|
"ActionInputs",
|
|
9
9
|
"ActionOutputs",
|
|
10
|
+
"Phase",
|
|
10
11
|
"Project",
|
|
11
12
|
"Run",
|
|
12
13
|
"RunDetails",
|
|
13
14
|
"Secret",
|
|
14
15
|
"SecretTypes",
|
|
15
16
|
"Task",
|
|
17
|
+
"Trigger",
|
|
18
|
+
"User",
|
|
16
19
|
"create_channel",
|
|
17
20
|
"upload_dir",
|
|
18
21
|
"upload_file",
|
|
@@ -22,6 +25,8 @@ from ._action import Action, ActionDetails, ActionInputs, ActionOutputs
|
|
|
22
25
|
from ._client.auth import create_channel
|
|
23
26
|
from ._data import upload_dir, upload_file
|
|
24
27
|
from ._project import Project
|
|
25
|
-
from ._run import Run, RunDetails
|
|
28
|
+
from ._run import Phase, Run, RunDetails
|
|
26
29
|
from ._secret import Secret, SecretTypes
|
|
27
30
|
from ._task import Task
|
|
31
|
+
from ._trigger import Trigger
|
|
32
|
+
from ._user import User
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
from typing import AsyncIterator, Protocol
|
|
2
2
|
|
|
3
3
|
from flyteidl.admin import project_attributes_pb2, project_pb2, version_pb2
|
|
4
|
-
from flyteidl.service import dataproxy_pb2
|
|
4
|
+
from flyteidl.service import dataproxy_pb2, identity_pb2
|
|
5
5
|
from grpc.aio import UnaryStreamCall
|
|
6
6
|
from grpc.aio._typing import RequestType
|
|
7
7
|
|
|
8
8
|
from flyte._protos.secret import payload_pb2
|
|
9
|
-
from flyte._protos.workflow import run_logs_service_pb2, run_service_pb2, task_service_pb2
|
|
9
|
+
from flyte._protos.workflow import run_logs_service_pb2, run_service_pb2, task_service_pb2, trigger_service_pb2
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class MetadataServiceProtocol(Protocol):
|
|
@@ -131,3 +131,37 @@ class SecretService(Protocol):
|
|
|
131
131
|
async def ListSecrets(self, request: payload_pb2.ListSecretsRequest) -> payload_pb2.ListSecretsResponse: ...
|
|
132
132
|
|
|
133
133
|
async def DeleteSecret(self, request: payload_pb2.DeleteSecretRequest) -> payload_pb2.DeleteSecretResponse: ...
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class IdentityService(Protocol):
|
|
137
|
+
async def UserInfo(self, request: identity_pb2.UserInfoRequest) -> identity_pb2.UserInfoResponse: ...
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class TriggerService(Protocol):
|
|
141
|
+
async def DeployTrigger(
|
|
142
|
+
self, request: trigger_service_pb2.DeployTriggerRequest
|
|
143
|
+
) -> trigger_service_pb2.DeployTriggerResponse: ...
|
|
144
|
+
|
|
145
|
+
async def GetTriggerDetails(
|
|
146
|
+
self, request: trigger_service_pb2.GetTriggerDetailsRequest
|
|
147
|
+
) -> trigger_service_pb2.GetTriggerDetailsResponse: ...
|
|
148
|
+
|
|
149
|
+
async def GetTriggerRevisionDetails(
|
|
150
|
+
self, request: trigger_service_pb2.GetTriggerRevisionDetailsRequest
|
|
151
|
+
) -> trigger_service_pb2.GetTriggerRevisionDetailsResponse: ...
|
|
152
|
+
|
|
153
|
+
async def ListTriggers(
|
|
154
|
+
self, request: trigger_service_pb2.ListTriggersRequest
|
|
155
|
+
) -> trigger_service_pb2.ListTriggersResponse: ...
|
|
156
|
+
|
|
157
|
+
async def GetTriggerRevisionHistory(
|
|
158
|
+
self, request: trigger_service_pb2.GetTriggerRevisionHistoryRequest
|
|
159
|
+
) -> trigger_service_pb2.GetTriggerRevisionHistoryResponse: ...
|
|
160
|
+
|
|
161
|
+
async def UpdateTriggers(
|
|
162
|
+
self, request: trigger_service_pb2.UpdateTriggersRequest
|
|
163
|
+
) -> trigger_service_pb2.UpdateTriggersResponse: ...
|
|
164
|
+
|
|
165
|
+
async def DeleteTriggers(
|
|
166
|
+
self, request: trigger_service_pb2.DeleteTriggersRequest
|
|
167
|
+
) -> trigger_service_pb2.DeleteTriggersResponse: ...
|
|
@@ -15,19 +15,26 @@ if "GRPC_VERBOSITY" not in os.environ:
|
|
|
15
15
|
#### Has to be before grpc
|
|
16
16
|
|
|
17
17
|
import grpc
|
|
18
|
-
from flyteidl.service import admin_pb2_grpc, dataproxy_pb2_grpc
|
|
18
|
+
from flyteidl.service import admin_pb2_grpc, dataproxy_pb2_grpc, identity_pb2_grpc
|
|
19
19
|
|
|
20
20
|
from flyte._protos.secret import secret_pb2_grpc
|
|
21
|
-
from flyte._protos.workflow import
|
|
21
|
+
from flyte._protos.workflow import (
|
|
22
|
+
run_logs_service_pb2_grpc,
|
|
23
|
+
run_service_pb2_grpc,
|
|
24
|
+
task_service_pb2_grpc,
|
|
25
|
+
trigger_service_pb2_grpc,
|
|
26
|
+
)
|
|
22
27
|
|
|
23
28
|
from ._protocols import (
|
|
24
29
|
DataProxyService,
|
|
30
|
+
IdentityService,
|
|
25
31
|
MetadataServiceProtocol,
|
|
26
32
|
ProjectDomainService,
|
|
27
33
|
RunLogsService,
|
|
28
34
|
RunService,
|
|
29
35
|
SecretService,
|
|
30
36
|
TaskService,
|
|
37
|
+
TriggerService,
|
|
31
38
|
)
|
|
32
39
|
from .auth import create_channel
|
|
33
40
|
|
|
@@ -38,7 +45,6 @@ class ClientSet:
|
|
|
38
45
|
channel: grpc.aio.Channel,
|
|
39
46
|
endpoint: str,
|
|
40
47
|
insecure: bool = False,
|
|
41
|
-
data_proxy_channel: grpc.aio.Channel | None = None,
|
|
42
48
|
**kwargs,
|
|
43
49
|
):
|
|
44
50
|
self.endpoint = endpoint
|
|
@@ -50,6 +56,8 @@ class ClientSet:
|
|
|
50
56
|
self._dataproxy = dataproxy_pb2_grpc.DataProxyServiceStub(channel=channel)
|
|
51
57
|
self._log_service = run_logs_service_pb2_grpc.RunLogsServiceStub(channel=channel)
|
|
52
58
|
self._secrets_service = secret_pb2_grpc.SecretServiceStub(channel=channel)
|
|
59
|
+
self._identity_service = identity_pb2_grpc.IdentityServiceStub(channel=channel)
|
|
60
|
+
self._trigger_service = trigger_service_pb2_grpc.TriggerServiceStub(channel=channel)
|
|
53
61
|
|
|
54
62
|
@classmethod
|
|
55
63
|
async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet:
|
|
@@ -105,5 +113,13 @@ class ClientSet:
|
|
|
105
113
|
def secrets_service(self) -> SecretService:
|
|
106
114
|
return self._secrets_service
|
|
107
115
|
|
|
116
|
+
@property
|
|
117
|
+
def identity_service(self) -> IdentityService:
|
|
118
|
+
return self._identity_service
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def trigger_service(self) -> TriggerService:
|
|
122
|
+
return self._trigger_service
|
|
123
|
+
|
|
108
124
|
async def close(self, grace: float | None = None):
|
|
109
125
|
return await self._channel.close(grace=grace)
|
flyte/remote/_run.py
CHANGED
|
@@ -7,6 +7,7 @@ import grpc
|
|
|
7
7
|
import rich.repr
|
|
8
8
|
|
|
9
9
|
from flyte._initialize import ensure_client, get_client, get_common_config
|
|
10
|
+
from flyte._logging import logger
|
|
10
11
|
from flyte._protos.common import identifier_pb2, list_pb2
|
|
11
12
|
from flyte._protos.workflow import run_definition_pb2, run_service_pb2
|
|
12
13
|
from flyte.syncify import syncify
|
|
@@ -16,6 +17,11 @@ from ._action import _action_details_rich_repr, _action_rich_repr
|
|
|
16
17
|
from ._common import ToJSONMixin
|
|
17
18
|
from ._console import get_run_url
|
|
18
19
|
|
|
20
|
+
# @kumare3 is sadpanda, because we have to create a mirror of phase types here, because protobuf phases are ghastly
|
|
21
|
+
Phase = Literal[
|
|
22
|
+
"queued", "waiting_for_resources", "initializing", "running", "succeeded", "failed", "aborted", "timed_out"
|
|
23
|
+
]
|
|
24
|
+
|
|
19
25
|
|
|
20
26
|
@dataclass
|
|
21
27
|
class Run(ToJSONMixin):
|
|
@@ -40,14 +46,16 @@ class Run(ToJSONMixin):
|
|
|
40
46
|
@classmethod
|
|
41
47
|
async def listall(
|
|
42
48
|
cls,
|
|
43
|
-
|
|
49
|
+
in_phase: Tuple[Phase] | None = None,
|
|
50
|
+
created_by_subject: str | None = None,
|
|
44
51
|
sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
|
|
45
52
|
limit: int = 100,
|
|
46
53
|
) -> AsyncIterator[Run]:
|
|
47
54
|
"""
|
|
48
55
|
Get all runs for the current project and domain.
|
|
49
56
|
|
|
50
|
-
:param
|
|
57
|
+
:param in_phase: Filter runs by one or more phases.
|
|
58
|
+
:param created_by_subject: Filter runs by the subject that created them. (this is not username, but the subject)
|
|
51
59
|
:param sort_by: The sorting criteria for the project list, in the format (field, order).
|
|
52
60
|
:param limit: The maximum number of runs to return.
|
|
53
61
|
:return: An iterator of runs.
|
|
@@ -59,6 +67,36 @@ class Run(ToJSONMixin):
|
|
|
59
67
|
key=sort_by[0],
|
|
60
68
|
direction=(list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING),
|
|
61
69
|
)
|
|
70
|
+
filters = []
|
|
71
|
+
if in_phase:
|
|
72
|
+
phases = [str(run_definition_pb2.Phase.Value(f"PHASE_{p.upper()}")) for p in in_phase]
|
|
73
|
+
logger.debug(f"Fetching run phases: {phases}")
|
|
74
|
+
if len(phases) > 1:
|
|
75
|
+
filters.append(
|
|
76
|
+
list_pb2.Filter(
|
|
77
|
+
function=list_pb2.Filter.Function.VALUE_IN,
|
|
78
|
+
field="phase",
|
|
79
|
+
values=phases,
|
|
80
|
+
),
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
filters.append(
|
|
84
|
+
list_pb2.Filter(
|
|
85
|
+
function=list_pb2.Filter.Function.EQUAL,
|
|
86
|
+
field="phase",
|
|
87
|
+
values=phases[0],
|
|
88
|
+
),
|
|
89
|
+
)
|
|
90
|
+
if created_by_subject:
|
|
91
|
+
logger.debug(f"Fetching runs created by: {created_by_subject}")
|
|
92
|
+
filters.append(
|
|
93
|
+
list_pb2.Filter(
|
|
94
|
+
function=list_pb2.Filter.Function.EQUAL,
|
|
95
|
+
field="created_by",
|
|
96
|
+
values=[created_by_subject],
|
|
97
|
+
),
|
|
98
|
+
)
|
|
99
|
+
|
|
62
100
|
cfg = get_common_config()
|
|
63
101
|
i = 0
|
|
64
102
|
while True:
|
|
@@ -66,6 +104,7 @@ class Run(ToJSONMixin):
|
|
|
66
104
|
limit=min(100, limit),
|
|
67
105
|
token=token,
|
|
68
106
|
sort_by=sort_pb2,
|
|
107
|
+
filters=filters,
|
|
69
108
|
)
|
|
70
109
|
resp = await get_client().run_service.ListRuns(
|
|
71
110
|
run_service_pb2.ListRunsRequest(
|
|
@@ -225,6 +264,7 @@ class Run(ToJSONMixin):
|
|
|
225
264
|
"""
|
|
226
265
|
Rich representation of the Run object.
|
|
227
266
|
"""
|
|
267
|
+
yield "url", f"[blue bold][link={self.url}]link[/link][/blue bold]"
|
|
228
268
|
yield from _action_rich_repr(self.pb2.action)
|
|
229
269
|
|
|
230
270
|
def __repr__(self) -> str:
|
flyte/remote/_task.py
CHANGED
|
@@ -99,6 +99,7 @@ AutoVersioning = Literal["latest", "current"]
|
|
|
99
99
|
class TaskDetails(ToJSONMixin):
|
|
100
100
|
pb2: task_definition_pb2.TaskDetails
|
|
101
101
|
max_inline_io_bytes: int = 10 * 1024 * 1024 # 10 MB
|
|
102
|
+
overriden_queue: Optional[str] = None
|
|
102
103
|
|
|
103
104
|
@classmethod
|
|
104
105
|
def get(
|
|
@@ -284,6 +285,13 @@ class TaskDetails(ToJSONMixin):
|
|
|
284
285
|
f"Reference tasks [{self.name}] cannot be executed locally, only remotely."
|
|
285
286
|
)
|
|
286
287
|
|
|
288
|
+
@property
|
|
289
|
+
def queue(self) -> Optional[str]:
|
|
290
|
+
"""
|
|
291
|
+
The queue to use for the task.
|
|
292
|
+
"""
|
|
293
|
+
return self.overriden_queue
|
|
294
|
+
|
|
287
295
|
def override(
|
|
288
296
|
self,
|
|
289
297
|
*,
|
|
@@ -295,6 +303,7 @@ class TaskDetails(ToJSONMixin):
|
|
|
295
303
|
secrets: Optional[flyte.SecretRequest] = None,
|
|
296
304
|
max_inline_io_bytes: Optional[int] = None,
|
|
297
305
|
cache: Optional[flyte.Cache] = None,
|
|
306
|
+
queue: Optional[str] = None,
|
|
298
307
|
**kwargs: Any,
|
|
299
308
|
) -> TaskDetails:
|
|
300
309
|
if len(kwargs) > 0:
|
|
@@ -342,7 +351,11 @@ class TaskDetails(ToJSONMixin):
|
|
|
342
351
|
md.cache_serializable = cache.serialize
|
|
343
352
|
md.cache_ignore_input_vars[:] = list(cache.ignored_inputs or ())
|
|
344
353
|
|
|
345
|
-
return TaskDetails(
|
|
354
|
+
return TaskDetails(
|
|
355
|
+
pb2,
|
|
356
|
+
max_inline_io_bytes=max_inline_io_bytes or self.max_inline_io_bytes,
|
|
357
|
+
overriden_queue=queue,
|
|
358
|
+
)
|
|
346
359
|
|
|
347
360
|
def __rich_repr__(self) -> rich.repr.Result:
|
|
348
361
|
"""
|
flyte/remote/_trigger.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from functools import cached_property
|
|
5
|
+
from typing import AsyncIterator
|
|
6
|
+
|
|
7
|
+
import grpc.aio
|
|
8
|
+
|
|
9
|
+
import flyte
|
|
10
|
+
from flyte._initialize import ensure_client, get_client, get_common_config
|
|
11
|
+
from flyte._internal.runtime import trigger_serde
|
|
12
|
+
from flyte._protos.common import identifier_pb2, list_pb2
|
|
13
|
+
from flyte._protos.workflow import common_pb2, task_definition_pb2, trigger_definition_pb2, trigger_service_pb2
|
|
14
|
+
from flyte.syncify import syncify
|
|
15
|
+
|
|
16
|
+
from ._common import ToJSONMixin
|
|
17
|
+
from ._task import Task, TaskDetails
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class TriggerDetails(ToJSONMixin):
|
|
22
|
+
pb2: trigger_definition_pb2.TriggerDetails
|
|
23
|
+
|
|
24
|
+
@syncify
|
|
25
|
+
@classmethod
|
|
26
|
+
async def get(cls, *, name: str, task_name: str) -> TriggerDetails:
|
|
27
|
+
"""
|
|
28
|
+
Retrieve detailed information about a specific trigger by its name.
|
|
29
|
+
"""
|
|
30
|
+
ensure_client()
|
|
31
|
+
cfg = get_common_config()
|
|
32
|
+
resp = await get_client().trigger_service.GetTriggerDetails(
|
|
33
|
+
request=trigger_service_pb2.GetTriggerDetailsRequest(
|
|
34
|
+
name=identifier_pb2.TriggerName(
|
|
35
|
+
task_name=task_name,
|
|
36
|
+
name=name,
|
|
37
|
+
org=cfg.org,
|
|
38
|
+
project=cfg.project,
|
|
39
|
+
domain=cfg.domain,
|
|
40
|
+
),
|
|
41
|
+
)
|
|
42
|
+
)
|
|
43
|
+
return cls(pb2=resp.trigger)
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def name(self) -> str:
|
|
47
|
+
return self.id.name.name
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def id(self) -> identifier_pb2.TriggerIdentifier:
|
|
51
|
+
return self.pb2.id
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def task_name(self) -> str:
|
|
55
|
+
return self.pb2.id.name.task_name
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def automation_spec(self) -> common_pb2.TriggerAutomationSpec:
|
|
59
|
+
return self.pb2.automation_spec
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def metadata(self) -> trigger_definition_pb2.TriggerMetadata:
|
|
63
|
+
return self.pb2.metadata
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def status(self) -> trigger_definition_pb2.TriggerStatus:
|
|
67
|
+
return self.pb2.status
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def is_active(self) -> bool:
|
|
71
|
+
return self.pb2.spec.active
|
|
72
|
+
|
|
73
|
+
@cached_property
|
|
74
|
+
def trigger(self) -> trigger_definition_pb2.Trigger:
|
|
75
|
+
return trigger_definition_pb2.Trigger(
|
|
76
|
+
id=self.pb2.id,
|
|
77
|
+
automation_spec=self.automation_spec,
|
|
78
|
+
metadata=self.metadata,
|
|
79
|
+
status=self.status,
|
|
80
|
+
active=self.is_active,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class Trigger(ToJSONMixin):
|
|
86
|
+
pb2: trigger_definition_pb2.Trigger
|
|
87
|
+
details: TriggerDetails | None = None
|
|
88
|
+
|
|
89
|
+
@syncify
|
|
90
|
+
@classmethod
|
|
91
|
+
async def create(
|
|
92
|
+
cls,
|
|
93
|
+
trigger: flyte.Trigger,
|
|
94
|
+
task_name: str,
|
|
95
|
+
task_version: str | None = None,
|
|
96
|
+
) -> Trigger:
|
|
97
|
+
"""
|
|
98
|
+
Create a new trigger in the Flyte platform.
|
|
99
|
+
|
|
100
|
+
:param trigger: The flyte.Trigger object containing the trigger definition.
|
|
101
|
+
:param task_name: Optional name of the task to associate with the trigger.
|
|
102
|
+
"""
|
|
103
|
+
ensure_client()
|
|
104
|
+
cfg = get_common_config()
|
|
105
|
+
|
|
106
|
+
# Fetch the task to ensure it exists and to get its input definitions
|
|
107
|
+
try:
|
|
108
|
+
lazy = (
|
|
109
|
+
Task.get(name=task_name, version=task_version)
|
|
110
|
+
if task_version
|
|
111
|
+
else Task.get(name=task_name, auto_version="latest")
|
|
112
|
+
)
|
|
113
|
+
task: TaskDetails = await lazy.fetch.aio()
|
|
114
|
+
|
|
115
|
+
task_trigger = await trigger_serde.to_task_trigger(
|
|
116
|
+
t=trigger,
|
|
117
|
+
task_name=task_name,
|
|
118
|
+
task_inputs=task.pb2.spec.task_template.interface.inputs,
|
|
119
|
+
task_default_inputs=list(task.pb2.spec.default_inputs),
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
resp = await get_client().trigger_service.DeployTrigger(
|
|
123
|
+
request=trigger_service_pb2.DeployTriggerRequest(
|
|
124
|
+
id=identifier_pb2.TriggerIdentifier(
|
|
125
|
+
name=identifier_pb2.TriggerName(
|
|
126
|
+
name=trigger.name,
|
|
127
|
+
task_name=task_name,
|
|
128
|
+
org=cfg.org,
|
|
129
|
+
project=cfg.project,
|
|
130
|
+
domain=cfg.domain,
|
|
131
|
+
),
|
|
132
|
+
revision=1,
|
|
133
|
+
),
|
|
134
|
+
spec=trigger_definition_pb2.TriggerSpec(
|
|
135
|
+
active=task_trigger.spec.active,
|
|
136
|
+
inputs=task_trigger.spec.inputs,
|
|
137
|
+
run_spec=task_trigger.spec.run_spec,
|
|
138
|
+
task_version=task.version,
|
|
139
|
+
),
|
|
140
|
+
automation_spec=task_trigger.automation_spec,
|
|
141
|
+
)
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
details = TriggerDetails(pb2=resp.trigger)
|
|
145
|
+
|
|
146
|
+
return cls(pb2=details.trigger, details=details)
|
|
147
|
+
except grpc.aio.AioRpcError as e:
|
|
148
|
+
if e.code() == grpc.StatusCode.NOT_FOUND:
|
|
149
|
+
raise ValueError(f"Task {task_name}:{task_version or 'latest'} not found") from e
|
|
150
|
+
raise
|
|
151
|
+
|
|
152
|
+
@syncify
|
|
153
|
+
@classmethod
|
|
154
|
+
async def get(cls, *, name: str, task_name: str) -> TriggerDetails:
|
|
155
|
+
"""
|
|
156
|
+
Retrieve a trigger by its name and associated task name.
|
|
157
|
+
"""
|
|
158
|
+
return await TriggerDetails.get(name=name, task_name=task_name)
|
|
159
|
+
|
|
160
|
+
@syncify
|
|
161
|
+
@classmethod
|
|
162
|
+
async def listall(
|
|
163
|
+
cls, task_name: str | None = None, task_version: str | None = None, limit: int = 100
|
|
164
|
+
) -> AsyncIterator[Trigger]:
|
|
165
|
+
"""
|
|
166
|
+
List all triggers associated with a specific task or all tasks if no task name is provided.
|
|
167
|
+
"""
|
|
168
|
+
ensure_client()
|
|
169
|
+
cfg = get_common_config()
|
|
170
|
+
token = None
|
|
171
|
+
# task_name_id = None TODO: implement listing by task name only
|
|
172
|
+
project_id = None
|
|
173
|
+
task_id = None
|
|
174
|
+
if task_name and task_version:
|
|
175
|
+
task_id = task_definition_pb2.TaskIdentifier(
|
|
176
|
+
name=task_name,
|
|
177
|
+
project=cfg.project,
|
|
178
|
+
domain=cfg.domain,
|
|
179
|
+
org=cfg.org,
|
|
180
|
+
version=task_version,
|
|
181
|
+
)
|
|
182
|
+
# elif task_name: TODO: implement listing by task name only
|
|
183
|
+
# task_name_id = task_definition_pb2.TaskName(
|
|
184
|
+
# name=task_name,
|
|
185
|
+
# project=cfg.project,
|
|
186
|
+
# domain=cfg.domain,
|
|
187
|
+
# org=cfg.org,
|
|
188
|
+
# )
|
|
189
|
+
else:
|
|
190
|
+
project_id = identifier_pb2.ProjectIdentifier(
|
|
191
|
+
organization=cfg.org,
|
|
192
|
+
domain=cfg.domain,
|
|
193
|
+
name=cfg.project,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
while True:
|
|
197
|
+
resp = await get_client().trigger_service.ListTriggers(
|
|
198
|
+
request=trigger_service_pb2.ListTriggersRequest(
|
|
199
|
+
project_id=project_id,
|
|
200
|
+
task_id=task_id,
|
|
201
|
+
# task_name=task_name_id,
|
|
202
|
+
request=list_pb2.ListRequest(
|
|
203
|
+
limit=limit,
|
|
204
|
+
token=token,
|
|
205
|
+
),
|
|
206
|
+
)
|
|
207
|
+
)
|
|
208
|
+
token = resp.token
|
|
209
|
+
for r in resp.triggers:
|
|
210
|
+
yield cls(r)
|
|
211
|
+
if not token:
|
|
212
|
+
break
|
|
213
|
+
|
|
214
|
+
@syncify
|
|
215
|
+
@classmethod
|
|
216
|
+
async def update(cls, name: str, task_name: str, active: bool):
|
|
217
|
+
"""
|
|
218
|
+
Pause a trigger by its name and associated task name.
|
|
219
|
+
"""
|
|
220
|
+
ensure_client()
|
|
221
|
+
cfg = get_common_config()
|
|
222
|
+
await get_client().trigger_service.UpdateTriggers(
|
|
223
|
+
request=trigger_service_pb2.UpdateTriggersRequest(
|
|
224
|
+
names=[
|
|
225
|
+
identifier_pb2.TriggerName(
|
|
226
|
+
org=cfg.org,
|
|
227
|
+
project=cfg.project,
|
|
228
|
+
domain=cfg.domain,
|
|
229
|
+
name=name,
|
|
230
|
+
task_name=task_name,
|
|
231
|
+
)
|
|
232
|
+
],
|
|
233
|
+
active=active,
|
|
234
|
+
)
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
@syncify
|
|
238
|
+
@classmethod
|
|
239
|
+
async def delete(cls, name: str, task_name: str):
|
|
240
|
+
"""
|
|
241
|
+
Delete a trigger by its name.
|
|
242
|
+
"""
|
|
243
|
+
ensure_client()
|
|
244
|
+
cfg = get_common_config()
|
|
245
|
+
await get_client().trigger_service.DeleteTriggers(
|
|
246
|
+
request=trigger_service_pb2.DeleteTriggersRequest(
|
|
247
|
+
names=[
|
|
248
|
+
identifier_pb2.TriggerName(
|
|
249
|
+
org=cfg.org,
|
|
250
|
+
project=cfg.project,
|
|
251
|
+
domain=cfg.domain,
|
|
252
|
+
name=name,
|
|
253
|
+
task_name=task_name,
|
|
254
|
+
)
|
|
255
|
+
],
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
@property
|
|
260
|
+
def id(self) -> identifier_pb2.TriggerIdentifier:
|
|
261
|
+
return self.pb2.id
|
|
262
|
+
|
|
263
|
+
@property
|
|
264
|
+
def name(self) -> str:
|
|
265
|
+
return self.id.name.name
|
|
266
|
+
|
|
267
|
+
@property
|
|
268
|
+
def task_name(self) -> str:
|
|
269
|
+
return self.id.name.task_name
|
|
270
|
+
|
|
271
|
+
@property
|
|
272
|
+
def automation_spec(self) -> common_pb2.TriggerAutomationSpec:
|
|
273
|
+
return self.pb2.automation_spec
|
|
274
|
+
|
|
275
|
+
async def get_details(self) -> TriggerDetails:
|
|
276
|
+
"""
|
|
277
|
+
Get detailed information about this trigger.
|
|
278
|
+
"""
|
|
279
|
+
if not self.details:
|
|
280
|
+
details = await TriggerDetails.get.aio(name=self.pb2.id.name.name)
|
|
281
|
+
self.details = details
|
|
282
|
+
return self.details
|
|
283
|
+
|
|
284
|
+
@property
|
|
285
|
+
def is_active(self) -> bool:
|
|
286
|
+
return self.pb2.active
|
|
287
|
+
|
|
288
|
+
def _rich_automation(self, automation: common_pb2.TriggerAutomationSpec):
|
|
289
|
+
if automation.type == common_pb2.TriggerAutomationSpec.TYPE_NONE:
|
|
290
|
+
yield "none", None
|
|
291
|
+
elif automation.type == common_pb2.TriggerAutomationSpec.TYPE_SCHEDULE:
|
|
292
|
+
if automation.schedule.cron_expression is not None:
|
|
293
|
+
yield "cron", automation.schedule.cron_expression
|
|
294
|
+
elif automation.schedule.rate is not None:
|
|
295
|
+
r = automation.schedule.rate
|
|
296
|
+
yield (
|
|
297
|
+
"fixed_rate",
|
|
298
|
+
(
|
|
299
|
+
f"Every [{r.value}] {r.unit} starting at "
|
|
300
|
+
f"{r.start_time.ToDatetime() if automation.HasField('start_time') else 'now'}"
|
|
301
|
+
),
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
def __rich_repr__(self):
|
|
305
|
+
yield "task_name", self.task_name
|
|
306
|
+
yield "name", self.name
|
|
307
|
+
yield from self._rich_automation(self.pb2.automation_spec)
|
|
308
|
+
yield "auto_activate", self.is_active
|
flyte/remote/_user.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from flyteidl.service import identity_pb2
|
|
6
|
+
from flyteidl.service.identity_pb2 import UserInfoResponse
|
|
7
|
+
|
|
8
|
+
from .._initialize import ensure_client, get_client
|
|
9
|
+
from ..syncify import syncify
|
|
10
|
+
from ._common import ToJSONMixin
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class User(ToJSONMixin):
|
|
15
|
+
pb2: UserInfoResponse
|
|
16
|
+
|
|
17
|
+
@syncify
|
|
18
|
+
@classmethod
|
|
19
|
+
async def get(cls) -> User:
|
|
20
|
+
"""
|
|
21
|
+
Fetches information about the currently logged in user.
|
|
22
|
+
Returns: A User object containing details about the user.
|
|
23
|
+
"""
|
|
24
|
+
ensure_client()
|
|
25
|
+
|
|
26
|
+
resp = await get_client().identity_service.UserInfo(identity_pb2.UserInfoRequest())
|
|
27
|
+
return cls(resp)
|
|
28
|
+
|
|
29
|
+
def subject(self) -> str:
|
|
30
|
+
return self.pb2.subject
|
|
31
|
+
|
|
32
|
+
def name(self) -> str:
|
|
33
|
+
return self.pb2.name
|
flyte/storage/__init__.py
CHANGED
|
@@ -3,6 +3,8 @@ __all__ = [
|
|
|
3
3
|
"GCS",
|
|
4
4
|
"S3",
|
|
5
5
|
"Storage",
|
|
6
|
+
"exists",
|
|
7
|
+
"exists_sync",
|
|
6
8
|
"get",
|
|
7
9
|
"get_configured_fsspec_kwargs",
|
|
8
10
|
"get_random_local_directory",
|
|
@@ -11,13 +13,15 @@ __all__ = [
|
|
|
11
13
|
"get_underlying_filesystem",
|
|
12
14
|
"is_remote",
|
|
13
15
|
"join",
|
|
16
|
+
"open",
|
|
14
17
|
"put",
|
|
15
18
|
"put_stream",
|
|
16
|
-
"put_stream",
|
|
17
19
|
]
|
|
18
20
|
|
|
19
21
|
from ._config import ABFS, GCS, S3, Storage
|
|
20
22
|
from ._storage import (
|
|
23
|
+
exists,
|
|
24
|
+
exists_sync,
|
|
21
25
|
get,
|
|
22
26
|
get_configured_fsspec_kwargs,
|
|
23
27
|
get_random_local_directory,
|
|
@@ -26,6 +30,7 @@ from ._storage import (
|
|
|
26
30
|
get_underlying_filesystem,
|
|
27
31
|
is_remote,
|
|
28
32
|
join,
|
|
33
|
+
open,
|
|
29
34
|
put,
|
|
30
35
|
put_stream,
|
|
31
36
|
)
|