flyte 0.1.0__py3-none-any.whl → 0.2.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (219) hide show
  1. flyte/__init__.py +78 -2
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/runtime.py +152 -0
  4. flyte/_build.py +26 -0
  5. flyte/_cache/__init__.py +12 -0
  6. flyte/_cache/cache.py +145 -0
  7. flyte/_cache/defaults.py +9 -0
  8. flyte/_cache/policy_function_body.py +42 -0
  9. flyte/_code_bundle/__init__.py +8 -0
  10. flyte/_code_bundle/_ignore.py +113 -0
  11. flyte/_code_bundle/_packaging.py +187 -0
  12. flyte/_code_bundle/_utils.py +323 -0
  13. flyte/_code_bundle/bundle.py +209 -0
  14. flyte/_context.py +152 -0
  15. flyte/_deploy.py +243 -0
  16. flyte/_doc.py +29 -0
  17. flyte/_docstring.py +32 -0
  18. flyte/_environment.py +84 -0
  19. flyte/_excepthook.py +37 -0
  20. flyte/_group.py +32 -0
  21. flyte/_hash.py +23 -0
  22. flyte/_image.py +762 -0
  23. flyte/_initialize.py +492 -0
  24. flyte/_interface.py +84 -0
  25. flyte/_internal/__init__.py +3 -0
  26. flyte/_internal/controllers/__init__.py +128 -0
  27. flyte/_internal/controllers/_local_controller.py +193 -0
  28. flyte/_internal/controllers/_trace.py +41 -0
  29. flyte/_internal/controllers/remote/__init__.py +60 -0
  30. flyte/_internal/controllers/remote/_action.py +146 -0
  31. flyte/_internal/controllers/remote/_client.py +47 -0
  32. flyte/_internal/controllers/remote/_controller.py +494 -0
  33. flyte/_internal/controllers/remote/_core.py +410 -0
  34. flyte/_internal/controllers/remote/_informer.py +361 -0
  35. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  36. flyte/_internal/imagebuild/__init__.py +11 -0
  37. flyte/_internal/imagebuild/docker_builder.py +427 -0
  38. flyte/_internal/imagebuild/image_builder.py +246 -0
  39. flyte/_internal/imagebuild/remote_builder.py +0 -0
  40. flyte/_internal/resolvers/__init__.py +0 -0
  41. flyte/_internal/resolvers/_task_module.py +54 -0
  42. flyte/_internal/resolvers/common.py +31 -0
  43. flyte/_internal/resolvers/default.py +28 -0
  44. flyte/_internal/runtime/__init__.py +0 -0
  45. flyte/_internal/runtime/convert.py +342 -0
  46. flyte/_internal/runtime/entrypoints.py +135 -0
  47. flyte/_internal/runtime/io.py +136 -0
  48. flyte/_internal/runtime/resources_serde.py +138 -0
  49. flyte/_internal/runtime/task_serde.py +330 -0
  50. flyte/_internal/runtime/taskrunner.py +191 -0
  51. flyte/_internal/runtime/types_serde.py +54 -0
  52. flyte/_logging.py +135 -0
  53. flyte/_map.py +215 -0
  54. flyte/_pod.py +19 -0
  55. flyte/_protos/__init__.py +0 -0
  56. flyte/_protos/common/authorization_pb2.py +66 -0
  57. flyte/_protos/common/authorization_pb2.pyi +108 -0
  58. flyte/_protos/common/authorization_pb2_grpc.py +4 -0
  59. flyte/_protos/common/identifier_pb2.py +71 -0
  60. flyte/_protos/common/identifier_pb2.pyi +82 -0
  61. flyte/_protos/common/identifier_pb2_grpc.py +4 -0
  62. flyte/_protos/common/identity_pb2.py +48 -0
  63. flyte/_protos/common/identity_pb2.pyi +72 -0
  64. flyte/_protos/common/identity_pb2_grpc.py +4 -0
  65. flyte/_protos/common/list_pb2.py +36 -0
  66. flyte/_protos/common/list_pb2.pyi +71 -0
  67. flyte/_protos/common/list_pb2_grpc.py +4 -0
  68. flyte/_protos/common/policy_pb2.py +37 -0
  69. flyte/_protos/common/policy_pb2.pyi +27 -0
  70. flyte/_protos/common/policy_pb2_grpc.py +4 -0
  71. flyte/_protos/common/role_pb2.py +37 -0
  72. flyte/_protos/common/role_pb2.pyi +53 -0
  73. flyte/_protos/common/role_pb2_grpc.py +4 -0
  74. flyte/_protos/common/runtime_version_pb2.py +28 -0
  75. flyte/_protos/common/runtime_version_pb2.pyi +24 -0
  76. flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
  77. flyte/_protos/logs/dataplane/payload_pb2.py +100 -0
  78. flyte/_protos/logs/dataplane/payload_pb2.pyi +177 -0
  79. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  80. flyte/_protos/secret/definition_pb2.py +49 -0
  81. flyte/_protos/secret/definition_pb2.pyi +93 -0
  82. flyte/_protos/secret/definition_pb2_grpc.py +4 -0
  83. flyte/_protos/secret/payload_pb2.py +62 -0
  84. flyte/_protos/secret/payload_pb2.pyi +94 -0
  85. flyte/_protos/secret/payload_pb2_grpc.py +4 -0
  86. flyte/_protos/secret/secret_pb2.py +38 -0
  87. flyte/_protos/secret/secret_pb2.pyi +6 -0
  88. flyte/_protos/secret/secret_pb2_grpc.py +198 -0
  89. flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
  90. flyte/_protos/validate/validate/validate_pb2.py +76 -0
  91. flyte/_protos/workflow/common_pb2.py +27 -0
  92. flyte/_protos/workflow/common_pb2.pyi +14 -0
  93. flyte/_protos/workflow/common_pb2_grpc.py +4 -0
  94. flyte/_protos/workflow/environment_pb2.py +29 -0
  95. flyte/_protos/workflow/environment_pb2.pyi +12 -0
  96. flyte/_protos/workflow/environment_pb2_grpc.py +4 -0
  97. flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
  98. flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  99. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  100. flyte/_protos/workflow/queue_service_pb2.py +105 -0
  101. flyte/_protos/workflow/queue_service_pb2.pyi +146 -0
  102. flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  103. flyte/_protos/workflow/run_definition_pb2.py +128 -0
  104. flyte/_protos/workflow/run_definition_pb2.pyi +314 -0
  105. flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  106. flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
  107. flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  108. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  109. flyte/_protos/workflow/run_service_pb2.py +129 -0
  110. flyte/_protos/workflow/run_service_pb2.pyi +171 -0
  111. flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
  112. flyte/_protos/workflow/state_service_pb2.py +66 -0
  113. flyte/_protos/workflow/state_service_pb2.pyi +75 -0
  114. flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
  115. flyte/_protos/workflow/task_definition_pb2.py +79 -0
  116. flyte/_protos/workflow/task_definition_pb2.pyi +81 -0
  117. flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  118. flyte/_protos/workflow/task_service_pb2.py +60 -0
  119. flyte/_protos/workflow/task_service_pb2.pyi +59 -0
  120. flyte/_protos/workflow/task_service_pb2_grpc.py +138 -0
  121. flyte/_resources.py +226 -0
  122. flyte/_retry.py +32 -0
  123. flyte/_reusable_environment.py +25 -0
  124. flyte/_run.py +482 -0
  125. flyte/_secret.py +61 -0
  126. flyte/_task.py +449 -0
  127. flyte/_task_environment.py +183 -0
  128. flyte/_timeout.py +47 -0
  129. flyte/_tools.py +27 -0
  130. flyte/_trace.py +120 -0
  131. flyte/_utils/__init__.py +26 -0
  132. flyte/_utils/asyn.py +119 -0
  133. flyte/_utils/async_cache.py +139 -0
  134. flyte/_utils/coro_management.py +23 -0
  135. flyte/_utils/file_handling.py +72 -0
  136. flyte/_utils/helpers.py +134 -0
  137. flyte/_utils/lazy_module.py +54 -0
  138. flyte/_utils/org_discovery.py +57 -0
  139. flyte/_utils/uv_script_parser.py +49 -0
  140. flyte/_version.py +21 -0
  141. flyte/cli/__init__.py +3 -0
  142. flyte/cli/_abort.py +28 -0
  143. flyte/cli/_common.py +337 -0
  144. flyte/cli/_create.py +145 -0
  145. flyte/cli/_delete.py +23 -0
  146. flyte/cli/_deploy.py +152 -0
  147. flyte/cli/_gen.py +163 -0
  148. flyte/cli/_get.py +310 -0
  149. flyte/cli/_params.py +538 -0
  150. flyte/cli/_run.py +231 -0
  151. flyte/cli/main.py +166 -0
  152. flyte/config/__init__.py +3 -0
  153. flyte/config/_config.py +216 -0
  154. flyte/config/_internal.py +64 -0
  155. flyte/config/_reader.py +207 -0
  156. flyte/connectors/__init__.py +0 -0
  157. flyte/errors.py +172 -0
  158. flyte/extras/__init__.py +5 -0
  159. flyte/extras/_container.py +263 -0
  160. flyte/io/__init__.py +27 -0
  161. flyte/io/_dir.py +448 -0
  162. flyte/io/_file.py +467 -0
  163. flyte/io/_structured_dataset/__init__.py +129 -0
  164. flyte/io/_structured_dataset/basic_dfs.py +219 -0
  165. flyte/io/_structured_dataset/structured_dataset.py +1061 -0
  166. flyte/models.py +391 -0
  167. flyte/remote/__init__.py +26 -0
  168. flyte/remote/_client/__init__.py +0 -0
  169. flyte/remote/_client/_protocols.py +133 -0
  170. flyte/remote/_client/auth/__init__.py +12 -0
  171. flyte/remote/_client/auth/_auth_utils.py +14 -0
  172. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  173. flyte/remote/_client/auth/_authenticators/base.py +397 -0
  174. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  175. flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
  176. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  177. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  178. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  179. flyte/remote/_client/auth/_channel.py +215 -0
  180. flyte/remote/_client/auth/_client_config.py +83 -0
  181. flyte/remote/_client/auth/_default_html.py +32 -0
  182. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  183. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  184. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  185. flyte/remote/_client/auth/_keyring.py +143 -0
  186. flyte/remote/_client/auth/_token_client.py +260 -0
  187. flyte/remote/_client/auth/errors.py +16 -0
  188. flyte/remote/_client/controlplane.py +95 -0
  189. flyte/remote/_console.py +18 -0
  190. flyte/remote/_data.py +159 -0
  191. flyte/remote/_logs.py +176 -0
  192. flyte/remote/_project.py +85 -0
  193. flyte/remote/_run.py +970 -0
  194. flyte/remote/_secret.py +132 -0
  195. flyte/remote/_task.py +391 -0
  196. flyte/report/__init__.py +3 -0
  197. flyte/report/_report.py +178 -0
  198. flyte/report/_template.html +124 -0
  199. flyte/storage/__init__.py +29 -0
  200. flyte/storage/_config.py +233 -0
  201. flyte/storage/_remote_fs.py +34 -0
  202. flyte/storage/_storage.py +271 -0
  203. flyte/storage/_utils.py +5 -0
  204. flyte/syncify/__init__.py +56 -0
  205. flyte/syncify/_api.py +371 -0
  206. flyte/types/__init__.py +36 -0
  207. flyte/types/_interface.py +40 -0
  208. flyte/types/_pickle.py +118 -0
  209. flyte/types/_renderer.py +162 -0
  210. flyte/types/_string_literals.py +120 -0
  211. flyte/types/_type_engine.py +2287 -0
  212. flyte/types/_utils.py +80 -0
  213. flyte-0.2.0a0.dist-info/METADATA +249 -0
  214. flyte-0.2.0a0.dist-info/RECORD +218 -0
  215. {flyte-0.1.0.dist-info → flyte-0.2.0a0.dist-info}/WHEEL +2 -1
  216. flyte-0.2.0a0.dist-info/entry_points.txt +3 -0
  217. flyte-0.2.0a0.dist-info/top_level.txt +1 -0
  218. flyte-0.1.0.dist-info/METADATA +0 -6
  219. flyte-0.1.0.dist-info/RECORD +0 -5
