atdata 0.2.3b1__py3-none-any.whl → 0.3.1b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/.gitignore +1 -0
- atdata/__init__.py +39 -0
- atdata/_cid.py +0 -21
- atdata/_exceptions.py +168 -0
- atdata/_helpers.py +41 -15
- atdata/_hf_api.py +95 -11
- atdata/_logging.py +70 -0
- atdata/_protocols.py +77 -238
- atdata/_schema_codec.py +7 -6
- atdata/_stub_manager.py +5 -25
- atdata/_type_utils.py +28 -2
- atdata/atmosphere/__init__.py +31 -20
- atdata/atmosphere/_types.py +4 -4
- atdata/atmosphere/client.py +64 -12
- atdata/atmosphere/lens.py +11 -12
- atdata/atmosphere/records.py +12 -12
- atdata/atmosphere/schema.py +16 -18
- atdata/atmosphere/store.py +6 -7
- atdata/cli/__init__.py +161 -175
- atdata/cli/diagnose.py +2 -2
- atdata/cli/{local.py → infra.py} +11 -11
- atdata/cli/inspect.py +69 -0
- atdata/cli/preview.py +63 -0
- atdata/cli/schema.py +109 -0
- atdata/dataset.py +583 -328
- atdata/index/__init__.py +54 -0
- atdata/index/_entry.py +157 -0
- atdata/index/_index.py +1198 -0
- atdata/index/_schema.py +380 -0
- atdata/lens.py +9 -2
- atdata/lexicons/__init__.py +121 -0
- atdata/lexicons/ac.foundation.dataset.arrayFormat.json +16 -0
- atdata/lexicons/ac.foundation.dataset.getLatestSchema.json +78 -0
- atdata/lexicons/ac.foundation.dataset.lens.json +99 -0
- atdata/lexicons/ac.foundation.dataset.record.json +96 -0
- atdata/lexicons/ac.foundation.dataset.schema.json +107 -0
- atdata/lexicons/ac.foundation.dataset.schemaType.json +16 -0
- atdata/lexicons/ac.foundation.dataset.storageBlobs.json +24 -0
- atdata/lexicons/ac.foundation.dataset.storageExternal.json +25 -0
- atdata/lexicons/ndarray_shim.json +16 -0
- atdata/local/__init__.py +70 -0
- atdata/local/_repo_legacy.py +218 -0
- atdata/manifest/__init__.py +28 -0
- atdata/manifest/_aggregates.py +156 -0
- atdata/manifest/_builder.py +163 -0
- atdata/manifest/_fields.py +154 -0
- atdata/manifest/_manifest.py +146 -0
- atdata/manifest/_query.py +150 -0
- atdata/manifest/_writer.py +74 -0
- atdata/promote.py +18 -14
- atdata/providers/__init__.py +25 -0
- atdata/providers/_base.py +140 -0
- atdata/providers/_factory.py +69 -0
- atdata/providers/_postgres.py +214 -0
- atdata/providers/_redis.py +171 -0
- atdata/providers/_sqlite.py +191 -0
- atdata/repository.py +323 -0
- atdata/stores/__init__.py +23 -0
- atdata/stores/_disk.py +123 -0
- atdata/stores/_s3.py +349 -0
- atdata/testing.py +341 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/METADATA +5 -2
- atdata-0.3.1b1.dist-info/RECORD +67 -0
- atdata/local.py +0 -1720
- atdata-0.2.3b1.dist-info/RECORD +0 -28
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/licenses/LICENSE +0 -0
atdata/dataset.py
CHANGED
|
@@ -31,6 +31,7 @@ Examples:
|
|
|
31
31
|
import webdataset as wds
|
|
32
32
|
|
|
33
33
|
from pathlib import Path
|
|
34
|
+
import itertools
|
|
34
35
|
import uuid
|
|
35
36
|
|
|
36
37
|
import dataclasses
|
|
@@ -42,15 +43,16 @@ from dataclasses import (
|
|
|
42
43
|
from abc import ABC
|
|
43
44
|
|
|
44
45
|
from ._sources import URLSource
|
|
45
|
-
from ._protocols import DataSource
|
|
46
|
+
from ._protocols import DataSource, Packable
|
|
47
|
+
from ._exceptions import SampleKeyError, PartialFailureError
|
|
46
48
|
|
|
47
|
-
from tqdm import tqdm
|
|
48
49
|
import numpy as np
|
|
49
50
|
import pandas as pd
|
|
50
51
|
import requests
|
|
51
52
|
|
|
52
53
|
import typing
|
|
53
54
|
from typing import (
|
|
55
|
+
TYPE_CHECKING,
|
|
54
56
|
Any,
|
|
55
57
|
Optional,
|
|
56
58
|
Dict,
|
|
@@ -66,6 +68,9 @@ from typing import (
|
|
|
66
68
|
dataclass_transform,
|
|
67
69
|
overload,
|
|
68
70
|
)
|
|
71
|
+
|
|
72
|
+
if TYPE_CHECKING:
|
|
73
|
+
from .manifest._query import SampleLocation
|
|
69
74
|
from numpy.typing import NDArray
|
|
70
75
|
|
|
71
76
|
import msgpack
|
|
@@ -94,9 +99,11 @@ DT = TypeVar("DT")
|
|
|
94
99
|
|
|
95
100
|
|
|
96
101
|
def _make_packable(x):
|
|
97
|
-
"""Convert numpy arrays to bytes;
|
|
102
|
+
"""Convert numpy arrays to bytes; coerce numpy scalars to Python natives."""
|
|
98
103
|
if isinstance(x, np.ndarray):
|
|
99
104
|
return eh.array_to_bytes(x)
|
|
105
|
+
if isinstance(x, np.generic):
|
|
106
|
+
return x.item()
|
|
100
107
|
return x
|
|
101
108
|
|
|
102
109
|
|
|
@@ -157,37 +164,17 @@ class DictSample:
|
|
|
157
164
|
|
|
158
165
|
@classmethod
|
|
159
166
|
def from_data(cls, data: dict[str, Any]) -> "DictSample":
|
|
160
|
-
"""Create a DictSample from unpacked msgpack data.
|
|
161
|
-
|
|
162
|
-
Args:
|
|
163
|
-
data: Dictionary with field names as keys.
|
|
164
|
-
|
|
165
|
-
Returns:
|
|
166
|
-
New DictSample instance wrapping the data.
|
|
167
|
-
"""
|
|
167
|
+
"""Create a DictSample from unpacked msgpack data."""
|
|
168
168
|
return cls(_data=data)
|
|
169
169
|
|
|
170
170
|
@classmethod
|
|
171
171
|
def from_bytes(cls, bs: bytes) -> "DictSample":
|
|
172
|
-
"""Create a DictSample from raw msgpack bytes.
|
|
173
|
-
|
|
174
|
-
Args:
|
|
175
|
-
bs: Raw bytes from a msgpack-serialized sample.
|
|
176
|
-
|
|
177
|
-
Returns:
|
|
178
|
-
New DictSample instance with the unpacked data.
|
|
179
|
-
"""
|
|
172
|
+
"""Create a DictSample from raw msgpack bytes."""
|
|
180
173
|
return cls.from_data(ormsgpack.unpackb(bs))
|
|
181
174
|
|
|
182
175
|
def __getattr__(self, name: str) -> Any:
|
|
183
176
|
"""Access a field by attribute name.
|
|
184
177
|
|
|
185
|
-
Args:
|
|
186
|
-
name: Field name to access.
|
|
187
|
-
|
|
188
|
-
Returns:
|
|
189
|
-
The field value.
|
|
190
|
-
|
|
191
178
|
Raises:
|
|
192
179
|
AttributeError: If the field doesn't exist.
|
|
193
180
|
"""
|
|
@@ -203,21 +190,9 @@ class DictSample:
|
|
|
203
190
|
) from None
|
|
204
191
|
|
|
205
192
|
def __getitem__(self, key: str) -> Any:
|
|
206
|
-
"""Access a field by dict key.
|
|
207
|
-
|
|
208
|
-
Args:
|
|
209
|
-
key: Field name to access.
|
|
210
|
-
|
|
211
|
-
Returns:
|
|
212
|
-
The field value.
|
|
213
|
-
|
|
214
|
-
Raises:
|
|
215
|
-
KeyError: If the field doesn't exist.
|
|
216
|
-
"""
|
|
217
193
|
return self._data[key]
|
|
218
194
|
|
|
219
195
|
def __contains__(self, key: str) -> bool:
|
|
220
|
-
"""Check if a field exists."""
|
|
221
196
|
return key in self._data
|
|
222
197
|
|
|
223
198
|
def keys(self) -> list[str]:
|
|
@@ -225,23 +200,13 @@ class DictSample:
|
|
|
225
200
|
return list(self._data.keys())
|
|
226
201
|
|
|
227
202
|
def values(self) -> list[Any]:
|
|
228
|
-
"""Return list of field values."""
|
|
229
203
|
return list(self._data.values())
|
|
230
204
|
|
|
231
205
|
def items(self) -> list[tuple[str, Any]]:
|
|
232
|
-
"""Return list of (field_name, value) tuples."""
|
|
233
206
|
return list(self._data.items())
|
|
234
207
|
|
|
235
208
|
def get(self, key: str, default: Any = None) -> Any:
|
|
236
|
-
"""Get a field value
|
|
237
|
-
|
|
238
|
-
Args:
|
|
239
|
-
key: Field name to access.
|
|
240
|
-
default: Value to return if field doesn't exist.
|
|
241
|
-
|
|
242
|
-
Returns:
|
|
243
|
-
The field value or default.
|
|
244
|
-
"""
|
|
209
|
+
"""Get a field value, returning *default* if missing."""
|
|
245
210
|
return self._data.get(key, default)
|
|
246
211
|
|
|
247
212
|
def to_dict(self) -> dict[str, Any]:
|
|
@@ -250,20 +215,12 @@ class DictSample:
|
|
|
250
215
|
|
|
251
216
|
@property
|
|
252
217
|
def packed(self) -> bytes:
|
|
253
|
-
"""
|
|
254
|
-
|
|
255
|
-
Returns:
|
|
256
|
-
Raw msgpack bytes representing this sample's data.
|
|
257
|
-
"""
|
|
218
|
+
"""Serialize to msgpack bytes."""
|
|
258
219
|
return msgpack.packb(self._data)
|
|
259
220
|
|
|
260
221
|
@property
|
|
261
222
|
def as_wds(self) -> "WDSRawSample":
|
|
262
|
-
"""
|
|
263
|
-
|
|
264
|
-
Returns:
|
|
265
|
-
A dictionary with ``__key__`` and ``msgpack`` fields.
|
|
266
|
-
"""
|
|
223
|
+
"""Serialize for writing to WebDataset (``__key__`` + ``msgpack``)."""
|
|
267
224
|
return {
|
|
268
225
|
"__key__": str(uuid.uuid1(0, 0)),
|
|
269
226
|
"msgpack": self.packed,
|
|
@@ -300,31 +257,13 @@ class PackableSample(ABC):
|
|
|
300
257
|
|
|
301
258
|
def _ensure_good(self):
|
|
302
259
|
"""Convert bytes to NDArray for fields annotated as NDArray or NDArray | None."""
|
|
303
|
-
|
|
304
|
-
# Auto-convert known types when annotated
|
|
305
|
-
# for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
|
|
306
260
|
for field in dataclasses.fields(self):
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
# Annotation for this variable is to be an NDArray
|
|
311
|
-
if _is_possibly_ndarray_type(var_type):
|
|
312
|
-
# ... so, we'll always auto-convert to numpy
|
|
313
|
-
|
|
314
|
-
var_cur_value = getattr(self, var_name)
|
|
315
|
-
|
|
316
|
-
# Execute the appropriate conversion for intermediate data
|
|
317
|
-
# based on what is provided
|
|
318
|
-
|
|
319
|
-
if isinstance(var_cur_value, np.ndarray):
|
|
320
|
-
# Already the correct type, no conversion needed
|
|
261
|
+
if _is_possibly_ndarray_type(field.type):
|
|
262
|
+
value = getattr(self, field.name)
|
|
263
|
+
if isinstance(value, np.ndarray):
|
|
321
264
|
continue
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
# Design note: bytes in NDArray-typed fields are always interpreted
|
|
325
|
-
# as serialized arrays. This means raw bytes fields must not be
|
|
326
|
-
# annotated as NDArray.
|
|
327
|
-
setattr(self, var_name, eh.bytes_to_array(var_cur_value))
|
|
265
|
+
elif isinstance(value, bytes):
|
|
266
|
+
setattr(self, field.name, eh.bytes_to_array(value))
|
|
328
267
|
|
|
329
268
|
def __post_init__(self):
|
|
330
269
|
self._ensure_good()
|
|
@@ -333,67 +272,31 @@ class PackableSample(ABC):
|
|
|
333
272
|
|
|
334
273
|
@classmethod
|
|
335
274
|
def from_data(cls, data: WDSRawSample) -> Self:
|
|
336
|
-
"""Create
|
|
337
|
-
|
|
338
|
-
Args:
|
|
339
|
-
data: Dictionary with keys matching the sample's field names.
|
|
340
|
-
|
|
341
|
-
Returns:
|
|
342
|
-
New instance with NDArray fields auto-converted from bytes.
|
|
343
|
-
"""
|
|
275
|
+
"""Create an instance from unpacked msgpack data."""
|
|
344
276
|
return cls(**data)
|
|
345
277
|
|
|
346
278
|
@classmethod
|
|
347
279
|
def from_bytes(cls, bs: bytes) -> Self:
|
|
348
|
-
"""Create
|
|
349
|
-
|
|
350
|
-
Args:
|
|
351
|
-
bs: Raw bytes from a msgpack-serialized sample.
|
|
352
|
-
|
|
353
|
-
Returns:
|
|
354
|
-
A new instance of this sample class deserialized from the bytes.
|
|
355
|
-
"""
|
|
280
|
+
"""Create an instance from raw msgpack bytes."""
|
|
356
281
|
return cls.from_data(ormsgpack.unpackb(bs))
|
|
357
282
|
|
|
358
283
|
@property
|
|
359
284
|
def packed(self) -> bytes:
|
|
360
|
-
"""
|
|
361
|
-
|
|
362
|
-
NDArray fields are automatically converted to bytes before packing.
|
|
363
|
-
All other fields are packed as-is if they're msgpack-compatible.
|
|
364
|
-
|
|
365
|
-
Returns:
|
|
366
|
-
Raw msgpack bytes representing this sample's data.
|
|
285
|
+
"""Serialize to msgpack bytes. NDArray fields are auto-converted.
|
|
367
286
|
|
|
368
287
|
Raises:
|
|
369
288
|
RuntimeError: If msgpack serialization fails.
|
|
370
289
|
"""
|
|
371
|
-
|
|
372
|
-
# Make sure that all of our (possibly unpackable) data is in a packable
|
|
373
|
-
# format
|
|
374
290
|
o = {k: _make_packable(v) for k, v in vars(self).items()}
|
|
375
|
-
|
|
376
291
|
ret = msgpack.packb(o)
|
|
377
|
-
|
|
378
292
|
if ret is None:
|
|
379
293
|
raise RuntimeError(f"Failed to pack sample to bytes: {o}")
|
|
380
|
-
|
|
381
294
|
return ret
|
|
382
295
|
|
|
383
296
|
@property
|
|
384
297
|
def as_wds(self) -> WDSRawSample:
|
|
385
|
-
"""
|
|
386
|
-
|
|
387
|
-
Returns:
|
|
388
|
-
A dictionary with ``__key__`` (UUID v1 for sortable keys) and
|
|
389
|
-
``msgpack`` (packed sample data) fields suitable for WebDataset.
|
|
390
|
-
|
|
391
|
-
Note:
|
|
392
|
-
Keys are auto-generated as UUID v1 for time-sortable ordering.
|
|
393
|
-
Custom key specification is not currently supported.
|
|
394
|
-
"""
|
|
298
|
+
"""Serialize for writing to WebDataset (``__key__`` + ``msgpack``)."""
|
|
395
299
|
return {
|
|
396
|
-
# Generates a UUID that is timelike-sortable
|
|
397
300
|
"__key__": str(uuid.uuid1(0, 0)),
|
|
398
301
|
"msgpack": self.packed,
|
|
399
302
|
}
|
|
@@ -404,82 +307,45 @@ def _batch_aggregate(xs: Sequence):
|
|
|
404
307
|
if not xs:
|
|
405
308
|
return []
|
|
406
309
|
if isinstance(xs[0], np.ndarray):
|
|
407
|
-
return np.
|
|
310
|
+
return np.stack(xs)
|
|
408
311
|
return list(xs)
|
|
409
312
|
|
|
410
313
|
|
|
411
314
|
class SampleBatch(Generic[DT]):
|
|
412
315
|
"""A batch of samples with automatic attribute aggregation.
|
|
413
316
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
samples in the batch.
|
|
418
|
-
|
|
419
|
-
NDArray fields are stacked into a numpy array with a batch dimension.
|
|
420
|
-
Other fields are aggregated into a list.
|
|
317
|
+
Accessing an attribute aggregates that field across all samples:
|
|
318
|
+
NDArray fields are stacked into a numpy array with a batch dimension;
|
|
319
|
+
other fields are collected into a list. Results are cached.
|
|
421
320
|
|
|
422
321
|
Parameters:
|
|
423
322
|
DT: The sample type, must derive from ``PackableSample``.
|
|
424
323
|
|
|
425
|
-
Attributes:
|
|
426
|
-
samples: The list of sample instances in this batch.
|
|
427
|
-
|
|
428
324
|
Examples:
|
|
429
325
|
>>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
|
|
430
|
-
>>> batch.embeddings #
|
|
431
|
-
>>> batch.names #
|
|
432
|
-
|
|
433
|
-
Note:
|
|
434
|
-
This class uses Python's ``__orig_class__`` mechanism to extract the
|
|
435
|
-
type parameter at runtime. Instances must be created using the
|
|
436
|
-
subscripted syntax ``SampleBatch[MyType](samples)`` rather than
|
|
437
|
-
calling the constructor directly with an unsubscripted class.
|
|
326
|
+
>>> batch.embeddings # Stacked numpy array of shape (3, ...)
|
|
327
|
+
>>> batch.names # List of names
|
|
438
328
|
"""
|
|
439
329
|
|
|
440
|
-
# Design note: The docstring uses "Parameters:" for type parameters because
|
|
441
|
-
# quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
|
|
442
|
-
|
|
443
330
|
def __init__(self, samples: Sequence[DT]):
|
|
444
|
-
"""Create a batch from a sequence of samples.
|
|
445
|
-
|
|
446
|
-
Args:
|
|
447
|
-
samples: A sequence of sample instances to aggregate into a batch.
|
|
448
|
-
Each sample must be an instance of a type derived from
|
|
449
|
-
``PackableSample``.
|
|
450
|
-
"""
|
|
331
|
+
"""Create a batch from a sequence of samples."""
|
|
451
332
|
self.samples = list(samples)
|
|
452
333
|
self._aggregate_cache = dict()
|
|
453
334
|
self._sample_type_cache: Type | None = None
|
|
454
335
|
|
|
455
336
|
@property
|
|
456
337
|
def sample_type(self) -> Type:
|
|
457
|
-
"""The type
|
|
458
|
-
|
|
459
|
-
Returns:
|
|
460
|
-
The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
|
|
461
|
-
"""
|
|
338
|
+
"""The type parameter ``DT`` used when creating this batch."""
|
|
462
339
|
if self._sample_type_cache is None:
|
|
463
340
|
self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
|
|
464
|
-
|
|
341
|
+
if self._sample_type_cache is None:
|
|
342
|
+
raise TypeError(
|
|
343
|
+
"SampleBatch requires a type parameter, e.g. SampleBatch[MySample]"
|
|
344
|
+
)
|
|
465
345
|
return self._sample_type_cache
|
|
466
346
|
|
|
467
347
|
def __getattr__(self, name):
|
|
468
|
-
"""Aggregate
|
|
469
|
-
|
|
470
|
-
This magic method enables attribute-style access to aggregated sample
|
|
471
|
-
fields. Results are cached for efficiency.
|
|
472
|
-
|
|
473
|
-
Args:
|
|
474
|
-
name: The attribute name to aggregate across samples.
|
|
475
|
-
|
|
476
|
-
Returns:
|
|
477
|
-
For NDArray fields: a stacked numpy array with batch dimension.
|
|
478
|
-
For other fields: a list of values from each sample.
|
|
479
|
-
|
|
480
|
-
Raises:
|
|
481
|
-
AttributeError: If the attribute doesn't exist on the sample type.
|
|
482
|
-
"""
|
|
348
|
+
"""Aggregate a field across all samples (cached)."""
|
|
483
349
|
# Aggregate named params of sample type
|
|
484
350
|
if name in vars(self.sample_type)["__annotations__"]:
|
|
485
351
|
if name not in self._aggregate_cache:
|
|
@@ -492,8 +358,8 @@ class SampleBatch(Generic[DT]):
|
|
|
492
358
|
raise AttributeError(f"No sample attribute named {name}")
|
|
493
359
|
|
|
494
360
|
|
|
495
|
-
ST = TypeVar("ST", bound=
|
|
496
|
-
RT = TypeVar("RT", bound=
|
|
361
|
+
ST = TypeVar("ST", bound=Packable)
|
|
362
|
+
RT = TypeVar("RT", bound=Packable)
|
|
497
363
|
|
|
498
364
|
|
|
499
365
|
class _ShardListStage(wds.utils.PipelineStage):
|
|
@@ -571,23 +437,18 @@ class Dataset(Generic[ST]):
|
|
|
571
437
|
|
|
572
438
|
@property
|
|
573
439
|
def sample_type(self) -> Type:
|
|
574
|
-
"""The type
|
|
575
|
-
|
|
576
|
-
Returns:
|
|
577
|
-
The type parameter ``ST`` used when creating this ``Dataset[ST]``.
|
|
578
|
-
"""
|
|
440
|
+
"""The type parameter ``ST`` used when creating this dataset."""
|
|
579
441
|
if self._sample_type_cache is None:
|
|
580
442
|
self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
|
|
581
|
-
|
|
443
|
+
if self._sample_type_cache is None:
|
|
444
|
+
raise TypeError(
|
|
445
|
+
"Dataset requires a type parameter, e.g. Dataset[MySample]"
|
|
446
|
+
)
|
|
582
447
|
return self._sample_type_cache
|
|
583
448
|
|
|
584
449
|
@property
|
|
585
450
|
def batch_type(self) -> Type:
|
|
586
|
-
"""
|
|
587
|
-
|
|
588
|
-
Returns:
|
|
589
|
-
``SampleBatch[ST]`` where ``ST`` is this dataset's sample type.
|
|
590
|
-
"""
|
|
451
|
+
"""``SampleBatch[ST]`` where ``ST`` is this dataset's sample type."""
|
|
591
452
|
return SampleBatch[self.sample_type]
|
|
592
453
|
|
|
593
454
|
def __init__(
|
|
@@ -614,28 +475,21 @@ class Dataset(Generic[ST]):
|
|
|
614
475
|
"""
|
|
615
476
|
super().__init__()
|
|
616
477
|
|
|
617
|
-
# Handle backward compatibility: url= keyword argument
|
|
618
478
|
if source is None and url is not None:
|
|
619
479
|
source = url
|
|
620
480
|
elif source is None:
|
|
621
481
|
raise TypeError("Dataset() missing required argument: 'source' or 'url'")
|
|
622
482
|
|
|
623
|
-
# Normalize source: strings become URLSource for backward compatibility
|
|
624
483
|
if isinstance(source, str):
|
|
625
484
|
self._source: DataSource = URLSource(source)
|
|
626
485
|
self.url = source
|
|
627
486
|
else:
|
|
628
487
|
self._source = source
|
|
629
|
-
# For compatibility, expose URL if source has list_shards
|
|
630
488
|
shards = source.list_shards()
|
|
631
|
-
# Design note: Using first shard as url for legacy compatibility.
|
|
632
|
-
# Full shard list is available via list_shards() method.
|
|
633
489
|
self.url = shards[0] if shards else ""
|
|
634
490
|
|
|
635
491
|
self._metadata: dict[str, Any] | None = None
|
|
636
492
|
self.metadata_url: str | None = metadata_url
|
|
637
|
-
"""Optional URL to msgpack-encoded metadata for this dataset."""
|
|
638
|
-
|
|
639
493
|
self._output_lens: Lens | None = None
|
|
640
494
|
self._sample_type_cache: Type | None = None
|
|
641
495
|
|
|
@@ -645,47 +499,23 @@ class Dataset(Generic[ST]):
|
|
|
645
499
|
return self._source
|
|
646
500
|
|
|
647
501
|
def as_type(self, other: Type[RT]) -> "Dataset[RT]":
|
|
648
|
-
"""View this dataset through a different sample type
|
|
649
|
-
|
|
650
|
-
Args:
|
|
651
|
-
other: The target sample type to transform into. Must be a type
|
|
652
|
-
derived from ``PackableSample``.
|
|
653
|
-
|
|
654
|
-
Returns:
|
|
655
|
-
A new ``Dataset`` instance that yields samples of type ``other``
|
|
656
|
-
by applying the appropriate lens transformation from the global
|
|
657
|
-
``LensNetwork`` registry.
|
|
502
|
+
"""View this dataset through a different sample type via a registered lens.
|
|
658
503
|
|
|
659
504
|
Raises:
|
|
660
|
-
ValueError: If no
|
|
661
|
-
sample type and the target type.
|
|
505
|
+
ValueError: If no lens exists between the current and target types.
|
|
662
506
|
"""
|
|
663
507
|
ret = Dataset[other](self._source)
|
|
664
|
-
# Get the singleton lens registry
|
|
665
508
|
lenses = LensNetwork()
|
|
666
509
|
ret._output_lens = lenses.transform(self.sample_type, ret.sample_type)
|
|
667
510
|
return ret
|
|
668
511
|
|
|
669
512
|
@property
|
|
670
513
|
def shards(self) -> Iterator[str]:
|
|
671
|
-
"""Lazily iterate over shard identifiers.
|
|
672
|
-
|
|
673
|
-
Yields:
|
|
674
|
-
Shard identifiers (e.g., 'train-000000.tar', 'train-000001.tar').
|
|
675
|
-
|
|
676
|
-
Examples:
|
|
677
|
-
>>> for shard in ds.shards:
|
|
678
|
-
... print(f"Processing {shard}")
|
|
679
|
-
"""
|
|
514
|
+
"""Lazily iterate over shard identifiers."""
|
|
680
515
|
return iter(self._source.list_shards())
|
|
681
516
|
|
|
682
517
|
def list_shards(self) -> list[str]:
|
|
683
|
-
"""
|
|
684
|
-
|
|
685
|
-
Returns:
|
|
686
|
-
A full (non-lazy) list of the individual ``tar`` files within the
|
|
687
|
-
source WebDataset.
|
|
688
|
-
"""
|
|
518
|
+
"""Return all shard paths/URLs as a list."""
|
|
689
519
|
return self._source.list_shards()
|
|
690
520
|
|
|
691
521
|
# Legacy alias for backwards compatibility
|
|
@@ -707,14 +537,7 @@ class Dataset(Generic[ST]):
|
|
|
707
537
|
|
|
708
538
|
@property
|
|
709
539
|
def metadata(self) -> dict[str, Any] | None:
|
|
710
|
-
"""Fetch and cache metadata from metadata_url.
|
|
711
|
-
|
|
712
|
-
Returns:
|
|
713
|
-
Deserialized metadata dictionary, or None if no metadata_url is set.
|
|
714
|
-
|
|
715
|
-
Raises:
|
|
716
|
-
requests.HTTPError: If metadata fetch fails.
|
|
717
|
-
"""
|
|
540
|
+
"""Fetch and cache metadata from metadata_url, or ``None`` if unset."""
|
|
718
541
|
if self.metadata_url is None:
|
|
719
542
|
return None
|
|
720
543
|
|
|
@@ -726,6 +549,367 @@ class Dataset(Generic[ST]):
|
|
|
726
549
|
# Use our cached values
|
|
727
550
|
return self._metadata
|
|
728
551
|
|
|
552
|
+
##
|
|
553
|
+
# Convenience methods (GH#38 developer experience)
|
|
554
|
+
|
|
555
|
+
@property
|
|
556
|
+
def schema(self) -> dict[str, type]:
|
|
557
|
+
"""Field names and types for this dataset's sample type.
|
|
558
|
+
|
|
559
|
+
Examples:
|
|
560
|
+
>>> ds = Dataset[MyData]("data.tar")
|
|
561
|
+
>>> ds.schema
|
|
562
|
+
{'name': <class 'str'>, 'embedding': numpy.ndarray}
|
|
563
|
+
"""
|
|
564
|
+
st = self.sample_type
|
|
565
|
+
if st is DictSample:
|
|
566
|
+
return {"_data": dict}
|
|
567
|
+
if dataclasses.is_dataclass(st):
|
|
568
|
+
return {f.name: f.type for f in dataclasses.fields(st)}
|
|
569
|
+
return {}
|
|
570
|
+
|
|
571
|
+
@property
|
|
572
|
+
def column_names(self) -> list[str]:
|
|
573
|
+
"""List of field names for this dataset's sample type."""
|
|
574
|
+
st = self.sample_type
|
|
575
|
+
if dataclasses.is_dataclass(st):
|
|
576
|
+
return [f.name for f in dataclasses.fields(st)]
|
|
577
|
+
return []
|
|
578
|
+
|
|
579
|
+
def __iter__(self) -> Iterator[ST]:
|
|
580
|
+
"""Shorthand for ``ds.ordered()``."""
|
|
581
|
+
return iter(self.ordered())
|
|
582
|
+
|
|
583
|
+
def __len__(self) -> int:
|
|
584
|
+
"""Total sample count (iterates all shards on first call, then cached)."""
|
|
585
|
+
if not hasattr(self, "_len_cache"):
|
|
586
|
+
self._len_cache: int = sum(1 for _ in self.ordered())
|
|
587
|
+
return self._len_cache
|
|
588
|
+
|
|
589
|
+
def head(self, n: int = 5) -> list[ST]:
|
|
590
|
+
"""Return the first *n* samples from the dataset.
|
|
591
|
+
|
|
592
|
+
Args:
|
|
593
|
+
n: Number of samples to return. Default: 5.
|
|
594
|
+
|
|
595
|
+
Returns:
|
|
596
|
+
List of up to *n* samples in shard order.
|
|
597
|
+
|
|
598
|
+
Examples:
|
|
599
|
+
>>> samples = ds.head(3)
|
|
600
|
+
>>> len(samples)
|
|
601
|
+
3
|
|
602
|
+
"""
|
|
603
|
+
return list(itertools.islice(self.ordered(), n))
|
|
604
|
+
|
|
605
|
+
def get(self, key: str) -> ST:
|
|
606
|
+
"""Retrieve a single sample by its ``__key__``.
|
|
607
|
+
|
|
608
|
+
Scans shards sequentially until a sample with a matching key is found.
|
|
609
|
+
This is O(n) for streaming datasets.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
key: The WebDataset ``__key__`` string to search for.
|
|
613
|
+
|
|
614
|
+
Returns:
|
|
615
|
+
The matching sample.
|
|
616
|
+
|
|
617
|
+
Raises:
|
|
618
|
+
SampleKeyError: If no sample with the given key exists.
|
|
619
|
+
|
|
620
|
+
Examples:
|
|
621
|
+
>>> sample = ds.get("00000001-0001-1000-8000-010000000000")
|
|
622
|
+
"""
|
|
623
|
+
pipeline = wds.pipeline.DataPipeline(
|
|
624
|
+
_ShardListStage(self._source),
|
|
625
|
+
wds.shardlists.split_by_worker,
|
|
626
|
+
_StreamOpenerStage(self._source),
|
|
627
|
+
wds.tariterators.tar_file_expander,
|
|
628
|
+
wds.tariterators.group_by_keys,
|
|
629
|
+
)
|
|
630
|
+
for raw_sample in pipeline:
|
|
631
|
+
if raw_sample.get("__key__") == key:
|
|
632
|
+
return self.wrap(raw_sample)
|
|
633
|
+
raise SampleKeyError(key)
|
|
634
|
+
|
|
635
|
+
def describe(self) -> dict[str, Any]:
|
|
636
|
+
"""Summary statistics: sample_type, fields, num_shards, shards, url, metadata."""
|
|
637
|
+
shards = self.list_shards()
|
|
638
|
+
return {
|
|
639
|
+
"sample_type": self.sample_type.__name__,
|
|
640
|
+
"fields": self.schema,
|
|
641
|
+
"num_shards": len(shards),
|
|
642
|
+
"shards": shards,
|
|
643
|
+
"url": self.url,
|
|
644
|
+
"metadata": self.metadata,
|
|
645
|
+
}
|
|
646
|
+
|
|
647
|
+
def filter(self, predicate: Callable[[ST], bool]) -> "Dataset[ST]":
|
|
648
|
+
"""Return a new dataset that yields only samples matching *predicate*.
|
|
649
|
+
|
|
650
|
+
The filter is applied lazily during iteration — no data is copied.
|
|
651
|
+
|
|
652
|
+
Args:
|
|
653
|
+
predicate: A function that takes a sample and returns ``True``
|
|
654
|
+
to keep it or ``False`` to discard it.
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
A new ``Dataset`` whose iterators apply the filter.
|
|
658
|
+
|
|
659
|
+
Examples:
|
|
660
|
+
>>> long_names = ds.filter(lambda s: len(s.name) > 10)
|
|
661
|
+
>>> for sample in long_names:
|
|
662
|
+
... assert len(sample.name) > 10
|
|
663
|
+
"""
|
|
664
|
+
filtered = Dataset[self.sample_type](self._source, self.metadata_url)
|
|
665
|
+
filtered._sample_type_cache = self._sample_type_cache
|
|
666
|
+
filtered._output_lens = self._output_lens
|
|
667
|
+
filtered._filter_fn = predicate
|
|
668
|
+
# Preserve any existing filters
|
|
669
|
+
parent_filters = getattr(self, "_filter_fn", None)
|
|
670
|
+
if parent_filters is not None:
|
|
671
|
+
outer = parent_filters
|
|
672
|
+
filtered._filter_fn = lambda s: outer(s) and predicate(s)
|
|
673
|
+
# Preserve any existing map
|
|
674
|
+
if hasattr(self, "_map_fn"):
|
|
675
|
+
filtered._map_fn = self._map_fn
|
|
676
|
+
return filtered
|
|
677
|
+
|
|
678
|
+
def map(self, fn: Callable[[ST], Any]) -> "Dataset":
|
|
679
|
+
"""Return a new dataset that applies *fn* to each sample during iteration.
|
|
680
|
+
|
|
681
|
+
The mapping is applied lazily during iteration — no data is copied.
|
|
682
|
+
|
|
683
|
+
Args:
|
|
684
|
+
fn: A function that takes a sample of type ``ST`` and returns
|
|
685
|
+
a transformed value.
|
|
686
|
+
|
|
687
|
+
Returns:
|
|
688
|
+
A new ``Dataset`` whose iterators apply the mapping.
|
|
689
|
+
|
|
690
|
+
Examples:
|
|
691
|
+
>>> names = ds.map(lambda s: s.name)
|
|
692
|
+
>>> for name in names:
|
|
693
|
+
... print(name)
|
|
694
|
+
"""
|
|
695
|
+
mapped = Dataset[self.sample_type](self._source, self.metadata_url)
|
|
696
|
+
mapped._sample_type_cache = self._sample_type_cache
|
|
697
|
+
mapped._output_lens = self._output_lens
|
|
698
|
+
mapped._map_fn = fn
|
|
699
|
+
# Preserve any existing map
|
|
700
|
+
if hasattr(self, "_map_fn"):
|
|
701
|
+
outer = self._map_fn
|
|
702
|
+
mapped._map_fn = lambda s: fn(outer(s))
|
|
703
|
+
# Preserve any existing filter
|
|
704
|
+
if hasattr(self, "_filter_fn"):
|
|
705
|
+
mapped._filter_fn = self._filter_fn
|
|
706
|
+
return mapped
|
|
707
|
+
|
|
708
|
+
def process_shards(
|
|
709
|
+
self,
|
|
710
|
+
fn: Callable[[list[ST]], Any],
|
|
711
|
+
*,
|
|
712
|
+
shards: list[str] | None = None,
|
|
713
|
+
) -> dict[str, Any]:
|
|
714
|
+
"""Process each shard independently, collecting per-shard results.
|
|
715
|
+
|
|
716
|
+
Unlike :meth:`map` (which is lazy and per-sample), this method eagerly
|
|
717
|
+
processes each shard in turn, calling *fn* with the full list of samples
|
|
718
|
+
from that shard. If some shards fail, raises
|
|
719
|
+
:class:`~atdata._exceptions.PartialFailureError` containing both the
|
|
720
|
+
successful results and the per-shard errors.
|
|
721
|
+
|
|
722
|
+
Args:
|
|
723
|
+
fn: Function receiving a list of samples from one shard and
|
|
724
|
+
returning an arbitrary result.
|
|
725
|
+
shards: Optional list of shard identifiers to process. If ``None``,
|
|
726
|
+
processes all shards in the dataset. Useful for retrying only
|
|
727
|
+
the failed shards from a previous ``PartialFailureError``.
|
|
728
|
+
|
|
729
|
+
Returns:
|
|
730
|
+
Dict mapping shard identifier to *fn*'s return value for each shard.
|
|
731
|
+
|
|
732
|
+
Raises:
|
|
733
|
+
PartialFailureError: If at least one shard fails. The exception
|
|
734
|
+
carries ``.succeeded_shards``, ``.failed_shards``, ``.errors``,
|
|
735
|
+
and ``.results`` for inspection and retry.
|
|
736
|
+
|
|
737
|
+
Examples:
|
|
738
|
+
>>> results = ds.process_shards(lambda samples: len(samples))
|
|
739
|
+
>>> # On partial failure, retry just the failed shards:
|
|
740
|
+
>>> try:
|
|
741
|
+
... results = ds.process_shards(expensive_fn)
|
|
742
|
+
... except PartialFailureError as e:
|
|
743
|
+
... retry = ds.process_shards(expensive_fn, shards=e.failed_shards)
|
|
744
|
+
"""
|
|
745
|
+
from ._logging import get_logger
|
|
746
|
+
|
|
747
|
+
log = get_logger()
|
|
748
|
+
shard_ids = shards or self.list_shards()
|
|
749
|
+
log.info("process_shards: starting %d shards", len(shard_ids))
|
|
750
|
+
|
|
751
|
+
succeeded: list[str] = []
|
|
752
|
+
failed: list[str] = []
|
|
753
|
+
errors: dict[str, Exception] = {}
|
|
754
|
+
results: dict[str, Any] = {}
|
|
755
|
+
|
|
756
|
+
for shard_id in shard_ids:
|
|
757
|
+
try:
|
|
758
|
+
shard_ds = Dataset[self.sample_type](shard_id)
|
|
759
|
+
shard_ds._sample_type_cache = self._sample_type_cache
|
|
760
|
+
samples = list(shard_ds.ordered())
|
|
761
|
+
results[shard_id] = fn(samples)
|
|
762
|
+
succeeded.append(shard_id)
|
|
763
|
+
log.debug("process_shards: shard ok %s", shard_id)
|
|
764
|
+
except Exception as exc:
|
|
765
|
+
failed.append(shard_id)
|
|
766
|
+
errors[shard_id] = exc
|
|
767
|
+
log.warning("process_shards: shard failed %s: %s", shard_id, exc)
|
|
768
|
+
|
|
769
|
+
if failed:
|
|
770
|
+
log.error(
|
|
771
|
+
"process_shards: %d/%d shards failed",
|
|
772
|
+
len(failed),
|
|
773
|
+
len(shard_ids),
|
|
774
|
+
)
|
|
775
|
+
raise PartialFailureError(
|
|
776
|
+
succeeded_shards=succeeded,
|
|
777
|
+
failed_shards=failed,
|
|
778
|
+
errors=errors,
|
|
779
|
+
results=results,
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
log.info("process_shards: all %d shards succeeded", len(shard_ids))
|
|
783
|
+
return results
|
|
784
|
+
|
|
785
|
+
def select(self, indices: Sequence[int]) -> list[ST]:
|
|
786
|
+
"""Return samples at the given integer indices.
|
|
787
|
+
|
|
788
|
+
Iterates through the dataset in order and collects samples whose
|
|
789
|
+
positional index matches. This is O(n) for streaming datasets.
|
|
790
|
+
|
|
791
|
+
Args:
|
|
792
|
+
indices: Sequence of zero-based indices to select.
|
|
793
|
+
|
|
794
|
+
Returns:
|
|
795
|
+
List of samples at the requested positions, in index order.
|
|
796
|
+
|
|
797
|
+
Examples:
|
|
798
|
+
>>> samples = ds.select([0, 5, 10])
|
|
799
|
+
>>> len(samples)
|
|
800
|
+
3
|
|
801
|
+
"""
|
|
802
|
+
if not indices:
|
|
803
|
+
return []
|
|
804
|
+
target = set(indices)
|
|
805
|
+
max_idx = max(indices)
|
|
806
|
+
result: dict[int, ST] = {}
|
|
807
|
+
for i, sample in enumerate(self.ordered()):
|
|
808
|
+
if i in target:
|
|
809
|
+
result[i] = sample
|
|
810
|
+
if i >= max_idx:
|
|
811
|
+
break
|
|
812
|
+
return [result[i] for i in indices if i in result]
|
|
813
|
+
|
|
814
|
+
def query(
|
|
815
|
+
self,
|
|
816
|
+
where: "Callable[[pd.DataFrame], pd.Series]",
|
|
817
|
+
) -> "list[SampleLocation]":
|
|
818
|
+
"""Query this dataset using per-shard manifest metadata.
|
|
819
|
+
|
|
820
|
+
Requires manifests to have been generated during shard writing.
|
|
821
|
+
Discovers manifest files alongside the tar shards, loads them,
|
|
822
|
+
and executes a two-phase query (shard-level aggregate pruning,
|
|
823
|
+
then sample-level parquet filtering).
|
|
824
|
+
|
|
825
|
+
Args:
|
|
826
|
+
where: Predicate function that receives a pandas DataFrame
|
|
827
|
+
of manifest fields and returns a boolean Series selecting
|
|
828
|
+
matching rows.
|
|
829
|
+
|
|
830
|
+
Returns:
|
|
831
|
+
List of ``SampleLocation`` for matching samples.
|
|
832
|
+
|
|
833
|
+
Raises:
|
|
834
|
+
FileNotFoundError: If no manifest files are found alongside shards.
|
|
835
|
+
|
|
836
|
+
Examples:
|
|
837
|
+
>>> locs = ds.query(where=lambda df: df["confidence"] > 0.9)
|
|
838
|
+
>>> len(locs)
|
|
839
|
+
42
|
|
840
|
+
"""
|
|
841
|
+
from .manifest import QueryExecutor
|
|
842
|
+
|
|
843
|
+
shard_urls = self.list_shards()
|
|
844
|
+
executor = QueryExecutor.from_shard_urls(shard_urls)
|
|
845
|
+
return executor.query(where=where)
|
|
846
|
+
|
|
847
|
+
def to_pandas(self, limit: int | None = None) -> "pd.DataFrame":
|
|
848
|
+
"""Materialize the dataset (or first *limit* samples) as a DataFrame.
|
|
849
|
+
|
|
850
|
+
Args:
|
|
851
|
+
limit: Maximum number of samples to include. ``None`` means all
|
|
852
|
+
samples (may use significant memory for large datasets).
|
|
853
|
+
|
|
854
|
+
Returns:
|
|
855
|
+
A pandas DataFrame with one row per sample and columns matching
|
|
856
|
+
the sample fields.
|
|
857
|
+
|
|
858
|
+
Warning:
|
|
859
|
+
With ``limit=None`` this loads the entire dataset into memory.
|
|
860
|
+
|
|
861
|
+
Examples:
|
|
862
|
+
>>> df = ds.to_pandas(limit=100)
|
|
863
|
+
>>> df.columns.tolist()
|
|
864
|
+
['name', 'embedding']
|
|
865
|
+
"""
|
|
866
|
+
samples = self.head(limit) if limit is not None else list(self.ordered())
|
|
867
|
+
rows = [
|
|
868
|
+
asdict(s) if dataclasses.is_dataclass(s) else s.to_dict() for s in samples
|
|
869
|
+
]
|
|
870
|
+
return pd.DataFrame(rows)
|
|
871
|
+
|
|
872
|
+
def to_dict(self, limit: int | None = None) -> dict[str, list[Any]]:
|
|
873
|
+
"""Materialize the dataset as a column-oriented dictionary.
|
|
874
|
+
|
|
875
|
+
Args:
|
|
876
|
+
limit: Maximum number of samples to include. ``None`` means all.
|
|
877
|
+
|
|
878
|
+
Returns:
|
|
879
|
+
Dictionary mapping field names to lists of values (one entry
|
|
880
|
+
per sample).
|
|
881
|
+
|
|
882
|
+
Warning:
|
|
883
|
+
With ``limit=None`` this loads the entire dataset into memory.
|
|
884
|
+
|
|
885
|
+
Examples:
|
|
886
|
+
>>> d = ds.to_dict(limit=10)
|
|
887
|
+
>>> d.keys()
|
|
888
|
+
dict_keys(['name', 'embedding'])
|
|
889
|
+
>>> len(d['name'])
|
|
890
|
+
10
|
|
891
|
+
"""
|
|
892
|
+
samples = self.head(limit) if limit is not None else list(self.ordered())
|
|
893
|
+
if not samples:
|
|
894
|
+
return {}
|
|
895
|
+
if dataclasses.is_dataclass(samples[0]):
|
|
896
|
+
fields = [f.name for f in dataclasses.fields(samples[0])]
|
|
897
|
+
return {f: [getattr(s, f) for s in samples] for f in fields}
|
|
898
|
+
# DictSample path
|
|
899
|
+
keys = samples[0].keys()
|
|
900
|
+
return {k: [s[k] for s in samples] for k in keys}
|
|
901
|
+
|
|
902
|
+
def _post_wrap_stages(self) -> list:
|
|
903
|
+
"""Build extra pipeline stages for filter/map set via .filter()/.map()."""
|
|
904
|
+
stages: list = []
|
|
905
|
+
filter_fn = getattr(self, "_filter_fn", None)
|
|
906
|
+
if filter_fn is not None:
|
|
907
|
+
stages.append(wds.filters.select(filter_fn))
|
|
908
|
+
map_fn = getattr(self, "_map_fn", None)
|
|
909
|
+
if map_fn is not None:
|
|
910
|
+
stages.append(wds.filters.map(map_fn))
|
|
911
|
+
return stages
|
|
912
|
+
|
|
729
913
|
@overload
|
|
730
914
|
def ordered(
|
|
731
915
|
self,
|
|
@@ -769,6 +953,7 @@ class Dataset(Generic[ST]):
|
|
|
769
953
|
wds.tariterators.tar_file_expander,
|
|
770
954
|
wds.tariterators.group_by_keys,
|
|
771
955
|
wds.filters.map(self.wrap),
|
|
956
|
+
*self._post_wrap_stages(),
|
|
772
957
|
)
|
|
773
958
|
|
|
774
959
|
return wds.pipeline.DataPipeline(
|
|
@@ -839,6 +1024,7 @@ class Dataset(Generic[ST]):
|
|
|
839
1024
|
wds.tariterators.group_by_keys,
|
|
840
1025
|
wds.filters.shuffle(buffer_samples),
|
|
841
1026
|
wds.filters.map(self.wrap),
|
|
1027
|
+
*self._post_wrap_stages(),
|
|
842
1028
|
)
|
|
843
1029
|
|
|
844
1030
|
return wds.pipeline.DataPipeline(
|
|
@@ -862,100 +1048,47 @@ class Dataset(Generic[ST]):
|
|
|
862
1048
|
maxcount: Optional[int] = None,
|
|
863
1049
|
**kwargs,
|
|
864
1050
|
):
|
|
865
|
-
"""Export dataset
|
|
866
|
-
|
|
867
|
-
Converts all samples to a pandas DataFrame and saves to parquet file(s).
|
|
868
|
-
Useful for interoperability with data analysis tools.
|
|
1051
|
+
"""Export dataset to parquet file(s).
|
|
869
1052
|
|
|
870
1053
|
Args:
|
|
871
|
-
path: Output path
|
|
872
|
-
|
|
873
|
-
sample_map:
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
**kwargs: Additional arguments passed to ``pandas.DataFrame.to_parquet()``.
|
|
878
|
-
Common options include ``compression``, ``index``, ``engine``.
|
|
879
|
-
|
|
880
|
-
Warning:
|
|
881
|
-
**Memory Usage**: When ``maxcount=None`` (default), this method loads
|
|
882
|
-
the **entire dataset into memory** as a pandas DataFrame before writing.
|
|
883
|
-
For large datasets, this can cause memory exhaustion.
|
|
884
|
-
|
|
885
|
-
For datasets larger than available RAM, always specify ``maxcount``::
|
|
886
|
-
|
|
887
|
-
# Safe for large datasets - processes in chunks
|
|
888
|
-
ds.to_parquet("output.parquet", maxcount=10000)
|
|
889
|
-
|
|
890
|
-
This creates multiple parquet files: ``output-000000.parquet``,
|
|
891
|
-
``output-000001.parquet``, etc.
|
|
1054
|
+
path: Output path. With *maxcount*, files are named
|
|
1055
|
+
``{stem}-{segment:06d}.parquet``.
|
|
1056
|
+
sample_map: Convert sample to dict. Defaults to ``dataclasses.asdict``.
|
|
1057
|
+
maxcount: Split into files of at most this many samples.
|
|
1058
|
+
Without it, the entire dataset is loaded into memory.
|
|
1059
|
+
**kwargs: Passed to ``pandas.DataFrame.to_parquet()``.
|
|
892
1060
|
|
|
893
1061
|
Examples:
|
|
894
|
-
>>> ds = Dataset[MySample]("data.tar")
|
|
895
|
-
>>> # Small dataset - load all at once
|
|
896
|
-
>>> ds.to_parquet("output.parquet")
|
|
897
|
-
>>>
|
|
898
|
-
>>> # Large dataset - process in chunks
|
|
899
1062
|
>>> ds.to_parquet("output.parquet", maxcount=50000)
|
|
900
1063
|
"""
|
|
901
|
-
##
|
|
902
|
-
|
|
903
|
-
# Normalize args
|
|
904
1064
|
path = Path(path)
|
|
905
1065
|
if sample_map is None:
|
|
906
1066
|
sample_map = asdict
|
|
907
1067
|
|
|
908
|
-
verbose = kwargs.get("verbose", False)
|
|
909
|
-
|
|
910
|
-
it = self.ordered(batch_size=None)
|
|
911
|
-
if verbose:
|
|
912
|
-
it = tqdm(it)
|
|
913
|
-
|
|
914
|
-
#
|
|
915
|
-
|
|
916
1068
|
if maxcount is None:
|
|
917
|
-
# Load and save full dataset
|
|
918
1069
|
df = pd.DataFrame([sample_map(x) for x in self.ordered(batch_size=None)])
|
|
919
1070
|
df.to_parquet(path, **kwargs)
|
|
920
|
-
|
|
921
1071
|
else:
|
|
922
|
-
# Load and save dataset in segments of size `maxcount`
|
|
923
|
-
|
|
924
1072
|
cur_segment = 0
|
|
925
|
-
cur_buffer = []
|
|
1073
|
+
cur_buffer: list = []
|
|
926
1074
|
path_template = (
|
|
927
1075
|
path.parent / f"{path.stem}-{{:06d}}{path.suffix}"
|
|
928
1076
|
).as_posix()
|
|
929
1077
|
|
|
930
1078
|
for x in self.ordered(batch_size=None):
|
|
931
1079
|
cur_buffer.append(sample_map(x))
|
|
932
|
-
|
|
933
1080
|
if len(cur_buffer) >= maxcount:
|
|
934
|
-
# Write current segment
|
|
935
1081
|
cur_path = path_template.format(cur_segment)
|
|
936
|
-
|
|
937
|
-
df.to_parquet(cur_path, **kwargs)
|
|
938
|
-
|
|
1082
|
+
pd.DataFrame(cur_buffer).to_parquet(cur_path, **kwargs)
|
|
939
1083
|
cur_segment += 1
|
|
940
1084
|
cur_buffer = []
|
|
941
1085
|
|
|
942
|
-
if
|
|
943
|
-
# Write one last segment with remainder
|
|
1086
|
+
if cur_buffer:
|
|
944
1087
|
cur_path = path_template.format(cur_segment)
|
|
945
|
-
|
|
946
|
-
df.to_parquet(cur_path, **kwargs)
|
|
1088
|
+
pd.DataFrame(cur_buffer).to_parquet(cur_path, **kwargs)
|
|
947
1089
|
|
|
948
1090
|
def wrap(self, sample: WDSRawSample) -> ST:
|
|
949
|
-
"""
|
|
950
|
-
|
|
951
|
-
Args:
|
|
952
|
-
sample: A dictionary containing at minimum a ``'msgpack'`` key with
|
|
953
|
-
serialized sample bytes.
|
|
954
|
-
|
|
955
|
-
Returns:
|
|
956
|
-
A deserialized sample of type ``ST``, optionally transformed through
|
|
957
|
-
a lens if ``as_type()`` was called.
|
|
958
|
-
"""
|
|
1091
|
+
"""Deserialize a raw WDS sample dict into type ``ST``."""
|
|
959
1092
|
if "msgpack" not in sample:
|
|
960
1093
|
raise ValueError(
|
|
961
1094
|
f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}"
|
|
@@ -972,20 +1105,7 @@ class Dataset(Generic[ST]):
|
|
|
972
1105
|
return self._output_lens(source_sample)
|
|
973
1106
|
|
|
974
1107
|
def wrap_batch(self, batch: WDSRawBatch) -> SampleBatch[ST]:
|
|
975
|
-
"""
|
|
976
|
-
|
|
977
|
-
Args:
|
|
978
|
-
batch: A dictionary containing a ``'msgpack'`` key with a list of
|
|
979
|
-
serialized sample bytes.
|
|
980
|
-
|
|
981
|
-
Returns:
|
|
982
|
-
A ``SampleBatch[ST]`` containing deserialized samples, optionally
|
|
983
|
-
transformed through a lens if ``as_type()`` was called.
|
|
984
|
-
|
|
985
|
-
Note:
|
|
986
|
-
This implementation deserializes samples one at a time, then
|
|
987
|
-
aggregates them into a batch.
|
|
988
|
-
"""
|
|
1108
|
+
"""Deserialize a raw WDS batch dict into ``SampleBatch[ST]``."""
|
|
989
1109
|
|
|
990
1110
|
if "msgpack" not in batch:
|
|
991
1111
|
raise ValueError(
|
|
@@ -1009,24 +1129,12 @@ _T = TypeVar("_T")
|
|
|
1009
1129
|
|
|
1010
1130
|
|
|
1011
1131
|
@dataclass_transform()
|
|
1012
|
-
def packable(cls: type[_T]) -> type[
|
|
1013
|
-
"""
|
|
1014
|
-
|
|
1015
|
-
This decorator transforms a class into a dataclass that inherits from
|
|
1016
|
-
``PackableSample``, enabling automatic msgpack serialization/deserialization
|
|
1017
|
-
with special handling for NDArray fields.
|
|
1132
|
+
def packable(cls: type[_T]) -> type[Packable]:
|
|
1133
|
+
"""Convert a class into a ``PackableSample`` dataclass with msgpack serialization.
|
|
1018
1134
|
|
|
1019
|
-
The resulting class
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
Args:
|
|
1024
|
-
cls: The class to convert. Should have type annotations for its fields.
|
|
1025
|
-
|
|
1026
|
-
Returns:
|
|
1027
|
-
A new dataclass that inherits from ``PackableSample`` with the same
|
|
1028
|
-
name and annotations as the original class. The class satisfies the
|
|
1029
|
-
``Packable`` protocol and can be used with ``Type[Packable]`` signatures.
|
|
1135
|
+
The resulting class gains ``packed``, ``as_wds``, ``from_bytes``, and
|
|
1136
|
+
``from_data`` methods, and satisfies the ``Packable`` protocol.
|
|
1137
|
+
NDArray fields are automatically handled during serialization.
|
|
1030
1138
|
|
|
1031
1139
|
Examples:
|
|
1032
1140
|
>>> @packable
|
|
@@ -1035,11 +1143,7 @@ def packable(cls: type[_T]) -> type[_T]:
|
|
|
1035
1143
|
... values: NDArray
|
|
1036
1144
|
...
|
|
1037
1145
|
>>> sample = MyData(name="test", values=np.array([1, 2, 3]))
|
|
1038
|
-
>>>
|
|
1039
|
-
>>> restored = MyData.from_bytes(bytes_data)
|
|
1040
|
-
>>>
|
|
1041
|
-
>>> # Works with Packable-typed APIs
|
|
1042
|
-
>>> index.publish_schema(MyData, version="1.0.0") # Type-safe
|
|
1146
|
+
>>> restored = MyData.from_bytes(sample.packed)
|
|
1043
1147
|
"""
|
|
1044
1148
|
|
|
1045
1149
|
##
|
|
@@ -1086,3 +1190,154 @@ def packable(cls: type[_T]) -> type[_T]:
|
|
|
1086
1190
|
##
|
|
1087
1191
|
|
|
1088
1192
|
return as_packable
|
|
1193
|
+
|
|
1194
|
+
|
|
1195
|
+
# ---------------------------------------------------------------------------
|
|
1196
|
+
# write_samples — convenience function for writing samples to tar files
|
|
1197
|
+
# ---------------------------------------------------------------------------
|
|
1198
|
+
|
|
1199
|
+
|
|
1200
|
+
def write_samples(
|
|
1201
|
+
samples: Iterable[ST],
|
|
1202
|
+
path: str | Path,
|
|
1203
|
+
*,
|
|
1204
|
+
maxcount: int | None = None,
|
|
1205
|
+
maxsize: int | None = None,
|
|
1206
|
+
manifest: bool = False,
|
|
1207
|
+
) -> "Dataset[ST]":
|
|
1208
|
+
"""Write an iterable of samples to WebDataset tar file(s).
|
|
1209
|
+
|
|
1210
|
+
Args:
|
|
1211
|
+
samples: Iterable of ``PackableSample`` instances. Must be non-empty.
|
|
1212
|
+
path: Output path for the tar file. For sharded output (when
|
|
1213
|
+
*maxcount* or *maxsize* is set), a ``%06d`` pattern is
|
|
1214
|
+
auto-appended if the path does not already contain ``%``.
|
|
1215
|
+
maxcount: Maximum samples per shard. Triggers multi-shard output.
|
|
1216
|
+
maxsize: Maximum bytes per shard. Triggers multi-shard output.
|
|
1217
|
+
manifest: If True, write per-shard manifest sidecar files
|
|
1218
|
+
(``.manifest.json`` + ``.manifest.parquet``) alongside each
|
|
1219
|
+
tar file. Manifests enable metadata queries via
|
|
1220
|
+
``QueryExecutor`` without opening the tars.
|
|
1221
|
+
|
|
1222
|
+
Returns:
|
|
1223
|
+
A ``Dataset`` wrapping the written file(s), typed to the sample
|
|
1224
|
+
type of the input samples.
|
|
1225
|
+
|
|
1226
|
+
Raises:
|
|
1227
|
+
ValueError: If *samples* is empty.
|
|
1228
|
+
|
|
1229
|
+
Examples:
|
|
1230
|
+
>>> samples = [MySample(key="0", text="hello")]
|
|
1231
|
+
>>> ds = write_samples(samples, "out.tar")
|
|
1232
|
+
>>> list(ds.ordered())
|
|
1233
|
+
[MySample(key='0', text='hello')]
|
|
1234
|
+
"""
|
|
1235
|
+
from ._hf_api import _shards_to_wds_url
|
|
1236
|
+
|
|
1237
|
+
if manifest:
|
|
1238
|
+
from .manifest._builder import ManifestBuilder
|
|
1239
|
+
from .manifest._writer import ManifestWriter
|
|
1240
|
+
|
|
1241
|
+
path = Path(path)
|
|
1242
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
1243
|
+
|
|
1244
|
+
use_shard_writer = maxcount is not None or maxsize is not None
|
|
1245
|
+
sample_type: type | None = None
|
|
1246
|
+
written_paths: list[str] = []
|
|
1247
|
+
|
|
1248
|
+
# Manifest tracking state
|
|
1249
|
+
_current_builder: list = [] # single-element list for nonlocal mutation
|
|
1250
|
+
_builders: list[tuple[str, "ManifestBuilder"]] = []
|
|
1251
|
+
_running_offset: list[int] = [0]
|
|
1252
|
+
|
|
1253
|
+
def _finalize_builder() -> None:
|
|
1254
|
+
"""Finalize the current manifest builder and stash it."""
|
|
1255
|
+
if _current_builder:
|
|
1256
|
+
shard_path = written_paths[-1] if written_paths else ""
|
|
1257
|
+
_builders.append((shard_path, _current_builder[0]))
|
|
1258
|
+
_current_builder.clear()
|
|
1259
|
+
|
|
1260
|
+
def _start_builder(shard_path: str) -> None:
|
|
1261
|
+
"""Start a new manifest builder for a shard."""
|
|
1262
|
+
_finalize_builder()
|
|
1263
|
+
shard_id = Path(shard_path).stem
|
|
1264
|
+
_current_builder.append(
|
|
1265
|
+
ManifestBuilder(sample_type=sample_type, shard_id=shard_id)
|
|
1266
|
+
)
|
|
1267
|
+
_running_offset[0] = 0
|
|
1268
|
+
|
|
1269
|
+
def _record_sample(sample: "PackableSample", wds_dict: dict) -> None:
|
|
1270
|
+
"""Record a sample in the active manifest builder."""
|
|
1271
|
+
if not _current_builder:
|
|
1272
|
+
return
|
|
1273
|
+
packed_bytes = wds_dict["msgpack"]
|
|
1274
|
+
size = len(packed_bytes)
|
|
1275
|
+
_current_builder[0].add_sample(
|
|
1276
|
+
key=wds_dict["__key__"],
|
|
1277
|
+
offset=_running_offset[0],
|
|
1278
|
+
size=size,
|
|
1279
|
+
sample=sample,
|
|
1280
|
+
)
|
|
1281
|
+
_running_offset[0] += size
|
|
1282
|
+
|
|
1283
|
+
if use_shard_writer:
|
|
1284
|
+
# Build shard pattern from path
|
|
1285
|
+
if "%" not in str(path):
|
|
1286
|
+
pattern = str(path.parent / f"{path.stem}-%06d{path.suffix}")
|
|
1287
|
+
else:
|
|
1288
|
+
pattern = str(path)
|
|
1289
|
+
|
|
1290
|
+
writer_kwargs: dict[str, Any] = {}
|
|
1291
|
+
if maxcount is not None:
|
|
1292
|
+
writer_kwargs["maxcount"] = maxcount
|
|
1293
|
+
if maxsize is not None:
|
|
1294
|
+
writer_kwargs["maxsize"] = maxsize
|
|
1295
|
+
|
|
1296
|
+
def _track(p: str) -> None:
|
|
1297
|
+
written_paths.append(str(Path(p).resolve()))
|
|
1298
|
+
if manifest and sample_type is not None:
|
|
1299
|
+
_start_builder(p)
|
|
1300
|
+
|
|
1301
|
+
with wds.writer.ShardWriter(pattern, post=_track, **writer_kwargs) as sink:
|
|
1302
|
+
for sample in samples:
|
|
1303
|
+
if sample_type is None:
|
|
1304
|
+
sample_type = type(sample)
|
|
1305
|
+
wds_dict = sample.as_wds
|
|
1306
|
+
sink.write(wds_dict)
|
|
1307
|
+
if manifest:
|
|
1308
|
+
# The first sample triggers _track before we get here when
|
|
1309
|
+
# ShardWriter opens the first shard, but just in case:
|
|
1310
|
+
if not _current_builder and sample_type is not None:
|
|
1311
|
+
_start_builder(str(path))
|
|
1312
|
+
_record_sample(sample, wds_dict)
|
|
1313
|
+
else:
|
|
1314
|
+
with wds.writer.TarWriter(str(path)) as sink:
|
|
1315
|
+
for sample in samples:
|
|
1316
|
+
if sample_type is None:
|
|
1317
|
+
sample_type = type(sample)
|
|
1318
|
+
wds_dict = sample.as_wds
|
|
1319
|
+
sink.write(wds_dict)
|
|
1320
|
+
if manifest:
|
|
1321
|
+
if not _current_builder and sample_type is not None:
|
|
1322
|
+
_current_builder.append(
|
|
1323
|
+
ManifestBuilder(sample_type=sample_type, shard_id=path.stem)
|
|
1324
|
+
)
|
|
1325
|
+
_record_sample(sample, wds_dict)
|
|
1326
|
+
written_paths.append(str(path.resolve()))
|
|
1327
|
+
|
|
1328
|
+
if sample_type is None:
|
|
1329
|
+
raise ValueError("samples must be non-empty")
|
|
1330
|
+
|
|
1331
|
+
# Finalize and write manifests
|
|
1332
|
+
if manifest:
|
|
1333
|
+
_finalize_builder()
|
|
1334
|
+
for shard_path, builder in _builders:
|
|
1335
|
+
m = builder.build()
|
|
1336
|
+
base = str(Path(shard_path).with_suffix(""))
|
|
1337
|
+
writer = ManifestWriter(base)
|
|
1338
|
+
writer.write(m)
|
|
1339
|
+
|
|
1340
|
+
url = _shards_to_wds_url(written_paths)
|
|
1341
|
+
ds: Dataset = Dataset(url)
|
|
1342
|
+
ds._sample_type_cache = sample_type
|
|
1343
|
+
return ds
|