flyteplugins-huggingface 2.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyteplugins/huggingface/__init__.py +0 -0
- flyteplugins/huggingface/datasets/__init__.py +33 -0
- flyteplugins/huggingface/datasets/_io.py +499 -0
- flyteplugins/huggingface/datasets/_source.py +158 -0
- flyteplugins/huggingface/datasets/_transformers.py +352 -0
- flyteplugins_huggingface-2.2.1.dist-info/METADATA +345 -0
- flyteplugins_huggingface-2.2.1.dist-info/RECORD +10 -0
- flyteplugins_huggingface-2.2.1.dist-info/WHEEL +5 -0
- flyteplugins_huggingface-2.2.1.dist-info/entry_points.txt +2 -0
- flyteplugins_huggingface-2.2.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import pathlib
|
|
4
|
+
import typing
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import datasets
|
|
8
|
+
import flyte.storage as storage
|
|
9
|
+
import pyarrow as pa
|
|
10
|
+
import pyarrow.parquet as pq
|
|
11
|
+
from flyte._logging import logger
|
|
12
|
+
from flyte.io import PARQUET, DataFrame
|
|
13
|
+
from flyte.io.extend import DataFrameDecoder, DataFrameEncoder
|
|
14
|
+
from flyteidl2.core import literals_pb2, types_pb2
|
|
15
|
+
from fsspec.core import strip_protocol
|
|
16
|
+
|
|
17
|
+
from ._io import ensure_hf_cached, join_uri_path, list_parquet_files, run_sync_io
|
|
18
|
+
from ._source import HFSource
|
|
19
|
+
|
|
20
|
+
_ROWS_PER_SHARD = 100_000
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _write_dataset(df: datasets.Dataset, path: str, filesystem) -> None:
|
|
24
|
+
table = df.data.table
|
|
25
|
+
writer = pq.ParquetWriter(strip_protocol(path), table.schema, filesystem=filesystem)
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
for batch in table.to_batches(max_chunksize=10_000):
|
|
29
|
+
writer.write_batch(batch)
|
|
30
|
+
finally:
|
|
31
|
+
writer.close()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _read_parquet_files(
|
|
35
|
+
parquet_files: list[str],
|
|
36
|
+
columns: list[str] | None,
|
|
37
|
+
filesystem=None,
|
|
38
|
+
) -> pa.Table:
|
|
39
|
+
tables = [
|
|
40
|
+
pq.read_table(
|
|
41
|
+
strip_protocol(file_path),
|
|
42
|
+
columns=columns,
|
|
43
|
+
filesystem=filesystem if storage.is_remote(file_path) else None,
|
|
44
|
+
)
|
|
45
|
+
for file_path in parquet_files
|
|
46
|
+
]
|
|
47
|
+
return pa.concat_tables(tables) if len(tables) > 1 else tables[0]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _batch_to_rows(batch: pa.RecordBatch) -> list[dict]:
|
|
51
|
+
col_lists = {name: col.to_pylist() for name, col in zip(batch.schema.names, batch.columns)}
|
|
52
|
+
return [{name: col_lists[name][i] for name in batch.schema.names} for i in range(batch.num_rows)]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _iter_parquet_rows(parquet_files: list[str], columns: list[str] | None) -> typing.Iterator[dict]:
|
|
56
|
+
for file_path in parquet_files:
|
|
57
|
+
filesystem = storage.get_underlying_filesystem(path=file_path) if storage.is_remote(file_path) else None
|
|
58
|
+
pf = pq.ParquetFile(
|
|
59
|
+
strip_protocol(file_path),
|
|
60
|
+
filesystem=filesystem,
|
|
61
|
+
)
|
|
62
|
+
for batch in pf.iter_batches(batch_size=10_000, columns=columns):
|
|
63
|
+
yield from _batch_to_rows(batch)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _write_iterable_dataset(ds: datasets.IterableDataset, uri: str, filesystem) -> None:
|
|
67
|
+
file_idx = 0
|
|
68
|
+
rows_in_shard = 0
|
|
69
|
+
writer: pq.ParquetWriter | None = None
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
for batch in ds.iter(batch_size=10_000):
|
|
73
|
+
table = pa.table(batch)
|
|
74
|
+
|
|
75
|
+
if writer is None:
|
|
76
|
+
shard_path = join_uri_path(uri, f"{file_idx:05}.parquet")
|
|
77
|
+
writer = pq.ParquetWriter(
|
|
78
|
+
strip_protocol(shard_path),
|
|
79
|
+
table.schema,
|
|
80
|
+
filesystem=filesystem,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
for arrow_batch in table.to_batches():
|
|
84
|
+
writer.write_batch(arrow_batch)
|
|
85
|
+
|
|
86
|
+
rows_in_shard += len(table)
|
|
87
|
+
|
|
88
|
+
if rows_in_shard >= _ROWS_PER_SHARD:
|
|
89
|
+
writer.close()
|
|
90
|
+
writer = None
|
|
91
|
+
|
|
92
|
+
file_idx += 1
|
|
93
|
+
rows_in_shard = 0
|
|
94
|
+
finally:
|
|
95
|
+
if writer is not None:
|
|
96
|
+
writer.close()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _requested_columns(
|
|
100
|
+
current_task_metadata: literals_pb2.StructuredDatasetMetadata,
|
|
101
|
+
) -> list[str] | None:
|
|
102
|
+
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
|
|
103
|
+
return [c.name for c in current_task_metadata.structured_dataset_type.columns]
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
async def _localize_parquet_files(parquet_files: list[str]) -> list[str]:
|
|
108
|
+
"""Download remote parquet files before handing them to PyArrow."""
|
|
109
|
+
local_files: list[str] = []
|
|
110
|
+
|
|
111
|
+
for file_path in parquet_files:
|
|
112
|
+
if storage.is_remote(file_path):
|
|
113
|
+
local_dir = storage.get_random_local_directory()
|
|
114
|
+
local_name = pathlib.PurePosixPath(file_path.rstrip("/")).name
|
|
115
|
+
local_path = str(local_dir / local_name)
|
|
116
|
+
logger.info(f"Downloading remote parquet shard {file_path} to {local_path}")
|
|
117
|
+
local_files.append(await storage.get(file_path, local_path))
|
|
118
|
+
else:
|
|
119
|
+
local_files.append(strip_protocol(file_path))
|
|
120
|
+
return local_files
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _probe_remote_parquet_files(parquet_files: list[str], filesystem) -> None:
|
|
124
|
+
for file_path in parquet_files:
|
|
125
|
+
if not storage.is_remote(file_path):
|
|
126
|
+
continue
|
|
127
|
+
pq.ParquetFile(strip_protocol(file_path), filesystem=filesystem).metadata
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class HuggingFaceDatasetToParquetEncodingHandler(DataFrameEncoder):
|
|
131
|
+
def __init__(self):
|
|
132
|
+
super().__init__(datasets.Dataset, None, PARQUET)
|
|
133
|
+
|
|
134
|
+
async def encode(
|
|
135
|
+
self,
|
|
136
|
+
dataframe: DataFrame,
|
|
137
|
+
structured_dataset_type: types_pb2.StructuredDatasetType,
|
|
138
|
+
) -> literals_pb2.StructuredDataset:
|
|
139
|
+
val = dataframe.val
|
|
140
|
+
|
|
141
|
+
if val is None and dataframe.uri:
|
|
142
|
+
structured_dataset_type.format = PARQUET
|
|
143
|
+
return literals_pb2.StructuredDataset(
|
|
144
|
+
uri=typing.cast(str, dataframe.uri),
|
|
145
|
+
metadata=literals_pb2.StructuredDatasetMetadata(structured_dataset_type=structured_dataset_type),
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
if not dataframe.uri:
|
|
149
|
+
from flyte._context import internal_ctx
|
|
150
|
+
|
|
151
|
+
uri = str(internal_ctx().raw_data.get_random_remote_path())
|
|
152
|
+
else:
|
|
153
|
+
uri = typing.cast(str, dataframe.uri)
|
|
154
|
+
|
|
155
|
+
if not storage.is_remote(uri):
|
|
156
|
+
Path(uri).mkdir(parents=True, exist_ok=True)
|
|
157
|
+
|
|
158
|
+
path = join_uri_path(uri, f"{0:05}.parquet")
|
|
159
|
+
df = typing.cast(datasets.Dataset, val)
|
|
160
|
+
|
|
161
|
+
filesystem = storage.get_underlying_filesystem(path=path)
|
|
162
|
+
logger.info(
|
|
163
|
+
f"Writing Hugging Face Dataset output to "
|
|
164
|
+
f"{'remote' if storage.is_remote(uri) else 'local'} parquet directory {uri}"
|
|
165
|
+
)
|
|
166
|
+
await run_sync_io("write HuggingFace dataset", _write_dataset, df, path, filesystem)
|
|
167
|
+
|
|
168
|
+
structured_dataset_type.format = PARQUET
|
|
169
|
+
return literals_pb2.StructuredDataset(
|
|
170
|
+
uri=uri,
|
|
171
|
+
metadata=literals_pb2.StructuredDatasetMetadata(structured_dataset_type=structured_dataset_type),
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class ParquetToHuggingFaceDatasetDecodingHandler(DataFrameDecoder):
|
|
176
|
+
def __init__(self, protocol: str | None = None):
|
|
177
|
+
super().__init__(datasets.Dataset, protocol, PARQUET)
|
|
178
|
+
|
|
179
|
+
async def decode(
|
|
180
|
+
self,
|
|
181
|
+
flyte_value: literals_pb2.StructuredDataset,
|
|
182
|
+
current_task_metadata: literals_pb2.StructuredDatasetMetadata,
|
|
183
|
+
) -> datasets.Dataset:
|
|
184
|
+
uri = flyte_value.uri
|
|
185
|
+
|
|
186
|
+
if uri.startswith("hf://"):
|
|
187
|
+
source_uri = uri
|
|
188
|
+
try:
|
|
189
|
+
uri = await ensure_hf_cached(HFSource.from_hf_uri(uri))
|
|
190
|
+
except Exception as e:
|
|
191
|
+
raise RuntimeError(
|
|
192
|
+
f"Failed to materialize Hugging Face dataset from {uri}: {type(e).__name__}: {e!r}"
|
|
193
|
+
) from e
|
|
194
|
+
logger.info(
|
|
195
|
+
f"Resolved Hugging Face source {source_uri} to "
|
|
196
|
+
f"{'remote' if storage.is_remote(uri) else 'local'} parquet directory {uri}"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
filesystem = storage.get_underlying_filesystem(path=uri)
|
|
200
|
+
logger.info(
|
|
201
|
+
f"Reading Hugging Face Dataset parquet from "
|
|
202
|
+
f"{'remote' if storage.is_remote(uri) else 'local'} directory {uri}"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
parquet_files = await list_parquet_files(uri, filesystem)
|
|
206
|
+
columns = _requested_columns(current_task_metadata)
|
|
207
|
+
|
|
208
|
+
try:
|
|
209
|
+
if any(storage.is_remote(file_path) for file_path in parquet_files):
|
|
210
|
+
logger.info(f"Using direct remote parquet reads for {uri} via Flyte storage filesystem")
|
|
211
|
+
table = await run_sync_io(
|
|
212
|
+
"read parquet files",
|
|
213
|
+
_read_parquet_files,
|
|
214
|
+
parquet_files,
|
|
215
|
+
columns,
|
|
216
|
+
filesystem,
|
|
217
|
+
)
|
|
218
|
+
except Exception as exc:
|
|
219
|
+
if any(storage.is_remote(file_path) for file_path in parquet_files):
|
|
220
|
+
logger.warning(
|
|
221
|
+
f"Direct parquet read failed for {uri}: {type(exc).__name__}: {exc}. "
|
|
222
|
+
"Falling back to localizing remote parquet shards."
|
|
223
|
+
)
|
|
224
|
+
parquet_files = await _localize_parquet_files(parquet_files)
|
|
225
|
+
logger.info(f"Using localized parquet shard reads for {uri}")
|
|
226
|
+
table = await run_sync_io(
|
|
227
|
+
"read localized parquet files",
|
|
228
|
+
_read_parquet_files,
|
|
229
|
+
parquet_files,
|
|
230
|
+
columns,
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
raise
|
|
234
|
+
|
|
235
|
+
return datasets.Dataset(table)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class HFToHuggingFaceDatasetDecodingHandler(ParquetToHuggingFaceDatasetDecodingHandler):
|
|
239
|
+
def __init__(self):
|
|
240
|
+
super().__init__("hf")
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class HuggingFaceIterableDatasetToParquetEncodingHandler(DataFrameEncoder):
|
|
244
|
+
def __init__(self):
|
|
245
|
+
super().__init__(datasets.IterableDataset, None, PARQUET)
|
|
246
|
+
|
|
247
|
+
async def encode(
|
|
248
|
+
self,
|
|
249
|
+
dataframe: DataFrame,
|
|
250
|
+
structured_dataset_type: types_pb2.StructuredDatasetType,
|
|
251
|
+
) -> literals_pb2.StructuredDataset:
|
|
252
|
+
val = dataframe.val
|
|
253
|
+
|
|
254
|
+
if val is None and dataframe.uri:
|
|
255
|
+
structured_dataset_type.format = PARQUET
|
|
256
|
+
return literals_pb2.StructuredDataset(
|
|
257
|
+
uri=typing.cast(str, dataframe.uri),
|
|
258
|
+
metadata=literals_pb2.StructuredDatasetMetadata(structured_dataset_type=structured_dataset_type),
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
if not dataframe.uri:
|
|
262
|
+
from flyte._context import internal_ctx
|
|
263
|
+
|
|
264
|
+
uri = str(internal_ctx().raw_data.get_random_remote_path())
|
|
265
|
+
else:
|
|
266
|
+
uri = typing.cast(str, dataframe.uri)
|
|
267
|
+
|
|
268
|
+
if not storage.is_remote(uri):
|
|
269
|
+
Path(uri).mkdir(parents=True, exist_ok=True)
|
|
270
|
+
|
|
271
|
+
ds = typing.cast(datasets.IterableDataset, val)
|
|
272
|
+
filesystem = storage.get_underlying_filesystem(path=uri)
|
|
273
|
+
logger.info(
|
|
274
|
+
f"Writing Hugging Face IterableDataset output to "
|
|
275
|
+
f"{'remote' if storage.is_remote(uri) else 'local'} parquet directory {uri}"
|
|
276
|
+
)
|
|
277
|
+
await run_sync_io(
|
|
278
|
+
"write HuggingFace iterable dataset",
|
|
279
|
+
_write_iterable_dataset,
|
|
280
|
+
ds,
|
|
281
|
+
uri,
|
|
282
|
+
filesystem,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
structured_dataset_type.format = PARQUET
|
|
286
|
+
return literals_pb2.StructuredDataset(
|
|
287
|
+
uri=uri,
|
|
288
|
+
metadata=literals_pb2.StructuredDatasetMetadata(structured_dataset_type=structured_dataset_type),
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class ParquetToHuggingFaceIterableDatasetDecodingHandler(DataFrameDecoder):
|
|
293
|
+
def __init__(self, protocol: str | None = None):
|
|
294
|
+
super().__init__(datasets.IterableDataset, protocol, PARQUET)
|
|
295
|
+
|
|
296
|
+
async def decode(
|
|
297
|
+
self,
|
|
298
|
+
flyte_value: literals_pb2.StructuredDataset,
|
|
299
|
+
current_task_metadata: literals_pb2.StructuredDatasetMetadata,
|
|
300
|
+
) -> datasets.IterableDataset:
|
|
301
|
+
uri = flyte_value.uri
|
|
302
|
+
|
|
303
|
+
if uri.startswith("hf://"):
|
|
304
|
+
source_uri = uri
|
|
305
|
+
try:
|
|
306
|
+
uri = await ensure_hf_cached(HFSource.from_hf_uri(uri))
|
|
307
|
+
except Exception as e:
|
|
308
|
+
raise RuntimeError(
|
|
309
|
+
f"Failed to materialize Hugging Face dataset from {uri}: {type(e).__name__}: {e!r}"
|
|
310
|
+
) from e
|
|
311
|
+
logger.info(
|
|
312
|
+
f"Resolved Hugging Face source {source_uri} to "
|
|
313
|
+
f"{'remote' if storage.is_remote(uri) else 'local'} parquet directory {uri}"
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
filesystem = storage.get_underlying_filesystem(path=uri)
|
|
317
|
+
logger.info(
|
|
318
|
+
f"Reading Hugging Face IterableDataset parquet from "
|
|
319
|
+
f"{'remote' if storage.is_remote(uri) else 'local'} directory {uri}"
|
|
320
|
+
)
|
|
321
|
+
parquet_files = await list_parquet_files(uri, filesystem)
|
|
322
|
+
columns = _requested_columns(current_task_metadata)
|
|
323
|
+
|
|
324
|
+
try:
|
|
325
|
+
if any(storage.is_remote(file_path) for file_path in parquet_files):
|
|
326
|
+
logger.info(f"Using direct remote iterable parquet reads for {uri} via Flyte storage filesystem")
|
|
327
|
+
await run_sync_io(
|
|
328
|
+
"probe parquet files",
|
|
329
|
+
_probe_remote_parquet_files,
|
|
330
|
+
parquet_files,
|
|
331
|
+
filesystem,
|
|
332
|
+
)
|
|
333
|
+
except Exception as exc:
|
|
334
|
+
if any(storage.is_remote(file_path) for file_path in parquet_files):
|
|
335
|
+
logger.warning(
|
|
336
|
+
f"Direct iterable parquet access failed for {uri}: {type(exc).__name__}: {exc}. "
|
|
337
|
+
"Falling back to localizing remote parquet shards."
|
|
338
|
+
)
|
|
339
|
+
parquet_files = await _localize_parquet_files(parquet_files)
|
|
340
|
+
logger.info(f"Using localized iterable parquet shard reads for {uri}")
|
|
341
|
+
else:
|
|
342
|
+
raise
|
|
343
|
+
|
|
344
|
+
return datasets.IterableDataset.from_generator(
|
|
345
|
+
_iter_parquet_rows,
|
|
346
|
+
gen_kwargs={"parquet_files": parquet_files, "columns": columns},
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class HFToHuggingFaceIterableDatasetDecodingHandler(ParquetToHuggingFaceIterableDatasetDecodingHandler):
|
|
351
|
+
def __init__(self):
|
|
352
|
+
super().__init__("hf")
|
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: flyteplugins-huggingface
|
|
3
|
+
Version: 2.2.1
|
|
4
|
+
Summary: Hugging Face Plugin for Flyte
|
|
5
|
+
Author-email: André Ahlert <andre@aex.partners>, Samhita Alla <samhita@union.ai>
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: datasets
|
|
9
|
+
Requires-Dist: huggingface-hub
|
|
10
|
+
Requires-Dist: pyarrow
|
|
11
|
+
Requires-Dist: flyte
|
|
12
|
+
|
|
13
|
+
# Hugging Face Plugin
|
|
14
|
+
|
|
15
|
+
Native Flyte support for Hugging Face integrations in Flyte.
|
|
16
|
+
|
|
17
|
+
This plugin provides dataset support for Hugging Face `datasets.Dataset`
|
|
18
|
+
and `datasets.IterableDataset` objects. It gives you two related capabilities:
|
|
19
|
+
|
|
20
|
+
1. Use `from_hf(...)` to reference a dataset on the Hugging Face Hub as a task
|
|
21
|
+
input default.
|
|
22
|
+
2. Pass Hugging Face dataset objects between Flyte tasks with automatic Parquet
|
|
23
|
+
serialization.
|
|
24
|
+
|
|
25
|
+
The plugin works by treating Hub datasets as Parquet-backed structured data. For
|
|
26
|
+
Hub sources, it first resolves the dataset's converted Parquet shards, then
|
|
27
|
+
materializes them either into a generated path for the current run or into a
|
|
28
|
+
shared artifact registry rooted at `cache_root`.
|
|
29
|
+
|
|
30
|
+
## Installation
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
pip install flyteplugins-huggingface
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
## Quick start
|
|
37
|
+
|
|
38
|
+
```python
|
|
39
|
+
import datasets
|
|
40
|
+
import flyte
|
|
41
|
+
from flyteplugins.huggingface.datasets import from_hf
|
|
42
|
+
|
|
43
|
+
env = flyte.TaskEnvironment(name="hf-example")
|
|
44
|
+
|
|
45
|
+
@env.task
|
|
46
|
+
async def count_reviews(
|
|
47
|
+
ds: datasets.Dataset = from_hf(
|
|
48
|
+
"stanfordnlp/imdb",
|
|
49
|
+
name="plain_text",
|
|
50
|
+
split="train",
|
|
51
|
+
),
|
|
52
|
+
) -> int:
|
|
53
|
+
return len(ds)
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
At the Flyte literal level this source is represented as an `hf://` URI, for
|
|
57
|
+
example:
|
|
58
|
+
|
|
59
|
+
```text
|
|
60
|
+
hf://stanfordnlp/imdb?name=plain_text&split=train
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
The task receives a hydrated `datasets.Dataset`. The `hf://` URI is only the
|
|
64
|
+
reference used between Flyte and the plugin.
|
|
65
|
+
|
|
66
|
+
## `from_hf(...)`
|
|
67
|
+
|
|
68
|
+
`from_hf(...)` is the entry point for Hub-backed task defaults:
|
|
69
|
+
|
|
70
|
+
```python
|
|
71
|
+
from flyteplugins.huggingface.datasets import from_hf
|
|
72
|
+
|
|
73
|
+
from_hf(
|
|
74
|
+
repo: str,
|
|
75
|
+
*,
|
|
76
|
+
name: str | None = None,
|
|
77
|
+
split: str | None = None,
|
|
78
|
+
revision: str | None = None,
|
|
79
|
+
cache_root: str | None = None,
|
|
80
|
+
)
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
Arguments:
|
|
84
|
+
|
|
85
|
+
- `repo`: Hugging Face dataset repo, such as `"stanfordnlp/imdb"` or `"glue"`.
|
|
86
|
+
- `name`: Optional dataset config/subset.
|
|
87
|
+
- `split`: Optional split such as `"train"` or `"validation"`.
|
|
88
|
+
- `revision`: Optional Hub revision. Defaults to `refs/convert/parquet`.
|
|
89
|
+
- `cache_root`: Optional shared remote cache root for cross-run reuse.
|
|
90
|
+
|
|
91
|
+
`from_hf(...)` returns a Flyte `DataFrame` reference, not an eagerly loaded
|
|
92
|
+
dataset object. When the task input is typed as `datasets.Dataset` or
|
|
93
|
+
`datasets.IterableDataset`, the plugin decoder materializes that reference into
|
|
94
|
+
the requested Hugging Face type.
|
|
95
|
+
|
|
96
|
+
## Config resolution
|
|
97
|
+
|
|
98
|
+
If you specify `name`, the plugin uses that config directly.
|
|
99
|
+
|
|
100
|
+
If you omit `name`, the plugin resolves the config as follows:
|
|
101
|
+
|
|
102
|
+
1. Try actual converted-parquet config `default`.
|
|
103
|
+
2. If `default` does not exist and there is exactly one config, use that one.
|
|
104
|
+
3. If there are multiple configs, raise an error and ask for `name=...`.
|
|
105
|
+
|
|
106
|
+
Examples:
|
|
107
|
+
|
|
108
|
+
```python
|
|
109
|
+
# Works: imdb has a single converted-parquet config, plain_text.
|
|
110
|
+
from_hf("stanfordnlp/imdb", split="train")
|
|
111
|
+
|
|
112
|
+
# Required: glue has multiple configs such as mrpc, sst2, qnli, ...
|
|
113
|
+
from_hf("glue", name="mrpc", split="train")
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
Using `name=` explicitly is recommended in examples and production code because
|
|
117
|
+
it makes the UI literal and task signature more obvious.
|
|
118
|
+
|
|
119
|
+
## Split behavior
|
|
120
|
+
|
|
121
|
+
If you specify `split`, only that split is materialized:
|
|
122
|
+
|
|
123
|
+
```python
|
|
124
|
+
@env.task
|
|
125
|
+
async def train_split(
|
|
126
|
+
ds: datasets.Dataset = from_hf(
|
|
127
|
+
"stanfordnlp/imdb",
|
|
128
|
+
name="plain_text",
|
|
129
|
+
split="train",
|
|
130
|
+
),
|
|
131
|
+
) -> int:
|
|
132
|
+
return len(ds)
|
|
133
|
+
```
|
|
134
|
+
|
|
135
|
+
If you omit `split`, the plugin reads every converted Parquet split under the
|
|
136
|
+
resolved config and presents them as one dataset stream/table:
|
|
137
|
+
|
|
138
|
+
```python
|
|
139
|
+
@env.task
|
|
140
|
+
async def all_splits(
|
|
141
|
+
ds: datasets.Dataset = from_hf(
|
|
142
|
+
"stanfordnlp/imdb",
|
|
143
|
+
name="plain_text",
|
|
144
|
+
),
|
|
145
|
+
) -> list[str]:
|
|
146
|
+
return ds.column_names
|
|
147
|
+
```
|
|
148
|
+
|
|
149
|
+
That means the result is a combined dataset, not a mapping of split name to
|
|
150
|
+
dataset.
|
|
151
|
+
|
|
152
|
+
## Cross-run reuse with `cache_root`
|
|
153
|
+
|
|
154
|
+
Without `cache_root`, a Hub source is materialized into a generated path for the
|
|
155
|
+
current execution only.
|
|
156
|
+
|
|
157
|
+
With `cache_root`, the plugin uses a shared cache registry so later runs can
|
|
158
|
+
skip the Hub download entirely:
|
|
159
|
+
|
|
160
|
+
```python
|
|
161
|
+
@env.task
|
|
162
|
+
async def train_cached(
|
|
163
|
+
ds: datasets.Dataset = from_hf(
|
|
164
|
+
"stanfordnlp/imdb",
|
|
165
|
+
name="plain_text",
|
|
166
|
+
split="train",
|
|
167
|
+
cache_root="s3://my-bucket/flyte-hf-cache",
|
|
168
|
+
),
|
|
169
|
+
) -> int:
|
|
170
|
+
return len(ds)
|
|
171
|
+
```
|
|
172
|
+
|
|
173
|
+
The shared cache layout is:
|
|
174
|
+
|
|
175
|
+
```text
|
|
176
|
+
{cache_root}/huggingface/datasets/
|
|
177
|
+
by-key/{source-cache-key}.json
|
|
178
|
+
blobs/{source-cache-key}/...
|
|
179
|
+
```
|
|
180
|
+
|
|
181
|
+
The cache key is derived from:
|
|
182
|
+
|
|
183
|
+
- repo
|
|
184
|
+
- config name
|
|
185
|
+
- split
|
|
186
|
+
- revision
|
|
187
|
+
- resolved Parquet shard metadata
|
|
188
|
+
|
|
189
|
+
This means the cache is stable across runs as long as the underlying converted
|
|
190
|
+
Parquet source does not change.
|
|
191
|
+
|
|
192
|
+
The canonical artifact location is always
|
|
193
|
+
`{cache_root}/huggingface/datasets/blobs/{source-cache-key}/...`. The registry
|
|
194
|
+
record under `by-key/` is metadata for that cache key.
|
|
195
|
+
|
|
196
|
+
## What the plugin logs
|
|
197
|
+
|
|
198
|
+
When `LOG_LEVEL` is `INFO` or lower, the plugin logs whether it is:
|
|
199
|
+
|
|
200
|
+
- checking the shared dataset cache
|
|
201
|
+
- materializing from the Hugging Face Hub
|
|
202
|
+
- using a cached artifact
|
|
203
|
+
- reading Parquet from a local or remote directory
|
|
204
|
+
|
|
205
|
+
This is the easiest way to confirm whether a run is reading from the Hub or
|
|
206
|
+
from your shared cache artifact.
|
|
207
|
+
|
|
208
|
+
## `datasets.Dataset` between tasks
|
|
209
|
+
|
|
210
|
+
You can return and pass real `datasets.Dataset` objects between tasks:
|
|
211
|
+
|
|
212
|
+
```python
|
|
213
|
+
import datasets
|
|
214
|
+
import flyte
|
|
215
|
+
|
|
216
|
+
env = flyte.TaskEnvironment(name="hf-transform")
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@env.task
|
|
220
|
+
async def create_dataset() -> datasets.Dataset:
|
|
221
|
+
return datasets.Dataset.from_dict(
|
|
222
|
+
{
|
|
223
|
+
"text": ["hello", "world", "flyte"],
|
|
224
|
+
"label": [0, 1, 0],
|
|
225
|
+
}
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
@env.task
|
|
230
|
+
async def filter_positive(ds: datasets.Dataset) -> datasets.Dataset:
|
|
231
|
+
return ds.filter(lambda row: row["label"] == 1)
|
|
232
|
+
```
|
|
233
|
+
|
|
234
|
+
Task-produced in-memory datasets are serialized to Parquet automatically. This
|
|
235
|
+
is separate from `from_hf(...)`, which is a source reference rather than a
|
|
236
|
+
materialized dataset object.
|
|
237
|
+
|
|
238
|
+
## `datasets.IterableDataset`
|
|
239
|
+
|
|
240
|
+
Use `datasets.IterableDataset` when you want row streaming behavior instead of a
|
|
241
|
+
fully materialized table:
|
|
242
|
+
|
|
243
|
+
```python
|
|
244
|
+
@env.task
|
|
245
|
+
async def stream_reviews(
|
|
246
|
+
ds: datasets.IterableDataset = from_hf(
|
|
247
|
+
"stanfordnlp/imdb",
|
|
248
|
+
name="plain_text",
|
|
249
|
+
split="train",
|
|
250
|
+
cache_root="s3://my-bucket/flyte-hf-cache",
|
|
251
|
+
),
|
|
252
|
+
) -> datasets.IterableDataset:
|
|
253
|
+
def add_length(batch):
|
|
254
|
+
batch["length"] = [len(text) for text in batch["text"]]
|
|
255
|
+
return batch
|
|
256
|
+
|
|
257
|
+
return ds.map(add_length, batched=True)
|
|
258
|
+
```
|
|
259
|
+
|
|
260
|
+
Notes:
|
|
261
|
+
|
|
262
|
+
- The returned Hugging Face `IterableDataset` is consumed with normal synchronous
|
|
263
|
+
iteration.
|
|
264
|
+
- Internally the plugin streams row batches from Parquet files.
|
|
265
|
+
- Iterable outputs are written back as sharded Parquet directories.
|
|
266
|
+
|
|
267
|
+
## Column projection
|
|
268
|
+
|
|
269
|
+
Use a Flyte structured-dataset column annotation when you only want selected
|
|
270
|
+
columns:
|
|
271
|
+
|
|
272
|
+
```python
|
|
273
|
+
from collections import OrderedDict
|
|
274
|
+
from typing import Annotated
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
@env.task
|
|
278
|
+
async def load_text_only(
|
|
279
|
+
ds: Annotated[datasets.Dataset, OrderedDict(text=str)] = from_hf(
|
|
280
|
+
"stanfordnlp/imdb",
|
|
281
|
+
name="plain_text",
|
|
282
|
+
split="train",
|
|
283
|
+
),
|
|
284
|
+
) -> list[str]:
|
|
285
|
+
return ds["text"][:10]
|
|
286
|
+
```
|
|
287
|
+
|
|
288
|
+
The plugin uses the annotation to request only those columns when reading
|
|
289
|
+
Parquet.
|
|
290
|
+
|
|
291
|
+
## Revision selection
|
|
292
|
+
|
|
293
|
+
Use `revision=` if you want to pin a specific converted-Parquet revision:
|
|
294
|
+
|
|
295
|
+
```python
|
|
296
|
+
@env.task
|
|
297
|
+
async def pinned_revision(
|
|
298
|
+
ds: datasets.Dataset = from_hf(
|
|
299
|
+
"stanfordnlp/imdb",
|
|
300
|
+
name="plain_text",
|
|
301
|
+
split="train",
|
|
302
|
+
revision="refs/convert/parquet",
|
|
303
|
+
cache_root="s3://my-bucket/flyte-hf-cache",
|
|
304
|
+
),
|
|
305
|
+
) -> int:
|
|
306
|
+
return len(ds)
|
|
307
|
+
```
|
|
308
|
+
|
|
309
|
+
If you do not specify a revision, the plugin uses `refs/convert/parquet`.
|
|
310
|
+
|
|
311
|
+
## Local vs remote behavior
|
|
312
|
+
|
|
313
|
+
There are two distinct layers to keep in mind:
|
|
314
|
+
|
|
315
|
+
1. Task inputs and outputs inside Flyte tasks.
|
|
316
|
+
2. What your launcher process sees when a run completes.
|
|
317
|
+
|
|
318
|
+
Inside a task, a parameter typed as `datasets.Dataset` or
|
|
319
|
+
`datasets.IterableDataset` is hydrated by the plugin into a Hugging Face object.
|
|
320
|
+
|
|
321
|
+
Outside the task, especially for remote runs, outputs are often represented to
|
|
322
|
+
the launcher as Flyte `DataFrame` references rather than already-opened Hugging
|
|
323
|
+
Face dataset objects. That is expected: the structured dataset literal remains
|
|
324
|
+
the transport format.
|
|
325
|
+
|
|
326
|
+
## Private datasets
|
|
327
|
+
|
|
328
|
+
Set `HF_TOKEN` in the task environment to access private Hugging Face datasets.
|
|
329
|
+
Without it, the plugin uses anonymous Hub access.
|
|
330
|
+
|
|
331
|
+
## Failure modes
|
|
332
|
+
|
|
333
|
+
Common issues:
|
|
334
|
+
|
|
335
|
+
- Missing `name` for a dataset with multiple configs:
|
|
336
|
+
the plugin raises and asks for `name=...`.
|
|
337
|
+
- No converted Parquet shards available:
|
|
338
|
+
the dataset may not have an auto-converted Parquet representation yet.
|
|
339
|
+
- Remote cache path credentials:
|
|
340
|
+
your Flyte runtime must be able to read and write the chosen `cache_root`.
|
|
341
|
+
|
|
342
|
+
## Example
|
|
343
|
+
|
|
344
|
+
See the example workflow in `plugins/huggingface/examples/hf_dataset_workflow.py`
|
|
345
|
+
for end-to-end local and remote scenarios.
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
flyteplugins/huggingface/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
flyteplugins/huggingface/datasets/__init__.py,sha256=8wuGxh0ayadFTxcr58VsrOTYrSz5FbZrpIUtjZ_6iqA,1363
|
|
3
|
+
flyteplugins/huggingface/datasets/_io.py,sha256=lZYX6Uc_6ocrlwmQ_GXbf7VUzk6hWZK_wPo5EGDWe5Q,16186
|
|
4
|
+
flyteplugins/huggingface/datasets/_source.py,sha256=YfeCPyN_arxfJ1y8eGm4Jwwu_FmcHsSo_BO4Zh7PuTc,4825
|
|
5
|
+
flyteplugins/huggingface/datasets/_transformers.py,sha256=XjX5zx__PLfZgn3fObbhlWuXxDGWABZAm9ZsqFtLDUk,13008
|
|
6
|
+
flyteplugins_huggingface-2.2.1.dist-info/METADATA,sha256=up4CpJpRR3P8qvzjVS6RlSuVKLKudJtzeItflgENBjQ,9230
|
|
7
|
+
flyteplugins_huggingface-2.2.1.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
8
|
+
flyteplugins_huggingface-2.2.1.dist-info/entry_points.txt,sha256=ixxMpwmr3wlKwLK5GNlxnWeS4i4Pub3O7aNsEQFkjE4,112
|
|
9
|
+
flyteplugins_huggingface-2.2.1.dist-info/top_level.txt,sha256=cgd779rPu9EsvdtuYgUxNHHgElaQvPn74KhB5XSeMBE,13
|
|
10
|
+
flyteplugins_huggingface-2.2.1.dist-info/RECORD,,
|