flyte 2.0.0b13__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/debug.py +38 -0
  3. flyte/_bin/runtime.py +62 -8
  4. flyte/_cache/cache.py +4 -2
  5. flyte/_cache/local_cache.py +216 -0
  6. flyte/_code_bundle/_ignore.py +12 -4
  7. flyte/_code_bundle/_packaging.py +13 -9
  8. flyte/_code_bundle/_utils.py +18 -10
  9. flyte/_code_bundle/bundle.py +17 -9
  10. flyte/_constants.py +1 -0
  11. flyte/_context.py +4 -1
  12. flyte/_custom_context.py +73 -0
  13. flyte/_debug/constants.py +38 -0
  14. flyte/_debug/utils.py +17 -0
  15. flyte/_debug/vscode.py +307 -0
  16. flyte/_deploy.py +235 -61
  17. flyte/_environment.py +20 -6
  18. flyte/_excepthook.py +1 -1
  19. flyte/_hash.py +1 -16
  20. flyte/_image.py +178 -81
  21. flyte/_initialize.py +132 -51
  22. flyte/_interface.py +39 -2
  23. flyte/_internal/controllers/__init__.py +4 -5
  24. flyte/_internal/controllers/_local_controller.py +70 -29
  25. flyte/_internal/controllers/_trace.py +1 -1
  26. flyte/_internal/controllers/remote/__init__.py +0 -2
  27. flyte/_internal/controllers/remote/_action.py +14 -16
  28. flyte/_internal/controllers/remote/_client.py +1 -1
  29. flyte/_internal/controllers/remote/_controller.py +68 -70
  30. flyte/_internal/controllers/remote/_core.py +127 -99
  31. flyte/_internal/controllers/remote/_informer.py +19 -10
  32. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  33. flyte/_internal/imagebuild/docker_builder.py +181 -69
  34. flyte/_internal/imagebuild/image_builder.py +0 -5
  35. flyte/_internal/imagebuild/remote_builder.py +155 -64
  36. flyte/_internal/imagebuild/utils.py +51 -2
  37. flyte/_internal/resolvers/_task_module.py +5 -38
  38. flyte/_internal/resolvers/default.py +2 -2
  39. flyte/_internal/runtime/convert.py +110 -21
  40. flyte/_internal/runtime/entrypoints.py +27 -1
  41. flyte/_internal/runtime/io.py +21 -8
  42. flyte/_internal/runtime/resources_serde.py +20 -6
  43. flyte/_internal/runtime/reuse.py +1 -1
  44. flyte/_internal/runtime/rusty.py +20 -5
  45. flyte/_internal/runtime/task_serde.py +34 -19
  46. flyte/_internal/runtime/taskrunner.py +22 -4
  47. flyte/_internal/runtime/trigger_serde.py +160 -0
  48. flyte/_internal/runtime/types_serde.py +1 -1
  49. flyte/_keyring/__init__.py +0 -0
  50. flyte/_keyring/file.py +115 -0
  51. flyte/_logging.py +201 -39
  52. flyte/_map.py +111 -14
  53. flyte/_module.py +70 -0
  54. flyte/_pod.py +4 -3
  55. flyte/_resources.py +213 -31
  56. flyte/_run.py +110 -39
  57. flyte/_task.py +75 -16
  58. flyte/_task_environment.py +105 -29
  59. flyte/_task_plugins.py +4 -2
  60. flyte/_trace.py +5 -0
  61. flyte/_trigger.py +1000 -0
  62. flyte/_utils/__init__.py +2 -1
  63. flyte/_utils/asyn.py +3 -1
  64. flyte/_utils/coro_management.py +2 -1
  65. flyte/_utils/docker_credentials.py +173 -0
  66. flyte/_utils/module_loader.py +17 -2
  67. flyte/_version.py +3 -3
  68. flyte/cli/_abort.py +3 -3
  69. flyte/cli/_build.py +3 -6
  70. flyte/cli/_common.py +78 -7
  71. flyte/cli/_create.py +182 -4
  72. flyte/cli/_delete.py +23 -1
  73. flyte/cli/_deploy.py +63 -16
  74. flyte/cli/_get.py +79 -34
  75. flyte/cli/_params.py +26 -10
  76. flyte/cli/_plugins.py +209 -0
  77. flyte/cli/_run.py +151 -26
  78. flyte/cli/_serve.py +64 -0
  79. flyte/cli/_update.py +37 -0
  80. flyte/cli/_user.py +17 -0
  81. flyte/cli/main.py +30 -4
  82. flyte/config/_config.py +10 -6
  83. flyte/config/_internal.py +1 -0
  84. flyte/config/_reader.py +29 -8
  85. flyte/connectors/__init__.py +11 -0
  86. flyte/connectors/_connector.py +270 -0
  87. flyte/connectors/_server.py +197 -0
  88. flyte/connectors/utils.py +135 -0
  89. flyte/errors.py +22 -2
  90. flyte/extend.py +8 -1
  91. flyte/extras/_container.py +6 -1
  92. flyte/git/__init__.py +3 -0
  93. flyte/git/_config.py +21 -0
  94. flyte/io/__init__.py +2 -0
  95. flyte/io/_dataframe/__init__.py +2 -0
  96. flyte/io/_dataframe/basic_dfs.py +17 -8
  97. flyte/io/_dataframe/dataframe.py +98 -132
  98. flyte/io/_dir.py +575 -113
  99. flyte/io/_file.py +582 -139
  100. flyte/io/_hashing_io.py +342 -0
  101. flyte/models.py +74 -15
  102. flyte/remote/__init__.py +6 -1
  103. flyte/remote/_action.py +34 -26
  104. flyte/remote/_client/_protocols.py +39 -4
  105. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  106. flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
  107. flyte/remote/_client/auth/_channel.py +10 -6
  108. flyte/remote/_client/controlplane.py +17 -5
  109. flyte/remote/_console.py +3 -2
  110. flyte/remote/_data.py +6 -6
  111. flyte/remote/_logs.py +3 -3
  112. flyte/remote/_run.py +64 -8
  113. flyte/remote/_secret.py +26 -17
  114. flyte/remote/_task.py +75 -33
  115. flyte/remote/_trigger.py +306 -0
  116. flyte/remote/_user.py +33 -0
  117. flyte/report/_report.py +1 -1
  118. flyte/storage/__init__.py +6 -1
  119. flyte/storage/_config.py +5 -1
  120. flyte/storage/_parallel_reader.py +274 -0
  121. flyte/storage/_storage.py +200 -103
  122. flyte/types/__init__.py +16 -0
  123. flyte/types/_interface.py +2 -2
  124. flyte/types/_pickle.py +35 -8
  125. flyte/types/_string_literals.py +8 -9
  126. flyte/types/_type_engine.py +40 -70
  127. flyte/types/_utils.py +1 -1
  128. flyte-2.0.0b30.data/scripts/debug.py +38 -0
  129. {flyte-2.0.0b13.data → flyte-2.0.0b30.data}/scripts/runtime.py +62 -8
  130. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/METADATA +11 -3
  131. flyte-2.0.0b30.dist-info/RECORD +192 -0
  132. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +3 -0
  133. flyte/_protos/common/authorization_pb2.py +0 -66
  134. flyte/_protos/common/authorization_pb2.pyi +0 -108
  135. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  136. flyte/_protos/common/identifier_pb2.py +0 -93
  137. flyte/_protos/common/identifier_pb2.pyi +0 -110
  138. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  139. flyte/_protos/common/identity_pb2.py +0 -48
  140. flyte/_protos/common/identity_pb2.pyi +0 -72
  141. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  142. flyte/_protos/common/list_pb2.py +0 -36
  143. flyte/_protos/common/list_pb2.pyi +0 -71
  144. flyte/_protos/common/list_pb2_grpc.py +0 -4
  145. flyte/_protos/common/policy_pb2.py +0 -37
  146. flyte/_protos/common/policy_pb2.pyi +0 -27
  147. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  148. flyte/_protos/common/role_pb2.py +0 -37
  149. flyte/_protos/common/role_pb2.pyi +0 -53
  150. flyte/_protos/common/role_pb2_grpc.py +0 -4
  151. flyte/_protos/common/runtime_version_pb2.py +0 -28
  152. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  153. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  154. flyte/_protos/imagebuilder/definition_pb2.py +0 -59
  155. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -140
  156. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  157. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  158. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  159. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  160. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  161. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  162. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  163. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  164. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  165. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  166. flyte/_protos/secret/definition_pb2.py +0 -49
  167. flyte/_protos/secret/definition_pb2.pyi +0 -93
  168. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  169. flyte/_protos/secret/payload_pb2.py +0 -62
  170. flyte/_protos/secret/payload_pb2.pyi +0 -94
  171. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  172. flyte/_protos/secret/secret_pb2.py +0 -38
  173. flyte/_protos/secret/secret_pb2.pyi +0 -6
  174. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  175. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  176. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  177. flyte/_protos/workflow/common_pb2.py +0 -27
  178. flyte/_protos/workflow/common_pb2.pyi +0 -14
  179. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  180. flyte/_protos/workflow/environment_pb2.py +0 -29
  181. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  182. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  183. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  184. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  185. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  186. flyte/_protos/workflow/queue_service_pb2.py +0 -109
  187. flyte/_protos/workflow/queue_service_pb2.pyi +0 -166
  188. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  189. flyte/_protos/workflow/run_definition_pb2.py +0 -121
  190. flyte/_protos/workflow/run_definition_pb2.pyi +0 -327
  191. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  192. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  193. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  194. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  195. flyte/_protos/workflow/run_service_pb2.py +0 -137
  196. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  197. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  198. flyte/_protos/workflow/state_service_pb2.py +0 -67
  199. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  200. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  201. flyte/_protos/workflow/task_definition_pb2.py +0 -79
  202. flyte/_protos/workflow/task_definition_pb2.pyi +0 -81
  203. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  204. flyte/_protos/workflow/task_service_pb2.py +0 -60
  205. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  206. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  207. flyte-2.0.0b13.dist-info/RECORD +0 -239
  208. /flyte/{_protos → _debug}/__init__.py +0 -0
  209. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  210. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  211. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,342 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import inspect
