flyte 0.2.0b1__py3-none-any.whl → 2.0.0b46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyte/__init__.py +83 -30
- flyte/_bin/connect.py +61 -0
- flyte/_bin/debug.py +38 -0
- flyte/_bin/runtime.py +87 -19
- flyte/_bin/serve.py +351 -0
- flyte/_build.py +3 -2
- flyte/_cache/cache.py +6 -5
- flyte/_cache/local_cache.py +216 -0
- flyte/_code_bundle/_ignore.py +31 -5
- flyte/_code_bundle/_packaging.py +42 -11
- flyte/_code_bundle/_utils.py +57 -34
- flyte/_code_bundle/bundle.py +130 -27
- flyte/_constants.py +1 -0
- flyte/_context.py +21 -5
- flyte/_custom_context.py +73 -0
- flyte/_debug/constants.py +37 -0
- flyte/_debug/utils.py +17 -0
- flyte/_debug/vscode.py +315 -0
- flyte/_deploy.py +396 -75
- flyte/_deployer.py +109 -0
- flyte/_environment.py +94 -11
- flyte/_excepthook.py +37 -0
- flyte/_group.py +2 -1
- flyte/_hash.py +1 -16
- flyte/_image.py +544 -231
- flyte/_initialize.py +456 -316
- flyte/_interface.py +40 -5
- flyte/_internal/controllers/__init__.py +22 -8
- flyte/_internal/controllers/_local_controller.py +159 -35
- flyte/_internal/controllers/_trace.py +18 -10
- flyte/_internal/controllers/remote/__init__.py +38 -9
- flyte/_internal/controllers/remote/_action.py +82 -12
- flyte/_internal/controllers/remote/_client.py +6 -2
- flyte/_internal/controllers/remote/_controller.py +290 -64
- flyte/_internal/controllers/remote/_core.py +155 -95
- flyte/_internal/controllers/remote/_informer.py +40 -20
- flyte/_internal/controllers/remote/_service_protocol.py +2 -2
- flyte/_internal/imagebuild/__init__.py +2 -10
- flyte/_internal/imagebuild/docker_builder.py +391 -84
- flyte/_internal/imagebuild/image_builder.py +111 -55
- flyte/_internal/imagebuild/remote_builder.py +409 -0
- flyte/_internal/imagebuild/utils.py +79 -0
- flyte/_internal/resolvers/_app_env_module.py +92 -0
- flyte/_internal/resolvers/_task_module.py +5 -38
- flyte/_internal/resolvers/app_env.py +26 -0
- flyte/_internal/resolvers/common.py +8 -1
- flyte/_internal/resolvers/default.py +2 -2
- flyte/_internal/runtime/convert.py +319 -36
- flyte/_internal/runtime/entrypoints.py +106 -18
- flyte/_internal/runtime/io.py +71 -23
- flyte/_internal/runtime/resources_serde.py +21 -7
- flyte/_internal/runtime/reuse.py +125 -0
- flyte/_internal/runtime/rusty.py +196 -0
- flyte/_internal/runtime/task_serde.py +239 -66
- flyte/_internal/runtime/taskrunner.py +48 -8
- flyte/_internal/runtime/trigger_serde.py +162 -0
- flyte/_internal/runtime/types_serde.py +7 -16
- flyte/_keyring/file.py +115 -0
- flyte/_link.py +30 -0
- flyte/_logging.py +241 -42
- flyte/_map.py +312 -0
- flyte/_metrics.py +59 -0
- flyte/_module.py +74 -0
- flyte/_pod.py +30 -0
- flyte/_resources.py +296 -33
- flyte/_retry.py +1 -7
- flyte/_reusable_environment.py +72 -7
- flyte/_run.py +462 -132
- flyte/_secret.py +47 -11
- flyte/_serve.py +333 -0
- flyte/_task.py +245 -56
- flyte/_task_environment.py +219 -97
- flyte/_task_plugins.py +47 -0
- flyte/_tools.py +8 -8
- flyte/_trace.py +15 -24
- flyte/_trigger.py +1027 -0
- flyte/_utils/__init__.py +12 -1
- flyte/_utils/asyn.py +3 -1
- flyte/_utils/async_cache.py +139 -0
- flyte/_utils/coro_management.py +5 -4
- flyte/_utils/description_parser.py +19 -0
- flyte/_utils/docker_credentials.py +173 -0
- flyte/_utils/helpers.py +45 -19
- flyte/_utils/module_loader.py +123 -0
- flyte/_utils/org_discovery.py +57 -0
- flyte/_utils/uv_script_parser.py +8 -1
- flyte/_version.py +16 -3
- flyte/app/__init__.py +27 -0
- flyte/app/_app_environment.py +362 -0
- flyte/app/_connector_environment.py +40 -0
- flyte/app/_deploy.py +130 -0
- flyte/app/_parameter.py +343 -0
- flyte/app/_runtime/__init__.py +3 -0
- flyte/app/_runtime/app_serde.py +383 -0
- flyte/app/_types.py +113 -0
- flyte/app/extras/__init__.py +9 -0
- flyte/app/extras/_auth_middleware.py +217 -0
- flyte/app/extras/_fastapi.py +93 -0
- flyte/app/extras/_model_loader/__init__.py +3 -0
- flyte/app/extras/_model_loader/config.py +7 -0
- flyte/app/extras/_model_loader/loader.py +288 -0
- flyte/cli/__init__.py +12 -0
- flyte/cli/_abort.py +28 -0
- flyte/cli/_build.py +114 -0
- flyte/cli/_common.py +493 -0
- flyte/cli/_create.py +371 -0
- flyte/cli/_delete.py +45 -0
- flyte/cli/_deploy.py +401 -0
- flyte/cli/_gen.py +316 -0
- flyte/cli/_get.py +446 -0
- flyte/cli/_option.py +33 -0
- flyte/{_cli → cli}/_params.py +57 -17
- flyte/cli/_plugins.py +209 -0
- flyte/cli/_prefetch.py +292 -0
- flyte/cli/_run.py +690 -0
- flyte/cli/_serve.py +338 -0
- flyte/cli/_update.py +86 -0
- flyte/cli/_user.py +20 -0
- flyte/cli/main.py +246 -0
- flyte/config/__init__.py +2 -167
- flyte/config/_config.py +215 -163
- flyte/config/_internal.py +10 -1
- flyte/config/_reader.py +225 -0
- flyte/connectors/__init__.py +11 -0
- flyte/connectors/_connector.py +330 -0
- flyte/connectors/_server.py +194 -0
- flyte/connectors/utils.py +159 -0
- flyte/errors.py +134 -2
- flyte/extend.py +24 -0
- flyte/extras/_container.py +69 -56
- flyte/git/__init__.py +3 -0
- flyte/git/_config.py +279 -0
- flyte/io/__init__.py +8 -1
- flyte/io/{structured_dataset → _dataframe}/__init__.py +32 -30
- flyte/io/{structured_dataset → _dataframe}/basic_dfs.py +75 -68
- flyte/io/{structured_dataset/structured_dataset.py → _dataframe/dataframe.py} +207 -242
- flyte/io/_dir.py +575 -113
- flyte/io/_file.py +587 -141
- flyte/io/_hashing_io.py +342 -0
- flyte/io/extend.py +7 -0
- flyte/models.py +635 -0
- flyte/prefetch/__init__.py +22 -0
- flyte/prefetch/_hf_model.py +563 -0
- flyte/remote/__init__.py +14 -3
- flyte/remote/_action.py +879 -0
- flyte/remote/_app.py +346 -0
- flyte/remote/_auth_metadata.py +42 -0
- flyte/remote/_client/_protocols.py +62 -4
- flyte/remote/_client/auth/_auth_utils.py +19 -0
- flyte/remote/_client/auth/_authenticators/base.py +8 -2
- flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
- flyte/remote/_client/auth/_authenticators/factory.py +4 -0
- flyte/remote/_client/auth/_authenticators/passthrough.py +79 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +17 -18
- flyte/remote/_client/auth/_channel.py +47 -18
- flyte/remote/_client/auth/_client_config.py +5 -3
- flyte/remote/_client/auth/_keyring.py +15 -2
- flyte/remote/_client/auth/_token_client.py +3 -3
- flyte/remote/_client/controlplane.py +206 -18
- flyte/remote/_common.py +66 -0
- flyte/remote/_data.py +107 -22
- flyte/remote/_logs.py +116 -33
- flyte/remote/_project.py +21 -19
- flyte/remote/_run.py +164 -631
- flyte/remote/_secret.py +72 -29
- flyte/remote/_task.py +387 -46
- flyte/remote/_trigger.py +368 -0
- flyte/remote/_user.py +43 -0
- flyte/report/_report.py +10 -6
- flyte/storage/__init__.py +13 -1
- flyte/storage/_config.py +237 -0
- flyte/storage/_parallel_reader.py +289 -0
- flyte/storage/_storage.py +268 -59
- flyte/syncify/__init__.py +56 -0
- flyte/syncify/_api.py +414 -0
- flyte/types/__init__.py +39 -0
- flyte/types/_interface.py +22 -7
- flyte/{io/pickle/transformer.py → types/_pickle.py} +37 -9
- flyte/types/_string_literals.py +8 -9
- flyte/types/_type_engine.py +226 -126
- flyte/types/_utils.py +1 -1
- flyte-2.0.0b46.data/scripts/debug.py +38 -0
- flyte-2.0.0b46.data/scripts/runtime.py +194 -0
- flyte-2.0.0b46.dist-info/METADATA +352 -0
- flyte-2.0.0b46.dist-info/RECORD +221 -0
- flyte-2.0.0b46.dist-info/entry_points.txt +8 -0
- flyte-2.0.0b46.dist-info/licenses/LICENSE +201 -0
- flyte/_api_commons.py +0 -3
- flyte/_cli/_common.py +0 -299
- flyte/_cli/_create.py +0 -42
- flyte/_cli/_delete.py +0 -23
- flyte/_cli/_deploy.py +0 -140
- flyte/_cli/_get.py +0 -235
- flyte/_cli/_run.py +0 -174
- flyte/_cli/main.py +0 -98
- flyte/_datastructures.py +0 -342
- flyte/_internal/controllers/pbhash.py +0 -39
- flyte/_protos/common/authorization_pb2.py +0 -66
- flyte/_protos/common/authorization_pb2.pyi +0 -108
- flyte/_protos/common/authorization_pb2_grpc.py +0 -4
- flyte/_protos/common/identifier_pb2.py +0 -71
- flyte/_protos/common/identifier_pb2.pyi +0 -82
- flyte/_protos/common/identifier_pb2_grpc.py +0 -4
- flyte/_protos/common/identity_pb2.py +0 -48
- flyte/_protos/common/identity_pb2.pyi +0 -72
- flyte/_protos/common/identity_pb2_grpc.py +0 -4
- flyte/_protos/common/list_pb2.py +0 -36
- flyte/_protos/common/list_pb2.pyi +0 -69
- flyte/_protos/common/list_pb2_grpc.py +0 -4
- flyte/_protos/common/policy_pb2.py +0 -37
- flyte/_protos/common/policy_pb2.pyi +0 -27
- flyte/_protos/common/policy_pb2_grpc.py +0 -4
- flyte/_protos/common/role_pb2.py +0 -37
- flyte/_protos/common/role_pb2.pyi +0 -53
- flyte/_protos/common/role_pb2_grpc.py +0 -4
- flyte/_protos/common/runtime_version_pb2.py +0 -28
- flyte/_protos/common/runtime_version_pb2.pyi +0 -24
- flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
- flyte/_protos/logs/dataplane/payload_pb2.py +0 -96
- flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -168
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/definition_pb2.py +0 -49
- flyte/_protos/secret/definition_pb2.pyi +0 -93
- flyte/_protos/secret/definition_pb2_grpc.py +0 -4
- flyte/_protos/secret/payload_pb2.py +0 -62
- flyte/_protos/secret/payload_pb2.pyi +0 -94
- flyte/_protos/secret/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/secret_pb2.py +0 -38
- flyte/_protos/secret/secret_pb2.pyi +0 -6
- flyte/_protos/secret/secret_pb2_grpc.py +0 -198
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
- flyte/_protos/validate/validate/validate_pb2.py +0 -76
- flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
- flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
- flyte/_protos/workflow/queue_service_pb2.py +0 -106
- flyte/_protos/workflow/queue_service_pb2.pyi +0 -141
- flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
- flyte/_protos/workflow/run_definition_pb2.py +0 -128
- flyte/_protos/workflow/run_definition_pb2.pyi +0 -310
- flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
- flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
- flyte/_protos/workflow/run_service_pb2.py +0 -133
- flyte/_protos/workflow/run_service_pb2.pyi +0 -175
- flyte/_protos/workflow/run_service_pb2_grpc.py +0 -412
- flyte/_protos/workflow/state_service_pb2.py +0 -58
- flyte/_protos/workflow/state_service_pb2.pyi +0 -71
- flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
- flyte/_protos/workflow/task_definition_pb2.py +0 -72
- flyte/_protos/workflow/task_definition_pb2.pyi +0 -65
- flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/task_service_pb2.py +0 -44
- flyte/_protos/workflow/task_service_pb2.pyi +0 -31
- flyte/_protos/workflow/task_service_pb2_grpc.py +0 -104
- flyte/io/_dataframe.py +0 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/remote/_console.py +0 -18
- flyte-0.2.0b1.dist-info/METADATA +0 -179
- flyte-0.2.0b1.dist-info/RECORD +0 -204
- flyte-0.2.0b1.dist-info/entry_points.txt +0 -3
- /flyte/{_cli → _debug}/__init__.py +0 -0
- /flyte/{_protos → _keyring}/__init__.py +0 -0
- {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/WHEEL +0 -0
- {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/top_level.txt +0 -0
flyte/remote/_task.py
CHANGED
|
@@ -1,19 +1,42 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import functools
|
|
4
5
|
from dataclasses import dataclass
|
|
5
|
-
from
|
|
6
|
-
from typing import Any, Callable, Coroutine, Dict, Literal, Optional, Union
|
|
6
|
+
from typing import Any, AsyncIterator, Callable, Coroutine, Dict, Iterator, Literal, Optional, Tuple, Union, cast
|
|
7
7
|
|
|
8
|
+
import grpc
|
|
8
9
|
import rich.repr
|
|
10
|
+
from flyteidl2.common import identifier_pb2, list_pb2
|
|
11
|
+
from flyteidl2.core import literals_pb2
|
|
12
|
+
from flyteidl2.task import task_definition_pb2, task_service_pb2
|
|
9
13
|
|
|
10
14
|
import flyte
|
|
11
15
|
import flyte.errors
|
|
12
|
-
from flyte.
|
|
16
|
+
from flyte._cache.cache import CacheBehavior
|
|
13
17
|
from flyte._context import internal_ctx
|
|
14
|
-
from flyte.
|
|
15
|
-
from flyte.
|
|
16
|
-
from flyte.
|
|
18
|
+
from flyte._initialize import ensure_client, get_client, get_init_config
|
|
19
|
+
from flyte._internal.runtime.resources_serde import get_proto_resources
|
|
20
|
+
from flyte._internal.runtime.task_serde import get_proto_retry_strategy, get_proto_timeout, get_security_context
|
|
21
|
+
from flyte._logging import logger
|
|
22
|
+
from flyte.models import NativeInterface
|
|
23
|
+
from flyte.syncify import syncify
|
|
24
|
+
|
|
25
|
+
from ._common import ToJSONMixin
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _repr_task_metadata(metadata: task_definition_pb2.TaskMetadata) -> rich.repr.Result:
|
|
29
|
+
"""
|
|
30
|
+
Rich representation of the task metadata.
|
|
31
|
+
"""
|
|
32
|
+
if metadata.deployed_by:
|
|
33
|
+
if metadata.deployed_by.user:
|
|
34
|
+
yield "deployed_by", f"User: {metadata.deployed_by.user.spec.email}"
|
|
35
|
+
else:
|
|
36
|
+
yield "deployed_by", f"App: {metadata.deployed_by.application.spec.name}"
|
|
37
|
+
yield "short_name", metadata.short_name
|
|
38
|
+
yield "deployed_at", metadata.deployed_at.ToDatetime()
|
|
39
|
+
yield "environment_name", metadata.environment_name
|
|
17
40
|
|
|
18
41
|
|
|
19
42
|
class LazyEntity:
|
|
@@ -22,33 +45,47 @@ class LazyEntity:
|
|
|
22
45
|
The entity is derived from RemoteEntity so that it behaves exactly like the mimicked entity.
|
|
23
46
|
"""
|
|
24
47
|
|
|
25
|
-
def __init__(self, name: str, getter: Callable[..., Coroutine[Any, Any,
|
|
26
|
-
self._task: Optional[
|
|
48
|
+
def __init__(self, name: str, getter: Callable[..., Coroutine[Any, Any, TaskDetails]], *args, **kwargs):
|
|
49
|
+
self._task: Optional[TaskDetails] = None
|
|
27
50
|
self._getter = getter
|
|
28
51
|
self._name = name
|
|
29
|
-
self._mutex = Lock()
|
|
52
|
+
self._mutex = asyncio.Lock()
|
|
30
53
|
|
|
31
54
|
@property
|
|
32
55
|
def name(self) -> str:
|
|
56
|
+
"""
|
|
57
|
+
Get the name of the task.
|
|
58
|
+
"""
|
|
33
59
|
return self._name
|
|
34
60
|
|
|
35
|
-
@
|
|
36
|
-
async def fetch(self) ->
|
|
61
|
+
@syncify
|
|
62
|
+
async def fetch(self) -> TaskDetails:
|
|
37
63
|
"""
|
|
38
64
|
Forwards all other attributes to task, causing the task to be fetched!
|
|
39
65
|
"""
|
|
40
|
-
with self._mutex:
|
|
66
|
+
async with self._mutex:
|
|
41
67
|
if self._task is None:
|
|
42
68
|
self._task = await self._getter()
|
|
43
|
-
|
|
44
|
-
|
|
69
|
+
if self._task is None:
|
|
70
|
+
raise RuntimeError(f"Error downloading the task {self._name}, (check original exception...)")
|
|
45
71
|
return self._task
|
|
46
72
|
|
|
73
|
+
@syncify
|
|
74
|
+
async def override(
|
|
75
|
+
self,
|
|
76
|
+
**kwargs: Any,
|
|
77
|
+
) -> LazyEntity:
|
|
78
|
+
task_details = cast(TaskDetails, await self.fetch.aio())
|
|
79
|
+
new_task_details = task_details.override(**kwargs)
|
|
80
|
+
new_entity = LazyEntity(self._name, self._getter)
|
|
81
|
+
new_entity._task = new_task_details
|
|
82
|
+
return new_entity
|
|
83
|
+
|
|
47
84
|
async def __call__(self, *args, **kwargs):
|
|
48
85
|
"""
|
|
49
86
|
Forwards the call to the underlying task. The entity will be fetched if not already present
|
|
50
87
|
"""
|
|
51
|
-
tk = await self.fetch.aio(
|
|
88
|
+
tk = await self.fetch.aio()
|
|
52
89
|
return await tk(*args, **kwargs)
|
|
53
90
|
|
|
54
91
|
def __repr__(self) -> str:
|
|
@@ -61,19 +98,29 @@ class LazyEntity:
|
|
|
61
98
|
AutoVersioning = Literal["latest", "current"]
|
|
62
99
|
|
|
63
100
|
|
|
64
|
-
@dataclass
|
|
65
|
-
class
|
|
101
|
+
@dataclass(frozen=True)
|
|
102
|
+
class TaskDetails(ToJSONMixin):
|
|
66
103
|
pb2: task_definition_pb2.TaskDetails
|
|
104
|
+
max_inline_io_bytes: int = 10 * 1024 * 1024 # 10 MB
|
|
105
|
+
overriden_queue: Optional[str] = None
|
|
67
106
|
|
|
68
107
|
@classmethod
|
|
69
|
-
def get(
|
|
108
|
+
def get(
|
|
109
|
+
cls,
|
|
110
|
+
name: str,
|
|
111
|
+
project: str | None,
|
|
112
|
+
domain: str | None,
|
|
113
|
+
version: str | None = None,
|
|
114
|
+
auto_version: AutoVersioning | None = None,
|
|
115
|
+
) -> LazyEntity:
|
|
70
116
|
"""
|
|
71
117
|
Get a task by its ID or name. If both are provided, the ID will take precedence.
|
|
72
118
|
|
|
73
119
|
Either version or auto_version are required parameters.
|
|
74
120
|
|
|
75
|
-
:param uri: The URI of the task. If provided, do not provide the rest of the parameters.
|
|
76
121
|
:param name: The name of the task.
|
|
122
|
+
:param project: The project of the task.
|
|
123
|
+
:param domain: The domain of the task.
|
|
77
124
|
:param version: The version of the task.
|
|
78
125
|
:param auto_version: If set to "latest", the latest-by-time ordered from now, version of the task will be used.
|
|
79
126
|
If set to "current", the version will be derived from the callee tasks context. This is useful if you are
|
|
@@ -87,34 +134,66 @@ class Task:
|
|
|
87
134
|
if version is None and auto_version not in ["latest", "current"]:
|
|
88
135
|
raise ValueError("auto_version must be either 'latest' or 'current'.")
|
|
89
136
|
|
|
90
|
-
async def deferred_get(_version: str | None, _auto_version: AutoVersioning | None) ->
|
|
137
|
+
async def deferred_get(_version: str | None, _auto_version: AutoVersioning | None) -> TaskDetails:
|
|
91
138
|
if _version is None:
|
|
92
139
|
if _auto_version == "latest":
|
|
93
|
-
|
|
140
|
+
tasks = []
|
|
141
|
+
async for x in Task.listall.aio(
|
|
142
|
+
by_task_name=name,
|
|
143
|
+
project=project,
|
|
144
|
+
domain=domain,
|
|
145
|
+
sort_by=("created_at", "desc"),
|
|
146
|
+
limit=1,
|
|
147
|
+
):
|
|
148
|
+
tasks.append(x)
|
|
149
|
+
if not tasks:
|
|
150
|
+
raise flyte.errors.RemoteTaskError(
|
|
151
|
+
f"No versions found for Task {name} in project {project}, domain {domain}."
|
|
152
|
+
)
|
|
153
|
+
_version = tasks[0].version
|
|
94
154
|
elif _auto_version == "current":
|
|
95
155
|
ctx = flyte.ctx()
|
|
96
156
|
if ctx is None:
|
|
97
157
|
raise ValueError("auto_version=current can only be used within a task context.")
|
|
98
158
|
_version = ctx.version
|
|
99
|
-
cfg =
|
|
159
|
+
cfg = get_init_config()
|
|
100
160
|
task_id = task_definition_pb2.TaskIdentifier(
|
|
101
161
|
org=cfg.org,
|
|
102
|
-
project=cfg.project,
|
|
103
|
-
domain=cfg.domain,
|
|
162
|
+
project=project or cfg.project,
|
|
163
|
+
domain=domain or cfg.domain,
|
|
104
164
|
name=name,
|
|
105
165
|
version=_version,
|
|
106
166
|
)
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
167
|
+
try:
|
|
168
|
+
resp = await get_client().task_service.GetTaskDetails(
|
|
169
|
+
task_service_pb2.GetTaskDetailsRequest(
|
|
170
|
+
task_id=task_id,
|
|
171
|
+
)
|
|
110
172
|
)
|
|
111
|
-
|
|
112
|
-
|
|
173
|
+
return cls(resp.details)
|
|
174
|
+
except grpc.aio.AioRpcError as err:
|
|
175
|
+
if err.code() == grpc.StatusCode.NOT_FOUND:
|
|
176
|
+
raise flyte.errors.RemoteTaskError(
|
|
177
|
+
f"Task {name}, version {_version} not found in {project} {domain}."
|
|
178
|
+
)
|
|
179
|
+
raise
|
|
113
180
|
|
|
114
181
|
return LazyEntity(
|
|
115
182
|
name=name, getter=functools.partial(deferred_get, _version=version, _auto_version=auto_version)
|
|
116
183
|
)
|
|
117
184
|
|
|
185
|
+
@classmethod
|
|
186
|
+
async def fetch(
|
|
187
|
+
cls,
|
|
188
|
+
name: str,
|
|
189
|
+
project: str | None = None,
|
|
190
|
+
domain: str | None = None,
|
|
191
|
+
version: str | None = None,
|
|
192
|
+
auto_version: AutoVersioning | None = None,
|
|
193
|
+
) -> TaskDetails:
|
|
194
|
+
lazy = TaskDetails.get(name, project=project, domain=domain, version=version, auto_version=auto_version)
|
|
195
|
+
return await lazy.fetch.aio()
|
|
196
|
+
|
|
118
197
|
@property
|
|
119
198
|
def name(self) -> str:
|
|
120
199
|
"""
|
|
@@ -136,6 +215,20 @@ class Task:
|
|
|
136
215
|
"""
|
|
137
216
|
return self.pb2.spec.task_template.type
|
|
138
217
|
|
|
218
|
+
@property
|
|
219
|
+
def default_input_args(self) -> Tuple[str, ...]:
|
|
220
|
+
"""
|
|
221
|
+
The default input arguments of the task.
|
|
222
|
+
"""
|
|
223
|
+
return tuple(x.name for x in self.pb2.spec.default_inputs)
|
|
224
|
+
|
|
225
|
+
@property
|
|
226
|
+
def required_args(self) -> Tuple[str, ...]:
|
|
227
|
+
"""
|
|
228
|
+
The required input arguments of the task.
|
|
229
|
+
"""
|
|
230
|
+
return tuple(x for x, _ in self.interface.inputs.items() if x not in self.default_input_args)
|
|
231
|
+
|
|
139
232
|
@functools.cached_property
|
|
140
233
|
def interface(self) -> NativeInterface:
|
|
141
234
|
"""
|
|
@@ -143,31 +236,40 @@ class Task:
|
|
|
143
236
|
"""
|
|
144
237
|
import flyte.types as types
|
|
145
238
|
|
|
146
|
-
return types.guess_interface(self.pb2.spec.task_template.interface)
|
|
239
|
+
return types.guess_interface(self.pb2.spec.task_template.interface, default_inputs=self.pb2.spec.default_inputs)
|
|
147
240
|
|
|
148
241
|
@property
|
|
149
242
|
def cache(self) -> flyte.Cache:
|
|
150
243
|
"""
|
|
151
244
|
The cache policy of the task.
|
|
152
245
|
"""
|
|
246
|
+
metadata = self.pb2.spec.task_template.metadata
|
|
247
|
+
behavior: CacheBehavior
|
|
248
|
+
if not metadata.discoverable:
|
|
249
|
+
behavior = "disable"
|
|
250
|
+
elif metadata.discovery_version:
|
|
251
|
+
behavior = "override"
|
|
252
|
+
else:
|
|
253
|
+
behavior = "auto"
|
|
254
|
+
|
|
153
255
|
return flyte.Cache(
|
|
154
|
-
behavior=
|
|
155
|
-
version_override=
|
|
156
|
-
serialize=
|
|
157
|
-
ignored_inputs=tuple(
|
|
256
|
+
behavior=behavior,
|
|
257
|
+
version_override=metadata.discovery_version if metadata.discovery_version else None,
|
|
258
|
+
serialize=metadata.cache_serializable,
|
|
259
|
+
ignored_inputs=tuple(metadata.cache_ignore_input_vars),
|
|
158
260
|
)
|
|
159
261
|
|
|
160
262
|
@property
|
|
161
263
|
def secrets(self):
|
|
162
264
|
"""
|
|
163
|
-
|
|
265
|
+
Get the list of secret keys required by the task.
|
|
164
266
|
"""
|
|
165
267
|
return [s.key for s in self.pb2.spec.task_template.security_context.secrets]
|
|
166
268
|
|
|
167
269
|
@property
|
|
168
270
|
def resources(self):
|
|
169
271
|
"""
|
|
170
|
-
|
|
272
|
+
Get the resource requests and limits for the task as a tuple (requests, limits).
|
|
171
273
|
"""
|
|
172
274
|
if self.pb2.spec.task_template.container is None:
|
|
173
275
|
return ()
|
|
@@ -180,6 +282,12 @@ class Task:
|
|
|
180
282
|
"""
|
|
181
283
|
Forwards the call to the underlying task. The entity will be fetched if not already present
|
|
182
284
|
"""
|
|
285
|
+
# TODO support kwargs, for this we need ordered inputs to be stored in the task spec.
|
|
286
|
+
if len(args) > 0:
|
|
287
|
+
raise flyte.errors.RemoteTaskError(
|
|
288
|
+
f"Remote task {self.name} does not support positional argumentscurrently. Please use keyword arguments."
|
|
289
|
+
)
|
|
290
|
+
|
|
183
291
|
ctx = internal_ctx()
|
|
184
292
|
if ctx.is_task_context():
|
|
185
293
|
# If we are in a task context, that implies we are executing a Run.
|
|
@@ -187,31 +295,112 @@ class Task:
|
|
|
187
295
|
# We will also check if we are not initialized, It is not expected to be not initialized
|
|
188
296
|
from flyte._internal.controllers import get_controller
|
|
189
297
|
|
|
190
|
-
controller =
|
|
298
|
+
controller = get_controller()
|
|
299
|
+
if len(self.required_args) > 0:
|
|
300
|
+
if len(args) + len(kwargs) < len(self.required_args):
|
|
301
|
+
raise ValueError(
|
|
302
|
+
f"Task {self.name} requires at least {self.required_args} arguments, "
|
|
303
|
+
f"but only received args:{args} kwargs{kwargs}."
|
|
304
|
+
)
|
|
191
305
|
if controller:
|
|
192
|
-
return await controller.submit_task_ref(self
|
|
193
|
-
raise flyte.errors
|
|
306
|
+
return await controller.submit_task_ref(self, *args, **kwargs)
|
|
307
|
+
raise flyte.errors.RemoteTaskError(f"Remote tasks [{self.name}] cannot be executed locally, only remotely.")
|
|
308
|
+
|
|
309
|
+
@property
|
|
310
|
+
def queue(self) -> Optional[str]:
|
|
311
|
+
"""
|
|
312
|
+
Get the queue name to use for task execution, if overridden.
|
|
313
|
+
"""
|
|
314
|
+
return self.overriden_queue
|
|
194
315
|
|
|
195
316
|
def override(
|
|
196
317
|
self,
|
|
197
318
|
*,
|
|
198
|
-
|
|
199
|
-
ref: Optional[bool] = None,
|
|
319
|
+
short_name: Optional[str] = None,
|
|
200
320
|
resources: Optional[flyte.Resources] = None,
|
|
201
|
-
cache: flyte.CacheRequest = "auto",
|
|
202
321
|
retries: Union[int, flyte.RetryStrategy] = 0,
|
|
203
322
|
timeout: Optional[flyte.TimeoutType] = None,
|
|
204
|
-
|
|
205
|
-
env: Optional[Dict[str, str]] = None,
|
|
323
|
+
env_vars: Optional[Dict[str, str]] = None,
|
|
206
324
|
secrets: Optional[flyte.SecretRequest] = None,
|
|
325
|
+
max_inline_io_bytes: Optional[int] = None,
|
|
326
|
+
cache: Optional[flyte.Cache] = None,
|
|
327
|
+
queue: Optional[str] = None,
|
|
207
328
|
**kwargs: Any,
|
|
208
|
-
) ->
|
|
209
|
-
|
|
329
|
+
) -> TaskDetails:
|
|
330
|
+
"""
|
|
331
|
+
Create a new TaskDetails with overridden properties.
|
|
332
|
+
|
|
333
|
+
:param short_name: Optional short name for the task.
|
|
334
|
+
:param resources: Optional resource requirements.
|
|
335
|
+
:param retries: Number of retries or retry strategy.
|
|
336
|
+
:param timeout: Execution timeout.
|
|
337
|
+
:param env_vars: Environment variables to set.
|
|
338
|
+
:param secrets: Secret requests for the task.
|
|
339
|
+
:param max_inline_io_bytes: Maximum inline I/O size in bytes.
|
|
340
|
+
:param cache: Cache configuration.
|
|
341
|
+
:param queue: Queue name for task execution.
|
|
342
|
+
:return: A new TaskDetails instance with the overrides applied.
|
|
343
|
+
"""
|
|
344
|
+
if len(kwargs) > 0:
|
|
345
|
+
raise ValueError(
|
|
346
|
+
f"RemoteTasks [{self.name}] do not support overriding with kwargs: {kwargs}, "
|
|
347
|
+
f"Check the parameters for override method."
|
|
348
|
+
)
|
|
349
|
+
pb2 = task_definition_pb2.TaskDetails()
|
|
350
|
+
pb2.CopyFrom(self.pb2)
|
|
351
|
+
|
|
352
|
+
if short_name:
|
|
353
|
+
pb2.metadata.short_name = short_name
|
|
354
|
+
|
|
355
|
+
template = pb2.spec.task_template
|
|
356
|
+
if secrets:
|
|
357
|
+
template.security_context.CopyFrom(get_security_context(secrets))
|
|
358
|
+
|
|
359
|
+
if template.HasField("container"):
|
|
360
|
+
if env_vars:
|
|
361
|
+
template.container.env.clear()
|
|
362
|
+
template.container.env.extend([literals_pb2.KeyValuePair(key=k, value=v) for k, v in env_vars.items()])
|
|
363
|
+
if resources:
|
|
364
|
+
template.container.resources.CopyFrom(get_proto_resources(resources))
|
|
365
|
+
|
|
366
|
+
md = template.metadata
|
|
367
|
+
if retries:
|
|
368
|
+
md.retries.CopyFrom(get_proto_retry_strategy(retries))
|
|
369
|
+
|
|
370
|
+
if timeout:
|
|
371
|
+
md.timeout.CopyFrom(get_proto_timeout(timeout))
|
|
372
|
+
|
|
373
|
+
if cache:
|
|
374
|
+
if cache.behavior == "disable":
|
|
375
|
+
md.discoverable = False
|
|
376
|
+
md.discovery_version = ""
|
|
377
|
+
elif cache.behavior == "override":
|
|
378
|
+
md.discoverable = True
|
|
379
|
+
if not cache.version_override:
|
|
380
|
+
raise ValueError("cache.version_override must be set when cache.behavior is 'override'")
|
|
381
|
+
md.discovery_version = cache.version_override
|
|
382
|
+
else:
|
|
383
|
+
if cache.behavior == "auto":
|
|
384
|
+
raise ValueError("cache.behavior must be 'disable' or 'override' for remote tasks")
|
|
385
|
+
raise ValueError(f"Invalid cache behavior: {cache.behavior}.")
|
|
386
|
+
md.cache_serializable = cache.serialize
|
|
387
|
+
md.cache_ignore_input_vars[:] = list(cache.ignored_inputs or ())
|
|
388
|
+
|
|
389
|
+
return TaskDetails(
|
|
390
|
+
pb2,
|
|
391
|
+
max_inline_io_bytes=max_inline_io_bytes or self.max_inline_io_bytes,
|
|
392
|
+
overriden_queue=queue,
|
|
393
|
+
)
|
|
210
394
|
|
|
211
395
|
def __rich_repr__(self) -> rich.repr.Result:
|
|
212
396
|
"""
|
|
213
397
|
Rich representation of the task.
|
|
214
398
|
"""
|
|
399
|
+
yield "short_name", self.pb2.spec.short_name
|
|
400
|
+
yield "environment", self.pb2.spec.environment
|
|
401
|
+
yield "default_inputs_keys", self.default_input_args
|
|
402
|
+
yield "required_args", self.required_args
|
|
403
|
+
yield "raw_default_inputs", [str(x) for x in self.pb2.spec.default_inputs]
|
|
215
404
|
yield "project", self.pb2.task_id.project
|
|
216
405
|
yield "domain", self.pb2.task_id.domain
|
|
217
406
|
yield "name", self.name
|
|
@@ -223,5 +412,157 @@ class Task:
|
|
|
223
412
|
yield "resources", self.resources
|
|
224
413
|
|
|
225
414
|
|
|
415
|
+
@dataclass
|
|
416
|
+
class Task(ToJSONMixin):
|
|
417
|
+
pb2: task_definition_pb2.Task
|
|
418
|
+
|
|
419
|
+
def __init__(self, pb2: task_definition_pb2.Task):
|
|
420
|
+
"""
|
|
421
|
+
Initialize a Task object.
|
|
422
|
+
|
|
423
|
+
:param pb2: The task protobuf definition.
|
|
424
|
+
"""
|
|
425
|
+
self.pb2 = pb2
|
|
426
|
+
|
|
427
|
+
@property
|
|
428
|
+
def name(self) -> str:
|
|
429
|
+
"""
|
|
430
|
+
The name of the task.
|
|
431
|
+
"""
|
|
432
|
+
return self.pb2.task_id.name
|
|
433
|
+
|
|
434
|
+
@property
|
|
435
|
+
def version(self) -> str:
|
|
436
|
+
"""
|
|
437
|
+
The version of the task.
|
|
438
|
+
"""
|
|
439
|
+
return self.pb2.task_id.version
|
|
440
|
+
|
|
441
|
+
@property
|
|
442
|
+
def url(self) -> str:
|
|
443
|
+
"""
|
|
444
|
+
Get the console URL for viewing the task.
|
|
445
|
+
"""
|
|
446
|
+
client = get_client()
|
|
447
|
+
return client.console.task_url(
|
|
448
|
+
project=self.pb2.task_id.project,
|
|
449
|
+
domain=self.pb2.task_id.domain,
|
|
450
|
+
task_name=self.pb2.task_id.name,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
@classmethod
|
|
454
|
+
def get(
|
|
455
|
+
cls,
|
|
456
|
+
name: str,
|
|
457
|
+
project: str | None = None,
|
|
458
|
+
domain: str | None = None,
|
|
459
|
+
version: str | None = None,
|
|
460
|
+
auto_version: AutoVersioning | None = None,
|
|
461
|
+
) -> LazyEntity:
|
|
462
|
+
"""
|
|
463
|
+
Get a task by its ID or name. If both are provided, the ID will take precedence.
|
|
464
|
+
|
|
465
|
+
Either version or auto_version are required parameters.
|
|
466
|
+
|
|
467
|
+
:param name: The name of the task.
|
|
468
|
+
:param project: The project of the task.
|
|
469
|
+
:param domain: The domain of the task.
|
|
470
|
+
:param version: The version of the task.
|
|
471
|
+
:param auto_version: If set to "latest", the latest-by-time ordered from now, version of the task will be used.
|
|
472
|
+
If set to "current", the version will be derived from the callee tasks context. This is useful if you are
|
|
473
|
+
deploying all environments with the same version. If auto_version is current, you can only access the task from
|
|
474
|
+
within a task context.
|
|
475
|
+
"""
|
|
476
|
+
return TaskDetails.get(name, project=project, domain=domain, version=version, auto_version=auto_version)
|
|
477
|
+
|
|
478
|
+
@syncify
|
|
479
|
+
@classmethod
|
|
480
|
+
async def listall(
|
|
481
|
+
cls,
|
|
482
|
+
by_task_name: str | None = None,
|
|
483
|
+
by_task_env: str | None = None,
|
|
484
|
+
project: str | None = None,
|
|
485
|
+
domain: str | None = None,
|
|
486
|
+
sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
|
|
487
|
+
limit: int = 100,
|
|
488
|
+
) -> Union[AsyncIterator[Task], Iterator[Task]]:
|
|
489
|
+
"""
|
|
490
|
+
Get all runs for the current project and domain.
|
|
491
|
+
|
|
492
|
+
:param by_task_name: If provided, only tasks with this name will be returned.
|
|
493
|
+
:param by_task_env: If provided, only tasks with this environment prefix will be returned.
|
|
494
|
+
:param project: The project to filter tasks by. If None, the current project will be used.
|
|
495
|
+
:param domain: The domain to filter tasks by. If None, the current domain will be used.
|
|
496
|
+
:param sort_by: The sorting criteria for the project list, in the format (field, order).
|
|
497
|
+
:param limit: The maximum number of tasks to return.
|
|
498
|
+
:return: An iterator of runs.
|
|
499
|
+
"""
|
|
500
|
+
ensure_client()
|
|
501
|
+
token = None
|
|
502
|
+
sort_by = sort_by or ("created_at", "asc")
|
|
503
|
+
sort_pb2 = list_pb2.Sort(
|
|
504
|
+
key=sort_by[0], direction=list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING
|
|
505
|
+
)
|
|
506
|
+
cfg = get_init_config()
|
|
507
|
+
filters = []
|
|
508
|
+
if by_task_name:
|
|
509
|
+
filters.append(
|
|
510
|
+
list_pb2.Filter(
|
|
511
|
+
function=list_pb2.Filter.Function.EQUAL,
|
|
512
|
+
field="name",
|
|
513
|
+
values=[by_task_name],
|
|
514
|
+
)
|
|
515
|
+
)
|
|
516
|
+
if by_task_env:
|
|
517
|
+
# ideally we should have a STARTS_WITH filter, but it is not supported yet
|
|
518
|
+
filters.append(
|
|
519
|
+
list_pb2.Filter(
|
|
520
|
+
function=list_pb2.Filter.Function.CONTAINS,
|
|
521
|
+
field="name",
|
|
522
|
+
values=[f"{by_task_env}."],
|
|
523
|
+
)
|
|
524
|
+
)
|
|
525
|
+
original_limit = limit
|
|
526
|
+
if limit > cfg.batch_size:
|
|
527
|
+
limit = cfg.batch_size
|
|
528
|
+
retrieved = 0
|
|
529
|
+
while True:
|
|
530
|
+
resp = await get_client().task_service.ListTasks(
|
|
531
|
+
task_service_pb2.ListTasksRequest(
|
|
532
|
+
org=cfg.org,
|
|
533
|
+
project_id=identifier_pb2.ProjectIdentifier(
|
|
534
|
+
organization=cfg.org,
|
|
535
|
+
domain=domain or cfg.domain,
|
|
536
|
+
name=project or cfg.project,
|
|
537
|
+
),
|
|
538
|
+
request=list_pb2.ListRequest(
|
|
539
|
+
sort_by=sort_pb2,
|
|
540
|
+
filters=filters,
|
|
541
|
+
limit=limit,
|
|
542
|
+
token=token,
|
|
543
|
+
),
|
|
544
|
+
)
|
|
545
|
+
)
|
|
546
|
+
token = resp.token
|
|
547
|
+
for t in resp.tasks:
|
|
548
|
+
retrieved += 1
|
|
549
|
+
yield cls(t)
|
|
550
|
+
if not token or retrieved >= original_limit:
|
|
551
|
+
logger.debug(f"Retrieved {retrieved} tasks, stopping iteration.")
|
|
552
|
+
break
|
|
553
|
+
|
|
554
|
+
def __rich_repr__(self) -> rich.repr.Result:
|
|
555
|
+
"""
|
|
556
|
+
Rich representation of the task.
|
|
557
|
+
"""
|
|
558
|
+
yield "project", self.pb2.task_id.project
|
|
559
|
+
yield "domain", self.pb2.task_id.domain
|
|
560
|
+
yield "name", self.pb2.task_id.name
|
|
561
|
+
yield "version", self.pb2.task_id.version
|
|
562
|
+
yield "short_name", self.pb2.metadata.short_name
|
|
563
|
+
for t in _repr_task_metadata(self.pb2.metadata):
|
|
564
|
+
yield t
|
|
565
|
+
|
|
566
|
+
|
|
226
567
|
if __name__ == "__main__":
|
|
227
568
|
tk = Task.get(name="example_task")
|