flyte 2.0.0b13__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/debug.py +38 -0
  3. flyte/_bin/runtime.py +62 -8
  4. flyte/_cache/cache.py +4 -2
  5. flyte/_cache/local_cache.py +216 -0
  6. flyte/_code_bundle/_ignore.py +12 -4
  7. flyte/_code_bundle/_packaging.py +13 -9
  8. flyte/_code_bundle/_utils.py +18 -10
  9. flyte/_code_bundle/bundle.py +17 -9
  10. flyte/_constants.py +1 -0
  11. flyte/_context.py +4 -1
  12. flyte/_custom_context.py +73 -0
  13. flyte/_debug/constants.py +38 -0
  14. flyte/_debug/utils.py +17 -0
  15. flyte/_debug/vscode.py +307 -0
  16. flyte/_deploy.py +235 -61
  17. flyte/_environment.py +20 -6
  18. flyte/_excepthook.py +1 -1
  19. flyte/_hash.py +1 -16
  20. flyte/_image.py +178 -81
  21. flyte/_initialize.py +132 -51
  22. flyte/_interface.py +39 -2
  23. flyte/_internal/controllers/__init__.py +4 -5
  24. flyte/_internal/controllers/_local_controller.py +70 -29
  25. flyte/_internal/controllers/_trace.py +1 -1
  26. flyte/_internal/controllers/remote/__init__.py +0 -2
  27. flyte/_internal/controllers/remote/_action.py +14 -16
  28. flyte/_internal/controllers/remote/_client.py +1 -1
  29. flyte/_internal/controllers/remote/_controller.py +68 -70
  30. flyte/_internal/controllers/remote/_core.py +127 -99
  31. flyte/_internal/controllers/remote/_informer.py +19 -10
  32. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  33. flyte/_internal/imagebuild/docker_builder.py +181 -69
  34. flyte/_internal/imagebuild/image_builder.py +0 -5
  35. flyte/_internal/imagebuild/remote_builder.py +155 -64
  36. flyte/_internal/imagebuild/utils.py +51 -2
  37. flyte/_internal/resolvers/_task_module.py +5 -38
  38. flyte/_internal/resolvers/default.py +2 -2
  39. flyte/_internal/runtime/convert.py +110 -21
  40. flyte/_internal/runtime/entrypoints.py +27 -1
  41. flyte/_internal/runtime/io.py +21 -8
  42. flyte/_internal/runtime/resources_serde.py +20 -6
  43. flyte/_internal/runtime/reuse.py +1 -1
  44. flyte/_internal/runtime/rusty.py +20 -5
  45. flyte/_internal/runtime/task_serde.py +34 -19
  46. flyte/_internal/runtime/taskrunner.py +22 -4
  47. flyte/_internal/runtime/trigger_serde.py +160 -0
  48. flyte/_internal/runtime/types_serde.py +1 -1
  49. flyte/_keyring/__init__.py +0 -0
  50. flyte/_keyring/file.py +115 -0
  51. flyte/_logging.py +201 -39
  52. flyte/_map.py +111 -14
  53. flyte/_module.py +70 -0
  54. flyte/_pod.py +4 -3
  55. flyte/_resources.py +213 -31
  56. flyte/_run.py +110 -39
  57. flyte/_task.py +75 -16
  58. flyte/_task_environment.py +105 -29
  59. flyte/_task_plugins.py +4 -2
  60. flyte/_trace.py +5 -0
  61. flyte/_trigger.py +1000 -0
  62. flyte/_utils/__init__.py +2 -1
  63. flyte/_utils/asyn.py +3 -1
  64. flyte/_utils/coro_management.py +2 -1
  65. flyte/_utils/docker_credentials.py +173 -0
  66. flyte/_utils/module_loader.py +17 -2
  67. flyte/_version.py +3 -3
  68. flyte/cli/_abort.py +3 -3
  69. flyte/cli/_build.py +3 -6
  70. flyte/cli/_common.py +78 -7
  71. flyte/cli/_create.py +182 -4
  72. flyte/cli/_delete.py +23 -1
  73. flyte/cli/_deploy.py +63 -16
  74. flyte/cli/_get.py +79 -34
  75. flyte/cli/_params.py +26 -10
  76. flyte/cli/_plugins.py +209 -0
  77. flyte/cli/_run.py +151 -26
  78. flyte/cli/_serve.py +64 -0
  79. flyte/cli/_update.py +37 -0
  80. flyte/cli/_user.py +17 -0
  81. flyte/cli/main.py +30 -4
  82. flyte/config/_config.py +10 -6
  83. flyte/config/_internal.py +1 -0
  84. flyte/config/_reader.py +29 -8
  85. flyte/connectors/__init__.py +11 -0
  86. flyte/connectors/_connector.py +270 -0
  87. flyte/connectors/_server.py +197 -0
  88. flyte/connectors/utils.py +135 -0
  89. flyte/errors.py +22 -2
  90. flyte/extend.py +8 -1
  91. flyte/extras/_container.py +6 -1
  92. flyte/git/__init__.py +3 -0
  93. flyte/git/_config.py +21 -0
  94. flyte/io/__init__.py +2 -0
  95. flyte/io/_dataframe/__init__.py +2 -0
  96. flyte/io/_dataframe/basic_dfs.py +17 -8
  97. flyte/io/_dataframe/dataframe.py +98 -132
  98. flyte/io/_dir.py +575 -113
  99. flyte/io/_file.py +582 -139
  100. flyte/io/_hashing_io.py +342 -0
  101. flyte/models.py +74 -15
  102. flyte/remote/__init__.py +6 -1
  103. flyte/remote/_action.py +34 -26
  104. flyte/remote/_client/_protocols.py +39 -4
  105. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  106. flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
  107. flyte/remote/_client/auth/_channel.py +10 -6
  108. flyte/remote/_client/controlplane.py +17 -5
  109. flyte/remote/_console.py +3 -2
  110. flyte/remote/_data.py +6 -6
  111. flyte/remote/_logs.py +3 -3
  112. flyte/remote/_run.py +64 -8
  113. flyte/remote/_secret.py +26 -17
  114. flyte/remote/_task.py +75 -33
  115. flyte/remote/_trigger.py +306 -0
  116. flyte/remote/_user.py +33 -0
  117. flyte/report/_report.py +1 -1
  118. flyte/storage/__init__.py +6 -1
  119. flyte/storage/_config.py +5 -1
  120. flyte/storage/_parallel_reader.py +274 -0
  121. flyte/storage/_storage.py +200 -103
  122. flyte/types/__init__.py +16 -0
  123. flyte/types/_interface.py +2 -2
  124. flyte/types/_pickle.py +35 -8
  125. flyte/types/_string_literals.py +8 -9
  126. flyte/types/_type_engine.py +40 -70
  127. flyte/types/_utils.py +1 -1
  128. flyte-2.0.0b30.data/scripts/debug.py +38 -0
  129. {flyte-2.0.0b13.data → flyte-2.0.0b30.data}/scripts/runtime.py +62 -8
  130. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/METADATA +11 -3
  131. flyte-2.0.0b30.dist-info/RECORD +192 -0
  132. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +3 -0
  133. flyte/_protos/common/authorization_pb2.py +0 -66
  134. flyte/_protos/common/authorization_pb2.pyi +0 -108
  135. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  136. flyte/_protos/common/identifier_pb2.py +0 -93
  137. flyte/_protos/common/identifier_pb2.pyi +0 -110
  138. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  139. flyte/_protos/common/identity_pb2.py +0 -48
  140. flyte/_protos/common/identity_pb2.pyi +0 -72
  141. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  142. flyte/_protos/common/list_pb2.py +0 -36
  143. flyte/_protos/common/list_pb2.pyi +0 -71
  144. flyte/_protos/common/list_pb2_grpc.py +0 -4
  145. flyte/_protos/common/policy_pb2.py +0 -37
  146. flyte/_protos/common/policy_pb2.pyi +0 -27
  147. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  148. flyte/_protos/common/role_pb2.py +0 -37
  149. flyte/_protos/common/role_pb2.pyi +0 -53
  150. flyte/_protos/common/role_pb2_grpc.py +0 -4
  151. flyte/_protos/common/runtime_version_pb2.py +0 -28
  152. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  153. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  154. flyte/_protos/imagebuilder/definition_pb2.py +0 -59
  155. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -140
  156. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  157. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  158. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  159. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  160. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  161. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  162. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  163. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  164. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  165. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  166. flyte/_protos/secret/definition_pb2.py +0 -49
  167. flyte/_protos/secret/definition_pb2.pyi +0 -93
  168. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  169. flyte/_protos/secret/payload_pb2.py +0 -62
  170. flyte/_protos/secret/payload_pb2.pyi +0 -94
  171. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  172. flyte/_protos/secret/secret_pb2.py +0 -38
  173. flyte/_protos/secret/secret_pb2.pyi +0 -6
  174. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  175. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  176. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  177. flyte/_protos/workflow/common_pb2.py +0 -27
  178. flyte/_protos/workflow/common_pb2.pyi +0 -14
  179. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  180. flyte/_protos/workflow/environment_pb2.py +0 -29
  181. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  182. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  183. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  184. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  185. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  186. flyte/_protos/workflow/queue_service_pb2.py +0 -109
  187. flyte/_protos/workflow/queue_service_pb2.pyi +0 -166
  188. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  189. flyte/_protos/workflow/run_definition_pb2.py +0 -121
  190. flyte/_protos/workflow/run_definition_pb2.pyi +0 -327
  191. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  192. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  193. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  194. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  195. flyte/_protos/workflow/run_service_pb2.py +0 -137
  196. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  197. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  198. flyte/_protos/workflow/state_service_pb2.py +0 -67
  199. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  200. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  201. flyte/_protos/workflow/task_definition_pb2.py +0 -79
  202. flyte/_protos/workflow/task_definition_pb2.pyi +0 -81
  203. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  204. flyte/_protos/workflow/task_service_pb2.py +0 -60
  205. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  206. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  207. flyte-2.0.0b13.dist-info/RECORD +0 -239
  208. /flyte/{_protos → _debug}/__init__.py +0 -0
  209. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  210. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  211. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,274 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import dataclasses
