flyte 2.0.0b23__py3-none-any.whl → 2.0.0b25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +11 -2
- flyte/_cache/local_cache.py +4 -3
- flyte/_code_bundle/_utils.py +3 -3
- flyte/_code_bundle/bundle.py +12 -5
- flyte/_context.py +4 -1
- flyte/_custom_context.py +73 -0
- flyte/_deploy.py +31 -7
- flyte/_image.py +48 -16
- flyte/_initialize.py +69 -26
- flyte/_internal/controllers/_local_controller.py +1 -0
- flyte/_internal/controllers/_trace.py +1 -1
- flyte/_internal/controllers/remote/_action.py +9 -10
- flyte/_internal/controllers/remote/_client.py +1 -1
- flyte/_internal/controllers/remote/_controller.py +4 -2
- flyte/_internal/controllers/remote/_core.py +10 -13
- flyte/_internal/controllers/remote/_informer.py +3 -3
- flyte/_internal/controllers/remote/_service_protocol.py +7 -7
- flyte/_internal/imagebuild/docker_builder.py +45 -59
- flyte/_internal/imagebuild/remote_builder.py +51 -11
- flyte/_internal/imagebuild/utils.py +51 -3
- flyte/_internal/runtime/convert.py +39 -18
- flyte/_internal/runtime/io.py +8 -7
- flyte/_internal/runtime/resources_serde.py +20 -6
- flyte/_internal/runtime/reuse.py +1 -1
- flyte/_internal/runtime/task_serde.py +7 -10
- flyte/_internal/runtime/taskrunner.py +10 -1
- flyte/_internal/runtime/trigger_serde.py +13 -13
- flyte/_internal/runtime/types_serde.py +1 -1
- flyte/_keyring/file.py +2 -2
- flyte/_map.py +65 -13
- flyte/_pod.py +2 -2
- flyte/_resources.py +175 -31
- flyte/_run.py +37 -21
- flyte/_task.py +27 -6
- flyte/_task_environment.py +37 -10
- flyte/_utils/module_loader.py +2 -2
- flyte/_version.py +3 -3
- flyte/cli/_common.py +47 -5
- flyte/cli/_create.py +4 -0
- flyte/cli/_deploy.py +8 -0
- flyte/cli/_get.py +4 -0
- flyte/cli/_params.py +4 -4
- flyte/cli/_run.py +50 -7
- flyte/cli/_update.py +4 -3
- flyte/config/_config.py +2 -0
- flyte/config/_internal.py +1 -0
- flyte/config/_reader.py +3 -3
- flyte/errors.py +1 -1
- flyte/extend.py +4 -0
- flyte/extras/_container.py +6 -1
- flyte/git/_config.py +11 -9
- flyte/io/_dataframe/basic_dfs.py +1 -1
- flyte/io/_dataframe/dataframe.py +12 -8
- flyte/io/_dir.py +48 -15
- flyte/io/_file.py +48 -11
- flyte/models.py +12 -8
- flyte/remote/_action.py +18 -16
- flyte/remote/_client/_protocols.py +4 -3
- flyte/remote/_client/auth/_channel.py +1 -1
- flyte/remote/_client/controlplane.py +4 -8
- flyte/remote/_data.py +4 -3
- flyte/remote/_logs.py +3 -3
- flyte/remote/_run.py +5 -5
- flyte/remote/_secret.py +20 -13
- flyte/remote/_task.py +7 -8
- flyte/remote/_trigger.py +25 -27
- flyte/storage/_parallel_reader.py +274 -0
- flyte/storage/_storage.py +66 -2
- flyte/types/_interface.py +2 -2
- flyte/types/_pickle.py +1 -1
- flyte/types/_string_literals.py +8 -9
- flyte/types/_type_engine.py +25 -17
- flyte/types/_utils.py +1 -1
- {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/METADATA +2 -1
- flyte-2.0.0b25.dist-info/RECORD +184 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +0 -66
- flyte/_protos/common/authorization_pb2.pyi +0 -108
- flyte/_protos/common/authorization_pb2_grpc.py +0 -4
- flyte/_protos/common/identifier_pb2.py +0 -117
- flyte/_protos/common/identifier_pb2.pyi +0 -142
- flyte/_protos/common/identifier_pb2_grpc.py +0 -4
- flyte/_protos/common/identity_pb2.py +0 -48
- flyte/_protos/common/identity_pb2.pyi +0 -72
- flyte/_protos/common/identity_pb2_grpc.py +0 -4
- flyte/_protos/common/list_pb2.py +0 -36
- flyte/_protos/common/list_pb2.pyi +0 -71
- flyte/_protos/common/list_pb2_grpc.py +0 -4
- flyte/_protos/common/policy_pb2.py +0 -37
- flyte/_protos/common/policy_pb2.pyi +0 -27
- flyte/_protos/common/policy_pb2_grpc.py +0 -4
- flyte/_protos/common/role_pb2.py +0 -37
- flyte/_protos/common/role_pb2.pyi +0 -53
- flyte/_protos/common/role_pb2_grpc.py +0 -4
- flyte/_protos/common/runtime_version_pb2.py +0 -28
- flyte/_protos/common/runtime_version_pb2.pyi +0 -24
- flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/definition_pb2.py +0 -60
- flyte/_protos/imagebuilder/definition_pb2.pyi +0 -153
- flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/payload_pb2.py +0 -32
- flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
- flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/service_pb2.py +0 -29
- flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
- flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
- flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
- flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/definition_pb2.py +0 -49
- flyte/_protos/secret/definition_pb2.pyi +0 -93
- flyte/_protos/secret/definition_pb2_grpc.py +0 -4
- flyte/_protos/secret/payload_pb2.py +0 -62
- flyte/_protos/secret/payload_pb2.pyi +0 -94
- flyte/_protos/secret/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/secret_pb2.py +0 -38
- flyte/_protos/secret/secret_pb2.pyi +0 -6
- flyte/_protos/secret/secret_pb2_grpc.py +0 -198
- flyte/_protos/validate/validate/validate_pb2.py +0 -76
- flyte/_protos/workflow/common_pb2.py +0 -38
- flyte/_protos/workflow/common_pb2.pyi +0 -63
- flyte/_protos/workflow/common_pb2_grpc.py +0 -4
- flyte/_protos/workflow/environment_pb2.py +0 -29
- flyte/_protos/workflow/environment_pb2.pyi +0 -12
- flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
- flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
- flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
- flyte/_protos/workflow/queue_service_pb2.py +0 -117
- flyte/_protos/workflow/queue_service_pb2.pyi +0 -182
- flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -206
- flyte/_protos/workflow/run_definition_pb2.py +0 -123
- flyte/_protos/workflow/run_definition_pb2.pyi +0 -354
- flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
- flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
- flyte/_protos/workflow/run_service_pb2.py +0 -147
- flyte/_protos/workflow/run_service_pb2.pyi +0 -203
- flyte/_protos/workflow/run_service_pb2_grpc.py +0 -480
- flyte/_protos/workflow/state_service_pb2.py +0 -67
- flyte/_protos/workflow/state_service_pb2.pyi +0 -76
- flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
- flyte/_protos/workflow/task_definition_pb2.py +0 -86
- flyte/_protos/workflow/task_definition_pb2.pyi +0 -105
- flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/task_service_pb2.py +0 -61
- flyte/_protos/workflow/task_service_pb2.pyi +0 -62
- flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
- flyte/_protos/workflow/trigger_definition_pb2.py +0 -66
- flyte/_protos/workflow/trigger_definition_pb2.pyi +0 -117
- flyte/_protos/workflow/trigger_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/trigger_service_pb2.py +0 -96
- flyte/_protos/workflow/trigger_service_pb2.pyi +0 -110
- flyte/_protos/workflow/trigger_service_pb2_grpc.py +0 -281
- flyte-2.0.0b23.dist-info/RECORD +0 -262
- {flyte-2.0.0b23.data → flyte-2.0.0b25.data}/scripts/debug.py +0 -0
- {flyte-2.0.0b23.data → flyte-2.0.0b25.data}/scripts/runtime.py +0 -0
- {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/WHEEL +0 -0
- {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/entry_points.txt +0 -0
- {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/licenses/LICENSE +0 -0
- {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/top_level.txt +0 -0
flyte/_map.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import functools
|
|
3
3
|
import logging
|
|
4
|
-
from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast
|
|
4
|
+
from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast, overload
|
|
5
5
|
|
|
6
6
|
from flyte.syncify import syncify
|
|
7
7
|
|
|
8
8
|
from ._group import group
|
|
9
9
|
from ._logging import logger
|
|
10
|
-
from ._task import P, R
|
|
10
|
+
from ._task import AsyncFunctionTaskTemplate, F, P, R
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class MapAsyncIterator(Generic[P, R]):
|
|
@@ -15,7 +15,7 @@ class MapAsyncIterator(Generic[P, R]):
|
|
|
15
15
|
|
|
16
16
|
def __init__(
|
|
17
17
|
self,
|
|
18
|
-
func:
|
|
18
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
19
19
|
args: tuple,
|
|
20
20
|
name: str,
|
|
21
21
|
concurrency: int,
|
|
@@ -78,7 +78,7 @@ class MapAsyncIterator(Generic[P, R]):
|
|
|
78
78
|
|
|
79
79
|
if isinstance(self.func, functools.partial):
|
|
80
80
|
# Handle partial functions by merging bound args/kwargs with mapped args
|
|
81
|
-
base_func = cast(
|
|
81
|
+
base_func = cast(AsyncFunctionTaskTemplate, self.func.func)
|
|
82
82
|
bound_args = self.func.args
|
|
83
83
|
bound_kwargs = self.func.keywords or {}
|
|
84
84
|
|
|
@@ -144,7 +144,7 @@ class _Mapper(Generic[P, R]):
|
|
|
144
144
|
:param func: partial function to validate
|
|
145
145
|
:raises TypeError: if the partial function is not valid for mapping
|
|
146
146
|
"""
|
|
147
|
-
f = cast(
|
|
147
|
+
f = cast(AsyncFunctionTaskTemplate, func.func)
|
|
148
148
|
inputs = f.native_interface.inputs
|
|
149
149
|
params = list(inputs.keys())
|
|
150
150
|
total_params = len(params)
|
|
@@ -172,9 +172,28 @@ class _Mapper(Generic[P, R]):
|
|
|
172
172
|
f"in partial function {f.name}."
|
|
173
173
|
)
|
|
174
174
|
|
|
175
|
+
@overload
|
|
175
176
|
def __call__(
|
|
176
177
|
self,
|
|
177
|
-
func:
|
|
178
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
179
|
+
*args: Iterable[Any],
|
|
180
|
+
group_name: str | None = None,
|
|
181
|
+
concurrency: int = 0,
|
|
182
|
+
) -> Iterator[R]: ...
|
|
183
|
+
|
|
184
|
+
@overload
|
|
185
|
+
def __call__(
|
|
186
|
+
self,
|
|
187
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
188
|
+
*args: Iterable[Any],
|
|
189
|
+
group_name: str | None = None,
|
|
190
|
+
concurrency: int = 0,
|
|
191
|
+
return_exceptions: bool = True,
|
|
192
|
+
) -> Iterator[R]: ...
|
|
193
|
+
|
|
194
|
+
def __call__(
|
|
195
|
+
self,
|
|
196
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
178
197
|
*args: Iterable[Any],
|
|
179
198
|
group_name: str | None = None,
|
|
180
199
|
concurrency: int = 0,
|
|
@@ -194,10 +213,10 @@ class _Mapper(Generic[P, R]):
|
|
|
194
213
|
return
|
|
195
214
|
|
|
196
215
|
if isinstance(func, functools.partial):
|
|
197
|
-
f = cast(
|
|
216
|
+
f = cast(AsyncFunctionTaskTemplate, func.func)
|
|
198
217
|
self.validate_partial(func)
|
|
199
218
|
else:
|
|
200
|
-
f = cast(
|
|
219
|
+
f = cast(AsyncFunctionTaskTemplate, func)
|
|
201
220
|
|
|
202
221
|
name = self._get_name(f.name, group_name)
|
|
203
222
|
logger.debug(f"Blocking Map for {name}")
|
|
@@ -234,7 +253,7 @@ class _Mapper(Generic[P, R]):
|
|
|
234
253
|
|
|
235
254
|
async def aio(
|
|
236
255
|
self,
|
|
237
|
-
func:
|
|
256
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
238
257
|
*args: Iterable[Any],
|
|
239
258
|
group_name: str | None = None,
|
|
240
259
|
concurrency: int = 0,
|
|
@@ -244,10 +263,10 @@ class _Mapper(Generic[P, R]):
|
|
|
244
263
|
return
|
|
245
264
|
|
|
246
265
|
if isinstance(func, functools.partial):
|
|
247
|
-
f = cast(
|
|
266
|
+
f = cast(AsyncFunctionTaskTemplate, func.func)
|
|
248
267
|
self.validate_partial(func)
|
|
249
268
|
else:
|
|
250
|
-
f = cast(
|
|
269
|
+
f = cast(AsyncFunctionTaskTemplate, func)
|
|
251
270
|
|
|
252
271
|
name = self._get_name(f.name, group_name)
|
|
253
272
|
with group(name):
|
|
@@ -277,7 +296,7 @@ class _Mapper(Generic[P, R]):
|
|
|
277
296
|
|
|
278
297
|
@syncify
|
|
279
298
|
async def _map(
|
|
280
|
-
func:
|
|
299
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
281
300
|
*args: Iterable[Any],
|
|
282
301
|
name: str = "map",
|
|
283
302
|
concurrency: int = 0,
|
|
@@ -290,4 +309,37 @@ async def _map(
|
|
|
290
309
|
yield result
|
|
291
310
|
|
|
292
311
|
|
|
293
|
-
|
|
312
|
+
@overload
|
|
313
|
+
def map(
|
|
314
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
315
|
+
*args: Iterable[Any],
|
|
316
|
+
group_name: str | None = None,
|
|
317
|
+
concurrency: int = 0,
|
|
318
|
+
) -> Iterator[R]: ...
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
@overload
|
|
322
|
+
def map(
|
|
323
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
324
|
+
*args: Iterable[Any],
|
|
325
|
+
group_name: str | None = None,
|
|
326
|
+
concurrency: int = 0,
|
|
327
|
+
return_exceptions: bool = True,
|
|
328
|
+
) -> Iterator[Union[R, Exception]]: ...
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def map(
|
|
332
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
333
|
+
*args: Iterable[Any],
|
|
334
|
+
group_name: str | None = None,
|
|
335
|
+
concurrency: int = 0,
|
|
336
|
+
return_exceptions: bool = True,
|
|
337
|
+
) -> Iterator[Union[R, Exception]]:
|
|
338
|
+
map: _Mapper = _Mapper()
|
|
339
|
+
return map(
|
|
340
|
+
func,
|
|
341
|
+
*args,
|
|
342
|
+
group_name=group_name,
|
|
343
|
+
concurrency=concurrency,
|
|
344
|
+
return_exceptions=return_exceptions,
|
|
345
|
+
)
|
flyte/_pod.py
CHANGED
|
@@ -2,7 +2,7 @@ from dataclasses import dataclass, field
|
|
|
2
2
|
from typing import TYPE_CHECKING, Dict, Optional
|
|
3
3
|
|
|
4
4
|
if TYPE_CHECKING:
|
|
5
|
-
from
|
|
5
|
+
from flyteidl2.core.tasks_pb2 import K8sPod
|
|
6
6
|
from kubernetes.client import V1PodSpec
|
|
7
7
|
|
|
8
8
|
|
|
@@ -20,7 +20,7 @@ class PodTemplate(object):
|
|
|
20
20
|
annotations: Optional[Dict[str, str]] = None
|
|
21
21
|
|
|
22
22
|
def to_k8s_pod(self) -> "K8sPod":
|
|
23
|
-
from
|
|
23
|
+
from flyteidl2.core.tasks_pb2 import K8sObjectMetadata, K8sPod
|
|
24
24
|
from kubernetes.client import ApiClient
|
|
25
25
|
|
|
26
26
|
return K8sPod(
|
flyte/_resources.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
|
+
import typing
|
|
1
2
|
from dataclasses import dataclass, fields
|
|
2
|
-
from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union, get_args
|
|
3
|
+
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union, get_args
|
|
3
4
|
|
|
4
5
|
import rich.repr
|
|
5
6
|
|
|
@@ -10,7 +11,7 @@ if TYPE_CHECKING:
|
|
|
10
11
|
|
|
11
12
|
PRIMARY_CONTAINER_DEFAULT_NAME = "primary"
|
|
12
13
|
|
|
13
|
-
GPUType = Literal["
|
|
14
|
+
GPUType = Literal["A10", "A10G", "A100", "A100 80G", "B200", "H100", "L4", "L40s", "T4", "V100", "RTX PRO 6000"]
|
|
14
15
|
GPUQuantity = Literal[1, 2, 3, 4, 5, 6, 7, 8]
|
|
15
16
|
A100Parts = Literal["1g.5gb", "2g.10gb", "3g.20gb", "4g.20gb", "7g.40gb"]
|
|
16
17
|
"""
|
|
@@ -37,31 +38,32 @@ V6EParts = Literal["1x1", "2x2", "2x4", "4x4", "4x8", "8x8", "8x16", "16x16"]
|
|
|
37
38
|
Slices for Google Cloud TPU v6e.
|
|
38
39
|
"""
|
|
39
40
|
|
|
41
|
+
NeuronType = Literal["Inf1", "Inf2", "Trn1", "Trn1n", "Trn2", "Trn2u"]
|
|
42
|
+
|
|
43
|
+
AMD_GPUType = Literal["MI100", "MI210", "MI250", "MI250X", "MI300A", "MI300X", "MI325X", "MI350X", "MI355X"]
|
|
44
|
+
|
|
45
|
+
HABANA_GAUDIType = Literal["Gaudi1"]
|
|
46
|
+
|
|
40
47
|
Accelerators = Literal[
|
|
41
|
-
|
|
42
|
-
"
|
|
43
|
-
"
|
|
44
|
-
"
|
|
45
|
-
"
|
|
46
|
-
"
|
|
47
|
-
"
|
|
48
|
-
"
|
|
49
|
-
"
|
|
50
|
-
|
|
51
|
-
"
|
|
52
|
-
"
|
|
53
|
-
"
|
|
54
|
-
"
|
|
55
|
-
"
|
|
56
|
-
"
|
|
57
|
-
"
|
|
58
|
-
"
|
|
59
|
-
|
|
60
|
-
"L40s:4",
|
|
61
|
-
"L40s:5",
|
|
62
|
-
"L40s:6",
|
|
63
|
-
"L40s:7",
|
|
64
|
-
"L40s:8",
|
|
48
|
+
# A10
|
|
49
|
+
"A10:1",
|
|
50
|
+
"A10:2",
|
|
51
|
+
"A10:3",
|
|
52
|
+
"A10:4",
|
|
53
|
+
"A10:5",
|
|
54
|
+
"A10:6",
|
|
55
|
+
"A10:7",
|
|
56
|
+
"A10:8",
|
|
57
|
+
# A10G
|
|
58
|
+
"A10G:1",
|
|
59
|
+
"A10G:2",
|
|
60
|
+
"A10G:3",
|
|
61
|
+
"A10G:4",
|
|
62
|
+
"A10G:5",
|
|
63
|
+
"A10G:6",
|
|
64
|
+
"A10G:7",
|
|
65
|
+
"A10G:8",
|
|
66
|
+
# A100
|
|
65
67
|
"A100:1",
|
|
66
68
|
"A100:2",
|
|
67
69
|
"A100:3",
|
|
@@ -70,6 +72,7 @@ Accelerators = Literal[
|
|
|
70
72
|
"A100:6",
|
|
71
73
|
"A100:7",
|
|
72
74
|
"A100:8",
|
|
75
|
+
# A100 80G
|
|
73
76
|
"A100 80G:1",
|
|
74
77
|
"A100 80G:2",
|
|
75
78
|
"A100 80G:3",
|
|
@@ -78,6 +81,16 @@ Accelerators = Literal[
|
|
|
78
81
|
"A100 80G:6",
|
|
79
82
|
"A100 80G:7",
|
|
80
83
|
"A100 80G:8",
|
|
84
|
+
# B200
|
|
85
|
+
"B200:1",
|
|
86
|
+
"B200:2",
|
|
87
|
+
"B200:3",
|
|
88
|
+
"B200:4",
|
|
89
|
+
"B200:5",
|
|
90
|
+
"B200:6",
|
|
91
|
+
"B200:7",
|
|
92
|
+
"B200:8",
|
|
93
|
+
# H100
|
|
81
94
|
"H100:1",
|
|
82
95
|
"H100:2",
|
|
83
96
|
"H100:3",
|
|
@@ -86,8 +99,97 @@ Accelerators = Literal[
|
|
|
86
99
|
"H100:6",
|
|
87
100
|
"H100:7",
|
|
88
101
|
"H100:8",
|
|
102
|
+
# H200
|
|
103
|
+
"H200:1",
|
|
104
|
+
"H200:2",
|
|
105
|
+
"H200:3",
|
|
106
|
+
"H200:4",
|
|
107
|
+
"H200:5",
|
|
108
|
+
"H200:6",
|
|
109
|
+
"H200:7",
|
|
110
|
+
"H200:8",
|
|
111
|
+
# L4
|
|
112
|
+
"L4:1",
|
|
113
|
+
"L4:2",
|
|
114
|
+
"L4:3",
|
|
115
|
+
"L4:4",
|
|
116
|
+
"L4:5",
|
|
117
|
+
"L4:6",
|
|
118
|
+
"L4:7",
|
|
119
|
+
"L4:8",
|
|
120
|
+
# L40s
|
|
121
|
+
"L40s:1",
|
|
122
|
+
"L40s:2",
|
|
123
|
+
"L40s:3",
|
|
124
|
+
"L40s:4",
|
|
125
|
+
"L40s:5",
|
|
126
|
+
"L40s:6",
|
|
127
|
+
"L40s:7",
|
|
128
|
+
"L40s:8",
|
|
129
|
+
# V100
|
|
130
|
+
"V100:1",
|
|
131
|
+
"V100:2",
|
|
132
|
+
"V100:3",
|
|
133
|
+
"V100:4",
|
|
134
|
+
"V100:5",
|
|
135
|
+
"V100:6",
|
|
136
|
+
"V100:7",
|
|
137
|
+
"V100:8",
|
|
138
|
+
# RTX 6000
|
|
139
|
+
"RTX PRO 6000:1",
|
|
140
|
+
# T4
|
|
141
|
+
"T4:1",
|
|
142
|
+
"T4:2",
|
|
143
|
+
"T4:3",
|
|
144
|
+
"T4:4",
|
|
145
|
+
"T4:5",
|
|
146
|
+
"T4:6",
|
|
147
|
+
"T4:7",
|
|
148
|
+
"T4:8",
|
|
149
|
+
# Trn1
|
|
150
|
+
"Trn1:1",
|
|
151
|
+
# Trn1n
|
|
152
|
+
"Trn1n:1",
|
|
153
|
+
# Trn2
|
|
154
|
+
"Trn2:1",
|
|
155
|
+
# Trn2u
|
|
156
|
+
"Trn2u:1",
|
|
157
|
+
# Inf1
|
|
158
|
+
"Inf1:1",
|
|
159
|
+
# Inf2
|
|
160
|
+
"Inf2:1",
|
|
161
|
+
# MI100
|
|
162
|
+
"MI100:1",
|
|
163
|
+
# MI210
|
|
164
|
+
"MI210:1",
|
|
165
|
+
# MI250
|
|
166
|
+
"MI250:1",
|
|
167
|
+
# MI250X
|
|
168
|
+
"MI250X:1",
|
|
169
|
+
# MI300A
|
|
170
|
+
"MI300A:1",
|
|
171
|
+
# MI300X
|
|
172
|
+
"MI300X:1",
|
|
173
|
+
# MI325X
|
|
174
|
+
"MI325X:1",
|
|
175
|
+
# MI350X
|
|
176
|
+
"MI350X:1",
|
|
177
|
+
# MI355X
|
|
178
|
+
"MI355X:1",
|
|
179
|
+
# Habana Gaudi
|
|
180
|
+
"Gaudi1:1",
|
|
89
181
|
]
|
|
90
182
|
|
|
183
|
+
DeviceClass = Literal["GPU", "TPU", "NEURON", "AMD_GPU", "HABANA_GAUDI"]
|
|
184
|
+
|
|
185
|
+
_DeviceClassType: Dict[typing.Any, str] = {
|
|
186
|
+
GPUType: "GPU",
|
|
187
|
+
TPUType: "TPU",
|
|
188
|
+
NeuronType: "NEURON",
|
|
189
|
+
AMD_GPUType: "AMD_GPU",
|
|
190
|
+
HABANA_GAUDIType: "HABANA_GAUDI",
|
|
191
|
+
}
|
|
192
|
+
|
|
91
193
|
|
|
92
194
|
@rich.repr.auto
|
|
93
195
|
@dataclass(frozen=True, slots=True)
|
|
@@ -100,6 +202,7 @@ class Device:
|
|
|
100
202
|
"""
|
|
101
203
|
|
|
102
204
|
quantity: int
|
|
205
|
+
device_class: DeviceClass
|
|
103
206
|
device: str | None = None
|
|
104
207
|
partition: str | None = None
|
|
105
208
|
|
|
@@ -126,7 +229,7 @@ def GPU(device: GPUType, quantity: GPUQuantity, partition: A100Parts | A100_80GB
|
|
|
126
229
|
elif partition is not None and device == "A100 80G":
|
|
127
230
|
if partition not in get_args(A100_80GBParts):
|
|
128
231
|
raise ValueError(f"Invalid partition for A100 80G: {partition}. Must be one of {get_args(A100_80GBParts)}")
|
|
129
|
-
return Device(device=device, quantity=quantity, partition=partition)
|
|
232
|
+
return Device(device=device, quantity=quantity, partition=partition, device_class="GPU")
|
|
130
233
|
|
|
131
234
|
|
|
132
235
|
def TPU(device: TPUType, partition: V5PParts | V6EParts | None = None):
|
|
@@ -147,7 +250,42 @@ def TPU(device: TPUType, partition: V5PParts | V6EParts | None = None):
|
|
|
147
250
|
elif partition is not None and device == "V5E":
|
|
148
251
|
if partition not in get_args(V5EParts):
|
|
149
252
|
raise ValueError(f"Invalid partition for V5E: {partition}. Must be one of {get_args(V5EParts)}")
|
|
150
|
-
return Device(1, device, partition)
|
|
253
|
+
return Device(1, "TPU", device, partition)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def Neuron(device: NeuronType) -> Device:
|
|
257
|
+
"""
|
|
258
|
+
Create a Neuron device instance.
|
|
259
|
+
:param device: Device type (e.g., "Inf1", "Inf2", "Trn1", "Trn1n", "Trn2", "Trn2u").
|
|
260
|
+
:param quantity: The number of Neuron devices of this type.
|
|
261
|
+
:return: Device instance.
|
|
262
|
+
"""
|
|
263
|
+
if device not in get_args(NeuronType):
|
|
264
|
+
raise ValueError(f"Invalid Neuron type: {device}. Must be one of {get_args(NeuronType)}")
|
|
265
|
+
return Device(device=device, quantity=1, device_class="NEURON")
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def AMD_GPU(device: AMD_GPUType) -> Device:
|
|
269
|
+
"""
|
|
270
|
+
Create an AMD GPU device instance.
|
|
271
|
+
:param device: Device type (e.g., "MI100", "MI210", "MI250", "MI250X", "MI300A", "MI300X", "MI325X", "MI350X",
|
|
272
|
+
"MI355X").
|
|
273
|
+
:return: Device instance.
|
|
274
|
+
"""
|
|
275
|
+
if device not in get_args(AMD_GPUType):
|
|
276
|
+
raise ValueError(f"Invalid AMD GPU type: {device}. Must be one of {get_args(AMD_GPUType)}")
|
|
277
|
+
return Device(device=device, quantity=1, device_class="AMD_GPU")
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def HABANA_GAUDI(device: HABANA_GAUDIType) -> Device:
|
|
281
|
+
"""
|
|
282
|
+
Create a Habana Gaudi device instance.
|
|
283
|
+
:param device: Device type (e.g., "DL1").
|
|
284
|
+
:return: Device instance.
|
|
285
|
+
"""
|
|
286
|
+
if device not in get_args(HABANA_GAUDIType):
|
|
287
|
+
raise ValueError(f"Invalid Habana Gaudi type: {device}. Must be one of {get_args(HABANA_GAUDIType)}")
|
|
288
|
+
return Device(device=device, quantity=1, device_class="HABANA_GAUDI")
|
|
151
289
|
|
|
152
290
|
|
|
153
291
|
CPUBaseType = int | float | str
|
|
@@ -202,7 +340,7 @@ class Resources:
|
|
|
202
340
|
raise ValueError("gpu must be greater than or equal to 0")
|
|
203
341
|
elif isinstance(self.gpu, str):
|
|
204
342
|
if self.gpu not in get_args(Accelerators):
|
|
205
|
-
raise ValueError(f"gpu must be one of {Accelerators}")
|
|
343
|
+
raise ValueError(f"gpu must be one of {Accelerators}, got {self.gpu}")
|
|
206
344
|
|
|
207
345
|
def get_device(self) -> Optional[Device]:
|
|
208
346
|
"""
|
|
@@ -214,10 +352,16 @@ class Resources:
|
|
|
214
352
|
if self.gpu is None:
|
|
215
353
|
return None
|
|
216
354
|
if isinstance(self.gpu, int):
|
|
217
|
-
return Device(quantity=self.gpu)
|
|
355
|
+
return Device(quantity=self.gpu, device_class="GPU")
|
|
218
356
|
if isinstance(self.gpu, str):
|
|
219
357
|
device, portion = self.gpu.split(":")
|
|
220
|
-
|
|
358
|
+
for cls, cls_name in _DeviceClassType.items():
|
|
359
|
+
if device in get_args(cls):
|
|
360
|
+
device_class = cls_name
|
|
361
|
+
break
|
|
362
|
+
else:
|
|
363
|
+
raise ValueError(f"Invalid device type: {device}. Must be one of {list(_DeviceClassType.keys())}")
|
|
364
|
+
return Device(device=device, device_class=device_class, quantity=int(portion)) # type: ignore
|
|
221
365
|
return self.gpu
|
|
222
366
|
|
|
223
367
|
def get_shared_memory(self) -> Optional[str]:
|
flyte/_run.py
CHANGED
|
@@ -12,13 +12,13 @@ from flyte._environment import Environment
|
|
|
12
12
|
from flyte._initialize import (
|
|
13
13
|
_get_init_config,
|
|
14
14
|
get_client,
|
|
15
|
-
|
|
15
|
+
get_init_config,
|
|
16
16
|
get_storage,
|
|
17
17
|
requires_initialization,
|
|
18
18
|
requires_storage,
|
|
19
19
|
)
|
|
20
20
|
from flyte._logging import logger
|
|
21
|
-
from flyte._task import P, R, TaskTemplate
|
|
21
|
+
from flyte._task import F, P, R, TaskTemplate
|
|
22
22
|
from flyte.models import (
|
|
23
23
|
ActionID,
|
|
24
24
|
Checkpoints,
|
|
@@ -94,6 +94,7 @@ class _Runner:
|
|
|
94
94
|
log_level: int | None = None,
|
|
95
95
|
disable_run_cache: bool = False,
|
|
96
96
|
queue: Optional[str] = None,
|
|
97
|
+
custom_context: Dict[str, str] | None = None,
|
|
97
98
|
):
|
|
98
99
|
from flyte._tools import ipython_check
|
|
99
100
|
|
|
@@ -124,11 +125,15 @@ class _Runner:
|
|
|
124
125
|
self._log_level = log_level
|
|
125
126
|
self._disable_run_cache = disable_run_cache
|
|
126
127
|
self._queue = queue
|
|
128
|
+
self._custom_context = custom_context or {}
|
|
127
129
|
|
|
128
130
|
@requires_initialization
|
|
129
|
-
async def _run_remote(self, obj: TaskTemplate[P, R] | LazyEntity, *args: P.args, **kwargs: P.kwargs) -> Run:
|
|
131
|
+
async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.args, **kwargs: P.kwargs) -> Run:
|
|
130
132
|
import grpc
|
|
131
|
-
from
|
|
133
|
+
from flyteidl2.common import identifier_pb2
|
|
134
|
+
from flyteidl2.core import literals_pb2
|
|
135
|
+
from flyteidl2.task import run_pb2
|
|
136
|
+
from flyteidl2.workflow import run_definition_pb2, run_service_pb2
|
|
132
137
|
from google.protobuf import wrappers_pb2
|
|
133
138
|
|
|
134
139
|
from flyte.remote import Run
|
|
@@ -138,21 +143,21 @@ class _Runner:
|
|
|
138
143
|
from ._deploy import build_images
|
|
139
144
|
from ._internal.runtime.convert import convert_from_native_to_inputs
|
|
140
145
|
from ._internal.runtime.task_serde import translate_task_to_wire
|
|
141
|
-
from ._protos.common import identifier_pb2
|
|
142
|
-
from ._protos.workflow import run_definition_pb2, run_service_pb2
|
|
143
146
|
|
|
144
|
-
cfg =
|
|
147
|
+
cfg = get_init_config()
|
|
145
148
|
project = self._project or cfg.project
|
|
146
149
|
domain = self._domain or cfg.domain
|
|
147
150
|
|
|
148
151
|
if isinstance(obj, LazyEntity):
|
|
149
152
|
task = await obj.fetch.aio()
|
|
150
153
|
task_spec = task.pb2.spec
|
|
151
|
-
inputs = await convert_from_native_to_inputs(
|
|
154
|
+
inputs = await convert_from_native_to_inputs(
|
|
155
|
+
task.interface, *args, custom_context=self._custom_context, **kwargs
|
|
156
|
+
)
|
|
152
157
|
version = task.pb2.task_id.version
|
|
153
158
|
code_bundle = None
|
|
154
159
|
else:
|
|
155
|
-
task = cast(TaskTemplate[P, R], obj)
|
|
160
|
+
task = cast(TaskTemplate[P, R, F], obj)
|
|
156
161
|
if obj.parent_env is None:
|
|
157
162
|
raise ValueError("Task is not attached to an environment. Please attach the task to an environment")
|
|
158
163
|
|
|
@@ -204,7 +209,9 @@ class _Runner:
|
|
|
204
209
|
root_dir=cfg.root_dir,
|
|
205
210
|
)
|
|
206
211
|
task_spec = translate_task_to_wire(obj, s_ctx)
|
|
207
|
-
inputs = await convert_from_native_to_inputs(
|
|
212
|
+
inputs = await convert_from_native_to_inputs(
|
|
213
|
+
obj.native_interface, *args, custom_context=self._custom_context, **kwargs
|
|
214
|
+
)
|
|
208
215
|
|
|
209
216
|
env = self._env_vars or {}
|
|
210
217
|
if env.get("LOG_LEVEL") is None:
|
|
@@ -254,9 +261,9 @@ class _Runner:
|
|
|
254
261
|
raise ValueError(f"Environment variable {k} must be a string, got {type(v)}")
|
|
255
262
|
kv_pairs.append(literals_pb2.KeyValuePair(key=k, value=v))
|
|
256
263
|
|
|
257
|
-
env_kv =
|
|
258
|
-
annotations =
|
|
259
|
-
labels =
|
|
264
|
+
env_kv = run_pb2.Envs(values=kv_pairs)
|
|
265
|
+
annotations = run_pb2.Annotations(values=self._annotations)
|
|
266
|
+
labels = run_pb2.Labels(values=self._labels)
|
|
260
267
|
|
|
261
268
|
try:
|
|
262
269
|
resp = await get_client().run_service.CreateRun(
|
|
@@ -265,7 +272,7 @@ class _Runner:
|
|
|
265
272
|
project_id=project_id,
|
|
266
273
|
task_spec=task_spec,
|
|
267
274
|
inputs=inputs.proto_inputs,
|
|
268
|
-
run_spec=
|
|
275
|
+
run_spec=run_pb2.RunSpec(
|
|
269
276
|
overwrite_cache=self._overwrite_cache,
|
|
270
277
|
interruptible=wrappers_pb2.BoolValue(value=self._interruptible)
|
|
271
278
|
if self._interruptible is not None
|
|
@@ -318,7 +325,7 @@ class _Runner:
|
|
|
318
325
|
|
|
319
326
|
@requires_storage
|
|
320
327
|
@requires_initialization
|
|
321
|
-
async def _run_hybrid(self, obj: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
|
328
|
+
async def _run_hybrid(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs: P.kwargs) -> R:
|
|
322
329
|
"""
|
|
323
330
|
Run a task in hybrid mode. This means that the parent action will be run locally, but the child actions will be
|
|
324
331
|
run in the cluster remotely. This is currently only used for testing,
|
|
@@ -333,7 +340,7 @@ class _Runner:
|
|
|
333
340
|
from ._internal import create_controller
|
|
334
341
|
from ._internal.runtime.taskrunner import run_task
|
|
335
342
|
|
|
336
|
-
cfg =
|
|
343
|
+
cfg = get_init_config()
|
|
337
344
|
|
|
338
345
|
if obj.parent_env is None:
|
|
339
346
|
raise ValueError("Task is not attached to an environment. Please attach the task to an environment.")
|
|
@@ -411,6 +418,7 @@ class _Runner:
|
|
|
411
418
|
compiled_image_cache=image_cache,
|
|
412
419
|
run_base_dir=run_base_dir,
|
|
413
420
|
report=flyte.report.Report(name=action.name),
|
|
421
|
+
custom_context=self._custom_context,
|
|
414
422
|
)
|
|
415
423
|
async with ctx.replace_task_context(tctx):
|
|
416
424
|
return await run_task(tctx=tctx, controller=controller, task=obj, inputs=inputs)
|
|
@@ -420,10 +428,11 @@ class _Runner:
|
|
|
420
428
|
raise err
|
|
421
429
|
return outputs
|
|
422
430
|
|
|
423
|
-
async def _run_local(self, obj: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> Run:
|
|
431
|
+
async def _run_local(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs: P.kwargs) -> Run:
|
|
432
|
+
from flyteidl2.common import identifier_pb2
|
|
433
|
+
|
|
424
434
|
from flyte._internal.controllers import create_controller
|
|
425
435
|
from flyte._internal.controllers._local_controller import LocalController
|
|
426
|
-
from flyte._protos.common import identifier_pb2
|
|
427
436
|
from flyte.remote import Run
|
|
428
437
|
from flyte.report import Report
|
|
429
438
|
|
|
@@ -461,7 +470,9 @@ class _Runner:
|
|
|
461
470
|
compiled_image_cache=None,
|
|
462
471
|
report=Report(name=action.name),
|
|
463
472
|
mode="local",
|
|
473
|
+
custom_context=self._custom_context,
|
|
464
474
|
)
|
|
475
|
+
|
|
465
476
|
with ctx.replace_task_context(tctx):
|
|
466
477
|
# make the local version always runs on a different thread, returns a wrapped future.
|
|
467
478
|
if obj._call_as_synchronous:
|
|
@@ -473,7 +484,7 @@ class _Runner:
|
|
|
473
484
|
|
|
474
485
|
class _LocalRun(Run):
|
|
475
486
|
def __init__(self, outputs: Tuple[Any, ...] | Any):
|
|
476
|
-
from
|
|
487
|
+
from flyteidl2.workflow import run_definition_pb2
|
|
477
488
|
|
|
478
489
|
self._outputs = outputs
|
|
479
490
|
super().__init__(
|
|
@@ -506,7 +517,7 @@ class _Runner:
|
|
|
506
517
|
@syncify
|
|
507
518
|
async def run(
|
|
508
519
|
self,
|
|
509
|
-
task: TaskTemplate[P, Union[R, Run]] | LazyEntity,
|
|
520
|
+
task: TaskTemplate[P, Union[R, Run], F] | LazyEntity,
|
|
510
521
|
*args: P.args,
|
|
511
522
|
**kwargs: P.kwargs,
|
|
512
523
|
) -> Union[R, Run]:
|
|
@@ -574,6 +585,7 @@ def with_runcontext(
|
|
|
574
585
|
log_level: int | None = None,
|
|
575
586
|
disable_run_cache: bool = False,
|
|
576
587
|
queue: Optional[str] = None,
|
|
588
|
+
custom_context: Dict[str, str] | None = None,
|
|
577
589
|
) -> _Runner:
|
|
578
590
|
"""
|
|
579
591
|
Launch a new run with the given parameters as the context.
|
|
@@ -618,6 +630,9 @@ def with_runcontext(
|
|
|
618
630
|
set using `flyte.init()`
|
|
619
631
|
:param disable_run_cache: Optional If true, the run cache will be disabled. This is useful for testing purposes.
|
|
620
632
|
:param queue: Optional The queue to use for the run. This is used to specify the cluster to use for the run.
|
|
633
|
+
:param custom_context: Optional global input context to pass to the task. This will be available via
|
|
634
|
+
get_custom_context() within the task and will automatically propagate to sub-tasks.
|
|
635
|
+
Acts as base/default values that can be overridden by context managers in the code.
|
|
621
636
|
|
|
622
637
|
:return: runner
|
|
623
638
|
"""
|
|
@@ -646,11 +661,12 @@ def with_runcontext(
|
|
|
646
661
|
log_level=log_level,
|
|
647
662
|
disable_run_cache=disable_run_cache,
|
|
648
663
|
queue=queue,
|
|
664
|
+
custom_context=custom_context,
|
|
649
665
|
)
|
|
650
666
|
|
|
651
667
|
|
|
652
668
|
@syncify
|
|
653
|
-
async def run(task: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> Union[R, Run]:
|
|
669
|
+
async def run(task: TaskTemplate[P, R, F], *args: P.args, **kwargs: P.kwargs) -> Union[R, Run]:
|
|
654
670
|
"""
|
|
655
671
|
Run a task with the given parameters
|
|
656
672
|
:param task: task to run
|