flyte 0.0.1b1__py3-none-any.whl → 0.0.1b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (210) hide show
  1. flyte/_cli/_common.py +0 -12
  2. flyte/_cli/_run.py +2 -24
  3. flyte/_cli/main.py +2 -28
  4. flyte/_image.py +6 -10
  5. flyte/_initialize.py +15 -24
  6. flyte/_internal/imagebuild/docker_builder.py +2 -2
  7. flyte/_internal/runtime/convert.py +0 -6
  8. flyte/_run.py +1 -0
  9. flyte/_version.py +2 -2
  10. flyte/remote/_console.py +1 -1
  11. flyte/types/_type_engine.py +3 -4
  12. {flyte-0.0.1b1.dist-info → flyte-0.0.1b2.dist-info}/METADATA +1 -1
  13. flyte-0.0.1b2.dist-info/RECORD +390 -0
  14. union/__init__.py +54 -0
  15. union/_api_commons.py +3 -0
  16. union/_bin/__init__.py +0 -0
  17. union/_bin/runtime.py +113 -0
  18. union/_build.py +25 -0
  19. union/_cache/__init__.py +12 -0
  20. union/_cache/cache.py +141 -0
  21. union/_cache/defaults.py +9 -0
  22. union/_cache/policy_function_body.py +42 -0
  23. union/_cli/__init__.py +0 -0
  24. union/_cli/_common.py +263 -0
  25. union/_cli/_create.py +40 -0
  26. union/_cli/_delete.py +23 -0
  27. union/_cli/_deploy.py +120 -0
  28. union/_cli/_get.py +162 -0
  29. {flyte → union}/_cli/_params.py +147 -106
  30. union/_cli/_run.py +150 -0
  31. union/_cli/main.py +72 -0
  32. union/_code_bundle/__init__.py +8 -0
  33. union/_code_bundle/_ignore.py +113 -0
  34. union/_code_bundle/_packaging.py +187 -0
  35. union/_code_bundle/_utils.py +342 -0
  36. union/_code_bundle/bundle.py +176 -0
  37. union/_context.py +146 -0
  38. union/_datastructures.py +295 -0
  39. union/_deploy.py +185 -0
  40. union/_doc.py +29 -0
  41. union/_docstring.py +26 -0
  42. union/_environment.py +43 -0
  43. union/_group.py +31 -0
  44. union/_hash.py +23 -0
  45. union/_image.py +760 -0
  46. union/_initialize.py +585 -0
  47. union/_interface.py +84 -0
  48. union/_internal/__init__.py +3 -0
  49. union/_internal/controllers/__init__.py +77 -0
  50. union/_internal/controllers/_local_controller.py +77 -0
  51. union/_internal/controllers/pbhash.py +39 -0
  52. union/_internal/controllers/remote/__init__.py +40 -0
  53. union/_internal/controllers/remote/_action.py +131 -0
  54. union/_internal/controllers/remote/_client.py +43 -0
  55. union/_internal/controllers/remote/_controller.py +169 -0
  56. union/_internal/controllers/remote/_core.py +341 -0
  57. union/_internal/controllers/remote/_informer.py +260 -0
  58. union/_internal/controllers/remote/_service_protocol.py +44 -0
  59. union/_internal/imagebuild/__init__.py +11 -0
  60. union/_internal/imagebuild/docker_builder.py +416 -0
  61. union/_internal/imagebuild/image_builder.py +243 -0
  62. union/_internal/imagebuild/remote_builder.py +0 -0
  63. union/_internal/resolvers/__init__.py +0 -0
  64. union/_internal/resolvers/_task_module.py +31 -0
  65. union/_internal/resolvers/common.py +24 -0
  66. union/_internal/resolvers/default.py +27 -0
  67. union/_internal/runtime/__init__.py +0 -0
  68. union/_internal/runtime/convert.py +163 -0
  69. union/_internal/runtime/entrypoints.py +121 -0
  70. union/_internal/runtime/io.py +136 -0
  71. union/_internal/runtime/resources_serde.py +134 -0
  72. union/_internal/runtime/task_serde.py +202 -0
  73. union/_internal/runtime/taskrunner.py +179 -0
  74. union/_internal/runtime/types_serde.py +53 -0
  75. union/_logging.py +124 -0
  76. union/_protos/__init__.py +0 -0
  77. union/_protos/common/authorization_pb2.py +66 -0
  78. union/_protos/common/authorization_pb2.pyi +106 -0
  79. union/_protos/common/authorization_pb2_grpc.py +4 -0
  80. union/_protos/common/identifier_pb2.py +71 -0
  81. union/_protos/common/identifier_pb2.pyi +82 -0
  82. union/_protos/common/identifier_pb2_grpc.py +4 -0
  83. union/_protos/common/identity_pb2.py +48 -0
  84. union/_protos/common/identity_pb2.pyi +72 -0
  85. union/_protos/common/identity_pb2_grpc.py +4 -0
  86. union/_protos/common/list_pb2.py +36 -0
  87. union/_protos/common/list_pb2.pyi +69 -0
  88. union/_protos/common/list_pb2_grpc.py +4 -0
  89. union/_protos/common/policy_pb2.py +37 -0
  90. union/_protos/common/policy_pb2.pyi +27 -0
  91. union/_protos/common/policy_pb2_grpc.py +4 -0
  92. union/_protos/common/role_pb2.py +37 -0
  93. union/_protos/common/role_pb2.pyi +51 -0
  94. union/_protos/common/role_pb2_grpc.py +4 -0
  95. union/_protos/common/runtime_version_pb2.py +28 -0
  96. union/_protos/common/runtime_version_pb2.pyi +24 -0
  97. union/_protos/common/runtime_version_pb2_grpc.py +4 -0
  98. union/_protos/logs/dataplane/payload_pb2.py +96 -0
  99. union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
  100. union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  101. union/_protos/secret/definition_pb2.py +49 -0
  102. union/_protos/secret/definition_pb2.pyi +93 -0
  103. union/_protos/secret/definition_pb2_grpc.py +4 -0
  104. union/_protos/secret/payload_pb2.py +62 -0
  105. union/_protos/secret/payload_pb2.pyi +94 -0
  106. union/_protos/secret/payload_pb2_grpc.py +4 -0
  107. union/_protos/secret/secret_pb2.py +38 -0
  108. union/_protos/secret/secret_pb2.pyi +6 -0
  109. union/_protos/secret/secret_pb2_grpc.py +198 -0
  110. union/_protos/validate/validate/validate_pb2.py +76 -0
  111. union/_protos/workflow/node_execution_service_pb2.py +26 -0
  112. union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  113. union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  114. union/_protos/workflow/queue_service_pb2.py +75 -0
  115. union/_protos/workflow/queue_service_pb2.pyi +103 -0
  116. union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  117. union/_protos/workflow/run_definition_pb2.py +100 -0
  118. union/_protos/workflow/run_definition_pb2.pyi +256 -0
  119. union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  120. union/_protos/workflow/run_logs_service_pb2.py +41 -0
  121. union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  122. union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  123. union/_protos/workflow/run_service_pb2.py +133 -0
  124. union/_protos/workflow/run_service_pb2.pyi +173 -0
  125. union/_protos/workflow/run_service_pb2_grpc.py +412 -0
  126. union/_protos/workflow/state_service_pb2.py +58 -0
  127. union/_protos/workflow/state_service_pb2.pyi +69 -0
  128. union/_protos/workflow/state_service_pb2_grpc.py +138 -0
  129. union/_protos/workflow/task_definition_pb2.py +72 -0
  130. union/_protos/workflow/task_definition_pb2.pyi +65 -0
  131. union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  132. union/_protos/workflow/task_service_pb2.py +44 -0
  133. union/_protos/workflow/task_service_pb2.pyi +31 -0
  134. union/_protos/workflow/task_service_pb2_grpc.py +104 -0
  135. union/_resources.py +226 -0
  136. union/_retry.py +32 -0
  137. union/_reusable_environment.py +25 -0
  138. union/_run.py +374 -0
  139. union/_secret.py +61 -0
  140. union/_task.py +354 -0
  141. union/_task_environment.py +186 -0
  142. union/_timeout.py +47 -0
  143. union/_tools.py +27 -0
  144. union/_utils/__init__.py +11 -0
  145. union/_utils/asyn.py +119 -0
  146. union/_utils/file_handling.py +71 -0
  147. union/_utils/helpers.py +46 -0
  148. union/_utils/lazy_module.py +54 -0
  149. union/_utils/uv_script_parser.py +49 -0
  150. union/_version.py +21 -0
  151. union/connectors/__init__.py +0 -0
  152. union/errors.py +128 -0
  153. union/extras/__init__.py +5 -0
  154. union/extras/_container.py +263 -0
  155. union/io/__init__.py +11 -0
  156. union/io/_dataframe.py +0 -0
  157. union/io/_dir.py +425 -0
  158. union/io/_file.py +418 -0
  159. union/io/pickle/__init__.py +0 -0
  160. union/io/pickle/transformer.py +117 -0
  161. union/io/structured_dataset/__init__.py +122 -0
  162. union/io/structured_dataset/basic_dfs.py +219 -0
  163. union/io/structured_dataset/structured_dataset.py +1057 -0
  164. union/py.typed +0 -0
  165. union/remote/__init__.py +23 -0
  166. union/remote/_client/__init__.py +0 -0
  167. union/remote/_client/_protocols.py +129 -0
  168. union/remote/_client/auth/__init__.py +12 -0
  169. union/remote/_client/auth/_authenticators/__init__.py +0 -0
  170. union/remote/_client/auth/_authenticators/base.py +391 -0
  171. union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  172. union/remote/_client/auth/_authenticators/device_code.py +120 -0
  173. union/remote/_client/auth/_authenticators/external_command.py +77 -0
  174. union/remote/_client/auth/_authenticators/factory.py +200 -0
  175. union/remote/_client/auth/_authenticators/pkce.py +515 -0
  176. union/remote/_client/auth/_channel.py +184 -0
  177. union/remote/_client/auth/_client_config.py +83 -0
  178. union/remote/_client/auth/_default_html.py +32 -0
  179. union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  180. union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
  181. union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
  182. union/remote/_client/auth/_keyring.py +154 -0
  183. union/remote/_client/auth/_token_client.py +258 -0
  184. union/remote/_client/auth/errors.py +16 -0
  185. union/remote/_client/controlplane.py +86 -0
  186. union/remote/_data.py +149 -0
  187. union/remote/_logs.py +74 -0
  188. union/remote/_project.py +86 -0
  189. union/remote/_run.py +820 -0
  190. union/remote/_secret.py +132 -0
  191. union/remote/_task.py +193 -0
  192. union/report/__init__.py +3 -0
  193. union/report/_report.py +178 -0
  194. union/report/_template.html +124 -0
  195. union/storage/__init__.py +24 -0
  196. union/storage/_remote_fs.py +34 -0
  197. union/storage/_storage.py +247 -0
  198. union/storage/_utils.py +5 -0
  199. union/types/__init__.py +11 -0
  200. union/types/_renderer.py +162 -0
  201. union/types/_string_literals.py +120 -0
  202. union/types/_type_engine.py +2131 -0
  203. union/types/_utils.py +80 -0
  204. flyte/config/__init__.py +0 -168
  205. flyte/config/_config.py +0 -196
  206. flyte/config/_internal.py +0 -64
  207. flyte-0.0.1b1.dist-info/RECORD +0 -204
  208. {flyte-0.0.1b1.dist-info → flyte-0.0.1b2.dist-info}/WHEEL +0 -0
  209. {flyte-0.0.1b1.dist-info → flyte-0.0.1b2.dist-info}/entry_points.txt +0 -0
  210. {flyte-0.0.1b1.dist-info → flyte-0.0.1b2.dist-info}/top_level.txt +0 -0