5
+ import io
6
+ import os
7
+ import pathlib
8
+ import sys
9
+ import tempfile
10
+ import typing
11
+ from typing import Any, Hashable, Protocol
12
+
13
+ import aiofiles
14
+ import aiofiles.os
15
+ import obstore
16
+
17
+ if typing.TYPE_CHECKING:
18
+ from obstore import Bytes, ObjectMeta
19
+ from obstore.store import ObjectStore
20
+
21
+ CHUNK_SIZE = int(os.getenv("FLYTE_IO_CHUNK_SIZE", str(16 * 1024 * 1024)))
22
+ MAX_CONCURRENCY = int(os.getenv("FLYTE_IO_MAX_CONCURRENCY", str(32)))
23
+
24
+
25
+ class DownloadQueueEmpty(RuntimeError):
26
+ pass
27
+
28
+
29
+ class BufferProtocol(Protocol):
30
+ async def write(self, offset, length, value: Bytes) -> None: ...
31
+
32
+ async def read(self) -> memoryview: ...
33
+
34
+ @property
35
+ def complete(self) -> bool: ...
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class _MemoryBuffer:
40
+ arr: bytearray
41
+ pending: int
42
+ _closed: bool = False
43
+
44
+ async def write(self, offset: int, length: int, value: Bytes) -> None:
45
+ self.arr[offset : offset + length] = memoryview(value)
46
+ self.pending -= length
47
+
48
+ async def read(self) -> memoryview:
49
+ return memoryview(self.arr)
50
+
51
+ @property
52
+ def complete(self) -> bool:
53
+ return self.pending == 0
54
+
55
+ @classmethod
56
+ def new(cls, size):
57
+ return cls(arr=bytearray(size), pending=size)
58
+
59
+
60
+ @dataclasses.dataclass
61
+ class _FileBuffer:
62
+ path: pathlib.Path
63
+ pending: int
64
+ _handle: io.FileIO | None = None
65
+ _closed: bool = False
66
+
67
+ async def write(self, offset: int, length: int, value: Bytes) -> None:
68
+ async with aiofiles.open(self.path, mode="r+b") as f:
69
+ await f.seek(offset)
70
+ await f.write(value)
71
+ self.pending -= length
72
+
73
+ async def read(self) -> memoryview:
74
+ async with aiofiles.open(self.path, mode="rb") as f:
75
+ return memoryview(await f.read())
76
+
77
+ @property
78
+ def complete(self) -> bool:
79
+ return self.pending == 0
80
+
81
+ @classmethod
82
+ def new(cls, path: pathlib.Path, size: int):
83
+ path.parent.mkdir(parents=True, exist_ok=True)
84
+ path.touch()
85
+ return cls(path=path, pending=size)
86
+
87
+
88
+ @dataclasses.dataclass
89
+ class Chunk:
90
+ offset: int
91
+ length: int
92
+
93
+
94
+ @dataclasses.dataclass
95
+ class Source:
96
+ id: Hashable
97
+ path: pathlib.Path # Should be str, represents the fully qualified prefix of a file (no bucket)
98
+ length: int
99
+ metadata: Any | None = None
100
+
101
+
102
+ @dataclasses.dataclass
103
+ class DownloadTask:
104
+ source: Source
105
+ chunk: Chunk
106
+ target: pathlib.Path | None = None
107
+
108
+
109
+ class ObstoreParallelReader:
110
+ def __init__(
111
+ self,
112
+ store: ObjectStore,
113
+ *,
114
+ chunk_size=CHUNK_SIZE,
115
+ max_concurrency=MAX_CONCURRENCY,
116
+ ):
117
+ self._store = store
118
+ self._chunk_size = chunk_size
119
+ self._max_concurrency = max_concurrency
120
+
121
+ def _chunks(self, size) -> typing.Iterator[tuple[int, int]]:
122
+ cs = self._chunk_size
123
+ for offset in range(0, size, cs):
124
+ length = min(cs, size - offset)
125
+ yield offset, length
126
+
127
+ async def _as_completed(self, gen: typing.AsyncGenerator[DownloadTask, None], transformer=None):
128
+ inq: asyncio.Queue = asyncio.Queue(self._max_concurrency * 2)
129
+ outq: asyncio.Queue = asyncio.Queue()
130
+ sentinel = object()
131
+ done = asyncio.Event()
132
+
133
+ active: dict[Hashable, _FileBuffer | _MemoryBuffer] = {}
134
+
135
+ async def _fill():
136
+ # Helper function to fill the input queue, this is because the generator is async because it does list/head
137
+ # calls on the object store which are async.
138
+ try:
139
+ counter = 0
140
+ async for task in gen:
141
+ if task.source.id not in active:
142
+ active[task.source.id] = (
143
+ _FileBuffer.new(task.target, task.source.length)
144
+ if task.target is not None
145
+ else _MemoryBuffer.new(task.source.length)
146
+ )
147
+ await inq.put(task)
148
+ counter += 1
149
+ await inq.put(sentinel)
150
+ if counter == 0:
151
+ raise DownloadQueueEmpty
152
+ except asyncio.CancelledError:
153
+ # document why we need to swallow this
154
+ pass
155
+
156
+ async def _worker():
157
+ try:
158
+ while not done.is_set():
159
+ task = await inq.get()
160
+ if task is sentinel:
161
+ inq.put_nowait(sentinel)
162
+ break
163
+ chunk_source_offset = task.chunk.offset
164
+ buf = active[task.source.id]
165
+ data_to_write = await obstore.get_range_async(
166
+ self._store,
167
+ str(task.source.path),
168
+ start=chunk_source_offset,
169
+ end=chunk_source_offset + task.chunk.length,
170
+ )
171
+ await buf.write(
172
+ task.chunk.offset,
173
+ task.chunk.length,
174
+ data_to_write,
175
+ )
176
+ if not buf.complete:
177
+ continue
178
+ if transformer is not None:
179
+ result = await transformer(buf)
180
+ elif task.target is not None:
181
+ result = task.target
182
+ else:
183
+ result = task.source
184
+ outq.put_nowait((task.source.id, result))
185
+ del active[task.source.id]
186
+ except asyncio.CancelledError:
187
+ pass
188
+ finally:
189
+ done.set()
190
+
191
+ # Yield results as they are completed
192
+ if sys.version_info >= (3, 11):
193
+ async with asyncio.TaskGroup() as tg:
194
+ tg.create_task(_fill())
195
+ for _ in range(self._max_concurrency):
196
+ tg.create_task(_worker())
197
+ while not done.is_set():
198
+ yield await outq.get()
199
+ else:
200
+ fill_task = asyncio.create_task(_fill())
201
+ worker_tasks = [asyncio.create_task(_worker()) for _ in range(self._max_concurrency)]
202
+ try:
203
+ while not done.is_set():
204
+ yield await outq.get()
205
+ except Exception as e:
206
+ if not fill_task.done():
207
+ fill_task.cancel()
208
+ for wt in worker_tasks:
209
+ if not wt.done():
210
+ wt.cancel()
211
+ raise e
212
+ finally:
213
+ await asyncio.gather(fill_task, *worker_tasks, return_exceptions=True)
214
+
215
+ # Drain the output queue
216
+ try:
217
+ while True:
218
+ yield outq.get_nowait()
219
+ except asyncio.QueueEmpty:
220
+ pass
221
+
222
+ async def download_files(
223
+ self, src_prefix: pathlib.Path, target_prefix: pathlib.Path, *paths, destination_file_name: str | None = None
224
+ ) -> None:
225
+ """
226
+ src_prefix: Prefix you want to download from in the object store, not including the bucket name, nor file name.
227
+ Should be replaced with string
228
+ target_prefix: Local directory to download to
229
+ paths: Specific paths (relative to src_prefix) to download. If empty, download everything
230
+ """
231
+
232
+ async def _list_downloadable() -> typing.AsyncGenerator[ObjectMeta, None]:
233
+ if paths:
234
+ # For specific file paths, use async head
235
+ for path_ in paths:
236
+ path = src_prefix / path_
237
+ x = await obstore.head_async(self._store, str(path))
238
+ yield x
239
+ return
240
+
241
+ # Use obstore.list() for recursive listing (all files in all subdirectories)
242
+ # obstore.list() returns an async iterator that yields batches (lists) of objects
243
+ async for batch in obstore.list(self._store, prefix=str(src_prefix)):
244
+ for obj in batch:
245
+ yield obj
246
+
247
+ async def _gen(tmp_dir: str) -> typing.AsyncGenerator[DownloadTask, None]:
248
+ async for obj in _list_downloadable():
249
+ path = pathlib.Path(obj["path"]) # e.g. Path(prefix/file.txt), needs to be changed to str.
250
+ size = obj["size"]
251
+ source = Source(id=path, path=path, length=size)
252
+ # Strip src_prefix from path for destination
253
+ rel_path = path.relative_to(src_prefix) # doesn't work on windows
254
+ for offset, length in self._chunks(size):
255
+ yield DownloadTask(
256
+ source=source,
257
+ target=tmp_dir / rel_path, # doesn't work on windows
258
+ chunk=Chunk(offset, length),
259
+ )
260
+
261
+ def _transform_decorator(tmp_dir: str):
262
+ async def _transformer(buf: _FileBuffer) -> None:
263
+ if len(paths) == 1 and destination_file_name is not None:
264
+ target = target_prefix / destination_file_name
265
+ else:
266
+ target = target_prefix / buf.path.relative_to(tmp_dir)
267
+ await aiofiles.os.makedirs(target.parent, exist_ok=True)
268
+ return await aiofiles.os.replace(buf.path, target) # mv buf.path target
269
+
270
+ return _transformer
271
+
272
+ with tempfile.TemporaryDirectory() as temporary_dir:
273
+ async for _ in self._as_completed(_gen(temporary_dir), transformer=_transform_decorator(temporary_dir)):
274
+ pass
flyte/storage/_storage.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  import pathlib
3
5
  import random
