flyte 2.0.0b32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (204) hide show
  1. flyte/__init__.py +108 -0
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +195 -0
  5. flyte/_bin/serve.py +178 -0
  6. flyte/_build.py +26 -0
  7. flyte/_cache/__init__.py +12 -0
  8. flyte/_cache/cache.py +147 -0
  9. flyte/_cache/defaults.py +9 -0
  10. flyte/_cache/local_cache.py +216 -0
  11. flyte/_cache/policy_function_body.py +42 -0
  12. flyte/_code_bundle/__init__.py +8 -0
  13. flyte/_code_bundle/_ignore.py +121 -0
  14. flyte/_code_bundle/_packaging.py +218 -0
  15. flyte/_code_bundle/_utils.py +347 -0
  16. flyte/_code_bundle/bundle.py +266 -0
  17. flyte/_constants.py +1 -0
  18. flyte/_context.py +155 -0
  19. flyte/_custom_context.py +73 -0
  20. flyte/_debug/__init__.py +0 -0
  21. flyte/_debug/constants.py +38 -0
  22. flyte/_debug/utils.py +17 -0
  23. flyte/_debug/vscode.py +307 -0
  24. flyte/_deploy.py +408 -0
  25. flyte/_deployer.py +109 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +122 -0
  29. flyte/_excepthook.py +37 -0
  30. flyte/_group.py +32 -0
  31. flyte/_hash.py +8 -0
  32. flyte/_image.py +1055 -0
  33. flyte/_initialize.py +628 -0
  34. flyte/_interface.py +119 -0
  35. flyte/_internal/__init__.py +3 -0
  36. flyte/_internal/controllers/__init__.py +129 -0
  37. flyte/_internal/controllers/_local_controller.py +239 -0
  38. flyte/_internal/controllers/_trace.py +48 -0
  39. flyte/_internal/controllers/remote/__init__.py +58 -0
  40. flyte/_internal/controllers/remote/_action.py +211 -0
  41. flyte/_internal/controllers/remote/_client.py +47 -0
  42. flyte/_internal/controllers/remote/_controller.py +583 -0
  43. flyte/_internal/controllers/remote/_core.py +465 -0
  44. flyte/_internal/controllers/remote/_informer.py +381 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +3 -0
  47. flyte/_internal/imagebuild/docker_builder.py +706 -0
  48. flyte/_internal/imagebuild/image_builder.py +277 -0
  49. flyte/_internal/imagebuild/remote_builder.py +386 -0
  50. flyte/_internal/imagebuild/utils.py +78 -0
  51. flyte/_internal/resolvers/__init__.py +0 -0
  52. flyte/_internal/resolvers/_task_module.py +21 -0
  53. flyte/_internal/resolvers/common.py +31 -0
  54. flyte/_internal/resolvers/default.py +28 -0
  55. flyte/_internal/runtime/__init__.py +0 -0
  56. flyte/_internal/runtime/convert.py +486 -0
  57. flyte/_internal/runtime/entrypoints.py +204 -0
  58. flyte/_internal/runtime/io.py +188 -0
  59. flyte/_internal/runtime/resources_serde.py +152 -0
  60. flyte/_internal/runtime/reuse.py +125 -0
  61. flyte/_internal/runtime/rusty.py +193 -0
  62. flyte/_internal/runtime/task_serde.py +362 -0
  63. flyte/_internal/runtime/taskrunner.py +209 -0
  64. flyte/_internal/runtime/trigger_serde.py +160 -0
  65. flyte/_internal/runtime/types_serde.py +54 -0
  66. flyte/_keyring/__init__.py +0 -0
  67. flyte/_keyring/file.py +115 -0
  68. flyte/_logging.py +300 -0
  69. flyte/_map.py +312 -0
  70. flyte/_module.py +72 -0
  71. flyte/_pod.py +30 -0
  72. flyte/_resources.py +473 -0
  73. flyte/_retry.py +32 -0
  74. flyte/_reusable_environment.py +102 -0
  75. flyte/_run.py +724 -0
  76. flyte/_secret.py +96 -0
  77. flyte/_task.py +550 -0
  78. flyte/_task_environment.py +316 -0
  79. flyte/_task_plugins.py +47 -0
  80. flyte/_timeout.py +47 -0
  81. flyte/_tools.py +27 -0
  82. flyte/_trace.py +119 -0
  83. flyte/_trigger.py +1000 -0
  84. flyte/_utils/__init__.py +30 -0
  85. flyte/_utils/asyn.py +121 -0
  86. flyte/_utils/async_cache.py +139 -0
  87. flyte/_utils/coro_management.py +27 -0
  88. flyte/_utils/docker_credentials.py +173 -0
  89. flyte/_utils/file_handling.py +72 -0
  90. flyte/_utils/helpers.py +134 -0
  91. flyte/_utils/lazy_module.py +54 -0
  92. flyte/_utils/module_loader.py +104 -0
  93. flyte/_utils/org_discovery.py +57 -0
  94. flyte/_utils/uv_script_parser.py +49 -0
  95. flyte/_version.py +34 -0
  96. flyte/app/__init__.py +22 -0
  97. flyte/app/_app_environment.py +157 -0
  98. flyte/app/_deploy.py +125 -0
  99. flyte/app/_input.py +160 -0
  100. flyte/app/_runtime/__init__.py +3 -0
  101. flyte/app/_runtime/app_serde.py +347 -0
  102. flyte/app/_types.py +101 -0
  103. flyte/app/extras/__init__.py +3 -0
  104. flyte/app/extras/_fastapi.py +151 -0
  105. flyte/cli/__init__.py +12 -0
  106. flyte/cli/_abort.py +28 -0
  107. flyte/cli/_build.py +114 -0
  108. flyte/cli/_common.py +468 -0
  109. flyte/cli/_create.py +371 -0
  110. flyte/cli/_delete.py +45 -0
  111. flyte/cli/_deploy.py +293 -0
  112. flyte/cli/_gen.py +176 -0
  113. flyte/cli/_get.py +370 -0
  114. flyte/cli/_option.py +33 -0
  115. flyte/cli/_params.py +554 -0
  116. flyte/cli/_plugins.py +209 -0
  117. flyte/cli/_run.py +597 -0
  118. flyte/cli/_serve.py +64 -0
  119. flyte/cli/_update.py +37 -0
  120. flyte/cli/_user.py +17 -0
  121. flyte/cli/main.py +221 -0
  122. flyte/config/__init__.py +3 -0
  123. flyte/config/_config.py +248 -0
  124. flyte/config/_internal.py +73 -0
  125. flyte/config/_reader.py +225 -0
  126. flyte/connectors/__init__.py +11 -0
  127. flyte/connectors/_connector.py +270 -0
  128. flyte/connectors/_server.py +197 -0
  129. flyte/connectors/utils.py +135 -0
  130. flyte/errors.py +243 -0
  131. flyte/extend.py +19 -0
  132. flyte/extras/__init__.py +5 -0
  133. flyte/extras/_container.py +286 -0
  134. flyte/git/__init__.py +3 -0
  135. flyte/git/_config.py +21 -0
  136. flyte/io/__init__.py +29 -0
  137. flyte/io/_dataframe/__init__.py +131 -0
  138. flyte/io/_dataframe/basic_dfs.py +223 -0
  139. flyte/io/_dataframe/dataframe.py +1026 -0
  140. flyte/io/_dir.py +910 -0
  141. flyte/io/_file.py +914 -0
  142. flyte/io/_hashing_io.py +342 -0
  143. flyte/models.py +479 -0
  144. flyte/py.typed +0 -0
  145. flyte/remote/__init__.py +35 -0
  146. flyte/remote/_action.py +738 -0
  147. flyte/remote/_app.py +57 -0
  148. flyte/remote/_client/__init__.py +0 -0
  149. flyte/remote/_client/_protocols.py +189 -0
  150. flyte/remote/_client/auth/__init__.py +12 -0
  151. flyte/remote/_client/auth/_auth_utils.py +14 -0
  152. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  153. flyte/remote/_client/auth/_authenticators/base.py +403 -0
  154. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  155. flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
  156. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  157. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  158. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  159. flyte/remote/_client/auth/_channel.py +213 -0
  160. flyte/remote/_client/auth/_client_config.py +85 -0
  161. flyte/remote/_client/auth/_default_html.py +32 -0
  162. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  163. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  164. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  165. flyte/remote/_client/auth/_keyring.py +152 -0
  166. flyte/remote/_client/auth/_token_client.py +260 -0
  167. flyte/remote/_client/auth/errors.py +16 -0
  168. flyte/remote/_client/controlplane.py +128 -0
  169. flyte/remote/_common.py +30 -0
  170. flyte/remote/_console.py +19 -0
  171. flyte/remote/_data.py +161 -0
  172. flyte/remote/_logs.py +185 -0
  173. flyte/remote/_project.py +88 -0
  174. flyte/remote/_run.py +386 -0
  175. flyte/remote/_secret.py +142 -0
  176. flyte/remote/_task.py +527 -0
  177. flyte/remote/_trigger.py +306 -0
  178. flyte/remote/_user.py +33 -0
  179. flyte/report/__init__.py +3 -0
  180. flyte/report/_report.py +182 -0
  181. flyte/report/_template.html +124 -0
  182. flyte/storage/__init__.py +36 -0
  183. flyte/storage/_config.py +237 -0
  184. flyte/storage/_parallel_reader.py +274 -0
  185. flyte/storage/_remote_fs.py +34 -0
  186. flyte/storage/_storage.py +456 -0
  187. flyte/storage/_utils.py +5 -0
  188. flyte/syncify/__init__.py +56 -0
  189. flyte/syncify/_api.py +375 -0
  190. flyte/types/__init__.py +52 -0
  191. flyte/types/_interface.py +40 -0
  192. flyte/types/_pickle.py +145 -0
  193. flyte/types/_renderer.py +162 -0
  194. flyte/types/_string_literals.py +119 -0
  195. flyte/types/_type_engine.py +2254 -0
  196. flyte/types/_utils.py +80 -0
  197. flyte-2.0.0b32.data/scripts/debug.py +38 -0
  198. flyte-2.0.0b32.data/scripts/runtime.py +195 -0
  199. flyte-2.0.0b32.dist-info/METADATA +351 -0
  200. flyte-2.0.0b32.dist-info/RECORD +204 -0
  201. flyte-2.0.0b32.dist-info/WHEEL +5 -0
  202. flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
  203. flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
  204. flyte-2.0.0b32.dist-info/top_level.txt +1 -0