union/_cli/_get.py ADDED
@@ -0,0 +1,162 @@
1
+ import asyncio
2
+
3
+ import rich_click as click
4
+ from rich.console import Console
5
+
6
+ from . import _common as common
7
+
8
+
9
+ @click.group(name="get")
10
+ def get():
11
+ """
12
+ Get the value of a task or environment.
13
+ """
14
+
15
+
16
+ @get.command()
17
+ @click.argument("name", type=str, required=False)
18
+ @click.pass_obj
19
+ def project(cfg: common.CLIConfig, name: str | None = None):
20
+ """
21
+ Get the current project.
22
+ """
23
+ from union.remote import Project
24
+
25
+ print(cfg)
26
+ cfg.init()
27
+
28
+ console = Console()
29
+ if name:
30
+ console.print(Project.get(name))
31
+ else:
32
+ console.print(common.get_table("Projects", Project.listall()))
33
+
34
+
35
+ @get.command(cls=common.CommandBase)
36
+ @click.argument("name", type=str, required=False)
37
+ @click.pass_obj
38
+ def run(cfg: common.CLIConfig, name: str | None = None, project: str | None = None, domain: str | None = None):
39
+ """
40
+ Get the current run.
41
+ """
42
+ from union.remote import Run, RunDetails
43
+
44
+ cfg.init(project=project, domain=domain)
45
+
46
+ console = Console()
47
+ if name:
48
+ details = RunDetails.get(name=name)
49
+ console.print(details)
50
+ else:
51
+ console.print(common.get_table("Runs", Run.listall()))
52
+
53
+
54
+ @get.command(cls=common.CommandBase)
55
+ @click.argument("name", type=str, required=False)
56
+ @click.argument("version", type=str, required=False)
57
+ @click.pass_obj
58
+ def task(
59
+ cfg: common.CLIConfig,
60
+ name: str | None = None,
61
+ version: str | None = None,
62
+ project: str | None = None,
63
+ domain: str | None = None,
64
+ ):
65
+ """
66
+ Get the current task.
67
+ """
68
+ from union.remote import Task
69
+
70
+ cfg.init(project=project, domain=domain)
71
+
72
+ console = Console()
73
+ if name:
74
+ v = Task.get(name=name, version=version)
75
+ if v is None:
76
+ raise click.BadParameter(f"Task {name} not found.")
77
+ t = v.fetch()
78
+ console.print(t)
79
+ else:
80
+ raise click.BadParameter("Task listing is not supported yet, please provide a name.")
81
+ # console.print(common.get_table("Tasks", Task.listall()))
82
+
83
+
84
+ @get.command(cls=common.CommandBase)
85
+ @click.argument("run_name", type=str, required=True)
86
+ @click.argument("action_name", type=str, required=False)
87
+ @click.pass_obj
88
+ def action(
89
+ cfg: common.CLIConfig,
90
+ run_name: str,
91
+ action_name: str | None = None,
92
+ project: str | None = None,
93
+ domain: str | None = None,
94
+ ):
95
+ """
96
+ Get all actions for a run or details for a specific action.
97
+ """
98
+ import union.remote as remote
99
+
100
+ cfg.init(project=project, domain=domain)
101
+
102
+ console = Console()
103
+ if action_name:
104
+ console.print(remote.Action.get(run_name=run_name, name=action_name))
105
+ else:
106
+ # List all actions for the run
107
+ console.print(common.get_table(f"Actions for {run_name}", remote.Action.listall(for_run_name=run_name)))
108
+
109
+
110
+ @get.command(cls=common.CommandBase)
111
+ @click.argument("run_name", type=str, required=False)
112
+ @click.argument("action_name", type=str, required=False)
113
+ @click.pass_obj
114
+ def logs(
115
+ cfg: common.CLIConfig,
116
+ run_name: str,
117
+ action_name: str | None = None,
118
+ project: str | None = None,
119
+ domain: str | None = None,
120
+ ):
121
+ """
122
+ Get the current run.
123
+ """
124
+ import union.remote as remote
125
+
126
+ cfg.init(project=project, domain=domain)
127
+
128
+ async def _run_log_view(_obj):
129
+ task = asyncio.create_task(_obj.show_logs())
130
+ try:
131
+ await task
132
+ except KeyboardInterrupt:
133
+ task.cancel()
134
+
135
+ if action_name:
136
+ obj = remote.Action.get(run_name=run_name, name=action_name)
137
+ else:
138
+ obj = remote.Run.get(run_name)
139
+ asyncio.run(_run_log_view(obj))
140
+
141
+
142
+ @get.command(cls=common.CommandBase)
143
+ @click.argument("name", type=str, required=False)
144
+ @click.pass_obj
145
+ def secret(
146
+ cfg: common.CLIConfig,
147
+ name: str | None = None,
148
+ project: str | None = None,
149
+ domain: str | None = None,
150
+ ):
151
+ """
152
+ Get the current secret.
153
+ """
154
+ import union.remote as remote
155
+
156
+ cfg.init(project=project, domain=domain)
157
+
158
+ console = Console()
159
+ if name:
160
+ console.print(remote.Secret.get(name))
161
+ else:
162
+ console.print(common.get_table("Secrets", remote.Secret.listall()))
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import dataclasses
2
3
  import datetime
