flyte 2.0.0b32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (204) hide show
  1. flyte/__init__.py +108 -0
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +195 -0
  5. flyte/_bin/serve.py +178 -0
  6. flyte/_build.py +26 -0
  7. flyte/_cache/__init__.py +12 -0
  8. flyte/_cache/cache.py +147 -0
  9. flyte/_cache/defaults.py +9 -0
  10. flyte/_cache/local_cache.py +216 -0
  11. flyte/_cache/policy_function_body.py +42 -0
  12. flyte/_code_bundle/__init__.py +8 -0
  13. flyte/_code_bundle/_ignore.py +121 -0
  14. flyte/_code_bundle/_packaging.py +218 -0
  15. flyte/_code_bundle/_utils.py +347 -0
  16. flyte/_code_bundle/bundle.py +266 -0
  17. flyte/_constants.py +1 -0
  18. flyte/_context.py +155 -0
  19. flyte/_custom_context.py +73 -0
  20. flyte/_debug/__init__.py +0 -0
  21. flyte/_debug/constants.py +38 -0
  22. flyte/_debug/utils.py +17 -0
  23. flyte/_debug/vscode.py +307 -0
  24. flyte/_deploy.py +408 -0
  25. flyte/_deployer.py +109 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +122 -0
  29. flyte/_excepthook.py +37 -0
  30. flyte/_group.py +32 -0
  31. flyte/_hash.py +8 -0
  32. flyte/_image.py +1055 -0
  33. flyte/_initialize.py +628 -0
  34. flyte/_interface.py +119 -0
  35. flyte/_internal/__init__.py +3 -0
  36. flyte/_internal/controllers/__init__.py +129 -0
  37. flyte/_internal/controllers/_local_controller.py +239 -0
  38. flyte/_internal/controllers/_trace.py +48 -0
  39. flyte/_internal/controllers/remote/__init__.py +58 -0
  40. flyte/_internal/controllers/remote/_action.py +211 -0
  41. flyte/_internal/controllers/remote/_client.py +47 -0
  42. flyte/_internal/controllers/remote/_controller.py +583 -0
  43. flyte/_internal/controllers/remote/_core.py +465 -0
  44. flyte/_internal/controllers/remote/_informer.py +381 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +3 -0
  47. flyte/_internal/imagebuild/docker_builder.py +706 -0
  48. flyte/_internal/imagebuild/image_builder.py +277 -0
  49. flyte/_internal/imagebuild/remote_builder.py +386 -0
  50. flyte/_internal/imagebuild/utils.py +78 -0
  51. flyte/_internal/resolvers/__init__.py +0 -0
  52. flyte/_internal/resolvers/_task_module.py +21 -0
  53. flyte/_internal/resolvers/common.py +31 -0
  54. flyte/_internal/resolvers/default.py +28 -0
  55. flyte/_internal/runtime/__init__.py +0 -0
  56. flyte/_internal/runtime/convert.py +486 -0
  57. flyte/_internal/runtime/entrypoints.py +204 -0
  58. flyte/_internal/runtime/io.py +188 -0
  59. flyte/_internal/runtime/resources_serde.py +152 -0
  60. flyte/_internal/runtime/reuse.py +125 -0
  61. flyte/_internal/runtime/rusty.py +193 -0
  62. flyte/_internal/runtime/task_serde.py +362 -0
  63. flyte/_internal/runtime/taskrunner.py +209 -0
  64. flyte/_internal/runtime/trigger_serde.py +160 -0
  65. flyte/_internal/runtime/types_serde.py +54 -0
  66. flyte/_keyring/__init__.py +0 -0
  67. flyte/_keyring/file.py +115 -0
  68. flyte/_logging.py +300 -0
  69. flyte/_map.py +312 -0
  70. flyte/_module.py +72 -0
  71. flyte/_pod.py +30 -0
  72. flyte/_resources.py +473 -0
  73. flyte/_retry.py +32 -0
  74. flyte/_reusable_environment.py +102 -0
  75. flyte/_run.py +724 -0
  76. flyte/_secret.py +96 -0
  77. flyte/_task.py +550 -0
  78. flyte/_task_environment.py +316 -0
  79. flyte/_task_plugins.py +47 -0
  80. flyte/_timeout.py +47 -0
  81. flyte/_tools.py +27 -0
  82. flyte/_trace.py +119 -0
  83. flyte/_trigger.py +1000 -0
  84. flyte/_utils/__init__.py +30 -0
  85. flyte/_utils/asyn.py +121 -0
  86. flyte/_utils/async_cache.py +139 -0
  87. flyte/_utils/coro_management.py +27 -0
  88. flyte/_utils/docker_credentials.py +173 -0
  89. flyte/_utils/file_handling.py +72 -0
  90. flyte/_utils/helpers.py +134 -0
  91. flyte/_utils/lazy_module.py +54 -0
  92. flyte/_utils/module_loader.py +104 -0
  93. flyte/_utils/org_discovery.py +57 -0
  94. flyte/_utils/uv_script_parser.py +49 -0
  95. flyte/_version.py +34 -0
  96. flyte/app/__init__.py +22 -0
  97. flyte/app/_app_environment.py +157 -0
  98. flyte/app/_deploy.py +125 -0
  99. flyte/app/_input.py +160 -0
  100. flyte/app/_runtime/__init__.py +3 -0
  101. flyte/app/_runtime/app_serde.py +347 -0
  102. flyte/app/_types.py +101 -0
  103. flyte/app/extras/__init__.py +3 -0
  104. flyte/app/extras/_fastapi.py +151 -0
  105. flyte/cli/__init__.py +12 -0
  106. flyte/cli/_abort.py +28 -0
  107. flyte/cli/_build.py +114 -0
  108. flyte/cli/_common.py +468 -0
  109. flyte/cli/_create.py +371 -0
  110. flyte/cli/_delete.py +45 -0
  111. flyte/cli/_deploy.py +293 -0
  112. flyte/cli/_gen.py +176 -0
  113. flyte/cli/_get.py +370 -0
  114. flyte/cli/_option.py +33 -0
  115. flyte/cli/_params.py +554 -0
  116. flyte/cli/_plugins.py +209 -0
  117. flyte/cli/_run.py +597 -0
  118. flyte/cli/_serve.py +64 -0
  119. flyte/cli/_update.py +37 -0
  120. flyte/cli/_user.py +17 -0
  121. flyte/cli/main.py +221 -0
  122. flyte/config/__init__.py +3 -0
  123. flyte/config/_config.py +248 -0
  124. flyte/config/_internal.py +73 -0
  125. flyte/config/_reader.py +225 -0
  126. flyte/connectors/__init__.py +11 -0
  127. flyte/connectors/_connector.py +270 -0
  128. flyte/connectors/_server.py +197 -0
  129. flyte/connectors/utils.py +135 -0
  130. flyte/errors.py +243 -0
  131. flyte/extend.py +19 -0
  132. flyte/extras/__init__.py +5 -0
  133. flyte/extras/_container.py +286 -0
  134. flyte/git/__init__.py +3 -0
  135. flyte/git/_config.py +21 -0
  136. flyte/io/__init__.py +29 -0
  137. flyte/io/_dataframe/__init__.py +131 -0
  138. flyte/io/_dataframe/basic_dfs.py +223 -0
  139. flyte/io/_dataframe/dataframe.py +1026 -0
  140. flyte/io/_dir.py +910 -0
  141. flyte/io/_file.py +914 -0
  142. flyte/io/_hashing_io.py +342 -0
  143. flyte/models.py +479 -0
  144. flyte/py.typed +0 -0
  145. flyte/remote/__init__.py +35 -0
  146. flyte/remote/_action.py +738 -0
  147. flyte/remote/_app.py +57 -0
  148. flyte/remote/_client/__init__.py +0 -0
  149. flyte/remote/_client/_protocols.py +189 -0
  150. flyte/remote/_client/auth/__init__.py +12 -0
  151. flyte/remote/_client/auth/_auth_utils.py +14 -0
  152. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  153. flyte/remote/_client/auth/_authenticators/base.py +403 -0
  154. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  155. flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
  156. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  157. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  158. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  159. flyte/remote/_client/auth/_channel.py +213 -0
  160. flyte/remote/_client/auth/_client_config.py +85 -0
  161. flyte/remote/_client/auth/_default_html.py +32 -0
  162. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  163. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  164. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  165. flyte/remote/_client/auth/_keyring.py +152 -0
  166. flyte/remote/_client/auth/_token_client.py +260 -0
  167. flyte/remote/_client/auth/errors.py +16 -0
  168. flyte/remote/_client/controlplane.py +128 -0
  169. flyte/remote/_common.py +30 -0
  170. flyte/remote/_console.py +19 -0
  171. flyte/remote/_data.py +161 -0
  172. flyte/remote/_logs.py +185 -0
  173. flyte/remote/_project.py +88 -0
  174. flyte/remote/_run.py +386 -0
  175. flyte/remote/_secret.py +142 -0
  176. flyte/remote/_task.py +527 -0
  177. flyte/remote/_trigger.py +306 -0
  178. flyte/remote/_user.py +33 -0
  179. flyte/report/__init__.py +3 -0
  180. flyte/report/_report.py +182 -0
  181. flyte/report/_template.html +124 -0
  182. flyte/storage/__init__.py +36 -0
  183. flyte/storage/_config.py +237 -0
  184. flyte/storage/_parallel_reader.py +274 -0
  185. flyte/storage/_remote_fs.py +34 -0
  186. flyte/storage/_storage.py +456 -0
  187. flyte/storage/_utils.py +5 -0
  188. flyte/syncify/__init__.py +56 -0
  189. flyte/syncify/_api.py +375 -0
  190. flyte/types/__init__.py +52 -0
  191. flyte/types/_interface.py +40 -0
  192. flyte/types/_pickle.py +145 -0
  193. flyte/types/_renderer.py +162 -0
  194. flyte/types/_string_literals.py +119 -0
  195. flyte/types/_type_engine.py +2254 -0
  196. flyte/types/_utils.py +80 -0
  197. flyte-2.0.0b32.data/scripts/debug.py +38 -0
  198. flyte-2.0.0b32.data/scripts/runtime.py +195 -0
  199. flyte-2.0.0b32.dist-info/METADATA +351 -0
  200. flyte-2.0.0b32.dist-info/RECORD +204 -0
  201. flyte-2.0.0b32.dist-info/WHEEL +5 -0
  202. flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
  203. flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
  204. flyte-2.0.0b32.dist-info/top_level.txt +1 -0
