flyte 2.0.0b22__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/runtime.py +43 -5
  3. flyte/_cache/cache.py +4 -2
  4. flyte/_cache/local_cache.py +216 -0
  5. flyte/_code_bundle/_ignore.py +1 -1
  6. flyte/_code_bundle/_packaging.py +4 -4
  7. flyte/_code_bundle/_utils.py +14 -8
  8. flyte/_code_bundle/bundle.py +13 -5
  9. flyte/_constants.py +1 -0
  10. flyte/_context.py +4 -1
  11. flyte/_custom_context.py +73 -0
  12. flyte/_debug/constants.py +0 -1
  13. flyte/_debug/vscode.py +6 -1
  14. flyte/_deploy.py +223 -59
  15. flyte/_environment.py +5 -0
  16. flyte/_excepthook.py +1 -1
  17. flyte/_image.py +144 -82
  18. flyte/_initialize.py +95 -12
  19. flyte/_interface.py +2 -0
  20. flyte/_internal/controllers/_local_controller.py +65 -24
  21. flyte/_internal/controllers/_trace.py +1 -1
  22. flyte/_internal/controllers/remote/_action.py +13 -11
  23. flyte/_internal/controllers/remote/_client.py +1 -1
  24. flyte/_internal/controllers/remote/_controller.py +9 -4
  25. flyte/_internal/controllers/remote/_core.py +16 -16
  26. flyte/_internal/controllers/remote/_informer.py +4 -4
  27. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  28. flyte/_internal/imagebuild/docker_builder.py +139 -84
  29. flyte/_internal/imagebuild/image_builder.py +7 -13
  30. flyte/_internal/imagebuild/remote_builder.py +65 -13
  31. flyte/_internal/imagebuild/utils.py +51 -3
  32. flyte/_internal/resolvers/_task_module.py +5 -38
  33. flyte/_internal/resolvers/default.py +2 -2
  34. flyte/_internal/runtime/convert.py +42 -20
  35. flyte/_internal/runtime/entrypoints.py +24 -1
  36. flyte/_internal/runtime/io.py +21 -8
  37. flyte/_internal/runtime/resources_serde.py +20 -6
  38. flyte/_internal/runtime/reuse.py +1 -1
  39. flyte/_internal/runtime/rusty.py +20 -5
  40. flyte/_internal/runtime/task_serde.py +33 -27
  41. flyte/_internal/runtime/taskrunner.py +10 -1
  42. flyte/_internal/runtime/trigger_serde.py +160 -0
  43. flyte/_internal/runtime/types_serde.py +1 -1
  44. flyte/_keyring/file.py +39 -9
  45. flyte/_logging.py +79 -12
  46. flyte/_map.py +31 -12
  47. flyte/_module.py +70 -0
  48. flyte/_pod.py +2 -2
  49. flyte/_resources.py +213 -31
  50. flyte/_run.py +107 -41
  51. flyte/_task.py +66 -10
  52. flyte/_task_environment.py +96 -24
  53. flyte/_task_plugins.py +4 -2
  54. flyte/_trigger.py +1000 -0
  55. flyte/_utils/__init__.py +2 -1
  56. flyte/_utils/asyn.py +3 -1
  57. flyte/_utils/docker_credentials.py +173 -0
  58. flyte/_utils/module_loader.py +17 -2
  59. flyte/_version.py +3 -3
  60. flyte/cli/_abort.py +3 -3
  61. flyte/cli/_build.py +1 -3
  62. flyte/cli/_common.py +78 -7
  63. flyte/cli/_create.py +178 -3
  64. flyte/cli/_delete.py +23 -1
  65. flyte/cli/_deploy.py +49 -11
  66. flyte/cli/_get.py +79 -34
  67. flyte/cli/_params.py +8 -6
  68. flyte/cli/_plugins.py +209 -0
  69. flyte/cli/_run.py +127 -11
  70. flyte/cli/_serve.py +64 -0
  71. flyte/cli/_update.py +37 -0
  72. flyte/cli/_user.py +17 -0
  73. flyte/cli/main.py +30 -4
  74. flyte/config/_config.py +2 -0
  75. flyte/config/_internal.py +1 -0
  76. flyte/config/_reader.py +3 -3
  77. flyte/connectors/__init__.py +11 -0
  78. flyte/connectors/_connector.py +270 -0
  79. flyte/connectors/_server.py +197 -0
  80. flyte/connectors/utils.py +135 -0
  81. flyte/errors.py +10 -1
  82. flyte/extend.py +8 -1
  83. flyte/extras/_container.py +6 -1
  84. flyte/git/_config.py +11 -9
  85. flyte/io/__init__.py +2 -0
  86. flyte/io/_dataframe/__init__.py +2 -0
  87. flyte/io/_dataframe/basic_dfs.py +1 -1
  88. flyte/io/_dataframe/dataframe.py +12 -8
  89. flyte/io/_dir.py +551 -120
  90. flyte/io/_file.py +538 -141
  91. flyte/models.py +57 -12
  92. flyte/remote/__init__.py +6 -1
  93. flyte/remote/_action.py +18 -16
  94. flyte/remote/_client/_protocols.py +39 -4
  95. flyte/remote/_client/auth/_channel.py +10 -6
  96. flyte/remote/_client/controlplane.py +17 -5
  97. flyte/remote/_console.py +3 -2
  98. flyte/remote/_data.py +4 -3
  99. flyte/remote/_logs.py +3 -3
  100. flyte/remote/_run.py +47 -7
  101. flyte/remote/_secret.py +26 -17
  102. flyte/remote/_task.py +21 -9
  103. flyte/remote/_trigger.py +306 -0
  104. flyte/remote/_user.py +33 -0
  105. flyte/storage/__init__.py +6 -1
  106. flyte/storage/_parallel_reader.py +274 -0
  107. flyte/storage/_storage.py +185 -103
  108. flyte/types/__init__.py +16 -0
  109. flyte/types/_interface.py +2 -2
  110. flyte/types/_pickle.py +17 -4
  111. flyte/types/_string_literals.py +8 -9
  112. flyte/types/_type_engine.py +26 -19
  113. flyte/types/_utils.py +1 -1
  114. {flyte-2.0.0b22.data → flyte-2.0.0b30.data}/scripts/runtime.py +43 -5
  115. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/METADATA +8 -1
  116. flyte-2.0.0b30.dist-info/RECORD +192 -0
  117. flyte/_protos/__init__.py +0 -0
  118. flyte/_protos/common/authorization_pb2.py +0 -66
  119. flyte/_protos/common/authorization_pb2.pyi +0 -108
  120. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  121. flyte/_protos/common/identifier_pb2.py +0 -99
  122. flyte/_protos/common/identifier_pb2.pyi +0 -120
  123. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  124. flyte/_protos/common/identity_pb2.py +0 -48
  125. flyte/_protos/common/identity_pb2.pyi +0 -72
  126. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  127. flyte/_protos/common/list_pb2.py +0 -36
  128. flyte/_protos/common/list_pb2.pyi +0 -71
  129. flyte/_protos/common/list_pb2_grpc.py +0 -4
  130. flyte/_protos/common/policy_pb2.py +0 -37
  131. flyte/_protos/common/policy_pb2.pyi +0 -27
  132. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  133. flyte/_protos/common/role_pb2.py +0 -37
  134. flyte/_protos/common/role_pb2.pyi +0 -53
  135. flyte/_protos/common/role_pb2_grpc.py +0 -4
  136. flyte/_protos/common/runtime_version_pb2.py +0 -28
  137. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  138. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  139. flyte/_protos/imagebuilder/definition_pb2.py +0 -60
  140. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -153
  141. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  142. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  143. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  144. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  145. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  146. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  147. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  148. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  149. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  150. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  151. flyte/_protos/secret/definition_pb2.py +0 -49
  152. flyte/_protos/secret/definition_pb2.pyi +0 -93
  153. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  154. flyte/_protos/secret/payload_pb2.py +0 -62
  155. flyte/_protos/secret/payload_pb2.pyi +0 -94
  156. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  157. flyte/_protos/secret/secret_pb2.py +0 -38
  158. flyte/_protos/secret/secret_pb2.pyi +0 -6
  159. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  160. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  161. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  162. flyte/_protos/workflow/common_pb2.py +0 -27
  163. flyte/_protos/workflow/common_pb2.pyi +0 -14
  164. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  165. flyte/_protos/workflow/environment_pb2.py +0 -29
  166. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  167. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  168. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  169. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  170. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  171. flyte/_protos/workflow/queue_service_pb2.py +0 -111
  172. flyte/_protos/workflow/queue_service_pb2.pyi +0 -168
  173. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  174. flyte/_protos/workflow/run_definition_pb2.py +0 -123
  175. flyte/_protos/workflow/run_definition_pb2.pyi +0 -352
  176. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  177. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  178. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  179. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  180. flyte/_protos/workflow/run_service_pb2.py +0 -137
  181. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  182. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  183. flyte/_protos/workflow/state_service_pb2.py +0 -67
  184. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  185. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  186. flyte/_protos/workflow/task_definition_pb2.py +0 -82
  187. flyte/_protos/workflow/task_definition_pb2.pyi +0 -88
  188. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  189. flyte/_protos/workflow/task_service_pb2.py +0 -60
  190. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  191. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  192. flyte-2.0.0b22.dist-info/RECORD +0 -250
  193. {flyte-2.0.0b22.data → flyte-2.0.0b30.data}/scripts/debug.py +0 -0
  194. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  195. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +0 -0
  196. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  197. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,274 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import dataclasses