3
4
  import enum
@@ -6,7 +7,6 @@ import importlib.util
6
7
  import json
7
8
  import os
8
9
  import pathlib
9
- import re
10
10
  import sys
11
11
  import typing
12
12
  import typing as t
@@ -14,16 +14,24 @@ from typing import get_args
14
14
 
15
15
  import rich_click as click
16
16
  import yaml
17
- from click import Parameter
18
- from flyteidl.core.interface_pb2 import Variable
19
17
  from flyteidl.core.literals_pb2 import Literal
20
18
  from flyteidl.core.types_pb2 import BlobType, LiteralType, SimpleType
21
- from google.protobuf.json_format import MessageToDict
22
- from mashumaro.codecs.json import JSONEncoder
23
19
 
24
- from flyte._logging import logger
25
- from flyte.io import Dir, File
26
- from flyte.io.pickle.transformer import FlytePickleTransformer
20
+ from union._logging import logger
21
+ from union.io import Dir, File
22
+ from union.io.pickle.transformer import FlytePickleTransformer
23
+ from union.storage._remote_fs import RemoteFSPathResolver
24
+ from union.types import TypeEngine
25
+
26
+
27
+ # ---------------------------------------------------
28
+ # TODO replace these
29
+ class ArtifactQuery:
30
+ pass
31
+
32
+
33
+ def is_remote(v: str) -> bool:
34
+ return False
27
35
 
