flyte 0.1.0__py3-none-any.whl → 0.2.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (205) hide show
  1. flyte/__init__.py +62 -2
  2. flyte/_api_commons.py +3 -0
  3. flyte/_bin/__init__.py +0 -0
  4. flyte/_bin/runtime.py +126 -0
  5. flyte/_build.py +25 -0
  6. flyte/_cache/__init__.py +12 -0
  7. flyte/_cache/cache.py +146 -0
  8. flyte/_cache/defaults.py +9 -0
  9. flyte/_cache/policy_function_body.py +42 -0
  10. flyte/_cli/__init__.py +0 -0
  11. flyte/_cli/_common.py +299 -0
  12. flyte/_cli/_create.py +42 -0
  13. flyte/_cli/_delete.py +23 -0
  14. flyte/_cli/_deploy.py +140 -0
  15. flyte/_cli/_get.py +235 -0
  16. flyte/_cli/_params.py +538 -0
  17. flyte/_cli/_run.py +174 -0
  18. flyte/_cli/main.py +98 -0
  19. flyte/_code_bundle/__init__.py +8 -0
  20. flyte/_code_bundle/_ignore.py +113 -0
  21. flyte/_code_bundle/_packaging.py +187 -0
  22. flyte/_code_bundle/_utils.py +339 -0
  23. flyte/_code_bundle/bundle.py +178 -0
  24. flyte/_context.py +146 -0
  25. flyte/_datastructures.py +342 -0
  26. flyte/_deploy.py +202 -0
  27. flyte/_doc.py +29 -0
  28. flyte/_docstring.py +32 -0
  29. flyte/_environment.py +43 -0
  30. flyte/_group.py +31 -0
  31. flyte/_hash.py +23 -0
  32. flyte/_image.py +757 -0
  33. flyte/_initialize.py +643 -0
  34. flyte/_interface.py +84 -0
  35. flyte/_internal/__init__.py +3 -0
  36. flyte/_internal/controllers/__init__.py +115 -0
  37. flyte/_internal/controllers/_local_controller.py +118 -0
  38. flyte/_internal/controllers/_trace.py +40 -0
  39. flyte/_internal/controllers/pbhash.py +39 -0
  40. flyte/_internal/controllers/remote/__init__.py +40 -0
  41. flyte/_internal/controllers/remote/_action.py +141 -0
  42. flyte/_internal/controllers/remote/_client.py +43 -0
  43. flyte/_internal/controllers/remote/_controller.py +361 -0
  44. flyte/_internal/controllers/remote/_core.py +402 -0
  45. flyte/_internal/controllers/remote/_informer.py +361 -0
  46. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  47. flyte/_internal/imagebuild/__init__.py +11 -0
  48. flyte/_internal/imagebuild/docker_builder.py +416 -0
  49. flyte/_internal/imagebuild/image_builder.py +241 -0
  50. flyte/_internal/imagebuild/remote_builder.py +0 -0
  51. flyte/_internal/resolvers/__init__.py +0 -0
  52. flyte/_internal/resolvers/_task_module.py +54 -0
  53. flyte/_internal/resolvers/common.py +31 -0
  54. flyte/_internal/resolvers/default.py +28 -0
  55. flyte/_internal/runtime/__init__.py +0 -0
  56. flyte/_internal/runtime/convert.py +205 -0
  57. flyte/_internal/runtime/entrypoints.py +135 -0
  58. flyte/_internal/runtime/io.py +136 -0
  59. flyte/_internal/runtime/resources_serde.py +138 -0
  60. flyte/_internal/runtime/task_serde.py +210 -0
  61. flyte/_internal/runtime/taskrunner.py +190 -0
  62. flyte/_internal/runtime/types_serde.py +54 -0
  63. flyte/_logging.py +124 -0
  64. flyte/_protos/__init__.py +0 -0
  65. flyte/_protos/common/authorization_pb2.py +66 -0
  66. flyte/_protos/common/authorization_pb2.pyi +108 -0
  67. flyte/_protos/common/authorization_pb2_grpc.py +4 -0
  68. flyte/_protos/common/identifier_pb2.py +71 -0
  69. flyte/_protos/common/identifier_pb2.pyi +82 -0
  70. flyte/_protos/common/identifier_pb2_grpc.py +4 -0
  71. flyte/_protos/common/identity_pb2.py +48 -0
  72. flyte/_protos/common/identity_pb2.pyi +72 -0
  73. flyte/_protos/common/identity_pb2_grpc.py +4 -0
  74. flyte/_protos/common/list_pb2.py +36 -0
  75. flyte/_protos/common/list_pb2.pyi +69 -0
  76. flyte/_protos/common/list_pb2_grpc.py +4 -0
  77. flyte/_protos/common/policy_pb2.py +37 -0
  78. flyte/_protos/common/policy_pb2.pyi +27 -0
  79. flyte/_protos/common/policy_pb2_grpc.py +4 -0
  80. flyte/_protos/common/role_pb2.py +37 -0
  81. flyte/_protos/common/role_pb2.pyi +53 -0
  82. flyte/_protos/common/role_pb2_grpc.py +4 -0
  83. flyte/_protos/common/runtime_version_pb2.py +28 -0
  84. flyte/_protos/common/runtime_version_pb2.pyi +24 -0
  85. flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
  86. flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
  87. flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
  88. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  89. flyte/_protos/secret/definition_pb2.py +49 -0
  90. flyte/_protos/secret/definition_pb2.pyi +93 -0
  91. flyte/_protos/secret/definition_pb2_grpc.py +4 -0
  92. flyte/_protos/secret/payload_pb2.py +62 -0
  93. flyte/_protos/secret/payload_pb2.pyi +94 -0
  94. flyte/_protos/secret/payload_pb2_grpc.py +4 -0
  95. flyte/_protos/secret/secret_pb2.py +38 -0
  96. flyte/_protos/secret/secret_pb2.pyi +6 -0
  97. flyte/_protos/secret/secret_pb2_grpc.py +198 -0
  98. flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
  99. flyte/_protos/validate/validate/validate_pb2.py +76 -0
  100. flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
  101. flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  102. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  103. flyte/_protos/workflow/queue_service_pb2.py +106 -0
  104. flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
  105. flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  106. flyte/_protos/workflow/run_definition_pb2.py +128 -0
  107. flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
  108. flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  109. flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
  110. flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  111. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  112. flyte/_protos/workflow/run_service_pb2.py +133 -0
  113. flyte/_protos/workflow/run_service_pb2.pyi +175 -0
  114. flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
  115. flyte/_protos/workflow/state_service_pb2.py +58 -0
  116. flyte/_protos/workflow/state_service_pb2.pyi +71 -0
  117. flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
  118. flyte/_protos/workflow/task_definition_pb2.py +72 -0
  119. flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
  120. flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  121. flyte/_protos/workflow/task_service_pb2.py +44 -0
  122. flyte/_protos/workflow/task_service_pb2.pyi +31 -0
  123. flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
  124. flyte/_resources.py +226 -0
  125. flyte/_retry.py +32 -0
  126. flyte/_reusable_environment.py +25 -0
  127. flyte/_run.py +410 -0
  128. flyte/_secret.py +61 -0
  129. flyte/_task.py +367 -0
  130. flyte/_task_environment.py +200 -0
  131. flyte/_timeout.py +47 -0
  132. flyte/_tools.py +27 -0
  133. flyte/_trace.py +128 -0
  134. flyte/_utils/__init__.py +20 -0
  135. flyte/_utils/asyn.py +119 -0
  136. flyte/_utils/coro_management.py +25 -0
  137. flyte/_utils/file_handling.py +72 -0
  138. flyte/_utils/helpers.py +108 -0
  139. flyte/_utils/lazy_module.py +54 -0
  140. flyte/_utils/uv_script_parser.py +49 -0
  141. flyte/_version.py +21 -0
  142. flyte/config/__init__.py +168 -0
  143. flyte/config/_config.py +196 -0
  144. flyte/config/_internal.py +64 -0
  145. flyte/connectors/__init__.py +0 -0
  146. flyte/errors.py +143 -0
  147. flyte/extras/__init__.py +5 -0
  148. flyte/extras/_container.py +273 -0
  149. flyte/io/__init__.py +11 -0
  150. flyte/io/_dataframe.py +0 -0
  151. flyte/io/_dir.py +448 -0
  152. flyte/io/_file.py +468 -0
  153. flyte/io/pickle/__init__.py +0 -0
  154. flyte/io/pickle/transformer.py +117 -0
  155. flyte/io/structured_dataset/__init__.py +129 -0
  156. flyte/io/structured_dataset/basic_dfs.py +219 -0
  157. flyte/io/structured_dataset/structured_dataset.py +1061 -0
  158. flyte/remote/__init__.py +25 -0
  159. flyte/remote/_client/__init__.py +0 -0
  160. flyte/remote/_client/_protocols.py +131 -0
  161. flyte/remote/_client/auth/__init__.py +12 -0
  162. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  163. flyte/remote/_client/auth/_authenticators/base.py +397 -0
  164. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  165. flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
  166. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  167. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  168. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  169. flyte/remote/_client/auth/_channel.py +184 -0
  170. flyte/remote/_client/auth/_client_config.py +83 -0
  171. flyte/remote/_client/auth/_default_html.py +32 -0
  172. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  173. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  174. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  175. flyte/remote/_client/auth/_keyring.py +143 -0
  176. flyte/remote/_client/auth/_token_client.py +260 -0
  177. flyte/remote/_client/auth/errors.py +16 -0
  178. flyte/remote/_client/controlplane.py +95 -0
  179. flyte/remote/_console.py +18 -0
  180. flyte/remote/_data.py +155 -0
  181. flyte/remote/_logs.py +116 -0
  182. flyte/remote/_project.py +86 -0
  183. flyte/remote/_run.py +873 -0
  184. flyte/remote/_secret.py +132 -0
  185. flyte/remote/_task.py +227 -0
  186. flyte/report/__init__.py +3 -0
  187. flyte/report/_report.py +178 -0
  188. flyte/report/_template.html +124 -0
  189. flyte/storage/__init__.py +24 -0
  190. flyte/storage/_remote_fs.py +34 -0
  191. flyte/storage/_storage.py +251 -0
  192. flyte/storage/_utils.py +5 -0
  193. flyte/types/__init__.py +13 -0
  194. flyte/types/_interface.py +25 -0
  195. flyte/types/_renderer.py +162 -0
  196. flyte/types/_string_literals.py +120 -0
  197. flyte/types/_type_engine.py +2211 -0
  198. flyte/types/_utils.py +80 -0
  199. flyte-0.2.0b1.dist-info/METADATA +179 -0
  200. flyte-0.2.0b1.dist-info/RECORD +204 -0
  201. {flyte-0.1.0.dist-info → flyte-0.2.0b1.dist-info}/WHEEL +2 -1
  202. flyte-0.2.0b1.dist-info/entry_points.txt +3 -0
  203. flyte-0.2.0b1.dist-info/top_level.txt +1 -0
  204. flyte-0.1.0.dist-info/METADATA +0 -6
  205. flyte-0.1.0.dist-info/RECORD +0 -5