@@ -7,6 +9,7 @@ from typing import AsyncGenerator, Optional
7
9
  from uuid import UUID
8
10
 
9
11
  import fsspec
12
+ import obstore
10
13
  from fsspec.asyn import AsyncFileSystem
11
14
  from fsspec.utils import get_protocol
12
15
  from obstore.exceptions import GenericError
@@ -14,6 +17,10 @@ from obstore.fsspec import register
14
17
 
15
18
  from flyte._initialize import get_storage
16
19
  from flyte._logging import logger
20
+ from flyte.errors import InitializationError, OnlyAsyncIOSupportedError
21
+
22
+ if typing.TYPE_CHECKING:
23
+ from obstore import AsyncReadableFile, AsyncWritableFile
17
24
 
18
25
  _OBSTORE_SUPPORTED_PROTOCOLS = ["s3", "gs", "abfs", "abfss"]
19
26
 
@@ -77,21 +84,36 @@ def get_configured_fsspec_kwargs(
77
84
  protocol: typing.Optional[str] = None, anonymous: bool = False
78
85
  ) -> typing.Dict[str, typing.Any]:
79
86
  if protocol:
87
+ # Try to get storage config safely - may not be initialized for local operations
88
+ try:
89
+ storage_config = get_storage()
90
+ except InitializationError:
91
+ storage_config = None
92
+
80
93
  match protocol:
