flyte 2.0.0b13__py3-none-any.whl → 2.0.0b30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyte/__init__.py +18 -2
- flyte/_bin/debug.py +38 -0
- flyte/_bin/runtime.py +62 -8
- flyte/_cache/cache.py +4 -2
- flyte/_cache/local_cache.py +216 -0
- flyte/_code_bundle/_ignore.py +12 -4
- flyte/_code_bundle/_packaging.py +13 -9
- flyte/_code_bundle/_utils.py +18 -10
- flyte/_code_bundle/bundle.py +17 -9
- flyte/_constants.py +1 -0
- flyte/_context.py +4 -1
- flyte/_custom_context.py +73 -0
- flyte/_debug/constants.py +38 -0
- flyte/_debug/utils.py +17 -0
- flyte/_debug/vscode.py +307 -0
- flyte/_deploy.py +235 -61
- flyte/_environment.py +20 -6
- flyte/_excepthook.py +1 -1
- flyte/_hash.py +1 -16
- flyte/_image.py +178 -81
- flyte/_initialize.py +132 -51
- flyte/_interface.py +39 -2
- flyte/_internal/controllers/__init__.py +4 -5
- flyte/_internal/controllers/_local_controller.py +70 -29
- flyte/_internal/controllers/_trace.py +1 -1
- flyte/_internal/controllers/remote/__init__.py +0 -2
- flyte/_internal/controllers/remote/_action.py +14 -16
- flyte/_internal/controllers/remote/_client.py +1 -1
- flyte/_internal/controllers/remote/_controller.py +68 -70
- flyte/_internal/controllers/remote/_core.py +127 -99
- flyte/_internal/controllers/remote/_informer.py +19 -10
- flyte/_internal/controllers/remote/_service_protocol.py +7 -7
- flyte/_internal/imagebuild/docker_builder.py +181 -69
- flyte/_internal/imagebuild/image_builder.py +0 -5
- flyte/_internal/imagebuild/remote_builder.py +155 -64
- flyte/_internal/imagebuild/utils.py +51 -2
- flyte/_internal/resolvers/_task_module.py +5 -38
- flyte/_internal/resolvers/default.py +2 -2
- flyte/_internal/runtime/convert.py +110 -21
- flyte/_internal/runtime/entrypoints.py +27 -1
- flyte/_internal/runtime/io.py +21 -8
- flyte/_internal/runtime/resources_serde.py +20 -6
- flyte/_internal/runtime/reuse.py +1 -1
- flyte/_internal/runtime/rusty.py +20 -5
- flyte/_internal/runtime/task_serde.py +34 -19
- flyte/_internal/runtime/taskrunner.py +22 -4
- flyte/_internal/runtime/trigger_serde.py +160 -0
- flyte/_internal/runtime/types_serde.py +1 -1
- flyte/_keyring/__init__.py +0 -0
- flyte/_keyring/file.py +115 -0
- flyte/_logging.py +201 -39
- flyte/_map.py +111 -14
- flyte/_module.py +70 -0
- flyte/_pod.py +4 -3
- flyte/_resources.py +213 -31
- flyte/_run.py +110 -39
- flyte/_task.py +75 -16
- flyte/_task_environment.py +105 -29
- flyte/_task_plugins.py +4 -2
- flyte/_trace.py +5 -0
- flyte/_trigger.py +1000 -0
- flyte/_utils/__init__.py +2 -1
- flyte/_utils/asyn.py +3 -1
- flyte/_utils/coro_management.py +2 -1
- flyte/_utils/docker_credentials.py +173 -0
- flyte/_utils/module_loader.py +17 -2
- flyte/_version.py +3 -3
- flyte/cli/_abort.py +3 -3
- flyte/cli/_build.py +3 -6
- flyte/cli/_common.py +78 -7
- flyte/cli/_create.py +182 -4
- flyte/cli/_delete.py +23 -1
- flyte/cli/_deploy.py +63 -16
- flyte/cli/_get.py +79 -34
- flyte/cli/_params.py +26 -10
- flyte/cli/_plugins.py +209 -0
- flyte/cli/_run.py +151 -26
- flyte/cli/_serve.py +64 -0
- flyte/cli/_update.py +37 -0
- flyte/cli/_user.py +17 -0
- flyte/cli/main.py +30 -4
- flyte/config/_config.py +10 -6
- flyte/config/_internal.py +1 -0
- flyte/config/_reader.py +29 -8
- flyte/connectors/__init__.py +11 -0
- flyte/connectors/_connector.py +270 -0
- flyte/connectors/_server.py +197 -0
- flyte/connectors/utils.py +135 -0
- flyte/errors.py +22 -2
- flyte/extend.py +8 -1
- flyte/extras/_container.py +6 -1
- flyte/git/__init__.py +3 -0
- flyte/git/_config.py +21 -0
- flyte/io/__init__.py +2 -0
- flyte/io/_dataframe/__init__.py +2 -0
- flyte/io/_dataframe/basic_dfs.py +17 -8
- flyte/io/_dataframe/dataframe.py +98 -132
- flyte/io/_dir.py +575 -113
- flyte/io/_file.py +582 -139
- flyte/io/_hashing_io.py +342 -0
- flyte/models.py +74 -15
- flyte/remote/__init__.py +6 -1
- flyte/remote/_action.py +34 -26
- flyte/remote/_client/_protocols.py +39 -4
- flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
- flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
- flyte/remote/_client/auth/_channel.py +10 -6
- flyte/remote/_client/controlplane.py +17 -5
- flyte/remote/_console.py +3 -2
- flyte/remote/_data.py +6 -6
- flyte/remote/_logs.py +3 -3
- flyte/remote/_run.py +64 -8
- flyte/remote/_secret.py +26 -17
- flyte/remote/_task.py +75 -33
- flyte/remote/_trigger.py +306 -0
- flyte/remote/_user.py +33 -0
- flyte/report/_report.py +1 -1
- flyte/storage/__init__.py +6 -1
- flyte/storage/_config.py +5 -1
- flyte/storage/_parallel_reader.py +274 -0
- flyte/storage/_storage.py +200 -103
- flyte/types/__init__.py +16 -0
- flyte/types/_interface.py +2 -2
- flyte/types/_pickle.py +35 -8
- flyte/types/_string_literals.py +8 -9
- flyte/types/_type_engine.py +40 -70
- flyte/types/_utils.py +1 -1
- flyte-2.0.0b30.data/scripts/debug.py +38 -0
- {flyte-2.0.0b13.data → flyte-2.0.0b30.data}/scripts/runtime.py +62 -8
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/METADATA +11 -3
- flyte-2.0.0b30.dist-info/RECORD +192 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +3 -0
- flyte/_protos/common/authorization_pb2.py +0 -66
- flyte/_protos/common/authorization_pb2.pyi +0 -108
- flyte/_protos/common/authorization_pb2_grpc.py +0 -4
- flyte/_protos/common/identifier_pb2.py +0 -93
- flyte/_protos/common/identifier_pb2.pyi +0 -110
- flyte/_protos/common/identifier_pb2_grpc.py +0 -4
- flyte/_protos/common/identity_pb2.py +0 -48
- flyte/_protos/common/identity_pb2.pyi +0 -72
- flyte/_protos/common/identity_pb2_grpc.py +0 -4
- flyte/_protos/common/list_pb2.py +0 -36
- flyte/_protos/common/list_pb2.pyi +0 -71
- flyte/_protos/common/list_pb2_grpc.py +0 -4
- flyte/_protos/common/policy_pb2.py +0 -37
- flyte/_protos/common/policy_pb2.pyi +0 -27
- flyte/_protos/common/policy_pb2_grpc.py +0 -4
- flyte/_protos/common/role_pb2.py +0 -37
- flyte/_protos/common/role_pb2.pyi +0 -53
- flyte/_protos/common/role_pb2_grpc.py +0 -4
- flyte/_protos/common/runtime_version_pb2.py +0 -28
- flyte/_protos/common/runtime_version_pb2.pyi +0 -24
- flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/definition_pb2.py +0 -59
- flyte/_protos/imagebuilder/definition_pb2.pyi +0 -140
- flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/payload_pb2.py +0 -32
- flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
- flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/service_pb2.py +0 -29
- flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
- flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
- flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
- flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/definition_pb2.py +0 -49
- flyte/_protos/secret/definition_pb2.pyi +0 -93
- flyte/_protos/secret/definition_pb2_grpc.py +0 -4
- flyte/_protos/secret/payload_pb2.py +0 -62
- flyte/_protos/secret/payload_pb2.pyi +0 -94
- flyte/_protos/secret/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/secret_pb2.py +0 -38
- flyte/_protos/secret/secret_pb2.pyi +0 -6
- flyte/_protos/secret/secret_pb2_grpc.py +0 -198
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
- flyte/_protos/validate/validate/validate_pb2.py +0 -76
- flyte/_protos/workflow/common_pb2.py +0 -27
- flyte/_protos/workflow/common_pb2.pyi +0 -14
- flyte/_protos/workflow/common_pb2_grpc.py +0 -4
- flyte/_protos/workflow/environment_pb2.py +0 -29
- flyte/_protos/workflow/environment_pb2.pyi +0 -12
- flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
- flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
- flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
- flyte/_protos/workflow/queue_service_pb2.py +0 -109
- flyte/_protos/workflow/queue_service_pb2.pyi +0 -166
- flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
- flyte/_protos/workflow/run_definition_pb2.py +0 -121
- flyte/_protos/workflow/run_definition_pb2.pyi +0 -327
- flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
- flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
- flyte/_protos/workflow/run_service_pb2.py +0 -137
- flyte/_protos/workflow/run_service_pb2.pyi +0 -185
- flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
- flyte/_protos/workflow/state_service_pb2.py +0 -67
- flyte/_protos/workflow/state_service_pb2.pyi +0 -76
- flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
- flyte/_protos/workflow/task_definition_pb2.py +0 -79
- flyte/_protos/workflow/task_definition_pb2.pyi +0 -81
- flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/task_service_pb2.py +0 -60
- flyte/_protos/workflow/task_service_pb2.pyi +0 -59
- flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
- flyte-2.0.0b13.dist-info/RECORD +0 -239
- /flyte/{_protos → _debug}/__init__.py +0 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
flyte/_map.py
CHANGED
|
@@ -1,17 +1,26 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
|
|
2
|
+
import functools
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast, overload
|
|
3
5
|
|
|
4
6
|
from flyte.syncify import syncify
|
|
5
7
|
|
|
6
8
|
from ._group import group
|
|
7
9
|
from ._logging import logger
|
|
8
|
-
from ._task import P, R
|
|
10
|
+
from ._task import AsyncFunctionTaskTemplate, F, P, R
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
class MapAsyncIterator(Generic[P, R]):
|
|
12
14
|
"""AsyncIterator implementation for the map function results"""
|
|
13
15
|
|
|
14
|
-
def __init__(
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
19
|
+
args: tuple,
|
|
20
|
+
name: str,
|
|
21
|
+
concurrency: int,
|
|
22
|
+
return_exceptions: bool,
|
|
23
|
+
):
|
|
15
24
|
self.func = func
|
|
16
25
|
self.args = args
|
|
17
26
|
self.name = name
|
|
@@ -49,13 +58,16 @@ class MapAsyncIterator(Generic[P, R]):
|
|
|
49
58
|
return result
|
|
50
59
|
except Exception as e:
|
|
51
60
|
self._exception_count += 1
|
|
52
|
-
logger.debug(
|
|
61
|
+
logger.debug(
|
|
62
|
+
f"Task {self._current_index - 1} failed with exception: {e}, return_exceptions={self.return_exceptions}"
|
|
63
|
+
)
|
|
53
64
|
if self.return_exceptions:
|
|
54
65
|
return e
|
|
55
66
|
else:
|
|
56
67
|
# Cancel remaining tasks
|
|
57
68
|
for remaining_task in self._tasks[self._current_index + 1 :]:
|
|
58
69
|
remaining_task.cancel()
|
|
70
|
+
logger.warning("Exception raising is `ON`, raising exception and cancelling remaining tasks")
|
|
59
71
|
raise e
|
|
60
72
|
|
|
61
73
|
async def _initialize(self):
|
|
@@ -64,10 +76,26 @@ class MapAsyncIterator(Generic[P, R]):
|
|
|
64
76
|
tasks = []
|
|
65
77
|
task_count = 0
|
|
66
78
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
79
|
+
if isinstance(self.func, functools.partial):
|
|
80
|
+
# Handle partial functions by merging bound args/kwargs with mapped args
|
|
81
|
+
base_func = cast(AsyncFunctionTaskTemplate, self.func.func)
|
|
82
|
+
bound_args = self.func.args
|
|
83
|
+
bound_kwargs = self.func.keywords or {}
|
|
84
|
+
|
|
85
|
+
for arg_tuple in zip(*self.args):
|
|
86
|
+
# Merge bound positional args with mapped args
|
|
87
|
+
merged_args = bound_args + arg_tuple
|
|
88
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
89
|
+
logger.debug(f"Running {base_func.name} with args: {merged_args} and kwargs: {bound_kwargs}")
|
|
90
|
+
task = asyncio.create_task(base_func.aio(*merged_args, **bound_kwargs))
|
|
91
|
+
tasks.append(task)
|
|
92
|
+
task_count += 1
|
|
93
|
+
else:
|
|
94
|
+
# Handle regular TaskTemplate functions
|
|
95
|
+
for arg_tuple in zip(*self.args):
|
|
96
|
+
task = asyncio.create_task(self.func.aio(*arg_tuple))
|
|
97
|
+
tasks.append(task)
|
|
98
|
+
task_count += 1
|
|
71
99
|
|
|
72
100
|
if task_count == 0:
|
|
73
101
|
logger.info(f"Group '{self.name}' has no tasks to process")
|
|
@@ -107,9 +135,65 @@ class _Mapper(Generic[P, R]):
|
|
|
107
135
|
"""Get the name of the group, defaulting to 'map' if not provided."""
|
|
108
136
|
return f"{task_name}_{group_name or 'map'}"
|
|
109
137
|
|
|
138
|
+
@staticmethod
|
|
139
|
+
def validate_partial(func: functools.partial[R]):
|
|
140
|
+
"""
|
|
141
|
+
This method validates that the provided partial function is valid for mapping, i.e. only the one argument
|
|
142
|
+
is left for mapping and the rest are provided as keywords or args.
|
|
143
|
+
|
|
144
|
+
:param func: partial function to validate
|
|
145
|
+
:raises TypeError: if the partial function is not valid for mapping
|
|
146
|
+
"""
|
|
147
|
+
f = cast(AsyncFunctionTaskTemplate, func.func)
|
|
148
|
+
inputs = f.native_interface.inputs
|
|
149
|
+
params = list(inputs.keys())
|
|
150
|
+
total_params = len(params)
|
|
151
|
+
provided_args = len(func.args)
|
|
152
|
+
provided_kwargs = len(func.keywords or {})
|
|
153
|
+
|
|
154
|
+
# Calculate how many parameters are left unspecified
|
|
155
|
+
unspecified_count = total_params - provided_args - provided_kwargs
|
|
156
|
+
|
|
157
|
+
# Exactly one parameter should be left for mapping
|
|
158
|
+
if unspecified_count != 1:
|
|
159
|
+
raise TypeError(
|
|
160
|
+
f"Partial function must leave exactly one parameter unspecified for mapping. "
|
|
161
|
+
f"Found {unspecified_count} unspecified parameters in {f.name}, "
|
|
162
|
+
f"params: {inputs.keys()}"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Validate that no parameter is both in args and keywords
|
|
166
|
+
if func.keywords:
|
|
167
|
+
param_names = list(inputs.keys())
|
|
168
|
+
for i, arg_name in enumerate(param_names[: provided_args + 1]):
|
|
169
|
+
if arg_name in func.keywords:
|
|
170
|
+
raise TypeError(
|
|
171
|
+
f"Parameter '{arg_name}' is provided both as positional argument and keyword argument "
|
|
172
|
+
f"in partial function {f.name}."
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
@overload
|
|
176
|
+
def __call__(
|
|
177
|
+
self,
|
|
178
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
179
|
+
*args: Iterable[Any],
|
|
180
|
+
group_name: str | None = None,
|
|
181
|
+
concurrency: int = 0,
|
|
182
|
+
) -> Iterator[R]: ...
|
|
183
|
+
|
|
184
|
+
@overload
|
|
110
185
|
def __call__(
|
|
111
186
|
self,
|
|
112
|
-
func:
|
|
187
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
188
|
+
*args: Iterable[Any],
|
|
189
|
+
group_name: str | None = None,
|
|
190
|
+
concurrency: int = 0,
|
|
191
|
+
return_exceptions: bool = True,
|
|
192
|
+
) -> Iterator[Union[R, Exception]]: ...
|
|
193
|
+
|
|
194
|
+
def __call__(
|
|
195
|
+
self,
|
|
196
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
113
197
|
*args: Iterable[Any],
|
|
114
198
|
group_name: str | None = None,
|
|
115
199
|
concurrency: int = 0,
|
|
@@ -128,7 +212,13 @@ class _Mapper(Generic[P, R]):
|
|
|
128
212
|
if not args:
|
|
129
213
|
return
|
|
130
214
|
|
|
131
|
-
|
|
215
|
+
if isinstance(func, functools.partial):
|
|
216
|
+
f = cast(AsyncFunctionTaskTemplate, func.func)
|
|
217
|
+
self.validate_partial(func)
|
|
218
|
+
else:
|
|
219
|
+
f = cast(AsyncFunctionTaskTemplate, func)
|
|
220
|
+
|
|
221
|
+
name = self._get_name(f.name, group_name)
|
|
132
222
|
logger.debug(f"Blocking Map for {name}")
|
|
133
223
|
with group(name):
|
|
134
224
|
import flyte
|
|
@@ -154,7 +244,7 @@ class _Mapper(Generic[P, R]):
|
|
|
154
244
|
*args,
|
|
155
245
|
name=name,
|
|
156
246
|
concurrency=concurrency,
|
|
157
|
-
return_exceptions=
|
|
247
|
+
return_exceptions=return_exceptions,
|
|
158
248
|
),
|
|
159
249
|
):
|
|
160
250
|
logger.debug(f"Mapped {x}, task {i}")
|
|
@@ -163,7 +253,7 @@ class _Mapper(Generic[P, R]):
|
|
|
163
253
|
|
|
164
254
|
async def aio(
|
|
165
255
|
self,
|
|
166
|
-
func:
|
|
256
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
167
257
|
*args: Iterable[Any],
|
|
168
258
|
group_name: str | None = None,
|
|
169
259
|
concurrency: int = 0,
|
|
@@ -171,7 +261,14 @@ class _Mapper(Generic[P, R]):
|
|
|
171
261
|
) -> AsyncGenerator[Union[R, Exception], None]:
|
|
172
262
|
if not args:
|
|
173
263
|
return
|
|
174
|
-
|
|
264
|
+
|
|
265
|
+
if isinstance(func, functools.partial):
|
|
266
|
+
f = cast(AsyncFunctionTaskTemplate, func.func)
|
|
267
|
+
self.validate_partial(func)
|
|
268
|
+
else:
|
|
269
|
+
f = cast(AsyncFunctionTaskTemplate, func)
|
|
270
|
+
|
|
271
|
+
name = self._get_name(f.name, group_name)
|
|
175
272
|
with group(name):
|
|
176
273
|
import flyte
|
|
177
274
|
|
|
@@ -199,7 +296,7 @@ class _Mapper(Generic[P, R]):
|
|
|
199
296
|
|
|
200
297
|
@syncify
|
|
201
298
|
async def _map(
|
|
202
|
-
func:
|
|
299
|
+
func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
|
|
203
300
|
*args: Iterable[Any],
|
|
204
301
|
name: str = "map",
|
|
205
302
|
concurrency: int = 0,
|
flyte/_module.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import os
|
|
3
|
+
import pathlib
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def extract_obj_module(obj: object, /, source_dir: pathlib.Path) -> str:
|
|
8
|
+
"""
|
|
9
|
+
Extract the module from the given object. If source_dir is provided, the module will be relative to the source_dir.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
obj: The object to extract the module from.
|
|
13
|
+
source_dir: The source directory to use for relative paths.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
The module name as a string.
|
|
17
|
+
"""
|
|
18
|
+
if source_dir is None:
|
|
19
|
+
raise ValueError("extract_obj_module: source_dir cannot be None - specify root-dir")
|
|
20
|
+
# Get the module containing the object
|
|
21
|
+
entity_module = inspect.getmodule(obj)
|
|
22
|
+
if entity_module is None:
|
|
23
|
+
obj_name = getattr(obj, "__name__", str(obj))
|
|
24
|
+
raise ValueError(f"Object {obj_name} has no module.")
|
|
25
|
+
|
|
26
|
+
fp = entity_module.__file__
|
|
27
|
+
if fp is None:
|
|
28
|
+
obj_name = getattr(obj, "__name__", str(obj))
|
|
29
|
+
raise ValueError(f"Object {obj_name} has no module.")
|
|
30
|
+
|
|
31
|
+
file_path = pathlib.Path(fp)
|
|
32
|
+
try:
|
|
33
|
+
# Get the relative path to the current directory
|
|
34
|
+
# Will raise ValueError if the file is not in the source directory
|
|
35
|
+
relative_path = file_path.relative_to(str(pathlib.Path(source_dir).absolute()))
|
|
36
|
+
|
|
37
|
+
if relative_path == pathlib.Path("_internal/resolvers"):
|
|
38
|
+
entity_module_name = entity_module.__name__
|
|
39
|
+
else:
|
|
40
|
+
# Replace file separators with dots and remove the '.py' extension
|
|
41
|
+
dotted_path = os.path.splitext(str(relative_path))[0].replace(os.sep, ".")
|
|
42
|
+
entity_module_name = dotted_path
|
|
43
|
+
except ValueError:
|
|
44
|
+
# If source_dir is not provided or file is not in source_dir, fallback to module name
|
|
45
|
+
# File is not relative to source_dir - check if it's an installed package
|
|
46
|
+
file_path_str = str(file_path)
|
|
47
|
+
if "site-packages" in file_path_str or "dist-packages" in file_path_str:
|
|
48
|
+
# It's an installed package - use the module's __name__ directly
|
|
49
|
+
# This will be importable via importlib.import_module()
|
|
50
|
+
entity_module_name = entity_module.__name__
|
|
51
|
+
else:
|
|
52
|
+
# File is not in source_dir and not in site-packages - re-raise the error
|
|
53
|
+
obj_name = getattr(obj, "__name__", str(obj))
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"Object {obj_name} module file {file_path} is not relative to "
|
|
56
|
+
f"source directory {source_dir} and is not an installed package."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
if entity_module_name == "__main__":
|
|
60
|
+
"""
|
|
61
|
+
This case is for the case in which the object is run from the main module.
|
|
62
|
+
"""
|
|
63
|
+
fp = sys.modules["__main__"].__file__
|
|
64
|
+
if fp is None:
|
|
65
|
+
obj_name = getattr(obj, "__name__", str(obj))
|
|
66
|
+
raise ValueError(f"Object {obj_name} has no module.")
|
|
67
|
+
main_path = pathlib.Path(fp)
|
|
68
|
+
entity_module_name = main_path.stem
|
|
69
|
+
|
|
70
|
+
return entity_module_name
|
flyte/_pod.py
CHANGED
|
@@ -2,8 +2,8 @@ from dataclasses import dataclass, field
|
|
|
2
2
|
from typing import TYPE_CHECKING, Dict, Optional
|
|
3
3
|
|
|
4
4
|
if TYPE_CHECKING:
|
|
5
|
-
from
|
|
6
|
-
from kubernetes.client import
|
|
5
|
+
from flyteidl2.core.tasks_pb2 import K8sPod
|
|
6
|
+
from kubernetes.client import V1PodSpec
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
|
|
@@ -20,7 +20,8 @@ class PodTemplate(object):
|
|
|
20
20
|
annotations: Optional[Dict[str, str]] = None
|
|
21
21
|
|
|
22
22
|
def to_k8s_pod(self) -> "K8sPod":
|
|
23
|
-
from
|
|
23
|
+
from flyteidl2.core.tasks_pb2 import K8sObjectMetadata, K8sPod
|
|
24
|
+
from kubernetes.client import ApiClient
|
|
24
25
|
|
|
25
26
|
return K8sPod(
|
|
26
27
|
metadata=K8sObjectMetadata(labels=self.labels, annotations=self.annotations),
|
flyte/_resources.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
|
+
import typing
|
|
1
2
|
from dataclasses import dataclass, fields
|
|
2
|
-
from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union, get_args
|
|
3
|
+
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union, get_args
|
|
3
4
|
|
|
4
5
|
import rich.repr
|
|
5
6
|
|
|
@@ -10,7 +11,7 @@ if TYPE_CHECKING:
|
|
|
10
11
|
|
|
11
12
|
PRIMARY_CONTAINER_DEFAULT_NAME = "primary"
|
|
12
13
|
|
|
13
|
-
GPUType = Literal["
|
|
14
|
+
GPUType = Literal["A10", "A10G", "A100", "A100 80G", "B200", "H100", "L4", "L40s", "T4", "V100", "RTX PRO 6000"]
|
|
14
15
|
GPUQuantity = Literal[1, 2, 3, 4, 5, 6, 7, 8]
|
|
15
16
|
A100Parts = Literal["1g.5gb", "2g.10gb", "3g.20gb", "4g.20gb", "7g.40gb"]
|
|
16
17
|
"""
|
|
@@ -37,31 +38,32 @@ V6EParts = Literal["1x1", "2x2", "2x4", "4x4", "4x8", "8x8", "8x16", "16x16"]
|
|
|
37
38
|
Slices for Google Cloud TPU v6e.
|
|
38
39
|
"""
|
|
39
40
|
|
|
41
|
+
NeuronType = Literal["Inf1", "Inf2", "Trn1", "Trn1n", "Trn2", "Trn2u"]
|
|
42
|
+
|
|
43
|
+
AMD_GPUType = Literal["MI100", "MI210", "MI250", "MI250X", "MI300A", "MI300X", "MI325X", "MI350X", "MI355X"]
|
|
44
|
+
|
|
45
|
+
HABANA_GAUDIType = Literal["Gaudi1"]
|
|
46
|
+
|
|
40
47
|
Accelerators = Literal[
|
|
41
|
-
|
|
42
|
-
"
|
|
43
|
-
"
|
|
44
|
-
"
|
|
45
|
-
"
|
|
46
|
-
"
|
|
47
|
-
"
|
|
48
|
-
"
|
|
49
|
-
"
|
|
50
|
-
|
|
51
|
-
"
|
|
52
|
-
"
|
|
53
|
-
"
|
|
54
|
-
"
|
|
55
|
-
"
|
|
56
|
-
"
|
|
57
|
-
"
|
|
58
|
-
"
|
|
59
|
-
|
|
60
|
-
"L40s:4",
|
|
61
|
-
"L40s:5",
|
|
62
|
-
"L40s:6",
|
|
63
|
-
"L40s:7",
|
|
64
|
-
"L40s:8",
|
|
48
|
+
# A10
|
|
49
|
+
"A10:1",
|
|
50
|
+
"A10:2",
|
|
51
|
+
"A10:3",
|
|
52
|
+
"A10:4",
|
|
53
|
+
"A10:5",
|
|
54
|
+
"A10:6",
|
|
55
|
+
"A10:7",
|
|
56
|
+
"A10:8",
|
|
57
|
+
# A10G
|
|
58
|
+
"A10G:1",
|
|
59
|
+
"A10G:2",
|
|
60
|
+
"A10G:3",
|
|
61
|
+
"A10G:4",
|
|
62
|
+
"A10G:5",
|
|
63
|
+
"A10G:6",
|
|
64
|
+
"A10G:7",
|
|
65
|
+
"A10G:8",
|
|
66
|
+
# A100
|
|
65
67
|
"A100:1",
|
|
66
68
|
"A100:2",
|
|
67
69
|
"A100:3",
|
|
@@ -70,6 +72,7 @@ Accelerators = Literal[
|
|
|
70
72
|
"A100:6",
|
|
71
73
|
"A100:7",
|
|
72
74
|
"A100:8",
|
|
75
|
+
# A100 80G
|
|
73
76
|
"A100 80G:1",
|
|
74
77
|
"A100 80G:2",
|
|
75
78
|
"A100 80G:3",
|
|
@@ -78,6 +81,16 @@ Accelerators = Literal[
|
|
|
78
81
|
"A100 80G:6",
|
|
79
82
|
"A100 80G:7",
|
|
80
83
|
"A100 80G:8",
|
|
84
|
+
# B200
|
|
85
|
+
"B200:1",
|
|
86
|
+
"B200:2",
|
|
87
|
+
"B200:3",
|
|
88
|
+
"B200:4",
|
|
89
|
+
"B200:5",
|
|
90
|
+
"B200:6",
|
|
91
|
+
"B200:7",
|
|
92
|
+
"B200:8",
|
|
93
|
+
# H100
|
|
81
94
|
"H100:1",
|
|
82
95
|
"H100:2",
|
|
83
96
|
"H100:3",
|
|
@@ -86,8 +99,135 @@ Accelerators = Literal[
|
|
|
86
99
|
"H100:6",
|
|
87
100
|
"H100:7",
|
|
88
101
|
"H100:8",
|
|
102
|
+
# H200
|
|
103
|
+
"H200:1",
|
|
104
|
+
"H200:2",
|
|
105
|
+
"H200:3",
|
|
106
|
+
"H200:4",
|
|
107
|
+
"H200:5",
|
|
108
|
+
"H200:6",
|
|
109
|
+
"H200:7",
|
|
110
|
+
"H200:8",
|
|
111
|
+
# L4
|
|
112
|
+
"L4:1",
|
|
113
|
+
"L4:2",
|
|
114
|
+
"L4:3",
|
|
115
|
+
"L4:4",
|
|
116
|
+
"L4:5",
|
|
117
|
+
"L4:6",
|
|
118
|
+
"L4:7",
|
|
119
|
+
"L4:8",
|
|
120
|
+
# L40s
|
|
121
|
+
"L40s:1",
|
|
122
|
+
"L40s:2",
|
|
123
|
+
"L40s:3",
|
|
124
|
+
"L40s:4",
|
|
125
|
+
"L40s:5",
|
|
126
|
+
"L40s:6",
|
|
127
|
+
"L40s:7",
|
|
128
|
+
"L40s:8",
|
|
129
|
+
# V100
|
|
130
|
+
"V100:1",
|
|
131
|
+
"V100:2",
|
|
132
|
+
"V100:3",
|
|
133
|
+
"V100:4",
|
|
134
|
+
"V100:5",
|
|
135
|
+
"V100:6",
|
|
136
|
+
"V100:7",
|
|
137
|
+
"V100:8",
|
|
138
|
+
# RTX 6000
|
|
139
|
+
"RTX PRO 6000:1",
|
|
140
|
+
# T4
|
|
141
|
+
"T4:1",
|
|
142
|
+
"T4:2",
|
|
143
|
+
"T4:3",
|
|
144
|
+
"T4:4",
|
|
145
|
+
"T4:5",
|
|
146
|
+
"T4:6",
|
|
147
|
+
"T4:7",
|
|
148
|
+
"T4:8",
|
|
149
|
+
# Trn1
|
|
150
|
+
"Trn1:1",
|
|
151
|
+
"Trn1:4",
|
|
152
|
+
"Trn1:8",
|
|
153
|
+
"Trn1:16",
|
|
154
|
+
# Trn1n
|
|
155
|
+
"Trn1n:1",
|
|
156
|
+
"Trn1n:4",
|
|
157
|
+
"Trn1n:8",
|
|
158
|
+
"Trn1n:16",
|
|
159
|
+
# Trn2
|
|
160
|
+
"Trn2:1",
|
|
161
|
+
"Trn2:4",
|
|
162
|
+
"Trn2:8",
|
|
163
|
+
"Trn2:16",
|
|
164
|
+
# Trn2u
|
|
165
|
+
"Trn2u:1",
|
|
166
|
+
"Trn2u:4",
|
|
167
|
+
"Trn2u:8",
|
|
168
|
+
"Trn2u:16",
|
|
169
|
+
# Inf1
|
|
170
|
+
"Inf1:1",
|
|
171
|
+
"Inf1:2",
|
|
172
|
+
"Inf1:3",
|
|
173
|
+
"Inf1:4",
|
|
174
|
+
"Inf1:5",
|
|
175
|
+
"Inf1:6",
|
|
176
|
+
"Inf1:7",
|
|
177
|
+
"Inf1:8",
|
|
178
|
+
"Inf1:9",
|
|
179
|
+
"Inf1:10",
|
|
180
|
+
"Inf1:11",
|
|
181
|
+
"Inf1:12",
|
|
182
|
+
"Inf1:13",
|
|
183
|
+
"Inf1:14",
|
|
184
|
+
"Inf1:15",
|
|
185
|
+
"Inf1:16",
|
|
186
|
+
# Inf2
|
|
187
|
+
"Inf2:1",
|
|
188
|
+
"Inf2:2",
|
|
189
|
+
"Inf2:3",
|
|
190
|
+
"Inf2:4",
|
|
191
|
+
"Inf2:5",
|
|
192
|
+
"Inf2:6",
|
|
193
|
+
"Inf2:7",
|
|
194
|
+
"Inf2:8",
|
|
195
|
+
"Inf2:9",
|
|
196
|
+
"Inf2:10",
|
|
197
|
+
"Inf2:11",
|
|
198
|
+
"Inf2:12",
|
|
199
|
+
# MI100
|
|
200
|
+
"MI100:1",
|
|
201
|
+
# MI210
|
|
202
|
+
"MI210:1",
|
|
203
|
+
# MI250
|
|
204
|
+
"MI250:1",
|
|
205
|
+
# MI250X
|
|
206
|
+
"MI250X:1",
|
|
207
|
+
# MI300A
|
|
208
|
+
"MI300A:1",
|
|
209
|
+
# MI300X
|
|
210
|
+
"MI300X:1",
|
|
211
|
+
# MI325X
|
|
212
|
+
"MI325X:1",
|
|
213
|
+
# MI350X
|
|
214
|
+
"MI350X:1",
|
|
215
|
+
# MI355X
|
|
216
|
+
"MI355X:1",
|
|
217
|
+
# Habana Gaudi
|
|
218
|
+
"Gaudi1:1",
|
|
89
219
|
]
|
|
90
220
|
|
|
221
|
+
DeviceClass = Literal["GPU", "TPU", "NEURON", "AMD_GPU", "HABANA_GAUDI"]
|
|
222
|
+
|
|
223
|
+
_DeviceClassType: Dict[typing.Any, str] = {
|
|
224
|
+
GPUType: "GPU",
|
|
225
|
+
TPUType: "TPU",
|
|
226
|
+
NeuronType: "NEURON",
|
|
227
|
+
AMD_GPUType: "AMD_GPU",
|
|
228
|
+
HABANA_GAUDIType: "HABANA_GAUDI",
|
|
229
|
+
}
|
|
230
|
+
|
|
91
231
|
|
|
92
232
|
@rich.repr.auto
|
|
93
233
|
@dataclass(frozen=True, slots=True)
|
|
@@ -100,6 +240,7 @@ class Device:
|
|
|
100
240
|
"""
|
|
101
241
|
|
|
102
242
|
quantity: int
|
|
243
|
+
device_class: DeviceClass
|
|
103
244
|
device: str | None = None
|
|
104
245
|
partition: str | None = None
|
|
105
246
|
|
|
@@ -126,7 +267,7 @@ def GPU(device: GPUType, quantity: GPUQuantity, partition: A100Parts | A100_80GB
|
|
|
126
267
|
elif partition is not None and device == "A100 80G":
|
|
127
268
|
if partition not in get_args(A100_80GBParts):
|
|
128
269
|
raise ValueError(f"Invalid partition for A100 80G: {partition}. Must be one of {get_args(A100_80GBParts)}")
|
|
129
|
-
return Device(device=device, quantity=quantity, partition=partition)
|
|
270
|
+
return Device(device=device, quantity=quantity, partition=partition, device_class="GPU")
|
|
130
271
|
|
|
131
272
|
|
|
132
273
|
def TPU(device: TPUType, partition: V5PParts | V6EParts | None = None):
|
|
@@ -147,7 +288,42 @@ def TPU(device: TPUType, partition: V5PParts | V6EParts | None = None):
|
|
|
147
288
|
elif partition is not None and device == "V5E":
|
|
148
289
|
if partition not in get_args(V5EParts):
|
|
149
290
|
raise ValueError(f"Invalid partition for V5E: {partition}. Must be one of {get_args(V5EParts)}")
|
|
150
|
-
return Device(1, device, partition)
|
|
291
|
+
return Device(1, "TPU", device, partition)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def Neuron(device: NeuronType) -> Device:
|
|
295
|
+
"""
|
|
296
|
+
Create a Neuron device instance.
|
|
297
|
+
:param device: Device type (e.g., "Inf1", "Inf2", "Trn1", "Trn1n", "Trn2", "Trn2u").
|
|
298
|
+
:param quantity: The number of Neuron devices of this type.
|
|
299
|
+
:return: Device instance.
|
|
300
|
+
"""
|
|
301
|
+
if device not in get_args(NeuronType):
|
|
302
|
+
raise ValueError(f"Invalid Neuron type: {device}. Must be one of {get_args(NeuronType)}")
|
|
303
|
+
return Device(device=device, quantity=1, device_class="NEURON")
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def AMD_GPU(device: AMD_GPUType) -> Device:
|
|
307
|
+
"""
|
|
308
|
+
Create an AMD GPU device instance.
|
|
309
|
+
:param device: Device type (e.g., "MI100", "MI210", "MI250", "MI250X", "MI300A", "MI300X", "MI325X", "MI350X",
|
|
310
|
+
"MI355X").
|
|
311
|
+
:return: Device instance.
|
|
312
|
+
"""
|
|
313
|
+
if device not in get_args(AMD_GPUType):
|
|
314
|
+
raise ValueError(f"Invalid AMD GPU type: {device}. Must be one of {get_args(AMD_GPUType)}")
|
|
315
|
+
return Device(device=device, quantity=1, device_class="AMD_GPU")
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def HABANA_GAUDI(device: HABANA_GAUDIType) -> Device:
|
|
319
|
+
"""
|
|
320
|
+
Create a Habana Gaudi device instance.
|
|
321
|
+
:param device: Device type (e.g., "DL1").
|
|
322
|
+
:return: Device instance.
|
|
323
|
+
"""
|
|
324
|
+
if device not in get_args(HABANA_GAUDIType):
|
|
325
|
+
raise ValueError(f"Invalid Habana Gaudi type: {device}. Must be one of {get_args(HABANA_GAUDIType)}")
|
|
326
|
+
return Device(device=device, quantity=1, device_class="HABANA_GAUDI")
|
|
151
327
|
|
|
152
328
|
|
|
153
329
|
CPUBaseType = int | float | str
|
|
@@ -202,7 +378,7 @@ class Resources:
|
|
|
202
378
|
raise ValueError("gpu must be greater than or equal to 0")
|
|
203
379
|
elif isinstance(self.gpu, str):
|
|
204
380
|
if self.gpu not in get_args(Accelerators):
|
|
205
|
-
raise ValueError(f"gpu must be one of {Accelerators}")
|
|
381
|
+
raise ValueError(f"gpu must be one of {Accelerators}, got {self.gpu}")
|
|
206
382
|
|
|
207
383
|
def get_device(self) -> Optional[Device]:
|
|
208
384
|
"""
|
|
@@ -214,10 +390,16 @@ class Resources:
|
|
|
214
390
|
if self.gpu is None:
|
|
215
391
|
return None
|
|
216
392
|
if isinstance(self.gpu, int):
|
|
217
|
-
return Device(quantity=self.gpu)
|
|
393
|
+
return Device(quantity=self.gpu, device_class="GPU")
|
|
218
394
|
if isinstance(self.gpu, str):
|
|
219
395
|
device, portion = self.gpu.split(":")
|
|
220
|
-
|
|
396
|
+
for cls, cls_name in _DeviceClassType.items():
|
|
397
|
+
if device in get_args(cls):
|
|
398
|
+
device_class = cls_name
|
|
399
|
+
break
|
|
400
|
+
else:
|
|
401
|
+
raise ValueError(f"Invalid device type: {device}. Must be one of {list(_DeviceClassType.keys())}")
|
|
402
|
+
return Device(device=device, device_class=device_class, quantity=int(portion)) # type: ignore
|
|
221
403
|
return self.gpu
|
|
222
404
|
|
|
223
405
|
def get_shared_memory(self) -> Optional[str]:
|