5
+ from typing import Any, Iterable, Optional, Protocol, Union, runtime_checkable
6
+
7
+
8
+ @runtime_checkable
9
+ class HashMethod(Protocol):
10
+ def update(self, data: memoryview, /) -> None: ...
11
+ def result(self) -> str: ...
12
+
13
+ # Optional convenience; not required by the writers.
14
+ def reset(self) -> None: ...
15
+
16
+
17
+ class PrecomputedValue(HashMethod):
18
+ def __init__(self, value: str):
19
+ self._value = value
20
+
21
+ def update(self, data: memoryview, /) -> None: ...
22
+
23
+ def result(self) -> str:
24
+ return self._value
25
+
26
+
27
+ class HashlibAccumulator(HashMethod):
28
+ """
29
+ Wrap a hashlib-like object to the Accumulator protocol.
30
+ h = hashlib.new("sha256")
31
+ acc = HashlibAccumulator(h)
32
+ """
33
+
34
+ def __init__(self, h):
35
+ self._h = h
36
+
37
+ def update(self, data: memoryview, /) -> None:
38
+ self._h.update(data)
39
+
40
+ def result(self) -> Any:
41
+ return self._h.hexdigest()
42
+
43
+ @classmethod
44
+ def from_hash_name(cls, name: str) -> HashlibAccumulator:
45
+ """
46
+ Create an accumulator from a hashlib algorithm name.
47
+ """
48
+ h = hashlib.new(name)
49
+ return cls(h)
50
+
51
+
52
+ class HashingWriter:
53
+ """
54
+ Sync writer that updates a user-supplied accumulator on every write.
55
+
56
+ Hashing covers the exact bytes you pass in. If you write str, it is encoded
57
+ using the underlying file's .encoding if available, else UTF-8 (for hashing only).
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ fh,
63
+ accumulator: HashMethod,
64
+ *,
65
+ encoding: Optional[str] = None,
66
+ errors: str = "strict",
67
+ ):
68
+ self._fh = fh
69
+ self._acc: HashMethod = accumulator
70
+ self._encoding = encoding or getattr(fh, "encoding", None)
71
+ self._errors = errors
72
+
73
+ def result(self) -> Any:
74
+ return self._acc.result()
75
+
76
+ def _to_bytes_mv(self, data) -> memoryview:
77
+ if isinstance(data, str):
78
+ b = data.encode(self._encoding or "utf-8", self._errors)
79
+ return memoryview(b)
80
+ if isinstance(data, (bytes, bytearray, memoryview)):
81
+ return memoryview(data)
82
+ # Accept any buffer-protocol object (e.g., numpy arrays)
83
+ return memoryview(data)
84
+
85
+ def write(self, data):
86
+ mv = self._to_bytes_mv(data)
87
+ self._acc.update(mv)
88
+ return self._fh.write(data)
89
+
90
+ def writelines(self, lines: Iterable[Union[str, bytes, bytearray, memoryview]]):
91
+ for line in lines:
92
+ self.write(line)
93
+
94
+ def flush(self):
95
+ return self._fh.flush()
96
+
97
+ def close(self):
98
+ return self._fh.close()
99
+
100
+ def __getattr__(self, name):
101
+ return getattr(self._fh, name)
102
+
103
+
104
+ class AsyncHashingWriter:
105
+ """
106
+ Async version of HashingWriter with the same behavior.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ fh,
112
+ accumulator: HashMethod,
113
+ *,
114
+ encoding: Optional[str] = None,
115
+ errors: str = "strict",
116
+ ):
117
+ self._fh = fh
118
+ self._acc = accumulator
119
+ self._encoding = encoding or getattr(fh, "encoding", None)
120
+ self._errors = errors
121
+
122
+ def result(self) -> Any:
123
+ return self._acc.result()
124
+
125
+ def _to_bytes_mv(self, data) -> memoryview:
126
+ if isinstance(data, str):
127
+ b = data.encode(self._encoding or "utf-8", self._errors)
128
+ return memoryview(b)
129
+ if isinstance(data, (bytes, bytearray, memoryview)):
130
+ return memoryview(data)
131
+ return memoryview(data)
132
+
133
+ async def write(self, data):
134
+ mv = self._to_bytes_mv(data)
135
+ self._acc.update(mv)
136
+ return await self._fh.write(data)
137
+
138
+ async def writelines(self, lines: Iterable[Union[str, bytes, bytearray, memoryview]]):
139
+ for line in lines:
140
+ await self.write(line)
141
+
142
+ async def flush(self):
143
+ fn = getattr(self._fh, "flush", None)
144
+ if fn is None:
145
+ return None
146
+ res = fn()
147
+ if inspect.isawaitable(res):
148
+ return await res
149
+ return res
150
+
151
+ async def close(self):
152
+ fn = getattr(self._fh, "close", None)
153
+ if fn is None:
154
+ return None
155
+ res = fn()
156
+ if inspect.isawaitable(res):
157
+ return await res
158
+ return res
159
+
160
+ def __getattr__(self, name):
161
+ return getattr(self._fh, name)
162
+
163
+
164
+ class HashingReader:
165
+ """
166
+ Sync reader that updates a user-supplied accumulator on every read operation.
167
+
168
+ If the underlying handle returns str (text mode), we encode it for hashing only,
169
+ using the handle's .encoding if present, else the explicit 'encoding' arg, else UTF-8.
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ fh,
175
+ accumulator: HashMethod,
176
+ *,
177
+ encoding: Optional[str] = None,
178
+ errors: str = "strict",
179
+ ):
180
+ self._fh = fh
181
+ self._acc = accumulator
182
+ self._encoding = encoding or getattr(fh, "encoding", None)
183
+ self._errors = errors
184
+
185
+ def result(self) -> str:
186
+ return self._acc.result()
187
+
188
+ def _to_bytes_mv(self, data) -> Optional[memoryview]:
189
+ if data is None:
190
+ return None
191
+ if isinstance(data, str):
192
+ return memoryview(data.encode(self._encoding or "utf-8", self._errors))
193
+ if isinstance(data, (bytes, bytearray, memoryview)):
194
+ return memoryview(data)
195
+ # Accept any buffer-protocol object (rare for read paths, but safe)
196
+ return memoryview(data)
197
+
198
+ def read(self, size: int = -1):
199
+ data = self._fh.read(size)
200
+ mv = self._to_bytes_mv(data)
201
+ if mv is not None and len(mv) > 0:
202
+ self._acc.update(mv)
203
+ return data
204
+
205
+ def readline(self, size: int = -1):
206
+ line = self._fh.readline(size)
207
+ mv = self._to_bytes_mv(line)
208
+ if mv is not None and len(mv) > 0:
209
+ self._acc.update(mv)
210
+ return line
211
+
212
+ def readlines(self, hint: int = -1):
213
+ lines = self._fh.readlines(hint)
214
+ # Update in order to reflect exact concatenation
215
+ for line in lines:
216
+ mv = self._to_bytes_mv(line)
217
+ if mv is not None and len(mv) > 0:
218
+ self._acc.update(mv)
219
+ return lines
220
+
221
+ def __iter__(self):
222
+ return self
223
+
224
+ def __next__(self):
225
+ # Delegate to the underlying iterator to preserve semantics (including buffering),
226
+ # but intercept the produced line to hash it.
227
+ line = next(self._fh)
228
+ mv = self._to_bytes_mv(line)
229
+ if mv is not None and len(mv) > 0:
230
+ self._acc.update(mv)
231
+ return line
232
+
233
+ # ---- passthrough ----
234
+ def __getattr__(self, name):
235
+ # Avoid leaking private lookups to the underlying object if we're missing something internal
236
+ if name.startswith("_"):
237
+ raise AttributeError(name)
238
+ return getattr(self._fh, name)
239
+
240
+
241
+ class AsyncHashingReader:
242
+ """
243
+ Async reader that updates a user-supplied accumulator on every read operation.
244
+
245
+ Works with aiofiles/fsspec async handles. `flush`/`close` may be awaitable or sync.
246
+ """
247
+
248
+ def __init__(
249
+ self,
250
+ fh,
251
+ accumulator: HashMethod,
252
+ *,
253
+ encoding: Optional[str] = None,
254
+ errors: str = "strict",
255
+ ):
256
+ self._fh = fh
257
+ self._acc = accumulator
258
+ self._encoding = encoding or getattr(fh, "encoding", None)
259
+ self._errors = errors
260
+
261
+ def result(self) -> str:
262
+ return self._acc.result()
263
+
264
+ def _to_bytes_mv(self, data) -> Optional[memoryview]:
265
+ if data is None:
266
+ return None
267
+ if isinstance(data, str):
268
+ return memoryview(data.encode(self._encoding or "utf-8", self._errors))
269
+ if isinstance(data, (bytes, bytearray, memoryview)):
270
+ return memoryview(data)
271
+ return memoryview(data)
272
+
273
+ async def read(self, size: int = -1):
274
+ data = await self._fh.read(size)
275
+ mv = self._to_bytes_mv(data)
276
+ if mv is not None and len(mv) > 0:
277
+ self._acc.update(mv)
278
+ return data
279
+
280
+ async def readline(self, size: int = -1):
281
+ line = await self._fh.readline(size)
282
+ mv = self._to_bytes_mv(line)
283
+ if mv is not None and len(mv) > 0:
284
+ self._acc.update(mv)
285
+ return line
286
+
287
+ async def readlines(self, hint: int = -1):
288
+ # Some async filehandles implement readlines(); if not, fall back to manual loop.
289
+ if hasattr(self._fh, "readlines"):
290
+ lines = await self._fh.readlines(hint)
291
+ for line in lines:
292
+ mv = self._to_bytes_mv(line)
293
+ if mv is not None and len(mv) > 0:
294
+ self._acc.update(mv)
295
+ return lines
296
+ # Fallback: read all via iteration
297
+ lines = []
298
+ async for line in self:
299
+ lines.append(line)
300
+ return lines
301
+
302
+ def __aiter__(self):
303
+ return self
304
+
305
+ async def __anext__(self):
306
+ # Prefer the underlying async iterator if present
307
+ anext_fn = getattr(self._fh, "__anext__", None)
308
+ if anext_fn is not None:
309
+ try:
310
+ line = await anext_fn()
311
+ except StopAsyncIteration:
312
+ raise
313
+ mv = self._to_bytes_mv(line)
314
+ if mv is not None and len(mv) > 0:
315
+ self._acc.update(mv)
316
+ return line
317
+
318
+ # Fallback to readline-based iteration
319
+ line = await self.readline()
320
+ if line == "" or line == b"":
321
+ raise StopAsyncIteration
322
+ return line
323
+
324
+ async def flush(self):
325
+ fn = getattr(self._fh, "flush", None)
326
+ if fn is None:
327
+ return None
328
+ res = fn()
329
+ return await res if inspect.isawaitable(res) else res
330
+
331
+ async def close(self):
332
+ fn = getattr(self._fh, "close", None)
333
+ if fn is None:
334
+ return None
335
+ res = fn()
336
+ return await res if inspect.isawaitable(res) else res
337
+
338
+ # ---- passthrough ----
339
+ def __getattr__(self, name):
340
+ if name.startswith("_"):
341
+ raise AttributeError(name)
342
+ return getattr(self._fh, name)
flyte/models.py CHANGED
@@ -3,17 +3,18 @@ from __future__ import annotations
3
3
  import inspect
4
4
  import os
5
5
  import pathlib
6
+ import typing
6
7
  from dataclasses import dataclass, field, replace
7
8
  from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Literal, Optional, Tuple, Type
8
9
 
9
10
  import rich.repr
10
11
 
11
12
  from flyte._docstring import Docstring
12
- from flyte._interface import extract_return_annotation
13
+ from flyte._interface import extract_return_annotation, literal_to_enum
13
14
  from flyte._logging import logger
14
15
 
15
16
  if TYPE_CHECKING:
16
- from flyteidl.core import literals_pb2
17
+ from flyteidl2.core import literals_pb2
17
18
 
18
19
  from flyte._internal.imagebuild.image_builder import ImageCache
19
20
  from flyte.report import Report
@@ -76,6 +77,37 @@ class ActionID:
76
77
  return self.new_sub_action(new_name)
77
78
 
78
79
 
80
+ @rich.repr.auto
81
+ @dataclass
82
+ class PathRewrite:
83
+ """
84
+ Configuration for rewriting paths during input loading.
85
+ """
86
+
87
+ # If set, rewrites any path starting with this prefix to the new prefix.
88
+ old_prefix: str
89
+ new_prefix: str
90
+
91
+ def __post_init__(self):
92
+ if not self.old_prefix or not self.new_prefix:
93
+ raise ValueError("Both old_prefix and new_prefix must be non-empty strings.")
94
+ if self.old_prefix == self.new_prefix:
95
+ raise ValueError("old_prefix and new_prefix must be different.")
96
+
97
+ @classmethod
98
+ def from_str(cls, pattern: str) -> PathRewrite:
99
+ """
100
+ Create a PathRewrite from a string pattern of the form `old_prefix->new_prefix`.
101
+ """
102
+ parts = pattern.split("->")
103
+ if len(parts) != 2:
104
+ raise ValueError(f"Invalid path rewrite pattern: {pattern}. Expected format 'old_prefix->new_prefix'.")
105
+ return cls(old_prefix=parts[0], new_prefix=parts[1])
106
+
107
+ def __repr__(self) -> str:
108
+ return f"{self.old_prefix}->{self.new_prefix}"
109
+
110
+
79
111
  @rich.repr.auto
80
112
  @dataclass(frozen=True, kw_only=True)
81
113
  class RawDataPath:
@@ -85,6 +117,7 @@ class RawDataPath:
85
117
  """