81
94
  case "s3":
82
95
  # If the protocol is s3, we can use the s3 filesystem
83
96
  from flyte.storage import S3
84
97
 
98
+ if storage_config and isinstance(storage_config, S3):
99
+ return storage_config.get_fsspec_kwargs(anonymous=anonymous)
100
+
85
101
  return S3.auto().get_fsspec_kwargs(anonymous=anonymous)
86
102
  case "gs":
87
103
  # If the protocol is gs, we can use the gs filesystem
88
104
  from flyte.storage import GCS
89
105
 
106
+ if storage_config and isinstance(storage_config, GCS):
107
+ return storage_config.get_fsspec_kwargs(anonymous=anonymous)
108
+
90
109
  return GCS.auto().get_fsspec_kwargs(anonymous=anonymous)
91
110
  case "abfs" | "abfss":
92
111
  # If the protocol is abfs or abfss, we can use the abfs filesystem
93
112
  from flyte.storage import ABFS
94
113
 
114
+ if storage_config and isinstance(storage_config, ABFS):
115
+ return storage_config.get_fsspec_kwargs(anonymous=anonymous)
116
+
95
117
  return ABFS.auto().get_fsspec_kwargs(anonymous=anonymous)
96
118
  case _:
97
119
  return {}
@@ -125,12 +147,76 @@ def _get_anonymous_filesystem(from_path):
125
147
  return get_underlying_filesystem(get_protocol(from_path), anonymous=True, asynchronous=True)
