flyte 0.2.0b1__py3-none-any.whl → 2.0.0b46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyte/__init__.py +83 -30
- flyte/_bin/connect.py +61 -0
- flyte/_bin/debug.py +38 -0
- flyte/_bin/runtime.py +87 -19
- flyte/_bin/serve.py +351 -0
- flyte/_build.py +3 -2
- flyte/_cache/cache.py +6 -5
- flyte/_cache/local_cache.py +216 -0
- flyte/_code_bundle/_ignore.py +31 -5
- flyte/_code_bundle/_packaging.py +42 -11
- flyte/_code_bundle/_utils.py +57 -34
- flyte/_code_bundle/bundle.py +130 -27
- flyte/_constants.py +1 -0
- flyte/_context.py +21 -5
- flyte/_custom_context.py +73 -0
- flyte/_debug/constants.py +37 -0
- flyte/_debug/utils.py +17 -0
- flyte/_debug/vscode.py +315 -0
- flyte/_deploy.py +396 -75
- flyte/_deployer.py +109 -0
- flyte/_environment.py +94 -11
- flyte/_excepthook.py +37 -0
- flyte/_group.py +2 -1
- flyte/_hash.py +1 -16
- flyte/_image.py +544 -231
- flyte/_initialize.py +456 -316
- flyte/_interface.py +40 -5
- flyte/_internal/controllers/__init__.py +22 -8
- flyte/_internal/controllers/_local_controller.py +159 -35
- flyte/_internal/controllers/_trace.py +18 -10
- flyte/_internal/controllers/remote/__init__.py +38 -9
- flyte/_internal/controllers/remote/_action.py +82 -12
- flyte/_internal/controllers/remote/_client.py +6 -2
- flyte/_internal/controllers/remote/_controller.py +290 -64
- flyte/_internal/controllers/remote/_core.py +155 -95
- flyte/_internal/controllers/remote/_informer.py +40 -20
- flyte/_internal/controllers/remote/_service_protocol.py +2 -2
- flyte/_internal/imagebuild/__init__.py +2 -10
- flyte/_internal/imagebuild/docker_builder.py +391 -84
- flyte/_internal/imagebuild/image_builder.py +111 -55
- flyte/_internal/imagebuild/remote_builder.py +409 -0
- flyte/_internal/imagebuild/utils.py +79 -0
- flyte/_internal/resolvers/_app_env_module.py +92 -0
- flyte/_internal/resolvers/_task_module.py +5 -38
- flyte/_internal/resolvers/app_env.py +26 -0
- flyte/_internal/resolvers/common.py +8 -1
- flyte/_internal/resolvers/default.py +2 -2
- flyte/_internal/runtime/convert.py +319 -36
- flyte/_internal/runtime/entrypoints.py +106 -18
- flyte/_internal/runtime/io.py +71 -23
- flyte/_internal/runtime/resources_serde.py +21 -7
- flyte/_internal/runtime/reuse.py +125 -0
- flyte/_internal/runtime/rusty.py +196 -0
- flyte/_internal/runtime/task_serde.py +239 -66
- flyte/_internal/runtime/taskrunner.py +48 -8
- flyte/_internal/runtime/trigger_serde.py +162 -0
- flyte/_internal/runtime/types_serde.py +7 -16
- flyte/_keyring/file.py +115 -0
- flyte/_link.py +30 -0
- flyte/_logging.py +241 -42
- flyte/_map.py +312 -0
- flyte/_metrics.py +59 -0
- flyte/_module.py +74 -0
- flyte/_pod.py +30 -0
- flyte/_resources.py +296 -33
- flyte/_retry.py +1 -7
- flyte/_reusable_environment.py +72 -7
- flyte/_run.py +462 -132
- flyte/_secret.py +47 -11
- flyte/_serve.py +333 -0
- flyte/_task.py +245 -56
- flyte/_task_environment.py +219 -97
- flyte/_task_plugins.py +47 -0
- flyte/_tools.py +8 -8
- flyte/_trace.py +15 -24
- flyte/_trigger.py +1027 -0
- flyte/_utils/__init__.py +12 -1
- flyte/_utils/asyn.py +3 -1
- flyte/_utils/async_cache.py +139 -0
- flyte/_utils/coro_management.py +5 -4
- flyte/_utils/description_parser.py +19 -0
- flyte/_utils/docker_credentials.py +173 -0
- flyte/_utils/helpers.py +45 -19
- flyte/_utils/module_loader.py +123 -0
- flyte/_utils/org_discovery.py +57 -0
- flyte/_utils/uv_script_parser.py +8 -1
- flyte/_version.py +16 -3
- flyte/app/__init__.py +27 -0
- flyte/app/_app_environment.py +362 -0
- flyte/app/_connector_environment.py +40 -0
- flyte/app/_deploy.py +130 -0
- flyte/app/_parameter.py +343 -0
- flyte/app/_runtime/__init__.py +3 -0
- flyte/app/_runtime/app_serde.py +383 -0
- flyte/app/_types.py +113 -0
- flyte/app/extras/__init__.py +9 -0
- flyte/app/extras/_auth_middleware.py +217 -0
- flyte/app/extras/_fastapi.py +93 -0
- flyte/app/extras/_model_loader/__init__.py +3 -0
- flyte/app/extras/_model_loader/config.py +7 -0
- flyte/app/extras/_model_loader/loader.py +288 -0
- flyte/cli/__init__.py +12 -0
- flyte/cli/_abort.py +28 -0
- flyte/cli/_build.py +114 -0
- flyte/cli/_common.py +493 -0
- flyte/cli/_create.py +371 -0
- flyte/cli/_delete.py +45 -0
- flyte/cli/_deploy.py +401 -0
- flyte/cli/_gen.py +316 -0
- flyte/cli/_get.py +446 -0
- flyte/cli/_option.py +33 -0
- flyte/{_cli → cli}/_params.py +57 -17
- flyte/cli/_plugins.py +209 -0
- flyte/cli/_prefetch.py +292 -0
- flyte/cli/_run.py +690 -0
- flyte/cli/_serve.py +338 -0
- flyte/cli/_update.py +86 -0
- flyte/cli/_user.py +20 -0
- flyte/cli/main.py +246 -0
- flyte/config/__init__.py +2 -167
- flyte/config/_config.py +215 -163
- flyte/config/_internal.py +10 -1
- flyte/config/_reader.py +225 -0
- flyte/connectors/__init__.py +11 -0
- flyte/connectors/_connector.py +330 -0
- flyte/connectors/_server.py +194 -0
- flyte/connectors/utils.py +159 -0
- flyte/errors.py +134 -2
- flyte/extend.py +24 -0
- flyte/extras/_container.py +69 -56
- flyte/git/__init__.py +3 -0
- flyte/git/_config.py +279 -0
- flyte/io/__init__.py +8 -1
- flyte/io/{structured_dataset → _dataframe}/__init__.py +32 -30
- flyte/io/{structured_dataset → _dataframe}/basic_dfs.py +75 -68
- flyte/io/{structured_dataset/structured_dataset.py → _dataframe/dataframe.py} +207 -242
- flyte/io/_dir.py +575 -113
- flyte/io/_file.py +587 -141
- flyte/io/_hashing_io.py +342 -0
- flyte/io/extend.py +7 -0
- flyte/models.py +635 -0
- flyte/prefetch/__init__.py +22 -0
- flyte/prefetch/_hf_model.py +563 -0
- flyte/remote/__init__.py +14 -3
- flyte/remote/_action.py +879 -0
- flyte/remote/_app.py +346 -0
- flyte/remote/_auth_metadata.py +42 -0
- flyte/remote/_client/_protocols.py +62 -4
- flyte/remote/_client/auth/_auth_utils.py +19 -0
- flyte/remote/_client/auth/_authenticators/base.py +8 -2
- flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
- flyte/remote/_client/auth/_authenticators/factory.py +4 -0
- flyte/remote/_client/auth/_authenticators/passthrough.py +79 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +17 -18
- flyte/remote/_client/auth/_channel.py +47 -18
- flyte/remote/_client/auth/_client_config.py +5 -3
- flyte/remote/_client/auth/_keyring.py +15 -2
- flyte/remote/_client/auth/_token_client.py +3 -3
- flyte/remote/_client/controlplane.py +206 -18
- flyte/remote/_common.py +66 -0
- flyte/remote/_data.py +107 -22
- flyte/remote/_logs.py +116 -33
- flyte/remote/_project.py +21 -19
- flyte/remote/_run.py +164 -631
- flyte/remote/_secret.py +72 -29
- flyte/remote/_task.py +387 -46
- flyte/remote/_trigger.py +368 -0
- flyte/remote/_user.py +43 -0
- flyte/report/_report.py +10 -6
- flyte/storage/__init__.py +13 -1
- flyte/storage/_config.py +237 -0
- flyte/storage/_parallel_reader.py +289 -0
- flyte/storage/_storage.py +268 -59
- flyte/syncify/__init__.py +56 -0
- flyte/syncify/_api.py +414 -0
- flyte/types/__init__.py +39 -0
- flyte/types/_interface.py +22 -7
- flyte/{io/pickle/transformer.py → types/_pickle.py} +37 -9
- flyte/types/_string_literals.py +8 -9
- flyte/types/_type_engine.py +226 -126
- flyte/types/_utils.py +1 -1
- flyte-2.0.0b46.data/scripts/debug.py +38 -0
- flyte-2.0.0b46.data/scripts/runtime.py +194 -0
- flyte-2.0.0b46.dist-info/METADATA +352 -0
- flyte-2.0.0b46.dist-info/RECORD +221 -0
- flyte-2.0.0b46.dist-info/entry_points.txt +8 -0
- flyte-2.0.0b46.dist-info/licenses/LICENSE +201 -0
- flyte/_api_commons.py +0 -3
- flyte/_cli/_common.py +0 -299
- flyte/_cli/_create.py +0 -42
- flyte/_cli/_delete.py +0 -23
- flyte/_cli/_deploy.py +0 -140
- flyte/_cli/_get.py +0 -235
- flyte/_cli/_run.py +0 -174
- flyte/_cli/main.py +0 -98
- flyte/_datastructures.py +0 -342
- flyte/_internal/controllers/pbhash.py +0 -39
- flyte/_protos/common/authorization_pb2.py +0 -66
- flyte/_protos/common/authorization_pb2.pyi +0 -108
- flyte/_protos/common/authorization_pb2_grpc.py +0 -4
- flyte/_protos/common/identifier_pb2.py +0 -71
- flyte/_protos/common/identifier_pb2.pyi +0 -82
- flyte/_protos/common/identifier_pb2_grpc.py +0 -4
- flyte/_protos/common/identity_pb2.py +0 -48
- flyte/_protos/common/identity_pb2.pyi +0 -72
- flyte/_protos/common/identity_pb2_grpc.py +0 -4
- flyte/_protos/common/list_pb2.py +0 -36
- flyte/_protos/common/list_pb2.pyi +0 -69
- flyte/_protos/common/list_pb2_grpc.py +0 -4
- flyte/_protos/common/policy_pb2.py +0 -37
- flyte/_protos/common/policy_pb2.pyi +0 -27
- flyte/_protos/common/policy_pb2_grpc.py +0 -4
- flyte/_protos/common/role_pb2.py +0 -37
- flyte/_protos/common/role_pb2.pyi +0 -53
- flyte/_protos/common/role_pb2_grpc.py +0 -4
- flyte/_protos/common/runtime_version_pb2.py +0 -28
- flyte/_protos/common/runtime_version_pb2.pyi +0 -24
- flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
- flyte/_protos/logs/dataplane/payload_pb2.py +0 -96
- flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -168
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/definition_pb2.py +0 -49
- flyte/_protos/secret/definition_pb2.pyi +0 -93
- flyte/_protos/secret/definition_pb2_grpc.py +0 -4
- flyte/_protos/secret/payload_pb2.py +0 -62
- flyte/_protos/secret/payload_pb2.pyi +0 -94
- flyte/_protos/secret/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/secret_pb2.py +0 -38
- flyte/_protos/secret/secret_pb2.pyi +0 -6
- flyte/_protos/secret/secret_pb2_grpc.py +0 -198
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
- flyte/_protos/validate/validate/validate_pb2.py +0 -76
- flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
- flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
- flyte/_protos/workflow/queue_service_pb2.py +0 -106
- flyte/_protos/workflow/queue_service_pb2.pyi +0 -141
- flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
- flyte/_protos/workflow/run_definition_pb2.py +0 -128
- flyte/_protos/workflow/run_definition_pb2.pyi +0 -310
- flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
- flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
- flyte/_protos/workflow/run_service_pb2.py +0 -133
- flyte/_protos/workflow/run_service_pb2.pyi +0 -175
- flyte/_protos/workflow/run_service_pb2_grpc.py +0 -412
- flyte/_protos/workflow/state_service_pb2.py +0 -58
- flyte/_protos/workflow/state_service_pb2.pyi +0 -71
- flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
- flyte/_protos/workflow/task_definition_pb2.py +0 -72
- flyte/_protos/workflow/task_definition_pb2.pyi +0 -65
- flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/task_service_pb2.py +0 -44
- flyte/_protos/workflow/task_service_pb2.pyi +0 -31
- flyte/_protos/workflow/task_service_pb2_grpc.py +0 -104
- flyte/io/_dataframe.py +0 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/remote/_console.py +0 -18
- flyte-0.2.0b1.dist-info/METADATA +0 -179
- flyte-0.2.0b1.dist-info/RECORD +0 -204
- flyte-0.2.0b1.dist-info/entry_points.txt +0 -3
- /flyte/{_cli → _debug}/__init__.py +0 -0
- /flyte/{_protos → _keyring}/__init__.py +0 -0
- {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/WHEEL +0 -0
- {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/top_level.txt +0 -0
flyte/_metrics.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Timer utilities for emitting timing metrics via structured logging.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
from typing import Any, Dict, Optional
|
|
7
|
+
|
|
8
|
+
from flyte._logging import logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Stopwatch:
|
|
12
|
+
"""
|
|
13
|
+
Simple stopwatch for timing code blocks.
|
|
14
|
+
Emits timing metrics via structured logging when stopped.
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
sw = Stopwatch("download_inputs")
|
|
18
|
+
sw.start()
|
|
19
|
+
# code to time
|
|
20
|
+
sw.stop()
|
|
21
|
+
|
|
22
|
+
:param metric_name: Name of the metric to emit
|
|
23
|
+
:param extra_fields: Additional fields to include in the log record
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, metric_name: str, extra_fields: Optional[Dict[str, Any]] = None):
|
|
27
|
+
self.metric_name = metric_name
|
|
28
|
+
self.extra_fields = extra_fields
|
|
29
|
+
self._start_time: Optional[float] = None
|
|
30
|
+
|
|
31
|
+
def start(self):
|
|
32
|
+
"""Start the stopwatch."""
|
|
33
|
+
self._start_time = time.perf_counter()
|
|
34
|
+
|
|
35
|
+
def stop(self):
|
|
36
|
+
"""Stop the stopwatch and emit the timing metric."""
|
|
37
|
+
if self._start_time is None:
|
|
38
|
+
raise RuntimeError(f"Stopwatch '{self.metric_name}' was never started")
|
|
39
|
+
duration = time.perf_counter() - self._start_time
|
|
40
|
+
_emit_metric(self.metric_name, duration, self.extra_fields)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _emit_metric(metric_name: str, duration: float, extra_fields: Optional[Dict[str, Any]] = None):
|
|
44
|
+
"""
|
|
45
|
+
Emit a timing metric via structured logging.
|
|
46
|
+
|
|
47
|
+
:param metric_name: Name of the metric (may be hierarchical with dots)
|
|
48
|
+
:param duration: Duration in seconds
|
|
49
|
+
:param extra_fields: Additional fields to include in the log record
|
|
50
|
+
"""
|
|
51
|
+
extra = {
|
|
52
|
+
"metric_type": "timer",
|
|
53
|
+
"metric_name": metric_name,
|
|
54
|
+
"duration_seconds": duration,
|
|
55
|
+
}
|
|
56
|
+
if extra_fields:
|
|
57
|
+
extra.update(extra_fields)
|
|
58
|
+
|
|
59
|
+
logger.info(f"Stopwatch: {metric_name} completed in {duration:.4f}s", extra=extra)
|
flyte/_module.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import os
|
|
3
|
+
import pathlib
|
|
4
|
+
import sys
|
|
5
|
+
from types import ModuleType
|
|
6
|
+
from typing import Tuple
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def extract_obj_module(obj: object, /, source_dir: pathlib.Path | None = None) -> Tuple[str, ModuleType]:
|
|
10
|
+
"""
|
|
11
|
+
Extract the module from the given object. If source_dir is provided, the module will be relative to the source_dir.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
obj: The object to extract the module from.
|
|
15
|
+
source_dir: The source directory to use for relative paths.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
The module name as a string.
|
|
19
|
+
"""
|
|
20
|
+
if source_dir is None:
|
|
21
|
+
raise ValueError("extract_obj_module: source_dir cannot be None - specify root-dir")
|
|
22
|
+
# Get the module containing the object
|
|
23
|
+
entity_module = inspect.getmodule(obj)
|
|
24
|
+
if entity_module is None:
|
|
25
|
+
obj_name = getattr(obj, "__name__", str(obj))
|
|
26
|
+
raise ValueError(f"Object {obj_name} has no module.")
|
|
27
|
+
|
|
28
|
+
fp = entity_module.__file__
|
|
29
|
+
if fp is None:
|
|
30
|
+
obj_name = getattr(obj, "__name__", str(obj))
|
|
31
|
+
raise ValueError(f"Object {obj_name} has no module.")
|
|
32
|
+
|
|
33
|
+
file_path = pathlib.Path(fp)
|
|
34
|
+
try:
|
|
35
|
+
# Get the relative path to the current directory
|
|
36
|
+
# Will raise ValueError if the file is not in the source directory
|
|
37
|
+
relative_path = file_path.relative_to(str(pathlib.Path(source_dir).absolute()))
|
|
38
|
+
|
|
39
|
+
if relative_path == pathlib.Path("_internal/resolvers"):
|
|
40
|
+
entity_module_name = entity_module.__name__
|
|
41
|
+
elif "site-packages" in str(file_path) or "dist-packages" in str(file_path):
|
|
42
|
+
raise ValueError("Object from a library")
|
|
43
|
+
else:
|
|
44
|
+
# Replace file separators with dots and remove the '.py' extension
|
|
45
|
+
dotted_path = os.path.splitext(str(relative_path))[0].replace(os.sep, ".")
|
|
46
|
+
entity_module_name = dotted_path
|
|
47
|
+
except ValueError:
|
|
48
|
+
# If source_dir is not provided or file is not in source_dir, fallback to module name
|
|
49
|
+
# File is not relative to source_dir - check if it's an installed package
|
|
50
|
+
file_path_str = str(file_path)
|
|
51
|
+
if "site-packages" in file_path_str or "dist-packages" in file_path_str:
|
|
52
|
+
# It's an installed package - use the module's __name__ directly
|
|
53
|
+
# This will be importable via importlib.import_module()
|
|
54
|
+
entity_module_name = entity_module.__name__
|
|
55
|
+
else:
|
|
56
|
+
# File is not in source_dir and not in site-packages - re-raise the error
|
|
57
|
+
obj_name = getattr(obj, "__name__", str(obj))
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"Object {obj_name} module file {file_path} is not relative to "
|
|
60
|
+
f"source directory {source_dir} and is not an installed package."
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if entity_module_name == "__main__":
|
|
64
|
+
"""
|
|
65
|
+
This case is for the case in which the object is run from the main module.
|
|
66
|
+
"""
|
|
67
|
+
fp = sys.modules["__main__"].__file__
|
|
68
|
+
if fp is None:
|
|
69
|
+
obj_name = getattr(obj, "__name__", str(obj))
|
|
70
|
+
raise ValueError(f"Object {obj_name} has no module.")
|
|
71
|
+
main_path = pathlib.Path(fp)
|
|
72
|
+
entity_module_name = main_path.stem
|
|
73
|
+
|
|
74
|
+
return entity_module_name, entity_module
|
flyte/_pod.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import TYPE_CHECKING, Dict, Optional
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from flyteidl2.core.tasks_pb2 import K8sPod
|
|
6
|
+
from kubernetes.client import V1PodSpec
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
|
|
10
|
+
_PRIMARY_CONTAINER_DEFAULT_NAME = "primary"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(init=True, repr=True, eq=True, frozen=False)
|
|
14
|
+
class PodTemplate(object):
|
|
15
|
+
"""Custom PodTemplate specification for a Task."""
|
|
16
|
+
|
|
17
|
+
pod_spec: Optional["V1PodSpec"] = field(default_factory=lambda: V1PodSpec())
|
|
18
|
+
primary_container_name: str = _PRIMARY_CONTAINER_DEFAULT_NAME
|
|
19
|
+
labels: Optional[Dict[str, str]] = None
|
|
20
|
+
annotations: Optional[Dict[str, str]] = None
|
|
21
|
+
|
|
22
|
+
def to_k8s_pod(self) -> "K8sPod":
|
|
23
|
+
from flyteidl2.core.tasks_pb2 import K8sObjectMetadata, K8sPod
|
|
24
|
+
from kubernetes.client import ApiClient
|
|
25
|
+
|
|
26
|
+
return K8sPod(
|
|
27
|
+
metadata=K8sObjectMetadata(labels=self.labels, annotations=self.annotations),
|
|
28
|
+
pod_spec=ApiClient().sanitize_for_serialization(self.pod_spec),
|
|
29
|
+
primary_container_name=self.primary_container_name,
|
|
30
|
+
)
|
flyte/_resources.py
CHANGED
|
@@ -1,9 +1,19 @@
|
|
|
1
|
-
|
|
2
|
-
from
|
|
1
|
+
import typing
|
|
2
|
+
from dataclasses import dataclass, fields
|
|
3
|
+
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union, get_args
|
|
3
4
|
|
|
4
5
|
import rich.repr
|
|
5
6
|
|
|
6
|
-
|
|
7
|
+
from flyte._pod import _PRIMARY_CONTAINER_DEFAULT_NAME
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from kubernetes.client import V1PodSpec
|
|
11
|
+
|
|
12
|
+
PRIMARY_CONTAINER_DEFAULT_NAME = "primary"
|
|
13
|
+
|
|
14
|
+
GPUType = Literal[
|
|
15
|
+
"A10", "A10G", "A100", "A100 80G", "B200", "H100", "H200", "L4", "L40s", "T4", "V100", "RTX PRO 6000", "GB10"
|
|
16
|
+
]
|
|
7
17
|
GPUQuantity = Literal[1, 2, 3, 4, 5, 6, 7, 8]
|
|
8
18
|
A100Parts = Literal["1g.5gb", "2g.10gb", "3g.20gb", "4g.20gb", "7g.40gb"]
|
|
9
19
|
"""
|
|
@@ -15,6 +25,11 @@ A100_80GBParts = Literal["1g.10gb", "2g.20gb", "3g.40gb", "4g.40gb", "7g.80gb"]
|
|
|
15
25
|
Partitions for NVIDIA A100 80GB GPU.
|
|
16
26
|
"""
|
|
17
27
|
|
|
28
|
+
H200Parts = Literal["1g.18gb", "1g.35gb", "2g.35gb", "3g.71gb", "4g.71gb", "7g.141gb"]
|
|
29
|
+
"""
|
|
30
|
+
Partitions for NVIDIA H200 GPU (141GB HBM3e).
|
|
31
|
+
"""
|
|
32
|
+
|
|
18
33
|
TPUType = Literal["V5P", "V6E"]
|
|
19
34
|
V5EParts = Literal["1x1", "2x2", "2x4", "4x4", "4x8", "8x8", "8x16", "16x16"]
|
|
20
35
|
|
|
@@ -30,31 +45,32 @@ V6EParts = Literal["1x1", "2x2", "2x4", "4x4", "4x8", "8x8", "8x16", "16x16"]
|
|
|
30
45
|
Slices for Google Cloud TPU v6e.
|
|
31
46
|
"""
|
|
32
47
|
|
|
48
|
+
NeuronType = Literal["Inf1", "Inf2", "Trn1", "Trn1n", "Trn2", "Trn2u"]
|
|
49
|
+
|
|
50
|
+
AMD_GPUType = Literal["MI100", "MI210", "MI250", "MI250X", "MI300A", "MI300X", "MI325X", "MI350X", "MI355X"]
|
|
51
|
+
|
|
52
|
+
HABANA_GAUDIType = Literal["Gaudi1"]
|
|
53
|
+
|
|
33
54
|
Accelerators = Literal[
|
|
34
|
-
|
|
35
|
-
"
|
|
36
|
-
"
|
|
37
|
-
"
|
|
38
|
-
"
|
|
39
|
-
"
|
|
40
|
-
"
|
|
41
|
-
"
|
|
42
|
-
"
|
|
43
|
-
|
|
44
|
-
"
|
|
45
|
-
"
|
|
46
|
-
"
|
|
47
|
-
"
|
|
48
|
-
"
|
|
49
|
-
"
|
|
50
|
-
"
|
|
51
|
-
"
|
|
52
|
-
|
|
53
|
-
"L40s:4",
|
|
54
|
-
"L40s:5",
|
|
55
|
-
"L40s:6",
|
|
56
|
-
"L40s:7",
|
|
57
|
-
"L40s:8",
|
|
55
|
+
# A10
|
|
56
|
+
"A10:1",
|
|
57
|
+
"A10:2",
|
|
58
|
+
"A10:3",
|
|
59
|
+
"A10:4",
|
|
60
|
+
"A10:5",
|
|
61
|
+
"A10:6",
|
|
62
|
+
"A10:7",
|
|
63
|
+
"A10:8",
|
|
64
|
+
# A10G
|
|
65
|
+
"A10G:1",
|
|
66
|
+
"A10G:2",
|
|
67
|
+
"A10G:3",
|
|
68
|
+
"A10G:4",
|
|
69
|
+
"A10G:5",
|
|
70
|
+
"A10G:6",
|
|
71
|
+
"A10G:7",
|
|
72
|
+
"A10G:8",
|
|
73
|
+
# A100
|
|
58
74
|
"A100:1",
|
|
59
75
|
"A100:2",
|
|
60
76
|
"A100:3",
|
|
@@ -63,6 +79,7 @@ Accelerators = Literal[
|
|
|
63
79
|
"A100:6",
|
|
64
80
|
"A100:7",
|
|
65
81
|
"A100:8",
|
|
82
|
+
# A100 80G
|
|
66
83
|
"A100 80G:1",
|
|
67
84
|
"A100 80G:2",
|
|
68
85
|
"A100 80G:3",
|
|
@@ -71,6 +88,16 @@ Accelerators = Literal[
|
|
|
71
88
|
"A100 80G:6",
|
|
72
89
|
"A100 80G:7",
|
|
73
90
|
"A100 80G:8",
|
|
91
|
+
# B200
|
|
92
|
+
"B200:1",
|
|
93
|
+
"B200:2",
|
|
94
|
+
"B200:3",
|
|
95
|
+
"B200:4",
|
|
96
|
+
"B200:5",
|
|
97
|
+
"B200:6",
|
|
98
|
+
"B200:7",
|
|
99
|
+
"B200:8",
|
|
100
|
+
# H100
|
|
74
101
|
"H100:1",
|
|
75
102
|
"H100:2",
|
|
76
103
|
"H100:3",
|
|
@@ -79,8 +106,137 @@ Accelerators = Literal[
|
|
|
79
106
|
"H100:6",
|
|
80
107
|
"H100:7",
|
|
81
108
|
"H100:8",
|
|
109
|
+
# H200
|
|
110
|
+
"H200:1",
|
|
111
|
+
"H200:2",
|
|
112
|
+
"H200:3",
|
|
113
|
+
"H200:4",
|
|
114
|
+
"H200:5",
|
|
115
|
+
"H200:6",
|
|
116
|
+
"H200:7",
|
|
117
|
+
"H200:8",
|
|
118
|
+
# L4
|
|
119
|
+
"L4:1",
|
|
120
|
+
"L4:2",
|
|
121
|
+
"L4:3",
|
|
122
|
+
"L4:4",
|
|
123
|
+
"L4:5",
|
|
124
|
+
"L4:6",
|
|
125
|
+
"L4:7",
|
|
126
|
+
"L4:8",
|
|
127
|
+
# L40s
|
|
128
|
+
"L40s:1",
|
|
129
|
+
"L40s:2",
|
|
130
|
+
"L40s:3",
|
|
131
|
+
"L40s:4",
|
|
132
|
+
"L40s:5",
|
|
133
|
+
"L40s:6",
|
|
134
|
+
"L40s:7",
|
|
135
|
+
"L40s:8",
|
|
136
|
+
# V100
|
|
137
|
+
"V100:1",
|
|
138
|
+
"V100:2",
|
|
139
|
+
"V100:3",
|
|
140
|
+
"V100:4",
|
|
141
|
+
"V100:5",
|
|
142
|
+
"V100:6",
|
|
143
|
+
"V100:7",
|
|
144
|
+
"V100:8",
|
|
145
|
+
# RTX 6000
|
|
146
|
+
"RTX PRO 6000:1",
|
|
147
|
+
# GB10
|
|
148
|
+
"GB10:1",
|
|
149
|
+
# T4
|
|
150
|
+
"T4:1",
|
|
151
|
+
"T4:2",
|
|
152
|
+
"T4:3",
|
|
153
|
+
"T4:4",
|
|
154
|
+
"T4:5",
|
|
155
|
+
"T4:6",
|
|
156
|
+
"T4:7",
|
|
157
|
+
"T4:8",
|
|
158
|
+
# Trn1
|
|
159
|
+
"Trn1:1",
|
|
160
|
+
"Trn1:4",
|
|
161
|
+
"Trn1:8",
|
|
162
|
+
"Trn1:16",
|
|
163
|
+
# Trn1n
|
|
164
|
+
"Trn1n:1",
|
|
165
|
+
"Trn1n:4",
|
|
166
|
+
"Trn1n:8",
|
|
167
|
+
"Trn1n:16",
|
|
168
|
+
# Trn2
|
|
169
|
+
"Trn2:1",
|
|
170
|
+
"Trn2:4",
|
|
171
|
+
"Trn2:8",
|
|
172
|
+
"Trn2:16",
|
|
173
|
+
# Trn2u
|
|
174
|
+
"Trn2u:1",
|
|
175
|
+
"Trn2u:4",
|
|
176
|
+
"Trn2u:8",
|
|
177
|
+
"Trn2u:16",
|
|
178
|
+
# Inf1
|
|
179
|
+
"Inf1:1",
|
|
180
|
+
"Inf1:2",
|
|
181
|
+
"Inf1:3",
|
|
182
|
+
"Inf1:4",
|
|
183
|
+
"Inf1:5",
|
|
184
|
+
"Inf1:6",
|
|
185
|
+
"Inf1:7",
|
|
186
|
+
"Inf1:8",
|
|
187
|
+
"Inf1:9",
|
|
188
|
+
"Inf1:10",
|
|
189
|
+
"Inf1:11",
|
|
190
|
+
"Inf1:12",
|
|
191
|
+
"Inf1:13",
|
|
192
|
+
"Inf1:14",
|
|
193
|
+
"Inf1:15",
|
|
194
|
+
"Inf1:16",
|
|
195
|
+
# Inf2
|
|
196
|
+
"Inf2:1",
|
|
197
|
+
"Inf2:2",
|
|
198
|
+
"Inf2:3",
|
|
199
|
+
"Inf2:4",
|
|
200
|
+
"Inf2:5",
|
|
201
|
+
"Inf2:6",
|
|
202
|
+
"Inf2:7",
|
|
203
|
+
"Inf2:8",
|
|
204
|
+
"Inf2:9",
|
|
205
|
+
"Inf2:10",
|
|
206
|
+
"Inf2:11",
|
|
207
|
+
"Inf2:12",
|
|
208
|
+
# MI100
|
|
209
|
+
"MI100:1",
|
|
210
|
+
# MI210
|
|
211
|
+
"MI210:1",
|
|
212
|
+
# MI250
|
|
213
|
+
"MI250:1",
|
|
214
|
+
# MI250X
|
|
215
|
+
"MI250X:1",
|
|
216
|
+
# MI300A
|
|
217
|
+
"MI300A:1",
|
|
218
|
+
# MI300X
|
|
219
|
+
"MI300X:1",
|
|
220
|
+
# MI325X
|
|
221
|
+
"MI325X:1",
|
|
222
|
+
# MI350X
|
|
223
|
+
"MI350X:1",
|
|
224
|
+
# MI355X
|
|
225
|
+
"MI355X:1",
|
|
226
|
+
# Habana Gaudi
|
|
227
|
+
"Gaudi1:1",
|
|
82
228
|
]
|
|
83
229
|
|
|
230
|
+
DeviceClass = Literal["GPU", "TPU", "NEURON", "AMD_GPU", "HABANA_GAUDI"]
|
|
231
|
+
|
|
232
|
+
_DeviceClassType: Dict[typing.Any, str] = {
|
|
233
|
+
GPUType: "GPU",
|
|
234
|
+
TPUType: "TPU",
|
|
235
|
+
NeuronType: "NEURON",
|
|
236
|
+
AMD_GPUType: "AMD_GPU",
|
|
237
|
+
HABANA_GAUDIType: "HABANA_GAUDI",
|
|
238
|
+
}
|
|
239
|
+
|
|
84
240
|
|
|
85
241
|
@rich.repr.auto
|
|
86
242
|
@dataclass(frozen=True, slots=True)
|
|
@@ -93,6 +249,7 @@ class Device:
|
|
|
93
249
|
"""
|
|
94
250
|
|
|
95
251
|
quantity: int
|
|
252
|
+
device_class: DeviceClass
|
|
96
253
|
device: str | None = None
|
|
97
254
|
partition: str | None = None
|
|
98
255
|
|
|
@@ -101,7 +258,9 @@ class Device:
|
|
|
101
258
|
raise ValueError("GPU quantity must be at least 1")
|
|
102
259
|
|
|
103
260
|
|
|
104
|
-
def GPU(
|
|
261
|
+
def GPU(
|
|
262
|
+
device: GPUType, quantity: GPUQuantity, partition: A100Parts | A100_80GBParts | H200Parts | None = None
|
|
263
|
+
) -> Device:
|
|
105
264
|
"""
|
|
106
265
|
Create a GPU device instance.
|
|
107
266
|
:param device: The type of GPU (e.g., "T4", "A100").
|
|
@@ -119,7 +278,10 @@ def GPU(device: GPUType, quantity: GPUQuantity, partition: A100Parts | A100_80GB
|
|
|
119
278
|
elif partition is not None and device == "A100 80G":
|
|
120
279
|
if partition not in get_args(A100_80GBParts):
|
|
121
280
|
raise ValueError(f"Invalid partition for A100 80G: {partition}. Must be one of {get_args(A100_80GBParts)}")
|
|
122
|
-
|
|
281
|
+
elif partition is not None and device == "H200":
|
|
282
|
+
if partition not in get_args(H200Parts):
|
|
283
|
+
raise ValueError(f"Invalid partition for H200: {partition}. Must be one of {get_args(H200Parts)}")
|
|
284
|
+
return Device(device=device, quantity=quantity, partition=partition, device_class="GPU")
|
|
123
285
|
|
|
124
286
|
|
|
125
287
|
def TPU(device: TPUType, partition: V5PParts | V6EParts | None = None):
|
|
@@ -140,7 +302,42 @@ def TPU(device: TPUType, partition: V5PParts | V6EParts | None = None):
|
|
|
140
302
|
elif partition is not None and device == "V5E":
|
|
141
303
|
if partition not in get_args(V5EParts):
|
|
142
304
|
raise ValueError(f"Invalid partition for V5E: {partition}. Must be one of {get_args(V5EParts)}")
|
|
143
|
-
return Device(1, device, partition)
|
|
305
|
+
return Device(1, "TPU", device, partition)
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def Neuron(device: NeuronType) -> Device:
|
|
309
|
+
"""
|
|
310
|
+
Create a Neuron device instance.
|
|
311
|
+
:param device: Device type (e.g., "Inf1", "Inf2", "Trn1", "Trn1n", "Trn2", "Trn2u").
|
|
312
|
+
:param quantity: The number of Neuron devices of this type.
|
|
313
|
+
:return: Device instance.
|
|
314
|
+
"""
|
|
315
|
+
if device not in get_args(NeuronType):
|
|
316
|
+
raise ValueError(f"Invalid Neuron type: {device}. Must be one of {get_args(NeuronType)}")
|
|
317
|
+
return Device(device=device, quantity=1, device_class="NEURON")
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def AMD_GPU(device: AMD_GPUType) -> Device:
|
|
321
|
+
"""
|
|
322
|
+
Create an AMD GPU device instance.
|
|
323
|
+
:param device: Device type (e.g., "MI100", "MI210", "MI250", "MI250X", "MI300A", "MI300X", "MI325X", "MI350X",
|
|
324
|
+
"MI355X").
|
|
325
|
+
:return: Device instance.
|
|
326
|
+
"""
|
|
327
|
+
if device not in get_args(AMD_GPUType):
|
|
328
|
+
raise ValueError(f"Invalid AMD GPU type: {device}. Must be one of {get_args(AMD_GPUType)}")
|
|
329
|
+
return Device(device=device, quantity=1, device_class="AMD_GPU")
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def HABANA_GAUDI(device: HABANA_GAUDIType) -> Device:
|
|
333
|
+
"""
|
|
334
|
+
Create a Habana Gaudi device instance.
|
|
335
|
+
:param device: Device type (e.g., "Gaudi1").
|
|
336
|
+
:return: Device instance.
|
|
337
|
+
"""
|
|
338
|
+
if device not in get_args(HABANA_GAUDIType):
|
|
339
|
+
raise ValueError(f"Invalid Habana Gaudi type: {device}. Must be one of {get_args(HABANA_GAUDIType)}")
|
|
340
|
+
return Device(device=device, quantity=1, device_class="HABANA_GAUDI")
|
|
144
341
|
|
|
145
342
|
|
|
146
343
|
CPUBaseType = int | float | str
|
|
@@ -171,6 +368,8 @@ class Resources:
|
|
|
171
368
|
strings, or a tuple of two ints or strings.
|
|
172
369
|
:param gpu: The amount of GPU to allocate to the task. This can be an Accelerators enum, an int, or None.
|
|
173
370
|
:param disk: The amount of disk to allocate to the task. This is a string of the form "10GiB".
|
|
371
|
+
:param shm: The amount of shared memory to allocate to the task. This is a string of the form "10GiB" or "auto".
|
|
372
|
+
If "auto", then the shared memory will be set to max amount of shared memory available on the node.
|
|
174
373
|
"""
|
|
175
374
|
|
|
176
375
|
cpu: Union[CPUBaseType, Tuple[CPUBaseType, CPUBaseType], None] = None
|
|
@@ -195,7 +394,7 @@ class Resources:
|
|
|
195
394
|
raise ValueError("gpu must be greater than or equal to 0")
|
|
196
395
|
elif isinstance(self.gpu, str):
|
|
197
396
|
if self.gpu not in get_args(Accelerators):
|
|
198
|
-
raise ValueError(f"gpu must be one of {Accelerators}")
|
|
397
|
+
raise ValueError(f"gpu must be one of {Accelerators}, got {self.gpu}")
|
|
199
398
|
|
|
200
399
|
def get_device(self) -> Optional[Device]:
|
|
201
400
|
"""
|
|
@@ -207,10 +406,16 @@ class Resources:
|
|
|
207
406
|
if self.gpu is None:
|
|
208
407
|
return None
|
|
209
408
|
if isinstance(self.gpu, int):
|
|
210
|
-
return Device(quantity=self.gpu)
|
|
409
|
+
return Device(quantity=self.gpu, device_class="GPU")
|
|
211
410
|
if isinstance(self.gpu, str):
|
|
212
411
|
device, portion = self.gpu.split(":")
|
|
213
|
-
|
|
412
|
+
for cls, cls_name in _DeviceClassType.items():
|
|
413
|
+
if device in get_args(cls):
|
|
414
|
+
device_class = cls_name
|
|
415
|
+
break
|
|
416
|
+
else:
|
|
417
|
+
raise ValueError(f"Invalid device type: {device}. Must be one of {list(_DeviceClassType.keys())}")
|
|
418
|
+
return Device(device=device, device_class=device_class, quantity=int(portion)) # type: ignore
|
|
214
419
|
return self.gpu
|
|
215
420
|
|
|
216
421
|
def get_shared_memory(self) -> Optional[str]:
|
|
@@ -224,3 +429,61 @@ class Resources:
|
|
|
224
429
|
if self.shm == "auto":
|
|
225
430
|
return ""
|
|
226
431
|
return self.shm
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def _check_resource_is_singular(resource: Resources):
|
|
435
|
+
"""
|
|
436
|
+
Raise a value error if the resource has a tuple.
|
|
437
|
+
"""
|
|
438
|
+
for field in fields(resource):
|
|
439
|
+
value = getattr(resource, field.name)
|
|
440
|
+
if isinstance(value, (tuple, list)):
|
|
441
|
+
raise ValueError(f"{value} can not be a list or tuple")
|
|
442
|
+
return resource
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def pod_spec_from_resources(
|
|
446
|
+
primary_container_name: str = _PRIMARY_CONTAINER_DEFAULT_NAME,
|
|
447
|
+
requests: Optional[Resources] = None,
|
|
448
|
+
limits: Optional[Resources] = None,
|
|
449
|
+
k8s_gpu_resource_key: str = "nvidia.com/gpu",
|
|
450
|
+
) -> "V1PodSpec":
|
|
451
|
+
from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements
|
|
452
|
+
|
|
453
|
+
def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resource_key: str):
|
|
454
|
+
if resources is None:
|
|
455
|
+
return None
|
|
456
|
+
|
|
457
|
+
resources_map = {
|
|
458
|
+
"cpu": "cpu",
|
|
459
|
+
"memory": "memory",
|
|
460
|
+
"gpu": k8s_gpu_resource_key,
|
|
461
|
+
"ephemeral_storage": "ephemeral-storage",
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
k8s_pod_resources = {}
|
|
465
|
+
|
|
466
|
+
_check_resource_is_singular(resources)
|
|
467
|
+
for resource in fields(resources):
|
|
468
|
+
resource_value = getattr(resources, resource.name)
|
|
469
|
+
if resource_value is not None:
|
|
470
|
+
k8s_pod_resources[resources_map[resource.name]] = resource_value
|
|
471
|
+
|
|
472
|
+
return k8s_pod_resources
|
|
473
|
+
|
|
474
|
+
requests = _construct_k8s_pods_resources(resources=requests, k8s_gpu_resource_key=k8s_gpu_resource_key)
|
|
475
|
+
limits = _construct_k8s_pods_resources(resources=limits, k8s_gpu_resource_key=k8s_gpu_resource_key)
|
|
476
|
+
requests = requests or limits
|
|
477
|
+
limits = limits or requests
|
|
478
|
+
|
|
479
|
+
return V1PodSpec(
|
|
480
|
+
containers=[
|
|
481
|
+
V1Container(
|
|
482
|
+
name=primary_container_name,
|
|
483
|
+
resources=V1ResourceRequirements(
|
|
484
|
+
requests=requests,
|
|
485
|
+
limits=limits,
|
|
486
|
+
),
|
|
487
|
+
)
|
|
488
|
+
]
|
|
489
|
+
)
|
flyte/_retry.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from datetime import timedelta
|
|
3
|
-
from typing import Union
|
|
4
2
|
|
|
5
3
|
|
|
6
4
|
@dataclass
|
|
@@ -17,16 +15,12 @@ class RetryStrategy:
|
|
|
17
15
|
```
|
|
18
16
|
- This will retry the task 5 times with a maximum backoff of 10 seconds and a backoff factor of 2.
|
|
19
17
|
```
|
|
20
|
-
@task(retries=RetryStrategy(count=5
|
|
18
|
+
@task(retries=RetryStrategy(count=5))
|
|
21
19
|
def my_task():
|
|
22
20
|
pass
|
|
23
21
|
```
|
|
24
22
|
|
|
25
23
|
:param count: The number of retries.
|
|
26
|
-
:param backoff: The maximum backoff time for retries. This can be a float or a timedelta.
|
|
27
|
-
:param backoff: The backoff exponential factor. This can be an integer or a float.
|
|
28
24
|
"""
|
|
29
25
|
|
|
30
26
|
count: int
|
|
31
|
-
backoff: Union[float, timedelta, None] = None
|
|
32
|
-
backoff_factor: Union[int, float, None] = None
|