86
118
 
87
119
  path: str
120
+ path_rewrite: Optional[PathRewrite] = None
88
121
 
89
122
  @classmethod
90
123
  def from_local_folder(cls, local_folder: str | pathlib.Path | None = None) -> RawDataPath:
@@ -111,7 +144,7 @@ class RawDataPath:
111
144
 
112
145
  def get_random_remote_path(self, file_name: Optional[str] = None) -> str:
113
146
  """
114
- Returns a random path for uploading a file/directory to.
147
+ Returns a random path for uploading a file/directory to. This file/folder will not be created, it's just a path.
115
148
 
116
149
  :param file_name: If given, will be joined after a randomly generated portion.
117
150
  :return:
@@ -127,13 +160,14 @@ class RawDataPath:
127
160
 
128
161
  protocol = get_protocol(file_prefix)
129
162
  if "file" in protocol:
130
- local_path = pathlib.Path(file_prefix) / random_string
163
+ parent_folder = pathlib.Path(file_prefix)
164
+ parent_folder.mkdir(exist_ok=True, parents=True)
131
165
  if file_name:
132
- # Only if file name is given do we create the parent, because it may be needed as a folder otherwise
133
- local_path = local_path / file_name
134
- if not local_path.exists():
135
- local_path.parent.mkdir(exist_ok=True, parents=True)
136
- local_path.touch()
166
+ random_folder = parent_folder / random_string
167
+ random_folder.mkdir()
168
+ local_path = random_folder / file_name
169
+ else:
170
+ local_path = parent_folder / random_string
137
171
  return str(local_path.absolute())
138
172
 
139
173
  fs = fsspec.filesystem(protocol)
@@ -161,11 +195,14 @@ class TaskContext:
161
195
  :param action: The action ID of the current execution. This is always set, within a run.
162
196
  :param version: The version of the executed task. This is set when the task is executed by an action and will be
163
197
  set on all sub-actions.
198
+ :param custom_context: Context metadata for the action. If an action receives context, it'll automatically pass it
199
+ to any actions it spawns. Context will not be used for cache key computation.
164
200
  """