126
148
 
127
149
 
150
+ async def _get_obstore_bypass(from_path: str, to_path: str | pathlib.Path, recursive: bool = False, **kwargs) -> str:
151
+ from obstore.store import ObjectStore
152
+
153
+ from flyte.storage._parallel_reader import ObstoreParallelReader
154
+
155
+ fs = get_underlying_filesystem(path=from_path)
156
+ bucket, prefix = fs._split_path(from_path) # pylint: disable=W0212
157
+ store: ObjectStore = fs._construct_store(bucket)
158
+
159
+ download_kwargs = {}
160
+ if "chunk_size" in kwargs:
161
+ download_kwargs["chunk_size"] = kwargs["chunk_size"]
162
+ if "max_concurrency" in kwargs:
163
+ download_kwargs["max_concurrency"] = kwargs["max_concurrency"]
164
+
165
+ reader = ObstoreParallelReader(store, **download_kwargs)
166
+ target_path = pathlib.Path(to_path) if isinstance(to_path, str) else to_path
167
+
168
+ # if recursive, just download the prefix to the target path
169
+ if recursive:
170
+ logger.debug(f"Downloading recursively {prefix=} to {target_path=}")
171
+ await reader.download_files(
172
+ prefix,
173
+ target_path,
174
+ )
175
+ return str(to_path)
176
+
177
+ # if not recursive, we need to split out the file name from the prefix
178
+ else:
179
+ path_for_reader = pathlib.Path(prefix).name
180
+ final_prefix = pathlib.Path(prefix).parent
181
+ logger.debug(f"Downloading single file {final_prefix=}, {path_for_reader=} to {target_path=}")
182
+ await reader.download_files(
183
+ final_prefix,
184
+ target_path.parent,
185
+ path_for_reader,
186
+ destination_file_name=target_path.name,
187
+ )
188
+ return str(target_path)
189
+
190
+
128
191
  async def get(from_path: str, to_path: Optional[str | pathlib.Path] = None, recursive: bool = False, **kwargs) -> str:
129
192
  if not to_path:
130
- name = pathlib.Path(from_path).name
193
+ name = pathlib.Path(from_path).name # may need to be adjusted for windows
131
194
  to_path = get_random_local_path(file_path_or_file_name=name)
132
195
  logger.debug(f"Storing file from {from_path} to {to_path}")
196
+ else:
197
+ # Only apply directory logic for single files (not recursive)
198
+ if not recursive:
199
+ to_path_str = str(to_path)
200
+ # Check for trailing separator BEFORE converting to Path (which normalizes and removes it)
201
+ ends_with_sep = to_path_str.endswith(os.sep)
202
+ to_path_obj = pathlib.Path(to_path)
203
+
204
+ # If path ends with os.sep or is an existing directory, append source filename
205
+ if ends_with_sep or (to_path_obj.exists() and to_path_obj.is_dir()):
206
+ source_filename = pathlib.Path(from_path).name # may need to be adjusted for windows
207
+ to_path = to_path_obj / source_filename
208
+ # For recursive=True, keep to_path as-is (it's the destination directory for contents)
209
+
133
210
  file_system = get_underlying_filesystem(path=from_path)
211
+
212
+ # Check if we should use obstore bypass
213
+ if (
214
+ _is_obstore_supported_protocol(file_system.protocol)
215
+ and hasattr(file_system, "_split_path")
216
+ and hasattr(file_system, "_construct_store")
217
+ ):
218
+ return await _get_obstore_bypass(from_path, to_path, recursive, **kwargs)
219
+
134
220
  try:
135
221
  return await _get_from_filesystem(file_system, from_path, to_path, recursive=recursive, **kwargs)
136
222
  except (OSError, GenericError) as oe:
@@ -145,7 +231,6 @@ async def get(from_path: str, to_path: Optional[str | pathlib.Path] = None, recu
145
231
  else:
146
232
  exists = file_system.exists(from_path)
147
233
  if not exists:
148
- # TODO: update exception to be more specific
149
234
  raise AssertionError(f"Unable to load data from {from_path}")
150
235
  file_system = _get_anonymous_filesystem(from_path)
151
236
  logger.debug(f"Attempting anonymous get with {file_system}")
@@ -174,7 +259,7 @@ async def put(from_path: str, to_path: Optional[str] = None, recursive: bool = F
174
259
  from flyte._context import internal_ctx
175
260
 
176
261
  ctx = internal_ctx()
177
- name = pathlib.Path(from_path).name if not recursive else None # don't pass a name for folders
262
+ name = pathlib.Path(from_path).name
178
263
  to_path = ctx.raw_data.get_random_remote_path(file_name=name)
179
264
 
180
265
  file_system = get_underlying_filesystem(path=to_path)
@@ -189,34 +274,48 @@ async def put(from_path: str, to_path: Optional[str] = None, recursive: bool = F
189
274
  return to_path
190
275
 
191
276
 
192
- async def _put_stream_obstore_bypass(data_iterable: typing.AsyncIterable[bytes] | bytes, to_path: str, **kwargs) -> str:
277
+ async def _open_obstore_bypass(path: str, mode: str = "rb", **kwargs) -> AsyncReadableFile | AsyncWritableFile:
193
278
  """
194
- NOTE: This can break if obstore changes its API.
195
-
196
- This function is a workaround for obstore's fsspec implementation which does not support async file operations.
197
- It uses the synchronous methods directly to put a stream of data.
279
+ Simple obstore bypass for opening files. No fallbacks, obstore only.
198
280
  """
199
- import obstore
200
281
  from obstore.store import ObjectStore
201
282
 
202
- fs = get_underlying_filesystem(path=to_path)
203
- if not hasattr(fs, "_split_path") or not hasattr(fs, "_construct_store"):
204
- raise NotImplementedError(f"Obstore bypass not supported for {fs.protocol} protocol, methods missing.")
205
- bucket, path = fs._split_path(to_path) # pylint: disable=W0212
283
+ fs = get_underlying_filesystem(path=path)
284
+ bucket, file_path = fs._split_path(path) # pylint: disable=W0212
206
285
  store: ObjectStore = fs._construct_store(bucket)
207
- if "attributes" in kwargs:
208
- attributes = kwargs.pop("attributes")
209
- else:
210
- attributes = {}
211
- buf_file = obstore.open_writer_async(store, path, attributes=attributes)
212
- if isinstance(data_iterable, bytes):
213
- await buf_file.write(data_iterable)
214
- else:
215
- async for data in data_iterable:
216
- await buf_file.write(data)
217
- # await buf_file.flush()
218
- await buf_file.close()
219
- return to_path
286
+
287
+ file_handle: AsyncReadableFile | AsyncWritableFile
288
+
289
+ if "w" in mode:
290
+ attributes = kwargs.pop("attributes", {})
291
+ file_handle = obstore.open_writer_async(store, file_path, attributes=attributes)
292
+ else: # read mode
293
+ buffer_size = kwargs.pop("buffer_size", 10 * 2**20)
294
+ file_handle = await obstore.open_reader_async(store, file_path, buffer_size=buffer_size)
295
+
296
+ return file_handle
297
+
298
+
299
+ async def open(path: str, mode: str = "rb", **kwargs) -> AsyncReadableFile | AsyncWritableFile:
300
+ """
301
+ Asynchronously open a file and return an async context manager.
302
+ This function checks if the underlying filesystem supports obstore bypass.
303
+ If it does, it uses obstore to open the file. Otherwise, it falls back to
304
+ the standard _open function which uses AsyncFileSystem.
305
+
306
+ It will raise NotImplementedError if neither obstore nor AsyncFileSystem is supported.
307
+ """
308
+ fs = get_underlying_filesystem(path=path)
309
+
310
+ # Check if we should use obstore bypass
311
+ if _is_obstore_supported_protocol(fs.protocol) and hasattr(fs, "_split_path") and hasattr(fs, "_construct_store"):
312
+ return await _open_obstore_bypass(path, mode, **kwargs)
313
+
314
+ # Fallback to normal open
315
+ if isinstance(fs, AsyncFileSystem):
316
+ return await fs.open_async(path, mode, **kwargs)
317
+
318
+ raise OnlyAsyncIOSupportedError(f"Filesystem {fs} does not support async operations")
220
319
 
221
320
 
222
321
  async def put_stream(
@@ -244,60 +343,31 @@ async def put_stream(
244
343
 
245
344
  ctx = internal_ctx()
246
345
  to_path = ctx.raw_data.get_random_remote_path(file_name=name)
247
- fs = get_underlying_filesystem(path=to_path)
248
346
 
249
- file_handle = None
250
- if isinstance(fs, AsyncFileSystem):
251
- try:
252
- if _is_obstore_supported_protocol(fs.protocol):
253
- # If the protocol is supported by obstore, use the obstore bypass method
254
- return await _put_stream_obstore_bypass(data_iterable, to_path=to_path, **kwargs)
255
- file_handle = await fs.open_async(to_path, "wb", **kwargs)
256
- if isinstance(data_iterable, bytes):
257
- await file_handle.write(data_iterable)
258
- else:
259
- async for data in data_iterable:
260
- await file_handle.write(data)
261
- return str(to_path)
262
- except NotImplementedError as e:
263
- logger.debug(f"{fs} doesn't implement 'open_async', falling back to sync, {e}")
264
- finally:
265
- if file_handle is not None:
266
- await file_handle.close()
267
-
268
- with fs.open(to_path, "wb", **kwargs) as f:
347
+ # Check if we should use obstore bypass
348
+ fs = get_underlying_filesystem(path=to_path)
349
+ try:
350
+ file_handle = typing.cast("AsyncWritableFile", await open(to_path, "wb", **kwargs))
269
351
  if isinstance(data_iterable, bytes):
270
- f.write(data_iterable)
352
+ await file_handle.write(data_iterable)
271
353
  else:
272
- # If data_iterable is async iterable, iterate over it and write each chunk to the file
273
354
  async for data in data_iterable:
274
- f.write(data)
275
- return str(to_path)
276
-
277
-
278
- async def _get_stream_obstore_bypass(path: str, chunk_size, **kwargs) -> AsyncGenerator[bytes, None]:
279
- """
280
- NOTE: This can break if obstore changes its API.
281
- This function is a workaround for obstore's fsspec implementation which does not support async file operations.
282
- It uses the synchronous methods directly to get a stream of data.
283
- """
284
- import obstore
285
- from obstore.store import ObjectStore
355
+ await file_handle.write(data)
356
+ await file_handle.close()
357
+ return str(to_path)
358
+ except OnlyAsyncIOSupportedError:
359
+ pass
360
+
361
+ # Fallback to normal open
362
+ file_handle_io: typing.IO = fs.open(to_path, mode="wb", **kwargs)
363
+ if isinstance(data_iterable, bytes):
364
+ file_handle_io.write(data_iterable)
365
+ else:
366
+ async for data in data_iterable:
367
+ file_handle_io.write(data)
368
+ file_handle_io.close()
286
369
 
287
- fs = get_underlying_filesystem(path=path)
288
- if not hasattr(fs, "_split_path") or not hasattr(fs, "_construct_store"):
289
- raise NotImplementedError(f"Obstore bypass not supported for {fs.protocol} protocol, methods missing.")
290
- bucket, rem_path = fs._split_path(path) # pylint: disable=W0212
291
- store: ObjectStore = fs._construct_store(bucket)
292
- buf_file = await obstore.open_reader_async(store, rem_path, buffer_size=chunk_size)
293
- try:
294
- while True:
295
- chunk = await buf_file.read()
296
- if not chunk:
297
- break
298
- yield bytes(chunk)
299
- finally:
300
- buf_file.close()
370
+ return str(to_path)
301
371
 
302
372
 
303
373
  async def get_stream(path: str, chunk_size=10 * 2**20, **kwargs) -> AsyncGenerator[bytes, None]:
@@ -307,42 +377,41 @@ async def get_stream(path: str, chunk_size=10 * 2**20, **kwargs) -> AsyncGenerat
307
377
  Example usage:
308
378
  ```python
309
379
  import flyte.storage as storage
310
- obj = storage.get_stream(path="s3://my_bucket/my_file.txt")
380
+ async for chunk in storage.get_stream(path="s3://my_bucket/my_file.txt"):
381
+ process(chunk)
311
382
  ```
312
383
 
313
384
  :param path: Path to the remote location where the data will be downloaded.
314
385
  :param kwargs: Additional arguments to be passed to the underlying filesystem.
315
386
  :param chunk_size: Size of each chunk to be read from the file.
316
- :return: An async iterator that yields chunks of data.
387
+ :return: An async iterator that yields chunks of bytes.
317
388
  """
318
- fs = get_underlying_filesystem(path=path, **kwargs)
389
+ # Check if we should use obstore bypass
390
+ fs = get_underlying_filesystem(path=path)
391
+ if _is_obstore_supported_protocol(fs.protocol) and hasattr(fs, "_split_path") and hasattr(fs, "_construct_store"):
392
+ # Set buffer_size for obstore if chunk_size is provided
393
+ if "buffer_size" not in kwargs:
394
+ kwargs["buffer_size"] = chunk_size
395
+ file_handle = typing.cast("AsyncReadableFile", await _open_obstore_bypass(path, "rb", **kwargs))
396
+ while chunk := await file_handle.read():
397
+ yield bytes(chunk)
398
+ return
319
399
 
320
- file_size = fs.info(path)["size"]
321
- total_read = 0
322
- file_handle = None
323
- try:
324
- if _is_obstore_supported_protocol(fs.protocol):
325
- # If the protocol is supported by obstore, use the obstore bypass method
326
- async for x in _get_stream_obstore_bypass(path, chunk_size=chunk_size, **kwargs):
327
- yield x
328
- return
329
- if isinstance(fs, AsyncFileSystem):
330
- file_handle = await fs.open_async(path, "rb")
331
- while chunk := await file_handle.read(min(chunk_size, file_size - total_read)):
332
- total_read += len(chunk)
333
- yield chunk
334
- return
335
- except NotImplementedError as e:
336
- logger.debug(f"{fs} doesn't implement 'open_async', falling back to sync, error: {e}")
337
- finally:
338
- if file_handle is not None:
339
- file_handle.close()
340
-
341
- # Sync fallback
342
- with fs.open(path, "rb") as file_handle:
343
- while chunk := file_handle.read(min(chunk_size, file_size - total_read)):
344
- total_read += len(chunk)
400
+ # Fallback to normal open
401
+ if "block_size" not in kwargs:
402
+ kwargs["block_size"] = chunk_size
403
+
404
+ if isinstance(fs, AsyncFileSystem):
405
+ file_handle = await fs.open_async(path, "rb", **kwargs)
406
+ while chunk := await file_handle.read():
345
407
  yield chunk
408
+ await file_handle.close()
409
+ return
410
+
411
+ file_handle = fs.open(path, "rb", **kwargs)
412
+ while chunk := file_handle.read():
413
+ yield chunk
414
+ file_handle.close()
346
415
 
347
416
 
348
417
  def join(*paths: str) -> str:
@@ -355,4 +424,32 @@ def join(*paths: str) -> str:
355
424
  return str(os.path.join(*paths))
356
425
 
357
426
 
427
+ async def exists(path: str, **kwargs) -> bool:
428
+ """
429
+ Check if a path exists.
430
+
431
+ :param path: Path to be checked.
432
+ :param kwargs: Additional arguments to be passed to the underlying filesystem.
433
+ :return: True if the path exists, False otherwise.
434
+ """
435
+ try:
436
+ fs = get_underlying_filesystem(path=path, **kwargs)
437
+ if isinstance(fs, AsyncFileSystem):
438
+ _ = await fs._info(path)
439
+ return True
440
+ _ = fs.info(path)
441
+ return True
442
+ except FileNotFoundError:
443
+ return False
444
+
445
+
446
+ def exists_sync(path: str, **kwargs) -> bool:
447
+ try:
448
+ fs = get_underlying_filesystem(path=path, **kwargs)
449
+ _ = fs.info(path)
450
+ return True
451
+ except FileNotFoundError:
452
+ return False
453
+
454
+
358
455
  register(_OBSTORE_SUPPORTED_PROTOCOLS, asynchronous=True)