flyte 0.0.1b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +62 -0
- flyte/_api_commons.py +3 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +126 -0
- flyte/_build.py +25 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +146 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_common.py +287 -0
- flyte/_cli/_create.py +42 -0
- flyte/_cli/_delete.py +23 -0
- flyte/_cli/_deploy.py +140 -0
- flyte/_cli/_get.py +235 -0
- flyte/_cli/_run.py +152 -0
- flyte/_cli/main.py +72 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +339 -0
- flyte/_code_bundle/bundle.py +178 -0
- flyte/_context.py +146 -0
- flyte/_datastructures.py +342 -0
- flyte/_deploy.py +202 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +43 -0
- flyte/_group.py +31 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +760 -0
- flyte/_initialize.py +634 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +115 -0
- flyte/_internal/controllers/_local_controller.py +118 -0
- flyte/_internal/controllers/_trace.py +40 -0
- flyte/_internal/controllers/pbhash.py +39 -0
- flyte/_internal/controllers/remote/__init__.py +40 -0
- flyte/_internal/controllers/remote/_action.py +141 -0
- flyte/_internal/controllers/remote/_client.py +43 -0
- flyte/_internal/controllers/remote/_controller.py +361 -0
- flyte/_internal/controllers/remote/_core.py +402 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +416 -0
- flyte/_internal/imagebuild/image_builder.py +241 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +199 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +210 -0
- flyte/_internal/runtime/taskrunner.py +190 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +124 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +69 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +106 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +133 -0
- flyte/_protos/workflow/run_service_pb2.pyi +175 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +58 -0
- flyte/_protos/workflow/state_service_pb2.pyi +71 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +72 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +44 -0
- flyte/_protos/workflow/task_service_pb2.pyi +31 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +411 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +367 -0
- flyte/_task_environment.py +200 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +128 -0
- flyte/_utils/__init__.py +20 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/coro_management.py +25 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +108 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +143 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +273 -0
- flyte/io/__init__.py +11 -0
- flyte/io/_dataframe.py +0 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +468 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/io/pickle/transformer.py +117 -0
- flyte/io/structured_dataset/__init__.py +129 -0
- flyte/io/structured_dataset/basic_dfs.py +219 -0
- flyte/io/structured_dataset/structured_dataset.py +1061 -0
- flyte/py.typed +0 -0
- flyte/remote/__init__.py +25 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +131 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +184 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +155 -0
- flyte/remote/_logs.py +116 -0
- flyte/remote/_project.py +86 -0
- flyte/remote/_run.py +873 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +227 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +24 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +251 -0
- flyte/storage/_utils.py +5 -0
- flyte/types/__init__.py +13 -0
- flyte/types/_interface.py +25 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2210 -0
- flyte/types/_utils.py +80 -0
- flyte-0.0.1b0.dist-info/METADATA +179 -0
- flyte-0.0.1b0.dist-info/RECORD +390 -0
- flyte-0.0.1b0.dist-info/WHEEL +5 -0
- flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
- flyte-0.0.1b0.dist-info/top_level.txt +1 -0
- union/__init__.py +54 -0
- union/_api_commons.py +3 -0
- union/_bin/__init__.py +0 -0
- union/_bin/runtime.py +113 -0
- union/_build.py +25 -0
- union/_cache/__init__.py +12 -0
- union/_cache/cache.py +141 -0
- union/_cache/defaults.py +9 -0
- union/_cache/policy_function_body.py +42 -0
- union/_cli/__init__.py +0 -0
- union/_cli/_common.py +263 -0
- union/_cli/_create.py +40 -0
- union/_cli/_delete.py +23 -0
- union/_cli/_deploy.py +120 -0
- union/_cli/_get.py +162 -0
- union/_cli/_params.py +579 -0
- union/_cli/_run.py +150 -0
- union/_cli/main.py +72 -0
- union/_code_bundle/__init__.py +8 -0
- union/_code_bundle/_ignore.py +113 -0
- union/_code_bundle/_packaging.py +187 -0
- union/_code_bundle/_utils.py +342 -0
- union/_code_bundle/bundle.py +176 -0
- union/_context.py +146 -0
- union/_datastructures.py +295 -0
- union/_deploy.py +185 -0
- union/_doc.py +29 -0
- union/_docstring.py +26 -0
- union/_environment.py +43 -0
- union/_group.py +31 -0
- union/_hash.py +23 -0
- union/_image.py +760 -0
- union/_initialize.py +585 -0
- union/_interface.py +84 -0
- union/_internal/__init__.py +3 -0
- union/_internal/controllers/__init__.py +77 -0
- union/_internal/controllers/_local_controller.py +77 -0
- union/_internal/controllers/pbhash.py +39 -0
- union/_internal/controllers/remote/__init__.py +40 -0
- union/_internal/controllers/remote/_action.py +131 -0
- union/_internal/controllers/remote/_client.py +43 -0
- union/_internal/controllers/remote/_controller.py +169 -0
- union/_internal/controllers/remote/_core.py +341 -0
- union/_internal/controllers/remote/_informer.py +260 -0
- union/_internal/controllers/remote/_service_protocol.py +44 -0
- union/_internal/imagebuild/__init__.py +11 -0
- union/_internal/imagebuild/docker_builder.py +416 -0
- union/_internal/imagebuild/image_builder.py +243 -0
- union/_internal/imagebuild/remote_builder.py +0 -0
- union/_internal/resolvers/__init__.py +0 -0
- union/_internal/resolvers/_task_module.py +31 -0
- union/_internal/resolvers/common.py +24 -0
- union/_internal/resolvers/default.py +27 -0
- union/_internal/runtime/__init__.py +0 -0
- union/_internal/runtime/convert.py +163 -0
- union/_internal/runtime/entrypoints.py +121 -0
- union/_internal/runtime/io.py +136 -0
- union/_internal/runtime/resources_serde.py +134 -0
- union/_internal/runtime/task_serde.py +202 -0
- union/_internal/runtime/taskrunner.py +179 -0
- union/_internal/runtime/types_serde.py +53 -0
- union/_logging.py +124 -0
- union/_protos/__init__.py +0 -0
- union/_protos/common/authorization_pb2.py +66 -0
- union/_protos/common/authorization_pb2.pyi +106 -0
- union/_protos/common/authorization_pb2_grpc.py +4 -0
- union/_protos/common/identifier_pb2.py +71 -0
- union/_protos/common/identifier_pb2.pyi +82 -0
- union/_protos/common/identifier_pb2_grpc.py +4 -0
- union/_protos/common/identity_pb2.py +48 -0
- union/_protos/common/identity_pb2.pyi +72 -0
- union/_protos/common/identity_pb2_grpc.py +4 -0
- union/_protos/common/list_pb2.py +36 -0
- union/_protos/common/list_pb2.pyi +69 -0
- union/_protos/common/list_pb2_grpc.py +4 -0
- union/_protos/common/policy_pb2.py +37 -0
- union/_protos/common/policy_pb2.pyi +27 -0
- union/_protos/common/policy_pb2_grpc.py +4 -0
- union/_protos/common/role_pb2.py +37 -0
- union/_protos/common/role_pb2.pyi +51 -0
- union/_protos/common/role_pb2_grpc.py +4 -0
- union/_protos/common/runtime_version_pb2.py +28 -0
- union/_protos/common/runtime_version_pb2.pyi +24 -0
- union/_protos/common/runtime_version_pb2_grpc.py +4 -0
- union/_protos/logs/dataplane/payload_pb2.py +96 -0
- union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- union/_protos/secret/definition_pb2.py +49 -0
- union/_protos/secret/definition_pb2.pyi +93 -0
- union/_protos/secret/definition_pb2_grpc.py +4 -0
- union/_protos/secret/payload_pb2.py +62 -0
- union/_protos/secret/payload_pb2.pyi +94 -0
- union/_protos/secret/payload_pb2_grpc.py +4 -0
- union/_protos/secret/secret_pb2.py +38 -0
- union/_protos/secret/secret_pb2.pyi +6 -0
- union/_protos/secret/secret_pb2_grpc.py +198 -0
- union/_protos/validate/validate/validate_pb2.py +76 -0
- union/_protos/workflow/node_execution_service_pb2.py +26 -0
- union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- union/_protos/workflow/queue_service_pb2.py +75 -0
- union/_protos/workflow/queue_service_pb2.pyi +103 -0
- union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- union/_protos/workflow/run_definition_pb2.py +100 -0
- union/_protos/workflow/run_definition_pb2.pyi +256 -0
- union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/run_logs_service_pb2.py +41 -0
- union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- union/_protos/workflow/run_service_pb2.py +133 -0
- union/_protos/workflow/run_service_pb2.pyi +173 -0
- union/_protos/workflow/run_service_pb2_grpc.py +412 -0
- union/_protos/workflow/state_service_pb2.py +58 -0
- union/_protos/workflow/state_service_pb2.pyi +69 -0
- union/_protos/workflow/state_service_pb2_grpc.py +138 -0
- union/_protos/workflow/task_definition_pb2.py +72 -0
- union/_protos/workflow/task_definition_pb2.pyi +65 -0
- union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/task_service_pb2.py +44 -0
- union/_protos/workflow/task_service_pb2.pyi +31 -0
- union/_protos/workflow/task_service_pb2_grpc.py +104 -0
- union/_resources.py +226 -0
- union/_retry.py +32 -0
- union/_reusable_environment.py +25 -0
- union/_run.py +374 -0
- union/_secret.py +61 -0
- union/_task.py +354 -0
- union/_task_environment.py +186 -0
- union/_timeout.py +47 -0
- union/_tools.py +27 -0
- union/_utils/__init__.py +11 -0
- union/_utils/asyn.py +119 -0
- union/_utils/file_handling.py +71 -0
- union/_utils/helpers.py +46 -0
- union/_utils/lazy_module.py +54 -0
- union/_utils/uv_script_parser.py +49 -0
- union/_version.py +21 -0
- union/connectors/__init__.py +0 -0
- union/errors.py +128 -0
- union/extras/__init__.py +5 -0
- union/extras/_container.py +263 -0
- union/io/__init__.py +11 -0
- union/io/_dataframe.py +0 -0
- union/io/_dir.py +425 -0
- union/io/_file.py +418 -0
- union/io/pickle/__init__.py +0 -0
- union/io/pickle/transformer.py +117 -0
- union/io/structured_dataset/__init__.py +122 -0
- union/io/structured_dataset/basic_dfs.py +219 -0
- union/io/structured_dataset/structured_dataset.py +1057 -0
- union/py.typed +0 -0
- union/remote/__init__.py +23 -0
- union/remote/_client/__init__.py +0 -0
- union/remote/_client/_protocols.py +129 -0
- union/remote/_client/auth/__init__.py +12 -0
- union/remote/_client/auth/_authenticators/__init__.py +0 -0
- union/remote/_client/auth/_authenticators/base.py +391 -0
- union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- union/remote/_client/auth/_authenticators/device_code.py +120 -0
- union/remote/_client/auth/_authenticators/external_command.py +77 -0
- union/remote/_client/auth/_authenticators/factory.py +200 -0
- union/remote/_client/auth/_authenticators/pkce.py +515 -0
- union/remote/_client/auth/_channel.py +184 -0
- union/remote/_client/auth/_client_config.py +83 -0
- union/remote/_client/auth/_default_html.py +32 -0
- union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
- union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
- union/remote/_client/auth/_keyring.py +154 -0
- union/remote/_client/auth/_token_client.py +258 -0
- union/remote/_client/auth/errors.py +16 -0
- union/remote/_client/controlplane.py +86 -0
- union/remote/_data.py +149 -0
- union/remote/_logs.py +74 -0
- union/remote/_project.py +86 -0
- union/remote/_run.py +820 -0
- union/remote/_secret.py +132 -0
- union/remote/_task.py +193 -0
- union/report/__init__.py +3 -0
- union/report/_report.py +178 -0
- union/report/_template.html +124 -0
- union/storage/__init__.py +24 -0
- union/storage/_remote_fs.py +34 -0
- union/storage/_storage.py +247 -0
- union/storage/_utils.py +5 -0
- union/types/__init__.py +11 -0
- union/types/_renderer.py +162 -0
- union/types/_string_literals.py +120 -0
- union/types/_type_engine.py +2131 -0
- union/types/_utils.py +80 -0
union/_cli/_params.py
ADDED
|
@@ -0,0 +1,579 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import dataclasses
|
|
3
|
+
import datetime
|
|
4
|
+
import enum
|
|
5
|
+
import importlib
|
|
6
|
+
import importlib.util
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import pathlib
|
|
10
|
+
import sys
|
|
11
|
+
import typing
|
|
12
|
+
import typing as t
|
|
13
|
+
from typing import get_args
|
|
14
|
+
|
|
15
|
+
import rich_click as click
|
|
16
|
+
import yaml
|
|
17
|
+
from flyteidl.core.literals_pb2 import Literal
|
|
18
|
+
from flyteidl.core.types_pb2 import BlobType, LiteralType, SimpleType
|
|
19
|
+
|
|
20
|
+
from union._logging import logger
|
|
21
|
+
from union.io import Dir, File
|
|
22
|
+
from union.io.pickle.transformer import FlytePickleTransformer
|
|
23
|
+
from union.storage._remote_fs import RemoteFSPathResolver
|
|
24
|
+
from union.types import TypeEngine
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# ---------------------------------------------------
|
|
28
|
+
# TODO replace these
|
|
29
|
+
class ArtifactQuery:
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def is_remote(v: str) -> bool:
|
|
34
|
+
return False
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class StructuredDataset:
|
|
38
|
+
def __init__(self, uri: str | None = None, dataframe: typing.Any = None):
|
|
39
|
+
self.uri = uri
|
|
40
|
+
self.dataframe = dataframe
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# ---------------------------------------------------
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def is_pydantic_basemodel(python_type: typing.Type) -> bool:
|
|
47
|
+
"""
|
|
48
|
+
Checks if the python type is a pydantic BaseModel
|
|
49
|
+
"""
|
|
50
|
+
try:
|
|
51
|
+
import pydantic # noqa: F401
|
|
52
|
+
except ImportError:
|
|
53
|
+
return False
|
|
54
|
+
else:
|
|
55
|
+
try:
|
|
56
|
+
from pydantic import BaseModel as BaseModelV2
|
|
57
|
+
from pydantic.v1 import BaseModel as BaseModelV1
|
|
58
|
+
|
|
59
|
+
return issubclass(python_type, BaseModelV1) or issubclass(python_type, BaseModelV2)
|
|
60
|
+
except ImportError:
|
|
61
|
+
from pydantic import BaseModel
|
|
62
|
+
|
|
63
|
+
return issubclass(python_type, BaseModel)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def key_value_callback(_: typing.Any, param: str, values: typing.List[str]) -> typing.Optional[typing.Dict[str, str]]:
|
|
67
|
+
"""
|
|
68
|
+
Callback for click to parse key-value pairs.
|
|
69
|
+
"""
|
|
70
|
+
if not values:
|
|
71
|
+
return None
|
|
72
|
+
result = {}
|
|
73
|
+
for v in values:
|
|
74
|
+
if "=" not in v:
|
|
75
|
+
raise click.BadParameter(f"Expected key-value pair of the form key=value, got {v}")
|
|
76
|
+
k, val = v.split("=", 1)
|
|
77
|
+
result[k.strip()] = val.strip()
|
|
78
|
+
return result
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def labels_callback(_: typing.Any, param: str, values: typing.List[str]) -> typing.Optional[typing.Dict[str, str]]:
|
|
82
|
+
"""
|
|
83
|
+
Callback for click to parse labels.
|
|
84
|
+
"""
|
|
85
|
+
if not values:
|
|
86
|
+
return None
|
|
87
|
+
result = {}
|
|
88
|
+
for v in values:
|
|
89
|
+
if "=" not in v:
|
|
90
|
+
result[v.strip()] = ""
|
|
91
|
+
else:
|
|
92
|
+
k, val = v.split("=", 1)
|
|
93
|
+
result[k.strip()] = val.strip()
|
|
94
|
+
return result
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class DirParamType(click.ParamType):
|
|
98
|
+
name = "directory path"
|
|
99
|
+
|
|
100
|
+
def convert(
|
|
101
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
102
|
+
) -> typing.Any:
|
|
103
|
+
if isinstance(value, ArtifactQuery):
|
|
104
|
+
return value
|
|
105
|
+
|
|
106
|
+
# set remote_directory to false if running pyflyte run locally. This makes sure that the original
|
|
107
|
+
# directory is used and not a random one.
|
|
108
|
+
remote_directory = None if getattr(ctx.obj, "is_remote", False) else False
|
|
109
|
+
if not is_remote(value):
|
|
110
|
+
p = pathlib.Path(value)
|
|
111
|
+
if not p.exists() or not p.is_dir():
|
|
112
|
+
raise click.BadParameter(f"parameter should be a valid flytedirectory path, {value}")
|
|
113
|
+
return Dir(path=value, remote_directory=remote_directory)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class StructuredDatasetParamType(click.ParamType):
|
|
117
|
+
"""
|
|
118
|
+
TODO handle column types
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
name = "structured dataset path (dir/file)"
|
|
122
|
+
|
|
123
|
+
def convert(
|
|
124
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
125
|
+
) -> typing.Any:
|
|
126
|
+
if isinstance(value, ArtifactQuery):
|
|
127
|
+
return value
|
|
128
|
+
if isinstance(value, str):
|
|
129
|
+
return StructuredDataset(uri=value)
|
|
130
|
+
elif isinstance(value, StructuredDataset):
|
|
131
|
+
return value
|
|
132
|
+
return StructuredDataset(dataframe=value)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class FileParamType(click.ParamType):
|
|
136
|
+
name = "file path"
|
|
137
|
+
|
|
138
|
+
def convert(
|
|
139
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
140
|
+
) -> typing.Any:
|
|
141
|
+
if isinstance(value, ArtifactQuery):
|
|
142
|
+
return value
|
|
143
|
+
# set remote_directory to false if running pyflyte run locally. This makes sure that the original
|
|
144
|
+
# file is used and not a random one.
|
|
145
|
+
remote_path = None if getattr(ctx.obj, "is_remote", False) else False
|
|
146
|
+
if not is_remote(value):
|
|
147
|
+
p = pathlib.Path(value)
|
|
148
|
+
if not p.exists() or not p.is_file():
|
|
149
|
+
raise click.BadParameter(f"parameter should be a valid file path, {value}")
|
|
150
|
+
return File(path=value, remote_path=remote_path)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class PickleParamType(click.ParamType):
|
|
154
|
+
name = "pickle"
|
|
155
|
+
|
|
156
|
+
def get_metavar(self, param: click.Parameter) -> t.Optional[str]:
|
|
157
|
+
return "Python Object <Module>:<Object>"
|
|
158
|
+
|
|
159
|
+
def convert(
|
|
160
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
161
|
+
) -> typing.Any:
|
|
162
|
+
if not isinstance(value, str):
|
|
163
|
+
return value
|
|
164
|
+
parts = value.split(":")
|
|
165
|
+
if len(parts) != 2:
|
|
166
|
+
if ctx and ctx.obj and ctx.obj.verbose > 0:
|
|
167
|
+
click.echo(f"Did not receive a string in the expected format <MODULE>:<VAR>, falling back to: {value}")
|
|
168
|
+
return value
|
|
169
|
+
try:
|
|
170
|
+
sys.path.insert(0, os.getcwd())
|
|
171
|
+
m = importlib.import_module(parts[0])
|
|
172
|
+
return m.__getattribute__(parts[1])
|
|
173
|
+
except ModuleNotFoundError as e:
|
|
174
|
+
raise click.BadParameter(f"Failed to import module {parts[0]}, error: {e}")
|
|
175
|
+
except AttributeError as e:
|
|
176
|
+
raise click.BadParameter(f"Failed to find attribute {parts[1]} in module {parts[0]}, error: {e}")
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class JSONIteratorParamType(click.ParamType):
|
|
180
|
+
name = "json iterator"
|
|
181
|
+
|
|
182
|
+
def convert(
|
|
183
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
184
|
+
) -> typing.Any:
|
|
185
|
+
return value
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
import re
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def parse_iso8601_duration(iso_duration: str) -> datetime.timedelta:
|
|
192
|
+
pattern = re.compile(
|
|
193
|
+
r"^P" # Starts with 'P'
|
|
194
|
+
r"(?:(?P<days>\d+)D)?" # Optional days
|
|
195
|
+
r"(?:T" # Optional time part
|
|
196
|
+
r"(?:(?P<hours>\d+)H)?"
|
|
197
|
+
r"(?:(?P<minutes>\d+)M)?"
|
|
198
|
+
r"(?:(?P<seconds>\d+)S)?"
|
|
199
|
+
r")?$"
|
|
200
|
+
)
|
|
201
|
+
match = pattern.match(iso_duration)
|
|
202
|
+
if not match:
|
|
203
|
+
raise ValueError(f"Invalid ISO 8601 duration format: {iso_duration}")
|
|
204
|
+
|
|
205
|
+
parts = {k: int(v) if v else 0 for k, v in match.groupdict().items()}
|
|
206
|
+
return datetime.timedelta(**parts)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def parse_human_durations(text: str) -> list[datetime.timedelta]:
|
|
210
|
+
raw_parts = text.strip("[]").split("|")
|
|
211
|
+
durations = []
|
|
212
|
+
|
|
213
|
+
for part in raw_parts:
|
|
214
|
+
part = part.strip().lower()
|
|
215
|
+
|
|
216
|
+
# Match 1:24 or :45
|
|
217
|
+
m_colon = re.match(r"^(?:(\d+):)?(\d+)$", part)
|
|
218
|
+
if m_colon:
|
|
219
|
+
minutes = int(m_colon.group(1)) if m_colon.group(1) else 0
|
|
220
|
+
seconds = int(m_colon.group(2))
|
|
221
|
+
durations.append(datetime.timedelta(minutes=minutes, seconds=seconds))
|
|
222
|
+
continue
|
|
223
|
+
|
|
224
|
+
# Match "10 days", "1 minute", etc.
|
|
225
|
+
m_units = re.match(r"^(\d+)\s*(day|hour|minute|second)s?$", part)
|
|
226
|
+
if m_units:
|
|
227
|
+
value = int(m_units.group(1))
|
|
228
|
+
unit = m_units.group(2)
|
|
229
|
+
durations.append(datetime.timedelta(**{unit + "s": value}))
|
|
230
|
+
continue
|
|
231
|
+
|
|
232
|
+
print(f"Warning: could not parse '{part}'")
|
|
233
|
+
|
|
234
|
+
return durations
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def parse_duration(s: str) -> datetime.timedelta:
|
|
238
|
+
try:
|
|
239
|
+
return parse_iso8601_duration(s)
|
|
240
|
+
except ValueError:
|
|
241
|
+
parts = parse_human_durations(s)
|
|
242
|
+
if not parts:
|
|
243
|
+
raise ValueError(f"Could not parse duration: {s}")
|
|
244
|
+
return sum(parts, datetime.timedelta())
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
class DateTimeType(click.DateTime):
|
|
248
|
+
_NOW_FMT = "now"
|
|
249
|
+
_TODAY_FMT = "today"
|
|
250
|
+
_FIXED_FORMATS: typing.ClassVar[typing.List[str]] = [_NOW_FMT, _TODAY_FMT]
|
|
251
|
+
_FLOATING_FORMATS: typing.ClassVar[typing.List[str]] = ["<FORMAT> - <ISO8601 duration>"]
|
|
252
|
+
_ADDITONAL_FORMATS: typing.ClassVar[typing.List[str]] = [*_FIXED_FORMATS, *_FLOATING_FORMATS]
|
|
253
|
+
_FLOATING_FORMAT_PATTERN = r"(.+)\s+([-+])\s+(.+)"
|
|
254
|
+
|
|
255
|
+
def __init__(self):
|
|
256
|
+
super().__init__()
|
|
257
|
+
self.formats.extend(self._ADDITONAL_FORMATS)
|
|
258
|
+
|
|
259
|
+
def _datetime_from_format(
|
|
260
|
+
self, value: str, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
261
|
+
) -> datetime.datetime:
|
|
262
|
+
if value in self._FIXED_FORMATS:
|
|
263
|
+
if value == self._NOW_FMT:
|
|
264
|
+
return datetime.datetime.now()
|
|
265
|
+
if value == self._TODAY_FMT:
|
|
266
|
+
n = datetime.datetime.now()
|
|
267
|
+
return datetime.datetime(n.year, n.month, n.day)
|
|
268
|
+
return super().convert(value, param, ctx)
|
|
269
|
+
|
|
270
|
+
def convert(
|
|
271
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
272
|
+
) -> typing.Any:
|
|
273
|
+
if isinstance(value, ArtifactQuery):
|
|
274
|
+
return value
|
|
275
|
+
|
|
276
|
+
if isinstance(value, str) and " " in value:
|
|
277
|
+
import re
|
|
278
|
+
|
|
279
|
+
m = re.match(self._FLOATING_FORMAT_PATTERN, value)
|
|
280
|
+
if m:
|
|
281
|
+
parts = m.groups()
|
|
282
|
+
if len(parts) != 3:
|
|
283
|
+
raise click.BadParameter(f"Expected format <FORMAT> - <ISO8601 duration>, got {value}")
|
|
284
|
+
dt = self._datetime_from_format(parts[0], param, ctx)
|
|
285
|
+
try:
|
|
286
|
+
delta = parse_duration(parts[2])
|
|
287
|
+
except Exception as e:
|
|
288
|
+
raise click.BadParameter(
|
|
289
|
+
f"Matched format {self._FLOATING_FORMATS}, but failed to parse duration {parts[2]}, error: {e}"
|
|
290
|
+
)
|
|
291
|
+
if parts[1] == "-":
|
|
292
|
+
return dt - delta
|
|
293
|
+
return dt + delta
|
|
294
|
+
else:
|
|
295
|
+
value = datetime.datetime.fromisoformat(value)
|
|
296
|
+
|
|
297
|
+
return self._datetime_from_format(value, param, ctx)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class DurationParamType(click.ParamType):
|
|
301
|
+
name = "[1:24 | :22 | 1 minute | 10 days | ...]"
|
|
302
|
+
|
|
303
|
+
def convert(
|
|
304
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
305
|
+
) -> typing.Any:
|
|
306
|
+
if isinstance(value, ArtifactQuery):
|
|
307
|
+
return value
|
|
308
|
+
if value is None:
|
|
309
|
+
raise click.BadParameter("None value cannot be converted to a Duration type.")
|
|
310
|
+
return parse_duration(value)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
class EnumParamType(click.Choice):
|
|
314
|
+
def __init__(self, enum_type: typing.Type[enum.Enum]):
|
|
315
|
+
super().__init__([str(e.value) for e in enum_type])
|
|
316
|
+
self._enum_type = enum_type
|
|
317
|
+
|
|
318
|
+
def convert(
|
|
319
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
320
|
+
) -> enum.Enum:
|
|
321
|
+
if isinstance(value, ArtifactQuery):
|
|
322
|
+
return value
|
|
323
|
+
if isinstance(value, self._enum_type):
|
|
324
|
+
return value
|
|
325
|
+
return self._enum_type(super().convert(value, param, ctx))
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
class UnionParamType(click.ParamType):
|
|
329
|
+
"""
|
|
330
|
+
A composite type that allows for multiple types to be specified. This is used for union types.
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
def __init__(self, types: typing.List[click.ParamType]):
|
|
334
|
+
super().__init__()
|
|
335
|
+
self._types = self._sort_precedence(types)
|
|
336
|
+
|
|
337
|
+
@property
|
|
338
|
+
def name(self) -> str:
|
|
339
|
+
return "|".join([t.name for t in self._types])
|
|
340
|
+
|
|
341
|
+
@staticmethod
|
|
342
|
+
def _sort_precedence(tp: typing.List[click.ParamType]) -> typing.List[click.ParamType]:
|
|
343
|
+
unprocessed = []
|
|
344
|
+
str_types = []
|
|
345
|
+
others = []
|
|
346
|
+
for p in tp:
|
|
347
|
+
if isinstance(p, type(click.UNPROCESSED)):
|
|
348
|
+
unprocessed.append(p)
|
|
349
|
+
elif isinstance(p, type(click.STRING)):
|
|
350
|
+
str_types.append(p)
|
|
351
|
+
else:
|
|
352
|
+
others.append(p)
|
|
353
|
+
return others + str_types + unprocessed
|
|
354
|
+
|
|
355
|
+
def convert(
|
|
356
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
357
|
+
) -> typing.Any:
|
|
358
|
+
"""
|
|
359
|
+
Important to implement NoneType / Optional.
|
|
360
|
+
Also could we just determine the click types from the python types
|
|
361
|
+
"""
|
|
362
|
+
if isinstance(value, ArtifactQuery):
|
|
363
|
+
return value
|
|
364
|
+
for p in self._types:
|
|
365
|
+
try:
|
|
366
|
+
return p.convert(value, param, ctx)
|
|
367
|
+
except Exception as e:
|
|
368
|
+
logger.debug(f"Ignoring conversion error for type {p} trying other variants in Union. Error: {e}")
|
|
369
|
+
raise click.BadParameter(f"Failed to convert {value} to any of the types {self._types}")
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class JsonParamType(click.ParamType):
|
|
373
|
+
name = "json object OR json/yaml file path"
|
|
374
|
+
|
|
375
|
+
def __init__(self, python_type: typing.Type):
|
|
376
|
+
super().__init__()
|
|
377
|
+
self._python_type = python_type
|
|
378
|
+
|
|
379
|
+
def _parse(self, value: typing.Any, param: typing.Optional[click.Parameter]):
|
|
380
|
+
if isinstance(value, (dict, list)):
|
|
381
|
+
return value
|
|
382
|
+
try:
|
|
383
|
+
return json.loads(value)
|
|
384
|
+
except Exception:
|
|
385
|
+
try:
|
|
386
|
+
# We failed to load the json, so we'll try to load it as a file
|
|
387
|
+
if os.path.exists(value):
|
|
388
|
+
# if the value is a yaml file, we'll try to load it as yaml
|
|
389
|
+
if value.endswith(".yaml") or value.endswith(".yml"):
|
|
390
|
+
with open(value, "r") as f:
|
|
391
|
+
return yaml.safe_load(f)
|
|
392
|
+
with open(value, "r") as f:
|
|
393
|
+
return json.load(f)
|
|
394
|
+
raise
|
|
395
|
+
except json.JSONDecodeError as e:
|
|
396
|
+
raise click.BadParameter(f"parameter {param} should be a valid json object, {value}, error: {e}")
|
|
397
|
+
|
|
398
|
+
def convert(
|
|
399
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
400
|
+
) -> typing.Any:
|
|
401
|
+
if isinstance(value, ArtifactQuery):
|
|
402
|
+
return value
|
|
403
|
+
if value is None:
|
|
404
|
+
raise click.BadParameter("None value cannot be converted to a Json type.")
|
|
405
|
+
|
|
406
|
+
parsed_value = self._parse(value, param)
|
|
407
|
+
|
|
408
|
+
# We compare the origin type because the json parsed value for list or dict is always a list or dict without
|
|
409
|
+
# the covariant type information.
|
|
410
|
+
if type(parsed_value) == typing.get_origin(self._python_type) or type(parsed_value) == self._python_type:
|
|
411
|
+
# Indexing the return value of get_args will raise an error for native dict and list types.
|
|
412
|
+
# We don't support native list/dict types with nested dataclasses.
|
|
413
|
+
if get_args(self._python_type) == ():
|
|
414
|
+
return parsed_value
|
|
415
|
+
elif isinstance(parsed_value, list) and dataclasses.is_dataclass(get_args(self._python_type)[0]):
|
|
416
|
+
j = JsonParamType(get_args(self._python_type)[0])
|
|
417
|
+
return [j.convert(v, param, ctx) for v in parsed_value]
|
|
418
|
+
elif isinstance(parsed_value, dict) and dataclasses.is_dataclass(get_args(self._python_type)[1]):
|
|
419
|
+
j = JsonParamType(get_args(self._python_type)[1])
|
|
420
|
+
return {k: j.convert(v, param, ctx) for k, v in parsed_value.items()}
|
|
421
|
+
|
|
422
|
+
return parsed_value
|
|
423
|
+
|
|
424
|
+
if is_pydantic_basemodel(self._python_type):
|
|
425
|
+
"""
|
|
426
|
+
This function supports backward compatibility for the Pydantic v1 plugin.
|
|
427
|
+
If the class is a Pydantic BaseModel, it attempts to parse JSON input using
|
|
428
|
+
the appropriate version of Pydantic (v1 or v2).
|
|
429
|
+
"""
|
|
430
|
+
try:
|
|
431
|
+
if importlib.util.find_spec("pydantic.v1") is not None:
|
|
432
|
+
from pydantic import BaseModel as BaseModelV2
|
|
433
|
+
|
|
434
|
+
if issubclass(self._python_type, BaseModelV2):
|
|
435
|
+
return self._python_type.model_validate_json(
|
|
436
|
+
json.dumps(parsed_value), strict=False, context={"deserialize": True}
|
|
437
|
+
)
|
|
438
|
+
except ImportError:
|
|
439
|
+
pass
|
|
440
|
+
|
|
441
|
+
# The behavior of the Pydantic v1 plugin.
|
|
442
|
+
return self._python_type.parse_raw(json.dumps(parsed_value))
|
|
443
|
+
return None
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def modify_literal_uris(lit: Literal):
|
|
447
|
+
"""
|
|
448
|
+
Modifies the literal object recursively to replace the URIs with the native paths.
|
|
449
|
+
"""
|
|
450
|
+
if lit.collection:
|
|
451
|
+
for l in lit.collection.literals:
|
|
452
|
+
modify_literal_uris(l)
|
|
453
|
+
elif lit.map:
|
|
454
|
+
for k, v in lit.map.literals.items():
|
|
455
|
+
modify_literal_uris(v)
|
|
456
|
+
elif lit.scalar:
|
|
457
|
+
if lit.scalar.blob and lit.scalar.blob.uri and lit.scalar.blob.uri.startswith(RemoteFSPathResolver.protocol):
|
|
458
|
+
lit.scalar.blob._uri = RemoteFSPathResolver.resolve_remote_path(lit.scalar.blob.uri)
|
|
459
|
+
elif lit.scalar.union:
|
|
460
|
+
modify_literal_uris(lit.scalar.union.value)
|
|
461
|
+
elif (
|
|
462
|
+
lit.scalar.structured_dataset
|
|
463
|
+
and lit.scalar.structured_dataset.uri
|
|
464
|
+
and lit.scalar.structured_dataset.uri.startswith(RemoteFSPathResolver.protocol)
|
|
465
|
+
):
|
|
466
|
+
lit.scalar.structured_dataset._uri = RemoteFSPathResolver.resolve_remote_path(
|
|
467
|
+
lit.scalar.structured_dataset.uri
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
SIMPLE_TYPE_CONVERTER: typing.Dict[SimpleType, click.ParamType] = {
|
|
472
|
+
SimpleType.FLOAT: click.FLOAT,
|
|
473
|
+
SimpleType.INTEGER: click.INT,
|
|
474
|
+
SimpleType.STRING: click.STRING,
|
|
475
|
+
SimpleType.BOOLEAN: click.BOOL,
|
|
476
|
+
SimpleType.DURATION: DurationParamType(),
|
|
477
|
+
SimpleType.DATETIME: DateTimeType(),
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def literal_type_to_click_type(lt: LiteralType, python_type: typing.Type) -> click.ParamType:
|
|
482
|
+
"""
|
|
483
|
+
Converts a Flyte LiteralType given a python_type to a click.ParamType
|
|
484
|
+
"""
|
|
485
|
+
if lt.simple:
|
|
486
|
+
if lt.simple == SimpleType.STRUCT:
|
|
487
|
+
ct = JsonParamType(python_type)
|
|
488
|
+
ct.name = f"JSON object {python_type.__name__}"
|
|
489
|
+
return ct
|
|
490
|
+
if lt.simple in SIMPLE_TYPE_CONVERTER:
|
|
491
|
+
return SIMPLE_TYPE_CONVERTER[lt.simple]
|
|
492
|
+
raise NotImplementedError(f"Type {lt.simple} is not supported in pyflyte run")
|
|
493
|
+
|
|
494
|
+
if lt.enum_type:
|
|
495
|
+
return EnumParamType(python_type) # type: ignore
|
|
496
|
+
|
|
497
|
+
if lt.structured_dataset_type:
|
|
498
|
+
return StructuredDatasetParamType()
|
|
499
|
+
|
|
500
|
+
if lt.collection_type or lt.map_value_type:
|
|
501
|
+
ct = JsonParamType(python_type)
|
|
502
|
+
if lt.collection_type:
|
|
503
|
+
ct.name = "json list"
|
|
504
|
+
else:
|
|
505
|
+
ct.name = "json dictionary"
|
|
506
|
+
return ct
|
|
507
|
+
|
|
508
|
+
if lt.blob:
|
|
509
|
+
if lt.blob.dimensionality == BlobType.BlobDimensionality.SINGLE:
|
|
510
|
+
if lt.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT:
|
|
511
|
+
return PickleParamType()
|
|
512
|
+
# elif lt.blob.format == JSONIteratorTransformer.JSON_ITERATOR_FORMAT:
|
|
513
|
+
# return JSONIteratorParamType()
|
|
514
|
+
return FileParamType()
|
|
515
|
+
return DirParamType()
|
|
516
|
+
|
|
517
|
+
if lt.union_type:
|
|
518
|
+
cts = []
|
|
519
|
+
for i in range(len(lt.union_type.variants)):
|
|
520
|
+
variant = lt.union_type.variants[i]
|
|
521
|
+
variant_python_type = typing.get_args(python_type)[i]
|
|
522
|
+
ct = literal_type_to_click_type(variant, variant_python_type)
|
|
523
|
+
cts.append(ct)
|
|
524
|
+
return UnionParamType(cts)
|
|
525
|
+
|
|
526
|
+
return click.UNPROCESSED
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
class FlyteLiteralConverter(object):
|
|
530
|
+
name = "literal_type"
|
|
531
|
+
|
|
532
|
+
def __init__(
|
|
533
|
+
self,
|
|
534
|
+
literal_type: LiteralType,
|
|
535
|
+
python_type: typing.Type,
|
|
536
|
+
is_remote: bool,
|
|
537
|
+
):
|
|
538
|
+
self._is_remote = is_remote
|
|
539
|
+
self._literal_type = literal_type
|
|
540
|
+
self._python_type = python_type
|
|
541
|
+
self._click_type = literal_type_to_click_type(literal_type, python_type)
|
|
542
|
+
|
|
543
|
+
@property
|
|
544
|
+
def click_type(self) -> click.ParamType:
|
|
545
|
+
return self._click_type
|
|
546
|
+
|
|
547
|
+
def is_bool(self) -> bool:
|
|
548
|
+
return self.click_type == click.BOOL
|
|
549
|
+
|
|
550
|
+
def convert(
|
|
551
|
+
self, ctx: click.Context, param: typing.Optional[click.Parameter], value: typing.Any
|
|
552
|
+
) -> typing.Union[Literal, typing.Any]:
|
|
553
|
+
"""
|
|
554
|
+
Convert the value to a Flyte Literal or a python native type. This is used by click to convert the input.
|
|
555
|
+
"""
|
|
556
|
+
if isinstance(value, ArtifactQuery):
|
|
557
|
+
return value
|
|
558
|
+
try:
|
|
559
|
+
# If the expected Python type is datetime.date, adjust the value to date
|
|
560
|
+
if self._python_type is datetime.date:
|
|
561
|
+
# Click produces datetime, so converting to date to avoid type mismatch error
|
|
562
|
+
value = value.date()
|
|
563
|
+
# If the input matches the default value in the launch plan, serialization can be skipped.
|
|
564
|
+
if param and value == param.default:
|
|
565
|
+
return None
|
|
566
|
+
|
|
567
|
+
# If this is used for remote execution, then we need to convert it back to a python native type
|
|
568
|
+
if not self._is_remote:
|
|
569
|
+
return value
|
|
570
|
+
|
|
571
|
+
lit = asyncio.run(TypeEngine.to_literal(value, self._python_type, self._literal_type))
|
|
572
|
+
return lit
|
|
573
|
+
except click.BadParameter:
|
|
574
|
+
raise
|
|
575
|
+
except Exception as e:
|
|
576
|
+
raise click.BadParameter(
|
|
577
|
+
f"Failed to convert param: {param if param else 'NA'}, value: {value} to type: {self._python_type}."
|
|
578
|
+
f" Reason {e}"
|
|
579
|
+
) from e
|