165
201
 
166
202
  action: ActionID
167
203
  version: str
168
204
  raw_data_path: RawDataPath
205
+ input_path: str | None = None
169
206
  output_path: str
170
207
  run_base_dir: str
171
208
  report: Report
@@ -175,6 +212,8 @@ class TaskContext:
175
212
  compiled_image_cache: ImageCache | None = None
176
213
  data: Dict[str, Any] = field(default_factory=dict)
177
214
  mode: Literal["local", "remote", "hybrid"] = "remote"
215
+ interactive_mode: bool = False
216
+ custom_context: Dict[str, str] = field(default_factory=dict)
178
217
 
179
218
  def replace(self, **kwargs) -> TaskContext:
180
219
  if "data" in kwargs:
@@ -316,10 +355,18 @@ class NativeInterface:
316
355
  """
317
356
  Extract the native interface from the given function. This is used to create a native interface for the task.
318
357
  """
358
+ # Get function parameters, defaults, varargs info (POSITIONAL_ONLY, VAR_POSITIONAL, KEYWORD_ONLY, etc.).
319
359
  sig = inspect.signature(func)
320
360
 
321
361
  # Extract parameter details (name, type, default value)
322
362
  param_info = {}
363
+ try:
364
+ # Get fully evaluated, real Python types for type checking.
365
+ hints = typing.get_type_hints(func, include_extras=True)
366
+ except Exception as e:
367
+ logger.warning(f"Could not get type hints for function {func.__name__}: {e}")
368
+ raise
369
+
323
370
  for name, param in sig.parameters.items():
