flyte 0.1.0__py3-none-any.whl → 0.2.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +62 -2
- flyte/_api_commons.py +3 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +126 -0
- flyte/_build.py +25 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +146 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_common.py +299 -0
- flyte/_cli/_create.py +42 -0
- flyte/_cli/_delete.py +23 -0
- flyte/_cli/_deploy.py +140 -0
- flyte/_cli/_get.py +235 -0
- flyte/_cli/_params.py +538 -0
- flyte/_cli/_run.py +174 -0
- flyte/_cli/main.py +98 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +339 -0
- flyte/_code_bundle/bundle.py +178 -0
- flyte/_context.py +146 -0
- flyte/_datastructures.py +342 -0
- flyte/_deploy.py +202 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +43 -0
- flyte/_group.py +31 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +757 -0
- flyte/_initialize.py +643 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +115 -0
- flyte/_internal/controllers/_local_controller.py +118 -0
- flyte/_internal/controllers/_trace.py +40 -0
- flyte/_internal/controllers/pbhash.py +39 -0
- flyte/_internal/controllers/remote/__init__.py +40 -0
- flyte/_internal/controllers/remote/_action.py +141 -0
- flyte/_internal/controllers/remote/_client.py +43 -0
- flyte/_internal/controllers/remote/_controller.py +361 -0
- flyte/_internal/controllers/remote/_core.py +402 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +416 -0
- flyte/_internal/imagebuild/image_builder.py +241 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +205 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +210 -0
- flyte/_internal/runtime/taskrunner.py +190 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +124 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +69 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +106 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +133 -0
- flyte/_protos/workflow/run_service_pb2.pyi +175 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +58 -0
- flyte/_protos/workflow/state_service_pb2.pyi +71 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +72 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +44 -0
- flyte/_protos/workflow/task_service_pb2.pyi +31 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +410 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +367 -0
- flyte/_task_environment.py +200 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +128 -0
- flyte/_utils/__init__.py +20 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/coro_management.py +25 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +108 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/config/__init__.py +168 -0
- flyte/config/_config.py +196 -0
- flyte/config/_internal.py +64 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +143 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +273 -0
- flyte/io/__init__.py +11 -0
- flyte/io/_dataframe.py +0 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +468 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/io/pickle/transformer.py +117 -0
- flyte/io/structured_dataset/__init__.py +129 -0
- flyte/io/structured_dataset/basic_dfs.py +219 -0
- flyte/io/structured_dataset/structured_dataset.py +1061 -0
- flyte/remote/__init__.py +25 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +131 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +184 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +155 -0
- flyte/remote/_logs.py +116 -0
- flyte/remote/_project.py +86 -0
- flyte/remote/_run.py +873 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +227 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +24 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +251 -0
- flyte/storage/_utils.py +5 -0
- flyte/types/__init__.py +13 -0
- flyte/types/_interface.py +25 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2211 -0
- flyte/types/_utils.py +80 -0
- flyte-0.2.0b0.dist-info/METADATA +179 -0
- flyte-0.2.0b0.dist-info/RECORD +204 -0
- {flyte-0.1.0.dist-info → flyte-0.2.0b0.dist-info}/WHEEL +2 -1
- flyte-0.2.0b0.dist-info/entry_points.txt +3 -0
- flyte-0.2.0b0.dist-info/top_level.txt +1 -0
- flyte-0.1.0.dist-info/METADATA +0 -6
- flyte-0.1.0.dist-info/RECORD +0 -5
flyte/_cli/_params.py
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import datetime
|
|
3
|
+
import enum
|
|
4
|
+
import importlib
|
|
5
|
+
import importlib.util
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
import pathlib
|
|
9
|
+
import re
|
|
10
|
+
import sys
|
|
11
|
+
import typing
|
|
12
|
+
import typing as t
|
|
13
|
+
from typing import get_args
|
|
14
|
+
|
|
15
|
+
import rich_click as click
|
|
16
|
+
import yaml
|
|
17
|
+
from click import Parameter
|
|
18
|
+
from flyteidl.core.interface_pb2 import Variable
|
|
19
|
+
from flyteidl.core.literals_pb2 import Literal
|
|
20
|
+
from flyteidl.core.types_pb2 import BlobType, LiteralType, SimpleType
|
|
21
|
+
from google.protobuf.json_format import MessageToDict
|
|
22
|
+
from mashumaro.codecs.json import JSONEncoder
|
|
23
|
+
|
|
24
|
+
from flyte._logging import logger
|
|
25
|
+
from flyte.io import Dir, File
|
|
26
|
+
from flyte.io.pickle.transformer import FlytePickleTransformer
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class StructuredDataset:
|
|
30
|
+
def __init__(self, uri: str | None = None, dataframe: typing.Any = None):
|
|
31
|
+
self.uri = uri
|
|
32
|
+
self.dataframe = dataframe
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# ---------------------------------------------------
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def key_value_callback(_: typing.Any, param: str, values: typing.List[str]) -> typing.Optional[typing.Dict[str, str]]:
|
|
39
|
+
"""
|
|
40
|
+
Callback for click to parse key-value pairs.
|
|
41
|
+
"""
|
|
42
|
+
if not values:
|
|
43
|
+
return None
|
|
44
|
+
result = {}
|
|
45
|
+
for v in values:
|
|
46
|
+
if "=" not in v:
|
|
47
|
+
raise click.BadParameter(f"Expected key-value pair of the form key=value, got {v}")
|
|
48
|
+
k, val = v.split("=", 1)
|
|
49
|
+
result[k.strip()] = val.strip()
|
|
50
|
+
return result
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def labels_callback(_: typing.Any, param: str, values: typing.List[str]) -> typing.Optional[typing.Dict[str, str]]:
|
|
54
|
+
"""
|
|
55
|
+
Callback for click to parse labels.
|
|
56
|
+
"""
|
|
57
|
+
if not values:
|
|
58
|
+
return None
|
|
59
|
+
result = {}
|
|
60
|
+
for v in values:
|
|
61
|
+
if "=" not in v:
|
|
62
|
+
result[v.strip()] = ""
|
|
63
|
+
else:
|
|
64
|
+
k, val = v.split("=", 1)
|
|
65
|
+
result[k.strip()] = val.strip()
|
|
66
|
+
return result
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class DirParamType(click.ParamType):
|
|
70
|
+
name = "directory path"
|
|
71
|
+
|
|
72
|
+
def convert(
|
|
73
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
74
|
+
) -> typing.Any:
|
|
75
|
+
from flyte.storage import is_remote
|
|
76
|
+
|
|
77
|
+
if not is_remote(value):
|
|
78
|
+
p = pathlib.Path(value)
|
|
79
|
+
if not p.exists() or not p.is_dir():
|
|
80
|
+
raise click.BadParameter(f"parameter should be a valid flytedirectory path, {value}")
|
|
81
|
+
return Dir(path=value)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class StructuredDatasetParamType(click.ParamType):
|
|
85
|
+
"""
|
|
86
|
+
TODO handle column types
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
name = "structured dataset path (dir/file)"
|
|
90
|
+
|
|
91
|
+
def convert(
|
|
92
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
93
|
+
) -> typing.Any:
|
|
94
|
+
if isinstance(value, str):
|
|
95
|
+
return StructuredDataset(uri=value)
|
|
96
|
+
elif isinstance(value, StructuredDataset):
|
|
97
|
+
return value
|
|
98
|
+
return StructuredDataset(dataframe=value)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class FileParamType(click.ParamType):
|
|
102
|
+
name = "file path"
|
|
103
|
+
|
|
104
|
+
def convert(
|
|
105
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
106
|
+
) -> typing.Any:
|
|
107
|
+
from flyte.storage import is_remote
|
|
108
|
+
|
|
109
|
+
if not is_remote(value):
|
|
110
|
+
p = pathlib.Path(value)
|
|
111
|
+
if not p.exists() or not p.is_file():
|
|
112
|
+
raise click.BadParameter(f"parameter should be a valid file path, {value}")
|
|
113
|
+
return File(path=value)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class PickleParamType(click.ParamType):
|
|
117
|
+
name = "pickle"
|
|
118
|
+
|
|
119
|
+
def get_metavar(self, param: "Parameter", *args) -> t.Optional[str]:
|
|
120
|
+
return "Python Object <Module>:<Object>"
|
|
121
|
+
|
|
122
|
+
def convert(
|
|
123
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
124
|
+
) -> typing.Any:
|
|
125
|
+
if not isinstance(value, str):
|
|
126
|
+
return value
|
|
127
|
+
parts = value.split(":")
|
|
128
|
+
if len(parts) != 2:
|
|
129
|
+
if ctx and ctx.obj and ctx.obj.log_level >= 10: # DEBUG level
|
|
130
|
+
click.echo(f"Did not receive a string in the expected format <MODULE>:<VAR>, falling back to: {value}")
|
|
131
|
+
return value
|
|
132
|
+
try:
|
|
133
|
+
sys.path.insert(0, os.getcwd())
|
|
134
|
+
m = importlib.import_module(parts[0])
|
|
135
|
+
return m.__getattribute__(parts[1])
|
|
136
|
+
except ModuleNotFoundError as e:
|
|
137
|
+
raise click.BadParameter(f"Failed to import module {parts[0]}, error: {e}")
|
|
138
|
+
except AttributeError as e:
|
|
139
|
+
raise click.BadParameter(f"Failed to find attribute {parts[1]} in module {parts[0]}, error: {e}")
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class JSONIteratorParamType(click.ParamType):
|
|
143
|
+
name = "json iterator"
|
|
144
|
+
|
|
145
|
+
def convert(
|
|
146
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
147
|
+
) -> typing.Any:
|
|
148
|
+
return value
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def parse_iso8601_duration(iso_duration: str) -> datetime.timedelta:
|
|
152
|
+
pattern = re.compile(
|
|
153
|
+
r"^P" # Starts with 'P'
|
|
154
|
+
r"(?:(?P<days>\d+)D)?" # Optional days
|
|
155
|
+
r"(?:T" # Optional time part
|
|
156
|
+
r"(?:(?P<hours>\d+)H)?"
|
|
157
|
+
r"(?:(?P<minutes>\d+)M)?"
|
|
158
|
+
r"(?:(?P<seconds>\d+)S)?"
|
|
159
|
+
r")?$"
|
|
160
|
+
)
|
|
161
|
+
match = pattern.match(iso_duration)
|
|
162
|
+
if not match:
|
|
163
|
+
raise ValueError(f"Invalid ISO 8601 duration format: {iso_duration}")
|
|
164
|
+
|
|
165
|
+
parts = {k: int(v) if v else 0 for k, v in match.groupdict().items()}
|
|
166
|
+
return datetime.timedelta(**parts)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def parse_human_durations(text: str) -> list[datetime.timedelta]:
|
|
170
|
+
raw_parts = text.strip("[]").split("|")
|
|
171
|
+
durations = []
|
|
172
|
+
|
|
173
|
+
for part in raw_parts:
|
|
174
|
+
new_part = part.strip().lower()
|
|
175
|
+
|
|
176
|
+
# Match 1:24 or :45
|
|
177
|
+
m_colon = re.match(r"^(?:(\d+):)?(\d+)$", new_part)
|
|
178
|
+
if m_colon:
|
|
179
|
+
minutes = int(m_colon.group(1)) if m_colon.group(1) else 0
|
|
180
|
+
seconds = int(m_colon.group(2))
|
|
181
|
+
durations.append(datetime.timedelta(minutes=minutes, seconds=seconds))
|
|
182
|
+
continue
|
|
183
|
+
|
|
184
|
+
# Match "10 days", "1 minute", etc.
|
|
185
|
+
m_units = re.match(r"^(\d+)\s*(day|hour|minute|second)s?$", new_part)
|
|
186
|
+
if m_units:
|
|
187
|
+
value = int(m_units.group(1))
|
|
188
|
+
unit = m_units.group(2)
|
|
189
|
+
durations.append(datetime.timedelta(**{unit + "s": value}))
|
|
190
|
+
continue
|
|
191
|
+
|
|
192
|
+
print(f"Warning: could not parse '{part}'")
|
|
193
|
+
|
|
194
|
+
return durations
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def parse_duration(s: str) -> datetime.timedelta:
|
|
198
|
+
try:
|
|
199
|
+
return parse_iso8601_duration(s)
|
|
200
|
+
except ValueError:
|
|
201
|
+
parts = parse_human_durations(s)
|
|
202
|
+
if not parts:
|
|
203
|
+
raise ValueError(f"Could not parse duration: {s}")
|
|
204
|
+
return sum(parts, datetime.timedelta())
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class DateTimeType(click.DateTime):
|
|
208
|
+
_NOW_FMT = "now"
|
|
209
|
+
_TODAY_FMT = "today"
|
|
210
|
+
_FIXED_FORMATS: typing.ClassVar[typing.List[str]] = [_NOW_FMT, _TODAY_FMT]
|
|
211
|
+
_FLOATING_FORMATS: typing.ClassVar[typing.List[str]] = ["<FORMAT> - <ISO8601 duration>"]
|
|
212
|
+
_ADDITONAL_FORMATS: typing.ClassVar[typing.List[str]] = [*_FIXED_FORMATS, *_FLOATING_FORMATS]
|
|
213
|
+
_FLOATING_FORMAT_PATTERN = r"(.+)\s+([-+])\s+(.+)"
|
|
214
|
+
|
|
215
|
+
def __init__(self):
|
|
216
|
+
super().__init__()
|
|
217
|
+
self.formats.extend(self._ADDITONAL_FORMATS)
|
|
218
|
+
|
|
219
|
+
def _datetime_from_format(
|
|
220
|
+
self, value: str, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
221
|
+
) -> datetime.datetime:
|
|
222
|
+
if value in self._FIXED_FORMATS:
|
|
223
|
+
if value == self._NOW_FMT:
|
|
224
|
+
return datetime.datetime.now()
|
|
225
|
+
if value == self._TODAY_FMT:
|
|
226
|
+
n = datetime.datetime.now()
|
|
227
|
+
return datetime.datetime(n.year, n.month, n.day)
|
|
228
|
+
return super().convert(value, param, ctx)
|
|
229
|
+
|
|
230
|
+
def convert(
|
|
231
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
232
|
+
) -> typing.Any:
|
|
233
|
+
if isinstance(value, str) and " " in value:
|
|
234
|
+
import re
|
|
235
|
+
|
|
236
|
+
m = re.match(self._FLOATING_FORMAT_PATTERN, value)
|
|
237
|
+
if m:
|
|
238
|
+
parts = m.groups()
|
|
239
|
+
if len(parts) != 3:
|
|
240
|
+
raise click.BadParameter(f"Expected format <FORMAT> - <ISO8601 duration>, got {value}")
|
|
241
|
+
dt = self._datetime_from_format(parts[0], param, ctx)
|
|
242
|
+
try:
|
|
243
|
+
delta = parse_duration(parts[2])
|
|
244
|
+
except Exception as e:
|
|
245
|
+
raise click.BadParameter(
|
|
246
|
+
f"Matched format {self._FLOATING_FORMATS}, but failed to parse duration {parts[2]}, error: {e}"
|
|
247
|
+
)
|
|
248
|
+
if parts[1] == "-":
|
|
249
|
+
return dt - delta
|
|
250
|
+
return dt + delta
|
|
251
|
+
else:
|
|
252
|
+
value = datetime.datetime.fromisoformat(value)
|
|
253
|
+
|
|
254
|
+
return self._datetime_from_format(value, param, ctx)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class DurationParamType(click.ParamType):
|
|
258
|
+
name = "[1:24 | :22 | 1 minute | 10 days | ...]"
|
|
259
|
+
|
|
260
|
+
def convert(
|
|
261
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
262
|
+
) -> typing.Any:
|
|
263
|
+
if value is None:
|
|
264
|
+
raise click.BadParameter("None value cannot be converted to a Duration type.")
|
|
265
|
+
return parse_duration(value)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
class EnumParamType(click.Choice):
|
|
269
|
+
def __init__(self, enum_type: typing.Type[enum.Enum]):
|
|
270
|
+
super().__init__([str(e.value) for e in enum_type])
|
|
271
|
+
self._enum_type = enum_type
|
|
272
|
+
|
|
273
|
+
def convert(
|
|
274
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
275
|
+
) -> enum.Enum:
|
|
276
|
+
if isinstance(value, self._enum_type):
|
|
277
|
+
return value
|
|
278
|
+
return self._enum_type(super().convert(value, param, ctx))
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class UnionParamType(click.ParamType):
|
|
282
|
+
"""
|
|
283
|
+
A composite type that allows for multiple types to be specified. This is used for union types.
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
def __init__(self, types: typing.List[click.ParamType]):
|
|
287
|
+
super().__init__()
|
|
288
|
+
self._types = self._sort_precedence(types)
|
|
289
|
+
self.name = "|".join([t.name for t in self._types])
|
|
290
|
+
|
|
291
|
+
@staticmethod
|
|
292
|
+
def _sort_precedence(tp: typing.List[click.ParamType]) -> typing.List[click.ParamType]:
|
|
293
|
+
unprocessed = []
|
|
294
|
+
str_types = []
|
|
295
|
+
others = []
|
|
296
|
+
for p in tp:
|
|
297
|
+
if isinstance(p, type(click.UNPROCESSED)):
|
|
298
|
+
unprocessed.append(p)
|
|
299
|
+
elif isinstance(p, type(click.STRING)):
|
|
300
|
+
str_types.append(p)
|
|
301
|
+
else:
|
|
302
|
+
others.append(p)
|
|
303
|
+
return others + str_types + unprocessed # type: ignore
|
|
304
|
+
|
|
305
|
+
def convert(
|
|
306
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
307
|
+
) -> typing.Any:
|
|
308
|
+
"""
|
|
309
|
+
Important to implement NoneType / Optional.
|
|
310
|
+
Also could we just determine the click types from the python types
|
|
311
|
+
"""
|
|
312
|
+
for p in self._types:
|
|
313
|
+
try:
|
|
314
|
+
return p.convert(value, param, ctx)
|
|
315
|
+
except Exception as e:
|
|
316
|
+
logger.debug(f"Ignoring conversion error for type {p} trying other variants in Union. Error: {e}")
|
|
317
|
+
raise click.BadParameter(f"Failed to convert {value} to any of the types {self._types}")
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
class JsonParamType(click.ParamType):
|
|
321
|
+
name = "json object OR json/yaml file path"
|
|
322
|
+
|
|
323
|
+
def __init__(self, python_type: typing.Type):
|
|
324
|
+
super().__init__()
|
|
325
|
+
self._python_type = python_type
|
|
326
|
+
|
|
327
|
+
def _parse(self, value: typing.Any, param: typing.Optional[click.Parameter]):
|
|
328
|
+
if isinstance(value, (dict, list)):
|
|
329
|
+
return value
|
|
330
|
+
try:
|
|
331
|
+
return json.loads(value)
|
|
332
|
+
except Exception:
|
|
333
|
+
try:
|
|
334
|
+
# We failed to load the json, so we'll try to load it as a file
|
|
335
|
+
if os.path.exists(value):
|
|
336
|
+
# if the value is a yaml file, we'll try to load it as yaml
|
|
337
|
+
if value.endswith((".yaml", "yml")):
|
|
338
|
+
with open(value, "r") as f:
|
|
339
|
+
return yaml.safe_load(f)
|
|
340
|
+
with open(value, "r") as f:
|
|
341
|
+
return json.load(f)
|
|
342
|
+
raise
|
|
343
|
+
except json.JSONDecodeError as e:
|
|
344
|
+
raise click.BadParameter(f"parameter {param} should be a valid json object, {value}, error: {e}")
|
|
345
|
+
|
|
346
|
+
def convert(
|
|
347
|
+
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
|
|
348
|
+
) -> typing.Any:
|
|
349
|
+
if value is None:
|
|
350
|
+
raise click.BadParameter("None value cannot be converted to a Json type.")
|
|
351
|
+
|
|
352
|
+
parsed_value = self._parse(value, param)
|
|
353
|
+
|
|
354
|
+
# We compare the origin type because the json parsed value for list or dict is always a list or dict without
|
|
355
|
+
# the covariant type information.
|
|
356
|
+
if type(parsed_value) is typing.get_origin(self._python_type) or type(parsed_value) is self._python_type:
|
|
357
|
+
# Indexing the return value of get_args will raise an error for native dict and list types.
|
|
358
|
+
# We don't support native list/dict types with nested dataclasses.
|
|
359
|
+
if get_args(self._python_type) == ():
|
|
360
|
+
return parsed_value
|
|
361
|
+
elif isinstance(parsed_value, list) and dataclasses.is_dataclass(get_args(self._python_type)[0]):
|
|
362
|
+
j = JsonParamType(get_args(self._python_type)[0])
|
|
363
|
+
# turn object back into json string
|
|
364
|
+
return [j.convert(json.dumps(v), param, ctx) for v in parsed_value]
|
|
365
|
+
elif isinstance(parsed_value, dict) and dataclasses.is_dataclass(get_args(self._python_type)[1]):
|
|
366
|
+
j = JsonParamType(get_args(self._python_type)[1])
|
|
367
|
+
# turn object back into json string
|
|
368
|
+
return {k: j.convert(json.dumps(v), param, ctx) for k, v in parsed_value.items()}
|
|
369
|
+
|
|
370
|
+
return parsed_value
|
|
371
|
+
|
|
372
|
+
from pydantic import BaseModel
|
|
373
|
+
|
|
374
|
+
if issubclass(self._python_type, BaseModel):
|
|
375
|
+
return typing.cast(BaseModel, self._python_type).model_validate_json(
|
|
376
|
+
json.dumps(parsed_value), strict=False, context={"deserialize": True}
|
|
377
|
+
)
|
|
378
|
+
elif dataclasses.is_dataclass(self._python_type):
|
|
379
|
+
from mashumaro.codecs.json import JSONDecoder
|
|
380
|
+
|
|
381
|
+
decoder = JSONDecoder(self._python_type)
|
|
382
|
+
return decoder.decode(value)
|
|
383
|
+
|
|
384
|
+
return parsed_value
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
SIMPLE_TYPE_CONVERTER = {
|
|
388
|
+
SimpleType.FLOAT: click.FLOAT,
|
|
389
|
+
SimpleType.INTEGER: click.INT,
|
|
390
|
+
SimpleType.STRING: click.STRING,
|
|
391
|
+
SimpleType.BOOLEAN: click.BOOL,
|
|
392
|
+
SimpleType.DURATION: DurationParamType(),
|
|
393
|
+
SimpleType.DATETIME: DateTimeType(),
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def literal_type_to_click_type(lt: LiteralType, python_type: typing.Type) -> click.ParamType:
|
|
398
|
+
"""
|
|
399
|
+
Converts a Flyte LiteralType given a python_type to a click.ParamType
|
|
400
|
+
"""
|
|
401
|
+
if lt.HasField("simple"):
|
|
402
|
+
if lt.simple == SimpleType.STRUCT:
|
|
403
|
+
ct = JsonParamType(python_type)
|
|
404
|
+
ct.name = f"JSON object {python_type.__name__}"
|
|
405
|
+
return ct
|
|
406
|
+
if lt.simple in SIMPLE_TYPE_CONVERTER:
|
|
407
|
+
return SIMPLE_TYPE_CONVERTER[lt.simple]
|
|
408
|
+
raise NotImplementedError(f"Type {lt.simple} is not supported in pyflyte run")
|
|
409
|
+
|
|
410
|
+
if lt.HasField("structured_dataset_type"):
|
|
411
|
+
return StructuredDatasetParamType()
|
|
412
|
+
|
|
413
|
+
if lt.HasField("collection_type") or lt.HasField("map_value_type"):
|
|
414
|
+
ct = JsonParamType(python_type)
|
|
415
|
+
if lt.HasField("collection_type"):
|
|
416
|
+
ct.name = "json list"
|
|
417
|
+
else:
|
|
418
|
+
ct.name = "json dictionary"
|
|
419
|
+
return ct
|
|
420
|
+
|
|
421
|
+
if lt.HasField("blob"):
|
|
422
|
+
if lt.blob.dimensionality == BlobType.BlobDimensionality.SINGLE:
|
|
423
|
+
if lt.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT:
|
|
424
|
+
return PickleParamType()
|
|
425
|
+
# TODO: Add JSONIteratorTransformer
|
|
426
|
+
# elif lt.blob.format == JSONIteratorTransformer.JSON_ITERATOR_FORMAT:
|
|
427
|
+
# return JSONIteratorParamType()
|
|
428
|
+
return FileParamType()
|
|
429
|
+
return DirParamType()
|
|
430
|
+
|
|
431
|
+
if lt.HasField("union_type"):
|
|
432
|
+
cts = []
|
|
433
|
+
for i in range(len(lt.union_type.variants)):
|
|
434
|
+
variant = lt.union_type.variants[i]
|
|
435
|
+
variant_python_type = typing.get_args(python_type)[i]
|
|
436
|
+
cts.append(literal_type_to_click_type(variant, variant_python_type))
|
|
437
|
+
return UnionParamType(cts)
|
|
438
|
+
|
|
439
|
+
if lt.HasField("enum_type"):
|
|
440
|
+
return EnumParamType(python_type) # type: ignore
|
|
441
|
+
|
|
442
|
+
return click.UNPROCESSED
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
class FlyteLiteralConverter(object):
|
|
446
|
+
name = "literal_type"
|
|
447
|
+
|
|
448
|
+
def __init__(
|
|
449
|
+
self,
|
|
450
|
+
literal_type: LiteralType,
|
|
451
|
+
python_type: typing.Type,
|
|
452
|
+
):
|
|
453
|
+
self._literal_type = literal_type
|
|
454
|
+
self._python_type = python_type
|
|
455
|
+
self._click_type = literal_type_to_click_type(literal_type, python_type)
|
|
456
|
+
|
|
457
|
+
@property
|
|
458
|
+
def click_type(self) -> click.ParamType:
|
|
459
|
+
return self._click_type
|
|
460
|
+
|
|
461
|
+
def is_bool(self) -> bool:
|
|
462
|
+
return self.click_type == click.BOOL
|
|
463
|
+
|
|
464
|
+
def convert(
|
|
465
|
+
self, ctx: click.Context, param: typing.Optional[click.Parameter], value: typing.Any
|
|
466
|
+
) -> typing.Union[Literal, typing.Any]:
|
|
467
|
+
"""
|
|
468
|
+
Convert the value to a python native type. This is used by click to convert the input.
|
|
469
|
+
"""
|
|
470
|
+
try:
|
|
471
|
+
# If the expected Python type is datetime.date, adjust the value to date
|
|
472
|
+
if self._python_type is datetime.date:
|
|
473
|
+
# Click produces datetime, so converting to date to avoid type mismatch error
|
|
474
|
+
value = value.date()
|
|
475
|
+
|
|
476
|
+
return value
|
|
477
|
+
except click.BadParameter:
|
|
478
|
+
raise
|
|
479
|
+
except Exception as e:
|
|
480
|
+
raise click.BadParameter(
|
|
481
|
+
f"Failed to convert param: {param if param else 'NA'}, value: {value} to type: {self._python_type}."
|
|
482
|
+
f" Reason {e}"
|
|
483
|
+
) from e
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def to_click_option(
|
|
487
|
+
input_name: str,
|
|
488
|
+
literal_var: Variable,
|
|
489
|
+
python_type: typing.Type,
|
|
490
|
+
default_val: typing.Any,
|
|
491
|
+
) -> click.Option:
|
|
492
|
+
"""
|
|
493
|
+
This handles converting workflow input types to supported click parameters with callbacks to initialize
|
|
494
|
+
the input values to their expected types.
|
|
495
|
+
"""
|
|
496
|
+
from flyteidl.core.types_pb2 import SimpleType
|
|
497
|
+
|
|
498
|
+
if input_name != input_name.lower():
|
|
499
|
+
# Click does not support uppercase option names: https://github.com/pallets/click/issues/837
|
|
500
|
+
raise ValueError(f"Workflow input name must be lowercase: {input_name!r}")
|
|
501
|
+
|
|
502
|
+
literal_converter = FlyteLiteralConverter(
|
|
503
|
+
literal_type=literal_var.type,
|
|
504
|
+
python_type=python_type,
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
if literal_converter.is_bool() and not default_val:
|
|
508
|
+
default_val = False
|
|
509
|
+
|
|
510
|
+
description_extra = ""
|
|
511
|
+
if literal_var.type.simple == SimpleType.STRUCT:
|
|
512
|
+
if default_val:
|
|
513
|
+
# pydantic v2
|
|
514
|
+
if hasattr(default_val, "model_dump_json"):
|
|
515
|
+
default_val = default_val.model_dump_json()
|
|
516
|
+
else:
|
|
517
|
+
encoder = JSONEncoder(python_type)
|
|
518
|
+
default_val = encoder.encode(default_val)
|
|
519
|
+
if literal_var.type.metadata:
|
|
520
|
+
description_extra = f": {MessageToDict(literal_var.type.metadata)}"
|
|
521
|
+
|
|
522
|
+
# If a query has been specified, the input is never strictly required at this layer
|
|
523
|
+
required = False if default_val is not None else True
|
|
524
|
+
is_flag: typing.Optional[bool] = None
|
|
525
|
+
if literal_converter.is_bool():
|
|
526
|
+
required = False
|
|
527
|
+
is_flag = True
|
|
528
|
+
|
|
529
|
+
return click.Option(
|
|
530
|
+
param_decls=[f"--{input_name}"],
|
|
531
|
+
type=literal_converter.click_type,
|
|
532
|
+
is_flag=is_flag,
|
|
533
|
+
default=default_val,
|
|
534
|
+
show_default=True,
|
|
535
|
+
required=required,
|
|
536
|
+
help=literal_var.description + description_extra,
|
|
537
|
+
callback=literal_converter.convert,
|
|
538
|
+
)
|