flyte 0.0.1b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +62 -0
- flyte/_api_commons.py +3 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +126 -0
- flyte/_build.py +25 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +146 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_common.py +287 -0
- flyte/_cli/_create.py +42 -0
- flyte/_cli/_delete.py +23 -0
- flyte/_cli/_deploy.py +140 -0
- flyte/_cli/_get.py +235 -0
- flyte/_cli/_run.py +152 -0
- flyte/_cli/main.py +72 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +339 -0
- flyte/_code_bundle/bundle.py +178 -0
- flyte/_context.py +146 -0
- flyte/_datastructures.py +342 -0
- flyte/_deploy.py +202 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +43 -0
- flyte/_group.py +31 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +760 -0
- flyte/_initialize.py +634 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +115 -0
- flyte/_internal/controllers/_local_controller.py +118 -0
- flyte/_internal/controllers/_trace.py +40 -0
- flyte/_internal/controllers/pbhash.py +39 -0
- flyte/_internal/controllers/remote/__init__.py +40 -0
- flyte/_internal/controllers/remote/_action.py +141 -0
- flyte/_internal/controllers/remote/_client.py +43 -0
- flyte/_internal/controllers/remote/_controller.py +361 -0
- flyte/_internal/controllers/remote/_core.py +402 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +416 -0
- flyte/_internal/imagebuild/image_builder.py +241 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +199 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +210 -0
- flyte/_internal/runtime/taskrunner.py +190 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +124 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +69 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +106 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +133 -0
- flyte/_protos/workflow/run_service_pb2.pyi +175 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +58 -0
- flyte/_protos/workflow/state_service_pb2.pyi +71 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +72 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +44 -0
- flyte/_protos/workflow/task_service_pb2.pyi +31 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +411 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +367 -0
- flyte/_task_environment.py +200 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +128 -0
- flyte/_utils/__init__.py +20 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/coro_management.py +25 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +108 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +143 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +273 -0
- flyte/io/__init__.py +11 -0
- flyte/io/_dataframe.py +0 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +468 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/io/pickle/transformer.py +117 -0
- flyte/io/structured_dataset/__init__.py +129 -0
- flyte/io/structured_dataset/basic_dfs.py +219 -0
- flyte/io/structured_dataset/structured_dataset.py +1061 -0
- flyte/py.typed +0 -0
- flyte/remote/__init__.py +25 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +131 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +184 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +155 -0
- flyte/remote/_logs.py +116 -0
- flyte/remote/_project.py +86 -0
- flyte/remote/_run.py +873 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +227 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +24 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +251 -0
- flyte/storage/_utils.py +5 -0
- flyte/types/__init__.py +13 -0
- flyte/types/_interface.py +25 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2210 -0
- flyte/types/_utils.py +80 -0
- flyte-0.0.1b0.dist-info/METADATA +179 -0
- flyte-0.0.1b0.dist-info/RECORD +390 -0
- flyte-0.0.1b0.dist-info/WHEEL +5 -0
- flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
- flyte-0.0.1b0.dist-info/top_level.txt +1 -0
- union/__init__.py +54 -0
- union/_api_commons.py +3 -0
- union/_bin/__init__.py +0 -0
- union/_bin/runtime.py +113 -0
- union/_build.py +25 -0
- union/_cache/__init__.py +12 -0
- union/_cache/cache.py +141 -0
- union/_cache/defaults.py +9 -0
- union/_cache/policy_function_body.py +42 -0
- union/_cli/__init__.py +0 -0
- union/_cli/_common.py +263 -0
- union/_cli/_create.py +40 -0
- union/_cli/_delete.py +23 -0
- union/_cli/_deploy.py +120 -0
- union/_cli/_get.py +162 -0
- union/_cli/_params.py +579 -0
- union/_cli/_run.py +150 -0
- union/_cli/main.py +72 -0
- union/_code_bundle/__init__.py +8 -0
- union/_code_bundle/_ignore.py +113 -0
- union/_code_bundle/_packaging.py +187 -0
- union/_code_bundle/_utils.py +342 -0
- union/_code_bundle/bundle.py +176 -0
- union/_context.py +146 -0
- union/_datastructures.py +295 -0
- union/_deploy.py +185 -0
- union/_doc.py +29 -0
- union/_docstring.py +26 -0
- union/_environment.py +43 -0
- union/_group.py +31 -0
- union/_hash.py +23 -0
- union/_image.py +760 -0
- union/_initialize.py +585 -0
- union/_interface.py +84 -0
- union/_internal/__init__.py +3 -0
- union/_internal/controllers/__init__.py +77 -0
- union/_internal/controllers/_local_controller.py +77 -0
- union/_internal/controllers/pbhash.py +39 -0
- union/_internal/controllers/remote/__init__.py +40 -0
- union/_internal/controllers/remote/_action.py +131 -0
- union/_internal/controllers/remote/_client.py +43 -0
- union/_internal/controllers/remote/_controller.py +169 -0
- union/_internal/controllers/remote/_core.py +341 -0
- union/_internal/controllers/remote/_informer.py +260 -0
- union/_internal/controllers/remote/_service_protocol.py +44 -0
- union/_internal/imagebuild/__init__.py +11 -0
- union/_internal/imagebuild/docker_builder.py +416 -0
- union/_internal/imagebuild/image_builder.py +243 -0
- union/_internal/imagebuild/remote_builder.py +0 -0
- union/_internal/resolvers/__init__.py +0 -0
- union/_internal/resolvers/_task_module.py +31 -0
- union/_internal/resolvers/common.py +24 -0
- union/_internal/resolvers/default.py +27 -0
- union/_internal/runtime/__init__.py +0 -0
- union/_internal/runtime/convert.py +163 -0
- union/_internal/runtime/entrypoints.py +121 -0
- union/_internal/runtime/io.py +136 -0
- union/_internal/runtime/resources_serde.py +134 -0
- union/_internal/runtime/task_serde.py +202 -0
- union/_internal/runtime/taskrunner.py +179 -0
- union/_internal/runtime/types_serde.py +53 -0
- union/_logging.py +124 -0
- union/_protos/__init__.py +0 -0
- union/_protos/common/authorization_pb2.py +66 -0
- union/_protos/common/authorization_pb2.pyi +106 -0
- union/_protos/common/authorization_pb2_grpc.py +4 -0
- union/_protos/common/identifier_pb2.py +71 -0
- union/_protos/common/identifier_pb2.pyi +82 -0
- union/_protos/common/identifier_pb2_grpc.py +4 -0
- union/_protos/common/identity_pb2.py +48 -0
- union/_protos/common/identity_pb2.pyi +72 -0
- union/_protos/common/identity_pb2_grpc.py +4 -0
- union/_protos/common/list_pb2.py +36 -0
- union/_protos/common/list_pb2.pyi +69 -0
- union/_protos/common/list_pb2_grpc.py +4 -0
- union/_protos/common/policy_pb2.py +37 -0
- union/_protos/common/policy_pb2.pyi +27 -0
- union/_protos/common/policy_pb2_grpc.py +4 -0
- union/_protos/common/role_pb2.py +37 -0
- union/_protos/common/role_pb2.pyi +51 -0
- union/_protos/common/role_pb2_grpc.py +4 -0
- union/_protos/common/runtime_version_pb2.py +28 -0
- union/_protos/common/runtime_version_pb2.pyi +24 -0
- union/_protos/common/runtime_version_pb2_grpc.py +4 -0
- union/_protos/logs/dataplane/payload_pb2.py +96 -0
- union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- union/_protos/secret/definition_pb2.py +49 -0
- union/_protos/secret/definition_pb2.pyi +93 -0
- union/_protos/secret/definition_pb2_grpc.py +4 -0
- union/_protos/secret/payload_pb2.py +62 -0
- union/_protos/secret/payload_pb2.pyi +94 -0
- union/_protos/secret/payload_pb2_grpc.py +4 -0
- union/_protos/secret/secret_pb2.py +38 -0
- union/_protos/secret/secret_pb2.pyi +6 -0
- union/_protos/secret/secret_pb2_grpc.py +198 -0
- union/_protos/validate/validate/validate_pb2.py +76 -0
- union/_protos/workflow/node_execution_service_pb2.py +26 -0
- union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- union/_protos/workflow/queue_service_pb2.py +75 -0
- union/_protos/workflow/queue_service_pb2.pyi +103 -0
- union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- union/_protos/workflow/run_definition_pb2.py +100 -0
- union/_protos/workflow/run_definition_pb2.pyi +256 -0
- union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/run_logs_service_pb2.py +41 -0
- union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- union/_protos/workflow/run_service_pb2.py +133 -0
- union/_protos/workflow/run_service_pb2.pyi +173 -0
- union/_protos/workflow/run_service_pb2_grpc.py +412 -0
- union/_protos/workflow/state_service_pb2.py +58 -0
- union/_protos/workflow/state_service_pb2.pyi +69 -0
- union/_protos/workflow/state_service_pb2_grpc.py +138 -0
- union/_protos/workflow/task_definition_pb2.py +72 -0
- union/_protos/workflow/task_definition_pb2.pyi +65 -0
- union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/task_service_pb2.py +44 -0
- union/_protos/workflow/task_service_pb2.pyi +31 -0
- union/_protos/workflow/task_service_pb2_grpc.py +104 -0
- union/_resources.py +226 -0
- union/_retry.py +32 -0
- union/_reusable_environment.py +25 -0
- union/_run.py +374 -0
- union/_secret.py +61 -0
- union/_task.py +354 -0
- union/_task_environment.py +186 -0
- union/_timeout.py +47 -0
- union/_tools.py +27 -0
- union/_utils/__init__.py +11 -0
- union/_utils/asyn.py +119 -0
- union/_utils/file_handling.py +71 -0
- union/_utils/helpers.py +46 -0
- union/_utils/lazy_module.py +54 -0
- union/_utils/uv_script_parser.py +49 -0
- union/_version.py +21 -0
- union/connectors/__init__.py +0 -0
- union/errors.py +128 -0
- union/extras/__init__.py +5 -0
- union/extras/_container.py +263 -0
- union/io/__init__.py +11 -0
- union/io/_dataframe.py +0 -0
- union/io/_dir.py +425 -0
- union/io/_file.py +418 -0
- union/io/pickle/__init__.py +0 -0
- union/io/pickle/transformer.py +117 -0
- union/io/structured_dataset/__init__.py +122 -0
- union/io/structured_dataset/basic_dfs.py +219 -0
- union/io/structured_dataset/structured_dataset.py +1057 -0
- union/py.typed +0 -0
- union/remote/__init__.py +23 -0
- union/remote/_client/__init__.py +0 -0
- union/remote/_client/_protocols.py +129 -0
- union/remote/_client/auth/__init__.py +12 -0
- union/remote/_client/auth/_authenticators/__init__.py +0 -0
- union/remote/_client/auth/_authenticators/base.py +391 -0
- union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- union/remote/_client/auth/_authenticators/device_code.py +120 -0
- union/remote/_client/auth/_authenticators/external_command.py +77 -0
- union/remote/_client/auth/_authenticators/factory.py +200 -0
- union/remote/_client/auth/_authenticators/pkce.py +515 -0
- union/remote/_client/auth/_channel.py +184 -0
- union/remote/_client/auth/_client_config.py +83 -0
- union/remote/_client/auth/_default_html.py +32 -0
- union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
- union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
- union/remote/_client/auth/_keyring.py +154 -0
- union/remote/_client/auth/_token_client.py +258 -0
- union/remote/_client/auth/errors.py +16 -0
- union/remote/_client/controlplane.py +86 -0
- union/remote/_data.py +149 -0
- union/remote/_logs.py +74 -0
- union/remote/_project.py +86 -0
- union/remote/_run.py +820 -0
- union/remote/_secret.py +132 -0
- union/remote/_task.py +193 -0
- union/report/__init__.py +3 -0
- union/report/_report.py +178 -0
- union/report/_template.html +124 -0
- union/storage/__init__.py +24 -0
- union/storage/_remote_fs.py +34 -0
- union/storage/_storage.py +247 -0
- union/storage/_utils.py +5 -0
- union/types/__init__.py +11 -0
- union/types/_renderer.py +162 -0
- union/types/_string_literals.py +120 -0
- union/types/_type_engine.py +2131 -0
- union/types/_utils.py +80 -0
flyte/_resources.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Literal, Optional, Tuple, Union, get_args
|
|
3
|
+
|
|
4
|
+
import rich.repr
|
|
5
|
+
|
|
6
|
+
GPUType = Literal["T4", "A100", "A100 80G", "H100", "L4", "L40s"]
|
|
7
|
+
GPUQuantity = Literal[1, 2, 3, 4, 5, 6, 7, 8]
|
|
8
|
+
A100Parts = Literal["1g.5gb", "2g.10gb", "3g.20gb", "4g.20gb", "7g.40gb"]
|
|
9
|
+
"""
|
|
10
|
+
Partitions for NVIDIA A100 GPU.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
A100_80GBParts = Literal["1g.10gb", "2g.20gb", "3g.40gb", "4g.40gb", "7g.80gb"]
|
|
14
|
+
"""
|
|
15
|
+
Partitions for NVIDIA A100 80GB GPU.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
TPUType = Literal["V5P", "V6E"]
|
|
19
|
+
V5EParts = Literal["1x1", "2x2", "2x4", "4x4", "4x8", "8x8", "8x16", "16x16"]
|
|
20
|
+
|
|
21
|
+
V5PParts = Literal[
|
|
22
|
+
"2x2x1", "2x2x2", "2x4x4", "4x4x4", "4x4x8", "4x8x8", "8x8x8", "8x8x16", "8x16x16", "16x16x16", "16x16x24"
|
|
23
|
+
]
|
|
24
|
+
"""
|
|
25
|
+
Slices for Google Cloud TPU v5p.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
V6EParts = Literal["1x1", "2x2", "2x4", "4x4", "4x8", "8x8", "8x16", "16x16"]
|
|
29
|
+
"""
|
|
30
|
+
Slices for Google Cloud TPU v6e.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
Accelerators = Literal[
|
|
34
|
+
"T4:1",
|
|
35
|
+
"T4:2",
|
|
36
|
+
"T4:3",
|
|
37
|
+
"T4:4",
|
|
38
|
+
"T4:5",
|
|
39
|
+
"T4:6",
|
|
40
|
+
"T4:7",
|
|
41
|
+
"T4:8",
|
|
42
|
+
"L4:1",
|
|
43
|
+
"L4:2",
|
|
44
|
+
"L4:3",
|
|
45
|
+
"L4:4",
|
|
46
|
+
"L4:5",
|
|
47
|
+
"L4:6",
|
|
48
|
+
"L4:7",
|
|
49
|
+
"L4:8",
|
|
50
|
+
"L40s:1",
|
|
51
|
+
"L40s:2",
|
|
52
|
+
"L40s:3",
|
|
53
|
+
"L40s:4",
|
|
54
|
+
"L40s:5",
|
|
55
|
+
"L40s:6",
|
|
56
|
+
"L40s:7",
|
|
57
|
+
"L40s:8",
|
|
58
|
+
"A100:1",
|
|
59
|
+
"A100:2",
|
|
60
|
+
"A100:3",
|
|
61
|
+
"A100:4",
|
|
62
|
+
"A100:5",
|
|
63
|
+
"A100:6",
|
|
64
|
+
"A100:7",
|
|
65
|
+
"A100:8",
|
|
66
|
+
"A100 80G:1",
|
|
67
|
+
"A100 80G:2",
|
|
68
|
+
"A100 80G:3",
|
|
69
|
+
"A100 80G:4",
|
|
70
|
+
"A100 80G:5",
|
|
71
|
+
"A100 80G:6",
|
|
72
|
+
"A100 80G:7",
|
|
73
|
+
"A100 80G:8",
|
|
74
|
+
"H100:1",
|
|
75
|
+
"H100:2",
|
|
76
|
+
"H100:3",
|
|
77
|
+
"H100:4",
|
|
78
|
+
"H100:5",
|
|
79
|
+
"H100:6",
|
|
80
|
+
"H100:7",
|
|
81
|
+
"H100:8",
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@rich.repr.auto
|
|
86
|
+
@dataclass(frozen=True, slots=True)
|
|
87
|
+
class Device:
|
|
88
|
+
"""
|
|
89
|
+
Represents a device type, its quantity and partition if applicable.
|
|
90
|
+
:param device: The type of device (e.g., "T4", "A100").
|
|
91
|
+
:param quantity: The number of devices of this type.
|
|
92
|
+
:param partition: The partition of the device (e.g., "1g.5gb", "2g.10gb" for gpus) or ("1x1", ... for tpus).
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
quantity: int
|
|
96
|
+
device: str | None = None
|
|
97
|
+
partition: str | None = None
|
|
98
|
+
|
|
99
|
+
def __post_init__(self):
|
|
100
|
+
if self.quantity < 1:
|
|
101
|
+
raise ValueError("GPU quantity must be at least 1")
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def GPU(device: GPUType, quantity: GPUQuantity, partition: A100Parts | A100_80GBParts | None = None) -> Device:
|
|
105
|
+
"""
|
|
106
|
+
Create a GPU device instance.
|
|
107
|
+
:param device: The type of GPU (e.g., "T4", "A100").
|
|
108
|
+
:param quantity: The number of GPUs of this type.
|
|
109
|
+
:param partition: The partition of the GPU (e.g., "1g.5gb", "2g.10gb" for gpus) or ("1x1", ... for tpus).
|
|
110
|
+
:return: Device instance.
|
|
111
|
+
"""
|
|
112
|
+
if quantity < 1:
|
|
113
|
+
raise ValueError("GPU quantity must be at least 1")
|
|
114
|
+
if device not in get_args(GPUType):
|
|
115
|
+
raise ValueError(f"Invalid GPU type: {device}. Must be one of {get_args(GPUType)}")
|
|
116
|
+
if partition is not None and device == "A100":
|
|
117
|
+
if partition not in get_args(A100Parts):
|
|
118
|
+
raise ValueError(f"Invalid partition for A100: {partition}. Must be one of {get_args(A100Parts)}")
|
|
119
|
+
elif partition is not None and device == "A100 80G":
|
|
120
|
+
if partition not in get_args(A100_80GBParts):
|
|
121
|
+
raise ValueError(f"Invalid partition for A100 80G: {partition}. Must be one of {get_args(A100_80GBParts)}")
|
|
122
|
+
return Device(device=device, quantity=quantity, partition=partition)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def TPU(device: TPUType, partition: V5PParts | V6EParts | None = None):
|
|
126
|
+
"""
|
|
127
|
+
Create a TPU device instance.
|
|
128
|
+
:param device: Device type (e.g., "V5P", "V6E").
|
|
129
|
+
:param partition: Partition of the TPU (e.g., "1x1", "2x2", ...).
|
|
130
|
+
:return: Device instance.
|
|
131
|
+
"""
|
|
132
|
+
if device not in get_args(TPUType):
|
|
133
|
+
raise ValueError(f"Invalid TPU type: {device}. Must be one of {get_args(TPUType)}")
|
|
134
|
+
if partition is not None and device == "V5P":
|
|
135
|
+
if partition not in get_args(V5PParts):
|
|
136
|
+
raise ValueError(f"Invalid partition for V5P: {partition}. Must be one of {get_args(V5PParts)}")
|
|
137
|
+
elif partition is not None and device == "V6E":
|
|
138
|
+
if partition not in get_args(V6EParts):
|
|
139
|
+
raise ValueError(f"Invalid partition for V6E: {partition}. Must be one of {get_args(V6EParts)}")
|
|
140
|
+
elif partition is not None and device == "V5E":
|
|
141
|
+
if partition not in get_args(V5EParts):
|
|
142
|
+
raise ValueError(f"Invalid partition for V5E: {partition}. Must be one of {get_args(V5EParts)}")
|
|
143
|
+
return Device(1, device, partition)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
CPUBaseType = int | float | str
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@dataclass
|
|
150
|
+
class Resources:
|
|
151
|
+
"""
|
|
152
|
+
Resources such as CPU, Memory, and GPU that can be allocated to a task.
|
|
153
|
+
|
|
154
|
+
Example:
|
|
155
|
+
- Single CPU, 1GiB of memory, and 1 T4 GPU:
|
|
156
|
+
```python
|
|
157
|
+
@task(resources=Resources(cpu=1, memory="1GiB", gpu="T4:1"))
|
|
158
|
+
def my_task() -> int:
|
|
159
|
+
return 42
|
|
160
|
+
```
|
|
161
|
+
- 1CPU with limit upto 2CPU, 2GiB of memory, and 8 A100 GPUs and 10GiB of disk:
|
|
162
|
+
```python
|
|
163
|
+
@task(resources=Resources(cpu=(1, 2), memory="2GiB", gpu="A100:8", disk="10GiB"))
|
|
164
|
+
def my_task() -> int:
|
|
165
|
+
return 42
|
|
166
|
+
```
|
|
167
|
+
|
|
168
|
+
:param cpu: The amount of CPU to allocate to the task. This can be a string, int, float, list of ints or strings,
|
|
169
|
+
or a tuple of two ints or strings.
|
|
170
|
+
:param memory: The amount of memory to allocate to the task. This can be a string, int, float, list of ints or
|
|
171
|
+
strings, or a tuple of two ints or strings.
|
|
172
|
+
:param gpu: The amount of GPU to allocate to the task. This can be an Accelerators enum, an int, or None.
|
|
173
|
+
:param disk: The amount of disk to allocate to the task. This is a string of the form "10GiB".
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
cpu: Union[CPUBaseType, Tuple[CPUBaseType, CPUBaseType], None] = None
|
|
177
|
+
memory: Union[str, Tuple[str, str], None] = None
|
|
178
|
+
gpu: Union[Accelerators, int, Device, None] = None
|
|
179
|
+
disk: Union[str, None] = None
|
|
180
|
+
shm: Union[str, Literal["auto"], None] = None
|
|
181
|
+
|
|
182
|
+
def __post_init__(self):
|
|
183
|
+
if isinstance(self.cpu, tuple):
|
|
184
|
+
if len(self.cpu) != 2:
|
|
185
|
+
raise ValueError("cpu tuple must have exactly two elements")
|
|
186
|
+
if isinstance(self.memory, tuple):
|
|
187
|
+
if len(self.memory) != 2:
|
|
188
|
+
raise ValueError("memory tuple must have exactly two elements")
|
|
189
|
+
if isinstance(self.cpu, (int, float)):
|
|
190
|
+
if self.cpu < 0:
|
|
191
|
+
raise ValueError("cpu must be greater than or equal to 0")
|
|
192
|
+
if self.gpu is not None:
|
|
193
|
+
if isinstance(self.gpu, int):
|
|
194
|
+
if self.gpu < 0:
|
|
195
|
+
raise ValueError("gpu must be greater than or equal to 0")
|
|
196
|
+
elif isinstance(self.gpu, str):
|
|
197
|
+
if self.gpu not in get_args(Accelerators):
|
|
198
|
+
raise ValueError(f"gpu must be one of {Accelerators}")
|
|
199
|
+
|
|
200
|
+
def get_device(self) -> Optional[Device]:
|
|
201
|
+
"""
|
|
202
|
+
Get the accelerator string for the task.
|
|
203
|
+
|
|
204
|
+
:return: If GPUs are requested, return a tuple of the device name, and potentially a partition string.
|
|
205
|
+
Default cloud provider labels typically use the following values: `1g.5gb`, `2g.10gb`, etc.
|
|
206
|
+
"""
|
|
207
|
+
if self.gpu is None:
|
|
208
|
+
return None
|
|
209
|
+
if isinstance(self.gpu, int):
|
|
210
|
+
return Device(quantity=self.gpu)
|
|
211
|
+
if isinstance(self.gpu, str):
|
|
212
|
+
device, portion = self.gpu.split(":")
|
|
213
|
+
return Device(device=device, quantity=int(portion))
|
|
214
|
+
return self.gpu
|
|
215
|
+
|
|
216
|
+
def get_shared_memory(self) -> Optional[str]:
|
|
217
|
+
"""
|
|
218
|
+
Get the shared memory string for the task.
|
|
219
|
+
|
|
220
|
+
:return: The shared memory string.
|
|
221
|
+
"""
|
|
222
|
+
if self.shm is None:
|
|
223
|
+
return None
|
|
224
|
+
if self.shm == "auto":
|
|
225
|
+
return ""
|
|
226
|
+
return self.shm
|
flyte/_retry.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from datetime import timedelta
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class RetryStrategy:
|
|
8
|
+
"""
|
|
9
|
+
Retry strategy for the task or task environment. Retry strategy is optional or can be a simple number of retries.
|
|
10
|
+
|
|
11
|
+
Example:
|
|
12
|
+
- This will retry the task 5 times.
|
|
13
|
+
```
|
|
14
|
+
@task(retries=5)
|
|
15
|
+
def my_task():
|
|
16
|
+
pass
|
|
17
|
+
```
|
|
18
|
+
- This will retry the task 5 times with a maximum backoff of 10 seconds and a backoff factor of 2.
|
|
19
|
+
```
|
|
20
|
+
@task(retries=RetryStrategy(count=5, max_backoff=10, backoff=2))
|
|
21
|
+
def my_task():
|
|
22
|
+
pass
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
:param count: The number of retries.
|
|
26
|
+
:param backoff: The maximum backoff time for retries. This can be a float or a timedelta.
|
|
27
|
+
:param backoff: The backoff exponential factor. This can be an integer or a float.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
count: int
|
|
31
|
+
backoff: Union[float, timedelta, None] = None
|
|
32
|
+
backoff_factor: Union[int, float, None] = None
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from datetime import timedelta
|
|
3
|
+
from typing import Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class ReusePolicy:
|
|
8
|
+
"""
|
|
9
|
+
ReusePolicy can be used to configure a task to reuse the environment. This is useful when the environment creation
|
|
10
|
+
is expensive and the runtime of the task is short. The environment will be reused for the next invocation of the
|
|
11
|
+
task, even the python process maybe be reused by subsequent task invocations. A good mental model is to think of
|
|
12
|
+
the environment as a container that is reused for multiple tasks, more like a long-running service.
|
|
13
|
+
|
|
14
|
+
Caution: It is important to note that the environment is shared, so managing memory and resources is important.
|
|
15
|
+
|
|
16
|
+
:param replicas: Either a single int representing number of replicas or a tuple of two ints representing
|
|
17
|
+
the min and max
|
|
18
|
+
:param idle_ttl: The maximum idle duration for an environment replica, specified as either seconds (int) or a
|
|
19
|
+
timedelta. If not set, the environment's global default will be used.
|
|
20
|
+
When a replica remains idle — meaning no tasks are running — for this duration, it will be automatically
|
|
21
|
+
terminated.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
replicas: Union[int, Tuple[int, int]] = 1
|
|
25
|
+
idle_ttl: Optional[Union[int, timedelta]] = None
|
flyte/_run.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import pathlib
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union, cast
|
|
6
|
+
|
|
7
|
+
import flyte
|
|
8
|
+
import flyte.report
|
|
9
|
+
from flyte import S3
|
|
10
|
+
|
|
11
|
+
from ._api_commons import syncer
|
|
12
|
+
from ._context import contextual_run, internal_ctx
|
|
13
|
+
from ._datastructures import ActionID, Checkpoints, RawDataPath, SerializationContext, TaskContext
|
|
14
|
+
from ._environment import Environment
|
|
15
|
+
from ._initialize import (
|
|
16
|
+
ABFS,
|
|
17
|
+
GCS,
|
|
18
|
+
_get_init_config,
|
|
19
|
+
get_client,
|
|
20
|
+
get_common_config,
|
|
21
|
+
get_storage,
|
|
22
|
+
requires_initialization,
|
|
23
|
+
requires_storage,
|
|
24
|
+
)
|
|
25
|
+
from ._internal import create_controller
|
|
26
|
+
from ._internal.runtime.io import _CHECKPOINT_FILE_NAME
|
|
27
|
+
from ._internal.runtime.taskrunner import run_task
|
|
28
|
+
from ._logging import logger
|
|
29
|
+
from ._protos.common import identifier_pb2
|
|
30
|
+
from ._task import P, R, TaskTemplate
|
|
31
|
+
from ._tools import ipython_check
|
|
32
|
+
from .errors import InitializationError
|
|
33
|
+
|
|
34
|
+
if TYPE_CHECKING:
|
|
35
|
+
from flyte.remote import Run
|
|
36
|
+
|
|
37
|
+
from ._code_bundle import CopyFiles
|
|
38
|
+
|
|
39
|
+
Mode = Literal["local", "remote", "hybrid"]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@syncer.wrap
|
|
43
|
+
class _Runner:
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
force_mode: Mode | None = None,
|
|
47
|
+
name: Optional[str] = None,
|
|
48
|
+
service_account: Optional[str] = None,
|
|
49
|
+
version: Optional[str] = None,
|
|
50
|
+
copy_style: CopyFiles = "loaded_modules",
|
|
51
|
+
dry_run: bool = False,
|
|
52
|
+
copy_bundle_to: pathlib.Path | None = None,
|
|
53
|
+
interactive_mode: bool | None = None,
|
|
54
|
+
raw_data_path: str | None = None,
|
|
55
|
+
metadata_path: str | None = None,
|
|
56
|
+
run_base_dir: str | None = None,
|
|
57
|
+
):
|
|
58
|
+
init_config = _get_init_config()
|
|
59
|
+
client = init_config.client if init_config else None
|
|
60
|
+
if not force_mode and client is not None:
|
|
61
|
+
force_mode = "remote"
|
|
62
|
+
force_mode = force_mode or "local"
|
|
63
|
+
logger.debug(f"Effective run mode: {force_mode}, client configured: {client is not None}")
|
|
64
|
+
self._mode = force_mode
|
|
65
|
+
self._name = name
|
|
66
|
+
self._service_account = service_account
|
|
67
|
+
self._version = version
|
|
68
|
+
self._copy_files = copy_style
|
|
69
|
+
self._dry_run = dry_run
|
|
70
|
+
self._copy_bundle_to = copy_bundle_to
|
|
71
|
+
self._interactive_mode = interactive_mode if interactive_mode else ipython_check()
|
|
72
|
+
self._raw_data_path = raw_data_path
|
|
73
|
+
self._metadata_path = metadata_path or "/tmp"
|
|
74
|
+
self._run_base_dir = run_base_dir or "/tmp/base"
|
|
75
|
+
|
|
76
|
+
@requires_initialization
|
|
77
|
+
async def _run_remote(self, obj: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> Run:
|
|
78
|
+
from flyte.remote import Run
|
|
79
|
+
|
|
80
|
+
from ._code_bundle import build_code_bundle, build_pkl_bundle
|
|
81
|
+
from ._deploy import build_images, plan_deploy
|
|
82
|
+
from ._internal.runtime.convert import convert_from_native_to_inputs
|
|
83
|
+
from ._internal.runtime.task_serde import translate_task_to_wire
|
|
84
|
+
from ._protos.workflow import run_definition_pb2, run_service_pb2
|
|
85
|
+
|
|
86
|
+
cfg = get_common_config()
|
|
87
|
+
|
|
88
|
+
if obj.parent_env is None:
|
|
89
|
+
raise ValueError("Task is not attached to an environment. Please attach the task to an environment.")
|
|
90
|
+
|
|
91
|
+
deploy_plan = plan_deploy(cast(Environment, obj.parent_env()))
|
|
92
|
+
image_cache = await build_images(deploy_plan)
|
|
93
|
+
|
|
94
|
+
if self._interactive_mode:
|
|
95
|
+
code_bundle = await build_pkl_bundle(
|
|
96
|
+
obj, upload_to_controlplane=not self._dry_run, copy_bundle_to=self._copy_bundle_to
|
|
97
|
+
)
|
|
98
|
+
else:
|
|
99
|
+
if self._copy_files != "none":
|
|
100
|
+
code_bundle = await build_code_bundle(
|
|
101
|
+
from_dir=cfg.root_dir, dryrun=self._dry_run, copy_bundle_to=self._copy_bundle_to
|
|
102
|
+
)
|
|
103
|
+
else:
|
|
104
|
+
code_bundle = None
|
|
105
|
+
|
|
106
|
+
version = self._version or (
|
|
107
|
+
code_bundle.computed_version if code_bundle and code_bundle.computed_version else None
|
|
108
|
+
)
|
|
109
|
+
if not version:
|
|
110
|
+
raise ValueError("Version is required when running a task")
|
|
111
|
+
s_ctx = SerializationContext(
|
|
112
|
+
code_bundle=code_bundle,
|
|
113
|
+
version=version,
|
|
114
|
+
image_cache=image_cache,
|
|
115
|
+
root_dir=cfg.root_dir,
|
|
116
|
+
)
|
|
117
|
+
task_spec = translate_task_to_wire(obj, s_ctx)
|
|
118
|
+
inputs = await convert_from_native_to_inputs(obj.native_interface, *args, **kwargs)
|
|
119
|
+
|
|
120
|
+
if not self._dry_run:
|
|
121
|
+
if get_client() is None:
|
|
122
|
+
# This can only happen, if the user forces flyte.run(mode="remote") without initializing the client
|
|
123
|
+
raise InitializationError(
|
|
124
|
+
"ClientNotInitializedError",
|
|
125
|
+
"user",
|
|
126
|
+
"flyte.run requires client to be initialized. "
|
|
127
|
+
"Call flyte.init() with a valid endpoint or api-key before using this function.",
|
|
128
|
+
)
|
|
129
|
+
run_id = None
|
|
130
|
+
project_id = None
|
|
131
|
+
if self._name:
|
|
132
|
+
run_id = run_definition_pb2.RunIdentifier(
|
|
133
|
+
project=cfg.project,
|
|
134
|
+
domain=cfg.domain,
|
|
135
|
+
org=cfg.org,
|
|
136
|
+
name=self._name if self._name else None,
|
|
137
|
+
)
|
|
138
|
+
else:
|
|
139
|
+
project_id = identifier_pb2.ProjectIdentifier(
|
|
140
|
+
name=cfg.project,
|
|
141
|
+
domain=cfg.domain,
|
|
142
|
+
organization=cfg.org,
|
|
143
|
+
)
|
|
144
|
+
# Fill in task id inside the task template if it's not provided.
|
|
145
|
+
# Maybe this should be done here, or the backend.
|
|
146
|
+
if task_spec.task_template.id.project == "":
|
|
147
|
+
task_spec.task_template.id.project = cfg.project if cfg.project else ""
|
|
148
|
+
if task_spec.task_template.id.domain == "":
|
|
149
|
+
task_spec.task_template.id.domain = cfg.domain if cfg.domain else ""
|
|
150
|
+
if task_spec.task_template.id.org == "":
|
|
151
|
+
task_spec.task_template.id.org = cfg.org if cfg.org else ""
|
|
152
|
+
if task_spec.task_template.id.version == "":
|
|
153
|
+
task_spec.task_template.id.version = version
|
|
154
|
+
|
|
155
|
+
resp = await get_client().run_service.CreateRun(
|
|
156
|
+
run_service_pb2.CreateRunRequest(
|
|
157
|
+
run_id=run_id,
|
|
158
|
+
project_id=project_id,
|
|
159
|
+
task_spec=task_spec,
|
|
160
|
+
inputs=inputs.proto_inputs,
|
|
161
|
+
),
|
|
162
|
+
)
|
|
163
|
+
return Run(pb2=resp.run)
|
|
164
|
+
|
|
165
|
+
class DryRun(Run):
|
|
166
|
+
def __init__(self, _task_spec, _inputs, _code_bundle):
|
|
167
|
+
super().__init__(
|
|
168
|
+
pb2=run_definition_pb2.Run(
|
|
169
|
+
action=run_definition_pb2.Action(
|
|
170
|
+
id=run_definition_pb2.ActionIdentifier(
|
|
171
|
+
name="a0", run=run_definition_pb2.RunIdentifier(name="dry-run")
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
self.task_spec = _task_spec
|
|
177
|
+
self.inputs = _inputs
|
|
178
|
+
self.code_bundle = _code_bundle
|
|
179
|
+
|
|
180
|
+
return DryRun(_task_spec=task_spec, _inputs=inputs, _code_bundle=code_bundle)
|
|
181
|
+
|
|
182
|
+
@requires_storage
|
|
183
|
+
@requires_initialization
|
|
184
|
+
async def _run_hybrid(self, obj: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
|
185
|
+
"""
|
|
186
|
+
Run a task in hybrid mode. This means that the parent action will be run locally, but the child actions will be
|
|
187
|
+
run in the cluster remotely. This is currently only used for testing,
|
|
188
|
+
over the longer term we will productize this.
|
|
189
|
+
"""
|
|
190
|
+
from flyte._code_bundle import build_code_bundle, build_pkl_bundle
|
|
191
|
+
from flyte._datastructures import RawDataPath
|
|
192
|
+
from flyte._deploy import build_images, plan_deploy
|
|
193
|
+
|
|
194
|
+
cfg = get_common_config()
|
|
195
|
+
|
|
196
|
+
if obj.parent_env is None:
|
|
197
|
+
raise ValueError("Task is not attached to an environment. Please attach the task to an environment.")
|
|
198
|
+
|
|
199
|
+
deploy_plan = plan_deploy(cast(Environment, obj.parent_env()))
|
|
200
|
+
image_cache = await build_images(deploy_plan)
|
|
201
|
+
|
|
202
|
+
if self._interactive_mode:
|
|
203
|
+
code_bundle = await build_pkl_bundle(
|
|
204
|
+
obj, upload_to_controlplane=not self._dry_run, copy_bundle_to=self._copy_bundle_to
|
|
205
|
+
)
|
|
206
|
+
else:
|
|
207
|
+
if self._copy_files != "none":
|
|
208
|
+
code_bundle = await build_code_bundle(
|
|
209
|
+
from_dir=cfg.root_dir, dryrun=self._dry_run, copy_bundle_to=self._copy_bundle_to
|
|
210
|
+
)
|
|
211
|
+
else:
|
|
212
|
+
code_bundle = None
|
|
213
|
+
|
|
214
|
+
version = self._version or (
|
|
215
|
+
code_bundle.computed_version if code_bundle and code_bundle.computed_version else None
|
|
216
|
+
)
|
|
217
|
+
if not version:
|
|
218
|
+
raise ValueError("Version is required when running a task")
|
|
219
|
+
|
|
220
|
+
project = cfg.project or "testproject"
|
|
221
|
+
domain = cfg.domain or "development"
|
|
222
|
+
org = cfg.org or "testorg"
|
|
223
|
+
action_name = "a0"
|
|
224
|
+
run_name = self._name
|
|
225
|
+
random_id = str(uuid.uuid4())[:6]
|
|
226
|
+
|
|
227
|
+
controller = create_controller(ct="remote", endpoint="localhost:8090", insecure=True)
|
|
228
|
+
action = ActionID(name=action_name, run_name=run_name, project=project, domain=domain, org=org)
|
|
229
|
+
|
|
230
|
+
inputs = obj.native_interface.convert_to_kwargs(*args, **kwargs)
|
|
231
|
+
# TODO: Ideally we should get this from runService
|
|
232
|
+
# The API should be:
|
|
233
|
+
# create new run, from run, in mode hybrid -> new run id, output_base, raw_data_path, inputs_path
|
|
234
|
+
storage = get_storage()
|
|
235
|
+
if type(storage) not in (S3, GCS, ABFS):
|
|
236
|
+
raise ValueError(f"Unsupported storage type: {type(storage)}")
|
|
237
|
+
if self._run_base_dir is None:
|
|
238
|
+
raise ValueError(
|
|
239
|
+
"Raw data path is required when running task, please set it in the run context:",
|
|
240
|
+
" flyte.with_runcontext(run_base_dir='s3://bucket/metadata/outputs')",
|
|
241
|
+
)
|
|
242
|
+
output_path = self._run_base_dir
|
|
243
|
+
raw_data_path = f"{output_path}/rd/{random_id}"
|
|
244
|
+
raw_data_path_obj = RawDataPath(path=raw_data_path)
|
|
245
|
+
checkpoint_path = f"{raw_data_path}/{_CHECKPOINT_FILE_NAME}"
|
|
246
|
+
prev_checkpoint = f"{raw_data_path}/prev_checkpoint"
|
|
247
|
+
checkpoints = Checkpoints(checkpoint_path, prev_checkpoint)
|
|
248
|
+
|
|
249
|
+
async def _run_task() -> Tuple[Any, Optional[Exception]]:
|
|
250
|
+
ctx = internal_ctx()
|
|
251
|
+
tctx = TaskContext(
|
|
252
|
+
action=action,
|
|
253
|
+
checkpoints=checkpoints,
|
|
254
|
+
code_bundle=code_bundle,
|
|
255
|
+
output_path=output_path,
|
|
256
|
+
version=version,
|
|
257
|
+
raw_data_path=raw_data_path_obj,
|
|
258
|
+
compiled_image_cache=image_cache,
|
|
259
|
+
run_base_dir=self._run_base_dir,
|
|
260
|
+
report=flyte.report.Report(name=action.name),
|
|
261
|
+
)
|
|
262
|
+
async with ctx.replace_task_context(tctx):
|
|
263
|
+
return await run_task(tctx=tctx, controller=controller, task=obj, inputs=inputs)
|
|
264
|
+
|
|
265
|
+
outputs, err = await contextual_run(_run_task)
|
|
266
|
+
if err:
|
|
267
|
+
raise err
|
|
268
|
+
return outputs
|
|
269
|
+
|
|
270
|
+
async def _run_local(self, obj: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
|
271
|
+
from flyte._internal.controllers import create_controller
|
|
272
|
+
from flyte._internal.runtime.convert import (
|
|
273
|
+
convert_error_to_native,
|
|
274
|
+
convert_from_native_to_inputs,
|
|
275
|
+
convert_outputs_to_native,
|
|
276
|
+
)
|
|
277
|
+
from flyte._internal.runtime.entrypoints import direct_dispatch
|
|
278
|
+
|
|
279
|
+
controller = create_controller(ct="local")
|
|
280
|
+
|
|
281
|
+
inputs = await convert_from_native_to_inputs(obj.native_interface, *args, **kwargs)
|
|
282
|
+
if self._name is None:
|
|
283
|
+
action = ActionID.create_random()
|
|
284
|
+
else:
|
|
285
|
+
action = ActionID(name=self._name)
|
|
286
|
+
out, err = await direct_dispatch(
|
|
287
|
+
obj,
|
|
288
|
+
action=action,
|
|
289
|
+
raw_data_path=internal_ctx().raw_data,
|
|
290
|
+
version="na",
|
|
291
|
+
controller=controller,
|
|
292
|
+
inputs=inputs,
|
|
293
|
+
output_path=self._metadata_path,
|
|
294
|
+
run_base_dir=self._metadata_path,
|
|
295
|
+
checkpoints=Checkpoints(
|
|
296
|
+
prev_checkpoint_path=internal_ctx().raw_data.path, checkpoint_path=internal_ctx().raw_data.path
|
|
297
|
+
),
|
|
298
|
+
) # type: ignore
|
|
299
|
+
if err:
|
|
300
|
+
native_err = convert_error_to_native(err)
|
|
301
|
+
if native_err:
|
|
302
|
+
raise native_err
|
|
303
|
+
if obj.native_interface.outputs and len(obj.native_interface.outputs) > 0:
|
|
304
|
+
if out is not None:
|
|
305
|
+
return cast(R, await convert_outputs_to_native(obj.native_interface, out))
|
|
306
|
+
return cast(R, None)
|
|
307
|
+
|
|
308
|
+
async def run(self, task: TaskTemplate[P, Union[R, Run]], *args: P.args, **kwargs: P.kwargs) -> Union[R, Run]:
|
|
309
|
+
"""
|
|
310
|
+
Run an async `@env.task` or `TaskTemplate` instance. The existing async context will be used.
|
|
311
|
+
|
|
312
|
+
Example:
|
|
313
|
+
```python
|
|
314
|
+
import flyte
|
|
315
|
+
env = flyte.TaskEnvironment("example")
|
|
316
|
+
|
|
317
|
+
@env.task
|
|
318
|
+
async def example_task(x: int, y: str) -> str:
|
|
319
|
+
return f"{x} {y}"
|
|
320
|
+
|
|
321
|
+
if __name__ == "__main__":
|
|
322
|
+
flyte.run(example_task, 1, y="hello")
|
|
323
|
+
```
|
|
324
|
+
|
|
325
|
+
:param task: TaskTemplate instance `@env.task` or `TaskTemplate`
|
|
326
|
+
:param args: Arguments to pass to the Task
|
|
327
|
+
:param kwargs: Keyword arguments to pass to the Task
|
|
328
|
+
:return: Run instance or the result of the task
|
|
329
|
+
"""
|
|
330
|
+
if self._mode == "remote":
|
|
331
|
+
return await self._run_remote(task, *args, **kwargs)
|
|
332
|
+
if self._mode == "hybrid":
|
|
333
|
+
return await self._run_hybrid(task, *args, **kwargs)
|
|
334
|
+
|
|
335
|
+
# TODO We could use this for remote as well and users could simply pass flyte:// or s3:// or file://
|
|
336
|
+
async with internal_ctx().new_raw_data_path(
|
|
337
|
+
raw_data_path=RawDataPath.from_local_folder(local_folder=self._raw_data_path)
|
|
338
|
+
):
|
|
339
|
+
return await self._run_local(task, *args, **kwargs)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def with_runcontext(
|
|
343
|
+
mode: Mode | None = None,
|
|
344
|
+
*,
|
|
345
|
+
name: Optional[str] = None,
|
|
346
|
+
service_account: Optional[str] = None,
|
|
347
|
+
version: Optional[str] = None,
|
|
348
|
+
copy_style: CopyFiles = "loaded_modules",
|
|
349
|
+
dry_run: bool = False,
|
|
350
|
+
copy_bundle_to: pathlib.Path | None = None,
|
|
351
|
+
interactive_mode: bool | None = None,
|
|
352
|
+
raw_data_path: str | None = None,
|
|
353
|
+
run_base_dir: str | None = None,
|
|
354
|
+
) -> _Runner:
|
|
355
|
+
"""
|
|
356
|
+
Launch a new run with the given parameters as the context.
|
|
357
|
+
|
|
358
|
+
Example:
|
|
359
|
+
```python
|
|
360
|
+
import flyte
|
|
361
|
+
env = flyte.TaskEnvironment("example")
|
|
362
|
+
|
|
363
|
+
@env.task
|
|
364
|
+
async def example_task(x: int, y: str) -> str:
|
|
365
|
+
return f"{x} {y}"
|
|
366
|
+
|
|
367
|
+
if __name__ == "__main__":
|
|
368
|
+
flyte.with_runcontext(name="example_run_id").run(example_task, 1, y="hello")
|
|
369
|
+
```
|
|
370
|
+
|
|
371
|
+
:param mode: Optional The mode to use for the run, if not provided, it will be computed from flyte.init
|
|
372
|
+
:param version: Optional The version to use for the run, if not provided, it will be computed from the code bundle
|
|
373
|
+
:param name: Optional The name to use for the run
|
|
374
|
+
:param service_account: Optional The service account to use for the run context
|
|
375
|
+
:param copy_style: Optional The copy style to use for the run context
|
|
376
|
+
:param dry_run: Optional If true, the run will not be executed, but the bundle will be created
|
|
377
|
+
:param copy_bundle_to: When dry_run is True, the bundle will be copied to this location if specified
|
|
378
|
+
:param interactive_mode: Optional, can be forced to True or False.
|
|
379
|
+
If not provided, it will be set based on the current environment. For example Jupyter notebooks are considered
|
|
380
|
+
interactive mode, while scripts are not. This is used to determine how the code bundle is created.
|
|
381
|
+
:param raw_data_path: Use this path to store the raw data for the run. Currently only supported for local runs,
|
|
382
|
+
and can be used to store raw data in specific locations. TODO coming soon for remote runs as well.
|
|
383
|
+
:return: runner
|
|
384
|
+
"""
|
|
385
|
+
if mode == "hybrid" and not name and not run_base_dir:
|
|
386
|
+
raise ValueError("Run name and run base dir are required for hybrid mode")
|
|
387
|
+
return _Runner(
|
|
388
|
+
force_mode=mode,
|
|
389
|
+
name=name,
|
|
390
|
+
service_account=service_account,
|
|
391
|
+
version=version,
|
|
392
|
+
copy_style=copy_style,
|
|
393
|
+
dry_run=dry_run,
|
|
394
|
+
copy_bundle_to=copy_bundle_to,
|
|
395
|
+
interactive_mode=interactive_mode,
|
|
396
|
+
raw_data_path=raw_data_path,
|
|
397
|
+
run_base_dir=run_base_dir,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
@syncer.wrap
|
|
402
|
+
async def run(task: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> Union[R, Run]:
|
|
403
|
+
"""
|
|
404
|
+
Run a task with the given parameters
|
|
405
|
+
:param task: task to run
|
|
406
|
+
:param args: args to pass to the task
|
|
407
|
+
:param kwargs: kwargs to pass to the task
|
|
408
|
+
:return: Run | Result of the task
|
|
409
|
+
"""
|
|
410
|
+
# using syncer causes problems
|
|
411
|
+
return await _Runner().run.aio(task, *args, **kwargs) # type: ignore
|