flyte/cli/_params.py ADDED
@@ -0,0 +1,554 @@
1
+ import dataclasses
2
+ import datetime
3
+ import enum
4
+ import importlib
5
+ import importlib.util
6
+ import json
7
+ import os
8
+ import pathlib
9
+ import re
10
+ import sys
11
+ import typing
12
+ import typing as t
13
+ from typing import get_args
14
+
15
+ import rich_click as click
16
+ import yaml
17
+ from click import Parameter
18
+ from flyteidl2.core.interface_pb2 import Variable
19
+ from flyteidl2.core.literals_pb2 import Literal
20
+ from flyteidl2.core.types_pb2 import BlobType, LiteralType, SimpleType
21
+ from google.protobuf.json_format import MessageToDict
22
+ from mashumaro.codecs.json import JSONEncoder
23
+
24
+ from flyte._logging import logger
25
+ from flyte.io import Dir, File
26
+ from flyte.types._pickle import FlytePickleTransformer
27
+
28
+
29
+ class StructuredDataset:
30
+ def __init__(self, uri: str | None = None, dataframe: typing.Any = None):
31
+ self.uri = uri
32
+ self.dataframe = dataframe
33
+
34
+
35
+ # ---------------------------------------------------
36
+
37
+
38
+ def key_value_callback(_: typing.Any, param: str, values: typing.List[str]) -> typing.Optional[typing.Dict[str, str]]:
39
+ """
40
+ Callback for click to parse key-value pairs.
41
+ """
42
+ if not values:
43
+ return None
44
+ result = {}
45
+ for v in values:
46
+ if "=" not in v:
47
+ raise click.BadParameter(f"Expected key-value pair of the form key=value, got {v}")
48
+ k, val = v.split("=", 1)
49
+ result[k.strip()] = val.strip()
50
+ return result
51
+
52
+
53
+ def labels_callback(_: typing.Any, param: str, values: typing.List[str]) -> typing.Optional[typing.Dict[str, str]]:
54
+ """
55
+ Callback for click to parse labels.
56
+ """
57
+ if not values:
58
+ return None
59
+ result = {}
60
+ for v in values:
61
+ if "=" not in v:
62
+ result[v.strip()] = ""
63
+ else:
64
+ k, val = v.split("=", 1)
65
+ result[k.strip()] = val.strip()
66
+ return result
67
+
68
+
69
+ class DirParamType(click.ParamType):
70
+ name = "directory path"
71
+
72
+ def convert(
73
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
74
+ ) -> typing.Any:
75
+ from flyte.storage import is_remote
76
+
77
+ if not is_remote(value):
78
+ p = pathlib.Path(value)
79
+ if not p.exists() or not p.is_dir():
80
+ raise click.BadParameter(f"parameter should be a valid flytedirectory path, {value}")
81
+ return Dir(path=value)
82
+
83
+
84
+ class StructuredDatasetParamType(click.ParamType):
85
+ """
86
+ TODO handle column types
87
+ """
88
+
89
+ name = "structured dataset path (dir/file)"
90
+
91
+ def convert(
92
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
93
+ ) -> typing.Any:
94
+ if isinstance(value, str):
95
+ return StructuredDataset(uri=value)
96
+ elif isinstance(value, StructuredDataset):
97
+ return value
98
+ return StructuredDataset(dataframe=value)
99
+
100
+
101
+ class FileParamType(click.ParamType):
102
+ name = "file path"
103
+
104
+ def convert(
105
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
106
+ ) -> typing.Any:
107
+ from flyte.storage import is_remote
108
+
109
+ if not is_remote(value):
110
+ p = pathlib.Path(value)
111
+ if not p.exists() or not p.is_file():
112
+ raise click.BadParameter(f"parameter should be a valid file path, {value}")
113
+ return File.from_existing_remote(value)
114
+
115
+
116
+ class PickleParamType(click.ParamType):
117
+ name = "pickle"
118
+
119
+ def get_metavar(self, param: "Parameter", *args) -> t.Optional[str]:
120
+ return "Python Object <Module>:<Object>"
121
+
122
+ def convert(
123
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
124
+ ) -> typing.Any:
125
+ if not isinstance(value, str):
126
+ return value
127
+ parts = value.split(":")
128
+ if len(parts) != 2:
129
+ if ctx and ctx.obj and ctx.obj.log_level >= 10: # DEBUG level
130
+ click.echo(f"Did not receive a string in the expected format <MODULE>:<VAR>, falling back to: {value}")
131
+ return value
132
+ try:
133
+ sys.path.insert(0, os.getcwd())
134
+ m = importlib.import_module(parts[0])
135
+ return m.__getattribute__(parts[1])
136
+ except ModuleNotFoundError as e:
137
+ raise click.BadParameter(f"Failed to import module {parts[0]}, error: {e}")
138
+ except AttributeError as e:
139
+ raise click.BadParameter(f"Failed to find attribute {parts[1]} in module {parts[0]}, error: {e}")
140
+
141
+
142
+ class JSONIteratorParamType(click.ParamType):
143
+ name = "json iterator"
144
+
145
+ def convert(
146
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
147
+ ) -> typing.Any:
148
+ return value
149
+
150
+
151
+ def parse_iso8601_duration(iso_duration: str) -> datetime.timedelta:
152
+ pattern = re.compile(
153
+ r"^P" # Starts with 'P'
154
+ r"(?:(?P<days>\d+)D)?" # Optional days
155
+ r"(?:T" # Optional time part
156
+ r"(?:(?P<hours>\d+)H)?"
157
+ r"(?:(?P<minutes>\d+)M)?"
158
+ r"(?:(?P<seconds>\d+)S)?"
159
+ r")?$"
160
+ )
161
+ match = pattern.match(iso_duration)
162
+ if not match:
163
+ raise ValueError(f"Invalid ISO 8601 duration format: {iso_duration}")
164
+
165
+ parts = {k: int(v) if v else 0 for k, v in match.groupdict().items()}
166
+ return datetime.timedelta(**parts)
167
+
168
+
169
+ def parse_human_durations(text: str) -> list[datetime.timedelta]:
170
+ raw_parts = text.strip("[]").split("|")
171
+ durations = []
172
+
173
+ for part in raw_parts:
174
+ new_part = part.strip().lower()
175
+
176
+ # Match 1:24 or :45
177
+ m_colon = re.match(r"^(?:(\d+):)?(\d+)$", new_part)
178
+ if m_colon:
179
+ minutes = int(m_colon.group(1)) if m_colon.group(1) else 0
180
+ seconds = int(m_colon.group(2))
181
+ durations.append(datetime.timedelta(minutes=minutes, seconds=seconds))
182
+ continue
183
+
184
+ # Match "10 days", "1 minute", etc.
185
+ m_units = re.match(r"^(\d+)\s*(day|hour|minute|second)s?$", new_part)
186
+ if m_units:
187
+ value = int(m_units.group(1))
188
+ unit = m_units.group(2)
189
+ durations.append(datetime.timedelta(**{unit + "s": value}))
190
+ continue
191
+
192
+ print(f"Warning: could not parse '{part}'")
193
+
194
+ return durations
195
+
196
+
197
+ def parse_duration(s: str) -> datetime.timedelta:
198
+ try:
199
+ return parse_iso8601_duration(s)
200
+ except ValueError:
201
+ parts = parse_human_durations(s)
202
+ if not parts:
203
+ raise ValueError(f"Could not parse duration: {s}")
204
+ return sum(parts, datetime.timedelta())
205
+
206
+
207
+ class DateTimeType(click.DateTime):
208
+ _NOW_FMT = "now"
209
+ _TODAY_FMT = "today"
210
+ _FIXED_FORMATS: typing.ClassVar[typing.List[str]] = [_NOW_FMT, _TODAY_FMT]
211
+ _FLOATING_FORMATS: typing.ClassVar[typing.List[str]] = ["<FORMAT> - <ISO8601 duration>"]
212
+ _ADDITONAL_FORMATS: typing.ClassVar[typing.List[str]] = [*_FIXED_FORMATS, *_FLOATING_FORMATS]
213
+ _FLOATING_FORMAT_PATTERN = r"(.+)\s+([-+])\s+(.+)"
214
+
215
+ def __init__(self):
216
+ super().__init__()
217
+ self.formats.extend(self._ADDITONAL_FORMATS)
218
+
219
+ def _datetime_from_format(
220
+ self, value: str, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
221
+ ) -> datetime.datetime:
222
+ if value in self._FIXED_FORMATS:
223
+ if value == self._NOW_FMT:
224
+ return datetime.datetime.now()
225
+ if value == self._TODAY_FMT:
226
+ n = datetime.datetime.now()
227
+ return datetime.datetime(n.year, n.month, n.day)
228
+ return super().convert(value, param, ctx)
229
+
230
+ def convert(
231
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
232
+ ) -> typing.Any:
233
+ if isinstance(value, str) and " " in value:
234
+ import re
235
+
236
+ m = re.match(self._FLOATING_FORMAT_PATTERN, value)
237
+ if m:
238
+ parts = m.groups()
239
+ if len(parts) != 3:
240
+ raise click.BadParameter(f"Expected format <FORMAT> - <ISO8601 duration>, got {value}")
241
+ dt = self._datetime_from_format(parts[0], param, ctx)
242
+ try:
243
+ delta = parse_duration(parts[2])
244
+ except Exception as e:
245
+ raise click.BadParameter(
246
+ f"Matched format {self._FLOATING_FORMATS}, but failed to parse duration {parts[2]}, error: {e}"
247
+ )
248
+ if parts[1] == "-":
249
+ return dt - delta
250
+ return dt + delta
251
+ else:
252
+ value = datetime.datetime.fromisoformat(value)
253
+
254
+ return self._datetime_from_format(value, param, ctx)
255
+
256
+
257
+ class DurationParamType(click.ParamType):
258
+ name = "[1:24 | :22 | 1 minute | 10 days | ...]"
259
+
260
+ def convert(
261
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
262
+ ) -> typing.Any:
263
+ if value is None:
264
+ raise click.BadParameter("None value cannot be converted to a Duration type.")
265
+ return parse_duration(value)
266
+
267
+
268
+ class EnumParamType(click.Choice):
269
+ def __init__(self, enum_type: typing.Type[enum.Enum]):
270
+ super().__init__([str(e.value) for e in enum_type])
271
+ self._enum_type = enum_type
272
+
273
+ def convert(
274
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
275
+ ) -> enum.Enum:
276
+ if isinstance(value, self._enum_type):
277
+ return value
278
+ return self._enum_type(super().convert(value, param, ctx))
279
+
280
+
281
+ class UnionParamType(click.ParamType):
282
+ """
283
+ A composite type that allows for multiple types to be specified. This is used for union types.
284
+ """
285
+
286
+ def __init__(self, types: typing.List[click.ParamType | None]):
287
+ super().__init__()
288
+ self._types = self._sort_precedence(types)
289
+ self.name = "|".join([t.name for t in self._types if t is not None])
290
+ self.optional = False
291
+ if None in types:
292
+ self.name = f"Optional[{self.name}]"
293
+ self.optional = True
294
+
295
+ @staticmethod
296
+ def _sort_precedence(tp: typing.List[click.ParamType | None]) -> typing.List[click.ParamType]:
297
+ unprocessed = []
298
+ str_types = []
299
+ others = []
300
+ for p in tp:
301
+ if isinstance(p, type(click.UNPROCESSED)):
302
+ unprocessed.append(p)
303
+ elif isinstance(p, type(click.STRING)):
304
+ str_types.append(p)
305
+ else:
306
+ others.append(p)
307
+ return others + str_types + unprocessed # type: ignore
308
+
309
+ def convert(
310
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
311
+ ) -> typing.Any:
312
+ """
313
+ Important to implement NoneType / Optional.
314
+ Also could we just determine the click types from the python types
315
+ """
316
+ for p in self._types:
317
+ try:
318
+ if p is None and value is None:
319
+ return None
320
+ return p.convert(value, param, ctx)
321
+ except Exception as e:
322
+ logger.debug(f"Ignoring conversion error for type {p} trying other variants in Union. Error: {e}")
323
+ raise click.BadParameter(f"Failed to convert {value} to any of the types {self._types}")
324
+
325
+
326
+ class JsonParamType(click.ParamType):
327
+ name = "json object OR json/yaml file path"
328
+
329
+ def __init__(self, python_type: typing.Type):
330
+ super().__init__()
331
+ self._python_type = python_type
332
+
333
+ def _parse(self, value: typing.Any, param: typing.Optional[click.Parameter]):
334
+ if isinstance(value, (dict, list)):
335
+ return value
336
+ try:
337
+ return json.loads(value)
338
+ except Exception:
339
+ try:
340
+ # We failed to load the json, so we'll try to load it as a file
341
+ if os.path.exists(value):
342
+ # if the value is a yaml file, we'll try to load it as yaml
343
+ if value.endswith((".yaml", "yml")):
344
+ with open(value, "r") as f:
345
+ return yaml.safe_load(f)
346
+ with open(value, "r") as f:
347
+ return json.load(f)
348
+ raise
349
+ except json.JSONDecodeError as e:
350
+ raise click.BadParameter(f"parameter {param} should be a valid json object, {value}, error: {e}")
351
+
352
+ def convert(
353
+ self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
354
+ ) -> typing.Any:
355
+ if value is None:
356
+ raise click.BadParameter("None value cannot be converted to a Json type.")
357
+
358
+ parsed_value = self._parse(value, param)
359
+
360
+ # We compare the origin type because the json parsed value for list or dict is always a list or dict without
361
+ # the covariant type information.
362
+ if type(parsed_value) is typing.get_origin(self._python_type) or type(parsed_value) is self._python_type:
363
+ # Indexing the return value of get_args will raise an error for native dict and list types.
364
+ # We don't support native list/dict types with nested dataclasses.
365
+ if get_args(self._python_type) == ():
366
+ return parsed_value
367
+ elif isinstance(parsed_value, list) and dataclasses.is_dataclass(get_args(self._python_type)[0]):
368
+ j = JsonParamType(get_args(self._python_type)[0])
369
+ # turn object back into json string
370
+ return [j.convert(json.dumps(v), param, ctx) for v in parsed_value]
371
+ elif isinstance(parsed_value, dict) and dataclasses.is_dataclass(get_args(self._python_type)[1]):
372
+ j = JsonParamType(get_args(self._python_type)[1])
373
+ # turn object back into json string
374
+ return {k: j.convert(json.dumps(v), param, ctx) for k, v in parsed_value.items()}
375
+
376
+ return parsed_value
377
+
378
+ from pydantic import BaseModel
379
+
380
+ if issubclass(self._python_type, BaseModel):
381
+ return typing.cast(BaseModel, self._python_type).model_validate_json(
382
+ json.dumps(parsed_value), strict=False, context={"deserialize": True}
383
+ )
384
+ elif dataclasses.is_dataclass(self._python_type):
385
+ from mashumaro.codecs.json import JSONDecoder
386
+
387
+ decoder = JSONDecoder(self._python_type)
388
+ return decoder.decode(value)
389
+
390
+ return parsed_value
391
+
392
+
393
+ SIMPLE_TYPE_CONVERTER = {
394
+ SimpleType.FLOAT: click.FLOAT,
395
+ SimpleType.INTEGER: click.INT,
396
+ SimpleType.STRING: click.STRING,
397
+ SimpleType.BOOLEAN: click.BOOL,
398
+ SimpleType.DURATION: DurationParamType(),
399
+ SimpleType.DATETIME: DateTimeType(),
400
+ }
401
+
402
+
403
+ def literal_type_to_click_type(lt: LiteralType, python_type: typing.Type) -> click.ParamType:
404
+ """
405
+ Converts a Flyte LiteralType given a python_type to a click.ParamType
406
+ """
407
+ if lt.HasField("simple"):
408
+ if lt.simple == SimpleType.STRUCT:
409
+ ct = JsonParamType(python_type)
410
+ ct.name = f"JSON object {python_type.__name__}"
411
+ return ct
412
+ if lt.simple in SIMPLE_TYPE_CONVERTER:
413
+ return SIMPLE_TYPE_CONVERTER[lt.simple]
414
+ raise NotImplementedError(f"Type {lt.simple} is not supported in `flyte run`")
415
+
416
+ if lt.HasField("structured_dataset_type"):
417
+ return StructuredDatasetParamType()
418
+
419
+ if lt.HasField("collection_type") or lt.HasField("map_value_type"):
420
+ ct = JsonParamType(python_type)
421
+ if lt.HasField("collection_type"):
422
+ ct.name = "json list"
423
+ else:
424
+ ct.name = "json dictionary"
425
+ return ct
426
+
427
+ if lt.HasField("blob"):
428
+ if lt.blob.dimensionality == BlobType.BlobDimensionality.SINGLE:
429
+ if lt.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT:
430
+ return PickleParamType()
431
+ # TODO: Add JSONIteratorTransformer
432
+ # elif lt.blob.format == JSONIteratorTransformer.JSON_ITERATOR_FORMAT:
433
+ # return JSONIteratorParamType()
434
+ return FileParamType()
435
+ return DirParamType()
436
+
437
+ if lt.HasField("union_type"):
438
+ cts = []
439
+ for i in range(len(lt.union_type.variants)):
440
+ variant = lt.union_type.variants[i]
441
+ variant_python_type = typing.get_args(python_type)[i]
442
+ if variant_python_type is type(None):
443
+ cts.append(None)
444
+ else:
445
+ cts.append(literal_type_to_click_type(variant, variant_python_type))
446
+ return UnionParamType(cts)
447
+
448
+ if lt.HasField("enum_type"):
449
+ return EnumParamType(python_type) # type: ignore
450
+
451
+ return click.UNPROCESSED
452
+
453
+
454
+ class FlyteLiteralConverter(object):
455
+ name = "literal_type"
456
+
457
+ def __init__(
458
+ self,
459
+ literal_type: LiteralType,
460
+ python_type: typing.Type,
461
+ ):
462
+ self._literal_type = literal_type
463
+ self._python_type = python_type
464
+ self._click_type = literal_type_to_click_type(literal_type, python_type)
465
+
466
+ @property
467
+ def click_type(self) -> click.ParamType:
468
+ return self._click_type
469
+
470
+ def is_bool(self) -> bool:
471
+ return self.click_type == click.BOOL
472
+
473
+ def is_optional(self) -> bool:
474
+ return isinstance(self.click_type, UnionParamType) and self.click_type.optional
475
+
476
+ def convert(
477
+ self, ctx: click.Context, param: typing.Optional[click.Parameter], value: typing.Any
478
+ ) -> typing.Union[Literal, typing.Any]:
479
+ """
480
+ Convert the value to a python native type. This is used by click to convert the input.
481
+ """
482
+ try:
483
+ # If the expected Python type is datetime.date, adjust the value to date
484
+ if self._python_type is datetime.date:
485
+ # Click produces datetime, so converting to date to avoid type mismatch error
486
+ value = value.date()
487
+
488
+ return value
489
+ except click.BadParameter:
490
+ raise
491
+ except Exception as e:
492
+ raise click.BadParameter(
493
+ f"Failed to convert param: {param if param else 'NA'}, value: {value} to type: {self._python_type}."
494
+ f" Reason {e}"
495
+ ) from e
496
+
497
+
498
+ def to_click_option(
499
+ input_name: str,
500
+ literal_var: Variable,
501
+ python_type: typing.Type,
502
+ default_val: typing.Any,
503
+ ) -> click.Option:
504
+ """
505
+ This handles converting workflow input types to supported click parameters with callbacks to initialize
506
+ the input values to their expected types.
507
+ """
508
+ from flyteidl2.core.types_pb2 import SimpleType
509
+
510
+ if input_name != input_name.lower():
511
+ # Click does not support uppercase option names: https://github.com/pallets/click/issues/837
512
+ raise ValueError(f"Workflow input name must be lowercase: {input_name!r}")
513
+
514
+ literal_converter = FlyteLiteralConverter(
515
+ literal_type=literal_var.type,
516
+ python_type=python_type,
517
+ )
518
+
519
+ if literal_converter.is_bool() and not default_val:
520
+ default_val = False
521
+
522
+ description_extra = ""
523
+ if literal_var.type.simple == SimpleType.STRUCT:
524
+ if default_val:
525
+ # pydantic v2
526
+ if hasattr(default_val, "model_dump_json"):
527
+ default_val = default_val.model_dump_json()
528
+ else:
529
+ encoder = JSONEncoder(python_type)
530
+ default_val = encoder.encode(default_val)
531
+ if literal_var.type.metadata:
532
+ description_extra = f": {MessageToDict(literal_var.type.metadata)}"
533
+
534
+ required = False if default_val is not None else True
535
+ is_flag: typing.Optional[bool] = None
536
+ param_decls = [f"--{input_name}"]
537
+ if literal_converter.is_bool():
538
+ required = False
539
+ is_flag = True
540
+ if default_val is True:
541
+ param_decls = [f"--{input_name}/--no-{input_name}"]
542
+ if literal_converter.is_optional():
543
+ required = False
544
+
545
+ return click.Option(
546
+ param_decls=param_decls,
547
+ type=literal_converter.click_type,
548
+ is_flag=is_flag,
549
+ default=default_val,
550
+ show_default=True,
551
+ required=required,
552
+ help=literal_var.description + description_extra,
553
+ callback=literal_converter.convert,
554
+ )