flyteplugins-huggingface 2.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyteplugins/huggingface/__init__.py +0 -0
- flyteplugins/huggingface/datasets/__init__.py +33 -0
- flyteplugins/huggingface/datasets/_io.py +499 -0
- flyteplugins/huggingface/datasets/_source.py +158 -0
- flyteplugins/huggingface/datasets/_transformers.py +352 -0
- flyteplugins_huggingface-2.2.1.dist-info/METADATA +345 -0
- flyteplugins_huggingface-2.2.1.dist-info/RECORD +10 -0
- flyteplugins_huggingface-2.2.1.dist-info/WHEEL +5 -0
- flyteplugins_huggingface-2.2.1.dist-info/entry_points.txt +2 -0
- flyteplugins_huggingface-2.2.1.dist-info/top_level.txt +1 -0
|
File without changes
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
|
|
3
|
+
from ._source import HFSource, from_hf
|
|
4
|
+
|
|
5
|
+
__all__ = ["HFSource", "from_hf"]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@functools.lru_cache(maxsize=None)
|
|
9
|
+
def register_huggingface_dataset_transformers():
|
|
10
|
+
"""Register Hugging Face Dataset encoders and decoders."""
|
|
11
|
+
from flyte.io.extend import DataFrameTransformerEngine
|
|
12
|
+
|
|
13
|
+
from ._transformers import (
|
|
14
|
+
HFToHuggingFaceDatasetDecodingHandler,
|
|
15
|
+
HFToHuggingFaceIterableDatasetDecodingHandler,
|
|
16
|
+
HuggingFaceDatasetToParquetEncodingHandler,
|
|
17
|
+
HuggingFaceIterableDatasetToParquetEncodingHandler,
|
|
18
|
+
ParquetToHuggingFaceDatasetDecodingHandler,
|
|
19
|
+
ParquetToHuggingFaceIterableDatasetDecodingHandler,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
DataFrameTransformerEngine.register(HuggingFaceDatasetToParquetEncodingHandler(), default_format_for_type=True)
|
|
23
|
+
DataFrameTransformerEngine.register(ParquetToHuggingFaceDatasetDecodingHandler(), default_format_for_type=True)
|
|
24
|
+
DataFrameTransformerEngine.register(HFToHuggingFaceDatasetDecodingHandler())
|
|
25
|
+
DataFrameTransformerEngine.register(
|
|
26
|
+
HuggingFaceIterableDatasetToParquetEncodingHandler(),
|
|
27
|
+
default_format_for_type=True,
|
|
28
|
+
)
|
|
29
|
+
DataFrameTransformerEngine.register(
|
|
30
|
+
ParquetToHuggingFaceIterableDatasetDecodingHandler(),
|
|
31
|
+
default_format_for_type=True,
|
|
32
|
+
)
|
|
33
|
+
DataFrameTransformerEngine.register(HFToHuggingFaceIterableDatasetDecodingHandler())
|
|
@@ -0,0 +1,499 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import typing
|
|
7
|
+
|
|
8
|
+
import flyte.storage as storage
|
|
9
|
+
from flyte._logging import logger
|
|
10
|
+
from fsspec.asyn import AsyncFileSystem
|
|
11
|
+
|
|
12
|
+
from ._source import (
|
|
13
|
+
HFShard,
|
|
14
|
+
HFSource,
|
|
15
|
+
hf_revision,
|
|
16
|
+
hf_source_cache_key,
|
|
17
|
+
hf_source_payload,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
_HF_CACHE_MANIFEST = "_flyte_hf_manifest.json"
|
|
21
|
+
_HF_DATASET_REGISTRY = "huggingface/datasets"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class HFParquetError(ValueError):
|
|
25
|
+
"""Raised when a Hugging Face parquet conversion cannot be resolved."""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_hf_cache_path(source: HFSource, cache_key: str | None = None) -> str:
|
|
29
|
+
"""Return a deterministic remote-storage path for source."""
|
|
30
|
+
if source.cache_root is None:
|
|
31
|
+
raise ValueError("cache_root is required for deterministic HF cache paths")
|
|
32
|
+
|
|
33
|
+
cache_key = cache_key or hf_source_cache_key(source)
|
|
34
|
+
return join_uri_path(source.cache_root, _HF_DATASET_REGISTRY, "blobs", cache_key)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_hf_registry_record_path(source: HFSource, cache_key: str) -> str:
|
|
38
|
+
if source.cache_root is None:
|
|
39
|
+
raise ValueError("cache_root is required for HF registry records")
|
|
40
|
+
|
|
41
|
+
return join_uri_path(
|
|
42
|
+
source.cache_root,
|
|
43
|
+
_HF_DATASET_REGISTRY,
|
|
44
|
+
"by-key",
|
|
45
|
+
f"{cache_key}.json",
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_random_hf_path() -> str:
|
|
50
|
+
return os.path.join(str(storage.get_random_local_directory()), "hf-dataset")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def join_uri_path(base: str, *parts: str) -> str:
|
|
54
|
+
"""Join URI/object-store path components with POSIX separators."""
|
|
55
|
+
joined = base.rstrip("/")
|
|
56
|
+
for part in parts:
|
|
57
|
+
cleaned = part.strip("/")
|
|
58
|
+
if cleaned:
|
|
59
|
+
joined = f"{joined}/{cleaned}" if joined else cleaned
|
|
60
|
+
return joined
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _storage_kind(path: str) -> str:
|
|
64
|
+
return "remote" if storage.is_remote(path) else "local"
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _source_log_description(source: HFSource) -> str:
|
|
68
|
+
config = source.name if source.name is not None else "auto"
|
|
69
|
+
split = source.split if source.split is not None else "all"
|
|
70
|
+
return f"config={config}, split={split}"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
async def run_sync_io(
|
|
74
|
+
label: str,
|
|
75
|
+
func: typing.Callable[..., typing.Any],
|
|
76
|
+
*args: typing.Any,
|
|
77
|
+
**kwargs: typing.Any,
|
|
78
|
+
) -> typing.Any:
|
|
79
|
+
"""Run blocking sync IO without blocking the active event loop."""
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
# Using to_thread() as the blocking calls are mostly IO-bound.
|
|
83
|
+
return await asyncio.to_thread(func, *args, **kwargs)
|
|
84
|
+
except asyncio.CancelledError:
|
|
85
|
+
logger.warning(
|
|
86
|
+
f"Cancellation requested while running {label}. The active sync IO "
|
|
87
|
+
"call may finish in its worker thread before stopping."
|
|
88
|
+
)
|
|
89
|
+
raise
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
async def storage_path_exists(path: str) -> bool:
|
|
93
|
+
if storage.is_remote(path):
|
|
94
|
+
try:
|
|
95
|
+
return await storage.exists(path)
|
|
96
|
+
except Exception as e:
|
|
97
|
+
logger.debug(f"Unable to check whether {path} exists: {e}")
|
|
98
|
+
return False
|
|
99
|
+
|
|
100
|
+
return typing.cast(bool, await run_sync_io("local exists", os.path.exists, path))
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
async def storage_read_bytes(path: str) -> bytes:
|
|
104
|
+
if storage.is_remote(path):
|
|
105
|
+
local_path = storage.get_random_local_path(file_path_or_file_name=os.path.basename(path))
|
|
106
|
+
await storage.get(path, str(local_path))
|
|
107
|
+
path = str(local_path)
|
|
108
|
+
|
|
109
|
+
def _read() -> bytes:
|
|
110
|
+
with open(path, "rb") as fh:
|
|
111
|
+
return fh.read()
|
|
112
|
+
|
|
113
|
+
return typing.cast(bytes, await run_sync_io("read local file", _read))
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
async def storage_write_bytes(path: str, data: bytes) -> None:
|
|
117
|
+
if storage.is_remote(path):
|
|
118
|
+
await storage.put_stream(data, to_path=path)
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
parent = os.path.dirname(path)
|
|
122
|
+
if parent:
|
|
123
|
+
os.makedirs(parent, exist_ok=True)
|
|
124
|
+
|
|
125
|
+
def _write() -> None:
|
|
126
|
+
with open(path, "wb") as fh:
|
|
127
|
+
fh.write(data)
|
|
128
|
+
|
|
129
|
+
await run_sync_io("write local file", _write)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
async def list_parquet_files(uri: str, filesystem) -> list[str]:
|
|
133
|
+
"""Return sorted parquet file paths under uri, recursively."""
|
|
134
|
+
try:
|
|
135
|
+
if isinstance(filesystem, AsyncFileSystem):
|
|
136
|
+
raw = sorted(f for f in await filesystem._find(uri) if f.endswith(".parquet"))
|
|
137
|
+
else:
|
|
138
|
+
found = await run_sync_io("filesystem find parquet", filesystem.find, uri)
|
|
139
|
+
raw = sorted(f for f in found if f.endswith(".parquet"))
|
|
140
|
+
|
|
141
|
+
if not raw:
|
|
142
|
+
return [join_uri_path(uri, f"{0:05}.parquet")]
|
|
143
|
+
|
|
144
|
+
if "://" in uri and "://" not in raw[0]:
|
|
145
|
+
proto = uri.split("://", maxsplit=1)[0] + "://"
|
|
146
|
+
raw = [f"{proto}{f}" for f in raw]
|
|
147
|
+
return raw
|
|
148
|
+
except Exception as e:
|
|
149
|
+
logger.warning(f"Unable to list parquet files under {uri}: {e}")
|
|
150
|
+
return [join_uri_path(uri, f"{0:05}.parquet")]
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _format_names(names: list[str]) -> str:
|
|
154
|
+
return ", ".join(names) if names else "none"
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _list_hf_dir(hfs, path: str, revision: str) -> list[dict[str, typing.Any]]:
|
|
158
|
+
return typing.cast(
|
|
159
|
+
list[dict[str, typing.Any]],
|
|
160
|
+
hfs.ls(path, revision=revision, detail=True),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _try_list_hf_dir(
|
|
165
|
+
hfs,
|
|
166
|
+
path: str,
|
|
167
|
+
revision: str,
|
|
168
|
+
) -> list[dict[str, typing.Any]] | None:
|
|
169
|
+
try:
|
|
170
|
+
return _list_hf_dir(hfs, path, revision)
|
|
171
|
+
except FileNotFoundError:
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _entry_name(entry: dict[str, typing.Any]) -> str:
|
|
176
|
+
return typing.cast(str, entry["name"]).rstrip("/").split("/")[-1]
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _is_directory(entry: dict[str, typing.Any]) -> bool:
|
|
180
|
+
return entry.get("type") == "directory"
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _is_parquet_file(entry: dict[str, typing.Any]) -> bool:
|
|
184
|
+
return entry.get("type") == "file" and typing.cast(str, entry.get("name", "")).endswith(".parquet")
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _config_dirs(entries: list[dict[str, typing.Any]]) -> list[dict[str, typing.Any]]:
|
|
188
|
+
return sorted(
|
|
189
|
+
[entry for entry in entries if _is_directory(entry)],
|
|
190
|
+
key=lambda entry: typing.cast(str, entry["name"]),
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _resolve_hf_config(
|
|
195
|
+
hfs,
|
|
196
|
+
source: HFSource,
|
|
197
|
+
revision: str,
|
|
198
|
+
) -> tuple[str, str, list[dict[str, typing.Any]]]:
|
|
199
|
+
"""Return the config name, path, and entries for source, or raise HFParquetError."""
|
|
200
|
+
repo_path = join_uri_path("datasets", source.repo)
|
|
201
|
+
|
|
202
|
+
if source.name is not None:
|
|
203
|
+
config_path = join_uri_path(repo_path, source.name)
|
|
204
|
+
entries = _try_list_hf_dir(hfs, config_path, revision)
|
|
205
|
+
if entries is not None:
|
|
206
|
+
return source.name, config_path, entries
|
|
207
|
+
|
|
208
|
+
repo_entries = _try_list_hf_dir(hfs, repo_path, revision) or []
|
|
209
|
+
available = [_entry_name(entry) for entry in _config_dirs(repo_entries)]
|
|
210
|
+
raise HFParquetError(
|
|
211
|
+
f"No Hugging Face parquet config named {source.name!r} found for "
|
|
212
|
+
f"{source.repo} at revision {revision!r}. Available configs: "
|
|
213
|
+
f"{_format_names(available)}."
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
default_path = join_uri_path(repo_path, "default")
|
|
217
|
+
default_entries = _try_list_hf_dir(hfs, default_path, revision)
|
|
218
|
+
if default_entries is not None:
|
|
219
|
+
return "default", default_path, default_entries
|
|
220
|
+
|
|
221
|
+
repo_entries = _try_list_hf_dir(hfs, repo_path, revision)
|
|
222
|
+
if repo_entries is None:
|
|
223
|
+
raise HFParquetError(f"No Hugging Face parquet conversion found for {source.repo} at revision {revision!r}.")
|
|
224
|
+
|
|
225
|
+
configs = _config_dirs(repo_entries)
|
|
226
|
+
config_names = [_entry_name(entry) for entry in configs]
|
|
227
|
+
if len(configs) == 1:
|
|
228
|
+
config_path = typing.cast(str, configs[0]["name"])
|
|
229
|
+
return config_names[0], config_path, _list_hf_dir(hfs, config_path, revision)
|
|
230
|
+
|
|
231
|
+
if not configs:
|
|
232
|
+
raise HFParquetError(f"No Hugging Face parquet configs found for {source.repo} at revision {revision!r}.")
|
|
233
|
+
|
|
234
|
+
raise HFParquetError(
|
|
235
|
+
f"Hugging Face dataset {source.repo} has multiple parquet configs: "
|
|
236
|
+
f"{_format_names(config_names)}. Pass name=... to from_hf()."
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def collect_hf_shards(source: HFSource) -> list[HFShard]:
|
|
241
|
+
"""Return parquet shards for a Hugging Face dataset source."""
|
|
242
|
+
import huggingface_hub
|
|
243
|
+
|
|
244
|
+
token = os.environ.get("HF_TOKEN")
|
|
245
|
+
if token is None:
|
|
246
|
+
logger.warning("HF_TOKEN not set, using anonymous access. Private datasets will fail.")
|
|
247
|
+
|
|
248
|
+
hfs = huggingface_hub.HfFileSystem(token=token)
|
|
249
|
+
revision = hf_revision(source)
|
|
250
|
+
config, base_path, base_entries = _resolve_hf_config(hfs, source, revision)
|
|
251
|
+
|
|
252
|
+
if source.split:
|
|
253
|
+
split_paths = [(source.split, join_uri_path(base_path, source.split))]
|
|
254
|
+
else:
|
|
255
|
+
split_paths = [
|
|
256
|
+
(_entry_name(entry), typing.cast(str, entry["name"])) for entry in base_entries if _is_directory(entry)
|
|
257
|
+
]
|
|
258
|
+
if not split_paths and any(_is_parquet_file(entry) for entry in base_entries):
|
|
259
|
+
split_paths = [("data", base_path)]
|
|
260
|
+
|
|
261
|
+
shards: list[HFShard] = []
|
|
262
|
+
for split_name, search_path in split_paths:
|
|
263
|
+
entries = _try_list_hf_dir(hfs, search_path, revision)
|
|
264
|
+
if entries is None:
|
|
265
|
+
raise HFParquetError(
|
|
266
|
+
f"No Hugging Face parquet split named {split_name!r} found for "
|
|
267
|
+
f"{source.repo} (config={config}) at revision {revision!r}."
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
for file_info in entries:
|
|
271
|
+
if _is_parquet_file(file_info):
|
|
272
|
+
file_name = file_info["name"].split("/")[-1]
|
|
273
|
+
rel = file_name if source.split else join_uri_path(split_name, file_name)
|
|
274
|
+
shards.append(
|
|
275
|
+
HFShard(
|
|
276
|
+
rel_path=rel,
|
|
277
|
+
hf_name=file_info["name"],
|
|
278
|
+
size=file_info.get("size"),
|
|
279
|
+
etag=file_info.get("etag") or file_info.get("ETag"),
|
|
280
|
+
last_modified=file_info.get("last_modified"),
|
|
281
|
+
)
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
if not shards:
|
|
285
|
+
raise HFParquetError(
|
|
286
|
+
f"No parquet files found for {source.repo} "
|
|
287
|
+
f"(config={config}, split={source.split}). "
|
|
288
|
+
"The dataset may not have been auto-converted to parquet yet."
|
|
289
|
+
)
|
|
290
|
+
return sorted(shards, key=lambda s: s.rel_path)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def manifest_path(remote_path: str) -> str:
|
|
294
|
+
return join_uri_path(remote_path, _HF_CACHE_MANIFEST)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def hf_cache_manifest(
|
|
298
|
+
source: HFSource,
|
|
299
|
+
shards: list[HFShard],
|
|
300
|
+
cache_key: str,
|
|
301
|
+
) -> dict[str, typing.Any]:
|
|
302
|
+
return {
|
|
303
|
+
"version": 1,
|
|
304
|
+
"cache_key": cache_key,
|
|
305
|
+
"source": hf_source_payload(source),
|
|
306
|
+
"shards": hf_source_payload(source, shards)["shards"],
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
async def read_cache_manifest(remote_path: str) -> dict[str, typing.Any] | None:
|
|
311
|
+
path = manifest_path(remote_path)
|
|
312
|
+
try:
|
|
313
|
+
if not await storage_path_exists(path):
|
|
314
|
+
return None
|
|
315
|
+
return typing.cast(
|
|
316
|
+
dict[str, typing.Any],
|
|
317
|
+
json.loads((await storage_read_bytes(path)).decode("utf-8")),
|
|
318
|
+
)
|
|
319
|
+
except Exception as e:
|
|
320
|
+
logger.debug(f"Unable to read HF cache manifest {path}: {e}")
|
|
321
|
+
return None
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
async def write_cache_manifest(remote_path: str, manifest: dict[str, typing.Any]) -> None:
|
|
325
|
+
data = json.dumps(manifest, sort_keys=True, indent=2).encode("utf-8")
|
|
326
|
+
await storage_write_bytes(manifest_path(remote_path), data)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
async def read_registry_record(
|
|
330
|
+
source: HFSource,
|
|
331
|
+
cache_key: str,
|
|
332
|
+
) -> dict[str, typing.Any] | None:
|
|
333
|
+
path = get_hf_registry_record_path(source, cache_key)
|
|
334
|
+
try:
|
|
335
|
+
if not await storage_path_exists(path):
|
|
336
|
+
return None
|
|
337
|
+
return typing.cast(
|
|
338
|
+
dict[str, typing.Any],
|
|
339
|
+
json.loads((await storage_read_bytes(path)).decode("utf-8")),
|
|
340
|
+
)
|
|
341
|
+
except Exception as e:
|
|
342
|
+
logger.debug(f"Unable to read HF registry record {path}: {e}")
|
|
343
|
+
return None
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
async def write_registry_record(
|
|
347
|
+
source: HFSource,
|
|
348
|
+
cache_key: str,
|
|
349
|
+
manifest: dict[str, typing.Any],
|
|
350
|
+
) -> None:
|
|
351
|
+
record = dict(manifest)
|
|
352
|
+
data = json.dumps(record, sort_keys=True, indent=2).encode("utf-8")
|
|
353
|
+
await storage_write_bytes(get_hf_registry_record_path(source, cache_key), data)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
async def open_sync_hf_reader(hfs, hf_name: str, revision: str):
|
|
357
|
+
return await run_sync_io(
|
|
358
|
+
"open HF shard",
|
|
359
|
+
hfs.open,
|
|
360
|
+
hf_name,
|
|
361
|
+
"rb",
|
|
362
|
+
revision=revision,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
async def close_sync_file(label: str, file_obj) -> None:
|
|
367
|
+
close = getattr(file_obj, "close", None)
|
|
368
|
+
if close is not None:
|
|
369
|
+
await run_sync_io(label, close)
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
async def iter_hf_shard_chunks(
|
|
373
|
+
hfs,
|
|
374
|
+
shard: HFShard,
|
|
375
|
+
*,
|
|
376
|
+
revision: str,
|
|
377
|
+
chunk_size: int,
|
|
378
|
+
) -> typing.AsyncIterator[bytes]:
|
|
379
|
+
src = await open_sync_hf_reader(hfs, shard.hf_name, revision)
|
|
380
|
+
try:
|
|
381
|
+
while chunk := await run_sync_io("read HF shard chunk", src.read, chunk_size):
|
|
382
|
+
yield chunk
|
|
383
|
+
finally:
|
|
384
|
+
await close_sync_file("close HF shard", src)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
async def stream_hf_shard(
|
|
388
|
+
hfs,
|
|
389
|
+
shard: HFShard,
|
|
390
|
+
dest: str,
|
|
391
|
+
*,
|
|
392
|
+
revision: str,
|
|
393
|
+
chunk_size: int,
|
|
394
|
+
) -> None:
|
|
395
|
+
if not storage.is_remote(dest):
|
|
396
|
+
parent = os.path.dirname(dest)
|
|
397
|
+
if parent:
|
|
398
|
+
os.makedirs(parent, exist_ok=True)
|
|
399
|
+
|
|
400
|
+
await storage.put_stream(
|
|
401
|
+
iter_hf_shard_chunks(
|
|
402
|
+
hfs,
|
|
403
|
+
shard,
|
|
404
|
+
revision=revision,
|
|
405
|
+
chunk_size=chunk_size,
|
|
406
|
+
),
|
|
407
|
+
to_path=dest,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
async def stream_hf_to_remote(
|
|
412
|
+
source: HFSource,
|
|
413
|
+
remote_path: str,
|
|
414
|
+
shards: list[HFShard] | None = None,
|
|
415
|
+
manifest: dict[str, typing.Any] | None = None,
|
|
416
|
+
) -> None:
|
|
417
|
+
"""Stream parquet shards from Hugging Face Hub to Flyte remote storage."""
|
|
418
|
+
import huggingface_hub
|
|
419
|
+
|
|
420
|
+
token = os.environ.get("HF_TOKEN")
|
|
421
|
+
hfs = huggingface_hub.HfFileSystem(token=token)
|
|
422
|
+
chunk_size = 64 * 1024 * 1024
|
|
423
|
+
revision = hf_revision(source)
|
|
424
|
+
|
|
425
|
+
if shards is None:
|
|
426
|
+
shards = typing.cast(
|
|
427
|
+
list[HFShard],
|
|
428
|
+
await run_sync_io("collect HF parquet shards", collect_hf_shards, source),
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
logger.info(
|
|
432
|
+
f"Streaming {len(shards)} Hugging Face parquet shard(s) for {source.repo} "
|
|
433
|
+
f"to {_storage_kind(remote_path)} path {remote_path}"
|
|
434
|
+
)
|
|
435
|
+
for shard in shards:
|
|
436
|
+
dest = join_uri_path(remote_path, shard.rel_path)
|
|
437
|
+
logger.info(f"Streaming {shard.rel_path} to {dest}")
|
|
438
|
+
await stream_hf_shard(
|
|
439
|
+
hfs,
|
|
440
|
+
shard,
|
|
441
|
+
dest,
|
|
442
|
+
revision=revision,
|
|
443
|
+
chunk_size=chunk_size,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
if manifest is not None:
|
|
447
|
+
await write_cache_manifest(remote_path, manifest)
|
|
448
|
+
|
|
449
|
+
logger.info(f"Streamed {len(shards)} parquet file(s) to {remote_path}")
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
async def ensure_hf_cached(source: HFSource) -> str:
|
|
453
|
+
"""Return the remote path for source, fetching from HF if not cached."""
|
|
454
|
+
shards = typing.cast(
|
|
455
|
+
list[HFShard],
|
|
456
|
+
await run_sync_io("collect HF parquet shards", collect_hf_shards, source),
|
|
457
|
+
)
|
|
458
|
+
cache_key = hf_source_cache_key(source, shards)
|
|
459
|
+
expected_manifest = hf_cache_manifest(source, shards, cache_key)
|
|
460
|
+
|
|
461
|
+
if source.cache_root is None:
|
|
462
|
+
remote_path = get_random_hf_path()
|
|
463
|
+
logger.info(
|
|
464
|
+
f"Materializing Hugging Face dataset {source.repo} "
|
|
465
|
+
f"({_source_log_description(source)}) "
|
|
466
|
+
f"to local path {remote_path}"
|
|
467
|
+
)
|
|
468
|
+
await stream_hf_to_remote(source, remote_path, shards, expected_manifest)
|
|
469
|
+
return remote_path
|
|
470
|
+
|
|
471
|
+
default_remote_path = get_hf_cache_path(source, cache_key)
|
|
472
|
+
logger.info(
|
|
473
|
+
f"Checking Hugging Face dataset cache for {source.repo} "
|
|
474
|
+
f"({_source_log_description(source)}) "
|
|
475
|
+
f"under {source.cache_root}"
|
|
476
|
+
)
|
|
477
|
+
registry_record = await read_registry_record(source, cache_key)
|
|
478
|
+
if await read_cache_manifest(default_remote_path) == expected_manifest:
|
|
479
|
+
if registry_record is None:
|
|
480
|
+
await write_registry_record(
|
|
481
|
+
source,
|
|
482
|
+
cache_key,
|
|
483
|
+
expected_manifest,
|
|
484
|
+
)
|
|
485
|
+
logger.info(f"Using cached Hugging Face dataset at {default_remote_path}")
|
|
486
|
+
return default_remote_path
|
|
487
|
+
|
|
488
|
+
logger.info(
|
|
489
|
+
f"Materializing Hugging Face dataset {source.repo} "
|
|
490
|
+
f"({_source_log_description(source)}) "
|
|
491
|
+
f"to remote cache artifact {default_remote_path}"
|
|
492
|
+
)
|
|
493
|
+
await stream_hf_to_remote(source, default_remote_path, shards, expected_manifest)
|
|
494
|
+
await write_registry_record(
|
|
495
|
+
source,
|
|
496
|
+
cache_key,
|
|
497
|
+
expected_manifest,
|
|
498
|
+
)
|
|
499
|
+
return default_remote_path
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import typing
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from urllib.parse import parse_qs, urlencode
|
|
8
|
+
|
|
9
|
+
from flyte.io import PARQUET, DataFrame
|
|
10
|
+
|
|
11
|
+
_HF_PARQUET_REVISION = "refs/convert/parquet"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class HFSource:
|
|
16
|
+
"""HuggingFace dataset source for task parameter defaults."""
|
|
17
|
+
|
|
18
|
+
repo: str
|
|
19
|
+
name: str | None = None
|
|
20
|
+
split: str | None = None
|
|
21
|
+
revision: str | None = None
|
|
22
|
+
cache_root: str | None = None
|
|
23
|
+
|
|
24
|
+
def __post_init__(self) -> None:
|
|
25
|
+
self.repo = self._normalize_required_field("repo", self.repo)
|
|
26
|
+
self.name = self._normalize_optional_field("name", self.name)
|
|
27
|
+
self.split = self._normalize_optional_field("split", self.split)
|
|
28
|
+
self.revision = self._normalize_optional_field("revision", self.revision)
|
|
29
|
+
self.cache_root = self._normalize_optional_field("cache_root", self.cache_root)
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def _normalize_required_field(field_name: str, value: str) -> str:
|
|
33
|
+
normalized = value.strip()
|
|
34
|
+
if not normalized:
|
|
35
|
+
raise ValueError(f"HFSource {field_name} must not be empty")
|
|
36
|
+
return normalized
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def _normalize_optional_field(field_name: str, value: str | None) -> str | None:
|
|
40
|
+
if value is None:
|
|
41
|
+
return None
|
|
42
|
+
normalized = value.strip()
|
|
43
|
+
if not normalized:
|
|
44
|
+
raise ValueError(f"HFSource {field_name} must not be blank")
|
|
45
|
+
return normalized
|
|
46
|
+
|
|
47
|
+
def to_hf_uri(self) -> str:
|
|
48
|
+
uri = f"hf://{self.repo}"
|
|
49
|
+
params = {}
|
|
50
|
+
if self.name:
|
|
51
|
+
params["name"] = self.name
|
|
52
|
+
if self.split:
|
|
53
|
+
params["split"] = self.split
|
|
54
|
+
if self.cache_root:
|
|
55
|
+
params["cache_root"] = self.cache_root
|
|
56
|
+
if self.revision:
|
|
57
|
+
params["revision"] = self.revision
|
|
58
|
+
if params:
|
|
59
|
+
uri = f"{uri}?{urlencode(params, safe=':/')}"
|
|
60
|
+
return uri
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def from_hf_uri(cls, uri: str) -> "HFSource":
|
|
64
|
+
if not uri.startswith("hf://"):
|
|
65
|
+
raise ValueError(f"Invalid HF URI: {uri}")
|
|
66
|
+
|
|
67
|
+
path, _, query = uri[5:].partition("?")
|
|
68
|
+
repo = path.strip("/")
|
|
69
|
+
if not repo:
|
|
70
|
+
raise ValueError(f"Invalid HF URI: {uri}")
|
|
71
|
+
|
|
72
|
+
query_params = parse_qs(query)
|
|
73
|
+
name = query_params.get("name", [None])[0]
|
|
74
|
+
split = query_params.get("split", [None])[0]
|
|
75
|
+
revision = query_params.get("revision", [None])[0]
|
|
76
|
+
cache_root = query_params.get("cache_root", [None])[0]
|
|
77
|
+
|
|
78
|
+
return cls(
|
|
79
|
+
repo=repo,
|
|
80
|
+
name=name,
|
|
81
|
+
split=split,
|
|
82
|
+
revision=revision,
|
|
83
|
+
cache_root=cache_root,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def from_hf(
|
|
88
|
+
repo: str,
|
|
89
|
+
*,
|
|
90
|
+
name: str | None = None,
|
|
91
|
+
split: str | None = None,
|
|
92
|
+
revision: str | None = None,
|
|
93
|
+
cache_root: str | None = None,
|
|
94
|
+
) -> DataFrame:
|
|
95
|
+
"""Return a DataFrame reference for use as a task parameter default.
|
|
96
|
+
|
|
97
|
+
cache_root optionally points at a stable remote directory that can be reused
|
|
98
|
+
across runs. Without it, the dataset is materialized to a generated Flyte raw-data
|
|
99
|
+
path for this run. If name is omitted, the plugin resolves the dataset's
|
|
100
|
+
default converted-parquet config, or the only available config when there is
|
|
101
|
+
exactly one.
|
|
102
|
+
"""
|
|
103
|
+
source = HFSource(
|
|
104
|
+
repo=repo,
|
|
105
|
+
name=name,
|
|
106
|
+
split=split,
|
|
107
|
+
revision=revision,
|
|
108
|
+
cache_root=cache_root,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
return DataFrame(
|
|
112
|
+
uri=source.to_hf_uri(),
|
|
113
|
+
format=PARQUET,
|
|
114
|
+
hash=hf_source_cache_key(source),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@dataclass(frozen=True)
|
|
119
|
+
class HFShard:
|
|
120
|
+
rel_path: str
|
|
121
|
+
hf_name: str
|
|
122
|
+
size: int | None = None
|
|
123
|
+
etag: str | None = None
|
|
124
|
+
last_modified: str | None = None
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def hf_revision(source: HFSource) -> str:
|
|
128
|
+
return source.revision or _HF_PARQUET_REVISION
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def hf_source_payload(source: HFSource, shards: list[HFShard] | None = None) -> dict[str, typing.Any]:
|
|
132
|
+
payload: dict[str, typing.Any] = {
|
|
133
|
+
"repo": source.repo,
|
|
134
|
+
"name": source.name,
|
|
135
|
+
"split": source.split,
|
|
136
|
+
"revision": hf_revision(source),
|
|
137
|
+
}
|
|
138
|
+
if shards is not None:
|
|
139
|
+
payload["shards"] = [
|
|
140
|
+
{
|
|
141
|
+
"rel_path": shard.rel_path,
|
|
142
|
+
"hf_name": shard.hf_name,
|
|
143
|
+
"size": shard.size,
|
|
144
|
+
"etag": shard.etag,
|
|
145
|
+
"last_modified": (str(shard.last_modified) if shard.last_modified is not None else None),
|
|
146
|
+
}
|
|
147
|
+
for shard in sorted(shards, key=lambda s: s.rel_path)
|
|
148
|
+
]
|
|
149
|
+
return payload
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def hf_source_cache_key(source: HFSource, shards: list[HFShard] | None = None) -> str:
|
|
153
|
+
payload = json.dumps(
|
|
154
|
+
hf_source_payload(source, shards),
|
|
155
|
+
sort_keys=True,
|
|
156
|
+
separators=(",", ":"),
|
|
157
|
+
)
|
|
158
|
+
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|