324
371
  if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
325
372
  raise ValueError(f"Function {func.__name__} cannot have variable positional or keyword arguments.")
@@ -327,10 +374,14 @@ class NativeInterface:
327
374
  logger.warning(
328
375
  f"Function {func.__name__} has parameter {name} without type annotation. Data will be pickled."
329
376
  )
330
- param_info[name] = (param.annotation, param.default)
377
+ arg_type = hints.get(name, param.annotation)
378
+ if typing.get_origin(arg_type) is Literal:
379
+ param_info[name] = (literal_to_enum(arg_type), param.default)
380
+ else:
381
+ param_info[name] = (arg_type, param.default)
331
382
 
332
383
  # Get return type
333
- outputs = extract_return_annotation(sig.return_annotation)
384
+ outputs = extract_return_annotation(hints.get("return", sig.return_annotation))
334
385
  return cls(inputs=param_info, outputs=outputs)
335
386
 
336
387
  def convert_to_kwargs(self, *args, **kwargs) -> Dict[str, Any]:
@@ -340,9 +391,15 @@ class NativeInterface:
340
391
  """
341
392
  # Convert positional arguments to keyword arguments
342
393
  if len(args) > len(self.inputs):
343
- raise ValueError(f"Too many positional arguments provided, inputs {self.inputs.keys()}, args {len(args)}")
394
+ raise ValueError(
395
+ f"Too many positional arguments provided, expected inputs {self.inputs.keys()}, args {len(args)}"
396
+ )
344
397
  for arg, input_name in zip(args, self.inputs.keys()):
345
398
  kwargs[input_name] = arg
399
+ if len(kwargs) > len(self.inputs):
400
+ raise ValueError(
401
+ f"Too many keyword arguments provided, expected inputs {self.inputs.keys()}, args {kwargs.keys()}"
402
+ )
346
403
  return kwargs
347
404
 
348
405
  def get_input_types(self) -> Dict[str, Type]:
@@ -408,13 +465,15 @@ class SerializationContext:
408
465
  code_bundle: Optional[CodeBundle] = None
409
466
  input_path: str = "{{.input}}"
410
467
  output_path: str = "{{.outputPrefix}}"
411
- _entrypoint_path: str = field(default="_bin/runtime.py", init=False)
468
+ interpreter_path: str = "/opt/venv/bin/python"
412
469
  image_cache: ImageCache | None = None
413
470
  root_dir: Optional[pathlib.Path] = None
414
471
 
415
- def get_entrypoint_path(self, interpreter_path: str) -> str:
472
+ def get_entrypoint_path(self, interpreter_path: Optional[str] = None) -> str:
416
473
  """
