flyte 0.0.1b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +62 -0
- flyte/_api_commons.py +3 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +126 -0
- flyte/_build.py +25 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +146 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_common.py +287 -0
- flyte/_cli/_create.py +42 -0
- flyte/_cli/_delete.py +23 -0
- flyte/_cli/_deploy.py +140 -0
- flyte/_cli/_get.py +235 -0
- flyte/_cli/_run.py +152 -0
- flyte/_cli/main.py +72 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +339 -0
- flyte/_code_bundle/bundle.py +178 -0
- flyte/_context.py +146 -0
- flyte/_datastructures.py +342 -0
- flyte/_deploy.py +202 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +43 -0
- flyte/_group.py +31 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +760 -0
- flyte/_initialize.py +634 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +115 -0
- flyte/_internal/controllers/_local_controller.py +118 -0
- flyte/_internal/controllers/_trace.py +40 -0
- flyte/_internal/controllers/pbhash.py +39 -0
- flyte/_internal/controllers/remote/__init__.py +40 -0
- flyte/_internal/controllers/remote/_action.py +141 -0
- flyte/_internal/controllers/remote/_client.py +43 -0
- flyte/_internal/controllers/remote/_controller.py +361 -0
- flyte/_internal/controllers/remote/_core.py +402 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +416 -0
- flyte/_internal/imagebuild/image_builder.py +241 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +199 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +210 -0
- flyte/_internal/runtime/taskrunner.py +190 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +124 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +69 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +106 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +133 -0
- flyte/_protos/workflow/run_service_pb2.pyi +175 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +58 -0
- flyte/_protos/workflow/state_service_pb2.pyi +71 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +72 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +44 -0
- flyte/_protos/workflow/task_service_pb2.pyi +31 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +411 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +367 -0
- flyte/_task_environment.py +200 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +128 -0
- flyte/_utils/__init__.py +20 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/coro_management.py +25 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +108 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +143 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +273 -0
- flyte/io/__init__.py +11 -0
- flyte/io/_dataframe.py +0 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +468 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/io/pickle/transformer.py +117 -0
- flyte/io/structured_dataset/__init__.py +129 -0
- flyte/io/structured_dataset/basic_dfs.py +219 -0
- flyte/io/structured_dataset/structured_dataset.py +1061 -0
- flyte/py.typed +0 -0
- flyte/remote/__init__.py +25 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +131 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +184 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +155 -0
- flyte/remote/_logs.py +116 -0
- flyte/remote/_project.py +86 -0
- flyte/remote/_run.py +873 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +227 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +24 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +251 -0
- flyte/storage/_utils.py +5 -0
- flyte/types/__init__.py +13 -0
- flyte/types/_interface.py +25 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2210 -0
- flyte/types/_utils.py +80 -0
- flyte-0.0.1b0.dist-info/METADATA +179 -0
- flyte-0.0.1b0.dist-info/RECORD +390 -0
- flyte-0.0.1b0.dist-info/WHEEL +5 -0
- flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
- flyte-0.0.1b0.dist-info/top_level.txt +1 -0
- union/__init__.py +54 -0
- union/_api_commons.py +3 -0
- union/_bin/__init__.py +0 -0
- union/_bin/runtime.py +113 -0
- union/_build.py +25 -0
- union/_cache/__init__.py +12 -0
- union/_cache/cache.py +141 -0
- union/_cache/defaults.py +9 -0
- union/_cache/policy_function_body.py +42 -0
- union/_cli/__init__.py +0 -0
- union/_cli/_common.py +263 -0
- union/_cli/_create.py +40 -0
- union/_cli/_delete.py +23 -0
- union/_cli/_deploy.py +120 -0
- union/_cli/_get.py +162 -0
- union/_cli/_params.py +579 -0
- union/_cli/_run.py +150 -0
- union/_cli/main.py +72 -0
- union/_code_bundle/__init__.py +8 -0
- union/_code_bundle/_ignore.py +113 -0
- union/_code_bundle/_packaging.py +187 -0
- union/_code_bundle/_utils.py +342 -0
- union/_code_bundle/bundle.py +176 -0
- union/_context.py +146 -0
- union/_datastructures.py +295 -0
- union/_deploy.py +185 -0
- union/_doc.py +29 -0
- union/_docstring.py +26 -0
- union/_environment.py +43 -0
- union/_group.py +31 -0
- union/_hash.py +23 -0
- union/_image.py +760 -0
- union/_initialize.py +585 -0
- union/_interface.py +84 -0
- union/_internal/__init__.py +3 -0
- union/_internal/controllers/__init__.py +77 -0
- union/_internal/controllers/_local_controller.py +77 -0
- union/_internal/controllers/pbhash.py +39 -0
- union/_internal/controllers/remote/__init__.py +40 -0
- union/_internal/controllers/remote/_action.py +131 -0
- union/_internal/controllers/remote/_client.py +43 -0
- union/_internal/controllers/remote/_controller.py +169 -0
- union/_internal/controllers/remote/_core.py +341 -0
- union/_internal/controllers/remote/_informer.py +260 -0
- union/_internal/controllers/remote/_service_protocol.py +44 -0
- union/_internal/imagebuild/__init__.py +11 -0
- union/_internal/imagebuild/docker_builder.py +416 -0
- union/_internal/imagebuild/image_builder.py +243 -0
- union/_internal/imagebuild/remote_builder.py +0 -0
- union/_internal/resolvers/__init__.py +0 -0
- union/_internal/resolvers/_task_module.py +31 -0
- union/_internal/resolvers/common.py +24 -0
- union/_internal/resolvers/default.py +27 -0
- union/_internal/runtime/__init__.py +0 -0
- union/_internal/runtime/convert.py +163 -0
- union/_internal/runtime/entrypoints.py +121 -0
- union/_internal/runtime/io.py +136 -0
- union/_internal/runtime/resources_serde.py +134 -0
- union/_internal/runtime/task_serde.py +202 -0
- union/_internal/runtime/taskrunner.py +179 -0
- union/_internal/runtime/types_serde.py +53 -0
- union/_logging.py +124 -0
- union/_protos/__init__.py +0 -0
- union/_protos/common/authorization_pb2.py +66 -0
- union/_protos/common/authorization_pb2.pyi +106 -0
- union/_protos/common/authorization_pb2_grpc.py +4 -0
- union/_protos/common/identifier_pb2.py +71 -0
- union/_protos/common/identifier_pb2.pyi +82 -0
- union/_protos/common/identifier_pb2_grpc.py +4 -0
- union/_protos/common/identity_pb2.py +48 -0
- union/_protos/common/identity_pb2.pyi +72 -0
- union/_protos/common/identity_pb2_grpc.py +4 -0
- union/_protos/common/list_pb2.py +36 -0
- union/_protos/common/list_pb2.pyi +69 -0
- union/_protos/common/list_pb2_grpc.py +4 -0
- union/_protos/common/policy_pb2.py +37 -0
- union/_protos/common/policy_pb2.pyi +27 -0
- union/_protos/common/policy_pb2_grpc.py +4 -0
- union/_protos/common/role_pb2.py +37 -0
- union/_protos/common/role_pb2.pyi +51 -0
- union/_protos/common/role_pb2_grpc.py +4 -0
- union/_protos/common/runtime_version_pb2.py +28 -0
- union/_protos/common/runtime_version_pb2.pyi +24 -0
- union/_protos/common/runtime_version_pb2_grpc.py +4 -0
- union/_protos/logs/dataplane/payload_pb2.py +96 -0
- union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- union/_protos/secret/definition_pb2.py +49 -0
- union/_protos/secret/definition_pb2.pyi +93 -0
- union/_protos/secret/definition_pb2_grpc.py +4 -0
- union/_protos/secret/payload_pb2.py +62 -0
- union/_protos/secret/payload_pb2.pyi +94 -0
- union/_protos/secret/payload_pb2_grpc.py +4 -0
- union/_protos/secret/secret_pb2.py +38 -0
- union/_protos/secret/secret_pb2.pyi +6 -0
- union/_protos/secret/secret_pb2_grpc.py +198 -0
- union/_protos/validate/validate/validate_pb2.py +76 -0
- union/_protos/workflow/node_execution_service_pb2.py +26 -0
- union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- union/_protos/workflow/queue_service_pb2.py +75 -0
- union/_protos/workflow/queue_service_pb2.pyi +103 -0
- union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- union/_protos/workflow/run_definition_pb2.py +100 -0
- union/_protos/workflow/run_definition_pb2.pyi +256 -0
- union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/run_logs_service_pb2.py +41 -0
- union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- union/_protos/workflow/run_service_pb2.py +133 -0
- union/_protos/workflow/run_service_pb2.pyi +173 -0
- union/_protos/workflow/run_service_pb2_grpc.py +412 -0
- union/_protos/workflow/state_service_pb2.py +58 -0
- union/_protos/workflow/state_service_pb2.pyi +69 -0
- union/_protos/workflow/state_service_pb2_grpc.py +138 -0
- union/_protos/workflow/task_definition_pb2.py +72 -0
- union/_protos/workflow/task_definition_pb2.pyi +65 -0
- union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/task_service_pb2.py +44 -0
- union/_protos/workflow/task_service_pb2.pyi +31 -0
- union/_protos/workflow/task_service_pb2_grpc.py +104 -0
- union/_resources.py +226 -0
- union/_retry.py +32 -0
- union/_reusable_environment.py +25 -0
- union/_run.py +374 -0
- union/_secret.py +61 -0
- union/_task.py +354 -0
- union/_task_environment.py +186 -0
- union/_timeout.py +47 -0
- union/_tools.py +27 -0
- union/_utils/__init__.py +11 -0
- union/_utils/asyn.py +119 -0
- union/_utils/file_handling.py +71 -0
- union/_utils/helpers.py +46 -0
- union/_utils/lazy_module.py +54 -0
- union/_utils/uv_script_parser.py +49 -0
- union/_version.py +21 -0
- union/connectors/__init__.py +0 -0
- union/errors.py +128 -0
- union/extras/__init__.py +5 -0
- union/extras/_container.py +263 -0
- union/io/__init__.py +11 -0
- union/io/_dataframe.py +0 -0
- union/io/_dir.py +425 -0
- union/io/_file.py +418 -0
- union/io/pickle/__init__.py +0 -0
- union/io/pickle/transformer.py +117 -0
- union/io/structured_dataset/__init__.py +122 -0
- union/io/structured_dataset/basic_dfs.py +219 -0
- union/io/structured_dataset/structured_dataset.py +1057 -0
- union/py.typed +0 -0
- union/remote/__init__.py +23 -0
- union/remote/_client/__init__.py +0 -0
- union/remote/_client/_protocols.py +129 -0
- union/remote/_client/auth/__init__.py +12 -0
- union/remote/_client/auth/_authenticators/__init__.py +0 -0
- union/remote/_client/auth/_authenticators/base.py +391 -0
- union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- union/remote/_client/auth/_authenticators/device_code.py +120 -0
- union/remote/_client/auth/_authenticators/external_command.py +77 -0
- union/remote/_client/auth/_authenticators/factory.py +200 -0
- union/remote/_client/auth/_authenticators/pkce.py +515 -0
- union/remote/_client/auth/_channel.py +184 -0
- union/remote/_client/auth/_client_config.py +83 -0
- union/remote/_client/auth/_default_html.py +32 -0
- union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
- union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
- union/remote/_client/auth/_keyring.py +154 -0
- union/remote/_client/auth/_token_client.py +258 -0
- union/remote/_client/auth/errors.py +16 -0
- union/remote/_client/controlplane.py +86 -0
- union/remote/_data.py +149 -0
- union/remote/_logs.py +74 -0
- union/remote/_project.py +86 -0
- union/remote/_run.py +820 -0
- union/remote/_secret.py +132 -0
- union/remote/_task.py +193 -0
- union/report/__init__.py +3 -0
- union/report/_report.py +178 -0
- union/report/_template.html +124 -0
- union/storage/__init__.py +24 -0
- union/storage/_remote_fs.py +34 -0
- union/storage/_storage.py +247 -0
- union/storage/_utils.py +5 -0
- union/types/__init__.py +11 -0
- union/types/_renderer.py +162 -0
- union/types/_string_literals.py +120 -0
- union/types/_type_engine.py +2131 -0
- union/types/_utils.py +80 -0
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import gzip
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import pathlib
|
|
6
|
+
import tempfile
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Type
|
|
9
|
+
|
|
10
|
+
from flyteidl.core.tasks_pb2 import TaskTemplate
|
|
11
|
+
|
|
12
|
+
import flyte.storage as storage
|
|
13
|
+
from flyte._datastructures import CodeBundle
|
|
14
|
+
from flyte._logging import log, logger
|
|
15
|
+
|
|
16
|
+
from ._ignore import GitIgnore, Ignore, StandardIgnore
|
|
17
|
+
from ._packaging import create_bundle, list_files_to_bundle, print_ls_tree
|
|
18
|
+
from ._utils import CopyFiles, hash_file
|
|
19
|
+
|
|
20
|
+
_pickled_file_extension = ".pkl.gz"
|
|
21
|
+
_tar_file_extension = ".tar.gz"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
async def build_pkl_bundle(
|
|
25
|
+
o: TaskTemplate,
|
|
26
|
+
upload_to_controlplane: bool = True,
|
|
27
|
+
upload_from_dataplane_path: str | None = None,
|
|
28
|
+
copy_bundle_to: pathlib.Path | None = None,
|
|
29
|
+
) -> CodeBundle:
|
|
30
|
+
"""
|
|
31
|
+
Build a Pickled for the given task.
|
|
32
|
+
|
|
33
|
+
TODO We can optimize this by having an LRU cache for the function, this is so that if the same task is being
|
|
34
|
+
pickled multiple times, we can avoid the overhead of pickling it multiple times, by copying to a common place
|
|
35
|
+
and reusing based on task hash.
|
|
36
|
+
|
|
37
|
+
:param o: Object to be pickled. This is the task template.
|
|
38
|
+
:param upload_to_controlplane: Whether to upload the pickled file to the control plane or not
|
|
39
|
+
:param upload_from_dataplane_path: If we are on the dataplane, this is the path where the
|
|
40
|
+
pickled file should be uploaded to. upload_to_controlplane has to be False in this case.
|
|
41
|
+
:param copy_bundle_to: If set, the bundle will be copied to this path. This is used for testing purposes.
|
|
42
|
+
:return: CodeBundle object containing the pickled file path and the computed version.
|
|
43
|
+
"""
|
|
44
|
+
import cloudpickle
|
|
45
|
+
|
|
46
|
+
import flyte.storage as storage
|
|
47
|
+
|
|
48
|
+
if upload_to_controlplane and upload_from_dataplane_path:
|
|
49
|
+
raise ValueError("Cannot upload to control plane and upload from dataplane path at the same time.")
|
|
50
|
+
|
|
51
|
+
logger.debug("Building pickled code bundle.")
|
|
52
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
53
|
+
dest = pathlib.Path(tmp_dir) / f"code_bundle{_pickled_file_extension}"
|
|
54
|
+
with gzip.GzipFile(filename=dest, mode="wb", mtime=0) as gzipped:
|
|
55
|
+
cloudpickle.dump(o, gzipped)
|
|
56
|
+
|
|
57
|
+
if upload_to_controlplane:
|
|
58
|
+
logger.debug("Uploading pickled code bundle to control plane.")
|
|
59
|
+
from flyte.remote import upload_file
|
|
60
|
+
|
|
61
|
+
hash_digest, remote_path = await upload_file(dest)
|
|
62
|
+
return CodeBundle(pkl=remote_path, computed_version=hash_digest)
|
|
63
|
+
|
|
64
|
+
elif upload_from_dataplane_path:
|
|
65
|
+
logger.debug(f"Uploading pickled code bundle to dataplane path {upload_from_dataplane_path}.")
|
|
66
|
+
_, str_digest, _ = hash_file(file_path=dest)
|
|
67
|
+
final_path = await storage.put(str(dest), upload_from_dataplane_path)
|
|
68
|
+
return CodeBundle(pkl=final_path, computed_version=str_digest)
|
|
69
|
+
|
|
70
|
+
else:
|
|
71
|
+
logger.debug("Dryrun enabled, not uploading pickled code bundle.")
|
|
72
|
+
_, str_digest, _ = hash_file(file_path=dest)
|
|
73
|
+
if copy_bundle_to:
|
|
74
|
+
import shutil
|
|
75
|
+
|
|
76
|
+
# Copy the bundle to the given path
|
|
77
|
+
shutil.copy(dest, copy_bundle_to)
|
|
78
|
+
local_path = copy_bundle_to / dest.name
|
|
79
|
+
return CodeBundle(pkl=str(local_path), computed_version=str_digest)
|
|
80
|
+
return CodeBundle(pkl=str(dest), computed_version=str_digest)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
async def build_code_bundle(
|
|
84
|
+
from_dir: Path,
|
|
85
|
+
*ignore: Type[Ignore],
|
|
86
|
+
extract_dir: str = ".",
|
|
87
|
+
dryrun: bool = False,
|
|
88
|
+
copy_bundle_to: pathlib.Path | None = None,
|
|
89
|
+
copy_style: CopyFiles = "loaded_modules",
|
|
90
|
+
) -> CodeBundle:
|
|
91
|
+
"""
|
|
92
|
+
Build the code bundle for the current environment.
|
|
93
|
+
:param from_dir: The directory to bundle of the code to bundle. This is the root directory for the source.
|
|
94
|
+
:param extract_dir: The directory to extract the code bundle to, when in the container. It defaults to the current
|
|
95
|
+
working directory.
|
|
96
|
+
:param ignore: The list of ignores to apply. This is a list of Ignore classes.
|
|
97
|
+
:param dryrun: If dryrun is enabled, files will not be uploaded to the control plane.
|
|
98
|
+
:param copy_bundle_to: If set, the bundle will be copied to this path. This is used for testing purposes.
|
|
99
|
+
:param copy_style: What to put into the tarball. (either all, or loaded_modules. if none, skip this function)
|
|
100
|
+
|
|
101
|
+
:return: The code bundle, which contains the path where the code was zipped to.
|
|
102
|
+
"""
|
|
103
|
+
logger.debug("Building code bundle.")
|
|
104
|
+
from flyte.remote import upload_file
|
|
105
|
+
|
|
106
|
+
if not ignore:
|
|
107
|
+
ignore = (StandardIgnore, GitIgnore)
|
|
108
|
+
|
|
109
|
+
logger.debug(f"Finding files to bundle, ignoring as configured by: {ignore}")
|
|
110
|
+
files, digest = list_files_to_bundle(from_dir, True, *ignore, copy_style=copy_style)
|
|
111
|
+
if logger.getEffectiveLevel() <= logging.INFO:
|
|
112
|
+
print_ls_tree(from_dir, files)
|
|
113
|
+
|
|
114
|
+
logger.debug("Building code bundle.")
|
|
115
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
116
|
+
bundle_path, tar_size, archive_size = create_bundle(from_dir, pathlib.Path(tmp_dir), files, digest)
|
|
117
|
+
logger.info(f"Code bundle created at {bundle_path}, size: {tar_size} MB, archive size: {archive_size} MB")
|
|
118
|
+
if not dryrun:
|
|
119
|
+
hash_digest, remote_path = await upload_file(bundle_path)
|
|
120
|
+
logger.info(f"Code bundle uploaded to {remote_path}")
|
|
121
|
+
else:
|
|
122
|
+
remote_path = "na"
|
|
123
|
+
if copy_bundle_to:
|
|
124
|
+
import shutil
|
|
125
|
+
|
|
126
|
+
# Copy the bundle to the given path
|
|
127
|
+
shutil.copy(bundle_path, copy_bundle_to)
|
|
128
|
+
remote_path = str(copy_bundle_to / bundle_path.name)
|
|
129
|
+
_, hash_digest, _ = hash_file(file_path=bundle_path)
|
|
130
|
+
return CodeBundle(tgz=remote_path, destination=extract_dir, computed_version=hash_digest)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@log(level=logging.INFO)
|
|
134
|
+
async def download_bundle(bundle: CodeBundle) -> pathlib.Path:
|
|
135
|
+
"""
|
|
136
|
+
Downloads a code bundle (tgz | pkl) to the local destination path.
|
|
137
|
+
:param bundle: The code bundle to download.
|
|
138
|
+
|
|
139
|
+
:return: The path to the downloaded code bundle.
|
|
140
|
+
"""
|
|
141
|
+
dest = pathlib.Path(bundle.destination)
|
|
142
|
+
if not dest.is_dir():
|
|
143
|
+
raise ValueError(f"Destination path should be a directory, found {dest}, {dest.stat()}")
|
|
144
|
+
|
|
145
|
+
# TODO make storage apis better to accept pathlib.Path
|
|
146
|
+
if bundle.tgz:
|
|
147
|
+
downloaded_bundle = dest / os.path.basename(bundle.tgz)
|
|
148
|
+
# Download the tgz file
|
|
149
|
+
path = await storage.get(bundle.tgz, str(downloaded_bundle.absolute()))
|
|
150
|
+
downloaded_bundle = pathlib.Path(path)
|
|
151
|
+
# NOTE the os.path.join(destination, ''). This is to ensure that the given path is in fact a directory and all
|
|
152
|
+
# downloaded data should be copied into this directory. We do this to account for a difference in behavior in
|
|
153
|
+
# fsspec, which requires a trailing slash in case of pre-existing directory.
|
|
154
|
+
process = await asyncio.create_subprocess_exec(
|
|
155
|
+
"tar",
|
|
156
|
+
"-xvf",
|
|
157
|
+
str(downloaded_bundle),
|
|
158
|
+
"-C",
|
|
159
|
+
str(dest),
|
|
160
|
+
stdout=asyncio.subprocess.PIPE,
|
|
161
|
+
stderr=asyncio.subprocess.PIPE,
|
|
162
|
+
)
|
|
163
|
+
stdout, stderr = await process.communicate()
|
|
164
|
+
|
|
165
|
+
if process.returncode != 0:
|
|
166
|
+
raise RuntimeError(stderr.decode())
|
|
167
|
+
return downloaded_bundle.absolute()
|
|
168
|
+
|
|
169
|
+
elif bundle.pkl:
|
|
170
|
+
# Lets gunzip the pkl file
|
|
171
|
+
|
|
172
|
+
downloaded_bundle = dest / os.path.basename(bundle.pkl)
|
|
173
|
+
# Download the tgz file
|
|
174
|
+
path = await storage.get(bundle.pkl, str(downloaded_bundle.absolute()))
|
|
175
|
+
downloaded_bundle = pathlib.Path(path)
|
|
176
|
+
return downloaded_bundle.absolute()
|
|
177
|
+
else:
|
|
178
|
+
raise ValueError("Code bundle should be either tgz or pkl, found neither.")
|
flyte/_context.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextvars
|
|
4
|
+
from dataclasses import dataclass, replace
|
|
5
|
+
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, ParamSpec, TypeVar
|
|
6
|
+
|
|
7
|
+
from flyte._datastructures import GroupData, RawDataPath, TaskContext
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from flyte.report import Report
|
|
11
|
+
|
|
12
|
+
P = ParamSpec("P") # capture the function's parameters
|
|
13
|
+
R = TypeVar("R") # return type
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True, kw_only=True)
|
|
17
|
+
class ContextData:
|
|
18
|
+
"""
|
|
19
|
+
A ContextData cannot be created without an execution. Even for local execution's there should be an execution ID
|
|
20
|
+
|
|
21
|
+
:param: action The action ID of the current execution. This is always set, within a run.
|
|
22
|
+
:param: group_data If nested in a group the current group information
|
|
23
|
+
:param: task_context The context of the current task execution, this is what is available to the user, it is set
|
|
24
|
+
when the task is executed through `run` methods. If the Task is executed as regular python methods, this
|
|
25
|
+
will be None.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
group_data: Optional[GroupData] = None
|
|
29
|
+
task_context: Optional[TaskContext] = None
|
|
30
|
+
raw_data_path: Optional[RawDataPath] = None
|
|
31
|
+
|
|
32
|
+
def replace(self, **kwargs) -> ContextData:
|
|
33
|
+
return replace(self, **kwargs)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Context:
|
|
37
|
+
"""
|
|
38
|
+
A context class to hold the current execution context.
|
|
39
|
+
This is not coroutine safe, it assumes that the context is set in a single thread.
|
|
40
|
+
You should use the `contextual_run` function to run a function in a new context tree.
|
|
41
|
+
|
|
42
|
+
A context tree is defined as a tree of contexts, where under the root, all coroutines that were started in
|
|
43
|
+
this context tree can access the context mutations, but no coroutine, created outside of the context tree can access
|
|
44
|
+
the context mutations.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, data: ContextData):
|
|
48
|
+
if data is None:
|
|
49
|
+
raise ValueError("Cannot create a new context without contextdata.")
|
|
50
|
+
self._data = data
|
|
51
|
+
self._id = id(self) # Immutable unique identifier
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def data(self) -> ContextData:
|
|
55
|
+
"""Viewable data."""
|
|
56
|
+
return self._data
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def raw_data(self) -> RawDataPath:
|
|
60
|
+
"""
|
|
61
|
+
Get the raw data prefix for the current context first by looking up the task context, then the raw data path
|
|
62
|
+
"""
|
|
63
|
+
if self.data and self.data.task_context and self.data.task_context.raw_data_path:
|
|
64
|
+
return self.data.task_context.raw_data_path
|
|
65
|
+
if self.data and self.data.raw_data_path:
|
|
66
|
+
return self.data.raw_data_path
|
|
67
|
+
raise ValueError("Raw data path has not been set in the context.")
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def id(self) -> int:
|
|
71
|
+
"""Viewable ID."""
|
|
72
|
+
return self._id
|
|
73
|
+
|
|
74
|
+
def replace_task_context(self, tctx: TaskContext) -> Context:
|
|
75
|
+
"""
|
|
76
|
+
Replace the task context in the current context.
|
|
77
|
+
"""
|
|
78
|
+
return Context(self.data.replace(task_context=tctx))
|
|
79
|
+
|
|
80
|
+
def new_raw_data_path(self, raw_data_path: RawDataPath) -> Context:
|
|
81
|
+
"""
|
|
82
|
+
Return a copy of the context with the given raw data path object
|
|
83
|
+
"""
|
|
84
|
+
return Context(self.data.replace(raw_data_path=raw_data_path))
|
|
85
|
+
|
|
86
|
+
def get_report(self) -> Optional[Report]:
|
|
87
|
+
"""
|
|
88
|
+
Returns a report if within a task context, else a None
|
|
89
|
+
:return:
|
|
90
|
+
"""
|
|
91
|
+
if self.data.task_context:
|
|
92
|
+
return self.data.task_context.report
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
def is_task_context(self) -> bool:
|
|
96
|
+
"""
|
|
97
|
+
Returns true if the context is a task context
|
|
98
|
+
:return:
|
|
99
|
+
"""
|
|
100
|
+
return self.data.task_context is not None
|
|
101
|
+
|
|
102
|
+
def __enter__(self):
|
|
103
|
+
"""Enter the context, setting it as the current context."""
|
|
104
|
+
self._token = root_context_var.set(self)
|
|
105
|
+
return self
|
|
106
|
+
|
|
107
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
108
|
+
"""Exit the context, restoring the previous context."""
|
|
109
|
+
root_context_var.reset(self._token)
|
|
110
|
+
|
|
111
|
+
async def __aenter__(self):
|
|
112
|
+
"""Async version of context entry."""
|
|
113
|
+
self._token = root_context_var.set(self)
|
|
114
|
+
return self
|
|
115
|
+
|
|
116
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
117
|
+
"""Async version of context exit."""
|
|
118
|
+
root_context_var.reset(self._token)
|
|
119
|
+
|
|
120
|
+
def __repr__(self):
|
|
121
|
+
return f"{self.data}"
|
|
122
|
+
|
|
123
|
+
def __str__(self):
|
|
124
|
+
return self.__repr__()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
# Global context variable to hold the current context
|
|
128
|
+
root_context_var = contextvars.ContextVar("root", default=Context(data=ContextData()))
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def ctx() -> Optional[TaskContext]:
|
|
132
|
+
"""Retrieve the current task context from the context variable."""
|
|
133
|
+
return internal_ctx().data.task_context
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def internal_ctx() -> Context:
|
|
137
|
+
"""Retrieve the current context from the context variable."""
|
|
138
|
+
return root_context_var.get()
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
async def contextual_run(func: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs) -> R:
|
|
142
|
+
"""
|
|
143
|
+
Run a function with a new context subtree.
|
|
144
|
+
"""
|
|
145
|
+
_ctx = contextvars.copy_context()
|
|
146
|
+
return await _ctx.run(func, *args, **kwargs)
|
flyte/_datastructures.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import os
|
|
5
|
+
import pathlib
|
|
6
|
+
import tempfile
|
|
7
|
+
from dataclasses import dataclass, field, replace
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type
|
|
9
|
+
|
|
10
|
+
from flyte._docstring import Docstring
|
|
11
|
+
from flyte._interface import extract_return_annotation
|
|
12
|
+
from flyte._logging import logger
|
|
13
|
+
from flyte._utils.helpers import base36_encode
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from flyte._internal.imagebuild.image_builder import ImageCache
|
|
17
|
+
from flyte.report import Report
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def generate_random_name() -> str:
|
|
21
|
+
"""
|
|
22
|
+
Generate a random name for the task. This is used to create unique names for tasks.
|
|
23
|
+
TODO we can use unique-namer in the future, for now its just guids
|
|
24
|
+
"""
|
|
25
|
+
from uuid import uuid4
|
|
26
|
+
|
|
27
|
+
return str(uuid4()) # Placeholder for actual random name generation logic
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True, kw_only=True)
|
|
31
|
+
class ActionID:
|
|
32
|
+
"""
|
|
33
|
+
A class representing the ID of an Action, nested within a Run. This is used to identify a specific action on a task.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
name: str
|
|
37
|
+
run_name: str | None = None
|
|
38
|
+
project: str | None = None
|
|
39
|
+
domain: str | None = None
|
|
40
|
+
org: str | None = None
|
|
41
|
+
|
|
42
|
+
def __post_init__(self):
|
|
43
|
+
if self.run_name is None:
|
|
44
|
+
object.__setattr__(self, "run_name", self.name)
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def create_random(cls):
|
|
48
|
+
name = generate_random_name()
|
|
49
|
+
return cls(name=name, run_name=name)
|
|
50
|
+
|
|
51
|
+
def new_sub_action(self, name: str | None = None) -> ActionID:
|
|
52
|
+
"""
|
|
53
|
+
Create a new sub-run with the given name. If name is None, a random name will be generated.
|
|
54
|
+
"""
|
|
55
|
+
if name is None:
|
|
56
|
+
name = generate_random_name()
|
|
57
|
+
return replace(self, name=name)
|
|
58
|
+
|
|
59
|
+
def new_sub_action_from(self, task_name: str, input_hash: str, group: str | None) -> ActionID:
|
|
60
|
+
"""Make a deterministic name"""
|
|
61
|
+
import hashlib
|
|
62
|
+
|
|
63
|
+
components = f"{self.run_name}-{self.name}-{input_hash}-{task_name}" + (f"-{group}" if group else "")
|
|
64
|
+
# has the components into something deterministic
|
|
65
|
+
bytes_digest = hashlib.md5(components.encode()).digest()
|
|
66
|
+
new_name = base36_encode(bytes_digest)
|
|
67
|
+
return self.new_sub_action(new_name)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass(frozen=True, kw_only=True)
|
|
71
|
+
class RawDataPath:
|
|
72
|
+
"""
|
|
73
|
+
A class representing the raw data path for a task. This is used to store the raw data for the task execution and
|
|
74
|
+
also get mutations on the path.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
path: str
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def from_local_folder(cls, local_folder: str | pathlib.Path | None = None) -> RawDataPath:
|
|
81
|
+
"""
|
|
82
|
+
Create a new context attribute object, with local path given. Will be created if it doesn't exist.
|
|
83
|
+
:return: Path to the temporary directory
|
|
84
|
+
"""
|
|
85
|
+
match local_folder:
|
|
86
|
+
case pathlib.Path():
|
|
87
|
+
local_folder.mkdir(parents=True, exist_ok=True)
|
|
88
|
+
return RawDataPath(path=str(local_folder))
|
|
89
|
+
case None:
|
|
90
|
+
# Create a temporary directory for data storage
|
|
91
|
+
p = tempfile.mkdtemp()
|
|
92
|
+
logger.debug(f"Creating temporary directory for data storage: {p}")
|
|
93
|
+
return RawDataPath(path=p)
|
|
94
|
+
case str():
|
|
95
|
+
return RawDataPath(path=local_folder)
|
|
96
|
+
case _:
|
|
97
|
+
raise ValueError(f"Invalid local path {local_folder}")
|
|
98
|
+
|
|
99
|
+
def get_random_remote_path(self, file_name: Optional[str] = None) -> str:
|
|
100
|
+
"""
|
|
101
|
+
Returns a random path for uploading a file/directory to.
|
|
102
|
+
|
|
103
|
+
:param file_name: If given, will be joined after a randomly generated portion.
|
|
104
|
+
:return:
|
|
105
|
+
"""
|
|
106
|
+
import random
|
|
107
|
+
from uuid import UUID
|
|
108
|
+
|
|
109
|
+
import fsspec
|
|
110
|
+
from fsspec.utils import get_protocol
|
|
111
|
+
|
|
112
|
+
random_string = UUID(int=random.getrandbits(128)).hex
|
|
113
|
+
file_prefix = self.path
|
|
114
|
+
|
|
115
|
+
protocol = get_protocol(file_prefix)
|
|
116
|
+
if "file" in protocol:
|
|
117
|
+
local_path = pathlib.Path(file_prefix) / random_string
|
|
118
|
+
if file_name:
|
|
119
|
+
# Only if file name is given do we create the parent, because it may be needed as a folder otherwise
|
|
120
|
+
local_path = local_path / file_name
|
|
121
|
+
if not local_path.exists():
|
|
122
|
+
local_path.parent.mkdir(exist_ok=True, parents=True)
|
|
123
|
+
local_path.touch()
|
|
124
|
+
return str(local_path.absolute())
|
|
125
|
+
|
|
126
|
+
fs = fsspec.filesystem(protocol)
|
|
127
|
+
if file_prefix.endswith(fs.sep):
|
|
128
|
+
file_prefix = file_prefix[:-1]
|
|
129
|
+
remote_path = fs.sep.join([file_prefix, random_string])
|
|
130
|
+
if file_name:
|
|
131
|
+
remote_path = fs.sep.join([remote_path, file_name])
|
|
132
|
+
return remote_path
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@dataclass(frozen=True)
|
|
136
|
+
class GroupData:
|
|
137
|
+
name: str
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@dataclass(frozen=True, kw_only=True)
|
|
141
|
+
class TaskContext:
|
|
142
|
+
"""
|
|
143
|
+
A context class to hold the current task executions context.
|
|
144
|
+
This can be used to access various contextual parameters in the task execution by the user.
|
|
145
|
+
|
|
146
|
+
:param action: The action ID of the current execution. This is always set, within a run.
|
|
147
|
+
:param version: The version of the executed task. This is set when the task is executed by an action and will be
|
|
148
|
+
set on all sub-actions.
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
action: ActionID
|
|
152
|
+
version: str
|
|
153
|
+
raw_data_path: RawDataPath
|
|
154
|
+
output_path: str
|
|
155
|
+
run_base_dir: str
|
|
156
|
+
report: Report
|
|
157
|
+
group_data: GroupData | None = None
|
|
158
|
+
checkpoints: Checkpoints | None = None
|
|
159
|
+
code_bundle: CodeBundle | None = None
|
|
160
|
+
compiled_image_cache: ImageCache | None = None
|
|
161
|
+
data: Dict[str, Any] = field(default_factory=dict)
|
|
162
|
+
|
|
163
|
+
def replace(self, **kwargs) -> TaskContext:
|
|
164
|
+
if "data" in kwargs:
|
|
165
|
+
rec_data = kwargs.pop("data")
|
|
166
|
+
if rec_data is None:
|
|
167
|
+
return replace(self, **kwargs)
|
|
168
|
+
data = {}
|
|
169
|
+
if self.data is not None:
|
|
170
|
+
data = self.data.copy()
|
|
171
|
+
data.update(rec_data)
|
|
172
|
+
kwargs.update({"data": data})
|
|
173
|
+
return replace(self, **kwargs)
|
|
174
|
+
|
|
175
|
+
def __getitem__(self, key: str) -> Optional[Any]:
|
|
176
|
+
return self.data.get(key)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@dataclass(frozen=True, kw_only=True)
|
|
180
|
+
class CodeBundle:
|
|
181
|
+
"""
|
|
182
|
+
A class representing a code bundle for a task. This is used to package the code and the inflation path.
|
|
183
|
+
The code bundle computes the version of the code using the hash of the code.
|
|
184
|
+
|
|
185
|
+
:param computed_version: The version of the code bundle. This is the hash of the code.
|
|
186
|
+
:param destination: The destination path for the code bundle to be inflated to.
|
|
187
|
+
:param tgz: Optional path to the tgz file.
|
|
188
|
+
:param pkl: Optional path to the pkl file.
|
|
189
|
+
:param downloaded_path: The path to the downloaded code bundle. This is only available during runtime, when
|
|
190
|
+
the code bundle has been downloaded and inflated.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
computed_version: str
|
|
194
|
+
destination: str = "."
|
|
195
|
+
tgz: str | None = None
|
|
196
|
+
pkl: str | None = None
|
|
197
|
+
downloaded_path: pathlib.Path | None = None
|
|
198
|
+
|
|
199
|
+
# runtime_dependencies: Tuple[str, ...] = field(default_factory=tuple) In the future if we want we could add this
|
|
200
|
+
# but this messes up actors, spark etc
|
|
201
|
+
|
|
202
|
+
def __post_init__(self):
|
|
203
|
+
if self.tgz is None and self.pkl is None:
|
|
204
|
+
raise ValueError("Either tgz or pkl must be provided")
|
|
205
|
+
|
|
206
|
+
def with_downloaded_path(self, path: pathlib.Path) -> CodeBundle:
|
|
207
|
+
"""
|
|
208
|
+
Create a new CodeBundle with the given downloaded path.
|
|
209
|
+
"""
|
|
210
|
+
return replace(self, downloaded_path=path)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@dataclass(frozen=True)
|
|
214
|
+
class Checkpoints:
|
|
215
|
+
"""
|
|
216
|
+
A class representing the checkpoints for a task. This is used to store the checkpoints for the task execution.
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
prev_checkpoint_path: str | None
|
|
220
|
+
checkpoint_path: str | None
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
@dataclass(frozen=True)
|
|
224
|
+
class NativeInterface:
|
|
225
|
+
"""
|
|
226
|
+
A class representing the native interface for a task. This is used to interact with the task and its execution
|
|
227
|
+
context.
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
inputs: Dict[str, Tuple[Type, Any]]
|
|
231
|
+
outputs: Dict[str, Type]
|
|
232
|
+
docstring: Optional[Docstring] = field(default=None)
|
|
233
|
+
|
|
234
|
+
def has_outputs(self) -> bool:
|
|
235
|
+
"""
|
|
236
|
+
Check if the task has outputs. This is used to determine if the task has outputs or not.
|
|
237
|
+
"""
|
|
238
|
+
return self.outputs is not None and len(self.outputs) > 0
|
|
239
|
+
|
|
240
|
+
@classmethod
|
|
241
|
+
def from_types(cls, inputs: Dict[str, Type], outputs: Dict[str, Type]) -> NativeInterface:
|
|
242
|
+
"""
|
|
243
|
+
Create a new NativeInterface from the given types. This is used to create a native interface for the task.
|
|
244
|
+
"""
|
|
245
|
+
return cls(inputs={k: (v, inspect.Parameter.empty) for k, v in inputs.items()}, outputs=outputs)
|
|
246
|
+
|
|
247
|
+
@classmethod
|
|
248
|
+
def from_callable(cls, func: Callable) -> NativeInterface:
|
|
249
|
+
"""
|
|
250
|
+
Extract the native interface from the given function. This is used to create a native interface for the task.
|
|
251
|
+
"""
|
|
252
|
+
sig = inspect.signature(func)
|
|
253
|
+
|
|
254
|
+
# Extract parameter details (name, type, default value)
|
|
255
|
+
param_info = {name: (param.annotation, param.default) for name, param in sig.parameters.items()}
|
|
256
|
+
|
|
257
|
+
# Get return type
|
|
258
|
+
outputs = extract_return_annotation(sig.return_annotation)
|
|
259
|
+
return cls(inputs=param_info, outputs=outputs)
|
|
260
|
+
|
|
261
|
+
def convert_to_kwargs(self, *args, **kwargs) -> Dict[str, Any]:
|
|
262
|
+
"""
|
|
263
|
+
Convert the given arguments to keyword arguments based on the native interface. This is used to convert the
|
|
264
|
+
arguments to the correct types for the task execution.
|
|
265
|
+
"""
|
|
266
|
+
# Convert positional arguments to keyword arguments
|
|
267
|
+
if len(args) > len(self.inputs):
|
|
268
|
+
raise ValueError(f"Too many positional arguments provided, inputs {self.inputs.keys()}, args {len(args)}")
|
|
269
|
+
for arg, input_name in zip(args, self.inputs.keys()):
|
|
270
|
+
kwargs[input_name] = arg
|
|
271
|
+
return kwargs
|
|
272
|
+
|
|
273
|
+
def get_input_types(self) -> Dict[str, Type]:
|
|
274
|
+
"""
|
|
275
|
+
Get the input types for the task. This is used to get the types of the inputs for the task execution.
|
|
276
|
+
"""
|
|
277
|
+
return {k: v[0] for k, v in self.inputs.items()}
|
|
278
|
+
|
|
279
|
+
def __repr__(self):
|
|
280
|
+
"""
|
|
281
|
+
Returns a string representation of the task interface.
|
|
282
|
+
"""
|
|
283
|
+
i = "("
|
|
284
|
+
if self.inputs:
|
|
285
|
+
initial = True
|
|
286
|
+
for key, tpe in self.inputs.items():
|
|
287
|
+
if not initial:
|
|
288
|
+
i += ", "
|
|
289
|
+
initial = False
|
|
290
|
+
tp = tpe[0] if isinstance(tpe[0], str) else tpe[0].__name__
|
|
291
|
+
i += f"{key}: {tp}"
|
|
292
|
+
if tpe[1] is not inspect.Parameter.empty:
|
|
293
|
+
i += f" = {tpe[1]}"
|
|
294
|
+
i += ")"
|
|
295
|
+
if self.outputs:
|
|
296
|
+
initial = True
|
|
297
|
+
multi = len(self.outputs) > 1
|
|
298
|
+
i += " -> "
|
|
299
|
+
if multi:
|
|
300
|
+
i += "("
|
|
301
|
+
for key, tpe in self.outputs.items():
|
|
302
|
+
if not initial:
|
|
303
|
+
i += ", "
|
|
304
|
+
initial = False
|
|
305
|
+
tp = tpe.__name__ if isinstance(tpe, type) else tpe
|
|
306
|
+
i += f"{key}: {tp}"
|
|
307
|
+
if multi:
|
|
308
|
+
i += ")"
|
|
309
|
+
return i + ":"
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
@dataclass
|
|
313
|
+
class SerializationContext:
|
|
314
|
+
"""
|
|
315
|
+
This object holds serialization time contextual information, that can be used when serializing the task and
|
|
316
|
+
various parameters of a tasktemplate. This is only available when the task is being serialized and can be
|
|
317
|
+
during a deployment or runtime.
|
|
318
|
+
|
|
319
|
+
:param version: The version of the task
|
|
320
|
+
:param code_bundle: The code bundle for the task. This is used to package the code and the inflation path.
|
|
321
|
+
:param input_path: The path to the inputs for the task. This is used to determine where the inputs will be located
|
|
322
|
+
:param output_path: The path to the outputs for the task. This is used to determine where the outputs will be
|
|
323
|
+
located
|
|
324
|
+
"""
|
|
325
|
+
|
|
326
|
+
version: str
|
|
327
|
+
project: str | None = None
|
|
328
|
+
domain: str | None = None
|
|
329
|
+
org: str | None = None
|
|
330
|
+
code_bundle: Optional[CodeBundle] = None
|
|
331
|
+
input_path: str = "{{.input}}"
|
|
332
|
+
output_path: str = "{{.outputPrefix}}"
|
|
333
|
+
_entrypoint_path: str = field(default="_bin/runtime.py", init=False)
|
|
334
|
+
image_cache: ImageCache | None = None
|
|
335
|
+
root_dir: Optional[pathlib.Path] = None
|
|
336
|
+
|
|
337
|
+
def get_entrypoint_path(self, interpreter_path: str) -> str:
|
|
338
|
+
"""
|
|
339
|
+
Get the entrypoint path for the task. This is used to determine the entrypoint for the task execution.
|
|
340
|
+
:param interpreter_path: The path to the interpreter (python)
|
|
341
|
+
"""
|
|
342
|
+
return os.path.join(os.path.dirname(os.path.dirname(interpreter_path)), self._entrypoint_path)
|