28
36
 
29
37
  class StructuredDataset:
@@ -35,6 +43,26 @@ class StructuredDataset:
35
43
  # ---------------------------------------------------
36
44
 
37
45
 
46
+ def is_pydantic_basemodel(python_type: typing.Type) -> bool:
47
+ """
48
+ Checks if the python type is a pydantic BaseModel
49
+ """
50
+ try:
51
+ import pydantic # noqa: F401
52
+ except ImportError:
53
+ return False
54
+ else:
55
+ try:
56
+ from pydantic import BaseModel as BaseModelV2
57
+ from pydantic.v1 import BaseModel as BaseModelV1
58
+
59
+ return issubclass(python_type, BaseModelV1) or issubclass(python_type, BaseModelV2)
60
+ except ImportError:
61
+ from pydantic import BaseModel
62
+
63
+ return issubclass(python_type, BaseModel)
64
+
65
+
38
66
  def key_value_callback(_: typing.Any, param: str, values: typing.List[str]) -> typing.Optional[typing.Dict[str, str]]:
39
67
  """
40
68
  Callback for click to parse key-value pairs.
@@ -72,13 +100,17 @@ class DirParamType(click.ParamType):
72
100
  def convert(
73
101
  self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
74
102
  ) -> typing.Any:
75
- from flyte.storage import is_remote
103
+ if isinstance(value, ArtifactQuery):
104
+ return value
76
105
 
106
+ # set remote_directory to false if running pyflyte run locally. This makes sure that the original
107
+ # directory is used and not a random one.
108
+ remote_directory = None if getattr(ctx.obj, "is_remote", False) else False
77
109
  if not is_remote(value):
78
110
  p = pathlib.Path(value)
79
111
  if not p.exists() or not p.is_dir():
80
112
  raise click.BadParameter(f"parameter should be a valid flytedirectory path, {value}")
81
- return Dir(path=value)
113
+ return Dir(path=value, remote_directory=remote_directory)
82
114
 
83
115
 
84
116
  class StructuredDatasetParamType(click.ParamType):
@@ -91,6 +123,8 @@ class StructuredDatasetParamType(click.ParamType):
91
123
  def convert(
92
124
  self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
93
125
  ) -> typing.Any:
126
+ if isinstance(value, ArtifactQuery):
127
+ return value
94
128
  if isinstance(value, str):
95
129
  return StructuredDataset(uri=value)
96
130
  elif isinstance(value, StructuredDataset):
@@ -104,19 +138,22 @@ class FileParamType(click.ParamType):
104
138
  def convert(
105
139
  self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
106
140
  ) -> typing.Any:
107
- from flyte.storage import is_remote
108
-
141
+ if isinstance(value, ArtifactQuery):
142
+ return value
143
+ # set remote_directory to false if running pyflyte run locally. This makes sure that the original
144
+ # file is used and not a random one.
145
+ remote_path = None if getattr(ctx.obj, "is_remote", False) else False
109
146
  if not is_remote(value):
110
147
  p = pathlib.Path(value)
111
148
  if not p.exists() or not p.is_file():
112
149
  raise click.BadParameter(f"parameter should be a valid file path, {value}")
113
- return File(path=value)
150
+ return File(path=value, remote_path=remote_path)
114
151
 
115
152
 
116
153
  class PickleParamType(click.ParamType):
117
154
  name = "pickle"
118
155
 
119
- def get_metavar(self, param: "Parameter", *args) -> t.Optional[str]:
156
+ def get_metavar(self, param: click.Parameter) -> t.Optional[str]:
120
157
  return "Python Object <Module>:<Object>"
121
158
 
122
159
  def convert(
@@ -126,7 +163,7 @@ class PickleParamType(click.ParamType):
126
163
  return value
127
164
  parts = value.split(":")
128
165
  if len(parts) != 2:
129
- if ctx and ctx.obj and ctx.obj.log_level >= 10: # DEBUG level
166
+ if ctx and ctx.obj and ctx.obj.verbose > 0:
130
167
  click.echo(f"Did not receive a string in the expected format <MODULE>:<VAR>, falling back to: {value}")
131
168
  return value
132
169
  try:
@@ -148,6 +185,9 @@ class JSONIteratorParamType(click.ParamType):
148
185
  return value
149
186
 
150
187
 
188
+ import re
189
+
190
+
151
191
  def parse_iso8601_duration(iso_duration: str) -> datetime.timedelta:
152
192
  pattern = re.compile(
153
193
  r"^P" # Starts with 'P'
@@ -171,10 +211,10 @@ def parse_human_durations(text: str) -> list[datetime.timedelta]:
171
211
  durations = []
172
212
 
173
213
  for part in raw_parts:
174
- new_part = part.strip().lower()
214
+ part = part.strip().lower()
175
215
 
176
216
  # Match 1:24 or :45
177
- m_colon = re.match(r"^(?:(\d+):)?(\d+)$", new_part)
217
+ m_colon = re.match(r"^(?:(\d+):)?(\d+)$", part)
178
218
  if m_colon:
179
219
  minutes = int(m_colon.group(1)) if m_colon.group(1) else 0
180
220
  seconds = int(m_colon.group(2))
@@ -182,7 +222,7 @@ def parse_human_durations(text: str) -> list[datetime.timedelta]:
182
222
  continue
183
223
 
184
224
  # Match "10 days", "1 minute", etc.
185
- m_units = re.match(r"^(\d+)\s*(day|hour|minute|second)s?$", new_part)
225
+ m_units = re.match(r"^(\d+)\s*(day|hour|minute|second)s?$", part)
186
226
  if m_units:
187
227
  value = int(m_units.group(1))
188
228
  unit = m_units.group(2)
@@ -230,6 +270,9 @@ class DateTimeType(click.DateTime):
230
270
  def convert(
231
271
  self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
232
272
  ) -> typing.Any:
273
+ if isinstance(value, ArtifactQuery):
274
+ return value
275
+
233
276
  if isinstance(value, str) and " " in value:
234
277
  import re
235
278
 
@@ -260,6 +303,8 @@ class DurationParamType(click.ParamType):
260
303
  def convert(
261
304
  self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
262
305
  ) -> typing.Any:
306
+ if isinstance(value, ArtifactQuery):
307
+ return value
263
308
  if value is None:
264
309
  raise click.BadParameter("None value cannot be converted to a Duration type.")
265
310
  return parse_duration(value)
@@ -273,6 +318,8 @@ class EnumParamType(click.Choice):
273
318
  def convert(
274
319
  self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
275
320
  ) -> enum.Enum:
321
+ if isinstance(value, ArtifactQuery):
322
+ return value
276
323
  if isinstance(value, self._enum_type):
277
324
  return value
278
325
  return self._enum_type(super().convert(value, param, ctx))
@@ -286,7 +333,10 @@ class UnionParamType(click.ParamType):
286
333
  def __init__(self, types: typing.List[click.ParamType]):
287
334
  super().__init__()
288
335
  self._types = self._sort_precedence(types)
289
- self.name = "|".join([t.name for t in self._types])
336
+
337
+ @property
338
+ def name(self) -> str:
339
+ return "|".join([t.name for t in self._types])
290
340
 
291
341
  @staticmethod
292
342
  def _sort_precedence(tp: typing.List[click.ParamType]) -> typing.List[click.ParamType]:
@@ -300,7 +350,7 @@ class UnionParamType(click.ParamType):
300
350
  str_types.append(p)
301
351
  else:
302
352
  others.append(p)
303
- return others + str_types + unprocessed # type: ignore
353
+ return others + str_types + unprocessed
304
354
 
305
355
  def convert(
306
356
  self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
@@ -309,6 +359,8 @@ class UnionParamType(click.ParamType):
309
359
  Important to implement NoneType / Optional.
310
360
  Also could we just determine the click types from the python types
311
361
  """
