flyte 0.0.1b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +62 -0
- flyte/_api_commons.py +3 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +126 -0
- flyte/_build.py +25 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +146 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_common.py +287 -0
- flyte/_cli/_create.py +42 -0
- flyte/_cli/_delete.py +23 -0
- flyte/_cli/_deploy.py +140 -0
- flyte/_cli/_get.py +235 -0
- flyte/_cli/_run.py +152 -0
- flyte/_cli/main.py +72 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +339 -0
- flyte/_code_bundle/bundle.py +178 -0
- flyte/_context.py +146 -0
- flyte/_datastructures.py +342 -0
- flyte/_deploy.py +202 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +43 -0
- flyte/_group.py +31 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +760 -0
- flyte/_initialize.py +634 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +115 -0
- flyte/_internal/controllers/_local_controller.py +118 -0
- flyte/_internal/controllers/_trace.py +40 -0
- flyte/_internal/controllers/pbhash.py +39 -0
- flyte/_internal/controllers/remote/__init__.py +40 -0
- flyte/_internal/controllers/remote/_action.py +141 -0
- flyte/_internal/controllers/remote/_client.py +43 -0
- flyte/_internal/controllers/remote/_controller.py +361 -0
- flyte/_internal/controllers/remote/_core.py +402 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +416 -0
- flyte/_internal/imagebuild/image_builder.py +241 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +199 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +210 -0
- flyte/_internal/runtime/taskrunner.py +190 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +124 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +69 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +106 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +133 -0
- flyte/_protos/workflow/run_service_pb2.pyi +175 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +58 -0
- flyte/_protos/workflow/state_service_pb2.pyi +71 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +72 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +44 -0
- flyte/_protos/workflow/task_service_pb2.pyi +31 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +411 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +367 -0
- flyte/_task_environment.py +200 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +128 -0
- flyte/_utils/__init__.py +20 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/coro_management.py +25 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +108 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +143 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +273 -0
- flyte/io/__init__.py +11 -0
- flyte/io/_dataframe.py +0 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +468 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/io/pickle/transformer.py +117 -0
- flyte/io/structured_dataset/__init__.py +129 -0
- flyte/io/structured_dataset/basic_dfs.py +219 -0
- flyte/io/structured_dataset/structured_dataset.py +1061 -0
- flyte/py.typed +0 -0
- flyte/remote/__init__.py +25 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +131 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +184 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +155 -0
- flyte/remote/_logs.py +116 -0
- flyte/remote/_project.py +86 -0
- flyte/remote/_run.py +873 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +227 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +24 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +251 -0
- flyte/storage/_utils.py +5 -0
- flyte/types/__init__.py +13 -0
- flyte/types/_interface.py +25 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2210 -0
- flyte/types/_utils.py +80 -0
- flyte-0.0.1b0.dist-info/METADATA +179 -0
- flyte-0.0.1b0.dist-info/RECORD +390 -0
- flyte-0.0.1b0.dist-info/WHEEL +5 -0
- flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
- flyte-0.0.1b0.dist-info/top_level.txt +1 -0
- union/__init__.py +54 -0
- union/_api_commons.py +3 -0
- union/_bin/__init__.py +0 -0
- union/_bin/runtime.py +113 -0
- union/_build.py +25 -0
- union/_cache/__init__.py +12 -0
- union/_cache/cache.py +141 -0
- union/_cache/defaults.py +9 -0
- union/_cache/policy_function_body.py +42 -0
- union/_cli/__init__.py +0 -0
- union/_cli/_common.py +263 -0
- union/_cli/_create.py +40 -0
- union/_cli/_delete.py +23 -0
- union/_cli/_deploy.py +120 -0
- union/_cli/_get.py +162 -0
- union/_cli/_params.py +579 -0
- union/_cli/_run.py +150 -0
- union/_cli/main.py +72 -0
- union/_code_bundle/__init__.py +8 -0
- union/_code_bundle/_ignore.py +113 -0
- union/_code_bundle/_packaging.py +187 -0
- union/_code_bundle/_utils.py +342 -0
- union/_code_bundle/bundle.py +176 -0
- union/_context.py +146 -0
- union/_datastructures.py +295 -0
- union/_deploy.py +185 -0
- union/_doc.py +29 -0
- union/_docstring.py +26 -0
- union/_environment.py +43 -0
- union/_group.py +31 -0
- union/_hash.py +23 -0
- union/_image.py +760 -0
- union/_initialize.py +585 -0
- union/_interface.py +84 -0
- union/_internal/__init__.py +3 -0
- union/_internal/controllers/__init__.py +77 -0
- union/_internal/controllers/_local_controller.py +77 -0
- union/_internal/controllers/pbhash.py +39 -0
- union/_internal/controllers/remote/__init__.py +40 -0
- union/_internal/controllers/remote/_action.py +131 -0
- union/_internal/controllers/remote/_client.py +43 -0
- union/_internal/controllers/remote/_controller.py +169 -0
- union/_internal/controllers/remote/_core.py +341 -0
- union/_internal/controllers/remote/_informer.py +260 -0
- union/_internal/controllers/remote/_service_protocol.py +44 -0
- union/_internal/imagebuild/__init__.py +11 -0
- union/_internal/imagebuild/docker_builder.py +416 -0
- union/_internal/imagebuild/image_builder.py +243 -0
- union/_internal/imagebuild/remote_builder.py +0 -0
- union/_internal/resolvers/__init__.py +0 -0
- union/_internal/resolvers/_task_module.py +31 -0
- union/_internal/resolvers/common.py +24 -0
- union/_internal/resolvers/default.py +27 -0
- union/_internal/runtime/__init__.py +0 -0
- union/_internal/runtime/convert.py +163 -0
- union/_internal/runtime/entrypoints.py +121 -0
- union/_internal/runtime/io.py +136 -0
- union/_internal/runtime/resources_serde.py +134 -0
- union/_internal/runtime/task_serde.py +202 -0
- union/_internal/runtime/taskrunner.py +179 -0
- union/_internal/runtime/types_serde.py +53 -0
- union/_logging.py +124 -0
- union/_protos/__init__.py +0 -0
- union/_protos/common/authorization_pb2.py +66 -0
- union/_protos/common/authorization_pb2.pyi +106 -0
- union/_protos/common/authorization_pb2_grpc.py +4 -0
- union/_protos/common/identifier_pb2.py +71 -0
- union/_protos/common/identifier_pb2.pyi +82 -0
- union/_protos/common/identifier_pb2_grpc.py +4 -0
- union/_protos/common/identity_pb2.py +48 -0
- union/_protos/common/identity_pb2.pyi +72 -0
- union/_protos/common/identity_pb2_grpc.py +4 -0
- union/_protos/common/list_pb2.py +36 -0
- union/_protos/common/list_pb2.pyi +69 -0
- union/_protos/common/list_pb2_grpc.py +4 -0
- union/_protos/common/policy_pb2.py +37 -0
- union/_protos/common/policy_pb2.pyi +27 -0
- union/_protos/common/policy_pb2_grpc.py +4 -0
- union/_protos/common/role_pb2.py +37 -0
- union/_protos/common/role_pb2.pyi +51 -0
- union/_protos/common/role_pb2_grpc.py +4 -0
- union/_protos/common/runtime_version_pb2.py +28 -0
- union/_protos/common/runtime_version_pb2.pyi +24 -0
- union/_protos/common/runtime_version_pb2_grpc.py +4 -0
- union/_protos/logs/dataplane/payload_pb2.py +96 -0
- union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- union/_protos/secret/definition_pb2.py +49 -0
- union/_protos/secret/definition_pb2.pyi +93 -0
- union/_protos/secret/definition_pb2_grpc.py +4 -0
- union/_protos/secret/payload_pb2.py +62 -0
- union/_protos/secret/payload_pb2.pyi +94 -0
- union/_protos/secret/payload_pb2_grpc.py +4 -0
- union/_protos/secret/secret_pb2.py +38 -0
- union/_protos/secret/secret_pb2.pyi +6 -0
- union/_protos/secret/secret_pb2_grpc.py +198 -0
- union/_protos/validate/validate/validate_pb2.py +76 -0
- union/_protos/workflow/node_execution_service_pb2.py +26 -0
- union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- union/_protos/workflow/queue_service_pb2.py +75 -0
- union/_protos/workflow/queue_service_pb2.pyi +103 -0
- union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- union/_protos/workflow/run_definition_pb2.py +100 -0
- union/_protos/workflow/run_definition_pb2.pyi +256 -0
- union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/run_logs_service_pb2.py +41 -0
- union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- union/_protos/workflow/run_service_pb2.py +133 -0
- union/_protos/workflow/run_service_pb2.pyi +173 -0
- union/_protos/workflow/run_service_pb2_grpc.py +412 -0
- union/_protos/workflow/state_service_pb2.py +58 -0
- union/_protos/workflow/state_service_pb2.pyi +69 -0
- union/_protos/workflow/state_service_pb2_grpc.py +138 -0
- union/_protos/workflow/task_definition_pb2.py +72 -0
- union/_protos/workflow/task_definition_pb2.pyi +65 -0
- union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/task_service_pb2.py +44 -0
- union/_protos/workflow/task_service_pb2.pyi +31 -0
- union/_protos/workflow/task_service_pb2_grpc.py +104 -0
- union/_resources.py +226 -0
- union/_retry.py +32 -0
- union/_reusable_environment.py +25 -0
- union/_run.py +374 -0
- union/_secret.py +61 -0
- union/_task.py +354 -0
- union/_task_environment.py +186 -0
- union/_timeout.py +47 -0
- union/_tools.py +27 -0
- union/_utils/__init__.py +11 -0
- union/_utils/asyn.py +119 -0
- union/_utils/file_handling.py +71 -0
- union/_utils/helpers.py +46 -0
- union/_utils/lazy_module.py +54 -0
- union/_utils/uv_script_parser.py +49 -0
- union/_version.py +21 -0
- union/connectors/__init__.py +0 -0
- union/errors.py +128 -0
- union/extras/__init__.py +5 -0
- union/extras/_container.py +263 -0
- union/io/__init__.py +11 -0
- union/io/_dataframe.py +0 -0
- union/io/_dir.py +425 -0
- union/io/_file.py +418 -0
- union/io/pickle/__init__.py +0 -0
- union/io/pickle/transformer.py +117 -0
- union/io/structured_dataset/__init__.py +122 -0
- union/io/structured_dataset/basic_dfs.py +219 -0
- union/io/structured_dataset/structured_dataset.py +1057 -0
- union/py.typed +0 -0
- union/remote/__init__.py +23 -0
- union/remote/_client/__init__.py +0 -0
- union/remote/_client/_protocols.py +129 -0
- union/remote/_client/auth/__init__.py +12 -0
- union/remote/_client/auth/_authenticators/__init__.py +0 -0
- union/remote/_client/auth/_authenticators/base.py +391 -0
- union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- union/remote/_client/auth/_authenticators/device_code.py +120 -0
- union/remote/_client/auth/_authenticators/external_command.py +77 -0
- union/remote/_client/auth/_authenticators/factory.py +200 -0
- union/remote/_client/auth/_authenticators/pkce.py +515 -0
- union/remote/_client/auth/_channel.py +184 -0
- union/remote/_client/auth/_client_config.py +83 -0
- union/remote/_client/auth/_default_html.py +32 -0
- union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
- union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
- union/remote/_client/auth/_keyring.py +154 -0
- union/remote/_client/auth/_token_client.py +258 -0
- union/remote/_client/auth/errors.py +16 -0
- union/remote/_client/controlplane.py +86 -0
- union/remote/_data.py +149 -0
- union/remote/_logs.py +74 -0
- union/remote/_project.py +86 -0
- union/remote/_run.py +820 -0
- union/remote/_secret.py +132 -0
- union/remote/_task.py +193 -0
- union/report/__init__.py +3 -0
- union/report/_report.py +178 -0
- union/report/_template.html +124 -0
- union/storage/__init__.py +24 -0
- union/storage/_remote_fs.py +34 -0
- union/storage/_storage.py +247 -0
- union/storage/_utils.py +5 -0
- union/types/__init__.py +11 -0
- union/types/_renderer.py +162 -0
- union/types/_string_literals.py +120 -0
- union/types/_type_engine.py +2131 -0
- union/types/_utils.py +80 -0
|
@@ -0,0 +1,2210 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import collections
|
|
5
|
+
import copy
|
|
6
|
+
import dataclasses
|
|
7
|
+
import datetime
|
|
8
|
+
import enum
|
|
9
|
+
import inspect
|
|
10
|
+
import json
|
|
11
|
+
import os
|
|
12
|
+
import sys
|
|
13
|
+
import textwrap
|
|
14
|
+
import threading
|
|
15
|
+
import typing
|
|
16
|
+
from abc import ABC, abstractmethod
|
|
17
|
+
from collections import OrderedDict
|
|
18
|
+
from functools import lru_cache
|
|
19
|
+
from types import GenericAlias
|
|
20
|
+
from typing import Any, Dict, NamedTuple, Optional, Type, cast
|
|
21
|
+
|
|
22
|
+
import msgpack
|
|
23
|
+
from flyteidl.core import interface_pb2, literals_pb2, types_pb2
|
|
24
|
+
from flyteidl.core.literals_pb2 import Binary, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Union, Void
|
|
25
|
+
from flyteidl.core.types_pb2 import LiteralType, SimpleType, TypeAnnotation, TypeStructure, UnionType
|
|
26
|
+
from fsspec.asyn import _run_coros_in_chunks # pylint: disable=W0212
|
|
27
|
+
from google.protobuf import json_format as _json_format
|
|
28
|
+
from google.protobuf import struct_pb2
|
|
29
|
+
from google.protobuf.json_format import MessageToDict as _MessageToDict
|
|
30
|
+
from google.protobuf.json_format import ParseDict as _ParseDict
|
|
31
|
+
from google.protobuf.message import Message
|
|
32
|
+
from mashumaro.codecs.json import JSONDecoder, JSONEncoder
|
|
33
|
+
from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder
|
|
34
|
+
from mashumaro.jsonschema.models import Context, JSONSchema
|
|
35
|
+
from mashumaro.jsonschema.plugins import BasePlugin
|
|
36
|
+
from mashumaro.jsonschema.schema import Instance
|
|
37
|
+
from mashumaro.mixins.json import DataClassJSONMixin
|
|
38
|
+
from typing_extensions import Annotated, get_args, get_origin
|
|
39
|
+
|
|
40
|
+
import flyte.storage as storage
|
|
41
|
+
from flyte._datastructures import NativeInterface
|
|
42
|
+
from flyte._hash import HashMethod
|
|
43
|
+
from flyte._logging import logger
|
|
44
|
+
from flyte._utils.helpers import load_proto_from_file
|
|
45
|
+
|
|
46
|
+
from ._utils import literal_types_match
|
|
47
|
+
|
|
48
|
+
T = typing.TypeVar("T")
|
|
49
|
+
|
|
50
|
+
MESSAGEPACK = "msgpack"
|
|
51
|
+
CACHE_KEY_METADATA = "cache-key-metadata"
|
|
52
|
+
SERIALIZATION_FORMAT = "serialization-format"
|
|
53
|
+
|
|
54
|
+
DEFINITIONS = "definitions"
|
|
55
|
+
TITLE = "title"
|
|
56
|
+
|
|
57
|
+
_TYPE_ENGINE_COROS_BATCH_SIZE = int(os.environ.get("_F_TE_MAX_COROS", "10"))
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True.
|
|
61
|
+
# This is relevant for cases like Dict[int, str]. If strict_map_key=False is not used,
|
|
62
|
+
# the decoder will raise an error when trying to decode keys that are not strictly typed.
|
|
63
|
+
def _default_msgpack_decoder(data: bytes) -> Any:
|
|
64
|
+
return msgpack.unpackb(data, strict_map_key=False)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def modify_literal_uris(lit: Literal):
|
|
68
|
+
"""
|
|
69
|
+
Modifies the literal object recursively to replace the URIs with the native paths in case they are of
|
|
70
|
+
type "flyte://"
|
|
71
|
+
"""
|
|
72
|
+
from flyte.storage._remote_fs import RemoteFSPathResolver
|
|
73
|
+
|
|
74
|
+
if lit.HasField("collection"):
|
|
75
|
+
for literal in lit.collection.literals:
|
|
76
|
+
modify_literal_uris(literal)
|
|
77
|
+
elif lit.HasField("map"):
|
|
78
|
+
for k, v in lit.map.literals.items():
|
|
79
|
+
modify_literal_uris(v)
|
|
80
|
+
elif lit.HasField("scalar"):
|
|
81
|
+
if (
|
|
82
|
+
lit.scalar.HasField("blob")
|
|
83
|
+
and lit.scalar.blob.uri
|
|
84
|
+
and lit.scalar.blob.uri.startswith(RemoteFSPathResolver.protocol)
|
|
85
|
+
):
|
|
86
|
+
lit.scalar.blob.uri = RemoteFSPathResolver.resolve_remote_path(lit.scalar.blob.uri)
|
|
87
|
+
elif lit.scalar.HasField("union"):
|
|
88
|
+
modify_literal_uris(lit.scalar.union.value)
|
|
89
|
+
elif (
|
|
90
|
+
lit.scalar.HasField("structured_dataset")
|
|
91
|
+
and lit.scalar.structured_dataset.uri
|
|
92
|
+
and lit.scalar.structured_dataset.uri.startswith(RemoteFSPathResolver.protocol)
|
|
93
|
+
):
|
|
94
|
+
lit.scalar.structured_dataset.uri = RemoteFSPathResolver.resolve_remote_path(
|
|
95
|
+
lit.scalar.structured_dataset.uri
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class TypeTransformerFailedError(TypeError, AssertionError, ValueError): ...
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class TypeTransformer(typing.Generic[T]):
|
|
103
|
+
"""
|
|
104
|
+
Base transformer type that should be implemented for every python native type that can be handled by flytekit
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True):
|
|
108
|
+
self._t = t
|
|
109
|
+
self._name = name
|
|
110
|
+
self._type_assertions_enabled = enable_type_assertions
|
|
111
|
+
self._msgpack_encoder: Dict[Type, MessagePackEncoder] = {}
|
|
112
|
+
self._msgpack_decoder: Dict[Type, MessagePackDecoder] = {}
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def name(self):
|
|
116
|
+
return self._name
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def python_type(self) -> Type[T]:
|
|
120
|
+
"""
|
|
121
|
+
This returns the python type
|
|
122
|
+
"""
|
|
123
|
+
return self._t
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def type_assertions_enabled(self) -> bool:
|
|
127
|
+
"""
|
|
128
|
+
Indicates if the transformer wants type assertions to be enabled at the core type engine layer
|
|
129
|
+
"""
|
|
130
|
+
return self._type_assertions_enabled
|
|
131
|
+
|
|
132
|
+
def isinstance_generic(self, obj, generic_alias):
|
|
133
|
+
origin = get_origin(generic_alias) # list from list[int])
|
|
134
|
+
|
|
135
|
+
if not isinstance(obj, origin):
|
|
136
|
+
raise TypeTransformerFailedError(f"Value '{obj}' is not of container type {origin}")
|
|
137
|
+
|
|
138
|
+
def assert_type(self, t: Type[T], v: T):
|
|
139
|
+
if sys.version_info >= (3, 10):
|
|
140
|
+
import types
|
|
141
|
+
|
|
142
|
+
if isinstance(t, types.GenericAlias):
|
|
143
|
+
return self.isinstance_generic(v, t)
|
|
144
|
+
|
|
145
|
+
if not hasattr(t, "__origin__") and not isinstance(v, t):
|
|
146
|
+
raise TypeTransformerFailedError(f"Expected value of type {t} but got '{v}' of type {type(v)}")
|
|
147
|
+
|
|
148
|
+
@abstractmethod
|
|
149
|
+
def get_literal_type(self, t: Type[T]) -> LiteralType:
|
|
150
|
+
"""
|
|
151
|
+
Converts the python type to a Flyte LiteralType
|
|
152
|
+
"""
|
|
153
|
+
raise NotImplementedError("Conversion to LiteralType should be implemented")
|
|
154
|
+
|
|
155
|
+
def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
|
|
156
|
+
"""
|
|
157
|
+
Converts the Flyte LiteralType to a python object type.
|
|
158
|
+
"""
|
|
159
|
+
raise ValueError("By default, transformers do not translate from Flyte types back to Python types")
|
|
160
|
+
|
|
161
|
+
@abstractmethod
|
|
162
|
+
async def to_literal(self, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
|
|
163
|
+
"""
|
|
164
|
+
Converts a given python_val to a Flyte Literal, assuming the given python_val matches the declared python_type.
|
|
165
|
+
Implementers should refrain from using type(python_val) instead rely on the passed in python_type. If these
|
|
166
|
+
do not match (or are not allowed) the Transformer implementer should raise an AssertionError, clearly stating
|
|
167
|
+
what was the mismatch
|
|
168
|
+
:param python_val: The actual value to be transformed
|
|
169
|
+
:param python_type: The assumed type of the value (this matches the declared type on the function)
|
|
170
|
+
:param expected: Expected Literal Type
|
|
171
|
+
"""
|
|
172
|
+
raise NotImplementedError(f"Conversion to Literal for python type {python_type} not implemented")
|
|
173
|
+
|
|
174
|
+
@abstractmethod
|
|
175
|
+
async def to_python_value(self, lv: Literal, expected_python_type: Type[T]) -> Optional[T]:
|
|
176
|
+
"""
|
|
177
|
+
Converts the given Literal to a Python Type. If the conversion cannot be done an AssertionError should be raised
|
|
178
|
+
:param lv: The received literal Value
|
|
179
|
+
:param expected_python_type: Expected native python type that should be returned
|
|
180
|
+
"""
|
|
181
|
+
raise NotImplementedError(
|
|
182
|
+
f"Conversion to python value expected type {expected_python_type} from literal not implemented"
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]:
|
|
186
|
+
"""
|
|
187
|
+
This function primarily handles deserialization for untyped dicts, dataclasses, Pydantic BaseModels, and
|
|
188
|
+
attribute access.
|
|
189
|
+
|
|
190
|
+
For untyped dict, dataclass, and pydantic basemodel:
|
|
191
|
+
Life Cycle (Untyped Dict as example):
|
|
192
|
+
python val -> msgpack bytes -> binary literal scalar -> msgpack bytes -> python val
|
|
193
|
+
(to_literal) (from_binary_idl)
|
|
194
|
+
|
|
195
|
+
For attribute access:
|
|
196
|
+
Life Cycle:
|
|
197
|
+
python val -> msgpack bytes -> binary literal scalar -> resolved golang value -> binary literal scalar
|
|
198
|
+
-> msgpack bytes -> python val
|
|
199
|
+
(to_literal) (propeller attribute access) (from_binary_idl)
|
|
200
|
+
"""
|
|
201
|
+
if binary_idl_object.tag == MESSAGEPACK:
|
|
202
|
+
try:
|
|
203
|
+
decoder = self._msgpack_decoder[expected_python_type]
|
|
204
|
+
except KeyError:
|
|
205
|
+
decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_msgpack_decoder)
|
|
206
|
+
self._msgpack_decoder[expected_python_type] = decoder
|
|
207
|
+
python_val = decoder.decode(binary_idl_object.value)
|
|
208
|
+
|
|
209
|
+
return python_val
|
|
210
|
+
else:
|
|
211
|
+
raise TypeTransformerFailedError(f"Unsupported binary format `{binary_idl_object.tag}`")
|
|
212
|
+
|
|
213
|
+
def to_html(self, python_val: T, expected_python_type: Type[T]) -> str:
|
|
214
|
+
"""
|
|
215
|
+
Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div
|
|
216
|
+
"""
|
|
217
|
+
return str(python_val)
|
|
218
|
+
|
|
219
|
+
def __repr__(self):
|
|
220
|
+
return f"{self._name} Transforms ({self._t}) to Flyte native"
|
|
221
|
+
|
|
222
|
+
def __str__(self):
|
|
223
|
+
return str(self.__repr__())
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class SimpleTransformer(TypeTransformer[T]):
|
|
227
|
+
"""
|
|
228
|
+
A Simple implementation of a type transformer that uses simple lambdas to transform and reduces boilerplate
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
def __init__(
|
|
232
|
+
self,
|
|
233
|
+
name: str,
|
|
234
|
+
t: Type[T],
|
|
235
|
+
lt: types_pb2.LiteralType,
|
|
236
|
+
to_literal_transformer: typing.Callable[[T], Literal],
|
|
237
|
+
from_literal_transformer: typing.Callable[[Literal], Optional[T]],
|
|
238
|
+
):
|
|
239
|
+
super().__init__(name, t)
|
|
240
|
+
self._type = t
|
|
241
|
+
self._lt = lt
|
|
242
|
+
self._to_literal_transformer = to_literal_transformer
|
|
243
|
+
self._from_literal_transformer = from_literal_transformer
|
|
244
|
+
|
|
245
|
+
@property
|
|
246
|
+
def base_type(self) -> Type:
|
|
247
|
+
return self._type
|
|
248
|
+
|
|
249
|
+
def get_literal_type(self, t: Optional[Type[T]] = None) -> types_pb2.LiteralType:
|
|
250
|
+
return self._lt
|
|
251
|
+
|
|
252
|
+
async def to_literal(self, python_val: T, python_type: Type[T], expected: Optional[LiteralType] = None) -> Literal:
|
|
253
|
+
if type(python_val) is not self._type:
|
|
254
|
+
raise TypeTransformerFailedError(
|
|
255
|
+
f"Expected value of type {self._type} but got '{python_val}' of type {type(python_val)}"
|
|
256
|
+
)
|
|
257
|
+
return self._to_literal_transformer(python_val)
|
|
258
|
+
|
|
259
|
+
def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]:
|
|
260
|
+
if binary_idl_object.tag == MESSAGEPACK:
|
|
261
|
+
if expected_python_type in [datetime.date, datetime.datetime, datetime.timedelta]:
|
|
262
|
+
"""
|
|
263
|
+
MessagePack doesn't support datetime, date, and timedelta.
|
|
264
|
+
However, mashumaro's MessagePackEncoder and MessagePackDecoder can convert them to str and vice versa.
|
|
265
|
+
That's why we need to use mashumaro's MessagePackDecoder here.
|
|
266
|
+
"""
|
|
267
|
+
try:
|
|
268
|
+
decoder = self._msgpack_decoder[expected_python_type]
|
|
269
|
+
except KeyError:
|
|
270
|
+
decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_msgpack_decoder)
|
|
271
|
+
self._msgpack_decoder[expected_python_type] = decoder
|
|
272
|
+
python_val = decoder.decode(binary_idl_object.value)
|
|
273
|
+
else:
|
|
274
|
+
python_val = msgpack.loads(binary_idl_object.value)
|
|
275
|
+
r"""
|
|
276
|
+
In the case below, when using Union Transformer + Simple Transformer, then `a`
|
|
277
|
+
can be converted to int, bool, str and float if we use MessagePackDecoder[expected_python_type].
|
|
278
|
+
|
|
279
|
+
Life Cycle:
|
|
280
|
+
1 -> msgpack bytes -> (1, true, "1", 1.0)
|
|
281
|
+
|
|
282
|
+
Example Code:
|
|
283
|
+
@dataclass
|
|
284
|
+
class DC:
|
|
285
|
+
a: Union[int, bool, str, float]
|
|
286
|
+
b: Union[int, bool, str, float]
|
|
287
|
+
|
|
288
|
+
@task(container_image=custom_image)
|
|
289
|
+
def add(a: Union[int, bool, str, float],
|
|
290
|
+
b: Union[int, bool, str, float]) -> Union[int, bool, str, float]:
|
|
291
|
+
return a + b
|
|
292
|
+
|
|
293
|
+
@workflow
|
|
294
|
+
def wf(dc: DC) -> Union[int, bool, str, float]:
|
|
295
|
+
return add(dc.a, dc.b)
|
|
296
|
+
|
|
297
|
+
wf(DC(1, 1))
|
|
298
|
+
"""
|
|
299
|
+
assert isinstance(python_val, expected_python_type)
|
|
300
|
+
|
|
301
|
+
return python_val
|
|
302
|
+
else:
|
|
303
|
+
raise TypeTransformerFailedError(f"Unsupported binary format `{binary_idl_object.tag}`")
|
|
304
|
+
|
|
305
|
+
async def to_python_value(self, lv: Literal, expected_python_type: Type[T]) -> T:
|
|
306
|
+
expected_python_type = get_underlying_type(expected_python_type)
|
|
307
|
+
|
|
308
|
+
if expected_python_type is not self._type:
|
|
309
|
+
raise TypeTransformerFailedError(
|
|
310
|
+
f"Cannot convert to type {expected_python_type}, only {self._type} is supported"
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
if lv.HasField("scalar") and lv.scalar.HasField("binary"):
|
|
314
|
+
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
|
|
315
|
+
|
|
316
|
+
try:
|
|
317
|
+
res = self._from_literal_transformer(lv)
|
|
318
|
+
if type(res) is not self._type:
|
|
319
|
+
raise TypeTransformerFailedError(f"Cannot convert literal {lv} to {self._type}")
|
|
320
|
+
return res
|
|
321
|
+
except AttributeError:
|
|
322
|
+
# Assume that this is because a property on `lv` was None
|
|
323
|
+
raise TypeTransformerFailedError(f"Cannot convert literal {lv} to {self._type}")
|
|
324
|
+
|
|
325
|
+
def guess_python_type(self, literal_type: types_pb2.LiteralType) -> Type[T]:
|
|
326
|
+
if literal_type.HasField("simple") and literal_type.simple == self._lt.simple:
|
|
327
|
+
return self.python_type
|
|
328
|
+
raise ValueError(f"Transformer {self} cannot reverse {literal_type}")
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
class RestrictedTypeError(Exception):
|
|
332
|
+
pass
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class RestrictedTypeTransformer(TypeTransformer[T], ABC):
|
|
336
|
+
"""
|
|
337
|
+
Types registered with the RestrictedTypeTransformer are not allowed to be converted to and from literals.
|
|
338
|
+
In other words,
|
|
339
|
+
Restricted types are not allowed to be used as inputs or outputs of tasks and workflows.
|
|
340
|
+
"""
|
|
341
|
+
|
|
342
|
+
def __init__(self, name: str, t: Type[T]):
|
|
343
|
+
super().__init__(name, t)
|
|
344
|
+
|
|
345
|
+
def get_literal_type(self, t: Optional[Type[T]] = None) -> LiteralType:
|
|
346
|
+
raise RestrictedTypeError(f"Transformer for type {self.python_type} is restricted currently")
|
|
347
|
+
|
|
348
|
+
async def to_literal(self, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
|
|
349
|
+
raise RestrictedTypeError(f"Transformer for type {self.python_type} is restricted currently")
|
|
350
|
+
|
|
351
|
+
async def to_python_value(self, lv: Literal, expected_python_type: Type[T]) -> T:
|
|
352
|
+
raise RestrictedTypeError(f"Transformer for type {self.python_type} is restricted currently")
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
class PydanticSchemaPlugin(BasePlugin):
|
|
356
|
+
"""This allows us to generate proper schemas for Pydantic models."""
|
|
357
|
+
|
|
358
|
+
def get_schema(
|
|
359
|
+
self,
|
|
360
|
+
instance: Instance,
|
|
361
|
+
ctx: Context,
|
|
362
|
+
schema: JSONSchema | None = None,
|
|
363
|
+
) -> JSONSchema | None:
|
|
364
|
+
from pydantic import BaseModel
|
|
365
|
+
|
|
366
|
+
try:
|
|
367
|
+
if issubclass(instance.type, BaseModel):
|
|
368
|
+
pydantic_schema = instance.type.model_json_schema()
|
|
369
|
+
return JSONSchema.from_dict(pydantic_schema)
|
|
370
|
+
except TypeError:
|
|
371
|
+
return None
|
|
372
|
+
return None
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
class DataclassTransformer(TypeTransformer[object]):
|
|
376
|
+
"""
|
|
377
|
+
The Dataclass Transformer provides a type transformer for dataclasses.
|
|
378
|
+
|
|
379
|
+
The dataclass is converted to and from MessagePack Bytes by the mashumaro library
|
|
380
|
+
and is transported between tasks using the Binary IDL representation.
|
|
381
|
+
Also, the type declaration will try to extract the JSON Schema for the
|
|
382
|
+
object, if possible, and pass it with the definition.
|
|
383
|
+
|
|
384
|
+
The lifecycle of the dataclass in the Flyte type system is as follows:
|
|
385
|
+
|
|
386
|
+
1. Serialization: The dataclass transformer converts the dataclass to MessagePack Bytes.
|
|
387
|
+
(1) Handle dataclass attributes to make them serializable with mashumaro.
|
|
388
|
+
(2) Use the mashumaro API to serialize the dataclass to MessagePack Bytes.
|
|
389
|
+
(3) Use MessagePack Bytes to create a Flyte Literal.
|
|
390
|
+
(4) Serialize the Flyte Literal to a Binary IDL Object.
|
|
391
|
+
|
|
392
|
+
2. Deserialization: The dataclass transformer converts the MessagePack Bytes back to a dataclass.
|
|
393
|
+
(1) Convert MessagePack Bytes to a dataclass using mashumaro.
|
|
394
|
+
(2) Handle dataclass attributes to ensure they are of the correct types.
|
|
395
|
+
|
|
396
|
+
TODO: Update the example using mashumaro instead of the older library
|
|
397
|
+
|
|
398
|
+
Example
|
|
399
|
+
|
|
400
|
+
.. code-block:: python
|
|
401
|
+
|
|
402
|
+
@dataclass
|
|
403
|
+
class Test:
|
|
404
|
+
a: int
|
|
405
|
+
b: str
|
|
406
|
+
|
|
407
|
+
t = Test(a=10,b="e")
|
|
408
|
+
JSONSchema().dump(t.schema())
|
|
409
|
+
|
|
410
|
+
Output will look like
|
|
411
|
+
|
|
412
|
+
.. code-block:: json
|
|
413
|
+
|
|
414
|
+
{'$schema': 'http://json-schema.org/draft-07/schema#',
|
|
415
|
+
'definitions': {'TestSchema': {'properties': {'a': {'title': 'a',
|
|
416
|
+
'type': 'number',
|
|
417
|
+
'format': 'integer'},
|
|
418
|
+
'b': {'title': 'b', 'type': 'string'}},
|
|
419
|
+
'type': 'object',
|
|
420
|
+
'additionalProperties': False}},
|
|
421
|
+
'$ref': '#/definitions/TestSchema'}
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
"""
|
|
425
|
+
|
|
426
|
+
def __init__(self) -> None:
|
|
427
|
+
super().__init__("Object-Dataclass-Transformer", object)
|
|
428
|
+
self._json_encoder: Dict[Type, JSONEncoder] = {}
|
|
429
|
+
self._json_decoder: Dict[Type, JSONDecoder] = {}
|
|
430
|
+
|
|
431
|
+
def assert_type(self, expected_type: Type, v: T):
|
|
432
|
+
# Skip iterating all attributes in the dataclass if the type of v already matches the expected_type
|
|
433
|
+
expected_type = get_underlying_type(expected_type)
|
|
434
|
+
if type(v) is expected_type or issubclass(type(v), expected_type):
|
|
435
|
+
return
|
|
436
|
+
|
|
437
|
+
# @dataclass
|
|
438
|
+
# class Foo:
|
|
439
|
+
# a: int = 0
|
|
440
|
+
#
|
|
441
|
+
# @task
|
|
442
|
+
# def t1(a: Foo):
|
|
443
|
+
# ...
|
|
444
|
+
#
|
|
445
|
+
# In above example, the type of v may not equal to the expected_type in some cases
|
|
446
|
+
# For example,
|
|
447
|
+
# 1. The input of t1 is another dataclass (bar), then we should raise an error
|
|
448
|
+
# 2. when using flyte remote to execute the above task, the expected_type is guess_python_type (FooSchema)
|
|
449
|
+
# by default.
|
|
450
|
+
# However, FooSchema is created by flytekit and it's not equal to the user-defined dataclass (Foo).
|
|
451
|
+
# Therefore, we should iterate all attributes in the dataclass and check the type of value in dataclass
|
|
452
|
+
# matches the expected_type.
|
|
453
|
+
|
|
454
|
+
expected_fields_dict = {}
|
|
455
|
+
|
|
456
|
+
for f in dataclasses.fields(expected_type):
|
|
457
|
+
expected_fields_dict[f.name] = cast(type, f.type)
|
|
458
|
+
|
|
459
|
+
if isinstance(v, dict):
|
|
460
|
+
original_dict = v
|
|
461
|
+
|
|
462
|
+
# Find the Optional keys in expected_fields_dict
|
|
463
|
+
optional_keys = {k for k, t in expected_fields_dict.items() if UnionTransformer.is_optional_type(t)}
|
|
464
|
+
|
|
465
|
+
# Remove the Optional keys from the keys of original_dict
|
|
466
|
+
original_key = set(original_dict.keys()) - optional_keys
|
|
467
|
+
expected_key = set(expected_fields_dict.keys()) - optional_keys
|
|
468
|
+
|
|
469
|
+
# Check if original_key is missing any keys from expected_key
|
|
470
|
+
missing_keys = expected_key - original_key
|
|
471
|
+
if missing_keys:
|
|
472
|
+
raise TypeTransformerFailedError(
|
|
473
|
+
f"The original fields are missing the following keys from the dataclass fields: "
|
|
474
|
+
f"{list(missing_keys)}"
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Check if original_key has any extra keys that are not in expected_key
|
|
478
|
+
extra_keys = original_key - expected_key
|
|
479
|
+
if extra_keys:
|
|
480
|
+
raise TypeTransformerFailedError(
|
|
481
|
+
f"The original fields have the following extra keys that are not in dataclass fields:"
|
|
482
|
+
f" {list(extra_keys)}"
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
for k, v in original_dict.items():
|
|
486
|
+
if k in expected_fields_dict:
|
|
487
|
+
if isinstance(v, dict):
|
|
488
|
+
self.assert_type(expected_fields_dict[k], v)
|
|
489
|
+
else:
|
|
490
|
+
expected_type = expected_fields_dict[k]
|
|
491
|
+
original_type = type(v)
|
|
492
|
+
if UnionTransformer.is_optional_type(expected_type):
|
|
493
|
+
expected_type = UnionTransformer.get_sub_type_in_optional(expected_type)
|
|
494
|
+
if original_type != expected_type:
|
|
495
|
+
raise TypeTransformerFailedError(
|
|
496
|
+
f"Type of Val '{original_type}' is not an instance of {expected_type}"
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
else:
|
|
500
|
+
for f in dataclasses.fields(type(v)): # type: ignore
|
|
501
|
+
original_type = cast(type, f.type)
|
|
502
|
+
if f.name not in expected_fields_dict:
|
|
503
|
+
raise TypeTransformerFailedError(
|
|
504
|
+
f"Field '{f.name}' is not present in the expected dataclass fields {expected_type.__name__}"
|
|
505
|
+
)
|
|
506
|
+
expected_type = expected_fields_dict[f.name]
|
|
507
|
+
|
|
508
|
+
if UnionTransformer.is_optional_type(original_type):
|
|
509
|
+
original_type = UnionTransformer.get_sub_type_in_optional(original_type)
|
|
510
|
+
if UnionTransformer.is_optional_type(expected_type):
|
|
511
|
+
expected_type = UnionTransformer.get_sub_type_in_optional(expected_type)
|
|
512
|
+
|
|
513
|
+
val = v.__getattribute__(f.name)
|
|
514
|
+
if dataclasses.is_dataclass(val):
|
|
515
|
+
self.assert_type(expected_type, val)
|
|
516
|
+
elif original_type != expected_type:
|
|
517
|
+
raise TypeTransformerFailedError(
|
|
518
|
+
f"Type of Val '{original_type}' is not an instance of {expected_type}"
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
def get_literal_type(self, t: Type[T]) -> LiteralType:
|
|
522
|
+
"""
|
|
523
|
+
Extracts the Literal type definition for a Dataclass and returns a type Struct.
|
|
524
|
+
If possible also extracts the JSONSchema for the dataclass.
|
|
525
|
+
"""
|
|
526
|
+
|
|
527
|
+
if is_annotated(t):
|
|
528
|
+
args = get_args(t)
|
|
529
|
+
logger.info(f"These annotations will be skipped for dataclasses = {args[1:]}")
|
|
530
|
+
# Drop all annotations and handle only the dataclass type passed in.
|
|
531
|
+
t = args[0]
|
|
532
|
+
|
|
533
|
+
schema = None
|
|
534
|
+
try:
|
|
535
|
+
# This produce JSON SCHEMA draft 2020-12
|
|
536
|
+
from mashumaro.jsonschema import build_json_schema
|
|
537
|
+
|
|
538
|
+
schema = build_json_schema(
|
|
539
|
+
self._get_origin_type_in_annotation(t), plugins=[PydanticSchemaPlugin()]
|
|
540
|
+
).to_dict()
|
|
541
|
+
except Exception as e:
|
|
542
|
+
logger.error(
|
|
543
|
+
f"Failed to extract schema for object {t}, error: {e}\n"
|
|
544
|
+
f"Possibly remove `DataClassJsonMixin` and `dataclass_json` decorator from dataclass declaration"
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
# Recursively construct the dataclass_type which contains the literal type of each field
|
|
548
|
+
literal_type = {}
|
|
549
|
+
|
|
550
|
+
hints = typing.get_type_hints(t)
|
|
551
|
+
# Get the type of each field from dataclass
|
|
552
|
+
for field in t.__dataclass_fields__.values(): # type: ignore
|
|
553
|
+
try:
|
|
554
|
+
name = field.name
|
|
555
|
+
python_type = hints.get(name, field.type)
|
|
556
|
+
literal_type[name] = TypeEngine.to_literal_type(python_type)
|
|
557
|
+
except Exception as e:
|
|
558
|
+
logger.warning(
|
|
559
|
+
"Field {} of type {} cannot be converted to a literal type. Error: {}".format(
|
|
560
|
+
field.name, field.type, e
|
|
561
|
+
)
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
# This is for attribute access in FlytePropeller.
|
|
565
|
+
ts = TypeStructure(tag="", dataclass_type=literal_type)
|
|
566
|
+
from google.protobuf.struct_pb2 import Struct
|
|
567
|
+
|
|
568
|
+
meta_struct = Struct()
|
|
569
|
+
meta_struct.update(
|
|
570
|
+
{
|
|
571
|
+
CACHE_KEY_METADATA: {
|
|
572
|
+
SERIALIZATION_FORMAT: MESSAGEPACK,
|
|
573
|
+
}
|
|
574
|
+
}
|
|
575
|
+
)
|
|
576
|
+
return types_pb2.LiteralType(
|
|
577
|
+
simple=types_pb2.SimpleType.STRUCT,
|
|
578
|
+
metadata=schema,
|
|
579
|
+
structure=ts,
|
|
580
|
+
annotation=TypeAnnotation(annotations=meta_struct),
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
async def to_literal(self, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
|
|
584
|
+
if isinstance(python_val, dict):
|
|
585
|
+
msgpack_bytes = msgpack.dumps(python_val)
|
|
586
|
+
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK)))
|
|
587
|
+
|
|
588
|
+
if not dataclasses.is_dataclass(python_val):
|
|
589
|
+
raise TypeTransformerFailedError(
|
|
590
|
+
f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for "
|
|
591
|
+
f"user defined datatypes in Flytekit"
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
# The function looks up or creates a MessagePackEncoder specifically designed for the object's type.
|
|
595
|
+
# This encoder is then used to convert a data class into MessagePack Bytes.
|
|
596
|
+
try:
|
|
597
|
+
encoder = self._msgpack_encoder[python_type]
|
|
598
|
+
except KeyError:
|
|
599
|
+
encoder = MessagePackEncoder(python_type)
|
|
600
|
+
self._msgpack_encoder[python_type] = encoder
|
|
601
|
+
|
|
602
|
+
try:
|
|
603
|
+
msgpack_bytes = encoder.encode(python_val)
|
|
604
|
+
except NotImplementedError:
|
|
605
|
+
# you can refer FlyteFile, FlyteDirectory and StructuredDataset to see how flyte types can be implemented.
|
|
606
|
+
raise NotImplementedError(
|
|
607
|
+
f"{python_type} should inherit from mashumaro.types.SerializableType"
|
|
608
|
+
f" and implement _serialize and _deserialize methods."
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK)))
|
|
612
|
+
|
|
613
|
+
def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]:
|
|
614
|
+
# dataclass will try to hash a python type when calling dataclass.schema(), but some types in the annotation are
|
|
615
|
+
# not hashable, such as Annotated[StructuredDataset, kwtypes(...)]. Therefore, we should just extract the origin
|
|
616
|
+
# type from annotated.
|
|
617
|
+
if get_origin(python_type) is list:
|
|
618
|
+
return typing.List[self._get_origin_type_in_annotation(get_args(python_type)[0])] # type: ignore
|
|
619
|
+
elif get_origin(python_type) is dict:
|
|
620
|
+
return typing.Dict[ # type: ignore
|
|
621
|
+
self._get_origin_type_in_annotation(get_args(python_type)[0]),
|
|
622
|
+
self._get_origin_type_in_annotation(get_args(python_type)[1]),
|
|
623
|
+
]
|
|
624
|
+
elif is_annotated(python_type):
|
|
625
|
+
return get_args(python_type)[0]
|
|
626
|
+
elif dataclasses.is_dataclass(python_type):
|
|
627
|
+
for field in dataclasses.fields(copy.deepcopy(python_type)):
|
|
628
|
+
field.type = self._get_origin_type_in_annotation(cast(type, field.type))
|
|
629
|
+
return python_type
|
|
630
|
+
|
|
631
|
+
async def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> T:
|
|
632
|
+
if binary_idl_object.tag == MESSAGEPACK:
|
|
633
|
+
if issubclass(expected_python_type, DataClassJSONMixin):
|
|
634
|
+
dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False)
|
|
635
|
+
json_str = json.dumps(dict_obj)
|
|
636
|
+
dc = expected_python_type.from_json(json_str) # type: ignore
|
|
637
|
+
else:
|
|
638
|
+
try:
|
|
639
|
+
decoder = self._msgpack_decoder[expected_python_type]
|
|
640
|
+
except KeyError:
|
|
641
|
+
decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_msgpack_decoder)
|
|
642
|
+
self._msgpack_decoder[expected_python_type] = decoder
|
|
643
|
+
dc = decoder.decode(binary_idl_object.value)
|
|
644
|
+
|
|
645
|
+
return cast(T, dc)
|
|
646
|
+
else:
|
|
647
|
+
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")
|
|
648
|
+
|
|
649
|
+
async def to_python_value(self, lv: Literal, expected_python_type: Type[T]) -> T:
|
|
650
|
+
if not dataclasses.is_dataclass(expected_python_type):
|
|
651
|
+
raise TypeTransformerFailedError(
|
|
652
|
+
f"{expected_python_type} is not of type @dataclass, only Dataclasses are supported for "
|
|
653
|
+
"user defined datatypes in Flytekit"
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
if lv.scalar and lv.scalar.binary:
|
|
657
|
+
return await self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
|
|
658
|
+
|
|
659
|
+
json_str = _json_format.MessageToJson(lv.scalar.generic)
|
|
660
|
+
|
|
661
|
+
# The `from_json` function is provided from mashumaro's `DataClassJSONMixin`.
|
|
662
|
+
# It deserializes a JSON string into a data class, and supports additional functionality over JSONDecoder
|
|
663
|
+
# We can't use hasattr(expected_python_type, "from_json") here because we rely on mashumaro's API to
|
|
664
|
+
# customize the deserialization behavior for Flyte types.
|
|
665
|
+
if issubclass(expected_python_type, DataClassJSONMixin):
|
|
666
|
+
dc = expected_python_type.from_json(json_str) # type: ignore
|
|
667
|
+
else:
|
|
668
|
+
# The function looks up or creates a JSONDecoder specifically designed for the object's type.
|
|
669
|
+
# This decoder is then used to convert a JSON string into a data class.
|
|
670
|
+
try:
|
|
671
|
+
decoder = self._json_decoder[expected_python_type]
|
|
672
|
+
except KeyError:
|
|
673
|
+
decoder = JSONDecoder(expected_python_type)
|
|
674
|
+
self._json_decoder[expected_python_type] = decoder
|
|
675
|
+
|
|
676
|
+
dc = decoder.decode(json_str)
|
|
677
|
+
|
|
678
|
+
return cast(T, dc)
|
|
679
|
+
|
|
680
|
+
# This ensures that calls with the same literal type returns the same dataclass. For example, `pyflyte run``
|
|
681
|
+
# command needs to call guess_python_type to get the TypeEngine-derived dataclass. Without caching here, separate
|
|
682
|
+
# calls to guess_python_type would result in a logically equivalent (but new) dataclass, which
|
|
683
|
+
# TypeEngine.assert_type would not be happy about.
|
|
684
|
+
def guess_python_type(self, literal_type: LiteralType) -> Type[T]: # type: ignore
|
|
685
|
+
if literal_type.simple == SimpleType.STRUCT:
|
|
686
|
+
if literal_type.HasField("metadata"):
|
|
687
|
+
from google.protobuf import json_format
|
|
688
|
+
|
|
689
|
+
metadata = json_format.MessageToDict(literal_type.metadata)
|
|
690
|
+
if TITLE in metadata:
|
|
691
|
+
schema_name = metadata[TITLE]
|
|
692
|
+
return convert_mashumaro_json_schema_to_python_class(metadata, schema_name)
|
|
693
|
+
raise ValueError(f"Dataclass transformer cannot reverse {literal_type}")
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
class ProtobufTransformer(TypeTransformer[Message]):
|
|
697
|
+
PB_FIELD_KEY = "pb_type"
|
|
698
|
+
|
|
699
|
+
def __init__(self):
|
|
700
|
+
super().__init__("Protobuf-Transformer", Message)
|
|
701
|
+
|
|
702
|
+
@staticmethod
|
|
703
|
+
def tag(expected_python_type: Type[T]) -> str:
|
|
704
|
+
return f"{expected_python_type.__module__}.{expected_python_type.__name__}"
|
|
705
|
+
|
|
706
|
+
def get_literal_type(self, t: Type[T]) -> LiteralType:
|
|
707
|
+
return LiteralType(simple=SimpleType.STRUCT, metadata={ProtobufTransformer.PB_FIELD_KEY: self.tag(t)})
|
|
708
|
+
|
|
709
|
+
async def to_literal(self, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
|
|
710
|
+
"""
|
|
711
|
+
Convert the protobuf struct to literal.
|
|
712
|
+
|
|
713
|
+
This conversion supports two types of python_val:
|
|
714
|
+
1. google.protobuf.struct_pb2.Struct: A dictionary-like message
|
|
715
|
+
2. google.protobuf.struct_pb2.ListValue: An ordered collection of values
|
|
716
|
+
|
|
717
|
+
For details, please refer to the following issue:
|
|
718
|
+
https://github.com/flyteorg/flyte/issues/5959
|
|
719
|
+
|
|
720
|
+
Because the remote handling works without errors, we implement conversion with the logic as below:
|
|
721
|
+
https://github.com/flyteorg/flyte/blob/a87585ab7cbb6a047c76d994b3f127c4210070fd/flytepropeller/pkg/controller/nodes/attr_path_resolver.go#L72-L106
|
|
722
|
+
"""
|
|
723
|
+
try:
|
|
724
|
+
if type(python_val) is struct_pb2.ListValue:
|
|
725
|
+
literals = []
|
|
726
|
+
for v in python_val:
|
|
727
|
+
literal_type = TypeEngine.to_literal_type(type(v))
|
|
728
|
+
# Recursively convert python native values to literals
|
|
729
|
+
literal = await TypeEngine.to_literal(v, type(v), literal_type)
|
|
730
|
+
literals.append(literal)
|
|
731
|
+
return Literal(collection=LiteralCollection(literals=literals))
|
|
732
|
+
else:
|
|
733
|
+
struct = struct_pb2.Struct()
|
|
734
|
+
struct.update(_MessageToDict(cast(Message, python_val)))
|
|
735
|
+
return Literal(scalar=Scalar(generic=struct))
|
|
736
|
+
except Exception:
|
|
737
|
+
raise TypeTransformerFailedError("Failed to convert to generic protobuf struct")
|
|
738
|
+
|
|
739
|
+
async def to_python_value(self, lv: Literal, expected_python_type: Type[T]) -> T:
|
|
740
|
+
if not (lv and lv.HasField("scalar") and lv.scalar.HasField("generic")):
|
|
741
|
+
raise TypeTransformerFailedError("Can only convert a generic literal to a Protobuf")
|
|
742
|
+
|
|
743
|
+
pb_obj = expected_python_type()
|
|
744
|
+
dictionary = _MessageToDict(lv.scalar.generic)
|
|
745
|
+
pb_obj = _ParseDict(dictionary, pb_obj) # type: ignore
|
|
746
|
+
return pb_obj
|
|
747
|
+
|
|
748
|
+
def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
|
|
749
|
+
# avoid loading
|
|
750
|
+
raise ValueError(f"Transformer {self} cannot reverse {literal_type}")
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
class EnumTransformer(TypeTransformer[enum.Enum]):
|
|
754
|
+
"""
|
|
755
|
+
Enables converting a python type enum.Enum to LiteralType.EnumType
|
|
756
|
+
"""
|
|
757
|
+
|
|
758
|
+
def __init__(self):
|
|
759
|
+
super().__init__(name="DefaultEnumTransformer", t=enum.Enum)
|
|
760
|
+
|
|
761
|
+
def get_literal_type(self, t: Type[T]) -> LiteralType:
|
|
762
|
+
if is_annotated(t):
|
|
763
|
+
raise ValueError(
|
|
764
|
+
f"Flytekit does not currently have support \
|
|
765
|
+
for FlyteAnnotations applied to enums. {t} cannot be \
|
|
766
|
+
parsed."
|
|
767
|
+
)
|
|
768
|
+
|
|
769
|
+
values = [v.value for v in t] # type: ignore
|
|
770
|
+
if not isinstance(values[0], str):
|
|
771
|
+
raise TypeTransformerFailedError("Only EnumTypes with value of string are supported")
|
|
772
|
+
return LiteralType(enum_type=types_pb2.EnumType(values=values))
|
|
773
|
+
|
|
774
|
+
async def to_literal(self, python_val: enum.Enum, python_type: Type[T], expected: LiteralType) -> Literal:
|
|
775
|
+
if type(python_val).__class__ != enum.EnumMeta:
|
|
776
|
+
raise TypeTransformerFailedError("Expected an enum")
|
|
777
|
+
if type(python_val.value) is not str:
|
|
778
|
+
raise TypeTransformerFailedError("Only string-valued enums are supported")
|
|
779
|
+
|
|
780
|
+
return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val.value))) # type: ignore
|
|
781
|
+
|
|
782
|
+
async def to_python_value(self, lv: Literal, expected_python_type: Type[T]) -> T:
|
|
783
|
+
if lv.HasField("scalar") and lv.scalar.HasField("binary"):
|
|
784
|
+
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
|
|
785
|
+
return expected_python_type(lv.scalar.primitive.string_value) # type: ignore
|
|
786
|
+
|
|
787
|
+
def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]:
|
|
788
|
+
if literal_type.HasField("enum_type"):
|
|
789
|
+
return enum.Enum("DynamicEnum", {f"{i}": i for i in literal_type.enum_type.values}) # type: ignore
|
|
790
|
+
raise ValueError(f"Enum transformer cannot reverse {literal_type}")
|
|
791
|
+
|
|
792
|
+
def assert_type(self, t: Type[enum.Enum], v: T):
|
|
793
|
+
val = v.value if isinstance(v, enum.Enum) else v
|
|
794
|
+
if val not in [t_item.value for t_item in t]:
|
|
795
|
+
raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}")
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any):
|
|
799
|
+
from flyte.io._dir import Dir
|
|
800
|
+
from flyte.io._file import File
|
|
801
|
+
|
|
802
|
+
attribute_list: typing.List[typing.Tuple[Any, Any]] = []
|
|
803
|
+
for property_key, property_val in schema["properties"].items():
|
|
804
|
+
if property_val.get("anyOf"):
|
|
805
|
+
property_type = property_val["anyOf"][0]["type"]
|
|
806
|
+
elif property_val.get("enum"):
|
|
807
|
+
property_type = "enum"
|
|
808
|
+
else:
|
|
809
|
+
property_type = property_val["type"]
|
|
810
|
+
# Handle list
|
|
811
|
+
if property_type == "array":
|
|
812
|
+
attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore
|
|
813
|
+
# Handle dataclass and dict
|
|
814
|
+
elif property_type == "object":
|
|
815
|
+
if property_val.get("anyOf"):
|
|
816
|
+
# For optional with dataclass
|
|
817
|
+
sub_schemea = property_val["anyOf"][0]
|
|
818
|
+
sub_schemea_name = sub_schemea["title"]
|
|
819
|
+
if File.schema_match(property_val):
|
|
820
|
+
attribute_list.append(
|
|
821
|
+
(
|
|
822
|
+
property_key,
|
|
823
|
+
typing.cast(
|
|
824
|
+
GenericAlias,
|
|
825
|
+
File,
|
|
826
|
+
),
|
|
827
|
+
)
|
|
828
|
+
)
|
|
829
|
+
continue
|
|
830
|
+
elif Dir.schema_match(property_val):
|
|
831
|
+
attribute_list.append(
|
|
832
|
+
(
|
|
833
|
+
property_key,
|
|
834
|
+
typing.cast(
|
|
835
|
+
GenericAlias,
|
|
836
|
+
Dir,
|
|
837
|
+
),
|
|
838
|
+
)
|
|
839
|
+
)
|
|
840
|
+
continue
|
|
841
|
+
attribute_list.append(
|
|
842
|
+
(
|
|
843
|
+
property_key,
|
|
844
|
+
typing.cast(
|
|
845
|
+
GenericAlias, convert_mashumaro_json_schema_to_python_class(sub_schemea, sub_schemea_name)
|
|
846
|
+
),
|
|
847
|
+
)
|
|
848
|
+
)
|
|
849
|
+
elif property_val.get("additionalProperties"):
|
|
850
|
+
# For typing.Dict type
|
|
851
|
+
elem_type = _get_element_type(property_val["additionalProperties"])
|
|
852
|
+
attribute_list.append((property_key, typing.Dict[str, elem_type])) # type: ignore
|
|
853
|
+
elif property_val.get("title"):
|
|
854
|
+
# For nested dataclass
|
|
855
|
+
sub_schemea_name = property_val["title"]
|
|
856
|
+
# Check Flyte offloaded types
|
|
857
|
+
if File.schema_match(property_val):
|
|
858
|
+
attribute_list.append(
|
|
859
|
+
(
|
|
860
|
+
property_key,
|
|
861
|
+
typing.cast(
|
|
862
|
+
GenericAlias,
|
|
863
|
+
File,
|
|
864
|
+
),
|
|
865
|
+
)
|
|
866
|
+
)
|
|
867
|
+
continue
|
|
868
|
+
elif Dir.schema_match(property_val):
|
|
869
|
+
attribute_list.append(
|
|
870
|
+
(
|
|
871
|
+
property_key,
|
|
872
|
+
typing.cast(
|
|
873
|
+
GenericAlias,
|
|
874
|
+
Dir,
|
|
875
|
+
),
|
|
876
|
+
)
|
|
877
|
+
)
|
|
878
|
+
continue
|
|
879
|
+
attribute_list.append(
|
|
880
|
+
(
|
|
881
|
+
property_key,
|
|
882
|
+
typing.cast(
|
|
883
|
+
GenericAlias, convert_mashumaro_json_schema_to_python_class(property_val, sub_schemea_name)
|
|
884
|
+
),
|
|
885
|
+
)
|
|
886
|
+
)
|
|
887
|
+
else:
|
|
888
|
+
# For untyped dict
|
|
889
|
+
attribute_list.append((property_key, dict)) # type: ignore
|
|
890
|
+
elif property_type == "enum":
|
|
891
|
+
attribute_list.append([property_key, str]) # type: ignore
|
|
892
|
+
# Handle int, float, bool or str
|
|
893
|
+
else:
|
|
894
|
+
attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore
|
|
895
|
+
return attribute_list
|
|
896
|
+
|
|
897
|
+
|
|
898
|
+
class TypeEngine(typing.Generic[T]):
|
|
899
|
+
"""
|
|
900
|
+
Core Extensible TypeEngine of Flytekit. This should be used to extend the capabilities of FlyteKits type system.
|
|
901
|
+
Users can implement their own TypeTransformers and register them with the TypeEngine. This will allow special
|
|
902
|
+
handling
|
|
903
|
+
of user objects
|
|
904
|
+
"""
|
|
905
|
+
|
|
906
|
+
_REGISTRY: typing.ClassVar[typing.Dict[type, TypeTransformer]] = {}
|
|
907
|
+
_RESTRICTED_TYPES: typing.ClassVar[typing.List[type]] = []
|
|
908
|
+
_DATACLASS_TRANSFORMER: typing.ClassVar[TypeTransformer] = DataclassTransformer()
|
|
909
|
+
_ENUM_TRANSFORMER: typing.ClassVar[TypeTransformer] = EnumTransformer()
|
|
910
|
+
lazy_import_lock: typing.ClassVar[threading.Lock] = threading.Lock()
|
|
911
|
+
|
|
912
|
+
@classmethod
|
|
913
|
+
def register(
|
|
914
|
+
cls,
|
|
915
|
+
transformer: TypeTransformer,
|
|
916
|
+
additional_types: Optional[typing.List[Type]] = None,
|
|
917
|
+
):
|
|
918
|
+
"""
|
|
919
|
+
This should be used for all types that respond with the right type annotation when you use type(...) function
|
|
920
|
+
"""
|
|
921
|
+
types = [transformer.python_type, *(additional_types or [])]
|
|
922
|
+
for t in types:
|
|
923
|
+
if t in cls._REGISTRY:
|
|
924
|
+
existing = cls._REGISTRY[t]
|
|
925
|
+
raise ValueError(
|
|
926
|
+
f"Transformer {existing.name} for type {t} is already registered."
|
|
927
|
+
f" Cannot override with {transformer.name}"
|
|
928
|
+
)
|
|
929
|
+
cls._REGISTRY[t] = transformer
|
|
930
|
+
|
|
931
|
+
@classmethod
|
|
932
|
+
def register_restricted_type(
|
|
933
|
+
cls,
|
|
934
|
+
name: str,
|
|
935
|
+
type: Type[T],
|
|
936
|
+
):
|
|
937
|
+
cls._RESTRICTED_TYPES.append(type)
|
|
938
|
+
cls.register(RestrictedTypeTransformer(name, type)) # type: ignore
|
|
939
|
+
|
|
940
|
+
@classmethod
|
|
941
|
+
def register_additional_type(cls, transformer: TypeTransformer[T], additional_type: Type[T], override=False):
|
|
942
|
+
if additional_type not in cls._REGISTRY or override:
|
|
943
|
+
cls._REGISTRY[additional_type] = transformer
|
|
944
|
+
|
|
945
|
+
@classmethod
|
|
946
|
+
def _get_transformer(cls, python_type: Type) -> Optional[TypeTransformer[T]]:
|
|
947
|
+
cls.lazy_import_transformers()
|
|
948
|
+
if is_annotated(python_type):
|
|
949
|
+
args = get_args(python_type)
|
|
950
|
+
for annotation in args:
|
|
951
|
+
if isinstance(annotation, TypeTransformer):
|
|
952
|
+
return annotation
|
|
953
|
+
return cls.get_transformer(args[0])
|
|
954
|
+
|
|
955
|
+
if inspect.isclass(python_type) and issubclass(python_type, enum.Enum):
|
|
956
|
+
# Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used.
|
|
957
|
+
return cls._ENUM_TRANSFORMER
|
|
958
|
+
|
|
959
|
+
if hasattr(python_type, "__origin__"):
|
|
960
|
+
# If the type is a generic type, we should check the origin type. But consider the case like Iterator[JSON]
|
|
961
|
+
# or List[int] has been specifically registered; we should check for the entire type.
|
|
962
|
+
# The challenge is for StructuredDataset, example List[StructuredDataset] the column names is an OrderedDict
|
|
963
|
+
# are not hashable, thus looking up this type is not possible.
|
|
964
|
+
# In such as case, we will have to skip the "type" lookup and use the origin type only
|
|
965
|
+
try:
|
|
966
|
+
if python_type in cls._REGISTRY:
|
|
967
|
+
return cls._REGISTRY[python_type]
|
|
968
|
+
except TypeError:
|
|
969
|
+
pass
|
|
970
|
+
if python_type.__origin__ in cls._REGISTRY:
|
|
971
|
+
return cls._REGISTRY[python_type.__origin__]
|
|
972
|
+
|
|
973
|
+
# Handling UnionType specially - PEP 604
|
|
974
|
+
if sys.version_info >= (3, 10):
|
|
975
|
+
import types
|
|
976
|
+
|
|
977
|
+
if isinstance(python_type, types.UnionType):
|
|
978
|
+
return cls._REGISTRY[types.UnionType]
|
|
979
|
+
|
|
980
|
+
if python_type in cls._REGISTRY:
|
|
981
|
+
return cls._REGISTRY[python_type]
|
|
982
|
+
|
|
983
|
+
return None
|
|
984
|
+
|
|
985
|
+
@classmethod
|
|
986
|
+
def get_transformer(cls, python_type: Type) -> TypeTransformer:
|
|
987
|
+
"""
|
|
988
|
+
Implements a recursive search for the transformer.
|
|
989
|
+
"""
|
|
990
|
+
v = cls._get_transformer(python_type)
|
|
991
|
+
if v is not None:
|
|
992
|
+
return v
|
|
993
|
+
|
|
994
|
+
if hasattr(python_type, "__mro__"):
|
|
995
|
+
class_tree = inspect.getmro(python_type)
|
|
996
|
+
for t in class_tree:
|
|
997
|
+
v = cls._get_transformer(t)
|
|
998
|
+
if v is not None:
|
|
999
|
+
return v
|
|
1000
|
+
|
|
1001
|
+
# dataclass type transformer is left for last to give users a chance to register a type transformer
|
|
1002
|
+
# to handle dataclass-like objects as part of the mro evaluation.
|
|
1003
|
+
#
|
|
1004
|
+
# NB: keep in mind that there are no compatibility guarantees between these user-defined dataclass transformers
|
|
1005
|
+
# and the flytekit one. This incompatibility is *not* a new behavior introduced by the recent type engine
|
|
1006
|
+
# refactor (https://github.com/flyteorg/flytekit/pull/2815), but it is worth calling out explicitly as a known
|
|
1007
|
+
# limitation nonetheless.
|
|
1008
|
+
if dataclasses.is_dataclass(python_type):
|
|
1009
|
+
return cls._DATACLASS_TRANSFORMER
|
|
1010
|
+
|
|
1011
|
+
display_pickle_warning(str(python_type))
|
|
1012
|
+
from flyte.io.pickle.transformer import FlytePickleTransformer
|
|
1013
|
+
|
|
1014
|
+
return FlytePickleTransformer()
|
|
1015
|
+
|
|
1016
|
+
@classmethod
|
|
1017
|
+
def lazy_import_transformers(cls):
|
|
1018
|
+
"""
|
|
1019
|
+
Only load the transformers if needed.
|
|
1020
|
+
"""
|
|
1021
|
+
with cls.lazy_import_lock:
|
|
1022
|
+
# Avoid a race condition where concurrent threads may exit lazy_import_transformers before the transformers
|
|
1023
|
+
# have been imported. This could be implemented without a lock if you assume python assignments are atomic
|
|
1024
|
+
# and re-registering transformers is acceptable, but I decided to play it safe.
|
|
1025
|
+
from flyte.io.structured_dataset import lazy_import_structured_dataset_handler
|
|
1026
|
+
|
|
1027
|
+
# todo: bring in extras transformers (pytorch, etc.)
|
|
1028
|
+
lazy_import_structured_dataset_handler()
|
|
1029
|
+
|
|
1030
|
+
@classmethod
|
|
1031
|
+
def to_literal_type(cls, python_type: Type[T]) -> LiteralType:
|
|
1032
|
+
"""
|
|
1033
|
+
Converts a python type into a flyte specific ``LiteralType``
|
|
1034
|
+
"""
|
|
1035
|
+
transformer = cls.get_transformer(python_type)
|
|
1036
|
+
res = transformer.get_literal_type(python_type)
|
|
1037
|
+
return res
|
|
1038
|
+
|
|
1039
|
+
@classmethod
|
|
1040
|
+
def to_literal_checks(cls, python_val: typing.Any, python_type: Type[T], expected: LiteralType):
|
|
1041
|
+
if isinstance(python_val, tuple):
|
|
1042
|
+
raise AssertionError(
|
|
1043
|
+
"Tuples are not a supported type for individual values in Flyte - got a tuple -"
|
|
1044
|
+
f" {python_val}. If using named tuple in an inner task, please, de-reference the"
|
|
1045
|
+
"actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then"
|
|
1046
|
+
"return v.x, instead of v, even if this has a single element"
|
|
1047
|
+
)
|
|
1048
|
+
if (
|
|
1049
|
+
(python_val is None and python_type is not type(None))
|
|
1050
|
+
and expected
|
|
1051
|
+
and expected.union_type is None
|
|
1052
|
+
and python_type is not Any
|
|
1053
|
+
):
|
|
1054
|
+
raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}")
|
|
1055
|
+
|
|
1056
|
+
@classmethod
|
|
1057
|
+
def calculate_hash(cls, python_val: typing.Any, python_type: Type[T]) -> Optional[str]:
|
|
1058
|
+
# In case the value is an annotated type we inspect the annotations and look for hash-related annotations.
|
|
1059
|
+
hsh = None
|
|
1060
|
+
if is_annotated(python_type):
|
|
1061
|
+
# We are now dealing with one of two cases:
|
|
1062
|
+
# 1. The annotated type is a `HashMethod`, which indicates that we should produce the hash using
|
|
1063
|
+
# the method indicated in the annotation.
|
|
1064
|
+
# 2. The annotated type is being used for a different purpose other than calculating hash values,
|
|
1065
|
+
# in which case we should just continue.
|
|
1066
|
+
for annotation in get_args(python_type)[1:]:
|
|
1067
|
+
if not isinstance(annotation, HashMethod):
|
|
1068
|
+
continue
|
|
1069
|
+
hsh = annotation.calculate(python_val)
|
|
1070
|
+
break
|
|
1071
|
+
return hsh
|
|
1072
|
+
|
|
1073
|
+
@classmethod
|
|
1074
|
+
async def to_literal(
|
|
1075
|
+
cls, python_val: typing.Any, python_type: Type[T], expected: types_pb2.LiteralType
|
|
1076
|
+
) -> literals_pb2.Literal:
|
|
1077
|
+
transformer = cls.get_transformer(python_type)
|
|
1078
|
+
|
|
1079
|
+
if transformer.type_assertions_enabled:
|
|
1080
|
+
transformer.assert_type(python_type, python_val)
|
|
1081
|
+
|
|
1082
|
+
lv = await transformer.to_literal(python_val, python_type, expected)
|
|
1083
|
+
|
|
1084
|
+
modify_literal_uris(lv)
|
|
1085
|
+
calculated_hash = cls.calculate_hash(python_val, python_type) or ""
|
|
1086
|
+
lv.hash = calculated_hash
|
|
1087
|
+
return lv
|
|
1088
|
+
|
|
1089
|
+
@classmethod
|
|
1090
|
+
async def unwrap_offloaded_literal(cls, lv: literals_pb2.Literal) -> literals_pb2.Literal:
|
|
1091
|
+
if not lv.HasField("offloaded_metadata"):
|
|
1092
|
+
return lv
|
|
1093
|
+
|
|
1094
|
+
literal_local_file = storage.get_random_local_path()
|
|
1095
|
+
assert lv.offloaded_metadata.uri, "missing offloaded uri"
|
|
1096
|
+
await storage.get(lv.offloaded_metadata.uri, str(literal_local_file))
|
|
1097
|
+
input_proto = load_proto_from_file(literals_pb2.Literal, literal_local_file)
|
|
1098
|
+
return input_proto
|
|
1099
|
+
|
|
1100
|
+
@classmethod
|
|
1101
|
+
async def to_python_value(cls, lv: Literal, expected_python_type: Type) -> typing.Any:
|
|
1102
|
+
"""
|
|
1103
|
+
Converts a Literal value with an expected python type into a python value.
|
|
1104
|
+
"""
|
|
1105
|
+
# Initiate the process of loading the offloaded literal if offloaded_metadata is set
|
|
1106
|
+
if lv.HasField("offloaded_metadata"):
|
|
1107
|
+
lv = await cls.unwrap_offloaded_literal(lv)
|
|
1108
|
+
|
|
1109
|
+
transformer = cls.get_transformer(expected_python_type)
|
|
1110
|
+
res = await transformer.to_python_value(lv, expected_python_type)
|
|
1111
|
+
return res
|
|
1112
|
+
|
|
1113
|
+
@classmethod
|
|
1114
|
+
def to_html(cls, python_val: typing.Any, expected_python_type: Type[typing.Any]) -> str:
|
|
1115
|
+
transformer = cls.get_transformer(expected_python_type)
|
|
1116
|
+
if is_annotated(expected_python_type):
|
|
1117
|
+
expected_python_type, *annotate_args = get_args(expected_python_type)
|
|
1118
|
+
from flyte.types._renderer import Renderable
|
|
1119
|
+
|
|
1120
|
+
for arg in annotate_args:
|
|
1121
|
+
if isinstance(arg, Renderable):
|
|
1122
|
+
return arg.to_html(python_val)
|
|
1123
|
+
return transformer.to_html(python_val, expected_python_type)
|
|
1124
|
+
|
|
1125
|
+
@classmethod
|
|
1126
|
+
def named_tuple_to_variable_map(cls, t: typing.NamedTuple) -> interface_pb2.VariableMap:
|
|
1127
|
+
"""
|
|
1128
|
+
Converts a python-native ``NamedTuple`` to a flyte-specific VariableMap of named literals.
|
|
1129
|
+
"""
|
|
1130
|
+
variables = {}
|
|
1131
|
+
for idx, (var_name, var_type) in enumerate(t.__annotations__.items()):
|
|
1132
|
+
literal_type = cls.to_literal_type(var_type)
|
|
1133
|
+
variables[var_name] = interface_pb2.Variable(type=literal_type, description=f"{idx}")
|
|
1134
|
+
return interface_pb2.VariableMap(variables=variables)
|
|
1135
|
+
|
|
1136
|
+
@classmethod
|
|
1137
|
+
async def literal_map_to_kwargs(
|
|
1138
|
+
cls,
|
|
1139
|
+
lm: LiteralMap,
|
|
1140
|
+
python_types: typing.Optional[typing.Dict[str, type]] = None,
|
|
1141
|
+
literal_types: typing.Optional[typing.Dict[str, interface_pb2.Variable]] = None,
|
|
1142
|
+
) -> typing.Dict[str, typing.Any]:
|
|
1143
|
+
"""
|
|
1144
|
+
Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task
|
|
1145
|
+
"""
|
|
1146
|
+
if python_types is None and literal_types is None:
|
|
1147
|
+
raise ValueError("At least one of python_types or literal_types must be provided")
|
|
1148
|
+
|
|
1149
|
+
if literal_types:
|
|
1150
|
+
python_interface_inputs: dict[str, Type[T]] = {
|
|
1151
|
+
name: TypeEngine.guess_python_type(lt.type) for name, lt in literal_types.items()
|
|
1152
|
+
}
|
|
1153
|
+
else:
|
|
1154
|
+
python_interface_inputs = python_types # type: ignore
|
|
1155
|
+
|
|
1156
|
+
if not python_interface_inputs or len(python_interface_inputs) == 0:
|
|
1157
|
+
return {}
|
|
1158
|
+
|
|
1159
|
+
if len(lm.literals) > len(python_interface_inputs):
|
|
1160
|
+
raise ValueError(
|
|
1161
|
+
f"Received more input values {len(lm.literals)}"
|
|
1162
|
+
f" than allowed by the input spec {len(python_interface_inputs)}"
|
|
1163
|
+
)
|
|
1164
|
+
kwargs = {}
|
|
1165
|
+
try:
|
|
1166
|
+
for i, k in enumerate(lm.literals):
|
|
1167
|
+
kwargs[k] = asyncio.create_task(TypeEngine.to_python_value(lm.literals[k], python_interface_inputs[k]))
|
|
1168
|
+
await asyncio.gather(*kwargs.values())
|
|
1169
|
+
except Exception as e:
|
|
1170
|
+
raise TypeTransformerFailedError(
|
|
1171
|
+
f"Error converting input:\n"
|
|
1172
|
+
f"Literal value: {lm.literals[k]}\n"
|
|
1173
|
+
f"Expected Python type: {python_interface_inputs[k]}\n"
|
|
1174
|
+
f"Exception: {e}"
|
|
1175
|
+
)
|
|
1176
|
+
|
|
1177
|
+
kwargs = {k: v.result() for k, v in kwargs.items() if v is not None}
|
|
1178
|
+
return kwargs
|
|
1179
|
+
|
|
1180
|
+
@classmethod
|
|
1181
|
+
async def dict_to_literal_map(
|
|
1182
|
+
cls,
|
|
1183
|
+
d: typing.Dict[str, typing.Any],
|
|
1184
|
+
type_hints: Optional[typing.Dict[str, type]] = None,
|
|
1185
|
+
) -> LiteralMap:
|
|
1186
|
+
"""
|
|
1187
|
+
Given a dictionary mapping string keys to python values and a dictionary containing guessed types for such
|
|
1188
|
+
string keys,
|
|
1189
|
+
convert to a LiteralMap.
|
|
1190
|
+
"""
|
|
1191
|
+
type_hints = type_hints or {}
|
|
1192
|
+
literal_map = {}
|
|
1193
|
+
for k, v in d.items():
|
|
1194
|
+
# The guessed type takes precedence over the type returned by the python runtime. This is needed
|
|
1195
|
+
# to account for the type erasure that happens in the case of built-in collection containers, such as
|
|
1196
|
+
# `list` and `dict`.
|
|
1197
|
+
python_type = type_hints.get(k, type(v))
|
|
1198
|
+
literal_map[k] = asyncio.create_task(
|
|
1199
|
+
TypeEngine.to_literal(
|
|
1200
|
+
python_val=v,
|
|
1201
|
+
python_type=python_type,
|
|
1202
|
+
expected=TypeEngine.to_literal_type(python_type),
|
|
1203
|
+
)
|
|
1204
|
+
)
|
|
1205
|
+
await asyncio.gather(*literal_map.values(), return_exceptions=True)
|
|
1206
|
+
for idx, (k, v) in enumerate(literal_map.items()):
|
|
1207
|
+
if literal_map[k].exception() is not None:
|
|
1208
|
+
python_type = type_hints.get(k, type(d[k]))
|
|
1209
|
+
e: BaseException = literal_map[k].exception() # type: ignore
|
|
1210
|
+
if isinstance(e, TypeError):
|
|
1211
|
+
raise TypeError(f"Error converting: {type(v)}, {python_type}, received_value {v}")
|
|
1212
|
+
else:
|
|
1213
|
+
raise e
|
|
1214
|
+
literal_map[k] = v.result()
|
|
1215
|
+
|
|
1216
|
+
return LiteralMap(literals=literal_map)
|
|
1217
|
+
|
|
1218
|
+
@classmethod
|
|
1219
|
+
def get_available_transformers(cls) -> typing.KeysView[Type]:
|
|
1220
|
+
"""
|
|
1221
|
+
Returns all python types for which transformers are available
|
|
1222
|
+
"""
|
|
1223
|
+
return cls._REGISTRY.keys()
|
|
1224
|
+
|
|
1225
|
+
@classmethod
|
|
1226
|
+
def guess_python_types(
|
|
1227
|
+
cls, flyte_variable_dict: typing.Dict[str, interface_pb2.Variable]
|
|
1228
|
+
) -> typing.Dict[str, Type[Any]]:
|
|
1229
|
+
"""
|
|
1230
|
+
Transforms a dictionary of flyte-specific ``Variable`` objects to a dictionary of regular python values.
|
|
1231
|
+
"""
|
|
1232
|
+
python_types = {}
|
|
1233
|
+
for k, v in flyte_variable_dict.items():
|
|
1234
|
+
python_types[k] = cls.guess_python_type(v.type)
|
|
1235
|
+
return python_types
|
|
1236
|
+
|
|
1237
|
+
@classmethod
|
|
1238
|
+
def guess_python_type(cls, flyte_type: LiteralType) -> Type[T]:
|
|
1239
|
+
"""
|
|
1240
|
+
Transforms a flyte-specific ``LiteralType`` to a regular python value.
|
|
1241
|
+
"""
|
|
1242
|
+
for _, transformer in cls._REGISTRY.items():
|
|
1243
|
+
try:
|
|
1244
|
+
return transformer.guess_python_type(flyte_type)
|
|
1245
|
+
except ValueError:
|
|
1246
|
+
logger.debug(f"Skipping transformer {transformer.name} for {flyte_type}")
|
|
1247
|
+
|
|
1248
|
+
# Because the dataclass transformer is handled explicitly in the get_transformer code, we have to handle it
|
|
1249
|
+
# separately here too.
|
|
1250
|
+
try:
|
|
1251
|
+
return cls._DATACLASS_TRANSFORMER.guess_python_type(literal_type=flyte_type)
|
|
1252
|
+
except ValueError:
|
|
1253
|
+
logger.debug(f"Skipping transformer {cls._DATACLASS_TRANSFORMER.name} for {flyte_type}")
|
|
1254
|
+
raise ValueError(f"No transformers could reverse Flyte literal type {flyte_type}")
|
|
1255
|
+
|
|
1256
|
+
|
|
1257
|
+
class ListTransformer(TypeTransformer[T]):
|
|
1258
|
+
"""
|
|
1259
|
+
Transformer that handles a univariate typing.List[T]
|
|
1260
|
+
"""
|
|
1261
|
+
|
|
1262
|
+
def __init__(self):
|
|
1263
|
+
super().__init__("Typed List", list)
|
|
1264
|
+
|
|
1265
|
+
@staticmethod
|
|
1266
|
+
def get_sub_type(t: Type[T]) -> Type[T]:
|
|
1267
|
+
"""
|
|
1268
|
+
Return the generic Type T of the List
|
|
1269
|
+
"""
|
|
1270
|
+
if (sub_type := ListTransformer.get_sub_type_or_none(t)) is not None:
|
|
1271
|
+
return sub_type
|
|
1272
|
+
|
|
1273
|
+
raise ValueError("Only generic univariate typing.List[T] type is supported.")
|
|
1274
|
+
|
|
1275
|
+
@staticmethod
|
|
1276
|
+
def get_sub_type_or_none(t: Type[T]) -> Optional[Type[T]]:
|
|
1277
|
+
"""
|
|
1278
|
+
Return the generic Type T of the List, or None if the generic type cannot be inferred
|
|
1279
|
+
"""
|
|
1280
|
+
if hasattr(t, "__origin__"):
|
|
1281
|
+
# Handle annotation on list generic, eg:
|
|
1282
|
+
# Annotated[typing.List[int], 'foo']
|
|
1283
|
+
if is_annotated(t):
|
|
1284
|
+
return ListTransformer.get_sub_type(get_args(t)[0])
|
|
1285
|
+
|
|
1286
|
+
if getattr(t, "__origin__") is list and hasattr(t, "__args__"):
|
|
1287
|
+
return getattr(t, "__args__")[0]
|
|
1288
|
+
|
|
1289
|
+
return None
|
|
1290
|
+
|
|
1291
|
+
def get_literal_type(self, t: Type[T]) -> Optional[types_pb2.LiteralType]:
|
|
1292
|
+
"""
|
|
1293
|
+
Only univariate Lists are supported in Flyte
|
|
1294
|
+
"""
|
|
1295
|
+
try:
|
|
1296
|
+
sub_type = TypeEngine.to_literal_type(self.get_sub_type(t))
|
|
1297
|
+
return types_pb2.LiteralType(collection_type=sub_type)
|
|
1298
|
+
except Exception as e:
|
|
1299
|
+
raise ValueError(f"Type of Generic List type is not supported, {e}")
|
|
1300
|
+
|
|
1301
|
+
async def to_literal(self, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
|
|
1302
|
+
if type(python_val) is not list:
|
|
1303
|
+
raise TypeTransformerFailedError("Expected a list")
|
|
1304
|
+
|
|
1305
|
+
t = self.get_sub_type(python_type)
|
|
1306
|
+
lit_list = [TypeEngine.to_literal(x, t, expected.collection_type) for x in python_val]
|
|
1307
|
+
|
|
1308
|
+
lit_list = await _run_coros_in_chunks(lit_list, batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)
|
|
1309
|
+
|
|
1310
|
+
return Literal(collection=LiteralCollection(literals=lit_list))
|
|
1311
|
+
|
|
1312
|
+
async def to_python_value( # type: ignore
|
|
1313
|
+
self, lv: Literal, expected_python_type: Type[T]
|
|
1314
|
+
) -> typing.Optional[typing.List[T]]:
|
|
1315
|
+
if lv and lv.HasField("scalar") and lv.scalar.HasField("binary"):
|
|
1316
|
+
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
|
|
1317
|
+
|
|
1318
|
+
try:
|
|
1319
|
+
lits = lv.collection.literals
|
|
1320
|
+
except AttributeError:
|
|
1321
|
+
raise TypeTransformerFailedError(
|
|
1322
|
+
(
|
|
1323
|
+
f"The expected python type is '{expected_python_type}' but the received Flyte literal value "
|
|
1324
|
+
f"is not a collection (Flyte's representation of Python lists)."
|
|
1325
|
+
)
|
|
1326
|
+
)
|
|
1327
|
+
|
|
1328
|
+
st = self.get_sub_type(expected_python_type)
|
|
1329
|
+
result = [TypeEngine.to_python_value(x, st) for x in lits]
|
|
1330
|
+
result = await _run_coros_in_chunks(result, batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)
|
|
1331
|
+
return result # type: ignore # should be a list, thinks its a tuple
|
|
1332
|
+
|
|
1333
|
+
def guess_python_type(self, literal_type: types_pb2.LiteralType) -> list: # type: ignore
|
|
1334
|
+
if literal_type.HasField("collection_type"):
|
|
1335
|
+
ct: Type = TypeEngine.guess_python_type(literal_type.collection_type)
|
|
1336
|
+
return typing.List[ct] # type: ignore
|
|
1337
|
+
raise ValueError(f"List transformer cannot reverse {literal_type}")
|
|
1338
|
+
|
|
1339
|
+
|
|
1340
|
+
@lru_cache
|
|
1341
|
+
def display_pickle_warning(python_type: str):
|
|
1342
|
+
# This is a warning that is only displayed once per python type
|
|
1343
|
+
logger.warning(
|
|
1344
|
+
f"Unsupported Type {python_type} found, Flyte will default to use PickleFile as the transport. "
|
|
1345
|
+
f"Pickle can only be used to send objects between the exact same version of Python, "
|
|
1346
|
+
f"and we strongly recommend to use python type that flyte support."
|
|
1347
|
+
)
|
|
1348
|
+
|
|
1349
|
+
|
|
1350
|
+
def _add_tag_to_type(x: types_pb2.LiteralType, tag: str) -> types_pb2.LiteralType:
|
|
1351
|
+
replica = types_pb2.LiteralType()
|
|
1352
|
+
replica.CopyFrom(x)
|
|
1353
|
+
replica.structure.CopyFrom(TypeStructure(tag=tag))
|
|
1354
|
+
return replica
|
|
1355
|
+
|
|
1356
|
+
|
|
1357
|
+
def _type_essence(x: types_pb2.LiteralType) -> types_pb2.LiteralType:
|
|
1358
|
+
if x.HasField("metadata") or x.HasField("structure") or x.HasField("annotation"):
|
|
1359
|
+
x2 = types_pb2.LiteralType()
|
|
1360
|
+
x2.CopyFrom(x)
|
|
1361
|
+
x2.ClearField("metadata")
|
|
1362
|
+
x2.ClearField("structure")
|
|
1363
|
+
x2.ClearField("annotation")
|
|
1364
|
+
return x2
|
|
1365
|
+
return x
|
|
1366
|
+
|
|
1367
|
+
|
|
1368
|
+
def _are_types_castable(upstream: types_pb2.LiteralType, downstream: types_pb2.LiteralType) -> bool:
|
|
1369
|
+
if upstream.HasField("collection_type"):
|
|
1370
|
+
if not downstream.HasField("collection_type"):
|
|
1371
|
+
return False
|
|
1372
|
+
|
|
1373
|
+
return _are_types_castable(upstream.collection_type, downstream.collection_type)
|
|
1374
|
+
|
|
1375
|
+
if upstream.HasField("map_value_type"):
|
|
1376
|
+
if not downstream.HasField("map_value_type"):
|
|
1377
|
+
return False
|
|
1378
|
+
|
|
1379
|
+
return _are_types_castable(upstream.map_value_type, downstream.map_value_type)
|
|
1380
|
+
|
|
1381
|
+
# TODO: Structured dataset type matching requires that downstream structured datasets
|
|
1382
|
+
# are a strict sub-set of the upstream structured dataset.
|
|
1383
|
+
if upstream.HasField("structured_dataset_type"):
|
|
1384
|
+
if not downstream.HasField("structured_dataset_type"):
|
|
1385
|
+
return False
|
|
1386
|
+
|
|
1387
|
+
usdt = upstream.structured_dataset_type
|
|
1388
|
+
dsdt = downstream.structured_dataset_type
|
|
1389
|
+
|
|
1390
|
+
if usdt.format != dsdt.format:
|
|
1391
|
+
return False
|
|
1392
|
+
|
|
1393
|
+
if usdt.external_schema_type != dsdt.external_schema_type:
|
|
1394
|
+
return False
|
|
1395
|
+
|
|
1396
|
+
if usdt.external_schema_bytes != dsdt.external_schema_bytes:
|
|
1397
|
+
return False
|
|
1398
|
+
|
|
1399
|
+
ucols = usdt.columns
|
|
1400
|
+
dcols = dsdt.columns
|
|
1401
|
+
|
|
1402
|
+
if len(ucols) != len(dcols):
|
|
1403
|
+
return False
|
|
1404
|
+
|
|
1405
|
+
for u, d in zip(ucols, dcols):
|
|
1406
|
+
if u.name != d.name:
|
|
1407
|
+
return False
|
|
1408
|
+
|
|
1409
|
+
if not _are_types_castable(u.literal_type, d.literal_type):
|
|
1410
|
+
return False
|
|
1411
|
+
|
|
1412
|
+
return True
|
|
1413
|
+
|
|
1414
|
+
if upstream.HasField("union_type"):
|
|
1415
|
+
# for each upstream variant, there must be a compatible type downstream
|
|
1416
|
+
for v in upstream.union_type.variants:
|
|
1417
|
+
if not _are_types_castable(v, downstream):
|
|
1418
|
+
return False
|
|
1419
|
+
return True
|
|
1420
|
+
|
|
1421
|
+
if downstream.HasField("union_type"):
|
|
1422
|
+
# there must be a compatible downstream type
|
|
1423
|
+
for v in downstream.union_type.variants:
|
|
1424
|
+
if _are_types_castable(upstream, v):
|
|
1425
|
+
return True
|
|
1426
|
+
|
|
1427
|
+
if upstream.HasField("enum_type"):
|
|
1428
|
+
# enums are castable to string
|
|
1429
|
+
if downstream.simple == SimpleType.STRING:
|
|
1430
|
+
return True
|
|
1431
|
+
|
|
1432
|
+
if _type_essence(upstream) == _type_essence(downstream):
|
|
1433
|
+
return True
|
|
1434
|
+
|
|
1435
|
+
return False
|
|
1436
|
+
|
|
1437
|
+
|
|
1438
|
+
def _is_union_type(t):
|
|
1439
|
+
"""Returns True if t is a Union type."""
|
|
1440
|
+
|
|
1441
|
+
if sys.version_info >= (3, 10):
|
|
1442
|
+
import types
|
|
1443
|
+
|
|
1444
|
+
UnionType = types.UnionType
|
|
1445
|
+
else:
|
|
1446
|
+
UnionType = None
|
|
1447
|
+
|
|
1448
|
+
return t is typing.Union or get_origin(t) is typing.Union or (UnionType and isinstance(t, UnionType))
|
|
1449
|
+
|
|
1450
|
+
|
|
1451
|
+
class UnionTransformer(TypeTransformer[T]):
|
|
1452
|
+
"""
|
|
1453
|
+
Transformer that handles a typing.Union[T1, T2, ...]
|
|
1454
|
+
"""
|
|
1455
|
+
|
|
1456
|
+
def __init__(self):
|
|
1457
|
+
super().__init__("Typed Union", typing.Union)
|
|
1458
|
+
|
|
1459
|
+
@staticmethod
|
|
1460
|
+
def is_optional_type(t: Type[Any]) -> bool:
|
|
1461
|
+
return _is_union_type(t) and type(None) in get_args(t)
|
|
1462
|
+
|
|
1463
|
+
@staticmethod
|
|
1464
|
+
def get_sub_type_in_optional(t: Type[T]) -> Type[T]:
|
|
1465
|
+
"""
|
|
1466
|
+
Return the generic Type T of the Optional type
|
|
1467
|
+
"""
|
|
1468
|
+
return get_args(t)[0]
|
|
1469
|
+
|
|
1470
|
+
def assert_type(self, t: Type[T], v: T):
|
|
1471
|
+
python_type = get_underlying_type(t)
|
|
1472
|
+
if _is_union_type(python_type):
|
|
1473
|
+
for sub_type in get_args(python_type):
|
|
1474
|
+
if sub_type == typing.Any:
|
|
1475
|
+
# this is an edge case
|
|
1476
|
+
return
|
|
1477
|
+
try:
|
|
1478
|
+
sub_trans: TypeTransformer = TypeEngine.get_transformer(sub_type)
|
|
1479
|
+
if sub_trans.type_assertions_enabled:
|
|
1480
|
+
sub_trans.assert_type(sub_type, v)
|
|
1481
|
+
return
|
|
1482
|
+
else:
|
|
1483
|
+
return
|
|
1484
|
+
except TypeTransformerFailedError:
|
|
1485
|
+
continue
|
|
1486
|
+
except TypeError:
|
|
1487
|
+
continue
|
|
1488
|
+
raise TypeTransformerFailedError(f"Value {v} is not of type {t}")
|
|
1489
|
+
|
|
1490
|
+
def get_literal_type(self, t: Type[T]) -> Optional[types_pb2.LiteralType]:
|
|
1491
|
+
t = get_underlying_type(t)
|
|
1492
|
+
|
|
1493
|
+
try:
|
|
1494
|
+
trans: typing.List[typing.Tuple[TypeTransformer, typing.Any]] = [
|
|
1495
|
+
(TypeEngine.get_transformer(x), x) for x in get_args(t)
|
|
1496
|
+
]
|
|
1497
|
+
# must go through TypeEngine.to_literal_type instead of trans.get_literal_type
|
|
1498
|
+
# to handle Annotated
|
|
1499
|
+
variants = [_add_tag_to_type(TypeEngine.to_literal_type(x), t.name) for (t, x) in trans]
|
|
1500
|
+
return types_pb2.LiteralType(union_type=UnionType(variants=variants))
|
|
1501
|
+
except Exception as e:
|
|
1502
|
+
raise ValueError(f"Type of Generic Union type is not supported, {e}")
|
|
1503
|
+
|
|
1504
|
+
async def to_literal(
|
|
1505
|
+
self, python_val: T, python_type: Type[T], expected: types_pb2.LiteralType
|
|
1506
|
+
) -> literals_pb2.Literal:
|
|
1507
|
+
python_type = get_underlying_type(python_type)
|
|
1508
|
+
|
|
1509
|
+
potential_types = []
|
|
1510
|
+
found_res = False
|
|
1511
|
+
is_ambiguous = False
|
|
1512
|
+
res = None
|
|
1513
|
+
res_type = None
|
|
1514
|
+
t = None
|
|
1515
|
+
for i in range(len(get_args(python_type))):
|
|
1516
|
+
try:
|
|
1517
|
+
t = get_args(python_type)[i]
|
|
1518
|
+
trans: TypeTransformer[T] = TypeEngine.get_transformer(t)
|
|
1519
|
+
attempt = trans.to_literal(python_val, t, expected.union_type.variants[i])
|
|
1520
|
+
res = await attempt
|
|
1521
|
+
if found_res:
|
|
1522
|
+
logger.debug(f"Current type {get_args(python_type)[i]} old res {res_type}")
|
|
1523
|
+
is_ambiguous = True
|
|
1524
|
+
res_type = _add_tag_to_type(trans.get_literal_type(t), trans.name)
|
|
1525
|
+
found_res = True
|
|
1526
|
+
potential_types.append(t)
|
|
1527
|
+
except Exception as e:
|
|
1528
|
+
logger.debug(
|
|
1529
|
+
f"UnionTransformer failed attempt to convert from {python_val} to {t} error: {e}",
|
|
1530
|
+
)
|
|
1531
|
+
continue
|
|
1532
|
+
|
|
1533
|
+
if is_ambiguous:
|
|
1534
|
+
raise TypeError(
|
|
1535
|
+
f"Ambiguous choice of variant for union type.\n"
|
|
1536
|
+
f"Potential types: {potential_types}\n"
|
|
1537
|
+
"These types are structurally the same, because it's attributes have the same names and associated"
|
|
1538
|
+
" types."
|
|
1539
|
+
)
|
|
1540
|
+
|
|
1541
|
+
if found_res:
|
|
1542
|
+
return Literal(scalar=Scalar(union=Union(value=res, type=res_type)))
|
|
1543
|
+
|
|
1544
|
+
raise TypeTransformerFailedError(f"Cannot convert from {python_val} to {python_type}")
|
|
1545
|
+
|
|
1546
|
+
async def to_python_value(self, lv: Literal, expected_python_type: Type[T]) -> Optional[typing.Any]:
|
|
1547
|
+
expected_python_type = get_underlying_type(expected_python_type)
|
|
1548
|
+
|
|
1549
|
+
union_tag = None
|
|
1550
|
+
union_type = None
|
|
1551
|
+
if lv.HasField("scalar") and lv.scalar.HasField("union"):
|
|
1552
|
+
union_type = lv.scalar.union.type
|
|
1553
|
+
if union_type.HasField("structure"):
|
|
1554
|
+
union_tag = union_type.structure.tag
|
|
1555
|
+
|
|
1556
|
+
found_res = False
|
|
1557
|
+
is_ambiguous = False
|
|
1558
|
+
cur_transformer = ""
|
|
1559
|
+
res = None
|
|
1560
|
+
res_tag = None
|
|
1561
|
+
# This is serial, not actually async, but should be okay since it's more reasonable for Unions.
|
|
1562
|
+
for v in get_args(expected_python_type):
|
|
1563
|
+
try:
|
|
1564
|
+
trans: TypeTransformer[T] = TypeEngine.get_transformer(v)
|
|
1565
|
+
if union_tag is not None:
|
|
1566
|
+
if trans.name != union_tag:
|
|
1567
|
+
continue
|
|
1568
|
+
|
|
1569
|
+
expected_literal_type = TypeEngine.to_literal_type(v)
|
|
1570
|
+
if not _are_types_castable(union_type, expected_literal_type):
|
|
1571
|
+
continue
|
|
1572
|
+
|
|
1573
|
+
assert lv.scalar.HasField("union"), f"Literal {lv} is not a union" # type checker
|
|
1574
|
+
|
|
1575
|
+
if lv.scalar.HasField("binary"):
|
|
1576
|
+
res = await trans.to_python_value(lv, v)
|
|
1577
|
+
else:
|
|
1578
|
+
res = await trans.to_python_value(lv.scalar.union.value, v)
|
|
1579
|
+
|
|
1580
|
+
if found_res:
|
|
1581
|
+
is_ambiguous = True
|
|
1582
|
+
cur_transformer = trans.name
|
|
1583
|
+
break
|
|
1584
|
+
else:
|
|
1585
|
+
res = await trans.to_python_value(lv, v)
|
|
1586
|
+
if found_res:
|
|
1587
|
+
is_ambiguous = True
|
|
1588
|
+
cur_transformer = trans.name
|
|
1589
|
+
break
|
|
1590
|
+
res_tag = trans.name
|
|
1591
|
+
found_res = True
|
|
1592
|
+
except Exception as e:
|
|
1593
|
+
logger.debug(f"Failed to convert from {lv} to {v} with error: {e}")
|
|
1594
|
+
|
|
1595
|
+
if is_ambiguous:
|
|
1596
|
+
raise TypeError(
|
|
1597
|
+
f"Ambiguous choice of variant for union type. Both {res_tag} and {cur_transformer} transformers match"
|
|
1598
|
+
)
|
|
1599
|
+
|
|
1600
|
+
if found_res:
|
|
1601
|
+
return res
|
|
1602
|
+
|
|
1603
|
+
raise TypeError(f"Cannot convert from {lv} to {expected_python_type} (using tag {union_tag})")
|
|
1604
|
+
|
|
1605
|
+
def guess_python_type(self, literal_type: LiteralType) -> type:
|
|
1606
|
+
if literal_type.HasField("union_type"):
|
|
1607
|
+
return typing.Union[tuple(TypeEngine.guess_python_type(v) for v in literal_type.union_type.variants)] # type: ignore
|
|
1608
|
+
|
|
1609
|
+
raise ValueError(f"Union transformer cannot reverse {literal_type}")
|
|
1610
|
+
|
|
1611
|
+
|
|
1612
|
+
class DictTransformer(TypeTransformer[dict]):
|
|
1613
|
+
"""
|
|
1614
|
+
Transformer that transforms an univariate dictionary Dict[str, T] to a Literal Map or
|
|
1615
|
+
transforms an untyped dictionary to a Binary Scalar Literal with a Struct Literal Type.
|
|
1616
|
+
"""
|
|
1617
|
+
|
|
1618
|
+
def __init__(self):
|
|
1619
|
+
super().__init__("Typed Dict", dict)
|
|
1620
|
+
|
|
1621
|
+
@staticmethod
|
|
1622
|
+
def extract_types(t: Optional[Type[dict]]) -> typing.Tuple:
|
|
1623
|
+
if t is None:
|
|
1624
|
+
return None, None
|
|
1625
|
+
|
|
1626
|
+
# Get the origin and type arguments.
|
|
1627
|
+
_origin = get_origin(t)
|
|
1628
|
+
_args = get_args(t)
|
|
1629
|
+
|
|
1630
|
+
# If not annotated or dict, return None, None.
|
|
1631
|
+
if _origin is None:
|
|
1632
|
+
return None, None
|
|
1633
|
+
|
|
1634
|
+
# If this is something like Annotated[dict[int, str], FlyteAnnotation("abc")],
|
|
1635
|
+
# we need to check if there's a FlyteAnnotation in the metadata.
|
|
1636
|
+
if _origin is Annotated:
|
|
1637
|
+
# This case should never happen since Python's typing system requires at least two arguments
|
|
1638
|
+
# for Annotated[...] - a type and an annotation. Including this check for completeness.
|
|
1639
|
+
if not _args:
|
|
1640
|
+
return None, None
|
|
1641
|
+
|
|
1642
|
+
first_arg = _args[0]
|
|
1643
|
+
# Recursively process the first argument if it's Annotated (or dict).
|
|
1644
|
+
return DictTransformer.extract_types(first_arg)
|
|
1645
|
+
|
|
1646
|
+
# If the origin is dict, return the type arguments if they exist.
|
|
1647
|
+
if _origin is dict:
|
|
1648
|
+
# _args can be ().
|
|
1649
|
+
if _args is not None:
|
|
1650
|
+
return _args # type: ignore
|
|
1651
|
+
|
|
1652
|
+
# Otherwise, we do not support this type in extract_types.
|
|
1653
|
+
raise ValueError(f"Trying to extract dictionary type information from a non-dict type {t}")
|
|
1654
|
+
|
|
1655
|
+
@staticmethod
|
|
1656
|
+
async def dict_to_binary_literal(v: dict, python_type: Type[dict], allow_pickle: bool) -> Literal:
|
|
1657
|
+
"""
|
|
1658
|
+
Converts a Python dictionary to a Flyte-specific ``Literal`` using MessagePack encoding.
|
|
1659
|
+
Falls back to Pickle if encoding fails and `allow_pickle` is True.
|
|
1660
|
+
"""
|
|
1661
|
+
from flyte.io.pickle.transformer import FlytePickle
|
|
1662
|
+
|
|
1663
|
+
try:
|
|
1664
|
+
# Handle dictionaries with non-string keys (e.g., Dict[int, Type])
|
|
1665
|
+
encoder = MessagePackEncoder(python_type)
|
|
1666
|
+
msgpack_bytes = encoder.encode(v)
|
|
1667
|
+
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK)))
|
|
1668
|
+
except TypeError as e:
|
|
1669
|
+
if allow_pickle:
|
|
1670
|
+
remote_path = await FlytePickle.to_pickle(v)
|
|
1671
|
+
return Literal(
|
|
1672
|
+
scalar=Scalar(
|
|
1673
|
+
generic=_json_format.Parse(json.dumps({"pickle_file": remote_path}), struct_pb2.Struct())
|
|
1674
|
+
),
|
|
1675
|
+
metadata={"format": "pickle"},
|
|
1676
|
+
)
|
|
1677
|
+
raise TypeTransformerFailedError(f"Cannot convert `{v}` to Flyte Literal.\nError Message: {e}")
|
|
1678
|
+
|
|
1679
|
+
@staticmethod
|
|
1680
|
+
def is_pickle(python_type: Type[dict]) -> bool:
|
|
1681
|
+
_origin = get_origin(python_type)
|
|
1682
|
+
metadata: typing.Tuple = ()
|
|
1683
|
+
if _origin is Annotated:
|
|
1684
|
+
metadata = get_args(python_type)[1:]
|
|
1685
|
+
|
|
1686
|
+
for each_metadata in metadata:
|
|
1687
|
+
if isinstance(each_metadata, OrderedDict):
|
|
1688
|
+
allow_pickle = each_metadata.get("allow_pickle", False)
|
|
1689
|
+
return allow_pickle
|
|
1690
|
+
|
|
1691
|
+
return False
|
|
1692
|
+
|
|
1693
|
+
def get_literal_type(self, t: Type[dict]) -> LiteralType:
|
|
1694
|
+
"""
|
|
1695
|
+
Transforms a native python dictionary to a flyte-specific ``LiteralType``
|
|
1696
|
+
"""
|
|
1697
|
+
tp = DictTransformer.extract_types(t)
|
|
1698
|
+
|
|
1699
|
+
if tp:
|
|
1700
|
+
if tp[0] is str:
|
|
1701
|
+
try:
|
|
1702
|
+
sub_type = TypeEngine.to_literal_type(cast(type, tp[1]))
|
|
1703
|
+
return types_pb2.LiteralType(map_value_type=sub_type)
|
|
1704
|
+
except Exception as e:
|
|
1705
|
+
raise ValueError(f"Type of Generic List type is not supported, {e}")
|
|
1706
|
+
return types_pb2.LiteralType(
|
|
1707
|
+
simple=types_pb2.SimpleType.STRUCT,
|
|
1708
|
+
annotation=TypeAnnotation(annotations={CACHE_KEY_METADATA: {SERIALIZATION_FORMAT: MESSAGEPACK}}),
|
|
1709
|
+
)
|
|
1710
|
+
|
|
1711
|
+
async def to_literal(self, python_val: typing.Any, python_type: Type[dict], expected: LiteralType) -> Literal:
|
|
1712
|
+
if type(python_val) is not dict:
|
|
1713
|
+
raise TypeTransformerFailedError("Expected a dict")
|
|
1714
|
+
|
|
1715
|
+
allow_pickle = False
|
|
1716
|
+
|
|
1717
|
+
if get_origin(python_type) is Annotated:
|
|
1718
|
+
allow_pickle = DictTransformer.is_pickle(python_type)
|
|
1719
|
+
|
|
1720
|
+
if expected and expected.HasField("simple") and expected.simple == SimpleType.STRUCT:
|
|
1721
|
+
return await self.dict_to_binary_literal(python_val, python_type, allow_pickle)
|
|
1722
|
+
|
|
1723
|
+
lit_map = {}
|
|
1724
|
+
for k, v in python_val.items():
|
|
1725
|
+
if type(k) is not str:
|
|
1726
|
+
raise ValueError("Flyte MapType expects all keys to be strings")
|
|
1727
|
+
# TODO: log a warning for Annotated objects that contain HashMethod
|
|
1728
|
+
|
|
1729
|
+
_, v_type = self.extract_types(python_type)
|
|
1730
|
+
lit_map[k] = TypeEngine.to_literal(v, cast(type, v_type), expected.map_value_type)
|
|
1731
|
+
vals = await _run_coros_in_chunks(list(lit_map.values()), batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)
|
|
1732
|
+
for idx, k in zip(range(len(vals)), lit_map.keys()):
|
|
1733
|
+
lit_map[k] = vals[idx]
|
|
1734
|
+
|
|
1735
|
+
return Literal(map=LiteralMap(literals=lit_map))
|
|
1736
|
+
|
|
1737
|
+
async def to_python_value(self, lv: Literal, expected_python_type: Type[dict]) -> dict:
|
|
1738
|
+
if lv and lv.HasField("scalar") and lv.scalar.HasField("binary"):
|
|
1739
|
+
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
|
|
1740
|
+
|
|
1741
|
+
if lv and lv.HasField("map"):
|
|
1742
|
+
tp = DictTransformer.extract_types(expected_python_type)
|
|
1743
|
+
|
|
1744
|
+
if tp is None or len(tp) == 0 or tp[0] is None:
|
|
1745
|
+
raise TypeError(
|
|
1746
|
+
"TypeMismatch: Cannot convert to python dictionary from Flyte Literal Dictionary as the given "
|
|
1747
|
+
"dictionary does not have sub-type hints or they do not match with the originating dictionary "
|
|
1748
|
+
"source. Flytekit does not currently support implicit conversions"
|
|
1749
|
+
)
|
|
1750
|
+
if tp[0] is not str:
|
|
1751
|
+
raise TypeError("TypeMismatch. Destination dictionary does not accept 'str' key")
|
|
1752
|
+
py_map = {}
|
|
1753
|
+
for k, v in lv.map.literals.items():
|
|
1754
|
+
py_map[k] = TypeEngine.to_python_value(v, cast(Type, tp[1]))
|
|
1755
|
+
|
|
1756
|
+
vals = await _run_coros_in_chunks(list(py_map.values()), batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)
|
|
1757
|
+
for idx, k in zip(range(len(vals)), py_map.keys()):
|
|
1758
|
+
py_map[k] = vals[idx]
|
|
1759
|
+
|
|
1760
|
+
return py_map
|
|
1761
|
+
|
|
1762
|
+
# for empty generic we have to explicitly test for lv.scalar.generic is not None as empty dict
|
|
1763
|
+
# evaluates to false
|
|
1764
|
+
# pr: han-ru is this part still necessary?
|
|
1765
|
+
if lv and lv.HasField("scalar") and lv.scalar.HasField("generic"):
|
|
1766
|
+
if lv.metadata and lv.metadata.get("format", None) == "pickle":
|
|
1767
|
+
from flyte.io.pickle.transformer import FlytePickle
|
|
1768
|
+
|
|
1769
|
+
uri = json.loads(_json_format.MessageToJson(lv.scalar.generic)).get("pickle_file")
|
|
1770
|
+
return await FlytePickle.from_pickle(uri)
|
|
1771
|
+
|
|
1772
|
+
try:
|
|
1773
|
+
"""
|
|
1774
|
+
Handles the case where Flyte Console provides input as a protobuf struct.
|
|
1775
|
+
When resolving an attribute like 'dc.dict_int_ff', FlytePropeller retrieves a dictionary.
|
|
1776
|
+
Mashumaro's decoder can convert this dictionary to the expected Python object if the correct type
|
|
1777
|
+
is provided.
|
|
1778
|
+
Since Flyte Types handle their own deserialization, the dictionary is automatically converted to
|
|
1779
|
+
the expected Python object.
|
|
1780
|
+
|
|
1781
|
+
Example Code:
|
|
1782
|
+
@dataclass
|
|
1783
|
+
class DC:
|
|
1784
|
+
dict_int_ff: Dict[int, FlyteFile]
|
|
1785
|
+
|
|
1786
|
+
@workflow
|
|
1787
|
+
def wf(dc: DC):
|
|
1788
|
+
t_ff(dc.dict_int_ff)
|
|
1789
|
+
|
|
1790
|
+
Life Cycle:
|
|
1791
|
+
json str -> protobuf struct -> resolved protobuf struct -> dictionary
|
|
1792
|
+
-> expected Python object
|
|
1793
|
+
(console user input) (console output) (propeller)
|
|
1794
|
+
(flytekit dict transformer) (mashumaro decoder)
|
|
1795
|
+
|
|
1796
|
+
Related PR:
|
|
1797
|
+
- Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro
|
|
1798
|
+
- Link: https://github.com/flyteorg/flytekit/pull/2554
|
|
1799
|
+
- Title: Binary IDL With MessagePack
|
|
1800
|
+
- Link: https://github.com/flyteorg/flytekit/pull/2760
|
|
1801
|
+
"""
|
|
1802
|
+
|
|
1803
|
+
dict_obj = json.loads(_json_format.MessageToJson(lv.scalar.generic))
|
|
1804
|
+
msgpack_bytes = msgpack.dumps(dict_obj)
|
|
1805
|
+
|
|
1806
|
+
try:
|
|
1807
|
+
decoder = self._msgpack_decoder[expected_python_type]
|
|
1808
|
+
except KeyError:
|
|
1809
|
+
decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_msgpack_decoder)
|
|
1810
|
+
self._msgpack_decoder[expected_python_type] = decoder
|
|
1811
|
+
|
|
1812
|
+
return decoder.decode(msgpack_bytes)
|
|
1813
|
+
except TypeError:
|
|
1814
|
+
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")
|
|
1815
|
+
|
|
1816
|
+
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")
|
|
1817
|
+
|
|
1818
|
+
def guess_python_type(self, literal_type: LiteralType) -> Union[Type[dict], typing.Dict[Type, Type]]:
|
|
1819
|
+
if literal_type.HasField("map_value_type"):
|
|
1820
|
+
mt: typing.Type = TypeEngine.guess_python_type(literal_type.map_value_type)
|
|
1821
|
+
return typing.Dict[str, mt] # type: ignore
|
|
1822
|
+
|
|
1823
|
+
if literal_type.simple == SimpleType.STRUCT:
|
|
1824
|
+
if not literal_type.HasField("metadata"):
|
|
1825
|
+
return dict # type: ignore
|
|
1826
|
+
|
|
1827
|
+
raise ValueError(f"Dictionary transformer cannot reverse {literal_type}")
|
|
1828
|
+
|
|
1829
|
+
|
|
1830
|
+
def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typing.Any) -> Type[T]:
|
|
1831
|
+
"""
|
|
1832
|
+
Generate a model class based on the provided JSON Schema
|
|
1833
|
+
:param schema: dict representing valid JSON schema
|
|
1834
|
+
:param schema_name: dataclass name of return type
|
|
1835
|
+
"""
|
|
1836
|
+
|
|
1837
|
+
attribute_list = generate_attribute_list_from_dataclass_json_mixin(schema, schema_name)
|
|
1838
|
+
return dataclasses.make_dataclass(schema_name, attribute_list)
|
|
1839
|
+
|
|
1840
|
+
|
|
1841
|
+
def _get_element_type(element_property: typing.Dict[str, str]) -> Type:
|
|
1842
|
+
from flyte.io._dir import Dir
|
|
1843
|
+
from flyte.io._file import File
|
|
1844
|
+
|
|
1845
|
+
if File.schema_match(element_property):
|
|
1846
|
+
return File
|
|
1847
|
+
elif Dir.schema_match(element_property):
|
|
1848
|
+
return Dir
|
|
1849
|
+
element_type = (
|
|
1850
|
+
[e_property["type"] for e_property in element_property["anyOf"]] # type: ignore
|
|
1851
|
+
if element_property.get("anyOf")
|
|
1852
|
+
else element_property["type"]
|
|
1853
|
+
)
|
|
1854
|
+
element_format = element_property["format"] if "format" in element_property else None
|
|
1855
|
+
|
|
1856
|
+
if isinstance(element_type, list):
|
|
1857
|
+
# Element type of Optional[int] is [integer, None]
|
|
1858
|
+
return typing.Optional[_get_element_type({"type": element_type[0]})] # type: ignore
|
|
1859
|
+
|
|
1860
|
+
if element_type == "string":
|
|
1861
|
+
return str
|
|
1862
|
+
elif element_type == "integer":
|
|
1863
|
+
return int
|
|
1864
|
+
elif element_type == "boolean":
|
|
1865
|
+
return bool
|
|
1866
|
+
elif element_type == "number":
|
|
1867
|
+
if element_format == "integer":
|
|
1868
|
+
return int
|
|
1869
|
+
else:
|
|
1870
|
+
return float
|
|
1871
|
+
return str
|
|
1872
|
+
|
|
1873
|
+
|
|
1874
|
+
# pr: han-ru is this still needed?
|
|
1875
|
+
def dataclass_from_dict(cls: type, src: typing.Dict[str, typing.Any]) -> typing.Any:
|
|
1876
|
+
"""
|
|
1877
|
+
Utility function to construct a dataclass object from dict
|
|
1878
|
+
"""
|
|
1879
|
+
field_types_lookup = {field.name: field.type for field in dataclasses.fields(cls)}
|
|
1880
|
+
|
|
1881
|
+
constructor_inputs = {}
|
|
1882
|
+
for field_name, value in src.items():
|
|
1883
|
+
if dataclasses.is_dataclass(field_types_lookup[field_name]):
|
|
1884
|
+
constructor_inputs[field_name] = dataclass_from_dict(cast(type, field_types_lookup[field_name]), value)
|
|
1885
|
+
else:
|
|
1886
|
+
constructor_inputs[field_name] = value
|
|
1887
|
+
|
|
1888
|
+
return cls(**constructor_inputs)
|
|
1889
|
+
|
|
1890
|
+
|
|
1891
|
+
def strict_type_hint_matching(input_val: typing.Any, target_literal_type: LiteralType) -> typing.Type:
|
|
1892
|
+
"""
|
|
1893
|
+
Try to be smarter about guessing the type of the input (and hence the transformer).
|
|
1894
|
+
If the literal type from the transformer for type(v), matches the literal type of the interface, then we
|
|
1895
|
+
can use type(). Otherwise, fall back to guess python type from the literal type.
|
|
1896
|
+
Raises ValueError, like in case of [1,2,3] type() will just give `list`, which won't work.
|
|
1897
|
+
Raises ValueError also if the transformer found for the raw type doesn't have a literal type match.
|
|
1898
|
+
"""
|
|
1899
|
+
native_type = type(input_val)
|
|
1900
|
+
transformer: TypeTransformer = TypeEngine.get_transformer(native_type)
|
|
1901
|
+
inferred_literal_type = transformer.get_literal_type(native_type)
|
|
1902
|
+
# note: if no good match, transformer will be the pickle transformer, but type will not match unless it's the
|
|
1903
|
+
# pickle type so will fall back to normal guessing
|
|
1904
|
+
if literal_types_match(inferred_literal_type, target_literal_type):
|
|
1905
|
+
return type(input_val)
|
|
1906
|
+
|
|
1907
|
+
raise ValueError(
|
|
1908
|
+
f"Transformer for {native_type} returned literal type {inferred_literal_type} "
|
|
1909
|
+
f"which doesn't match {target_literal_type}"
|
|
1910
|
+
)
|
|
1911
|
+
|
|
1912
|
+
|
|
1913
|
+
def _check_and_covert_float(lv: literals_pb2.Literal) -> float:
|
|
1914
|
+
if lv.scalar.primitive.HasField("float_value"):
|
|
1915
|
+
return lv.scalar.primitive.float_value
|
|
1916
|
+
elif lv.scalar.primitive.HasField("integer"):
|
|
1917
|
+
return float(lv.scalar.primitive.integer)
|
|
1918
|
+
raise TypeTransformerFailedError(f"Cannot convert literal {lv} to float")
|
|
1919
|
+
|
|
1920
|
+
|
|
1921
|
+
def _handle_flyte_console_float_input_to_int(lv: Literal) -> int:
|
|
1922
|
+
"""
|
|
1923
|
+
Flyte Console is written by JavaScript and JavaScript has only one number type which is Number.
|
|
1924
|
+
Sometimes it keeps track of trailing 0s and sometimes it doesn't.
|
|
1925
|
+
We have to convert float to int back in the following example.
|
|
1926
|
+
|
|
1927
|
+
Example Code:
|
|
1928
|
+
@dataclass
|
|
1929
|
+
class DC:
|
|
1930
|
+
a: int
|
|
1931
|
+
|
|
1932
|
+
@workflow
|
|
1933
|
+
def wf(dc: DC):
|
|
1934
|
+
t_int(a=dc.a)
|
|
1935
|
+
|
|
1936
|
+
Life Cycle:
|
|
1937
|
+
json str -> protobuf struct -> resolved float -> float
|
|
1938
|
+
-> int
|
|
1939
|
+
(console user input) (console output) (propeller) (flytekit simple transformer)
|
|
1940
|
+
(_handle_flyte_console_float_input_to_int)
|
|
1941
|
+
"""
|
|
1942
|
+
if lv.scalar.primitive.HasField("integer"):
|
|
1943
|
+
return lv.scalar.primitive.integer
|
|
1944
|
+
|
|
1945
|
+
if lv.scalar.primitive.HasField("float_value"):
|
|
1946
|
+
logger.info(f"Converting literal float {lv.scalar.primitive.float_value} to int, might have precision loss.")
|
|
1947
|
+
return int(lv.scalar.primitive.float_value)
|
|
1948
|
+
|
|
1949
|
+
raise TypeTransformerFailedError(f"Cannot convert literal {lv} to int")
|
|
1950
|
+
|
|
1951
|
+
|
|
1952
|
+
def _check_and_convert_void(lv: Literal) -> None:
|
|
1953
|
+
if not lv.scalar.HasField("none_type"):
|
|
1954
|
+
raise TypeTransformerFailedError(f"Cannot convert literal {lv} to None")
|
|
1955
|
+
return None
|
|
1956
|
+
|
|
1957
|
+
|
|
1958
|
+
IntTransformer = SimpleTransformer(
|
|
1959
|
+
"int",
|
|
1960
|
+
int,
|
|
1961
|
+
types_pb2.LiteralType(simple=types_pb2.SimpleType.INTEGER),
|
|
1962
|
+
lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x))),
|
|
1963
|
+
_handle_flyte_console_float_input_to_int,
|
|
1964
|
+
)
|
|
1965
|
+
|
|
1966
|
+
FloatTransformer = SimpleTransformer(
|
|
1967
|
+
"float",
|
|
1968
|
+
float,
|
|
1969
|
+
types_pb2.LiteralType(simple=types_pb2.SimpleType.FLOAT),
|
|
1970
|
+
lambda x: Literal(scalar=Scalar(primitive=Primitive(float_value=x))),
|
|
1971
|
+
_check_and_covert_float,
|
|
1972
|
+
)
|
|
1973
|
+
|
|
1974
|
+
BoolTransformer = SimpleTransformer(
|
|
1975
|
+
"bool",
|
|
1976
|
+
bool,
|
|
1977
|
+
types_pb2.LiteralType(simple=types_pb2.SimpleType.BOOLEAN),
|
|
1978
|
+
lambda x: Literal(scalar=Scalar(primitive=Primitive(boolean=x))),
|
|
1979
|
+
lambda x: x.scalar.primitive.boolean,
|
|
1980
|
+
)
|
|
1981
|
+
|
|
1982
|
+
StrTransformer = SimpleTransformer(
|
|
1983
|
+
"str",
|
|
1984
|
+
str,
|
|
1985
|
+
types_pb2.LiteralType(simple=types_pb2.SimpleType.STRING),
|
|
1986
|
+
lambda x: Literal(scalar=Scalar(primitive=Primitive(string_value=x))),
|
|
1987
|
+
lambda x: x.scalar.primitive.string_value if x.scalar.primitive.HasField("string_value") else None,
|
|
1988
|
+
)
|
|
1989
|
+
|
|
1990
|
+
DatetimeTransformer = SimpleTransformer(
|
|
1991
|
+
"datetime",
|
|
1992
|
+
datetime.datetime,
|
|
1993
|
+
types_pb2.LiteralType(simple=types_pb2.SimpleType.DATETIME),
|
|
1994
|
+
lambda x: Literal(scalar=Scalar(primitive=Primitive(datetime=x))),
|
|
1995
|
+
lambda x: x.scalar.primitive.datetime if x.scalar.primitive.HasField("datetime") else None,
|
|
1996
|
+
)
|
|
1997
|
+
|
|
1998
|
+
TimedeltaTransformer = SimpleTransformer(
|
|
1999
|
+
"timedelta",
|
|
2000
|
+
datetime.timedelta,
|
|
2001
|
+
types_pb2.LiteralType(simple=types_pb2.SimpleType.DURATION),
|
|
2002
|
+
lambda x: Literal(scalar=Scalar(primitive=Primitive(duration=x))),
|
|
2003
|
+
lambda x: x.scalar.primitive.duration if x.scalar.primitive.HasField("duration") else None,
|
|
2004
|
+
)
|
|
2005
|
+
|
|
2006
|
+
DateTransformer = SimpleTransformer(
|
|
2007
|
+
"date",
|
|
2008
|
+
datetime.date,
|
|
2009
|
+
types_pb2.LiteralType(simple=types_pb2.SimpleType.DATETIME),
|
|
2010
|
+
lambda x: Literal(
|
|
2011
|
+
scalar=Scalar(primitive=Primitive(datetime=datetime.datetime.combine(x, datetime.time.min)))
|
|
2012
|
+
), # convert datetime to date
|
|
2013
|
+
lambda x: x.scalar.primitive.datetime.date() if x.scalar.primitive.HasField("datetime") else None,
|
|
2014
|
+
)
|
|
2015
|
+
|
|
2016
|
+
NoneTransformer = SimpleTransformer(
|
|
2017
|
+
"none",
|
|
2018
|
+
type(None),
|
|
2019
|
+
types_pb2.LiteralType(simple=types_pb2.SimpleType.NONE),
|
|
2020
|
+
lambda x: Literal(scalar=Scalar(none_type=Void())),
|
|
2021
|
+
lambda x: _check_and_convert_void(x),
|
|
2022
|
+
)
|
|
2023
|
+
|
|
2024
|
+
|
|
2025
|
+
def _register_default_type_transformers():
|
|
2026
|
+
from types import UnionType
|
|
2027
|
+
|
|
2028
|
+
TypeEngine.register(IntTransformer)
|
|
2029
|
+
TypeEngine.register(FloatTransformer)
|
|
2030
|
+
TypeEngine.register(StrTransformer)
|
|
2031
|
+
TypeEngine.register(DatetimeTransformer)
|
|
2032
|
+
TypeEngine.register(DateTransformer)
|
|
2033
|
+
TypeEngine.register(TimedeltaTransformer)
|
|
2034
|
+
TypeEngine.register(BoolTransformer)
|
|
2035
|
+
TypeEngine.register(NoneTransformer, [None])
|
|
2036
|
+
TypeEngine.register(ListTransformer())
|
|
2037
|
+
TypeEngine.register(UnionTransformer(), [UnionType])
|
|
2038
|
+
TypeEngine.register(DictTransformer())
|
|
2039
|
+
TypeEngine.register(EnumTransformer())
|
|
2040
|
+
TypeEngine.register(ProtobufTransformer())
|
|
2041
|
+
|
|
2042
|
+
# inner type is. Also unsupported are typing's Tuples. Even though you can look inside them, Flyte's type system
|
|
2043
|
+
# doesn't support these currently.
|
|
2044
|
+
# Confusing note: typing.NamedTuple is in here even though task functions themselves can return them. We just mean
|
|
2045
|
+
# that the return signature of a task can be a NamedTuple that contains another NamedTuple inside it.
|
|
2046
|
+
# Also, it's not entirely true that Flyte IDL doesn't support tuples. We can always fake them as structs, but we'll
|
|
2047
|
+
# hold off on doing that for now, as we may amend the IDL formally to support tuples.
|
|
2048
|
+
TypeEngine.register_restricted_type("non typed tuple", tuple)
|
|
2049
|
+
TypeEngine.register_restricted_type("non typed tuple", typing.Tuple)
|
|
2050
|
+
TypeEngine.register_restricted_type("named tuple", NamedTuple)
|
|
2051
|
+
|
|
2052
|
+
|
|
2053
|
+
class LiteralsResolver(collections.UserDict):
|
|
2054
|
+
"""
|
|
2055
|
+
LiteralsResolver is a helper class meant primarily for use with the FlyteRemote experience or any other situation
|
|
2056
|
+
where you might be working with LiteralMaps. This object allows the caller to specify the Python type that should
|
|
2057
|
+
correspond to an element of the map.
|
|
2058
|
+
"""
|
|
2059
|
+
|
|
2060
|
+
def __init__(
|
|
2061
|
+
self,
|
|
2062
|
+
literals: typing.Dict[str, Literal],
|
|
2063
|
+
variable_map: Optional[Dict[str, interface_pb2.Variable]] = None,
|
|
2064
|
+
):
|
|
2065
|
+
"""
|
|
2066
|
+
:param literals: A Python map of strings to Flyte Literal models.
|
|
2067
|
+
:param variable_map: This map should be basically one side (either input or output) of the Flyte
|
|
2068
|
+
TypedInterface model and is used to guess the Python type through the TypeEngine if a Python type is not
|
|
2069
|
+
specified by the user. TypeEngine guessing is flaky though, so calls to get() should specify the as_type
|
|
2070
|
+
parameter when possible.
|
|
2071
|
+
"""
|
|
2072
|
+
super().__init__(literals)
|
|
2073
|
+
if literals is None:
|
|
2074
|
+
raise ValueError("Cannot instantiate LiteralsResolver without a map of Literals.")
|
|
2075
|
+
self._literals = literals
|
|
2076
|
+
self._variable_map = variable_map
|
|
2077
|
+
self._native_values: Dict[str, type] = {}
|
|
2078
|
+
self._type_hints: Dict[str, type] = {}
|
|
2079
|
+
|
|
2080
|
+
def __str__(self) -> str:
|
|
2081
|
+
if self.literals:
|
|
2082
|
+
if len(self.literals) == len(self.native_values):
|
|
2083
|
+
return str(self.native_values)
|
|
2084
|
+
if self.native_values:
|
|
2085
|
+
header = "Partially converted to native values, call get(key, <type_hint>) to convert rest...\n"
|
|
2086
|
+
strs = []
|
|
2087
|
+
for key, literal in self._literals.items():
|
|
2088
|
+
if key in self._native_values:
|
|
2089
|
+
strs.append(f"{key}: " + str(self._native_values[key]) + "\n")
|
|
2090
|
+
else:
|
|
2091
|
+
lit_txt = str(self._literals[key])
|
|
2092
|
+
lit_txt = textwrap.indent(lit_txt, " " * (len(key) + 2))
|
|
2093
|
+
strs.append(f"{key}: \n" + lit_txt)
|
|
2094
|
+
|
|
2095
|
+
return header + "{\n" + textwrap.indent("".join(strs), " " * 2) + "\n}"
|
|
2096
|
+
else:
|
|
2097
|
+
return str(self.literals)
|
|
2098
|
+
return "{}"
|
|
2099
|
+
|
|
2100
|
+
def __repr__(self):
|
|
2101
|
+
return self.__str__()
|
|
2102
|
+
|
|
2103
|
+
@property
|
|
2104
|
+
def native_values(self) -> typing.Dict[str, typing.Any]:
|
|
2105
|
+
return self._native_values
|
|
2106
|
+
|
|
2107
|
+
@property
|
|
2108
|
+
def variable_map(self) -> Optional[Dict[str, interface_pb2.Variable]]:
|
|
2109
|
+
return self._variable_map
|
|
2110
|
+
|
|
2111
|
+
@property
|
|
2112
|
+
def literals(self):
|
|
2113
|
+
return self._literals
|
|
2114
|
+
|
|
2115
|
+
def update_type_hints(self, type_hints: typing.Dict[str, typing.Type]):
|
|
2116
|
+
self._type_hints.update(type_hints)
|
|
2117
|
+
|
|
2118
|
+
def get_literal(self, key: str) -> Literal:
|
|
2119
|
+
if key not in self._literals:
|
|
2120
|
+
raise ValueError(f"Key {key} is not in the literal map")
|
|
2121
|
+
|
|
2122
|
+
return self._literals[key]
|
|
2123
|
+
|
|
2124
|
+
def as_python_native(self, python_interface: NativeInterface) -> typing.Any:
|
|
2125
|
+
"""
|
|
2126
|
+
This should return the native Python representation, compatible with unpacking.
|
|
2127
|
+
This function relies on Python interface outputs being ordered correctly.
|
|
2128
|
+
|
|
2129
|
+
:param python_interface: Only outputs are used but easier to pass the whole interface.
|
|
2130
|
+
"""
|
|
2131
|
+
if len(self.literals) == 0:
|
|
2132
|
+
return None
|
|
2133
|
+
|
|
2134
|
+
if self.variable_map is None:
|
|
2135
|
+
raise AssertionError(f"Variable map is empty in literals resolver with {self.literals}")
|
|
2136
|
+
|
|
2137
|
+
# Trigger get() on everything to make sure native values are present using the python interface as type hint
|
|
2138
|
+
for lit_key, lit in self.literals.items():
|
|
2139
|
+
asyncio.run(self.get(lit_key, as_type=python_interface.outputs.get(lit_key)))
|
|
2140
|
+
|
|
2141
|
+
# if 1 item, then return 1 item
|
|
2142
|
+
if len(self.native_values) == 1:
|
|
2143
|
+
return next(iter(self.native_values.values()))
|
|
2144
|
+
|
|
2145
|
+
# if more than 1 item, then return a tuple - can ignore naming the tuple unless it becomes a problem
|
|
2146
|
+
# This relies on python_interface.outputs being ordered correctly.
|
|
2147
|
+
res = cast(typing.Tuple[typing.Any, ...], ())
|
|
2148
|
+
for var_name, _ in python_interface.outputs.items():
|
|
2149
|
+
if var_name not in self.native_values:
|
|
2150
|
+
raise ValueError(f"Key {var_name} is not in the native values")
|
|
2151
|
+
|
|
2152
|
+
res += (self.native_values[var_name],)
|
|
2153
|
+
|
|
2154
|
+
return res
|
|
2155
|
+
|
|
2156
|
+
def __getitem__(self, key: str):
|
|
2157
|
+
# First check to see if it's even in the literal map.
|
|
2158
|
+
if key not in self._literals:
|
|
2159
|
+
raise ValueError(f"Key {key} is not in the literal map")
|
|
2160
|
+
|
|
2161
|
+
# Return the cached value if it's cached
|
|
2162
|
+
if key in self._native_values:
|
|
2163
|
+
return self._native_values[key]
|
|
2164
|
+
|
|
2165
|
+
return self.get(key)
|
|
2166
|
+
|
|
2167
|
+
async def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any: # type: ignore
|
|
2168
|
+
"""
|
|
2169
|
+
This will get the ``attr`` value from the Literal map, and invoke the TypeEngine to convert it into a Python
|
|
2170
|
+
native value. A Python type can optionally be supplied. If successful, the native value will be cached and
|
|
2171
|
+
future calls will return the cached value instead.
|
|
2172
|
+
|
|
2173
|
+
:param attr:
|
|
2174
|
+
:param as_type:
|
|
2175
|
+
:return: Python native value from the LiteralMap
|
|
2176
|
+
"""
|
|
2177
|
+
if attr not in self._literals:
|
|
2178
|
+
raise AttributeError(f"Attribute {attr} not found")
|
|
2179
|
+
if attr in self.native_values:
|
|
2180
|
+
return self.native_values[attr]
|
|
2181
|
+
|
|
2182
|
+
if as_type is None:
|
|
2183
|
+
if attr in self._type_hints:
|
|
2184
|
+
as_type = self._type_hints[attr]
|
|
2185
|
+
else:
|
|
2186
|
+
if self.variable_map and attr in self.variable_map:
|
|
2187
|
+
try:
|
|
2188
|
+
as_type = TypeEngine.guess_python_type(self.variable_map[attr].type)
|
|
2189
|
+
except ValueError as e:
|
|
2190
|
+
logger.error(f"Could not guess a type for Variable {self.variable_map[attr]}")
|
|
2191
|
+
raise e
|
|
2192
|
+
else:
|
|
2193
|
+
raise ValueError("as_type argument not supplied and Variable map not specified in LiteralsResolver")
|
|
2194
|
+
val = await TypeEngine.to_python_value(self._literals[attr], cast(Type, as_type))
|
|
2195
|
+
self._native_values[attr] = val
|
|
2196
|
+
return val
|
|
2197
|
+
|
|
2198
|
+
|
|
2199
|
+
_register_default_type_transformers()
|
|
2200
|
+
|
|
2201
|
+
|
|
2202
|
+
def is_annotated(t: Type) -> bool:
|
|
2203
|
+
return get_origin(t) is Annotated
|
|
2204
|
+
|
|
2205
|
+
|
|
2206
|
+
def get_underlying_type(t: Type) -> Type:
|
|
2207
|
+
"""Return the underlying type for annotated types or the type itself"""
|
|
2208
|
+
if is_annotated(t):
|
|
2209
|
+
return get_args(t)[0]
|
|
2210
|
+
return t
|