5
+ import io
6
+ import os
7
+ import pathlib
8
+ import sys
9
+ import tempfile
10
+ import typing
11
+ from typing import Any, Hashable, Protocol
12
+
13
+ import aiofiles
14
+ import aiofiles.os
15
+ import obstore
16
+
17
+ if typing.TYPE_CHECKING:
18
+ from obstore import Bytes, ObjectMeta
19
+ from obstore.store import ObjectStore
20
+
21
+ CHUNK_SIZE = int(os.getenv("FLYTE_IO_CHUNK_SIZE", str(16 * 1024 * 1024)))
22
+ MAX_CONCURRENCY = int(os.getenv("FLYTE_IO_MAX_CONCURRENCY", str(32)))
23
+
24
+
25
+ class DownloadQueueEmpty(RuntimeError):
26
+ pass
27
+
28
+
29
+ class BufferProtocol(Protocol):
30
+ async def write(self, offset, length, value: Bytes) -> None: ...
31
+
32
+ async def read(self) -> memoryview: ...
33
+
34
+ @property
35
+ def complete(self) -> bool: ...
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class _MemoryBuffer:
40
+ arr: bytearray
41
+ pending: int
42
+ _closed: bool = False
43
+
44
+ async def write(self, offset: int, length: int, value: Bytes) -> None:
45
+ self.arr[offset : offset + length] = memoryview(value)
46
+ self.pending -= length
47
+
48
+ async def read(self) -> memoryview:
49
+ return memoryview(self.arr)
50
+
51
+ @property
52
+ def complete(self) -> bool:
53
+ return self.pending == 0
54
+
55
+ @classmethod
56
+ def new(cls, size):
57
+ return cls(arr=bytearray(size), pending=size)
58
+
59
+
60
+ @dataclasses.dataclass
61
+ class _FileBuffer:
62
+ path: pathlib.Path
63
+ pending: int
64
+ _handle: io.FileIO | None = None
65
+ _closed: bool = False
66
+
67
+ async def write(self, offset: int, length: int, value: Bytes) -> None:
68
+ async with aiofiles.open(self.path, mode="r+b") as f:
69
+ await f.seek(offset)
70
+ await f.write(value)
71
+ self.pending -= length
72
+
73
+ async def read(self) -> memoryview:
74
+ async with aiofiles.open(self.path, mode="rb") as f:
75
+ return memoryview(await f.read())
76
+
77
+ @property
78
+ def complete(self) -> bool:
79
+ return self.pending == 0
80
+
81
+ @classmethod
82
+ def new(cls, path: pathlib.Path, size: int):
83
+ path.parent.mkdir(parents=True, exist_ok=True)
84
+ path.touch()
85
+ return cls(path=path, pending=size)
86
+
87
+
88
+ @dataclasses.dataclass
89
+ class Chunk:
90
+ offset: int
91
+ length: int
92
+
93
+
94
+ @dataclasses.dataclass
95
+ class Source:
96
+ id: Hashable
97
+ path: pathlib.Path # Should be str, represents the fully qualified prefix of a file (no bucket)
98
+ length: int
99
+ metadata: Any | None = None
100
+
101
+
102
+ @dataclasses.dataclass
103
+ class DownloadTask:
104
+ source: Source
105
+ chunk: Chunk
106
+ target: pathlib.Path | None = None
107
+
108
+
109
+ class ObstoreParallelReader:
110
+ def __init__(
111
+ self,
112
+ store: ObjectStore,
113
+ *,
114
+ chunk_size=CHUNK_SIZE,
115
+ max_concurrency=MAX_CONCURRENCY,
116
+ ):
117
+ self._store = store
118
+ self._chunk_size = chunk_size
119
+ self._max_concurrency = max_concurrency
120
+
121
+ def _chunks(self, size) -> typing.Iterator[tuple[int, int]]:
122
+ cs = self._chunk_size
123
+ for offset in range(0, size, cs):
124
+ length = min(cs, size - offset)
125
+ yield offset, length
126
+
127
+ async def _as_completed(self, gen: typing.AsyncGenerator[DownloadTask, None], transformer=None):
128
+ inq: asyncio.Queue = asyncio.Queue(self._max_concurrency * 2)
129
+ outq: asyncio.Queue = asyncio.Queue()
130
+ sentinel = object()
131
+ done = asyncio.Event()
132
+
133
+ active: dict[Hashable, _FileBuffer | _MemoryBuffer] = {}
134
+
135
+ async def _fill():
136
+ # Helper function to fill the input queue, this is because the generator is async because it does list/head
137
+ # calls on the object store which are async.
138
+ try:
139
+ counter = 0
140
+ async for task in gen:
141
+ if task.source.id not in active:
142
+ active[task.source.id] = (
143
+ _FileBuffer.new(task.target, task.source.length)
144
+ if task.target is not None
145
+ else _MemoryBuffer.new(task.source.length)
146
+ )
147
+ await inq.put(task)
148
+ counter += 1
149
+ await inq.put(sentinel)
150
+ if counter == 0:
151
+ raise DownloadQueueEmpty
152
+ except asyncio.CancelledError:
153
+ # document why we need to swallow this
154
+ pass
155
+
156
+ async def _worker():
157
+ try:
158
+ while not done.is_set():
159
+ task = await inq.get()
160
+ if task is sentinel:
161
+ inq.put_nowait(sentinel)
162
+ break
163
+ chunk_source_offset = task.chunk.offset
164
+ buf = active[task.source.id]
165
+ data_to_write = await obstore.get_range_async(
166
+ self._store,
167
+ str(task.source.path),
168
+ start=chunk_source_offset,
169
+ end=chunk_source_offset + task.chunk.length,
170
+ )
171
+ await buf.write(
172
+ task.chunk.offset,
173
+ task.chunk.length,
174
+ data_to_write,
175
+ )
176
+ if not buf.complete:
177
+ continue
178
+ if transformer is not None:
179
+ result = await transformer(buf)
180
+ elif task.target is not None:
181
+ result = task.target
182
+ else:
183
+ result = task.source
184
+ outq.put_nowait((task.source.id, result))
185
+ del active[task.source.id]
186
+ except asyncio.CancelledError:
187
+ pass
188
+ finally:
189
+ done.set()
190
+
191
+ # Yield results as they are completed
192
+ if sys.version_info >= (3, 11):
193
+ async with asyncio.TaskGroup() as tg:
194
+ tg.create_task(_fill())
195
+ for _ in range(self._max_concurrency):
196
+ tg.create_task(_worker())
197
+ while not done.is_set():
198
+ yield await outq.get()
199
+ else:
200
+ fill_task = asyncio.create_task(_fill())
201
+ worker_tasks = [asyncio.create_task(_worker()) for _ in range(self._max_concurrency)]
202
+ try:
203
+ while not done.is_set():
204
+ yield await outq.get()
205
+ except Exception as e:
206
+ if not fill_task.done():
207
+ fill_task.cancel()
208
+ for wt in worker_tasks:
209
+ if not wt.done():
210
+ wt.cancel()
211
+ raise e
212
+ finally:
213
+ await asyncio.gather(fill_task, *worker_tasks, return_exceptions=True)
214
+
215
+ # Drain the output queue
216
+ try:
217
+ while True:
218
+ yield outq.get_nowait()
219
+ except asyncio.QueueEmpty:
220
+ pass
221
+
222
+ async def download_files(
223
+ self, src_prefix: pathlib.Path, target_prefix: pathlib.Path, *paths, destination_file_name: str | None = None
224
+ ) -> None:
225
+ """
226
+ src_prefix: Prefix you want to download from in the object store, not including the bucket name, nor file name.
227
+ Should be replaced with string
228
+ target_prefix: Local directory to download to
229
+ paths: Specific paths (relative to src_prefix) to download. If empty, download everything
230
+ """
231
+
232
+ async def _list_downloadable() -> typing.AsyncGenerator[ObjectMeta, None]:
233
+ if paths:
234
+ # For specific file paths, use async head
235
+ for path_ in paths:
236
+ path = src_prefix / path_
237
+ x = await obstore.head_async(self._store, str(path))
238
+ yield x
239
+ return
240
+
241
+ # Use obstore.list() for recursive listing (all files in all subdirectories)
242
+ # obstore.list() returns an async iterator that yields batches (lists) of objects
243
+ async for batch in obstore.list(self._store, prefix=str(src_prefix)):
244
+ for obj in batch:
245
+ yield obj
246
+
247
+ async def _gen(tmp_dir: str) -> typing.AsyncGenerator[DownloadTask, None]:
248
+ async for obj in _list_downloadable():
249
+ path = pathlib.Path(obj["path"]) # e.g. Path(prefix/file.txt), needs to be changed to str.
250
+ size = obj["size"]
251
+ source = Source(id=path, path=path, length=size)
252
+ # Strip src_prefix from path for destination
253
+ rel_path = path.relative_to(src_prefix) # doesn't work on windows
254
+ for offset, length in self._chunks(size):
255
+ yield DownloadTask(
256
+ source=source,
257
+ target=tmp_dir / rel_path, # doesn't work on windows
258
+ chunk=Chunk(offset, length),
259
+ )
260
+
261
+ def _transform_decorator(tmp_dir: str):
262
+ async def _transformer(buf: _FileBuffer) -> None:
263
+ if len(paths) == 1 and destination_file_name is not None:
264
+ target = target_prefix / destination_file_name
265
+ else:
266
+ target = target_prefix / buf.path.relative_to(tmp_dir)
267
+ await aiofiles.os.makedirs(target.parent, exist_ok=True)
268
+ return await aiofiles.os.replace(buf.path, target) # mv buf.path target
269
+
270
+ return _transformer
271
+
272
+ with tempfile.TemporaryDirectory() as temporary_dir:
273
+ async for _ in self._as_completed(_gen(temporary_dir), transformer=_transform_decorator(temporary_dir)):
274
+ pass
flyte/storage/_storage.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  import pathlib
3
5
  import random