362
+ if isinstance(value, ArtifactQuery):
363
+ return value
312
364
  for p in self._types:
313
365
  try:
314
366
  return p.convert(value, param, ctx)
@@ -334,7 +386,7 @@ class JsonParamType(click.ParamType):
334
386
  # We failed to load the json, so we'll try to load it as a file
335
387
  if os.path.exists(value):
336
388
  # if the value is a yaml file, we'll try to load it as yaml
337
- if value.endswith((".yaml", "yml")):
389
+ if value.endswith(".yaml") or value.endswith(".yml"):
338
390
  with open(value, "r") as f:
339
391
  return yaml.safe_load(f)
340
392
  with open(value, "r") as f:
@@ -346,6 +398,8 @@ class JsonParamType(click.ParamType):
346
398
  def convert(
347
399
  self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
348
400
  ) -> typing.Any:
401
+ if isinstance(value, ArtifactQuery):
402
+ return value
349
403
  if value is None:
350
404
  raise click.BadParameter("None value cannot be converted to a Json type.")
351
405
 
@@ -353,38 +407,68 @@ class JsonParamType(click.ParamType):
353
407
 
354
408
  # We compare the origin type because the json parsed value for list or dict is always a list or dict without
355
409
  # the covariant type information.
356
- if type(parsed_value) is typing.get_origin(self._python_type) or type(parsed_value) is self._python_type:
410
+ if type(parsed_value) == typing.get_origin(self._python_type) or type(parsed_value) == self._python_type:
357
411
  # Indexing the return value of get_args will raise an error for native dict and list types.