@@ -0,0 +1,237 @@
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import os
5
+ import typing
6
+ from dataclasses import dataclass
7
+ from typing import ClassVar
8
+
9
+ from flyte.config import set_if_exists
10
+
11
+
12
+ @dataclass(init=True, repr=True, eq=True, frozen=True)
13
+ class Storage(object):
14
+ """
15
+ Data storage configuration that applies across any provider.
16
+ """
17
+
18
+ retries: int = 3
19
+ backoff: datetime.timedelta = datetime.timedelta(seconds=5)
20
+ enable_debug: bool = False
21
+ attach_execution_metadata: bool = True
22
+
23
+ _KEY_ENV_VAR_MAPPING: ClassVar[typing.Dict[str, str]] = {
24
+ "enable_debug": "UNION_STORAGE_DEBUG",
25
+ "retries": "UNION_STORAGE_RETRIES",
26
+ "backoff": "UNION_STORAGE_BACKOFF_SECONDS",
27
+ }
28
+
29
+ def get_fsspec_kwargs(self, anonymous: bool = False, **kwargs) -> typing.Dict[str, typing.Any]:
30
+ """
31
+ Returns the configuration as kwargs for constructing an fsspec filesystem.
32
+ """
33
+ return {}
34
+
35
+ @classmethod
36
+ def _auto_as_kwargs(cls) -> typing.Dict[str, typing.Any]:
37
+ retries = os.getenv(cls._KEY_ENV_VAR_MAPPING["retries"])
38
+ backoff = os.getenv(cls._KEY_ENV_VAR_MAPPING["backoff"])
39
+ enable_debug = os.getenv(cls._KEY_ENV_VAR_MAPPING["enable_debug"])
40
+
41
+ kwargs: typing.Dict[str, typing.Any] = {}
42
+ kwargs = set_if_exists(kwargs, "enable_debug", enable_debug)
43
+ kwargs = set_if_exists(kwargs, "retries", retries)
44
+ kwargs = set_if_exists(kwargs, "backoff", backoff)
45
+ return kwargs
46
+
47
+ @classmethod
48
+ def auto(cls) -> Storage:
49
+ """
50
+ Construct the config object automatically from environment variables.
51
+ """
52
+ return cls(**cls._auto_as_kwargs())
53
+
54
+
55
+ @dataclass(init=True, repr=True, eq=True, frozen=True)
56
+ class S3(Storage):
57
+ """
58
+ S3 specific configuration
59
+ """
60
+
61
+ endpoint: typing.Optional[str] = None
62
+ access_key_id: typing.Optional[str] = None
63
+ secret_access_key: typing.Optional[str] = None
64
+ region: typing.Optional[str] = None
65
+
66
+ _KEY_ENV_VAR_MAPPING: ClassVar[typing.Dict[str, str]] = {
67
+ "endpoint": "FLYTE_AWS_ENDPOINT",
68
+ "access_key_id": "FLYTE_AWS_ACCESS_KEY_ID",
69
+ "secret_access_key": "FLYTE_AWS_SECRET_ACCESS_KEY",
70
+ } | Storage._KEY_ENV_VAR_MAPPING
71
+
72
+ # Refer to https://github.com/developmentseed/obstore/blob/33654fc37f19a657689eb93327b621e9f9e01494/obstore/python/obstore/store/_aws.pyi#L11
73
+ # for key and secret
74
+ _CONFIG_KEY_FSSPEC_S3_KEY_ID: ClassVar = "access_key_id"
75
+ _CONFIG_KEY_FSSPEC_S3_SECRET: ClassVar = "secret_access_key"
76
+ _CONFIG_KEY_ENDPOINT: ClassVar = "endpoint_url"
77
+ _KEY_SKIP_SIGNATURE: ClassVar = "skip_signature"
78
+
79
+ @classmethod
80
+ def auto(cls, region: str | None = None) -> S3:
81
+ """
82
+ :return: Config
83
+ """
84
+ endpoint = os.getenv(cls._KEY_ENV_VAR_MAPPING["endpoint"], None)
85
+ access_key_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["access_key_id"], None)
86
+ secret_access_key = os.getenv(cls._KEY_ENV_VAR_MAPPING["secret_access_key"], None)
87
+
88
+ kwargs = super()._auto_as_kwargs()
89
+ kwargs = set_if_exists(kwargs, "endpoint", endpoint)
90
+ kwargs = set_if_exists(kwargs, "access_key_id", access_key_id)
91
+ kwargs = set_if_exists(kwargs, "secret_access_key", secret_access_key)
92
+ kwargs = set_if_exists(kwargs, "region", region)
93
+
94
+ return S3(**kwargs)
95
+
96
+ @classmethod
97
+ def for_sandbox(cls) -> S3:
98
+ """
99
+ :return:
100
+ """
101
+ kwargs = super()._auto_as_kwargs()
102
+ final_kwargs = kwargs | {
103
+ "endpoint": "http://localhost:4566",
104
+ "access_key_id": "minio",
105
+ "secret_access_key": "miniostorage",
106
+ }
107
+ return S3(**final_kwargs)
108
+
109
+ def get_fsspec_kwargs(self, anonymous: bool = False, **kwargs) -> typing.Dict[str, typing.Any]:
110
+ # Construct the config object
111
+ kwargs.pop("anonymous", None) # Remove anonymous if it exists, as we handle it separately
112
+ config: typing.Dict[str, typing.Any] = {}
113
+ if self._CONFIG_KEY_FSSPEC_S3_KEY_ID in kwargs or self.access_key_id:
114
+ config[self._CONFIG_KEY_FSSPEC_S3_KEY_ID] = kwargs.pop(
115
+ self._CONFIG_KEY_FSSPEC_S3_KEY_ID, self.access_key_id
116
+ )
117
+ if self._CONFIG_KEY_FSSPEC_S3_SECRET in kwargs or self.secret_access_key:
118
+ config[self._CONFIG_KEY_FSSPEC_S3_SECRET] = kwargs.pop(
119
+ self._CONFIG_KEY_FSSPEC_S3_SECRET, self.secret_access_key
120
+ )
121
+ if self._CONFIG_KEY_ENDPOINT in kwargs or self.endpoint:
122
+ config["endpoint_url"] = kwargs.pop(self._CONFIG_KEY_ENDPOINT, self.endpoint)
123
+
124
+ retries = kwargs.pop("retries", self.retries)
125
+ backoff = kwargs.pop("backoff", self.backoff)
126
+
127
+ if anonymous:
128
+ config[self._KEY_SKIP_SIGNATURE] = True
129
+
130
+ retry_config = {
131
+ "max_retries": retries,
132
+ "backoff": {
133
+ "base": 2,
134
+ "init_backoff": backoff,
135
+ "max_backoff": datetime.timedelta(seconds=16),
136
+ },
137
+ "retry_timeout": datetime.timedelta(minutes=3),
138
+ }
139
+
140
+ client_options = {"timeout": "99999s", "allow_http": True}
141
+
142
+ if config:
143
+ kwargs["config"] = config
144
+ kwargs["client_options"] = client_options or None
145
+ kwargs["retry_config"] = retry_config or None
146
+ if self.region:
147
+ kwargs["region"] = self.region
148
+
149
+ return kwargs
150
+
151
+
152
+ @dataclass(init=True, repr=True, eq=True, frozen=True)
153
+ class GCS(Storage):
154
+ """
155
+ Any GCS specific configuration.
156
+ """
157
+
158
+ gsutil_parallelism: bool = False
159
+
160
+ _KEY_ENV_VAR_MAPPING: ClassVar[dict[str, str]] = {
161
+ "gsutil_parallelism": "GCP_GSUTIL_PARALLELISM",
162
+ }
163
+
164
+ @classmethod
165
+ def auto(cls) -> GCS:
166
+ gsutil_parallelism = os.getenv(cls._KEY_ENV_VAR_MAPPING["gsutil_parallelism"], None)
167
+
168
+ kwargs: typing.Dict[str, typing.Any] = {}
169
+ kwargs = set_if_exists(kwargs, "gsutil_parallelism", gsutil_parallelism)
170
+ return GCS(**kwargs)
171
+
172
+ def get_fsspec_kwargs(self, anonymous: bool = False, **kwargs) -> typing.Dict[str, typing.Any]:
173
+ kwargs.pop("anonymous", None)
174
+ return kwargs
175
+
176
+
177
+ @dataclass(init=True, repr=True, eq=True, frozen=True)
178
+ class ABFS(Storage):
179
+ """
180
+ Any Azure Blob Storage specific configuration.
181
+ """
182
+
183
+ account_name: typing.Optional[str] = None
184
+ account_key: typing.Optional[str] = None
185
+ tenant_id: typing.Optional[str] = None
186
+ client_id: typing.Optional[str] = None
187
+ client_secret: typing.Optional[str] = None
188
+
189
+ _KEY_ENV_VAR_MAPPING: ClassVar[dict[str, str]] = {
190
+ "account_name": "AZURE_STORAGE_ACCOUNT_NAME",
191
+ "account_key": "AZURE_STORAGE_ACCOUNT_KEY",
192
+ "tenant_id": "AZURE_TENANT_ID",
193
+ "client_id": "AZURE_CLIENT_ID",
194
+ "client_secret": "AZURE_CLIENT_SECRET",
195
+ }
196
+ _KEY_SKIP_SIGNATURE: ClassVar = "skip_signature"
197
+
198
+ @classmethod
199
+ def auto(cls) -> ABFS:
200
+ account_name = os.getenv(cls._KEY_ENV_VAR_MAPPING["account_name"], None)
201
+ account_key = os.getenv(cls._KEY_ENV_VAR_MAPPING["account_key"], None)
202
+ tenant_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["tenant_id"], None)
203
+ client_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["client_id"], None)
204
+ client_secret = os.getenv(cls._KEY_ENV_VAR_MAPPING["client_secret"], None)
205
+
206
+ kwargs: typing.Dict[str, typing.Any] = {}
207
+ kwargs = set_if_exists(kwargs, "account_name", account_name)
208
+ kwargs = set_if_exists(kwargs, "account_key", account_key)
209
+ kwargs = set_if_exists(kwargs, "tenant_id", tenant_id)
210
+ kwargs = set_if_exists(kwargs, "client_id", client_id)
211
+ kwargs = set_if_exists(kwargs, "client_secret", client_secret)
212
+ return ABFS(**kwargs)
213
+
214
+ def get_fsspec_kwargs(self, anonymous: bool = False, **kwargs) -> typing.Dict[str, typing.Any]:
215
+ kwargs.pop("anonymous", None)
216
+ config: typing.Dict[str, typing.Any] = {}
217
+ if "account_name" in kwargs or self.account_name:
218
+ config["account_name"] = kwargs.get("account_name", self.account_name)
219
+ if "account_key" in kwargs or self.account_key:
220
+ config["account_key"] = kwargs.get("account_key", self.account_key)
221
+ if "client_id" in kwargs or self.client_id:
222
+ config["client_id"] = kwargs.get("client_id", self.client_id)
223
+ if "client_secret" in kwargs or self.client_secret:
224
+ config["client_secret"] = kwargs.get("client_secret", self.client_secret)
225
+ if "tenant_id" in kwargs or self.tenant_id:
226
+ config["tenant_id"] = kwargs.get("tenant_id", self.tenant_id)
227
+
228
+ if anonymous:
229
+ config[self._KEY_SKIP_SIGNATURE] = True
230
+
231
+ client_options = {"timeout": "99999s", "allow_http": "true"}
232
+
233
+ if config:
234
+ kwargs["config"] = config
235
+ kwargs["client_options"] = client_options
236
+
237
+ return kwargs
@@ -0,0 +1,274 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import dataclasses
5
+ import io
6
+ import os
7
+ import pathlib
8
+ import sys
9
+ import tempfile
10
+ import typing
11
+ from typing import Any, Hashable, Protocol
12
+
13
+ import aiofiles
14
+ import aiofiles.os
15
+ import obstore
16
+
17
+ if typing.TYPE_CHECKING:
18
+ from obstore import Bytes, ObjectMeta
19
+ from obstore.store import ObjectStore
20
+
21
+ CHUNK_SIZE = int(os.getenv("FLYTE_IO_CHUNK_SIZE", str(16 * 1024 * 1024)))
22
+ MAX_CONCURRENCY = int(os.getenv("FLYTE_IO_MAX_CONCURRENCY", str(32)))
23
+
24
+
25
+ class DownloadQueueEmpty(RuntimeError):
26
+ pass
27
+
28
+
29
+ class BufferProtocol(Protocol):
30
+ async def write(self, offset, length, value: Bytes) -> None: ...
31
+
32
+ async def read(self) -> memoryview: ...
33
+
34
+ @property
35
+ def complete(self) -> bool: ...
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class _MemoryBuffer:
40
+ arr: bytearray
41
+ pending: int
42
+ _closed: bool = False
43
+
44
+ async def write(self, offset: int, length: int, value: Bytes) -> None:
45
+ self.arr[offset : offset + length] = memoryview(value)
46
+ self.pending -= length
47
+
48
+ async def read(self) -> memoryview:
49
+ return memoryview(self.arr)
50
+
51
+ @property
52
+ def complete(self) -> bool:
53
+ return self.pending == 0
54
+
55
+ @classmethod
56
+ def new(cls, size):
57
+ return cls(arr=bytearray(size), pending=size)
58
+
59
+
60
+ @dataclasses.dataclass
61
+ class _FileBuffer:
62
+ path: pathlib.Path
63
+ pending: int
64
+ _handle: io.FileIO | None = None
65
+ _closed: bool = False
66
+
67
+ async def write(self, offset: int, length: int, value: Bytes) -> None:
68
+ async with aiofiles.open(self.path, mode="r+b") as f:
69
+ await f.seek(offset)
70
+ await f.write(value)
71
+ self.pending -= length
72
+
73
+ async def read(self) -> memoryview:
74
+ async with aiofiles.open(self.path, mode="rb") as f:
75
+ return memoryview(await f.read())
76
+
77
+ @property
78
+ def complete(self) -> bool:
79
+ return self.pending == 0
80
+
81
+ @classmethod
82
+ def new(cls, path: pathlib.Path, size: int):
83
+ path.parent.mkdir(parents=True, exist_ok=True)
84
+ path.touch()
85
+ return cls(path=path, pending=size)
86
+
87
+
88
+ @dataclasses.dataclass
89
+ class Chunk:
90
+ offset: int
91
+ length: int
92
+
93
+
94
+ @dataclasses.dataclass
95
+ class Source:
96
+ id: Hashable
97
+ path: pathlib.Path # Should be str, represents the fully qualified prefix of a file (no bucket)
98
+ length: int
99
+ metadata: Any | None = None
100
+
101
+
102
+ @dataclasses.dataclass
103
+ class DownloadTask:
104
+ source: Source
105
+ chunk: Chunk
106
+ target: pathlib.Path | None = None
107
+
108
+
109
+ class ObstoreParallelReader:
110
+ def __init__(
111
+ self,
112
+ store: ObjectStore,
113
+ *,
114
+ chunk_size=CHUNK_SIZE,
115
+ max_concurrency=MAX_CONCURRENCY,
116
+ ):
117
+ self._store = store
118
+ self._chunk_size = chunk_size
119
+ self._max_concurrency = max_concurrency
120
+
121
+ def _chunks(self, size) -> typing.Iterator[tuple[int, int]]:
122
+ cs = self._chunk_size
123
+ for offset in range(0, size, cs):
124
+ length = min(cs, size - offset)
125
+ yield offset, length
126
+
127
+ async def _as_completed(self, gen: typing.AsyncGenerator[DownloadTask, None], transformer=None):
128
+ inq: asyncio.Queue = asyncio.Queue(self._max_concurrency * 2)
129
+ outq: asyncio.Queue = asyncio.Queue()
130
+ sentinel = object()
131
+ done = asyncio.Event()
132
+
133
+ active: dict[Hashable, _FileBuffer | _MemoryBuffer] = {}
134
+
135
+ async def _fill():
136
+ # Helper function to fill the input queue, this is because the generator is async because it does list/head
137
+ # calls on the object store which are async.
138
+ try:
139
+ counter = 0
140
+ async for task in gen:
141
+ if task.source.id not in active:
142
+ active[task.source.id] = (
143
+ _FileBuffer.new(task.target, task.source.length)
144
+ if task.target is not None
145
+ else _MemoryBuffer.new(task.source.length)
146
+ )
147
+ await inq.put(task)
148
+ counter += 1
149
+ await inq.put(sentinel)
150
+ if counter == 0:
151
+ raise DownloadQueueEmpty
152
+ except asyncio.CancelledError:
153
+ # document why we need to swallow this
154
+ pass
155
+
156
+ async def _worker():
157
+ try:
158
+ while not done.is_set():
159
+ task = await inq.get()
160
+ if task is sentinel:
161
+ inq.put_nowait(sentinel)
162
+ break
163
+ chunk_source_offset = task.chunk.offset
164
+ buf = active[task.source.id]
165
+ data_to_write = await obstore.get_range_async(
166
+ self._store,
167
+ str(task.source.path),
168
+ start=chunk_source_offset,
169
+ end=chunk_source_offset + task.chunk.length,
170
+ )
171
+ await buf.write(
172
+ task.chunk.offset,
173
+ task.chunk.length,
174
+ data_to_write,
175
+ )
176
+ if not buf.complete:
177
+ continue
178
+ if transformer is not None:
179
+ result = await transformer(buf)
180
+ elif task.target is not None:
181
+ result = task.target
182
+ else:
183
+ result = task.source
184
+ outq.put_nowait((task.source.id, result))
185
+ del active[task.source.id]
186
+ except asyncio.CancelledError:
187
+ pass
188
+ finally:
189
+ done.set()
190
+
191
+ # Yield results as they are completed
192
+ if sys.version_info >= (3, 11):
193
+ async with asyncio.TaskGroup() as tg:
194
+ tg.create_task(_fill())
195
+ for _ in range(self._max_concurrency):
196
+ tg.create_task(_worker())
197
+ while not done.is_set():
198
+ yield await outq.get()
199
+ else:
200
+ fill_task = asyncio.create_task(_fill())
201
+ worker_tasks = [asyncio.create_task(_worker()) for _ in range(self._max_concurrency)]
202
+ try:
203
+ while not done.is_set():
204
+ yield await outq.get()
205
+ except Exception as e:
206
+ if not fill_task.done():
207
+ fill_task.cancel()
208
+ for wt in worker_tasks:
209
+ if not wt.done():
210
+ wt.cancel()
211
+ raise e
212
+ finally:
213
+ await asyncio.gather(fill_task, *worker_tasks, return_exceptions=True)
214
+
215
+ # Drain the output queue
216
+ try:
217
+ while True:
218
+ yield outq.get_nowait()
219
+ except asyncio.QueueEmpty:
220
+ pass
221
+
222
+ async def download_files(
223
+ self, src_prefix: pathlib.Path, target_prefix: pathlib.Path, *paths, destination_file_name: str | None = None
224
+ ) -> None:
225
+ """
226
+ src_prefix: Prefix you want to download from in the object store, not including the bucket name, nor file name.
227
+ Should be replaced with string
228
+ target_prefix: Local directory to download to
229
+ paths: Specific paths (relative to src_prefix) to download. If empty, download everything
230
+ """
231
+
232
+ async def _list_downloadable() -> typing.AsyncGenerator[ObjectMeta, None]:
233
+ if paths:
234
+ # For specific file paths, use async head
235
+ for path_ in paths:
236
+ path = src_prefix / path_
237
+ x = await obstore.head_async(self._store, str(path))
238
+ yield x
239
+ return
240
+
241
+ # Use obstore.list() for recursive listing (all files in all subdirectories)
242
+ # obstore.list() returns an async iterator that yields batches (lists) of objects
243
+ async for batch in obstore.list(self._store, prefix=str(src_prefix)):
244
+ for obj in batch:
245
+ yield obj
246
+
247
+ async def _gen(tmp_dir: str) -> typing.AsyncGenerator[DownloadTask, None]:
248
+ async for obj in _list_downloadable():
249
+ path = pathlib.Path(obj["path"]) # e.g. Path(prefix/file.txt), needs to be changed to str.
250
+ size = obj["size"]
251
+ source = Source(id=path, path=path, length=size)
252
+ # Strip src_prefix from path for destination
253
+ rel_path = path.relative_to(src_prefix) # doesn't work on windows
254
+ for offset, length in self._chunks(size):
255
+ yield DownloadTask(
256
+ source=source,
257
+ target=tmp_dir / rel_path, # doesn't work on windows
258
+ chunk=Chunk(offset, length),
259
+ )
260
+
261
+ def _transform_decorator(tmp_dir: str):
262
+ async def _transformer(buf: _FileBuffer) -> None:
263
+ if len(paths) == 1 and destination_file_name is not None:
264
+ target = target_prefix / destination_file_name
265
+ else:
266
+ target = target_prefix / buf.path.relative_to(tmp_dir)
267
+ await aiofiles.os.makedirs(target.parent, exist_ok=True)
268
+ return await aiofiles.os.replace(buf.path, target) # mv buf.path target
269
+
270
+ return _transformer
271
+
272
+ with tempfile.TemporaryDirectory() as temporary_dir:
273
+ async for _ in self._as_completed(_gen(temporary_dir), transformer=_transform_decorator(temporary_dir)):
274
+ pass
@@ -0,0 +1,34 @@
1
+ from __future__ import annotations
2
+
3
+ import threading
4
+ import typing
5
+
6
+ # This file system is not really a filesystem, so users aren't really able to specify the remote path,
7
+ # at least not yet.
8
+ REMOTE_PLACEHOLDER = "flyte://data"
9
+
10
+ HashStructure = typing.Dict[str, typing.Tuple[bytes, int]]
11
+
12
+
13
+ class RemoteFSPathResolver:
14
+ protocol = "flyte://"
15
+ _flyte_path_to_remote_map: typing.ClassVar[typing.Dict[str, str]] = {}
16
+ _lock = threading.Lock()
17
+
18
+ @classmethod
19
+ def resolve_remote_path(cls, flyte_uri: str) -> typing.Optional[str]:
20
+ """
21
+ Given a flyte uri, return the remote path if it exists or was created in current session, otherwise return None
22
+ """
23
+ with cls._lock:
24
+ if flyte_uri in cls._flyte_path_to_remote_map:
25
+ return cls._flyte_path_to_remote_map[flyte_uri]
26
+ return None
27
+
28
+ @classmethod
29
+ def add_mapping(cls, flyte_uri: str, remote_path: str):
30
+ """
31
+ Thread safe method to dd a mapping from a flyte uri to a remote path
32
+ """
33
+ with cls._lock:
34
+ cls._flyte_path_to_remote_map[flyte_uri] = remote_path