@@ -7,6 +9,7 @@ from typing import AsyncGenerator, Optional
7
9
  from uuid import UUID
8
10
 
9
11
  import fsspec
12
+ import obstore
10
13
  from fsspec.asyn import AsyncFileSystem
11
14
  from fsspec.utils import get_protocol
12
15
  from obstore.exceptions import GenericError
@@ -14,7 +17,10 @@ from obstore.fsspec import register
14
17
 
15
18
  from flyte._initialize import get_storage
16
19
  from flyte._logging import logger
17
- from flyte.errors import InitializationError
20
+ from flyte.errors import InitializationError, OnlyAsyncIOSupportedError
21
+
22
+ if typing.TYPE_CHECKING:
23
+ from obstore import AsyncReadableFile, AsyncWritableFile
18
24
 
19
25
  _OBSTORE_SUPPORTED_PROTOCOLS = ["s3", "gs", "abfs", "abfss"]
20
26
 
@@ -141,12 +147,76 @@ def _get_anonymous_filesystem(from_path):
141
147
  return get_underlying_filesystem(get_protocol(from_path), anonymous=True, asynchronous=True)
142
148
 
143
149
 
150
+ async def _get_obstore_bypass(from_path: str, to_path: str | pathlib.Path, recursive: bool = False, **kwargs) -> str:
151
+ from obstore.store import ObjectStore
152
+
153
+ from flyte.storage._parallel_reader import ObstoreParallelReader
154
+
155
+ fs = get_underlying_filesystem(path=from_path)
156
+ bucket, prefix = fs._split_path(from_path) # pylint: disable=W0212
157
+ store: ObjectStore = fs._construct_store(bucket)
158
+
159
+ download_kwargs = {}
160
+ if "chunk_size" in kwargs:
161
+ download_kwargs["chunk_size"] = kwargs["chunk_size"]
162
+ if "max_concurrency" in kwargs:
163
+ download_kwargs["max_concurrency"] = kwargs["max_concurrency"]
164
+
165
+ reader = ObstoreParallelReader(store, **download_kwargs)
166
+ target_path = pathlib.Path(to_path) if isinstance(to_path, str) else to_path
167
+
168
+ # if recursive, just download the prefix to the target path
169
+ if recursive:
170
+ logger.debug(f"Downloading recursively {prefix=} to {target_path=}")
171
+ await reader.download_files(
172
+ prefix,
173
+ target_path,
174
+ )
175
+ return str(to_path)
176
+
177
+ # if not recursive, we need to split out the file name from the prefix
178
+ else:
179
+ path_for_reader = pathlib.Path(prefix).name
180
+ final_prefix = pathlib.Path(prefix).parent
181
+ logger.debug(f"Downloading single file {final_prefix=}, {path_for_reader=} to {target_path=}")
182
+ await reader.download_files(
183
+ final_prefix,
184
+ target_path.parent,
185
+ path_for_reader,
186
+ destination_file_name=target_path.name,
187
+ )
188
+ return str(target_path)
189
+
190
+
144
191
  async def get(from_path: str, to_path: Optional[str | pathlib.Path] = None, recursive: bool = False, **kwargs) -> str:
145
192
  if not to_path:
146
- name = pathlib.Path(from_path).name
193
+ name = pathlib.Path(from_path).name # may need to be adjusted for windows
147
194
  to_path = get_random_local_path(file_path_or_file_name=name)
148
195
  logger.debug(f"Storing file from {from_path} to {to_path}")
196
+ else:
197
+ # Only apply directory logic for single files (not recursive)
198
+ if not recursive:
199
+ to_path_str = str(to_path)
200
+ # Check for trailing separator BEFORE converting to Path (which normalizes and removes it)
201
+ ends_with_sep = to_path_str.endswith(os.sep)
202
+ to_path_obj = pathlib.Path(to_path)
203
+
204
+ # If path ends with os.sep or is an existing directory, append source filename
205
+ if ends_with_sep or (to_path_obj.exists() and to_path_obj.is_dir()):
206
+ source_filename = pathlib.Path(from_path).name # may need to be adjusted for windows
207
+ to_path = to_path_obj / source_filename
208
+ # For recursive=True, keep to_path as-is (it's the destination directory for contents)
209
+
149
210
  file_system = get_underlying_filesystem(path=from_path)
211
+
212
+ # Check if we should use obstore bypass
213
+ if (
214
+ _is_obstore_supported_protocol(file_system.protocol)
215
+ and hasattr(file_system, "_split_path")
216
+ and hasattr(file_system, "_construct_store")
217
+ ):
218
+ return await _get_obstore_bypass(from_path, to_path, recursive, **kwargs)
219
+
150
220
  try:
151
221
  return await _get_from_filesystem(file_system, from_path, to_path, recursive=recursive, **kwargs)
152
222
  except (OSError, GenericError) as oe:
@@ -189,7 +259,7 @@ async def put(from_path: str, to_path: Optional[str] = None, recursive: bool = F
189
259
  from flyte._context import internal_ctx
190
260
 
191
261
  ctx = internal_ctx()
192
- name = pathlib.Path(from_path).name if not recursive else None # don't pass a name for folders
262
+ name = pathlib.Path(from_path).name
193
263
  to_path = ctx.raw_data.get_random_remote_path(file_name=name)
194
264
 
195
265
  file_system = get_underlying_filesystem(path=to_path)
@@ -204,34 +274,48 @@ async def put(from_path: str, to_path: Optional[str] = None, recursive: bool = F
204
274
  return to_path
205
275
 
206
276
 
207
- async def _put_stream_obstore_bypass(data_iterable: typing.AsyncIterable[bytes] | bytes, to_path: str, **kwargs) -> str:
277
+ async def _open_obstore_bypass(path: str, mode: str = "rb", **kwargs) -> AsyncReadableFile | AsyncWritableFile:
208
278
  """
209
- NOTE: This can break if obstore changes its API.
210
-
211
- This function is a workaround for obstore's fsspec implementation which does not support async file operations.
212
- It uses the synchronous methods directly to put a stream of data.
279
+ Simple obstore bypass for opening files. No fallbacks, obstore only.
213
280
  """
214
- import obstore
215
281
  from obstore.store import ObjectStore
216
282
 
217
- fs = get_underlying_filesystem(path=to_path)
218
- if not hasattr(fs, "_split_path") or not hasattr(fs, "_construct_store"):
219
- raise NotImplementedError(f"Obstore bypass not supported for {fs.protocol} protocol, methods missing.")
220
- bucket, path = fs._split_path(to_path) # pylint: disable=W0212
283
+ fs = get_underlying_filesystem(path=path)
284
+ bucket, file_path = fs._split_path(path) # pylint: disable=W0212
221
285
  store: ObjectStore = fs._construct_store(bucket)
222
- if "attributes" in kwargs:
223
- attributes = kwargs.pop("attributes")
224
- else:
225
- attributes = {}
226
- buf_file = obstore.open_writer_async(store, path, attributes=attributes)
227
- if isinstance(data_iterable, bytes):
228
- await buf_file.write(data_iterable)
229
- else:
230
- async for data in data_iterable:
231
- await buf_file.write(data)
232
- # await buf_file.flush()
233
- await buf_file.close()
234
- return to_path
286
+
287
+ file_handle: AsyncReadableFile | AsyncWritableFile
288
+
289
+ if "w" in mode:
290
+ attributes = kwargs.pop("attributes", {})
291
+ file_handle = obstore.open_writer_async(store, file_path, attributes=attributes)
292
+ else: # read mode
293
+ buffer_size = kwargs.pop("buffer_size", 10 * 2**20)
294
+ file_handle = await obstore.open_reader_async(store, file_path, buffer_size=buffer_size)
295
+
296
+ return file_handle
297
+
298
+
299
+ async def open(path: str, mode: str = "rb", **kwargs) -> AsyncReadableFile | AsyncWritableFile:
300
+ """
301
+ Asynchronously open a file and return an async context manager.
302
+ This function checks if the underlying filesystem supports obstore bypass.
303
+ If it does, it uses obstore to open the file. Otherwise, it falls back to
304
+ the standard _open function which uses AsyncFileSystem.
305
+
306
+ It will raise NotImplementedError if neither obstore nor AsyncFileSystem is supported.
307
+ """
308
+ fs = get_underlying_filesystem(path=path)
309
+
310
+ # Check if we should use obstore bypass
311
+ if _is_obstore_supported_protocol(fs.protocol) and hasattr(fs, "_split_path") and hasattr(fs, "_construct_store"):
312
+ return await _open_obstore_bypass(path, mode, **kwargs)
313
+
314
+ # Fallback to normal open
315
+ if isinstance(fs, AsyncFileSystem):
316
+ return await fs.open_async(path, mode, **kwargs)
317
+
318
+ raise OnlyAsyncIOSupportedError(f"Filesystem {fs} does not support async operations")
235
319
 
236
320
 
237
321
  async def put_stream(
@@ -259,60 +343,31 @@ async def put_stream(
259
343
 
260
344
  ctx = internal_ctx()
261
345
  to_path = ctx.raw_data.get_random_remote_path(file_name=name)
262
- fs = get_underlying_filesystem(path=to_path)
263
346
 
264
- file_handle = None
265
- if isinstance(fs, AsyncFileSystem):
266
- try:
267
- if _is_obstore_supported_protocol(fs.protocol):
268
- # If the protocol is supported by obstore, use the obstore bypass method
269
- return await _put_stream_obstore_bypass(data_iterable, to_path=to_path, **kwargs)
270
- file_handle = await fs.open_async(to_path, "wb", **kwargs)
271
- if isinstance(data_iterable, bytes):
272
- await file_handle.write(data_iterable)
273
- else:
274
- async for data in data_iterable:
275
- await file_handle.write(data)
276
- return str(to_path)
277
- except NotImplementedError as e:
278
- logger.debug(f"{fs} doesn't implement 'open_async', falling back to sync, {e}")
279
- finally:
280
- if file_handle is not None:
281
- await file_handle.close()
282
-
283
- with fs.open(to_path, "wb", **kwargs) as f:
347
+ # Check if we should use obstore bypass
348
+ fs = get_underlying_filesystem(path=to_path)
349
+ try:
350
+ file_handle = typing.cast("AsyncWritableFile", await open(to_path, "wb", **kwargs))
284
351
  if isinstance(data_iterable, bytes):
285
- f.write(data_iterable)
352
+ await file_handle.write(data_iterable)
286
353
  else:
287
- # If data_iterable is async iterable, iterate over it and write each chunk to the file
288
354
  async for data in data_iterable:
289
- f.write(data)
290
- return str(to_path)
291
-
292
-
293
- async def _get_stream_obstore_bypass(path: str, chunk_size, **kwargs) -> AsyncGenerator[bytes, None]:
294
- """
295
- NOTE: This can break if obstore changes its API.
296
- This function is a workaround for obstore's fsspec implementation which does not support async file operations.
297
- It uses the synchronous methods directly to get a stream of data.
298
- """
299
- import obstore
300
- from obstore.store import ObjectStore
355
+ await file_handle.write(data)
356
+ await file_handle.close()
357
+ return str(to_path)
358
+ except OnlyAsyncIOSupportedError:
359
+ pass
360
+
361
+ # Fallback to normal open
362
+ file_handle_io: typing.IO = fs.open(to_path, mode="wb", **kwargs)
363
+ if isinstance(data_iterable, bytes):
364
+ file_handle_io.write(data_iterable)
365
+ else:
366
+ async for data in data_iterable:
367
+ file_handle_io.write(data)
368
+ file_handle_io.close()
301
369
 
302
- fs = get_underlying_filesystem(path=path)
303
- if not hasattr(fs, "_split_path") or not hasattr(fs, "_construct_store"):
304
- raise NotImplementedError(f"Obstore bypass not supported for {fs.protocol} protocol, methods missing.")
305
- bucket, rem_path = fs._split_path(path) # pylint: disable=W0212
306
- store: ObjectStore = fs._construct_store(bucket)
307
- buf_file = await obstore.open_reader_async(store, rem_path, buffer_size=chunk_size)
308
- try:
309
- while True:
310
- chunk = await buf_file.read()
311
- if not chunk:
312
- break
313
- yield bytes(chunk)
314
- finally:
315
- buf_file.close()
370
+ return str(to_path)
316
371
 
317
372
 
318
373
  async def get_stream(path: str, chunk_size=10 * 2**20, **kwargs) -> AsyncGenerator[bytes, None]:
@@ -322,42 +377,41 @@ async def get_stream(path: str, chunk_size=10 * 2**20, **kwargs) -> AsyncGenerat
322
377
  Example usage:
323
378
  ```python
324
379
  import flyte.storage as storage
325
- obj = storage.get_stream(path="s3://my_bucket/my_file.txt")
380
+ async for chunk in storage.get_stream(path="s3://my_bucket/my_file.txt"):
381
+ process(chunk)
326
382
  ```
327
383
 
328
384
  :param path: Path to the remote location where the data will be downloaded.
329
385
  :param kwargs: Additional arguments to be passed to the underlying filesystem.
330
386
  :param chunk_size: Size of each chunk to be read from the file.
331
- :return: An async iterator that yields chunks of data.
387
+ :return: An async iterator that yields chunks of bytes.
332
388
  """
333
- fs = get_underlying_filesystem(path=path, **kwargs)
389
+ # Check if we should use obstore bypass
390
+ fs = get_underlying_filesystem(path=path)
391
+ if _is_obstore_supported_protocol(fs.protocol) and hasattr(fs, "_split_path") and hasattr(fs, "_construct_store"):
392
+ # Set buffer_size for obstore if chunk_size is provided
393
+ if "buffer_size" not in kwargs:
394
+ kwargs["buffer_size"] = chunk_size
395
+ file_handle = typing.cast("AsyncReadableFile", await _open_obstore_bypass(path, "rb", **kwargs))
396
+ while chunk := await file_handle.read():
397
+ yield bytes(chunk)
398
+ return
334
399
 
335
- file_size = fs.info(path)["size"]
336
- total_read = 0
337
- file_handle = None
338
- try:
339
- if _is_obstore_supported_protocol(fs.protocol):
340
- # If the protocol is supported by obstore, use the obstore bypass method
341
- async for x in _get_stream_obstore_bypass(path, chunk_size=chunk_size, **kwargs):
342
- yield x
343
- return
344
- if isinstance(fs, AsyncFileSystem):
345
- file_handle = await fs.open_async(path, "rb")
346
- while chunk := await file_handle.read(min(chunk_size, file_size - total_read)):
347
- total_read += len(chunk)
348
- yield chunk
349
- return
350
- except NotImplementedError as e:
351
- logger.debug(f"{fs} doesn't implement 'open_async', falling back to sync, error: {e}")
352
- finally:
353
- if file_handle is not None:
354
- file_handle.close()
355
-
356
- # Sync fallback
357
- with fs.open(path, "rb") as file_handle:
358
- while chunk := file_handle.read(min(chunk_size, file_size - total_read)):
359
- total_read += len(chunk)
400
+ # Fallback to normal open
401
+ if "block_size" not in kwargs:
402
+ kwargs["block_size"] = chunk_size
403
+
404
+ if isinstance(fs, AsyncFileSystem):
405
+ file_handle = await fs.open_async(path, "rb", **kwargs)
406
+ while chunk := await file_handle.read():
360
407
  yield chunk
408
+ await file_handle.close()
409
+ return
410
+
411
+ file_handle = fs.open(path, "rb", **kwargs)
412
+ while chunk := file_handle.read():
413
+ yield chunk
414
+ file_handle.close()
361
415
 
362
416
 
363
417
  def join(*paths: str) -> str:
@@ -370,4 +424,32 @@ def join(*paths: str) -> str:
370
424
  return str(os.path.join(*paths))
371
425
 
372
426
 
427
+ async def exists(path: str, **kwargs) -> bool:
428
+ """
429
+ Check if a path exists.
430
+
431
+ :param path: Path to be checked.
432
+ :param kwargs: Additional arguments to be passed to the underlying filesystem.
433
+ :return: True if the path exists, False otherwise.
434
+ """
435
+ try:
436
+ fs = get_underlying_filesystem(path=path, **kwargs)
437
+ if isinstance(fs, AsyncFileSystem):
438
+ _ = await fs._info(path)
439
+ return True
440
+ _ = fs.info(path)
441
+ return True
442
+ except FileNotFoundError:
443
+ return False
444
+
445
+
446
+ def exists_sync(path: str, **kwargs) -> bool:
447
+ try:
448
+ fs = get_underlying_filesystem(path=path, **kwargs)
449
+ _ = fs.info(path)
450
+ return True
451
+ except FileNotFoundError:
452
+ return False
453
+
454
+
373
455
  register(_OBSTORE_SUPPORTED_PROTOCOLS, asynchronous=True)
flyte/types/__init__.py CHANGED
@@ -19,6 +19,10 @@ It is always possible to bypass the type system and use the `FlytePickle` type t
19
19
  written in python. The Pickled objects cannot be represented in the UI, and may be in-efficient for large datasets.
20
20
  """
21
21
 
22
+ from importlib.metadata import entry_points
23
+
24
+ from flyte._logging import logger
25
+
22
26
  from ._interface import guess_interface
23
27
  from ._pickle import FlytePickle
24
28
  from ._renderer import Renderable
@@ -34,3 +38,15 @@ __all__ = [
34
38
  "guess_interface",
35
39
  "literal_string_repr",
36
40
  ]
41
+
42
+
43
+ def _load_custom_type_transformers():
44
+ plugins = entry_points(group="flyte.plugins.types")
45
+ for ep in plugins:
46
+ try:
47
+ logger.info(f"Loading type transformer: {ep.name}")
48
+ loaded = ep.load()
49
+ if callable(loaded):
50
+ loaded()
51
+ except Exception as e:
52
+ logger.warning(f"Failed to load type transformer {ep.name} with error: {e}")
flyte/types/_interface.py CHANGED
@@ -1,9 +1,9 @@
1
1
  import inspect
2
2
  from typing import Any, Dict, Iterable, Tuple, Type, cast
3
3
 
4
- from flyteidl.core import interface_pb2, literals_pb2
4
+ from flyteidl2.core import interface_pb2, literals_pb2
5
+ from flyteidl2.task import common_pb2
5
6
 
6
- from flyte._protos.workflow import common_pb2
7
7
  from flyte.models import NativeInterface
8
8
 
9
9