358
412
  # We don't support native list/dict types with nested dataclasses.
359
413
  if get_args(self._python_type) == ():
360
414
  return parsed_value
361
415
  elif isinstance(parsed_value, list) and dataclasses.is_dataclass(get_args(self._python_type)[0]):
362
416
  j = JsonParamType(get_args(self._python_type)[0])
363
- # turn object back into json string
364
- return [j.convert(json.dumps(v), param, ctx) for v in parsed_value]
417
+ return [j.convert(v, param, ctx) for v in parsed_value]
365
418
  elif isinstance(parsed_value, dict) and dataclasses.is_dataclass(get_args(self._python_type)[1]):
366
419
  j = JsonParamType(get_args(self._python_type)[1])
367
- # turn object back into json string
368
- return {k: j.convert(json.dumps(v), param, ctx) for k, v in parsed_value.items()}
420
+ return {k: j.convert(v, param, ctx) for k, v in parsed_value.items()}
369
421
 
370
422
  return parsed_value
371
423
 
372
- from pydantic import BaseModel
373
-
374
- if issubclass(self._python_type, BaseModel):
375
- return typing.cast(BaseModel, self._python_type).model_validate_json(
376
- json.dumps(parsed_value), strict=False, context={"deserialize": True}
377
- )
378
- elif dataclasses.is_dataclass(self._python_type):
379
- from mashumaro.codecs.json import JSONDecoder
424
+ if is_pydantic_basemodel(self._python_type):
425
+ """
426
+ This function supports backward compatibility for the Pydantic v1 plugin.
427
+ If the class is a Pydantic BaseModel, it attempts to parse JSON input using
428
+ the appropriate version of Pydantic (v1 or v2).
429
+ """
430
+ try:
431
+ if importlib.util.find_spec("pydantic.v1") is not None:
432
+ from pydantic import BaseModel as BaseModelV2
433
+
434
+ if issubclass(self._python_type, BaseModelV2):
435
+ return self._python_type.model_validate_json(
436
+ json.dumps(parsed_value), strict=False, context={"deserialize": True}
437
+ )
438
+ except ImportError:
439
+ pass
440
+
441
+ # The behavior of the Pydantic v1 plugin.
442
+ return self._python_type.parse_raw(json.dumps(parsed_value))
443
+ return None
380
444
 