flyte/_cli/_params.py ADDED
@@ -0,0 +1,538 @@
1
+ import dataclasses
2
+ import datetime
3
+ import enum
4
+ import importlib
5
+ import importlib.util
6
+ import json
7
+ import os
8
+ import pathlib
9
+ import re
10
+ import sys
11
+ import typing
12
+ import typing as t
13
+ from typing import get_args
14
+
15
+ import rich_click as click
16
+ import yaml
17
+ from click import Parameter
18
+ from flyteidl.core.interface_pb2 import Variable
19
+ from flyteidl.core.literals_pb2 import Literal
20
+ from flyteidl.core.types_pb2 import BlobType, LiteralType, SimpleType
21
+ from google.protobuf.json_format import MessageToDict
22
+ from mashumaro.codecs.json import JSONEncoder
23
+
24
+ from flyte._logging import logger
25
+ from flyte.io import Dir, File
26
+ from flyte.io.pickle.transformer import FlytePickleTransformer
27
+
28
+
29
+ class StructuredDataset:
30
+ def __init__(self, uri: str | None = None, dataframe: typing.Any = None):
31
+ self.uri = uri
32
+ self.dataframe = dataframe
33
+
34
+
35
+ # ---------------------------------------------------
36
+
37
+
38
+ def key_value_callback(_: typing.Any, param: str, values: typing.List[str]) -> typing.Optional[typing.Dict[str, str]]:
39
+ """
40
+ Callback for click to parse key-value pairs.
41
+ """
42
+ if not values:
43
+ return None
44
+ result = {}
45
+ for v in values:
46
+ if "=" not in v:
47
+ raise click.BadParameter(f"Expected key-value pair of the form key=value, got {v}")
48
+ k, val = v.split("=", 1)
49
+ result[k.strip()] = val.strip()
50
+ return result
51
+
52
+
53
+ def labels_callback(_: typing.Any, param: str, values: typing.List[str]) -> typing.Optional[typing.Dict[str, str]]:
54
+ """
55
+ Callback for click to parse labels.
56
+ """
57
+ if not values:
58
+ return None
59
+ result = {}
60
+ for v in values:
61
+ if "=" not in v:
62
+ result[v.strip()] = ""
63
+ else:
64
+ k, val = v.split("=", 1)
65
+ result[k.strip()] = val.strip()
66
+ return result
67
+
68
+
69
+ class DirParamType(click.ParamType):
70
+ name = "directory path"
71
+
72
+ def convert(
73
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
74
+ ) -> typing.Any:
75
+ from flyte.storage import is_remote
76
+
77
+ if not is_remote(value):
78
+ p = pathlib.Path(value)
79
+ if not p.exists() or not p.is_dir():
80
+ raise click.BadParameter(f"parameter should be a valid flytedirectory path, {value}")
81
+ return Dir(path=value)
82
+
83
+
84
+ class StructuredDatasetParamType(click.ParamType):
85
+ """
86
+ TODO handle column types
87
+ """
88
+
89
+ name = "structured dataset path (dir/file)"
90
+
91
+ def convert(
92
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
93
+ ) -> typing.Any:
94
+ if isinstance(value, str):
95
+ return StructuredDataset(uri=value)
96
+ elif isinstance(value, StructuredDataset):
97
+ return value
98
+ return StructuredDataset(dataframe=value)
99
+
100
+
101
+ class FileParamType(click.ParamType):
102
+ name = "file path"
103
+
104
+ def convert(
105
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
106
+ ) -> typing.Any:
107
+ from flyte.storage import is_remote
108
+
109
+ if not is_remote(value):
110
+ p = pathlib.Path(value)
111
+ if not p.exists() or not p.is_file():
112
+ raise click.BadParameter(f"parameter should be a valid file path, {value}")
113
+ return File(path=value)
114
+
115
+
116
+ class PickleParamType(click.ParamType):
117
+ name = "pickle"
118
+
119
+ def get_metavar(self, param: "Parameter", *args) -> t.Optional[str]:
120
+ return "Python Object <Module>:<Object>"
121
+
122
+ def convert(
123
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
124
+ ) -> typing.Any:
125
+ if not isinstance(value, str):
126
+ return value
127
+ parts = value.split(":")
128
+ if len(parts) != 2:
129
+ if ctx and ctx.obj and ctx.obj.log_level >= 10: # DEBUG level
130
+ click.echo(f"Did not receive a string in the expected format <MODULE>:<VAR>, falling back to: {value}")
131
+ return value
132
+ try:
133
+ sys.path.insert(0, os.getcwd())
134
+ m = importlib.import_module(parts[0])
135
+ return m.__getattribute__(parts[1])
136
+ except ModuleNotFoundError as e:
137
+ raise click.BadParameter(f"Failed to import module {parts[0]}, error: {e}")
138
+ except AttributeError as e:
139
+ raise click.BadParameter(f"Failed to find attribute {parts[1]} in module {parts[0]}, error: {e}")
140
+
141
+
142
+ class JSONIteratorParamType(click.ParamType):
143
+ name = "json iterator"
144
+
145
+ def convert(
146
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
147
+ ) -> typing.Any:
148
+ return value
149
+
150
+
151
+ def parse_iso8601_duration(iso_duration: str) -> datetime.timedelta:
152
+ pattern = re.compile(
153
+ r"^P" # Starts with 'P'
154
+ r"(?:(?P<days>\d+)D)?" # Optional days
155
+ r"(?:T" # Optional time part
156
+ r"(?:(?P<hours>\d+)H)?"
157
+ r"(?:(?P<minutes>\d+)M)?"
158
+ r"(?:(?P<seconds>\d+)S)?"
159
+ r")?$"
160
+ )
161
+ match = pattern.match(iso_duration)
162
+ if not match:
163
+ raise ValueError(f"Invalid ISO 8601 duration format: {iso_duration}")
164
+
165
+ parts = {k: int(v) if v else 0 for k, v in match.groupdict().items()}
166
+ return datetime.timedelta(**parts)
167
+
168
+
169
+ def parse_human_durations(text: str) -> list[datetime.timedelta]:
170
+ raw_parts = text.strip("[]").split("|")
171
+ durations = []
172
+
173
+ for part in raw_parts:
174
+ new_part = part.strip().lower()
175
+
176
+ # Match 1:24 or :45
177
+ m_colon = re.match(r"^(?:(\d+):)?(\d+)$", new_part)
178
+ if m_colon:
179
+ minutes = int(m_colon.group(1)) if m_colon.group(1) else 0
180
+ seconds = int(m_colon.group(2))
181
+ durations.append(datetime.timedelta(minutes=minutes, seconds=seconds))
182
+ continue
183
+
184
+ # Match "10 days", "1 minute", etc.
185
+ m_units = re.match(r"^(\d+)\s*(day|hour|minute|second)s?$", new_part)
186
+ if m_units:
187
+ value = int(m_units.group(1))
188
+ unit = m_units.group(2)
189
+ durations.append(datetime.timedelta(**{unit + "s": value}))
190
+ continue
191
+
192
+ print(f"Warning: could not parse '{part}'")
193
+
194
+ return durations
195
+
196
+
197
+ def parse_duration(s: str) -> datetime.timedelta:
198
+ try:
199
+ return parse_iso8601_duration(s)
200
+ except ValueError:
201
+ parts = parse_human_durations(s)
202
+ if not parts:
203
+ raise ValueError(f"Could not parse duration: {s}")
204
+ return sum(parts, datetime.timedelta())
205
+
206
+
207
+ class DateTimeType(click.DateTime):
208
+ _NOW_FMT = "now"
209
+ _TODAY_FMT = "today"
210
+ _FIXED_FORMATS: typing.ClassVar[typing.List[str]] = [_NOW_FMT, _TODAY_FMT]
211
+ _FLOATING_FORMATS: typing.ClassVar[typing.List[str]] = ["<FORMAT> - <ISO8601 duration>"]
212
+ _ADDITONAL_FORMATS: typing.ClassVar[typing.List[str]] = [*_FIXED_FORMATS, *_FLOATING_FORMATS]
213
+ _FLOATING_FORMAT_PATTERN = r"(.+)\s+([-+])\s+(.+)"
214
+
215
+ def __init__(self):
216
+ super().__init__()
217
+ self.formats.extend(self._ADDITONAL_FORMATS)
218
+
219
+ def _datetime_from_format(
220
+ self, value: str, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
221
+ ) -> datetime.datetime:
222
+ if value in self._FIXED_FORMATS:
223
+ if value == self._NOW_FMT:
224
+ return datetime.datetime.now()
225
+ if value == self._TODAY_FMT:
226
+ n = datetime.datetime.now()
227
+ return datetime.datetime(n.year, n.month, n.day)
228
+ return super().convert(value, param, ctx)
229
+
230
+ def convert(
231
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
232
+ ) -> typing.Any:
233
+ if isinstance(value, str) and " " in value:
234
+ import re
235
+
236
+ m = re.match(self._FLOATING_FORMAT_PATTERN, value)
237
+ if m:
238
+ parts = m.groups()
239
+ if len(parts) != 3:
240
+ raise click.BadParameter(f"Expected format <FORMAT> - <ISO8601 duration>, got {value}")
241
+ dt = self._datetime_from_format(parts[0], param, ctx)
242
+ try:
243
+ delta = parse_duration(parts[2])
244
+ except Exception as e:
245
+ raise click.BadParameter(
246
+ f"Matched format {self._FLOATING_FORMATS}, but failed to parse duration {parts[2]}, error: {e}"
247
+ )
248
+ if parts[1] == "-":
249
+ return dt - delta
250
+ return dt + delta
251
+ else:
252
+ value = datetime.datetime.fromisoformat(value)
253
+
254
+ return self._datetime_from_format(value, param, ctx)
255
+
256
+
257
+ class DurationParamType(click.ParamType):
258
+ name = "[1:24 | :22 | 1 minute | 10 days | ...]"
259
+
260
+ def convert(
261
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
262
+ ) -> typing.Any:
263
+ if value is None:
264
+ raise click.BadParameter("None value cannot be converted to a Duration type.")
265
+ return parse_duration(value)
266
+
267
+
268
+ class EnumParamType(click.Choice):
269
+ def __init__(self, enum_type: typing.Type[enum.Enum]):
270
+ super().__init__([str(e.value) for e in enum_type])
271
+ self._enum_type = enum_type
272
+
273
+ def convert(
274
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
275
+ ) -> enum.Enum:
276
+ if isinstance(value, self._enum_type):
277
+ return value
278
+ return self._enum_type(super().convert(value, param, ctx))
279
+
280
+
281
+ class UnionParamType(click.ParamType):
282
+ """
283
+ A composite type that allows for multiple types to be specified. This is used for union types.
284
+ """
285
+
286
+ def __init__(self, types: typing.List[click.ParamType]):
287
+ super().__init__()
288
+ self._types = self._sort_precedence(types)
289
+ self.name = "|".join([t.name for t in self._types])
290
+
291
+ @staticmethod
292
+ def _sort_precedence(tp: typing.List[click.ParamType]) -> typing.List[click.ParamType]:
293
+ unprocessed = []
294
+ str_types = []
295
+ others = []
296
+ for p in tp:
297
+ if isinstance(p, type(click.UNPROCESSED)):
298
+ unprocessed.append(p)
299
+ elif isinstance(p, type(click.STRING)):
300
+ str_types.append(p)
301
+ else:
302
+ others.append(p)
303
+ return others + str_types + unprocessed # type: ignore
304
+
305
+ def convert(
306
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
307
+ ) -> typing.Any:
308
+ """
309
+ Important to implement NoneType / Optional.
310
+ Also could we just determine the click types from the python types
311
+ """
312
+ for p in self._types:
313
+ try:
314
+ return p.convert(value, param, ctx)
315
+ except Exception as e:
316
+ logger.debug(f"Ignoring conversion error for type {p} trying other variants in Union. Error: {e}")
317
+ raise click.BadParameter(f"Failed to convert {value} to any of the types {self._types}")
318
+
319
+
320
+ class JsonParamType(click.ParamType):
321
+ name = "json object OR json/yaml file path"
322
+
323
+ def __init__(self, python_type: typing.Type):
324
+ super().__init__()
325
+ self._python_type = python_type
326
+
327
+ def _parse(self, value: typing.Any, param: typing.Optional[click.Parameter]):
328
+ if isinstance(value, (dict, list)):
329
+ return value
330
+ try:
331
+ return json.loads(value)
332
+ except Exception:
333
+ try:
334
+ # We failed to load the json, so we'll try to load it as a file
335
+ if os.path.exists(value):
336
+ # if the value is a yaml file, we'll try to load it as yaml
337
+ if value.endswith((".yaml", "yml")):
338
+ with open(value, "r") as f:
339
+ return yaml.safe_load(f)
340
+ with open(value, "r") as f:
341
+ return json.load(f)
342
+ raise
343
+ except json.JSONDecodeError as e:
344
+ raise click.BadParameter(f"parameter {param} should be a valid json object, {value}, error: {e}")
345
+
346
+ def convert(
347
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
348
+ ) -> typing.Any:
349
+ if value is None:
350
+ raise click.BadParameter("None value cannot be converted to a Json type.")
351
+
352
+ parsed_value = self._parse(value, param)
353
+
354
+ # We compare the origin type because the json parsed value for list or dict is always a list or dict without
355
+ # the covariant type information.
356
+ if type(parsed_value) is typing.get_origin(self._python_type) or type(parsed_value) is self._python_type:
357
+ # Indexing the return value of get_args will raise an error for native dict and list types.
358
+ # We don't support native list/dict types with nested dataclasses.
359
+ if get_args(self._python_type) == ():
360
+ return parsed_value
361
+ elif isinstance(parsed_value, list) and dataclasses.is_dataclass(get_args(self._python_type)[0]):
362
+ j = JsonParamType(get_args(self._python_type)[0])
363
+ # turn object back into json string
364
+ return [j.convert(json.dumps(v), param, ctx) for v in parsed_value]
365
+ elif isinstance(parsed_value, dict) and dataclasses.is_dataclass(get_args(self._python_type)[1]):
366
+ j = JsonParamType(get_args(self._python_type)[1])
367
+ # turn object back into json string
368
+ return {k: j.convert(json.dumps(v), param, ctx) for k, v in parsed_value.items()}
369
+
370
+ return parsed_value
371
+
372
+ from pydantic import BaseModel
373
+
374
+ if issubclass(self._python_type, BaseModel):
375
+ return typing.cast(BaseModel, self._python_type).model_validate_json(
376
+ json.dumps(parsed_value), strict=False, context={"deserialize": True}
377
+ )
378
+ elif dataclasses.is_dataclass(self._python_type):
379
+ from mashumaro.codecs.json import JSONDecoder
380
+
381
+ decoder = JSONDecoder(self._python_type)
382
+ return decoder.decode(value)
383
+
384
+ return parsed_value
385
+
386
+
387
+ SIMPLE_TYPE_CONVERTER = {
388
+ SimpleType.FLOAT: click.FLOAT,
389
+ SimpleType.INTEGER: click.INT,
390
+ SimpleType.STRING: click.STRING,
391
+ SimpleType.BOOLEAN: click.BOOL,
392
+ SimpleType.DURATION: DurationParamType(),
393
+ SimpleType.DATETIME: DateTimeType(),
394
+ }
395
+
396
+
397
+ def literal_type_to_click_type(lt: LiteralType, python_type: typing.Type) -> click.ParamType:
398
+ """
399
+ Converts a Flyte LiteralType given a python_type to a click.ParamType
400
+ """
401
+ if lt.HasField("simple"):
402
+ if lt.simple == SimpleType.STRUCT:
403
+ ct = JsonParamType(python_type)
404
+ ct.name = f"JSON object {python_type.__name__}"
405
+ return ct
406
+ if lt.simple in SIMPLE_TYPE_CONVERTER:
407
+ return SIMPLE_TYPE_CONVERTER[lt.simple]
408
+ raise NotImplementedError(f"Type {lt.simple} is not supported in pyflyte run")
409
+
410
+ if lt.HasField("structured_dataset_type"):
411
+ return StructuredDatasetParamType()
412
+
413
+ if lt.HasField("collection_type") or lt.HasField("map_value_type"):
414
+ ct = JsonParamType(python_type)
415
+ if lt.HasField("collection_type"):
416
+ ct.name = "json list"
417
+ else:
418
+ ct.name = "json dictionary"
419
+ return ct
420
+
421
+ if lt.HasField("blob"):
422
+ if lt.blob.dimensionality == BlobType.BlobDimensionality.SINGLE:
423
+ if lt.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT:
424
+ return PickleParamType()
425
+ # TODO: Add JSONIteratorTransformer
426
+ # elif lt.blob.format == JSONIteratorTransformer.JSON_ITERATOR_FORMAT:
427
+ # return JSONIteratorParamType()
428
+ return FileParamType()
429
+ return DirParamType()
430
+
431
+ if lt.HasField("union_type"):
432
+ cts = []
433
+ for i in range(len(lt.union_type.variants)):
434
+ variant = lt.union_type.variants[i]
435
+ variant_python_type = typing.get_args(python_type)[i]
436
+ cts.append(literal_type_to_click_type(variant, variant_python_type))
437
+ return UnionParamType(cts)
438
+
439
+ if lt.HasField("enum_type"):
440
+ return EnumParamType(python_type) # type: ignore
441
+
442
+ return click.UNPROCESSED
443
+
444
+
445
+ class FlyteLiteralConverter(object):
446
+ name = "literal_type"
447
+
448
+ def __init__(
449
+ self,
450
+ literal_type: LiteralType,
451
+ python_type: typing.Type,
452
+ ):
453
+ self._literal_type = literal_type
454
+ self._python_type = python_type
455
+ self._click_type = literal_type_to_click_type(literal_type, python_type)
456
+
457
+ @property
458
+ def click_type(self) -> click.ParamType:
459
+ return self._click_type
460
+
461
+ def is_bool(self) -> bool:
462
+ return self.click_type == click.BOOL
463
+
464
+ def convert(
465
+ self, ctx: click.Context, param: typing.Optional[click.Parameter], value: typing.Any
466
+ ) -> typing.Union[Literal, typing.Any]:
467
+ """
468
+ Convert the value to a python native type. This is used by click to convert the input.
469
+ """
470
+ try:
471
+ # If the expected Python type is datetime.date, adjust the value to date
472
+ if self._python_type is datetime.date:
473
+ # Click produces datetime, so converting to date to avoid type mismatch error
474
+ value = value.date()
475
+
476
+ return value
477
+ except click.BadParameter:
478
+ raise
479
+ except Exception as e:
480
+ raise click.BadParameter(
481
+ f"Failed to convert param: {param if param else 'NA'}, value: {value} to type: {self._python_type}."
482
+ f" Reason {e}"
483
+ ) from e
484
+
485
+
486
+ def to_click_option(
487
+ input_name: str,
488
+ literal_var: Variable,
489
+ python_type: typing.Type,
490
+ default_val: typing.Any,
491
+ ) -> click.Option:
492
+ """
493
+ This handles converting workflow input types to supported click parameters with callbacks to initialize
494
+ the input values to their expected types.
495
+ """
496
+ from flyteidl.core.types_pb2 import SimpleType
497
+
498
+ if input_name != input_name.lower():
499
+ # Click does not support uppercase option names: https://github.com/pallets/click/issues/837
500
+ raise ValueError(f"Workflow input name must be lowercase: {input_name!r}")
501
+
502
+ literal_converter = FlyteLiteralConverter(
503
+ literal_type=literal_var.type,
504
+ python_type=python_type,
505
+ )
506
+
507
+ if literal_converter.is_bool() and not default_val:
508
+ default_val = False
509
+
510
+ description_extra = ""
511
+ if literal_var.type.simple == SimpleType.STRUCT:
512
+ if default_val:
513
+ # pydantic v2
514
+ if hasattr(default_val, "model_dump_json"):
515
+ default_val = default_val.model_dump_json()
516
+ else:
517
+ encoder = JSONEncoder(python_type)
518
+ default_val = encoder.encode(default_val)
519
+ if literal_var.type.metadata:
520
+ description_extra = f": {MessageToDict(literal_var.type.metadata)}"
521
+
522
+ # If a query has been specified, the input is never strictly required at this layer
523
+ required = False if default_val is not None else True
524
+ is_flag: typing.Optional[bool] = None
525
+ if literal_converter.is_bool():
526
+ required = False
527
+ is_flag = True
528
+
529
+ return click.Option(
530
+ param_decls=[f"--{input_name}"],
531
+ type=literal_converter.click_type,
532
+ is_flag=is_flag,
533
+ default=default_val,
534
+ show_default=True,
535
+ required=required,
536
+ help=literal_var.description + description_extra,
537
+ callback=literal_converter.convert,
538
+ )