flyte 2.0.0b13__py3-none-any.whl → 2.0.0b30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyte/__init__.py +18 -2
- flyte/_bin/debug.py +38 -0
- flyte/_bin/runtime.py +62 -8
- flyte/_cache/cache.py +4 -2
- flyte/_cache/local_cache.py +216 -0
- flyte/_code_bundle/_ignore.py +12 -4
- flyte/_code_bundle/_packaging.py +13 -9
- flyte/_code_bundle/_utils.py +18 -10
- flyte/_code_bundle/bundle.py +17 -9
- flyte/_constants.py +1 -0
- flyte/_context.py +4 -1
- flyte/_custom_context.py +73 -0
- flyte/_debug/constants.py +38 -0
- flyte/_debug/utils.py +17 -0
- flyte/_debug/vscode.py +307 -0
- flyte/_deploy.py +235 -61
- flyte/_environment.py +20 -6
- flyte/_excepthook.py +1 -1
- flyte/_hash.py +1 -16
- flyte/_image.py +178 -81
- flyte/_initialize.py +132 -51
- flyte/_interface.py +39 -2
- flyte/_internal/controllers/__init__.py +4 -5
- flyte/_internal/controllers/_local_controller.py +70 -29
- flyte/_internal/controllers/_trace.py +1 -1
- flyte/_internal/controllers/remote/__init__.py +0 -2
- flyte/_internal/controllers/remote/_action.py +14 -16
- flyte/_internal/controllers/remote/_client.py +1 -1
- flyte/_internal/controllers/remote/_controller.py +68 -70
- flyte/_internal/controllers/remote/_core.py +127 -99
- flyte/_internal/controllers/remote/_informer.py +19 -10
- flyte/_internal/controllers/remote/_service_protocol.py +7 -7
- flyte/_internal/imagebuild/docker_builder.py +181 -69
- flyte/_internal/imagebuild/image_builder.py +0 -5
- flyte/_internal/imagebuild/remote_builder.py +155 -64
- flyte/_internal/imagebuild/utils.py +51 -2
- flyte/_internal/resolvers/_task_module.py +5 -38
- flyte/_internal/resolvers/default.py +2 -2
- flyte/_internal/runtime/convert.py +110 -21
- flyte/_internal/runtime/entrypoints.py +27 -1
- flyte/_internal/runtime/io.py +21 -8
- flyte/_internal/runtime/resources_serde.py +20 -6
- flyte/_internal/runtime/reuse.py +1 -1
- flyte/_internal/runtime/rusty.py +20 -5
- flyte/_internal/runtime/task_serde.py +34 -19
- flyte/_internal/runtime/taskrunner.py +22 -4
- flyte/_internal/runtime/trigger_serde.py +160 -0
- flyte/_internal/runtime/types_serde.py +1 -1
- flyte/_keyring/__init__.py +0 -0
- flyte/_keyring/file.py +115 -0
- flyte/_logging.py +201 -39
- flyte/_map.py +111 -14
- flyte/_module.py +70 -0
- flyte/_pod.py +4 -3
- flyte/_resources.py +213 -31
- flyte/_run.py +110 -39
- flyte/_task.py +75 -16
- flyte/_task_environment.py +105 -29
- flyte/_task_plugins.py +4 -2
- flyte/_trace.py +5 -0
- flyte/_trigger.py +1000 -0
- flyte/_utils/__init__.py +2 -1
- flyte/_utils/asyn.py +3 -1
- flyte/_utils/coro_management.py +2 -1
- flyte/_utils/docker_credentials.py +173 -0
- flyte/_utils/module_loader.py +17 -2
- flyte/_version.py +3 -3
- flyte/cli/_abort.py +3 -3
- flyte/cli/_build.py +3 -6
- flyte/cli/_common.py +78 -7
- flyte/cli/_create.py +182 -4
- flyte/cli/_delete.py +23 -1
- flyte/cli/_deploy.py +63 -16
- flyte/cli/_get.py +79 -34
- flyte/cli/_params.py +26 -10
- flyte/cli/_plugins.py +209 -0
- flyte/cli/_run.py +151 -26
- flyte/cli/_serve.py +64 -0
- flyte/cli/_update.py +37 -0
- flyte/cli/_user.py +17 -0
- flyte/cli/main.py +30 -4
- flyte/config/_config.py +10 -6
- flyte/config/_internal.py +1 -0
- flyte/config/_reader.py +29 -8
- flyte/connectors/__init__.py +11 -0
- flyte/connectors/_connector.py +270 -0
- flyte/connectors/_server.py +197 -0
- flyte/connectors/utils.py +135 -0
- flyte/errors.py +22 -2
- flyte/extend.py +8 -1
- flyte/extras/_container.py +6 -1
- flyte/git/__init__.py +3 -0
- flyte/git/_config.py +21 -0
- flyte/io/__init__.py +2 -0
- flyte/io/_dataframe/__init__.py +2 -0
- flyte/io/_dataframe/basic_dfs.py +17 -8
- flyte/io/_dataframe/dataframe.py +98 -132
- flyte/io/_dir.py +575 -113
- flyte/io/_file.py +582 -139
- flyte/io/_hashing_io.py +342 -0
- flyte/models.py +74 -15
- flyte/remote/__init__.py +6 -1
- flyte/remote/_action.py +34 -26
- flyte/remote/_client/_protocols.py +39 -4
- flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
- flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
- flyte/remote/_client/auth/_channel.py +10 -6
- flyte/remote/_client/controlplane.py +17 -5
- flyte/remote/_console.py +3 -2
- flyte/remote/_data.py +6 -6
- flyte/remote/_logs.py +3 -3
- flyte/remote/_run.py +64 -8
- flyte/remote/_secret.py +26 -17
- flyte/remote/_task.py +75 -33
- flyte/remote/_trigger.py +306 -0
- flyte/remote/_user.py +33 -0
- flyte/report/_report.py +1 -1
- flyte/storage/__init__.py +6 -1
- flyte/storage/_config.py +5 -1
- flyte/storage/_parallel_reader.py +274 -0
- flyte/storage/_storage.py +200 -103
- flyte/types/__init__.py +16 -0
- flyte/types/_interface.py +2 -2
- flyte/types/_pickle.py +35 -8
- flyte/types/_string_literals.py +8 -9
- flyte/types/_type_engine.py +40 -70
- flyte/types/_utils.py +1 -1
- flyte-2.0.0b30.data/scripts/debug.py +38 -0
- {flyte-2.0.0b13.data → flyte-2.0.0b30.data}/scripts/runtime.py +62 -8
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/METADATA +11 -3
- flyte-2.0.0b30.dist-info/RECORD +192 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +3 -0
- flyte/_protos/common/authorization_pb2.py +0 -66
- flyte/_protos/common/authorization_pb2.pyi +0 -108
- flyte/_protos/common/authorization_pb2_grpc.py +0 -4
- flyte/_protos/common/identifier_pb2.py +0 -93
- flyte/_protos/common/identifier_pb2.pyi +0 -110
- flyte/_protos/common/identifier_pb2_grpc.py +0 -4
- flyte/_protos/common/identity_pb2.py +0 -48
- flyte/_protos/common/identity_pb2.pyi +0 -72
- flyte/_protos/common/identity_pb2_grpc.py +0 -4
- flyte/_protos/common/list_pb2.py +0 -36
- flyte/_protos/common/list_pb2.pyi +0 -71
- flyte/_protos/common/list_pb2_grpc.py +0 -4
- flyte/_protos/common/policy_pb2.py +0 -37
- flyte/_protos/common/policy_pb2.pyi +0 -27
- flyte/_protos/common/policy_pb2_grpc.py +0 -4
- flyte/_protos/common/role_pb2.py +0 -37
- flyte/_protos/common/role_pb2.pyi +0 -53
- flyte/_protos/common/role_pb2_grpc.py +0 -4
- flyte/_protos/common/runtime_version_pb2.py +0 -28
- flyte/_protos/common/runtime_version_pb2.pyi +0 -24
- flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/definition_pb2.py +0 -59
- flyte/_protos/imagebuilder/definition_pb2.pyi +0 -140
- flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/payload_pb2.py +0 -32
- flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
- flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/service_pb2.py +0 -29
- flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
- flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
- flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
- flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/definition_pb2.py +0 -49
- flyte/_protos/secret/definition_pb2.pyi +0 -93
- flyte/_protos/secret/definition_pb2_grpc.py +0 -4
- flyte/_protos/secret/payload_pb2.py +0 -62
- flyte/_protos/secret/payload_pb2.pyi +0 -94
- flyte/_protos/secret/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/secret_pb2.py +0 -38
- flyte/_protos/secret/secret_pb2.pyi +0 -6
- flyte/_protos/secret/secret_pb2_grpc.py +0 -198
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
- flyte/_protos/validate/validate/validate_pb2.py +0 -76
- flyte/_protos/workflow/common_pb2.py +0 -27
- flyte/_protos/workflow/common_pb2.pyi +0 -14
- flyte/_protos/workflow/common_pb2_grpc.py +0 -4
- flyte/_protos/workflow/environment_pb2.py +0 -29
- flyte/_protos/workflow/environment_pb2.pyi +0 -12
- flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
- flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
- flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
- flyte/_protos/workflow/queue_service_pb2.py +0 -109
- flyte/_protos/workflow/queue_service_pb2.pyi +0 -166
- flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
- flyte/_protos/workflow/run_definition_pb2.py +0 -121
- flyte/_protos/workflow/run_definition_pb2.pyi +0 -327
- flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
- flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
- flyte/_protos/workflow/run_service_pb2.py +0 -137
- flyte/_protos/workflow/run_service_pb2.pyi +0 -185
- flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
- flyte/_protos/workflow/state_service_pb2.py +0 -67
- flyte/_protos/workflow/state_service_pb2.pyi +0 -76
- flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
- flyte/_protos/workflow/task_definition_pb2.py +0 -79
- flyte/_protos/workflow/task_definition_pb2.pyi +0 -81
- flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/task_service_pb2.py +0 -60
- flyte/_protos/workflow/task_service_pb2.pyi +0 -59
- flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
- flyte-2.0.0b13.dist-info/RECORD +0 -239
- /flyte/{_protos → _debug}/__init__.py +0 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
flyte/remote/_task.py
CHANGED
|
@@ -1,24 +1,23 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import functools
|
|
4
5
|
from dataclasses import dataclass
|
|
5
|
-
from threading import Lock
|
|
6
6
|
from typing import Any, AsyncIterator, Callable, Coroutine, Dict, Iterator, Literal, Optional, Tuple, Union, cast
|
|
7
7
|
|
|
8
8
|
import rich.repr
|
|
9
|
-
from
|
|
10
|
-
from
|
|
9
|
+
from flyteidl2.common import identifier_pb2, list_pb2
|
|
10
|
+
from flyteidl2.core import literals_pb2
|
|
11
|
+
from flyteidl2.task import task_definition_pb2, task_service_pb2
|
|
11
12
|
|
|
12
13
|
import flyte
|
|
13
14
|
import flyte.errors
|
|
14
15
|
from flyte._cache.cache import CacheBehavior
|
|
15
16
|
from flyte._context import internal_ctx
|
|
16
|
-
from flyte._initialize import ensure_client, get_client,
|
|
17
|
+
from flyte._initialize import ensure_client, get_client, get_init_config
|
|
17
18
|
from flyte._internal.runtime.resources_serde import get_proto_resources
|
|
18
19
|
from flyte._internal.runtime.task_serde import get_proto_retry_strategy, get_proto_timeout, get_security_context
|
|
19
20
|
from flyte._logging import logger
|
|
20
|
-
from flyte._protos.common import identifier_pb2, list_pb2
|
|
21
|
-
from flyte._protos.workflow import task_definition_pb2, task_service_pb2
|
|
22
21
|
from flyte.models import NativeInterface
|
|
23
22
|
from flyte.syncify import syncify
|
|
24
23
|
|
|
@@ -35,7 +34,7 @@ def _repr_task_metadata(metadata: task_definition_pb2.TaskMetadata) -> rich.repr
|
|
|
35
34
|
else:
|
|
36
35
|
yield "deployed_by", f"App: {metadata.deployed_by.application.spec.name}"
|
|
37
36
|
yield "short_name", metadata.short_name
|
|
38
|
-
yield "deployed_at",
|
|
37
|
+
yield "deployed_at", metadata.deployed_at.ToDatetime()
|
|
39
38
|
yield "environment_name", metadata.environment_name
|
|
40
39
|
|
|
41
40
|
|
|
@@ -49,7 +48,7 @@ class LazyEntity:
|
|
|
49
48
|
self._task: Optional[TaskDetails] = None
|
|
50
49
|
self._getter = getter
|
|
51
50
|
self._name = name
|
|
52
|
-
self._mutex = Lock()
|
|
51
|
+
self._mutex = asyncio.Lock()
|
|
53
52
|
|
|
54
53
|
@property
|
|
55
54
|
def name(self) -> str:
|
|
@@ -60,11 +59,11 @@ class LazyEntity:
|
|
|
60
59
|
"""
|
|
61
60
|
Forwards all other attributes to task, causing the task to be fetched!
|
|
62
61
|
"""
|
|
63
|
-
with self._mutex:
|
|
62
|
+
async with self._mutex:
|
|
64
63
|
if self._task is None:
|
|
65
64
|
self._task = await self._getter()
|
|
66
|
-
|
|
67
|
-
|
|
65
|
+
if self._task is None:
|
|
66
|
+
raise RuntimeError(f"Error downloading the task {self._name}, (check original exception...)")
|
|
68
67
|
return self._task
|
|
69
68
|
|
|
70
69
|
@syncify
|
|
@@ -73,8 +72,10 @@ class LazyEntity:
|
|
|
73
72
|
**kwargs: Any,
|
|
74
73
|
) -> LazyEntity:
|
|
75
74
|
task_details = cast(TaskDetails, await self.fetch.aio())
|
|
76
|
-
task_details.override(**kwargs)
|
|
77
|
-
|
|
75
|
+
new_task_details = task_details.override(**kwargs)
|
|
76
|
+
new_entity = LazyEntity(self._name, self._getter)
|
|
77
|
+
new_entity._task = new_task_details
|
|
78
|
+
return new_entity
|
|
78
79
|
|
|
79
80
|
async def __call__(self, *args, **kwargs):
|
|
80
81
|
"""
|
|
@@ -93,10 +94,11 @@ class LazyEntity:
|
|
|
93
94
|
AutoVersioning = Literal["latest", "current"]
|
|
94
95
|
|
|
95
96
|
|
|
96
|
-
@dataclass
|
|
97
|
+
@dataclass(frozen=True)
|
|
97
98
|
class TaskDetails(ToJSONMixin):
|
|
98
99
|
pb2: task_definition_pb2.TaskDetails
|
|
99
100
|
max_inline_io_bytes: int = 10 * 1024 * 1024 # 10 MB
|
|
101
|
+
overriden_queue: Optional[str] = None
|
|
100
102
|
|
|
101
103
|
@classmethod
|
|
102
104
|
def get(
|
|
@@ -148,7 +150,7 @@ class TaskDetails(ToJSONMixin):
|
|
|
148
150
|
if ctx is None:
|
|
149
151
|
raise ValueError("auto_version=current can only be used within a task context.")
|
|
150
152
|
_version = ctx.version
|
|
151
|
-
cfg =
|
|
153
|
+
cfg = get_init_config()
|
|
152
154
|
task_id = task_definition_pb2.TaskIdentifier(
|
|
153
155
|
org=cfg.org,
|
|
154
156
|
project=project or cfg.project,
|
|
@@ -261,12 +263,6 @@ class TaskDetails(ToJSONMixin):
|
|
|
261
263
|
f"Reference task {self.name} does not support positional arguments"
|
|
262
264
|
f"currently. Please use keyword arguments."
|
|
263
265
|
)
|
|
264
|
-
if len(self.required_args) > 0:
|
|
265
|
-
if len(args) + len(kwargs) < len(self.required_args):
|
|
266
|
-
raise ValueError(
|
|
267
|
-
f"Task {self.name} requires at least {self.required_args} arguments, "
|
|
268
|
-
f"but only received args:{args} kwargs{kwargs}."
|
|
269
|
-
)
|
|
270
266
|
|
|
271
267
|
ctx = internal_ctx()
|
|
272
268
|
if ctx.is_task_context():
|
|
@@ -276,19 +272,37 @@ class TaskDetails(ToJSONMixin):
|
|
|
276
272
|
from flyte._internal.controllers import get_controller
|
|
277
273
|
|
|
278
274
|
controller = get_controller()
|
|
275
|
+
if len(self.required_args) > 0:
|
|
276
|
+
if len(args) + len(kwargs) < len(self.required_args):
|
|
277
|
+
raise ValueError(
|
|
278
|
+
f"Task {self.name} requires at least {self.required_args} arguments, "
|
|
279
|
+
f"but only received args:{args} kwargs{kwargs}."
|
|
280
|
+
)
|
|
279
281
|
if controller:
|
|
280
|
-
return await controller.submit_task_ref(self
|
|
281
|
-
raise flyte.errors
|
|
282
|
+
return await controller.submit_task_ref(self, *args, **kwargs)
|
|
283
|
+
raise flyte.errors.ReferenceTaskError(
|
|
284
|
+
f"Reference tasks [{self.name}] cannot be executed locally, only remotely."
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
@property
|
|
288
|
+
def queue(self) -> Optional[str]:
|
|
289
|
+
"""
|
|
290
|
+
The queue to use for the task.
|
|
291
|
+
"""
|
|
292
|
+
return self.overriden_queue
|
|
282
293
|
|
|
283
294
|
def override(
|
|
284
295
|
self,
|
|
285
296
|
*,
|
|
286
|
-
|
|
297
|
+
short_name: Optional[str] = None,
|
|
287
298
|
resources: Optional[flyte.Resources] = None,
|
|
288
299
|
retries: Union[int, flyte.RetryStrategy] = 0,
|
|
289
300
|
timeout: Optional[flyte.TimeoutType] = None,
|
|
290
301
|
env_vars: Optional[Dict[str, str]] = None,
|
|
291
302
|
secrets: Optional[flyte.SecretRequest] = None,
|
|
303
|
+
max_inline_io_bytes: Optional[int] = None,
|
|
304
|
+
cache: Optional[flyte.Cache] = None,
|
|
305
|
+
queue: Optional[str] = None,
|
|
292
306
|
**kwargs: Any,
|
|
293
307
|
) -> TaskDetails:
|
|
294
308
|
if len(kwargs) > 0:
|
|
@@ -296,29 +310,57 @@ class TaskDetails(ToJSONMixin):
|
|
|
296
310
|
f"ReferenceTasks [{self.name}] do not support overriding with kwargs: {kwargs}, "
|
|
297
311
|
f"Check the parameters for override method."
|
|
298
312
|
)
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
313
|
+
pb2 = task_definition_pb2.TaskDetails()
|
|
314
|
+
pb2.CopyFrom(self.pb2)
|
|
315
|
+
|
|
316
|
+
if short_name:
|
|
317
|
+
pb2.metadata.short_name = short_name
|
|
318
|
+
|
|
319
|
+
template = pb2.spec.task_template
|
|
302
320
|
if secrets:
|
|
303
321
|
template.security_context.CopyFrom(get_security_context(secrets))
|
|
322
|
+
|
|
304
323
|
if template.HasField("container"):
|
|
305
324
|
if env_vars:
|
|
306
325
|
template.container.env.clear()
|
|
307
326
|
template.container.env.extend([literals_pb2.KeyValuePair(key=k, value=v) for k, v in env_vars.items()])
|
|
308
327
|
if resources:
|
|
309
328
|
template.container.resources.CopyFrom(get_proto_resources(resources))
|
|
329
|
+
|
|
330
|
+
md = template.metadata
|
|
310
331
|
if retries:
|
|
311
|
-
|
|
312
|
-
if timeout:
|
|
313
|
-
template.metadata.timeout.CopyFrom(get_proto_timeout(timeout))
|
|
332
|
+
md.retries.CopyFrom(get_proto_retry_strategy(retries))
|
|
314
333
|
|
|
315
|
-
|
|
334
|
+
if timeout:
|
|
335
|
+
md.timeout.CopyFrom(get_proto_timeout(timeout))
|
|
336
|
+
|
|
337
|
+
if cache:
|
|
338
|
+
if cache.behavior == "disable":
|
|
339
|
+
md.discoverable = False
|
|
340
|
+
md.discovery_version = ""
|
|
341
|
+
elif cache.behavior == "override":
|
|
342
|
+
md.discoverable = True
|
|
343
|
+
if not cache.version_override:
|
|
344
|
+
raise ValueError("cache.version_override must be set when cache.behavior is 'override'")
|
|
345
|
+
md.discovery_version = cache.version_override
|
|
346
|
+
else:
|
|
347
|
+
if cache.behavior == "auto":
|
|
348
|
+
raise ValueError("cache.behavior must be 'disable' or 'override' for reference tasks")
|
|
349
|
+
raise ValueError(f"Invalid cache behavior: {cache.behavior}.")
|
|
350
|
+
md.cache_serializable = cache.serialize
|
|
351
|
+
md.cache_ignore_input_vars[:] = list(cache.ignored_inputs or ())
|
|
352
|
+
|
|
353
|
+
return TaskDetails(
|
|
354
|
+
pb2,
|
|
355
|
+
max_inline_io_bytes=max_inline_io_bytes or self.max_inline_io_bytes,
|
|
356
|
+
overriden_queue=queue,
|
|
357
|
+
)
|
|
316
358
|
|
|
317
359
|
def __rich_repr__(self) -> rich.repr.Result:
|
|
318
360
|
"""
|
|
319
361
|
Rich representation of the task.
|
|
320
362
|
"""
|
|
321
|
-
yield "
|
|
363
|
+
yield "short_name", self.pb2.spec.short_name
|
|
322
364
|
yield "environment", self.pb2.spec.environment
|
|
323
365
|
yield "default_inputs_keys", self.default_input_args
|
|
324
366
|
yield "required_args", self.required_args
|
|
@@ -408,7 +450,7 @@ class Task(ToJSONMixin):
|
|
|
408
450
|
sort_pb2 = list_pb2.Sort(
|
|
409
451
|
key=sort_by[0], direction=list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING
|
|
410
452
|
)
|
|
411
|
-
cfg =
|
|
453
|
+
cfg = get_init_config()
|
|
412
454
|
filters = []
|
|
413
455
|
if by_task_name:
|
|
414
456
|
filters.append(
|
flyte/remote/_trigger.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from functools import cached_property
|
|
5
|
+
from typing import AsyncIterator
|
|
6
|
+
|
|
7
|
+
import grpc.aio
|
|
8
|
+
from flyteidl2.common import identifier_pb2, list_pb2
|
|
9
|
+
from flyteidl2.task import common_pb2, task_definition_pb2
|
|
10
|
+
from flyteidl2.trigger import trigger_definition_pb2, trigger_service_pb2
|
|
11
|
+
|
|
12
|
+
import flyte
|
|
13
|
+
from flyte._initialize import ensure_client, get_client, get_init_config
|
|
14
|
+
from flyte._internal.runtime import trigger_serde
|
|
15
|
+
from flyte.syncify import syncify
|
|
16
|
+
|
|
17
|
+
from ._common import ToJSONMixin
|
|
18
|
+
from ._task import Task, TaskDetails
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class TriggerDetails(ToJSONMixin):
|
|
23
|
+
pb2: trigger_definition_pb2.TriggerDetails
|
|
24
|
+
|
|
25
|
+
@syncify
|
|
26
|
+
@classmethod
|
|
27
|
+
async def get(cls, *, name: str, task_name: str) -> TriggerDetails:
|
|
28
|
+
"""
|
|
29
|
+
Retrieve detailed information about a specific trigger by its name.
|
|
30
|
+
"""
|
|
31
|
+
ensure_client()
|
|
32
|
+
cfg = get_init_config()
|
|
33
|
+
resp = await get_client().trigger_service.GetTriggerDetails(
|
|
34
|
+
request=trigger_service_pb2.GetTriggerDetailsRequest(
|
|
35
|
+
name=identifier_pb2.TriggerName(
|
|
36
|
+
task_name=task_name,
|
|
37
|
+
name=name,
|
|
38
|
+
org=cfg.org,
|
|
39
|
+
project=cfg.project,
|
|
40
|
+
domain=cfg.domain,
|
|
41
|
+
),
|
|
42
|
+
)
|
|
43
|
+
)
|
|
44
|
+
return cls(pb2=resp.trigger)
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def name(self) -> str:
|
|
48
|
+
return self.id.name.name
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def id(self) -> identifier_pb2.TriggerIdentifier:
|
|
52
|
+
return self.pb2.id
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def task_name(self) -> str:
|
|
56
|
+
return self.pb2.id.name.task_name
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def automation_spec(self) -> common_pb2.TriggerAutomationSpec:
|
|
60
|
+
return self.pb2.automation_spec
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def metadata(self) -> trigger_definition_pb2.TriggerMetadata:
|
|
64
|
+
return self.pb2.metadata
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def status(self) -> trigger_definition_pb2.TriggerStatus:
|
|
68
|
+
return self.pb2.status
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def is_active(self) -> bool:
|
|
72
|
+
return self.pb2.spec.active
|
|
73
|
+
|
|
74
|
+
@cached_property
|
|
75
|
+
def trigger(self) -> trigger_definition_pb2.Trigger:
|
|
76
|
+
return trigger_definition_pb2.Trigger(
|
|
77
|
+
id=self.pb2.id,
|
|
78
|
+
automation_spec=self.automation_spec,
|
|
79
|
+
metadata=self.metadata,
|
|
80
|
+
status=self.status,
|
|
81
|
+
active=self.is_active,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class Trigger(ToJSONMixin):
|
|
87
|
+
pb2: trigger_definition_pb2.Trigger
|
|
88
|
+
details: TriggerDetails | None = None
|
|
89
|
+
|
|
90
|
+
@syncify
|
|
91
|
+
@classmethod
|
|
92
|
+
async def create(
|
|
93
|
+
cls,
|
|
94
|
+
trigger: flyte.Trigger,
|
|
95
|
+
task_name: str,
|
|
96
|
+
task_version: str | None = None,
|
|
97
|
+
) -> Trigger:
|
|
98
|
+
"""
|
|
99
|
+
Create a new trigger in the Flyte platform.
|
|
100
|
+
|
|
101
|
+
:param trigger: The flyte.Trigger object containing the trigger definition.
|
|
102
|
+
:param task_name: Optional name of the task to associate with the trigger.
|
|
103
|
+
"""
|
|
104
|
+
ensure_client()
|
|
105
|
+
cfg = get_init_config()
|
|
106
|
+
|
|
107
|
+
# Fetch the task to ensure it exists and to get its input definitions
|
|
108
|
+
try:
|
|
109
|
+
lazy = (
|
|
110
|
+
Task.get(name=task_name, version=task_version)
|
|
111
|
+
if task_version
|
|
112
|
+
else Task.get(name=task_name, auto_version="latest")
|
|
113
|
+
)
|
|
114
|
+
task: TaskDetails = await lazy.fetch.aio()
|
|
115
|
+
|
|
116
|
+
task_trigger = await trigger_serde.to_task_trigger(
|
|
117
|
+
t=trigger,
|
|
118
|
+
task_name=task_name,
|
|
119
|
+
task_inputs=task.pb2.spec.task_template.interface.inputs,
|
|
120
|
+
task_default_inputs=list(task.pb2.spec.default_inputs),
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
resp = await get_client().trigger_service.DeployTrigger(
|
|
124
|
+
request=trigger_service_pb2.DeployTriggerRequest(
|
|
125
|
+
name=identifier_pb2.TriggerName(
|
|
126
|
+
name=trigger.name,
|
|
127
|
+
task_name=task_name,
|
|
128
|
+
org=cfg.org,
|
|
129
|
+
project=cfg.project,
|
|
130
|
+
domain=cfg.domain,
|
|
131
|
+
),
|
|
132
|
+
spec=trigger_definition_pb2.TriggerSpec(
|
|
133
|
+
active=task_trigger.spec.active,
|
|
134
|
+
inputs=task_trigger.spec.inputs,
|
|
135
|
+
run_spec=task_trigger.spec.run_spec,
|
|
136
|
+
task_version=task.version,
|
|
137
|
+
),
|
|
138
|
+
automation_spec=task_trigger.automation_spec,
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
details = TriggerDetails(pb2=resp.trigger)
|
|
143
|
+
|
|
144
|
+
return cls(pb2=details.trigger, details=details)
|
|
145
|
+
except grpc.aio.AioRpcError as e:
|
|
146
|
+
if e.code() == grpc.StatusCode.NOT_FOUND:
|
|
147
|
+
raise ValueError(f"Task {task_name}:{task_version or 'latest'} not found") from e
|
|
148
|
+
raise
|
|
149
|
+
|
|
150
|
+
@syncify
|
|
151
|
+
@classmethod
|
|
152
|
+
async def get(cls, *, name: str, task_name: str) -> TriggerDetails:
|
|
153
|
+
"""
|
|
154
|
+
Retrieve a trigger by its name and associated task name.
|
|
155
|
+
"""
|
|
156
|
+
return await TriggerDetails.get.aio(name=name, task_name=task_name)
|
|
157
|
+
|
|
158
|
+
@syncify
|
|
159
|
+
@classmethod
|
|
160
|
+
async def listall(
|
|
161
|
+
cls, task_name: str | None = None, task_version: str | None = None, limit: int = 100
|
|
162
|
+
) -> AsyncIterator[Trigger]:
|
|
163
|
+
"""
|
|
164
|
+
List all triggers associated with a specific task or all tasks if no task name is provided.
|
|
165
|
+
"""
|
|
166
|
+
ensure_client()
|
|
167
|
+
cfg = get_init_config()
|
|
168
|
+
token = None
|
|
169
|
+
task_name_id = None
|
|
170
|
+
project_id = None
|
|
171
|
+
task_id = None
|
|
172
|
+
if task_name and task_version:
|
|
173
|
+
task_id = task_definition_pb2.TaskIdentifier(
|
|
174
|
+
name=task_name,
|
|
175
|
+
project=cfg.project,
|
|
176
|
+
domain=cfg.domain,
|
|
177
|
+
org=cfg.org,
|
|
178
|
+
version=task_version,
|
|
179
|
+
)
|
|
180
|
+
elif task_name:
|
|
181
|
+
task_name_id = task_definition_pb2.TaskName(
|
|
182
|
+
name=task_name,
|
|
183
|
+
project=cfg.project,
|
|
184
|
+
domain=cfg.domain,
|
|
185
|
+
org=cfg.org,
|
|
186
|
+
)
|
|
187
|
+
else:
|
|
188
|
+
project_id = identifier_pb2.ProjectIdentifier(
|
|
189
|
+
organization=cfg.org,
|
|
190
|
+
domain=cfg.domain,
|
|
191
|
+
name=cfg.project,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
while True:
|
|
195
|
+
resp = await get_client().trigger_service.ListTriggers(
|
|
196
|
+
request=trigger_service_pb2.ListTriggersRequest(
|
|
197
|
+
project_id=project_id,
|
|
198
|
+
task_id=task_id,
|
|
199
|
+
task_name=task_name_id,
|
|
200
|
+
request=list_pb2.ListRequest(
|
|
201
|
+
limit=limit,
|
|
202
|
+
token=token,
|
|
203
|
+
),
|
|
204
|
+
)
|
|
205
|
+
)
|
|
206
|
+
token = resp.token
|
|
207
|
+
for r in resp.triggers:
|
|
208
|
+
yield cls(r)
|
|
209
|
+
if not token:
|
|
210
|
+
break
|
|
211
|
+
|
|
212
|
+
@syncify
|
|
213
|
+
@classmethod
|
|
214
|
+
async def update(cls, name: str, task_name: str, active: bool):
|
|
215
|
+
"""
|
|
216
|
+
Pause a trigger by its name and associated task name.
|
|
217
|
+
"""
|
|
218
|
+
ensure_client()
|
|
219
|
+
cfg = get_init_config()
|
|
220
|
+
await get_client().trigger_service.UpdateTriggers(
|
|
221
|
+
request=trigger_service_pb2.UpdateTriggersRequest(
|
|
222
|
+
names=[
|
|
223
|
+
identifier_pb2.TriggerName(
|
|
224
|
+
org=cfg.org,
|
|
225
|
+
project=cfg.project,
|
|
226
|
+
domain=cfg.domain,
|
|
227
|
+
name=name,
|
|
228
|
+
task_name=task_name,
|
|
229
|
+
)
|
|
230
|
+
],
|
|
231
|
+
active=active,
|
|
232
|
+
)
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
@syncify
|
|
236
|
+
@classmethod
|
|
237
|
+
async def delete(cls, name: str, task_name: str):
|
|
238
|
+
"""
|
|
239
|
+
Delete a trigger by its name.
|
|
240
|
+
"""
|
|
241
|
+
ensure_client()
|
|
242
|
+
cfg = get_init_config()
|
|
243
|
+
await get_client().trigger_service.DeleteTriggers(
|
|
244
|
+
request=trigger_service_pb2.DeleteTriggersRequest(
|
|
245
|
+
names=[
|
|
246
|
+
identifier_pb2.TriggerName(
|
|
247
|
+
org=cfg.org,
|
|
248
|
+
project=cfg.project,
|
|
249
|
+
domain=cfg.domain,
|
|
250
|
+
name=name,
|
|
251
|
+
task_name=task_name,
|
|
252
|
+
)
|
|
253
|
+
],
|
|
254
|
+
)
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def id(self) -> identifier_pb2.TriggerIdentifier:
|
|
259
|
+
return self.pb2.id
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def name(self) -> str:
|
|
263
|
+
return self.id.name.name
|
|
264
|
+
|
|
265
|
+
@property
|
|
266
|
+
def task_name(self) -> str:
|
|
267
|
+
return self.id.name.task_name
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def automation_spec(self) -> common_pb2.TriggerAutomationSpec:
|
|
271
|
+
return self.pb2.automation_spec
|
|
272
|
+
|
|
273
|
+
async def get_details(self) -> TriggerDetails:
|
|
274
|
+
"""
|
|
275
|
+
Get detailed information about this trigger.
|
|
276
|
+
"""
|
|
277
|
+
if not self.details:
|
|
278
|
+
details = await TriggerDetails.get.aio(name=self.pb2.id.name.name)
|
|
279
|
+
self.details = details
|
|
280
|
+
return self.details
|
|
281
|
+
|
|
282
|
+
@property
|
|
283
|
+
def is_active(self) -> bool:
|
|
284
|
+
return self.pb2.active
|
|
285
|
+
|
|
286
|
+
def _rich_automation(self, automation: common_pb2.TriggerAutomationSpec):
|
|
287
|
+
if automation.type == common_pb2.TriggerAutomationSpec.type.TYPE_NONE:
|
|
288
|
+
yield "none", None
|
|
289
|
+
elif automation.type == common_pb2.TriggerAutomationSpec.type.TYPE_SCHEDULE:
|
|
290
|
+
if automation.schedule.cron is not None:
|
|
291
|
+
yield "cron", automation.schedule.cron
|
|
292
|
+
elif automation.schedule.rate is not None:
|
|
293
|
+
r = automation.schedule.rate
|
|
294
|
+
yield (
|
|
295
|
+
"fixed_rate",
|
|
296
|
+
(
|
|
297
|
+
f"Every [{r.value}] {r.unit} starting at "
|
|
298
|
+
f"{r.start_time.ToDatetime() if automation.HasField('start_time') else 'now'}"
|
|
299
|
+
),
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
def __rich_repr__(self):
|
|
303
|
+
yield "task_name", self.task_name
|
|
304
|
+
yield "name", self.name
|
|
305
|
+
yield from self._rich_automation(self.pb2.automation_spec)
|
|
306
|
+
yield "auto_activate", self.is_active
|
flyte/remote/_user.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from flyteidl.service import identity_pb2
|
|
6
|
+
from flyteidl.service.identity_pb2 import UserInfoResponse
|
|
7
|
+
|
|
8
|
+
from .._initialize import ensure_client, get_client
|
|
9
|
+
from ..syncify import syncify
|
|
10
|
+
from ._common import ToJSONMixin
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class User(ToJSONMixin):
|
|
15
|
+
pb2: UserInfoResponse
|
|
16
|
+
|
|
17
|
+
@syncify
|
|
18
|
+
@classmethod
|
|
19
|
+
async def get(cls) -> User:
|
|
20
|
+
"""
|
|
21
|
+
Fetches information about the currently logged in user.
|
|
22
|
+
Returns: A User object containing details about the user.
|
|
23
|
+
"""
|
|
24
|
+
ensure_client()
|
|
25
|
+
|
|
26
|
+
resp = await get_client().identity_service.UserInfo(identity_pb2.UserInfoRequest())
|
|
27
|
+
return cls(resp)
|
|
28
|
+
|
|
29
|
+
def subject(self) -> str:
|
|
30
|
+
return self.pb2.subject
|
|
31
|
+
|
|
32
|
+
def name(self) -> str:
|
|
33
|
+
return self.pb2.name
|
flyte/report/_report.py
CHANGED
|
@@ -4,7 +4,6 @@ import string
|
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
5
|
from typing import TYPE_CHECKING, Dict, List, Union
|
|
6
6
|
|
|
7
|
-
from flyte._internal.runtime import io
|
|
8
7
|
from flyte._logging import logger
|
|
9
8
|
from flyte._tools import ipython_check
|
|
10
9
|
from flyte.syncify import syncify
|
|
@@ -133,6 +132,7 @@ async def flush():
|
|
|
133
132
|
"""
|
|
134
133
|
import flyte.storage as storage
|
|
135
134
|
from flyte._context import internal_ctx
|
|
135
|
+
from flyte._internal.runtime import io
|
|
136
136
|
|
|
137
137
|
if not internal_ctx().is_task_context():
|
|
138
138
|
return
|
flyte/storage/__init__.py
CHANGED
|
@@ -3,6 +3,8 @@ __all__ = [
|
|
|
3
3
|
"GCS",
|
|
4
4
|
"S3",
|
|
5
5
|
"Storage",
|
|
6
|
+
"exists",
|
|
7
|
+
"exists_sync",
|
|
6
8
|
"get",
|
|
7
9
|
"get_configured_fsspec_kwargs",
|
|
8
10
|
"get_random_local_directory",
|
|
@@ -11,13 +13,15 @@ __all__ = [
|
|
|
11
13
|
"get_underlying_filesystem",
|
|
12
14
|
"is_remote",
|
|
13
15
|
"join",
|
|
16
|
+
"open",
|
|
14
17
|
"put",
|
|
15
18
|
"put_stream",
|
|
16
|
-
"put_stream",
|
|
17
19
|
]
|
|
18
20
|
|
|
19
21
|
from ._config import ABFS, GCS, S3, Storage
|
|
20
22
|
from ._storage import (
|
|
23
|
+
exists,
|
|
24
|
+
exists_sync,
|
|
21
25
|
get,
|
|
22
26
|
get_configured_fsspec_kwargs,
|
|
23
27
|
get_random_local_directory,
|
|
@@ -26,6 +30,7 @@ from ._storage import (
|
|
|
26
30
|
get_underlying_filesystem,
|
|
27
31
|
is_remote,
|
|
28
32
|
join,
|
|
33
|
+
open,
|
|
29
34
|
put,
|
|
30
35
|
put_stream,
|
|
31
36
|
)
|
flyte/storage/_config.py
CHANGED
|
@@ -61,6 +61,7 @@ class S3(Storage):
|
|
|
61
61
|
endpoint: typing.Optional[str] = None
|
|
62
62
|
access_key_id: typing.Optional[str] = None
|
|
63
63
|
secret_access_key: typing.Optional[str] = None
|
|
64
|
+
region: typing.Optional[str] = None
|
|
64
65
|
|
|
65
66
|
_KEY_ENV_VAR_MAPPING: ClassVar[typing.Dict[str, str]] = {
|
|
66
67
|
"endpoint": "FLYTE_AWS_ENDPOINT",
|
|
@@ -76,7 +77,7 @@ class S3(Storage):
|
|
|
76
77
|
_KEY_SKIP_SIGNATURE: ClassVar = "skip_signature"
|
|
77
78
|
|
|
78
79
|
@classmethod
|
|
79
|
-
def auto(cls) -> S3:
|
|
80
|
+
def auto(cls, region: str | None = None) -> S3:
|
|
80
81
|
"""
|
|
81
82
|
:return: Config
|
|
82
83
|
"""
|
|
@@ -88,6 +89,7 @@ class S3(Storage):
|
|
|
88
89
|
kwargs = set_if_exists(kwargs, "endpoint", endpoint)
|
|
89
90
|
kwargs = set_if_exists(kwargs, "access_key_id", access_key_id)
|
|
90
91
|
kwargs = set_if_exists(kwargs, "secret_access_key", secret_access_key)
|
|
92
|
+
kwargs = set_if_exists(kwargs, "region", region)
|
|
91
93
|
|
|
92
94
|
return S3(**kwargs)
|
|
93
95
|
|
|
@@ -141,6 +143,8 @@ class S3(Storage):
|
|
|
141
143
|
kwargs["config"] = config
|
|
142
144
|
kwargs["client_options"] = client_options or None
|
|
143
145
|
kwargs["retry_config"] = retry_config or None
|
|
146
|
+
if self.region:
|
|
147
|
+
kwargs["region"] = self.region
|
|
144
148
|
|
|
145
149
|
return kwargs
|
|
146
150
|
|