381
- decoder = JSONDecoder(self._python_type)
382
- return decoder.decode(value)
383
445
 
384
- return parsed_value
446
+ def modify_literal_uris(lit: Literal):
447
+ """
448
+ Modifies the literal object recursively to replace the URIs with the native paths.
449
+ """
450
+ if lit.collection:
451
+ for l in lit.collection.literals:
452
+ modify_literal_uris(l)
453
+ elif lit.map:
454
+ for k, v in lit.map.literals.items():
455
+ modify_literal_uris(v)
456
+ elif lit.scalar:
457
+ if lit.scalar.blob and lit.scalar.blob.uri and lit.scalar.blob.uri.startswith(RemoteFSPathResolver.protocol):
458
+ lit.scalar.blob._uri = RemoteFSPathResolver.resolve_remote_path(lit.scalar.blob.uri)
459
+ elif lit.scalar.union:
460
+ modify_literal_uris(lit.scalar.union.value)
461
+ elif (
462
+ lit.scalar.structured_dataset
463
+ and lit.scalar.structured_dataset.uri
464
+ and lit.scalar.structured_dataset.uri.startswith(RemoteFSPathResolver.protocol)
465
+ ):
466
+ lit.scalar.structured_dataset._uri = RemoteFSPathResolver.resolve_remote_path(
467
+ lit.scalar.structured_dataset.uri
468
+ )
385
469
 
386
470
 
387
- SIMPLE_TYPE_CONVERTER = {
471
+ SIMPLE_TYPE_CONVERTER: typing.Dict[SimpleType, click.ParamType] = {
388
472
  SimpleType.FLOAT: click.FLOAT,
389
473
  SimpleType.INTEGER: click.INT,
390
474
  SimpleType.STRING: click.STRING,
@@ -398,7 +482,7 @@ def literal_type_to_click_type(lt: LiteralType, python_type: typing.Type) -> cli
398
482
  """
399
483
  Converts a Flyte LiteralType given a python_type to a click.ParamType
400
484
  """
401
- if lt.HasField("simple"):
485
+ if lt.simple:
402
486
  if lt.simple == SimpleType.STRUCT:
403
487
  ct = JsonParamType(python_type)
404
488
  ct.name = f"JSON object {python_type.__name__}"
@@ -407,38 +491,38 @@ def literal_type_to_click_type(lt: LiteralType, python_type: typing.Type) -> cli
407
491
  return SIMPLE_TYPE_CONVERTER[lt.simple]
408
492
  raise NotImplementedError(f"Type {lt.simple} is not supported in pyflyte run")
409
493
 
410
- if lt.HasField("structured_dataset_type"):
494
+ if lt.enum_type:
495
+ return EnumParamType(python_type) # type: ignore
496
+
497
+ if lt.structured_dataset_type:
411
498
  return StructuredDatasetParamType()
412
499
 
413
- if lt.HasField("collection_type") or lt.HasField("map_value_type"):
500
+ if lt.collection_type or lt.map_value_type:
414
501
  ct = JsonParamType(python_type)
415
- if lt.HasField("collection_type"):
502
+ if lt.collection_type:
416
503
  ct.name = "json list"
417
504
  else:
418
505
  ct.name = "json dictionary"
419
506
  return ct
420
507
 
421
- if lt.HasField("blob"):
508
+ if lt.blob:
422
509
  if lt.blob.dimensionality == BlobType.BlobDimensionality.SINGLE:
423
510
  if lt.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT:
424
511
  return PickleParamType()
425
- # TODO: Add JSONIteratorTransformer
426
512
  # elif lt.blob.format == JSONIteratorTransformer.JSON_ITERATOR_FORMAT:
427
513
  # return JSONIteratorParamType()
428
514
  return FileParamType()
429
515
  return DirParamType()
430
516
 
431
- if lt.HasField("union_type"):
517
+ if lt.union_type:
432
518
  cts = []
433
519
  for i in range(len(lt.union_type.variants)):
434
520
  variant = lt.union_type.variants[i]
435
521
  variant_python_type = typing.get_args(python_type)[i]
436
- cts.append(literal_type_to_click_type(variant, variant_python_type))
522
+ ct = literal_type_to_click_type(variant, variant_python_type)
523
+ cts.append(ct)
437
524
  return UnionParamType(cts)
438
525
 
439
- if lt.HasField("enum_type"):
440
- return EnumParamType(python_type) # type: ignore
441
-
442
526
  return click.UNPROCESSED
443
527
 
444
528
 
@@ -449,7 +533,9 @@ class FlyteLiteralConverter(object):
449
533
  self,
450
534
  literal_type: LiteralType,
451
535
  python_type: typing.Type,
536
+ is_remote: bool,
452
537
  ):
538
+ self._is_remote = is_remote
453
539
  self._literal_type = literal_type
454
540
  self._python_type = python_type
455
541
  self._click_type = literal_type_to_click_type(literal_type, python_type)
@@ -465,15 +551,25 @@ class FlyteLiteralConverter(object):
465
551
  self, ctx: click.Context, param: typing.Optional[click.Parameter], value: typing.Any
466
552
  ) -> typing.Union[Literal, typing.Any]:
467
553
  """
468
- Convert the value to a python native type. This is used by click to convert the input.
554
+ Convert the value to a Flyte Literal or a python native type. This is used by click to convert the input.
469
555
  """
556
+ if isinstance(value, ArtifactQuery):
557
+ return value
470
558
  try:
471
559
  # If the expected Python type is datetime.date, adjust the value to date
472
560
  if self._python_type is datetime.date:
473
561
  # Click produces datetime, so converting to date to avoid type mismatch error
474
562
  value = value.date()
563
+ # If the input matches the default value in the launch plan, serialization can be skipped.
564
+ if param and value == param.default:
565
+ return None
475
566
 
476
- return value
567
+ # If this is used for remote execution, then we need to convert it back to a python native type
568
+ if not self._is_remote:
569
+ return value
570
+
571
+ lit = asyncio.run(TypeEngine.to_literal(value, self._python_type, self._literal_type))
572
+ return lit
477
573
  except click.BadParameter:
478
574
  raise
479
575
  except Exception as e:
@@ -481,58 +577,3 @@ class FlyteLiteralConverter(object):
481
577
  f"Failed to convert param: {param if param else 'NA'}, value: {value} to type: {self._python_type}."
482
578
  f" Reason {e}"
483
579
  ) from e
484
-
485
-
486
- def to_click_option(
487
- input_name: str,
488
- literal_var: Variable,
489
- python_type: typing.Type,
490
- default_val: typing.Any,
491
- ) -> click.Option:
492
- """
493
- This handles converting workflow input types to supported click parameters with callbacks to initialize
494
- the input values to their expected types.
495
- """
496
- from flyteidl.core.types_pb2 import SimpleType
497
-
498
- if input_name != input_name.lower():
499
- # Click does not support uppercase option names: https://github.com/pallets/click/issues/837
500
- raise ValueError(f"Workflow input name must be lowercase: {input_name!r}")
501
-
502
- literal_converter = FlyteLiteralConverter(
503
- literal_type=literal_var.type,
504
- python_type=python_type,
505
- )
506
-
507
- if literal_converter.is_bool() and not default_val:
508
- default_val = False
509
-
510
- description_extra = ""
511
- if literal_var.type.simple == SimpleType.STRUCT:
512
- if default_val:
513
- # pydantic v2
514
- if hasattr(default_val, "model_dump_json"):
515
- default_val = default_val.model_dump_json()
516
- else:
517
- encoder = JSONEncoder(python_type)
518
- default_val = encoder.encode(default_val)
519
- if literal_var.type.metadata:
520
- description_extra = f": {MessageToDict(literal_var.type.metadata)}"
521
-
522
- # If a query has been specified, the input is never strictly required at this layer
523
- required = False if default_val is not None else True
524
- is_flag: typing.Optional[bool] = None
525
- if literal_converter.is_bool():
526
- required = False
527
- is_flag = True
528
-
529
- return click.Option(
530
- param_decls=[f"--{input_name}"],
531
- type=literal_converter.click_type,
532
- is_flag=is_flag,
533
- default=default_val,
534
- show_default=True,
535
- required=required,
536
- help=literal_var.description + description_extra,
537
- callback=literal_converter.convert,
538
- )