417
474
  Get the entrypoint path for the task. This is used to determine the entrypoint for the task execution.
418
475
  :param interpreter_path: The path to the interpreter (python)
419
476
  """
420
- return os.path.join(os.path.dirname(os.path.dirname(interpreter_path)), self._entrypoint_path)
477
+ if interpreter_path is None:
478
+ interpreter_path = self.interpreter_path
479
+ return os.path.join(os.path.dirname(interpreter_path), "runtime.py")
flyte/remote/__init__.py CHANGED
@@ -7,12 +7,15 @@ __all__ = [
7
7
  "ActionDetails",
8
8
  "ActionInputs",
9
9
  "ActionOutputs",
10
+ "Phase",
10
11
  "Project",
11
12
  "Run",
12
13
  "RunDetails",
13
14
  "Secret",
14
15
  "SecretTypes",
15
16
  "Task",
17
+ "Trigger",
18
+ "User",
16
19
  "create_channel",
17
20
  "upload_dir",
18
21
  "upload_file",
@@ -22,6 +25,8 @@ from ._action import Action, ActionDetails, ActionInputs, ActionOutputs
22
25
  from ._client.auth import create_channel
23
26
  from ._data import upload_dir, upload_file
24
27
  from ._project import Project
25
- from ._run import Run, RunDetails
28
+ from ._run import Phase, Run, RunDetails
26
29
  from ._secret import Secret, SecretTypes
27
30
  from ._task import Task
31
+ from ._trigger import Trigger
32
+ from ._user import User