flyte/models.py ADDED
@@ -0,0 +1,391 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import os
5
+ import pathlib
6
+ import tempfile
7
+ from dataclasses import dataclass, field, replace
8
+ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Literal, Optional, Tuple, Type
9
+
10
+ import rich.repr
11
+
12
+ from flyte._docstring import Docstring
13
+ from flyte._interface import extract_return_annotation
14
+ from flyte._logging import logger
15
+ from flyte._utils.helpers import base36_encode
16
+
17
+ if TYPE_CHECKING:
18
+ from flyteidl.core import literals_pb2
19
+
20
+ from flyte._internal.imagebuild.image_builder import ImageCache
21
+ from flyte.report import Report
22
+
23
+
24
+ def generate_random_name() -> str:
25
+ """
26
+ Generate a random name for the task. This is used to create unique names for tasks.
27
+ TODO we can use unique-namer in the future, for now its just guids
28
+ """
29
+ from uuid import uuid4
30
+
31
+ return str(uuid4()) # Placeholder for actual random name generation logic
32
+
33
+
34
+ @rich.repr.auto
35
+ @dataclass(frozen=True, kw_only=True)
36
+ class ActionID:
37
+ """
38
+ A class representing the ID of an Action, nested within a Run. This is used to identify a specific action on a task.
39
+ """
40
+
41
+ name: str
42
+ run_name: str | None = None
43
+ project: str | None = None
44
+ domain: str | None = None
45
+ org: str | None = None
46
+
47
+ def __post_init__(self):
48
+ if self.run_name is None:
49
+ object.__setattr__(self, "run_name", self.name)
50
+
51
+ @classmethod
52
+ def create_random(cls):
53
+ name = generate_random_name()
54
+ return cls(name=name, run_name=name)
55
+
56
+ def new_sub_action(self, name: str | None = None) -> ActionID:
57
+ """
58
+ Create a new sub-run with the given name. If name is None, a random name will be generated.
59
+ """
60
+ if name is None:
61
+ name = generate_random_name()
62
+ return replace(self, name=name)
63
+
64
+ def new_sub_action_from(self, task_call_seq: int, task_hash: str, input_hash: str, group: str | None) -> ActionID:
65
+ """Make a deterministic name"""
66
+ import hashlib
67
+
68
+ components = f"{self.name}-{input_hash}-{task_hash}-{task_call_seq}" + (f"-{group}" if group else "")
69
+ logger.debug(f"----- Generating sub-action ID from components: {components}")
70
+ # has the components into something deterministic
71
+ bytes_digest = hashlib.md5(components.encode()).digest()
72
+ new_name = base36_encode(bytes_digest)
73
+ return self.new_sub_action(new_name)
74
+
75
+
76
+ @rich.repr.auto
77
+ @dataclass(frozen=True, kw_only=True)
78
+ class RawDataPath:
79
+ """
80
+ A class representing the raw data path for a task. This is used to store the raw data for the task execution and
81
+ also get mutations on the path.
82
+ """
83
+
84
+ path: str
85
+
86
+ @classmethod
87
+ def from_local_folder(cls, local_folder: str | pathlib.Path | None = None) -> RawDataPath:
88
+ """
89
+ Create a new context attribute object, with local path given. Will be created if it doesn't exist.
90
+ :return: Path to the temporary directory
91
+ """
92
+ match local_folder:
93
+ case pathlib.Path():
94
+ local_folder.mkdir(parents=True, exist_ok=True)
95
+ return RawDataPath(path=str(local_folder))
96
+ case None:
97
+ # Create a temporary directory for data storage
98
+ p = tempfile.mkdtemp()
99
+ logger.debug(f"Creating temporary directory for data storage: {p}")
100
+ pathlib.Path(p).mkdir(parents=True, exist_ok=True)
101
+ return RawDataPath(path=p)
102
+ case str():
103
+ return RawDataPath(path=local_folder)
104
+ case _:
105
+ raise ValueError(f"Invalid local path {local_folder}")
106
+
107
+ def get_random_remote_path(self, file_name: Optional[str] = None) -> str:
108
+ """
109
+ Returns a random path for uploading a file/directory to.
110
+
111
+ :param file_name: If given, will be joined after a randomly generated portion.
112
+ :return:
113
+ """
114
+ import random
115
+ from uuid import UUID
116
+
117
+ import fsspec
118
+ from fsspec.utils import get_protocol
119
+
120
+ random_string = UUID(int=random.getrandbits(128)).hex
121
+ file_prefix = self.path
122
+
123
+ protocol = get_protocol(file_prefix)
124
+ if "file" in protocol:
125
+ local_path = pathlib.Path(file_prefix) / random_string
126
+ if file_name:
127
+ # Only if file name is given do we create the parent, because it may be needed as a folder otherwise
128
+ local_path = local_path / file_name
129
+ if not local_path.exists():
130
+ local_path.parent.mkdir(exist_ok=True, parents=True)
131
+ local_path.touch()
132
+ return str(local_path.absolute())
133
+
134
+ fs = fsspec.filesystem(protocol)
135
+ if file_prefix.endswith(fs.sep):
136
+ file_prefix = file_prefix[:-1]
137
+ remote_path = fs.sep.join([file_prefix, random_string])
138
+ if file_name:
139
+ remote_path = fs.sep.join([remote_path, file_name])
140
+ return remote_path
141
+
142
+
143
+ @rich.repr.auto
144
+ @dataclass(frozen=True)
145
+ class GroupData:
146
+ name: str
147
+
148
+
149
+ @rich.repr.auto
150
+ @dataclass(frozen=True, kw_only=True)
151
+ class TaskContext:
152
+ """
153
+ A context class to hold the current task executions context.
154
+ This can be used to access various contextual parameters in the task execution by the user.
155
+
156
+ :param action: The action ID of the current execution. This is always set, within a run.
157
+ :param version: The version of the executed task. This is set when the task is executed by an action and will be
158
+ set on all sub-actions.
159
+ """
160
+
161
+ action: ActionID
162
+ version: str
163
+ raw_data_path: RawDataPath
164
+ output_path: str
165
+ run_base_dir: str
166
+ report: Report
167
+ group_data: GroupData | None = None
168
+ checkpoints: Checkpoints | None = None
169
+ code_bundle: CodeBundle | None = None
170
+ compiled_image_cache: ImageCache | None = None
171
+ data: Dict[str, Any] = field(default_factory=dict)
172
+ mode: Literal["local", "remote", "hybrid"] = "remote"
173
+
174
+ def replace(self, **kwargs) -> TaskContext:
175
+ if "data" in kwargs:
176
+ rec_data = kwargs.pop("data")
177
+ if rec_data is None:
178
+ return replace(self, **kwargs)
179
+ data = {}
180
+ if self.data is not None:
181
+ data = self.data.copy()
182
+ data.update(rec_data)
183
+ kwargs.update({"data": data})
184
+ return replace(self, **kwargs)
185
+
186
+ def __getitem__(self, key: str) -> Optional[Any]:
187
+ return self.data.get(key)
188
+
189
+
190
+ @rich.repr.auto
191
+ @dataclass(frozen=True, kw_only=True)
192
+ class CodeBundle:
193
+ """
194
+ A class representing a code bundle for a task. This is used to package the code and the inflation path.
195
+ The code bundle computes the version of the code using the hash of the code.
196
+
197
+ :param computed_version: The version of the code bundle. This is the hash of the code.
198
+ :param destination: The destination path for the code bundle to be inflated to.
199
+ :param tgz: Optional path to the tgz file.
200
+ :param pkl: Optional path to the pkl file.
201
+ :param downloaded_path: The path to the downloaded code bundle. This is only available during runtime, when
202
+ the code bundle has been downloaded and inflated.
203
+ """
204
+
205
+ computed_version: str
206
+ destination: str = "."
207
+ tgz: str | None = None
208
+ pkl: str | None = None
209
+ downloaded_path: pathlib.Path | None = None
210
+
211
+ # runtime_dependencies: Tuple[str, ...] = field(default_factory=tuple) In the future if we want we could add this
212
+ # but this messes up actors, spark etc
213
+
214
+ def __post_init__(self):
215
+ if self.tgz is None and self.pkl is None:
216
+ raise ValueError("Either tgz or pkl must be provided")
217
+
218
+ def with_downloaded_path(self, path: pathlib.Path) -> CodeBundle:
219
+ """
220
+ Create a new CodeBundle with the given downloaded path.
221
+ """
222
+ return replace(self, downloaded_path=path)
223
+
224
+
225
+ @rich.repr.auto
226
+ @dataclass(frozen=True)
227
+ class Checkpoints:
228
+ """
229
+ A class representing the checkpoints for a task. This is used to store the checkpoints for the task execution.
230
+ """
231
+
232
+ prev_checkpoint_path: str | None
233
+ checkpoint_path: str | None
234
+
235
+
236
+ class _has_default:
237
+ """
238
+ A marker class to indicate that a specific input has a default value or not.
239
+ This is used to determine if the input is required or not.
240
+ """
241
+
242
+
243
+ @dataclass(frozen=True)
244
+ class NativeInterface:
245
+ """
246
+ A class representing the native interface for a task. This is used to interact with the task and its execution
247
+ context.
248
+ """
249
+
250
+ inputs: Dict[str, Tuple[Type, Any]]
251
+ outputs: Dict[str, Type]
252
+ docstring: Optional[Docstring] = None
253
+
254
+ # This field is used to indicate that the task has a default value for the input, but already in the
255
+ # remote form.
256
+ _remote_defaults: Optional[Dict[str, literals_pb2.Literal]] = field(default=None, repr=False)
257
+
258
+ has_default: ClassVar[Type[_has_default]] = _has_default # This can be used to indicate if a specific input
259
+ # has a default value or not, in the case when the default value is not known. An example would be remote tasks.
260
+
261
+ def has_outputs(self) -> bool:
262
+ """
263
+ Check if the task has outputs. This is used to determine if the task has outputs or not.
264
+ """
265
+ return self.outputs is not None and len(self.outputs) > 0
266
+
267
+ def num_required_inputs(self) -> int:
268
+ """
269
+ Get the number of required inputs for the task. This is used to determine how many inputs are required for the
270
+ task execution.
271
+ """
272
+ return sum(1 for t in self.inputs.values() if t[1] is inspect.Parameter.empty)
273
+
274
+ @classmethod
275
+ def from_types(
276
+ cls,
277
+ inputs: Dict[str, Tuple[Type, Type[_has_default] | Type[inspect._empty]]],
278
+ outputs: Dict[str, Type],
279
+ default_inputs: Optional[Dict[str, literals_pb2.Literal]] = None,
280
+ ) -> NativeInterface:
281
+ """
282
+ Create a new NativeInterface from the given types. This is used to create a native interface for the task.
283
+ :param inputs: A dictionary of input names and their types and a value indicating if they have a default value.
284
+ :param outputs: A dictionary of output names and their types.
285
+ :param default_inputs: Optional dictionary of default inputs for remote tasks.
286
+ :return: A NativeInterface object with the given inputs and outputs.
287
+ """
288
+ for k, v in inputs.items():
289
+ if v[1] is cls.has_default and (default_inputs is None or k not in default_inputs):
290
+ raise ValueError(f"Input {k} has a default value but no default input provided for remote task.")
291
+ return cls(inputs=inputs, outputs=outputs, _remote_defaults=default_inputs)
292
+
293
+ @classmethod
294
+ def from_callable(cls, func: Callable) -> NativeInterface:
295
+ """
296
+ Extract the native interface from the given function. This is used to create a native interface for the task.
297
+ """
298
+ sig = inspect.signature(func)
299
+
300
+ # Extract parameter details (name, type, default value)
301
+ param_info = {name: (param.annotation, param.default) for name, param in sig.parameters.items()}
302
+
303
+ # Get return type
304
+ outputs = extract_return_annotation(sig.return_annotation)
305
+ return cls(inputs=param_info, outputs=outputs)
306
+
307
+ def convert_to_kwargs(self, *args, **kwargs) -> Dict[str, Any]:
308
+ """
309
+ Convert the given arguments to keyword arguments based on the native interface. This is used to convert the
310
+ arguments to the correct types for the task execution.
311
+ """
312
+ # Convert positional arguments to keyword arguments
313
+ if len(args) > len(self.inputs):
314
+ raise ValueError(f"Too many positional arguments provided, inputs {self.inputs.keys()}, args {len(args)}")
315
+ for arg, input_name in zip(args, self.inputs.keys()):
316
+ kwargs[input_name] = arg
317
+ return kwargs
318
+
319
+ def get_input_types(self) -> Dict[str, Type]:
320
+ """
321
+ Get the input types for the task. This is used to get the types of the inputs for the task execution.
322
+ """
323
+ return {k: v[0] for k, v in self.inputs.items()}
324
+
325
+ def __repr__(self):
326
+ """
327
+ Returns a string representation of the task interface.
328
+ """
329
+ i = "("
330
+ if self.inputs:
331
+ initial = True
332
+ for key, tpe in self.inputs.items():
333
+ if not initial:
334
+ i += ", "
335
+ initial = False
336
+ tp = tpe[0] if isinstance(tpe[0], str) else tpe[0].__name__
337
+ i += f"{key}: {tp}"
338
+ if tpe[1] is not inspect.Parameter.empty:
339
+ if tpe[1] is self.has_default:
340
+ i += " = ..."
341
+ else:
342
+ i += f" = {tpe[1]}"
343
+ i += ")"
344
+ if self.outputs:
345
+ initial = True
346
+ multi = len(self.outputs) > 1
347
+ i += " -> "
348
+ if multi:
349
+ i += "("
350
+ for key, tpe in self.outputs.items():
351
+ if not initial:
352
+ i += ", "
353
+ initial = False
354
+ tp = tpe.__name__ if isinstance(tpe, type) else tpe
355
+ i += f"{key}: {tp}"
356
+ if multi:
357
+ i += ")"
358
+ return i + ":"
359
+
360
+
361
+ @dataclass
362
+ class SerializationContext:
363
+ """
364
+ This object holds serialization time contextual information, that can be used when serializing the task and
365
+ various parameters of a tasktemplate. This is only available when the task is being serialized and can be
366
+ during a deployment or runtime.
367
+
368
+ :param version: The version of the task
369
+ :param code_bundle: The code bundle for the task. This is used to package the code and the inflation path.
370
+ :param input_path: The path to the inputs for the task. This is used to determine where the inputs will be located
371
+ :param output_path: The path to the outputs for the task. This is used to determine where the outputs will be
372
+ located
373
+ """
374
+
375
+ version: str
376
+ project: str | None = None
377
+ domain: str | None = None
378
+ org: str | None = None
379
+ code_bundle: Optional[CodeBundle] = None
380
+ input_path: str = "{{.input}}"
381
+ output_path: str = "{{.outputPrefix}}"
382
+ _entrypoint_path: str = field(default="_bin/runtime.py", init=False)
383
+ image_cache: ImageCache | None = None
384
+ root_dir: Optional[pathlib.Path] = None
385
+
386
+ def get_entrypoint_path(self, interpreter_path: str) -> str:
387
+ """
388
+ Get the entrypoint path for the task. This is used to determine the entrypoint for the task execution.
389
+ :param interpreter_path: The path to the interpreter (python)
390
+ """
391
+ return os.path.join(os.path.dirname(os.path.dirname(interpreter_path)), self._entrypoint_path)
@@ -0,0 +1,26 @@
1
+ """
2
+ Remote Entities that are accessible from the Union Server once deployed or created.
3
+ """
4
+
5
+ __all__ = [
6
+ "Action",
7
+ "ActionDetails",
8
+ "ActionInputs",
9
+ "ActionOutputs",
10
+ "Project",
11
+ "Run",
12
+ "RunDetails",
13
+ "Secret",
14
+ "SecretTypes",
15
+ "Task",
16
+ "create_channel",
17
+ "upload_dir",
18
+ "upload_file",
19
+ ]
20
+
21
+ from ._client.auth import create_channel
22
+ from ._data import upload_dir, upload_file
23
+ from ._project import Project
24
+ from ._run import Action, ActionDetails, ActionInputs, ActionOutputs, Run, RunDetails
25
+ from ._secret import Secret, SecretTypes
26
+ from ._task import Task
File without changes
@@ -0,0 +1,133 @@
1
+ from typing import AsyncIterator, Protocol
2
+
3
+ from flyteidl.admin import project_attributes_pb2, project_pb2, version_pb2
4
+ from flyteidl.service import dataproxy_pb2
5
+ from grpc.aio import UnaryStreamCall
6
+ from grpc.aio._typing import RequestType
7
+
8
+ from flyte._protos.secret import payload_pb2
9
+ from flyte._protos.workflow import run_logs_service_pb2, run_service_pb2, task_service_pb2
10
+
11
+
12
+ class MetadataServiceProtocol(Protocol):
13
+ async def GetVersion(self, request: version_pb2.GetVersionRequest) -> version_pb2.GetVersionResponse: ...
14
+
15
+
16
+ class ProjectDomainService(Protocol):
17
+ async def RegisterProject(
18
+ self, request: project_pb2.ProjectRegisterRequest
19
+ ) -> project_pb2.ProjectRegisterResponse: ...
20
+
21
+ async def UpdateProject(self, request: project_pb2.Project) -> project_pb2.ProjectUpdateResponse: ...
22
+
23
+ async def GetProject(self, request: project_pb2.ProjectGetRequest) -> project_pb2.Project: ...
24
+
25
+ async def ListProjects(self, request: project_pb2.ProjectListRequest) -> project_pb2.Projects: ...
26
+
27
+ async def GetDomains(self, request: project_pb2.GetDomainRequest) -> project_pb2.GetDomainsResponse: ...
28
+
29
+ async def UpdateProjectDomainAttributes(
30
+ self, request: project_attributes_pb2.ProjectAttributesUpdateRequest
31
+ ) -> project_pb2.ProjectUpdateResponse: ...
32
+
33
+ async def GetProjectDomainAttributes(
34
+ self, request: project_attributes_pb2.ProjectAttributesGetRequest
35
+ ) -> project_attributes_pb2.ProjectAttributes: ...
36
+
37
+ async def DeleteProjectDomainAttributes(
38
+ self, request: project_attributes_pb2.ProjectAttributesDeleteRequest
39
+ ) -> project_attributes_pb2.ProjectAttributesDeleteResponse: ...
40
+
41
+ async def UpdateProjectAttributes(
42
+ self, request: project_attributes_pb2.ProjectAttributesUpdateRequest
43
+ ) -> project_attributes_pb2.ProjectAttributesUpdateResponse: ...
44
+
45
+ async def GetProjectAttributes(
46
+ self, request: project_attributes_pb2.ProjectAttributesGetRequest
47
+ ) -> project_attributes_pb2.ProjectAttributes: ...
48
+
49
+ async def DeleteProjectAttributes(
50
+ self, request: project_attributes_pb2.ProjectAttributesDeleteRequest
51
+ ) -> project_attributes_pb2.ProjectAttributesDeleteResponse: ...
52
+
53
+
54
+ class TaskService(Protocol):
55
+ async def DeployTask(self, request: task_service_pb2.DeployTaskRequest) -> task_service_pb2.DeployTaskResponse: ...
56
+
57
+ async def GetTaskDetails(
58
+ self, request: task_service_pb2.GetTaskDetailsRequest
59
+ ) -> task_service_pb2.GetTaskDetailsResponse: ...
60
+
61
+ async def ListTasks(self, request: task_service_pb2.ListTasksRequest) -> task_service_pb2.ListTasksResponse: ...
62
+
63
+
64
+ class RunService(Protocol):
65
+ async def CreateRun(self, request: run_service_pb2.CreateRunRequest) -> run_service_pb2.CreateRunResponse: ...
66
+
67
+ async def AbortRun(self, request: run_service_pb2.AbortRunRequest) -> run_service_pb2.AbortRunResponse: ...
68
+
69
+ async def GetRunDetails(
70
+ self, request: run_service_pb2.GetRunDetailsRequest
71
+ ) -> run_service_pb2.GetRunDetailsResponse: ...
72
+
73
+ async def WatchRunDetails(
74
+ self, request: run_service_pb2.WatchRunDetailsRequest
75
+ ) -> AsyncIterator[run_service_pb2.WatchRunDetailsResponse]: ...
76
+
77
+ async def GetActionDetails(
78
+ self, request: run_service_pb2.GetActionDetailsRequest
79
+ ) -> run_service_pb2.GetActionDetailsResponse: ...
80
+
81
+ async def WatchActionDetails(
82
+ self, request: run_service_pb2.WatchActionDetailsRequest
83
+ ) -> AsyncIterator[run_service_pb2.WatchActionDetailsResponse]: ...
84
+
85
+ async def GetActionData(
86
+ self, request: run_service_pb2.GetActionDataRequest
87
+ ) -> run_service_pb2.GetActionDataResponse: ...
88
+
89
+ async def ListRuns(self, request: run_service_pb2.ListRunsRequest) -> run_service_pb2.ListRunsResponse: ...
90
+
91
+ async def WatchRuns(
92
+ self, request: run_service_pb2.WatchRunsRequest
93
+ ) -> AsyncIterator[run_service_pb2.WatchRunsResponse]: ...
94
+
95
+ async def ListActions(self, request: run_service_pb2.ListActionsRequest) -> run_service_pb2.ListActionsResponse: ...
96
+
97
+ async def WatchActions(
98
+ self, request: run_service_pb2.WatchActionsRequest
99
+ ) -> AsyncIterator[run_service_pb2.WatchActionsResponse]: ...
100
+
101
+
102
+ class DataProxyService(Protocol):
103
+ async def CreateUploadLocation(
104
+ self, request: dataproxy_pb2.CreateUploadLocationRequest
105
+ ) -> dataproxy_pb2.CreateUploadLocationResponse: ...
106
+
107
+ async def CreateDownloadLocation(
108
+ self, request: dataproxy_pb2.CreateDownloadLocationRequest
109
+ ) -> dataproxy_pb2.CreateDownloadLocationResponse: ...
110
+
111
+ async def CreateDownloadLink(
112
+ self, request: dataproxy_pb2.CreateDownloadLinkRequest
113
+ ) -> dataproxy_pb2.CreateDownloadLinkResponse: ...
114
+
115
+ async def GetData(self, request: dataproxy_pb2.GetDataRequest) -> dataproxy_pb2.GetDataResponse: ...
116
+
117
+
118
+ class RunLogsService(Protocol):
119
+ def TailLogs(
120
+ self, request: run_logs_service_pb2.TailLogsRequest
121
+ ) -> UnaryStreamCall[RequestType, run_logs_service_pb2.TailLogsResponse]: ...
122
+
123
+
124
+ class SecretService(Protocol):
125
+ async def CreateSecret(self, request: payload_pb2.CreateSecretRequest) -> payload_pb2.CreateSecretResponse: ...
126
+
127
+ async def UpdateSecret(self, request: payload_pb2.UpdateSecretRequest) -> payload_pb2.UpdateSecretResponse: ...
128
+
129
+ async def GetSecret(self, request: payload_pb2.GetSecretRequest) -> payload_pb2.GetSecretResponse: ...
130
+
131
+ async def ListSecrets(self, request: payload_pb2.ListSecretsRequest) -> payload_pb2.ListSecretsResponse: ...
132
+
133
+ async def DeleteSecret(self, request: payload_pb2.DeleteSecretRequest) -> payload_pb2.DeleteSecretResponse: ...
@@ -0,0 +1,12 @@
1
+ from flyte.remote._client.auth._channel import create_channel
2
+ from flyte.remote._client.auth._client_config import AuthType, ClientConfig
3
+ from flyte.remote._client.auth.errors import AccessTokenNotFoundError, AuthenticationError, AuthenticationPending
4
+
5
+ __all__ = [
6
+ "AccessTokenNotFoundError",
7
+ "AuthType",
8
+ "AuthenticationError",
9
+ "AuthenticationPending",
10
+ "ClientConfig",
11
+ "create_channel",
12
+ ]
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ from typing import Literal
5
+
6
+
7
+ def decode_api_key(encoded_str: str) -> tuple[str, str, str, str | Literal["None"]]:
8
+ """Decode encoded base64 string into app credentials. endpoint, client_id, client_secret, org"""
9
+ endpoint, client_id, client_secret, org = base64.b64decode(encoded_str.encode("utf-8")).decode("utf-8").split(":")
10
+ # For consistency, let's make sure org is always a non-empty string
11
+ if not org:
12
+ org = "None"
13
+
14
+ return endpoint, client_id, client_secret, org
File without changes