flyte 2.0.0b32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +108 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/debug.py +38 -0
- flyte/_bin/runtime.py +195 -0
- flyte/_bin/serve.py +178 -0
- flyte/_build.py +26 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +147 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/local_cache.py +216 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +121 -0
- flyte/_code_bundle/_packaging.py +218 -0
- flyte/_code_bundle/_utils.py +347 -0
- flyte/_code_bundle/bundle.py +266 -0
- flyte/_constants.py +1 -0
- flyte/_context.py +155 -0
- flyte/_custom_context.py +73 -0
- flyte/_debug/__init__.py +0 -0
- flyte/_debug/constants.py +38 -0
- flyte/_debug/utils.py +17 -0
- flyte/_debug/vscode.py +307 -0
- flyte/_deploy.py +408 -0
- flyte/_deployer.py +109 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +122 -0
- flyte/_excepthook.py +37 -0
- flyte/_group.py +32 -0
- flyte/_hash.py +8 -0
- flyte/_image.py +1055 -0
- flyte/_initialize.py +628 -0
- flyte/_interface.py +119 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +129 -0
- flyte/_internal/controllers/_local_controller.py +239 -0
- flyte/_internal/controllers/_trace.py +48 -0
- flyte/_internal/controllers/remote/__init__.py +58 -0
- flyte/_internal/controllers/remote/_action.py +211 -0
- flyte/_internal/controllers/remote/_client.py +47 -0
- flyte/_internal/controllers/remote/_controller.py +583 -0
- flyte/_internal/controllers/remote/_core.py +465 -0
- flyte/_internal/controllers/remote/_informer.py +381 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +3 -0
- flyte/_internal/imagebuild/docker_builder.py +706 -0
- flyte/_internal/imagebuild/image_builder.py +277 -0
- flyte/_internal/imagebuild/remote_builder.py +386 -0
- flyte/_internal/imagebuild/utils.py +78 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +21 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +486 -0
- flyte/_internal/runtime/entrypoints.py +204 -0
- flyte/_internal/runtime/io.py +188 -0
- flyte/_internal/runtime/resources_serde.py +152 -0
- flyte/_internal/runtime/reuse.py +125 -0
- flyte/_internal/runtime/rusty.py +193 -0
- flyte/_internal/runtime/task_serde.py +362 -0
- flyte/_internal/runtime/taskrunner.py +209 -0
- flyte/_internal/runtime/trigger_serde.py +160 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_keyring/__init__.py +0 -0
- flyte/_keyring/file.py +115 -0
- flyte/_logging.py +300 -0
- flyte/_map.py +312 -0
- flyte/_module.py +72 -0
- flyte/_pod.py +30 -0
- flyte/_resources.py +473 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +102 -0
- flyte/_run.py +724 -0
- flyte/_secret.py +96 -0
- flyte/_task.py +550 -0
- flyte/_task_environment.py +316 -0
- flyte/_task_plugins.py +47 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +119 -0
- flyte/_trigger.py +1000 -0
- flyte/_utils/__init__.py +30 -0
- flyte/_utils/asyn.py +121 -0
- flyte/_utils/async_cache.py +139 -0
- flyte/_utils/coro_management.py +27 -0
- flyte/_utils/docker_credentials.py +173 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +134 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/module_loader.py +104 -0
- flyte/_utils/org_discovery.py +57 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +34 -0
- flyte/app/__init__.py +22 -0
- flyte/app/_app_environment.py +157 -0
- flyte/app/_deploy.py +125 -0
- flyte/app/_input.py +160 -0
- flyte/app/_runtime/__init__.py +3 -0
- flyte/app/_runtime/app_serde.py +347 -0
- flyte/app/_types.py +101 -0
- flyte/app/extras/__init__.py +3 -0
- flyte/app/extras/_fastapi.py +151 -0
- flyte/cli/__init__.py +12 -0
- flyte/cli/_abort.py +28 -0
- flyte/cli/_build.py +114 -0
- flyte/cli/_common.py +468 -0
- flyte/cli/_create.py +371 -0
- flyte/cli/_delete.py +45 -0
- flyte/cli/_deploy.py +293 -0
- flyte/cli/_gen.py +176 -0
- flyte/cli/_get.py +370 -0
- flyte/cli/_option.py +33 -0
- flyte/cli/_params.py +554 -0
- flyte/cli/_plugins.py +209 -0
- flyte/cli/_run.py +597 -0
- flyte/cli/_serve.py +64 -0
- flyte/cli/_update.py +37 -0
- flyte/cli/_user.py +17 -0
- flyte/cli/main.py +221 -0
- flyte/config/__init__.py +3 -0
- flyte/config/_config.py +248 -0
- flyte/config/_internal.py +73 -0
- flyte/config/_reader.py +225 -0
- flyte/connectors/__init__.py +11 -0
- flyte/connectors/_connector.py +270 -0
- flyte/connectors/_server.py +197 -0
- flyte/connectors/utils.py +135 -0
- flyte/errors.py +243 -0
- flyte/extend.py +19 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +286 -0
- flyte/git/__init__.py +3 -0
- flyte/git/_config.py +21 -0
- flyte/io/__init__.py +29 -0
- flyte/io/_dataframe/__init__.py +131 -0
- flyte/io/_dataframe/basic_dfs.py +223 -0
- flyte/io/_dataframe/dataframe.py +1026 -0
- flyte/io/_dir.py +910 -0
- flyte/io/_file.py +914 -0
- flyte/io/_hashing_io.py +342 -0
- flyte/models.py +479 -0
- flyte/py.typed +0 -0
- flyte/remote/__init__.py +35 -0
- flyte/remote/_action.py +738 -0
- flyte/remote/_app.py +57 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +189 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_auth_utils.py +14 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +403 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +213 -0
- flyte/remote/_client/auth/_client_config.py +85 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +152 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +128 -0
- flyte/remote/_common.py +30 -0
- flyte/remote/_console.py +19 -0
- flyte/remote/_data.py +161 -0
- flyte/remote/_logs.py +185 -0
- flyte/remote/_project.py +88 -0
- flyte/remote/_run.py +386 -0
- flyte/remote/_secret.py +142 -0
- flyte/remote/_task.py +527 -0
- flyte/remote/_trigger.py +306 -0
- flyte/remote/_user.py +33 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +182 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +36 -0
- flyte/storage/_config.py +237 -0
- flyte/storage/_parallel_reader.py +274 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +456 -0
- flyte/storage/_utils.py +5 -0
- flyte/syncify/__init__.py +56 -0
- flyte/syncify/_api.py +375 -0
- flyte/types/__init__.py +52 -0
- flyte/types/_interface.py +40 -0
- flyte/types/_pickle.py +145 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +119 -0
- flyte/types/_type_engine.py +2254 -0
- flyte/types/_utils.py +80 -0
- flyte-2.0.0b32.data/scripts/debug.py +38 -0
- flyte-2.0.0b32.data/scripts/runtime.py +195 -0
- flyte-2.0.0b32.dist-info/METADATA +351 -0
- flyte-2.0.0b32.dist-info/RECORD +204 -0
- flyte-2.0.0b32.dist-info/WHEEL +5 -0
- flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
- flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
- flyte-2.0.0b32.dist-info/top_level.txt +1 -0
flyte/remote/_task.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import functools
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, AsyncIterator, Callable, Coroutine, Dict, Iterator, Literal, Optional, Tuple, Union, cast
|
|
7
|
+
|
|
8
|
+
import rich.repr
|
|
9
|
+
from flyteidl2.common import identifier_pb2, list_pb2
|
|
10
|
+
from flyteidl2.core import literals_pb2
|
|
11
|
+
from flyteidl2.task import task_definition_pb2, task_service_pb2
|
|
12
|
+
|
|
13
|
+
import flyte
|
|
14
|
+
import flyte.errors
|
|
15
|
+
from flyte._cache.cache import CacheBehavior
|
|
16
|
+
from flyte._context import internal_ctx
|
|
17
|
+
from flyte._initialize import ensure_client, get_client, get_init_config
|
|
18
|
+
from flyte._internal.runtime.resources_serde import get_proto_resources
|
|
19
|
+
from flyte._internal.runtime.task_serde import get_proto_retry_strategy, get_proto_timeout, get_security_context
|
|
20
|
+
from flyte._logging import logger
|
|
21
|
+
from flyte.models import NativeInterface
|
|
22
|
+
from flyte.syncify import syncify
|
|
23
|
+
|
|
24
|
+
from ._common import ToJSONMixin
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _repr_task_metadata(metadata: task_definition_pb2.TaskMetadata) -> rich.repr.Result:
|
|
28
|
+
"""
|
|
29
|
+
Rich representation of the task metadata.
|
|
30
|
+
"""
|
|
31
|
+
if metadata.deployed_by:
|
|
32
|
+
if metadata.deployed_by.user:
|
|
33
|
+
yield "deployed_by", f"User: {metadata.deployed_by.user.spec.email}"
|
|
34
|
+
else:
|
|
35
|
+
yield "deployed_by", f"App: {metadata.deployed_by.application.spec.name}"
|
|
36
|
+
yield "short_name", metadata.short_name
|
|
37
|
+
yield "deployed_at", metadata.deployed_at.ToDatetime()
|
|
38
|
+
yield "environment_name", metadata.environment_name
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class LazyEntity:
|
|
42
|
+
"""
|
|
43
|
+
Fetches the entity when the entity is called or when the entity is retrieved.
|
|
44
|
+
The entity is derived from RemoteEntity so that it behaves exactly like the mimicked entity.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, name: str, getter: Callable[..., Coroutine[Any, Any, TaskDetails]], *args, **kwargs):
|
|
48
|
+
self._task: Optional[TaskDetails] = None
|
|
49
|
+
self._getter = getter
|
|
50
|
+
self._name = name
|
|
51
|
+
self._mutex = asyncio.Lock()
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def name(self) -> str:
|
|
55
|
+
return self._name
|
|
56
|
+
|
|
57
|
+
@syncify
|
|
58
|
+
async def fetch(self) -> TaskDetails:
|
|
59
|
+
"""
|
|
60
|
+
Forwards all other attributes to task, causing the task to be fetched!
|
|
61
|
+
"""
|
|
62
|
+
async with self._mutex:
|
|
63
|
+
if self._task is None:
|
|
64
|
+
self._task = await self._getter()
|
|
65
|
+
if self._task is None:
|
|
66
|
+
raise RuntimeError(f"Error downloading the task {self._name}, (check original exception...)")
|
|
67
|
+
return self._task
|
|
68
|
+
|
|
69
|
+
@syncify
|
|
70
|
+
async def override(
|
|
71
|
+
self,
|
|
72
|
+
**kwargs: Any,
|
|
73
|
+
) -> LazyEntity:
|
|
74
|
+
task_details = cast(TaskDetails, await self.fetch.aio())
|
|
75
|
+
new_task_details = task_details.override(**kwargs)
|
|
76
|
+
new_entity = LazyEntity(self._name, self._getter)
|
|
77
|
+
new_entity._task = new_task_details
|
|
78
|
+
return new_entity
|
|
79
|
+
|
|
80
|
+
async def __call__(self, *args, **kwargs):
|
|
81
|
+
"""
|
|
82
|
+
Forwards the call to the underlying task. The entity will be fetched if not already present
|
|
83
|
+
"""
|
|
84
|
+
tk = await self.fetch.aio()
|
|
85
|
+
return await tk(*args, **kwargs)
|
|
86
|
+
|
|
87
|
+
def __repr__(self) -> str:
|
|
88
|
+
return str(self)
|
|
89
|
+
|
|
90
|
+
def __str__(self) -> str:
|
|
91
|
+
return f"Future for task with name {self._name}"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
AutoVersioning = Literal["latest", "current"]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@dataclass(frozen=True)
|
|
98
|
+
class TaskDetails(ToJSONMixin):
|
|
99
|
+
pb2: task_definition_pb2.TaskDetails
|
|
100
|
+
max_inline_io_bytes: int = 10 * 1024 * 1024 # 10 MB
|
|
101
|
+
overriden_queue: Optional[str] = None
|
|
102
|
+
|
|
103
|
+
@classmethod
|
|
104
|
+
def get(
|
|
105
|
+
cls,
|
|
106
|
+
name: str,
|
|
107
|
+
project: str | None,
|
|
108
|
+
domain: str | None,
|
|
109
|
+
version: str | None = None,
|
|
110
|
+
auto_version: AutoVersioning | None = None,
|
|
111
|
+
) -> LazyEntity:
|
|
112
|
+
"""
|
|
113
|
+
Get a task by its ID or name. If both are provided, the ID will take precedence.
|
|
114
|
+
|
|
115
|
+
Either version or auto_version are required parameters.
|
|
116
|
+
|
|
117
|
+
:param name: The name of the task.
|
|
118
|
+
:param project: The project of the task.
|
|
119
|
+
:param domain: The domain of the task.
|
|
120
|
+
:param version: The version of the task.
|
|
121
|
+
:param auto_version: If set to "latest", the latest-by-time ordered from now, version of the task will be used.
|
|
122
|
+
If set to "current", the version will be derived from the callee tasks context. This is useful if you are
|
|
123
|
+
deploying all environments with the same version. If auto_version is current, you can only access the task from
|
|
124
|
+
within a task context.
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
if version is None and auto_version is None:
|
|
128
|
+
raise ValueError("Either version or auto_version must be provided.")
|
|
129
|
+
|
|
130
|
+
if version is None and auto_version not in ["latest", "current"]:
|
|
131
|
+
raise ValueError("auto_version must be either 'latest' or 'current'.")
|
|
132
|
+
|
|
133
|
+
async def deferred_get(_version: str | None, _auto_version: AutoVersioning | None) -> TaskDetails:
|
|
134
|
+
if _version is None:
|
|
135
|
+
if _auto_version == "latest":
|
|
136
|
+
tasks = []
|
|
137
|
+
async for x in Task.listall.aio(
|
|
138
|
+
by_task_name=name,
|
|
139
|
+
project=project,
|
|
140
|
+
domain=domain,
|
|
141
|
+
sort_by=("created_at", "desc"),
|
|
142
|
+
limit=1,
|
|
143
|
+
):
|
|
144
|
+
tasks.append(x)
|
|
145
|
+
if not tasks:
|
|
146
|
+
raise flyte.errors.ReferenceTaskError(f"Task {name} not found.")
|
|
147
|
+
_version = tasks[0].version
|
|
148
|
+
elif _auto_version == "current":
|
|
149
|
+
ctx = flyte.ctx()
|
|
150
|
+
if ctx is None:
|
|
151
|
+
raise ValueError("auto_version=current can only be used within a task context.")
|
|
152
|
+
_version = ctx.version
|
|
153
|
+
cfg = get_init_config()
|
|
154
|
+
task_id = task_definition_pb2.TaskIdentifier(
|
|
155
|
+
org=cfg.org,
|
|
156
|
+
project=project or cfg.project,
|
|
157
|
+
domain=domain or cfg.domain,
|
|
158
|
+
name=name,
|
|
159
|
+
version=_version,
|
|
160
|
+
)
|
|
161
|
+
resp = await get_client().task_service.GetTaskDetails(
|
|
162
|
+
task_service_pb2.GetTaskDetailsRequest(
|
|
163
|
+
task_id=task_id,
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
return cls(resp.details)
|
|
167
|
+
|
|
168
|
+
return LazyEntity(
|
|
169
|
+
name=name, getter=functools.partial(deferred_get, _version=version, _auto_version=auto_version)
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
@classmethod
|
|
173
|
+
async def fetch(
|
|
174
|
+
cls,
|
|
175
|
+
name: str,
|
|
176
|
+
project: str | None = None,
|
|
177
|
+
domain: str | None = None,
|
|
178
|
+
version: str | None = None,
|
|
179
|
+
auto_version: AutoVersioning | None = None,
|
|
180
|
+
) -> TaskDetails:
|
|
181
|
+
lazy = TaskDetails.get(name, project=project, domain=domain, version=version, auto_version=auto_version)
|
|
182
|
+
return await lazy.fetch.aio()
|
|
183
|
+
|
|
184
|
+
@property
|
|
185
|
+
def name(self) -> str:
|
|
186
|
+
"""
|
|
187
|
+
The name of the task.
|
|
188
|
+
"""
|
|
189
|
+
return self.pb2.task_id.name
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def version(self) -> str:
|
|
193
|
+
"""
|
|
194
|
+
The version of the task.
|
|
195
|
+
"""
|
|
196
|
+
return self.pb2.task_id.version
|
|
197
|
+
|
|
198
|
+
@property
|
|
199
|
+
def task_type(self) -> str:
|
|
200
|
+
"""
|
|
201
|
+
The type of the task.
|
|
202
|
+
"""
|
|
203
|
+
return self.pb2.spec.task_template.type
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def default_input_args(self) -> Tuple[str, ...]:
|
|
207
|
+
"""
|
|
208
|
+
The default input arguments of the task.
|
|
209
|
+
"""
|
|
210
|
+
return tuple(x.name for x in self.pb2.spec.default_inputs)
|
|
211
|
+
|
|
212
|
+
@property
|
|
213
|
+
def required_args(self) -> Tuple[str, ...]:
|
|
214
|
+
"""
|
|
215
|
+
The required input arguments of the task.
|
|
216
|
+
"""
|
|
217
|
+
return tuple(x for x, _ in self.interface.inputs.items() if x not in self.default_input_args)
|
|
218
|
+
|
|
219
|
+
@functools.cached_property
|
|
220
|
+
def interface(self) -> NativeInterface:
|
|
221
|
+
"""
|
|
222
|
+
The interface of the task.
|
|
223
|
+
"""
|
|
224
|
+
import flyte.types as types
|
|
225
|
+
|
|
226
|
+
return types.guess_interface(self.pb2.spec.task_template.interface, default_inputs=self.pb2.spec.default_inputs)
|
|
227
|
+
|
|
228
|
+
@property
|
|
229
|
+
def cache(self) -> flyte.Cache:
|
|
230
|
+
"""
|
|
231
|
+
The cache policy of the task.
|
|
232
|
+
"""
|
|
233
|
+
metadata = self.pb2.spec.task_template.metadata
|
|
234
|
+
behavior: CacheBehavior
|
|
235
|
+
if not metadata.discoverable:
|
|
236
|
+
behavior = "disable"
|
|
237
|
+
elif metadata.discovery_version:
|
|
238
|
+
behavior = "override"
|
|
239
|
+
else:
|
|
240
|
+
behavior = "auto"
|
|
241
|
+
|
|
242
|
+
return flyte.Cache(
|
|
243
|
+
behavior=behavior,
|
|
244
|
+
version_override=metadata.discovery_version if metadata.discovery_version else None,
|
|
245
|
+
serialize=metadata.cache_serializable,
|
|
246
|
+
ignored_inputs=tuple(metadata.cache_ignore_input_vars),
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
@property
|
|
250
|
+
def secrets(self):
|
|
251
|
+
"""
|
|
252
|
+
The secrets of the task.
|
|
253
|
+
"""
|
|
254
|
+
return [s.key for s in self.pb2.spec.task_template.security_context.secrets]
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def resources(self):
|
|
258
|
+
"""
|
|
259
|
+
The resources of the task.
|
|
260
|
+
"""
|
|
261
|
+
if self.pb2.spec.task_template.container is None:
|
|
262
|
+
return ()
|
|
263
|
+
return (
|
|
264
|
+
self.pb2.spec.task_template.container.resources.requests,
|
|
265
|
+
self.pb2.spec.task_template.container.resources.limits,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
async def __call__(self, *args, **kwargs):
|
|
269
|
+
"""
|
|
270
|
+
Forwards the call to the underlying task. The entity will be fetched if not already present
|
|
271
|
+
"""
|
|
272
|
+
# TODO support kwargs, for this we need ordered inputs to be stored in the task spec.
|
|
273
|
+
if len(args) > 0:
|
|
274
|
+
raise flyte.errors.ReferenceTaskError(
|
|
275
|
+
f"Reference task {self.name} does not support positional arguments"
|
|
276
|
+
f"currently. Please use keyword arguments."
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
ctx = internal_ctx()
|
|
280
|
+
if ctx.is_task_context():
|
|
281
|
+
# If we are in a task context, that implies we are executing a Run.
|
|
282
|
+
# In this scenario, we should submit the task to the controller.
|
|
283
|
+
# We will also check if we are not initialized, It is not expected to be not initialized
|
|
284
|
+
from flyte._internal.controllers import get_controller
|
|
285
|
+
|
|
286
|
+
controller = get_controller()
|
|
287
|
+
if len(self.required_args) > 0:
|
|
288
|
+
if len(args) + len(kwargs) < len(self.required_args):
|
|
289
|
+
raise ValueError(
|
|
290
|
+
f"Task {self.name} requires at least {self.required_args} arguments, "
|
|
291
|
+
f"but only received args:{args} kwargs{kwargs}."
|
|
292
|
+
)
|
|
293
|
+
if controller:
|
|
294
|
+
return await controller.submit_task_ref(self, *args, **kwargs)
|
|
295
|
+
raise flyte.errors.ReferenceTaskError(
|
|
296
|
+
f"Reference tasks [{self.name}] cannot be executed locally, only remotely."
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
@property
|
|
300
|
+
def queue(self) -> Optional[str]:
|
|
301
|
+
"""
|
|
302
|
+
The queue to use for the task.
|
|
303
|
+
"""
|
|
304
|
+
return self.overriden_queue
|
|
305
|
+
|
|
306
|
+
def override(
|
|
307
|
+
self,
|
|
308
|
+
*,
|
|
309
|
+
short_name: Optional[str] = None,
|
|
310
|
+
resources: Optional[flyte.Resources] = None,
|
|
311
|
+
retries: Union[int, flyte.RetryStrategy] = 0,
|
|
312
|
+
timeout: Optional[flyte.TimeoutType] = None,
|
|
313
|
+
env_vars: Optional[Dict[str, str]] = None,
|
|
314
|
+
secrets: Optional[flyte.SecretRequest] = None,
|
|
315
|
+
max_inline_io_bytes: Optional[int] = None,
|
|
316
|
+
cache: Optional[flyte.Cache] = None,
|
|
317
|
+
queue: Optional[str] = None,
|
|
318
|
+
**kwargs: Any,
|
|
319
|
+
) -> TaskDetails:
|
|
320
|
+
if len(kwargs) > 0:
|
|
321
|
+
raise ValueError(
|
|
322
|
+
f"ReferenceTasks [{self.name}] do not support overriding with kwargs: {kwargs}, "
|
|
323
|
+
f"Check the parameters for override method."
|
|
324
|
+
)
|
|
325
|
+
pb2 = task_definition_pb2.TaskDetails()
|
|
326
|
+
pb2.CopyFrom(self.pb2)
|
|
327
|
+
|
|
328
|
+
if short_name:
|
|
329
|
+
pb2.metadata.short_name = short_name
|
|
330
|
+
|
|
331
|
+
template = pb2.spec.task_template
|
|
332
|
+
if secrets:
|
|
333
|
+
template.security_context.CopyFrom(get_security_context(secrets))
|
|
334
|
+
|
|
335
|
+
if template.HasField("container"):
|
|
336
|
+
if env_vars:
|
|
337
|
+
template.container.env.clear()
|
|
338
|
+
template.container.env.extend([literals_pb2.KeyValuePair(key=k, value=v) for k, v in env_vars.items()])
|
|
339
|
+
if resources:
|
|
340
|
+
template.container.resources.CopyFrom(get_proto_resources(resources))
|
|
341
|
+
|
|
342
|
+
md = template.metadata
|
|
343
|
+
if retries:
|
|
344
|
+
md.retries.CopyFrom(get_proto_retry_strategy(retries))
|
|
345
|
+
|
|
346
|
+
if timeout:
|
|
347
|
+
md.timeout.CopyFrom(get_proto_timeout(timeout))
|
|
348
|
+
|
|
349
|
+
if cache:
|
|
350
|
+
if cache.behavior == "disable":
|
|
351
|
+
md.discoverable = False
|
|
352
|
+
md.discovery_version = ""
|
|
353
|
+
elif cache.behavior == "override":
|
|
354
|
+
md.discoverable = True
|
|
355
|
+
if not cache.version_override:
|
|
356
|
+
raise ValueError("cache.version_override must be set when cache.behavior is 'override'")
|
|
357
|
+
md.discovery_version = cache.version_override
|
|
358
|
+
else:
|
|
359
|
+
if cache.behavior == "auto":
|
|
360
|
+
raise ValueError("cache.behavior must be 'disable' or 'override' for reference tasks")
|
|
361
|
+
raise ValueError(f"Invalid cache behavior: {cache.behavior}.")
|
|
362
|
+
md.cache_serializable = cache.serialize
|
|
363
|
+
md.cache_ignore_input_vars[:] = list(cache.ignored_inputs or ())
|
|
364
|
+
|
|
365
|
+
return TaskDetails(
|
|
366
|
+
pb2,
|
|
367
|
+
max_inline_io_bytes=max_inline_io_bytes or self.max_inline_io_bytes,
|
|
368
|
+
overriden_queue=queue,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
def __rich_repr__(self) -> rich.repr.Result:
|
|
372
|
+
"""
|
|
373
|
+
Rich representation of the task.
|
|
374
|
+
"""
|
|
375
|
+
yield "short_name", self.pb2.spec.short_name
|
|
376
|
+
yield "environment", self.pb2.spec.environment
|
|
377
|
+
yield "default_inputs_keys", self.default_input_args
|
|
378
|
+
yield "required_args", self.required_args
|
|
379
|
+
yield "raw_default_inputs", [str(x) for x in self.pb2.spec.default_inputs]
|
|
380
|
+
yield "project", self.pb2.task_id.project
|
|
381
|
+
yield "domain", self.pb2.task_id.domain
|
|
382
|
+
yield "name", self.name
|
|
383
|
+
yield "version", self.version
|
|
384
|
+
yield "task_type", self.task_type
|
|
385
|
+
yield "cache", self.cache
|
|
386
|
+
yield "interface", self.name + str(self.interface)
|
|
387
|
+
yield "secrets", self.secrets
|
|
388
|
+
yield "resources", self.resources
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
@dataclass
|
|
392
|
+
class Task(ToJSONMixin):
|
|
393
|
+
pb2: task_definition_pb2.Task
|
|
394
|
+
|
|
395
|
+
def __init__(self, pb2: task_definition_pb2.Task):
|
|
396
|
+
self.pb2 = pb2
|
|
397
|
+
|
|
398
|
+
@property
|
|
399
|
+
def name(self) -> str:
|
|
400
|
+
"""
|
|
401
|
+
The name of the task.
|
|
402
|
+
"""
|
|
403
|
+
return self.pb2.task_id.name
|
|
404
|
+
|
|
405
|
+
@property
|
|
406
|
+
def version(self) -> str:
|
|
407
|
+
"""
|
|
408
|
+
The version of the task.
|
|
409
|
+
"""
|
|
410
|
+
return self.pb2.task_id.version
|
|
411
|
+
|
|
412
|
+
@classmethod
|
|
413
|
+
def get(
|
|
414
|
+
cls,
|
|
415
|
+
name: str,
|
|
416
|
+
project: str | None = None,
|
|
417
|
+
domain: str | None = None,
|
|
418
|
+
version: str | None = None,
|
|
419
|
+
auto_version: AutoVersioning | None = None,
|
|
420
|
+
) -> LazyEntity:
|
|
421
|
+
"""
|
|
422
|
+
Get a task by its ID or name. If both are provided, the ID will take precedence.
|
|
423
|
+
|
|
424
|
+
Either version or auto_version are required parameters.
|
|
425
|
+
|
|
426
|
+
:param name: The name of the task.
|
|
427
|
+
:param project: The project of the task.
|
|
428
|
+
:param domain: The domain of the task.
|
|
429
|
+
:param version: The version of the task.
|
|
430
|
+
:param auto_version: If set to "latest", the latest-by-time ordered from now, version of the task will be used.
|
|
431
|
+
If set to "current", the version will be derived from the callee tasks context. This is useful if you are
|
|
432
|
+
deploying all environments with the same version. If auto_version is current, you can only access the task from
|
|
433
|
+
within a task context.
|
|
434
|
+
"""
|
|
435
|
+
return TaskDetails.get(name, project=project, domain=domain, version=version, auto_version=auto_version)
|
|
436
|
+
|
|
437
|
+
@syncify
|
|
438
|
+
@classmethod
|
|
439
|
+
async def listall(
|
|
440
|
+
cls,
|
|
441
|
+
by_task_name: str | None = None,
|
|
442
|
+
by_task_env: str | None = None,
|
|
443
|
+
project: str | None = None,
|
|
444
|
+
domain: str | None = None,
|
|
445
|
+
sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
|
|
446
|
+
limit: int = 100,
|
|
447
|
+
) -> Union[AsyncIterator[Task], Iterator[Task]]:
|
|
448
|
+
"""
|
|
449
|
+
Get all runs for the current project and domain.
|
|
450
|
+
|
|
451
|
+
:param by_task_name: If provided, only tasks with this name will be returned.
|
|
452
|
+
:param by_task_env: If provided, only tasks with this environment prefix will be returned.
|
|
453
|
+
:param project: The project to filter tasks by. If None, the current project will be used.
|
|
454
|
+
:param domain: The domain to filter tasks by. If None, the current domain will be used.
|
|
455
|
+
:param sort_by: The sorting criteria for the project list, in the format (field, order).
|
|
456
|
+
:param limit: The maximum number of tasks to return.
|
|
457
|
+
:return: An iterator of runs.
|
|
458
|
+
"""
|
|
459
|
+
ensure_client()
|
|
460
|
+
token = None
|
|
461
|
+
sort_by = sort_by or ("created_at", "asc")
|
|
462
|
+
sort_pb2 = list_pb2.Sort(
|
|
463
|
+
key=sort_by[0], direction=list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING
|
|
464
|
+
)
|
|
465
|
+
cfg = get_init_config()
|
|
466
|
+
filters = []
|
|
467
|
+
if by_task_name:
|
|
468
|
+
filters.append(
|
|
469
|
+
list_pb2.Filter(
|
|
470
|
+
function=list_pb2.Filter.Function.EQUAL,
|
|
471
|
+
field="name",
|
|
472
|
+
values=[by_task_name],
|
|
473
|
+
)
|
|
474
|
+
)
|
|
475
|
+
if by_task_env:
|
|
476
|
+
# ideally we should have a STARTS_WITH filter, but it is not supported yet
|
|
477
|
+
filters.append(
|
|
478
|
+
list_pb2.Filter(
|
|
479
|
+
function=list_pb2.Filter.Function.CONTAINS,
|
|
480
|
+
field="name",
|
|
481
|
+
values=[f"{by_task_env}."],
|
|
482
|
+
)
|
|
483
|
+
)
|
|
484
|
+
original_limit = limit
|
|
485
|
+
if limit > cfg.batch_size:
|
|
486
|
+
limit = cfg.batch_size
|
|
487
|
+
retrieved = 0
|
|
488
|
+
while True:
|
|
489
|
+
resp = await get_client().task_service.ListTasks(
|
|
490
|
+
task_service_pb2.ListTasksRequest(
|
|
491
|
+
org=cfg.org,
|
|
492
|
+
project_id=identifier_pb2.ProjectIdentifier(
|
|
493
|
+
organization=cfg.org,
|
|
494
|
+
domain=domain or cfg.domain,
|
|
495
|
+
name=project or cfg.project,
|
|
496
|
+
),
|
|
497
|
+
request=list_pb2.ListRequest(
|
|
498
|
+
sort_by=sort_pb2,
|
|
499
|
+
filters=filters,
|
|
500
|
+
limit=limit,
|
|
501
|
+
token=token,
|
|
502
|
+
),
|
|
503
|
+
)
|
|
504
|
+
)
|
|
505
|
+
token = resp.token
|
|
506
|
+
for t in resp.tasks:
|
|
507
|
+
retrieved += 1
|
|
508
|
+
yield cls(t)
|
|
509
|
+
if not token or retrieved >= original_limit:
|
|
510
|
+
logger.debug(f"Retrieved {retrieved} tasks, stopping iteration.")
|
|
511
|
+
break
|
|
512
|
+
|
|
513
|
+
def __rich_repr__(self) -> rich.repr.Result:
|
|
514
|
+
"""
|
|
515
|
+
Rich representation of the task.
|
|
516
|
+
"""
|
|
517
|
+
yield "project", self.pb2.task_id.project
|
|
518
|
+
yield "domain", self.pb2.task_id.domain
|
|
519
|
+
yield "name", self.pb2.task_id.name
|
|
520
|
+
yield "version", self.pb2.task_id.version
|
|
521
|
+
yield "short_name", self.pb2.metadata.short_name
|
|
522
|
+
for t in _repr_task_metadata(self.pb2.metadata):
|
|
523
|
+
yield t
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
if __name__ == "__main__":
|
|
527
|
+
tk = Task.get(name="example_task")
|