atdata 0.2.3b1__py3-none-any.whl → 0.3.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/.gitignore +1 -0
- atdata/__init__.py +30 -0
- atdata/_exceptions.py +168 -0
- atdata/_helpers.py +29 -15
- atdata/_hf_api.py +63 -11
- atdata/_logging.py +70 -0
- atdata/_protocols.py +19 -62
- atdata/_schema_codec.py +5 -4
- atdata/_type_utils.py +28 -2
- atdata/atmosphere/__init__.py +19 -9
- atdata/atmosphere/records.py +3 -2
- atdata/atmosphere/schema.py +2 -2
- atdata/cli/__init__.py +157 -171
- atdata/cli/inspect.py +69 -0
- atdata/cli/local.py +1 -1
- atdata/cli/preview.py +63 -0
- atdata/cli/schema.py +109 -0
- atdata/dataset.py +428 -326
- atdata/lens.py +9 -2
- atdata/local/__init__.py +71 -0
- atdata/local/_entry.py +157 -0
- atdata/local/_index.py +940 -0
- atdata/local/_repo_legacy.py +218 -0
- atdata/local/_s3.py +349 -0
- atdata/local/_schema.py +380 -0
- atdata/manifest/__init__.py +28 -0
- atdata/manifest/_aggregates.py +156 -0
- atdata/manifest/_builder.py +163 -0
- atdata/manifest/_fields.py +154 -0
- atdata/manifest/_manifest.py +146 -0
- atdata/manifest/_query.py +150 -0
- atdata/manifest/_writer.py +74 -0
- atdata/promote.py +4 -4
- atdata/providers/__init__.py +25 -0
- atdata/providers/_base.py +140 -0
- atdata/providers/_factory.py +69 -0
- atdata/providers/_postgres.py +214 -0
- atdata/providers/_redis.py +171 -0
- atdata/providers/_sqlite.py +191 -0
- atdata/repository.py +323 -0
- atdata/testing.py +337 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.0b1.dist-info}/METADATA +4 -1
- atdata-0.3.0b1.dist-info/RECORD +54 -0
- atdata/local.py +0 -1720
- atdata-0.2.3b1.dist-info/RECORD +0 -28
- {atdata-0.2.3b1.dist-info → atdata-0.3.0b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.0b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.0b1.dist-info}/licenses/LICENSE +0 -0
atdata/dataset.py
CHANGED
|
@@ -31,6 +31,7 @@ Examples:
|
|
|
31
31
|
import webdataset as wds
|
|
32
32
|
|
|
33
33
|
from pathlib import Path
|
|
34
|
+
import itertools
|
|
34
35
|
import uuid
|
|
35
36
|
|
|
36
37
|
import dataclasses
|
|
@@ -42,15 +43,16 @@ from dataclasses import (
|
|
|
42
43
|
from abc import ABC
|
|
43
44
|
|
|
44
45
|
from ._sources import URLSource
|
|
45
|
-
from ._protocols import DataSource
|
|
46
|
+
from ._protocols import DataSource, Packable
|
|
47
|
+
from ._exceptions import SampleKeyError, PartialFailureError
|
|
46
48
|
|
|
47
|
-
from tqdm import tqdm
|
|
48
49
|
import numpy as np
|
|
49
50
|
import pandas as pd
|
|
50
51
|
import requests
|
|
51
52
|
|
|
52
53
|
import typing
|
|
53
54
|
from typing import (
|
|
55
|
+
TYPE_CHECKING,
|
|
54
56
|
Any,
|
|
55
57
|
Optional,
|
|
56
58
|
Dict,
|
|
@@ -66,6 +68,9 @@ from typing import (
|
|
|
66
68
|
dataclass_transform,
|
|
67
69
|
overload,
|
|
68
70
|
)
|
|
71
|
+
|
|
72
|
+
if TYPE_CHECKING:
|
|
73
|
+
from .manifest._query import SampleLocation
|
|
69
74
|
from numpy.typing import NDArray
|
|
70
75
|
|
|
71
76
|
import msgpack
|
|
@@ -157,37 +162,17 @@ class DictSample:
|
|
|
157
162
|
|
|
158
163
|
@classmethod
|
|
159
164
|
def from_data(cls, data: dict[str, Any]) -> "DictSample":
|
|
160
|
-
"""Create a DictSample from unpacked msgpack data.
|
|
161
|
-
|
|
162
|
-
Args:
|
|
163
|
-
data: Dictionary with field names as keys.
|
|
164
|
-
|
|
165
|
-
Returns:
|
|
166
|
-
New DictSample instance wrapping the data.
|
|
167
|
-
"""
|
|
165
|
+
"""Create a DictSample from unpacked msgpack data."""
|
|
168
166
|
return cls(_data=data)
|
|
169
167
|
|
|
170
168
|
@classmethod
|
|
171
169
|
def from_bytes(cls, bs: bytes) -> "DictSample":
|
|
172
|
-
"""Create a DictSample from raw msgpack bytes.
|
|
173
|
-
|
|
174
|
-
Args:
|
|
175
|
-
bs: Raw bytes from a msgpack-serialized sample.
|
|
176
|
-
|
|
177
|
-
Returns:
|
|
178
|
-
New DictSample instance with the unpacked data.
|
|
179
|
-
"""
|
|
170
|
+
"""Create a DictSample from raw msgpack bytes."""
|
|
180
171
|
return cls.from_data(ormsgpack.unpackb(bs))
|
|
181
172
|
|
|
182
173
|
def __getattr__(self, name: str) -> Any:
|
|
183
174
|
"""Access a field by attribute name.
|
|
184
175
|
|
|
185
|
-
Args:
|
|
186
|
-
name: Field name to access.
|
|
187
|
-
|
|
188
|
-
Returns:
|
|
189
|
-
The field value.
|
|
190
|
-
|
|
191
176
|
Raises:
|
|
192
177
|
AttributeError: If the field doesn't exist.
|
|
193
178
|
"""
|
|
@@ -203,21 +188,9 @@ class DictSample:
|
|
|
203
188
|
) from None
|
|
204
189
|
|
|
205
190
|
def __getitem__(self, key: str) -> Any:
|
|
206
|
-
"""Access a field by dict key.
|
|
207
|
-
|
|
208
|
-
Args:
|
|
209
|
-
key: Field name to access.
|
|
210
|
-
|
|
211
|
-
Returns:
|
|
212
|
-
The field value.
|
|
213
|
-
|
|
214
|
-
Raises:
|
|
215
|
-
KeyError: If the field doesn't exist.
|
|
216
|
-
"""
|
|
217
191
|
return self._data[key]
|
|
218
192
|
|
|
219
193
|
def __contains__(self, key: str) -> bool:
|
|
220
|
-
"""Check if a field exists."""
|
|
221
194
|
return key in self._data
|
|
222
195
|
|
|
223
196
|
def keys(self) -> list[str]:
|
|
@@ -225,23 +198,13 @@ class DictSample:
|
|
|
225
198
|
return list(self._data.keys())
|
|
226
199
|
|
|
227
200
|
def values(self) -> list[Any]:
|
|
228
|
-
"""Return list of field values."""
|
|
229
201
|
return list(self._data.values())
|
|
230
202
|
|
|
231
203
|
def items(self) -> list[tuple[str, Any]]:
|
|
232
|
-
"""Return list of (field_name, value) tuples."""
|
|
233
204
|
return list(self._data.items())
|
|
234
205
|
|
|
235
206
|
def get(self, key: str, default: Any = None) -> Any:
|
|
236
|
-
"""Get a field value
|
|
237
|
-
|
|
238
|
-
Args:
|
|
239
|
-
key: Field name to access.
|
|
240
|
-
default: Value to return if field doesn't exist.
|
|
241
|
-
|
|
242
|
-
Returns:
|
|
243
|
-
The field value or default.
|
|
244
|
-
"""
|
|
207
|
+
"""Get a field value, returning *default* if missing."""
|
|
245
208
|
return self._data.get(key, default)
|
|
246
209
|
|
|
247
210
|
def to_dict(self) -> dict[str, Any]:
|
|
@@ -250,20 +213,12 @@ class DictSample:
|
|
|
250
213
|
|
|
251
214
|
@property
|
|
252
215
|
def packed(self) -> bytes:
|
|
253
|
-
"""
|
|
254
|
-
|
|
255
|
-
Returns:
|
|
256
|
-
Raw msgpack bytes representing this sample's data.
|
|
257
|
-
"""
|
|
216
|
+
"""Serialize to msgpack bytes."""
|
|
258
217
|
return msgpack.packb(self._data)
|
|
259
218
|
|
|
260
219
|
@property
|
|
261
220
|
def as_wds(self) -> "WDSRawSample":
|
|
262
|
-
"""
|
|
263
|
-
|
|
264
|
-
Returns:
|
|
265
|
-
A dictionary with ``__key__`` and ``msgpack`` fields.
|
|
266
|
-
"""
|
|
221
|
+
"""Serialize for writing to WebDataset (``__key__`` + ``msgpack``)."""
|
|
267
222
|
return {
|
|
268
223
|
"__key__": str(uuid.uuid1(0, 0)),
|
|
269
224
|
"msgpack": self.packed,
|
|
@@ -300,31 +255,13 @@ class PackableSample(ABC):
|
|
|
300
255
|
|
|
301
256
|
def _ensure_good(self):
|
|
302
257
|
"""Convert bytes to NDArray for fields annotated as NDArray or NDArray | None."""
|
|
303
|
-
|
|
304
|
-
# Auto-convert known types when annotated
|
|
305
|
-
# for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
|
|
306
258
|
for field in dataclasses.fields(self):
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
# Annotation for this variable is to be an NDArray
|
|
311
|
-
if _is_possibly_ndarray_type(var_type):
|
|
312
|
-
# ... so, we'll always auto-convert to numpy
|
|
313
|
-
|
|
314
|
-
var_cur_value = getattr(self, var_name)
|
|
315
|
-
|
|
316
|
-
# Execute the appropriate conversion for intermediate data
|
|
317
|
-
# based on what is provided
|
|
318
|
-
|
|
319
|
-
if isinstance(var_cur_value, np.ndarray):
|
|
320
|
-
# Already the correct type, no conversion needed
|
|
259
|
+
if _is_possibly_ndarray_type(field.type):
|
|
260
|
+
value = getattr(self, field.name)
|
|
261
|
+
if isinstance(value, np.ndarray):
|
|
321
262
|
continue
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
# Design note: bytes in NDArray-typed fields are always interpreted
|
|
325
|
-
# as serialized arrays. This means raw bytes fields must not be
|
|
326
|
-
# annotated as NDArray.
|
|
327
|
-
setattr(self, var_name, eh.bytes_to_array(var_cur_value))
|
|
263
|
+
elif isinstance(value, bytes):
|
|
264
|
+
setattr(self, field.name, eh.bytes_to_array(value))
|
|
328
265
|
|
|
329
266
|
def __post_init__(self):
|
|
330
267
|
self._ensure_good()
|
|
@@ -333,67 +270,31 @@ class PackableSample(ABC):
|
|
|
333
270
|
|
|
334
271
|
@classmethod
|
|
335
272
|
def from_data(cls, data: WDSRawSample) -> Self:
|
|
336
|
-
"""Create
|
|
337
|
-
|
|
338
|
-
Args:
|
|
339
|
-
data: Dictionary with keys matching the sample's field names.
|
|
340
|
-
|
|
341
|
-
Returns:
|
|
342
|
-
New instance with NDArray fields auto-converted from bytes.
|
|
343
|
-
"""
|
|
273
|
+
"""Create an instance from unpacked msgpack data."""
|
|
344
274
|
return cls(**data)
|
|
345
275
|
|
|
346
276
|
@classmethod
|
|
347
277
|
def from_bytes(cls, bs: bytes) -> Self:
|
|
348
|
-
"""Create
|
|
349
|
-
|
|
350
|
-
Args:
|
|
351
|
-
bs: Raw bytes from a msgpack-serialized sample.
|
|
352
|
-
|
|
353
|
-
Returns:
|
|
354
|
-
A new instance of this sample class deserialized from the bytes.
|
|
355
|
-
"""
|
|
278
|
+
"""Create an instance from raw msgpack bytes."""
|
|
356
279
|
return cls.from_data(ormsgpack.unpackb(bs))
|
|
357
280
|
|
|
358
281
|
@property
|
|
359
282
|
def packed(self) -> bytes:
|
|
360
|
-
"""
|
|
361
|
-
|
|
362
|
-
NDArray fields are automatically converted to bytes before packing.
|
|
363
|
-
All other fields are packed as-is if they're msgpack-compatible.
|
|
364
|
-
|
|
365
|
-
Returns:
|
|
366
|
-
Raw msgpack bytes representing this sample's data.
|
|
283
|
+
"""Serialize to msgpack bytes. NDArray fields are auto-converted.
|
|
367
284
|
|
|
368
285
|
Raises:
|
|
369
286
|
RuntimeError: If msgpack serialization fails.
|
|
370
287
|
"""
|
|
371
|
-
|
|
372
|
-
# Make sure that all of our (possibly unpackable) data is in a packable
|
|
373
|
-
# format
|
|
374
288
|
o = {k: _make_packable(v) for k, v in vars(self).items()}
|
|
375
|
-
|
|
376
289
|
ret = msgpack.packb(o)
|
|
377
|
-
|
|
378
290
|
if ret is None:
|
|
379
291
|
raise RuntimeError(f"Failed to pack sample to bytes: {o}")
|
|
380
|
-
|
|
381
292
|
return ret
|
|
382
293
|
|
|
383
294
|
@property
|
|
384
295
|
def as_wds(self) -> WDSRawSample:
|
|
385
|
-
"""
|
|
386
|
-
|
|
387
|
-
Returns:
|
|
388
|
-
A dictionary with ``__key__`` (UUID v1 for sortable keys) and
|
|
389
|
-
``msgpack`` (packed sample data) fields suitable for WebDataset.
|
|
390
|
-
|
|
391
|
-
Note:
|
|
392
|
-
Keys are auto-generated as UUID v1 for time-sortable ordering.
|
|
393
|
-
Custom key specification is not currently supported.
|
|
394
|
-
"""
|
|
296
|
+
"""Serialize for writing to WebDataset (``__key__`` + ``msgpack``)."""
|
|
395
297
|
return {
|
|
396
|
-
# Generates a UUID that is timelike-sortable
|
|
397
298
|
"__key__": str(uuid.uuid1(0, 0)),
|
|
398
299
|
"msgpack": self.packed,
|
|
399
300
|
}
|
|
@@ -411,75 +312,38 @@ def _batch_aggregate(xs: Sequence):
|
|
|
411
312
|
class SampleBatch(Generic[DT]):
|
|
412
313
|
"""A batch of samples with automatic attribute aggregation.
|
|
413
314
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
samples in the batch.
|
|
418
|
-
|
|
419
|
-
NDArray fields are stacked into a numpy array with a batch dimension.
|
|
420
|
-
Other fields are aggregated into a list.
|
|
315
|
+
Accessing an attribute aggregates that field across all samples:
|
|
316
|
+
NDArray fields are stacked into a numpy array with a batch dimension;
|
|
317
|
+
other fields are collected into a list. Results are cached.
|
|
421
318
|
|
|
422
319
|
Parameters:
|
|
423
320
|
DT: The sample type, must derive from ``PackableSample``.
|
|
424
321
|
|
|
425
|
-
Attributes:
|
|
426
|
-
samples: The list of sample instances in this batch.
|
|
427
|
-
|
|
428
322
|
Examples:
|
|
429
323
|
>>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
|
|
430
|
-
>>> batch.embeddings #
|
|
431
|
-
>>> batch.names #
|
|
432
|
-
|
|
433
|
-
Note:
|
|
434
|
-
This class uses Python's ``__orig_class__`` mechanism to extract the
|
|
435
|
-
type parameter at runtime. Instances must be created using the
|
|
436
|
-
subscripted syntax ``SampleBatch[MyType](samples)`` rather than
|
|
437
|
-
calling the constructor directly with an unsubscripted class.
|
|
324
|
+
>>> batch.embeddings # Stacked numpy array of shape (3, ...)
|
|
325
|
+
>>> batch.names # List of names
|
|
438
326
|
"""
|
|
439
327
|
|
|
440
|
-
# Design note: The docstring uses "Parameters:" for type parameters because
|
|
441
|
-
# quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
|
|
442
|
-
|
|
443
328
|
def __init__(self, samples: Sequence[DT]):
|
|
444
|
-
"""Create a batch from a sequence of samples.
|
|
445
|
-
|
|
446
|
-
Args:
|
|
447
|
-
samples: A sequence of sample instances to aggregate into a batch.
|
|
448
|
-
Each sample must be an instance of a type derived from
|
|
449
|
-
``PackableSample``.
|
|
450
|
-
"""
|
|
329
|
+
"""Create a batch from a sequence of samples."""
|
|
451
330
|
self.samples = list(samples)
|
|
452
331
|
self._aggregate_cache = dict()
|
|
453
332
|
self._sample_type_cache: Type | None = None
|
|
454
333
|
|
|
455
334
|
@property
|
|
456
335
|
def sample_type(self) -> Type:
|
|
457
|
-
"""The type
|
|
458
|
-
|
|
459
|
-
Returns:
|
|
460
|
-
The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
|
|
461
|
-
"""
|
|
336
|
+
"""The type parameter ``DT`` used when creating this batch."""
|
|
462
337
|
if self._sample_type_cache is None:
|
|
463
338
|
self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
|
|
464
|
-
|
|
339
|
+
if self._sample_type_cache is None:
|
|
340
|
+
raise TypeError(
|
|
341
|
+
"SampleBatch requires a type parameter, e.g. SampleBatch[MySample]"
|
|
342
|
+
)
|
|
465
343
|
return self._sample_type_cache
|
|
466
344
|
|
|
467
345
|
def __getattr__(self, name):
|
|
468
|
-
"""Aggregate
|
|
469
|
-
|
|
470
|
-
This magic method enables attribute-style access to aggregated sample
|
|
471
|
-
fields. Results are cached for efficiency.
|
|
472
|
-
|
|
473
|
-
Args:
|
|
474
|
-
name: The attribute name to aggregate across samples.
|
|
475
|
-
|
|
476
|
-
Returns:
|
|
477
|
-
For NDArray fields: a stacked numpy array with batch dimension.
|
|
478
|
-
For other fields: a list of values from each sample.
|
|
479
|
-
|
|
480
|
-
Raises:
|
|
481
|
-
AttributeError: If the attribute doesn't exist on the sample type.
|
|
482
|
-
"""
|
|
346
|
+
"""Aggregate a field across all samples (cached)."""
|
|
483
347
|
# Aggregate named params of sample type
|
|
484
348
|
if name in vars(self.sample_type)["__annotations__"]:
|
|
485
349
|
if name not in self._aggregate_cache:
|
|
@@ -492,8 +356,8 @@ class SampleBatch(Generic[DT]):
|
|
|
492
356
|
raise AttributeError(f"No sample attribute named {name}")
|
|
493
357
|
|
|
494
358
|
|
|
495
|
-
ST = TypeVar("ST", bound=
|
|
496
|
-
RT = TypeVar("RT", bound=
|
|
359
|
+
ST = TypeVar("ST", bound=Packable)
|
|
360
|
+
RT = TypeVar("RT", bound=Packable)
|
|
497
361
|
|
|
498
362
|
|
|
499
363
|
class _ShardListStage(wds.utils.PipelineStage):
|
|
@@ -571,23 +435,18 @@ class Dataset(Generic[ST]):
|
|
|
571
435
|
|
|
572
436
|
@property
|
|
573
437
|
def sample_type(self) -> Type:
|
|
574
|
-
"""The type
|
|
575
|
-
|
|
576
|
-
Returns:
|
|
577
|
-
The type parameter ``ST`` used when creating this ``Dataset[ST]``.
|
|
578
|
-
"""
|
|
438
|
+
"""The type parameter ``ST`` used when creating this dataset."""
|
|
579
439
|
if self._sample_type_cache is None:
|
|
580
440
|
self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
|
|
581
|
-
|
|
441
|
+
if self._sample_type_cache is None:
|
|
442
|
+
raise TypeError(
|
|
443
|
+
"Dataset requires a type parameter, e.g. Dataset[MySample]"
|
|
444
|
+
)
|
|
582
445
|
return self._sample_type_cache
|
|
583
446
|
|
|
584
447
|
@property
|
|
585
448
|
def batch_type(self) -> Type:
|
|
586
|
-
"""
|
|
587
|
-
|
|
588
|
-
Returns:
|
|
589
|
-
``SampleBatch[ST]`` where ``ST`` is this dataset's sample type.
|
|
590
|
-
"""
|
|
449
|
+
"""``SampleBatch[ST]`` where ``ST`` is this dataset's sample type."""
|
|
591
450
|
return SampleBatch[self.sample_type]
|
|
592
451
|
|
|
593
452
|
def __init__(
|
|
@@ -614,28 +473,21 @@ class Dataset(Generic[ST]):
|
|
|
614
473
|
"""
|
|
615
474
|
super().__init__()
|
|
616
475
|
|
|
617
|
-
# Handle backward compatibility: url= keyword argument
|
|
618
476
|
if source is None and url is not None:
|
|
619
477
|
source = url
|
|
620
478
|
elif source is None:
|
|
621
479
|
raise TypeError("Dataset() missing required argument: 'source' or 'url'")
|
|
622
480
|
|
|
623
|
-
# Normalize source: strings become URLSource for backward compatibility
|
|
624
481
|
if isinstance(source, str):
|
|
625
482
|
self._source: DataSource = URLSource(source)
|
|
626
483
|
self.url = source
|
|
627
484
|
else:
|
|
628
485
|
self._source = source
|
|
629
|
-
# For compatibility, expose URL if source has list_shards
|
|
630
486
|
shards = source.list_shards()
|
|
631
|
-
# Design note: Using first shard as url for legacy compatibility.
|
|
632
|
-
# Full shard list is available via list_shards() method.
|
|
633
487
|
self.url = shards[0] if shards else ""
|
|
634
488
|
|
|
635
489
|
self._metadata: dict[str, Any] | None = None
|
|
636
490
|
self.metadata_url: str | None = metadata_url
|
|
637
|
-
"""Optional URL to msgpack-encoded metadata for this dataset."""
|
|
638
|
-
|
|
639
491
|
self._output_lens: Lens | None = None
|
|
640
492
|
self._sample_type_cache: Type | None = None
|
|
641
493
|
|
|
@@ -645,47 +497,23 @@ class Dataset(Generic[ST]):
|
|
|
645
497
|
return self._source
|
|
646
498
|
|
|
647
499
|
def as_type(self, other: Type[RT]) -> "Dataset[RT]":
|
|
648
|
-
"""View this dataset through a different sample type
|
|
649
|
-
|
|
650
|
-
Args:
|
|
651
|
-
other: The target sample type to transform into. Must be a type
|
|
652
|
-
derived from ``PackableSample``.
|
|
653
|
-
|
|
654
|
-
Returns:
|
|
655
|
-
A new ``Dataset`` instance that yields samples of type ``other``
|
|
656
|
-
by applying the appropriate lens transformation from the global
|
|
657
|
-
``LensNetwork`` registry.
|
|
500
|
+
"""View this dataset through a different sample type via a registered lens.
|
|
658
501
|
|
|
659
502
|
Raises:
|
|
660
|
-
ValueError: If no
|
|
661
|
-
sample type and the target type.
|
|
503
|
+
ValueError: If no lens exists between the current and target types.
|
|
662
504
|
"""
|
|
663
505
|
ret = Dataset[other](self._source)
|
|
664
|
-
# Get the singleton lens registry
|
|
665
506
|
lenses = LensNetwork()
|
|
666
507
|
ret._output_lens = lenses.transform(self.sample_type, ret.sample_type)
|
|
667
508
|
return ret
|
|
668
509
|
|
|
669
510
|
@property
|
|
670
511
|
def shards(self) -> Iterator[str]:
|
|
671
|
-
"""Lazily iterate over shard identifiers.
|
|
672
|
-
|
|
673
|
-
Yields:
|
|
674
|
-
Shard identifiers (e.g., 'train-000000.tar', 'train-000001.tar').
|
|
675
|
-
|
|
676
|
-
Examples:
|
|
677
|
-
>>> for shard in ds.shards:
|
|
678
|
-
... print(f"Processing {shard}")
|
|
679
|
-
"""
|
|
512
|
+
"""Lazily iterate over shard identifiers."""
|
|
680
513
|
return iter(self._source.list_shards())
|
|
681
514
|
|
|
682
515
|
def list_shards(self) -> list[str]:
|
|
683
|
-
"""
|
|
684
|
-
|
|
685
|
-
Returns:
|
|
686
|
-
A full (non-lazy) list of the individual ``tar`` files within the
|
|
687
|
-
source WebDataset.
|
|
688
|
-
"""
|
|
516
|
+
"""Return all shard paths/URLs as a list."""
|
|
689
517
|
return self._source.list_shards()
|
|
690
518
|
|
|
691
519
|
# Legacy alias for backwards compatibility
|
|
@@ -707,14 +535,7 @@ class Dataset(Generic[ST]):
|
|
|
707
535
|
|
|
708
536
|
@property
|
|
709
537
|
def metadata(self) -> dict[str, Any] | None:
|
|
710
|
-
"""Fetch and cache metadata from metadata_url.
|
|
711
|
-
|
|
712
|
-
Returns:
|
|
713
|
-
Deserialized metadata dictionary, or None if no metadata_url is set.
|
|
714
|
-
|
|
715
|
-
Raises:
|
|
716
|
-
requests.HTTPError: If metadata fetch fails.
|
|
717
|
-
"""
|
|
538
|
+
"""Fetch and cache metadata from metadata_url, or ``None`` if unset."""
|
|
718
539
|
if self.metadata_url is None:
|
|
719
540
|
return None
|
|
720
541
|
|
|
@@ -726,6 +547,367 @@ class Dataset(Generic[ST]):
|
|
|
726
547
|
# Use our cached values
|
|
727
548
|
return self._metadata
|
|
728
549
|
|
|
550
|
+
##
|
|
551
|
+
# Convenience methods (GH#38 developer experience)
|
|
552
|
+
|
|
553
|
+
@property
|
|
554
|
+
def schema(self) -> dict[str, type]:
|
|
555
|
+
"""Field names and types for this dataset's sample type.
|
|
556
|
+
|
|
557
|
+
Examples:
|
|
558
|
+
>>> ds = Dataset[MyData]("data.tar")
|
|
559
|
+
>>> ds.schema
|
|
560
|
+
{'name': <class 'str'>, 'embedding': numpy.ndarray}
|
|
561
|
+
"""
|
|
562
|
+
st = self.sample_type
|
|
563
|
+
if st is DictSample:
|
|
564
|
+
return {"_data": dict}
|
|
565
|
+
if dataclasses.is_dataclass(st):
|
|
566
|
+
return {f.name: f.type for f in dataclasses.fields(st)}
|
|
567
|
+
return {}
|
|
568
|
+
|
|
569
|
+
@property
|
|
570
|
+
def column_names(self) -> list[str]:
|
|
571
|
+
"""List of field names for this dataset's sample type."""
|
|
572
|
+
st = self.sample_type
|
|
573
|
+
if dataclasses.is_dataclass(st):
|
|
574
|
+
return [f.name for f in dataclasses.fields(st)]
|
|
575
|
+
return []
|
|
576
|
+
|
|
577
|
+
def __iter__(self) -> Iterator[ST]:
|
|
578
|
+
"""Shorthand for ``ds.ordered()``."""
|
|
579
|
+
return iter(self.ordered())
|
|
580
|
+
|
|
581
|
+
def __len__(self) -> int:
|
|
582
|
+
"""Total sample count (iterates all shards on first call, then cached)."""
|
|
583
|
+
if not hasattr(self, "_len_cache"):
|
|
584
|
+
self._len_cache: int = sum(1 for _ in self.ordered())
|
|
585
|
+
return self._len_cache
|
|
586
|
+
|
|
587
|
+
def head(self, n: int = 5) -> list[ST]:
|
|
588
|
+
"""Return the first *n* samples from the dataset.
|
|
589
|
+
|
|
590
|
+
Args:
|
|
591
|
+
n: Number of samples to return. Default: 5.
|
|
592
|
+
|
|
593
|
+
Returns:
|
|
594
|
+
List of up to *n* samples in shard order.
|
|
595
|
+
|
|
596
|
+
Examples:
|
|
597
|
+
>>> samples = ds.head(3)
|
|
598
|
+
>>> len(samples)
|
|
599
|
+
3
|
|
600
|
+
"""
|
|
601
|
+
return list(itertools.islice(self.ordered(), n))
|
|
602
|
+
|
|
603
|
+
def get(self, key: str) -> ST:
|
|
604
|
+
"""Retrieve a single sample by its ``__key__``.
|
|
605
|
+
|
|
606
|
+
Scans shards sequentially until a sample with a matching key is found.
|
|
607
|
+
This is O(n) for streaming datasets.
|
|
608
|
+
|
|
609
|
+
Args:
|
|
610
|
+
key: The WebDataset ``__key__`` string to search for.
|
|
611
|
+
|
|
612
|
+
Returns:
|
|
613
|
+
The matching sample.
|
|
614
|
+
|
|
615
|
+
Raises:
|
|
616
|
+
SampleKeyError: If no sample with the given key exists.
|
|
617
|
+
|
|
618
|
+
Examples:
|
|
619
|
+
>>> sample = ds.get("00000001-0001-1000-8000-010000000000")
|
|
620
|
+
"""
|
|
621
|
+
pipeline = wds.pipeline.DataPipeline(
|
|
622
|
+
_ShardListStage(self._source),
|
|
623
|
+
wds.shardlists.split_by_worker,
|
|
624
|
+
_StreamOpenerStage(self._source),
|
|
625
|
+
wds.tariterators.tar_file_expander,
|
|
626
|
+
wds.tariterators.group_by_keys,
|
|
627
|
+
)
|
|
628
|
+
for raw_sample in pipeline:
|
|
629
|
+
if raw_sample.get("__key__") == key:
|
|
630
|
+
return self.wrap(raw_sample)
|
|
631
|
+
raise SampleKeyError(key)
|
|
632
|
+
|
|
633
|
+
def describe(self) -> dict[str, Any]:
|
|
634
|
+
"""Summary statistics: sample_type, fields, num_shards, shards, url, metadata."""
|
|
635
|
+
shards = self.list_shards()
|
|
636
|
+
return {
|
|
637
|
+
"sample_type": self.sample_type.__name__,
|
|
638
|
+
"fields": self.schema,
|
|
639
|
+
"num_shards": len(shards),
|
|
640
|
+
"shards": shards,
|
|
641
|
+
"url": self.url,
|
|
642
|
+
"metadata": self.metadata,
|
|
643
|
+
}
|
|
644
|
+
|
|
645
|
+
def filter(self, predicate: Callable[[ST], bool]) -> "Dataset[ST]":
|
|
646
|
+
"""Return a new dataset that yields only samples matching *predicate*.
|
|
647
|
+
|
|
648
|
+
The filter is applied lazily during iteration — no data is copied.
|
|
649
|
+
|
|
650
|
+
Args:
|
|
651
|
+
predicate: A function that takes a sample and returns ``True``
|
|
652
|
+
to keep it or ``False`` to discard it.
|
|
653
|
+
|
|
654
|
+
Returns:
|
|
655
|
+
A new ``Dataset`` whose iterators apply the filter.
|
|
656
|
+
|
|
657
|
+
Examples:
|
|
658
|
+
>>> long_names = ds.filter(lambda s: len(s.name) > 10)
|
|
659
|
+
>>> for sample in long_names:
|
|
660
|
+
... assert len(sample.name) > 10
|
|
661
|
+
"""
|
|
662
|
+
filtered = Dataset[self.sample_type](self._source, self.metadata_url)
|
|
663
|
+
filtered._sample_type_cache = self._sample_type_cache
|
|
664
|
+
filtered._output_lens = self._output_lens
|
|
665
|
+
filtered._filter_fn = predicate
|
|
666
|
+
# Preserve any existing filters
|
|
667
|
+
parent_filters = getattr(self, "_filter_fn", None)
|
|
668
|
+
if parent_filters is not None:
|
|
669
|
+
outer = parent_filters
|
|
670
|
+
filtered._filter_fn = lambda s: outer(s) and predicate(s)
|
|
671
|
+
# Preserve any existing map
|
|
672
|
+
if hasattr(self, "_map_fn"):
|
|
673
|
+
filtered._map_fn = self._map_fn
|
|
674
|
+
return filtered
|
|
675
|
+
|
|
676
|
+
def map(self, fn: Callable[[ST], Any]) -> "Dataset":
|
|
677
|
+
"""Return a new dataset that applies *fn* to each sample during iteration.
|
|
678
|
+
|
|
679
|
+
The mapping is applied lazily during iteration — no data is copied.
|
|
680
|
+
|
|
681
|
+
Args:
|
|
682
|
+
fn: A function that takes a sample of type ``ST`` and returns
|
|
683
|
+
a transformed value.
|
|
684
|
+
|
|
685
|
+
Returns:
|
|
686
|
+
A new ``Dataset`` whose iterators apply the mapping.
|
|
687
|
+
|
|
688
|
+
Examples:
|
|
689
|
+
>>> names = ds.map(lambda s: s.name)
|
|
690
|
+
>>> for name in names:
|
|
691
|
+
... print(name)
|
|
692
|
+
"""
|
|
693
|
+
mapped = Dataset[self.sample_type](self._source, self.metadata_url)
|
|
694
|
+
mapped._sample_type_cache = self._sample_type_cache
|
|
695
|
+
mapped._output_lens = self._output_lens
|
|
696
|
+
mapped._map_fn = fn
|
|
697
|
+
# Preserve any existing map
|
|
698
|
+
if hasattr(self, "_map_fn"):
|
|
699
|
+
outer = self._map_fn
|
|
700
|
+
mapped._map_fn = lambda s: fn(outer(s))
|
|
701
|
+
# Preserve any existing filter
|
|
702
|
+
if hasattr(self, "_filter_fn"):
|
|
703
|
+
mapped._filter_fn = self._filter_fn
|
|
704
|
+
return mapped
|
|
705
|
+
|
|
706
|
+
def process_shards(
|
|
707
|
+
self,
|
|
708
|
+
fn: Callable[[list[ST]], Any],
|
|
709
|
+
*,
|
|
710
|
+
shards: list[str] | None = None,
|
|
711
|
+
) -> dict[str, Any]:
|
|
712
|
+
"""Process each shard independently, collecting per-shard results.
|
|
713
|
+
|
|
714
|
+
Unlike :meth:`map` (which is lazy and per-sample), this method eagerly
|
|
715
|
+
processes each shard in turn, calling *fn* with the full list of samples
|
|
716
|
+
from that shard. If some shards fail, raises
|
|
717
|
+
:class:`~atdata._exceptions.PartialFailureError` containing both the
|
|
718
|
+
successful results and the per-shard errors.
|
|
719
|
+
|
|
720
|
+
Args:
|
|
721
|
+
fn: Function receiving a list of samples from one shard and
|
|
722
|
+
returning an arbitrary result.
|
|
723
|
+
shards: Optional list of shard identifiers to process. If ``None``,
|
|
724
|
+
processes all shards in the dataset. Useful for retrying only
|
|
725
|
+
the failed shards from a previous ``PartialFailureError``.
|
|
726
|
+
|
|
727
|
+
Returns:
|
|
728
|
+
Dict mapping shard identifier to *fn*'s return value for each shard.
|
|
729
|
+
|
|
730
|
+
Raises:
|
|
731
|
+
PartialFailureError: If at least one shard fails. The exception
|
|
732
|
+
carries ``.succeeded_shards``, ``.failed_shards``, ``.errors``,
|
|
733
|
+
and ``.results`` for inspection and retry.
|
|
734
|
+
|
|
735
|
+
Examples:
|
|
736
|
+
>>> results = ds.process_shards(lambda samples: len(samples))
|
|
737
|
+
>>> # On partial failure, retry just the failed shards:
|
|
738
|
+
>>> try:
|
|
739
|
+
... results = ds.process_shards(expensive_fn)
|
|
740
|
+
... except PartialFailureError as e:
|
|
741
|
+
... retry = ds.process_shards(expensive_fn, shards=e.failed_shards)
|
|
742
|
+
"""
|
|
743
|
+
from ._logging import get_logger
|
|
744
|
+
|
|
745
|
+
log = get_logger()
|
|
746
|
+
shard_ids = shards or self.list_shards()
|
|
747
|
+
log.info("process_shards: starting %d shards", len(shard_ids))
|
|
748
|
+
|
|
749
|
+
succeeded: list[str] = []
|
|
750
|
+
failed: list[str] = []
|
|
751
|
+
errors: dict[str, Exception] = {}
|
|
752
|
+
results: dict[str, Any] = {}
|
|
753
|
+
|
|
754
|
+
for shard_id in shard_ids:
|
|
755
|
+
try:
|
|
756
|
+
shard_ds = Dataset[self.sample_type](shard_id)
|
|
757
|
+
shard_ds._sample_type_cache = self._sample_type_cache
|
|
758
|
+
samples = list(shard_ds.ordered())
|
|
759
|
+
results[shard_id] = fn(samples)
|
|
760
|
+
succeeded.append(shard_id)
|
|
761
|
+
log.debug("process_shards: shard ok %s", shard_id)
|
|
762
|
+
except Exception as exc:
|
|
763
|
+
failed.append(shard_id)
|
|
764
|
+
errors[shard_id] = exc
|
|
765
|
+
log.warning("process_shards: shard failed %s: %s", shard_id, exc)
|
|
766
|
+
|
|
767
|
+
if failed:
|
|
768
|
+
log.error(
|
|
769
|
+
"process_shards: %d/%d shards failed",
|
|
770
|
+
len(failed),
|
|
771
|
+
len(shard_ids),
|
|
772
|
+
)
|
|
773
|
+
raise PartialFailureError(
|
|
774
|
+
succeeded_shards=succeeded,
|
|
775
|
+
failed_shards=failed,
|
|
776
|
+
errors=errors,
|
|
777
|
+
results=results,
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
log.info("process_shards: all %d shards succeeded", len(shard_ids))
|
|
781
|
+
return results
|
|
782
|
+
|
|
783
|
+
def select(self, indices: Sequence[int]) -> list[ST]:
|
|
784
|
+
"""Return samples at the given integer indices.
|
|
785
|
+
|
|
786
|
+
Iterates through the dataset in order and collects samples whose
|
|
787
|
+
positional index matches. This is O(n) for streaming datasets.
|
|
788
|
+
|
|
789
|
+
Args:
|
|
790
|
+
indices: Sequence of zero-based indices to select.
|
|
791
|
+
|
|
792
|
+
Returns:
|
|
793
|
+
List of samples at the requested positions, in index order.
|
|
794
|
+
|
|
795
|
+
Examples:
|
|
796
|
+
>>> samples = ds.select([0, 5, 10])
|
|
797
|
+
>>> len(samples)
|
|
798
|
+
3
|
|
799
|
+
"""
|
|
800
|
+
if not indices:
|
|
801
|
+
return []
|
|
802
|
+
target = set(indices)
|
|
803
|
+
max_idx = max(indices)
|
|
804
|
+
result: dict[int, ST] = {}
|
|
805
|
+
for i, sample in enumerate(self.ordered()):
|
|
806
|
+
if i in target:
|
|
807
|
+
result[i] = sample
|
|
808
|
+
if i >= max_idx:
|
|
809
|
+
break
|
|
810
|
+
return [result[i] for i in indices if i in result]
|
|
811
|
+
|
|
812
|
+
def query(
|
|
813
|
+
self,
|
|
814
|
+
where: "Callable[[pd.DataFrame], pd.Series]",
|
|
815
|
+
) -> "list[SampleLocation]":
|
|
816
|
+
"""Query this dataset using per-shard manifest metadata.
|
|
817
|
+
|
|
818
|
+
Requires manifests to have been generated during shard writing.
|
|
819
|
+
Discovers manifest files alongside the tar shards, loads them,
|
|
820
|
+
and executes a two-phase query (shard-level aggregate pruning,
|
|
821
|
+
then sample-level parquet filtering).
|
|
822
|
+
|
|
823
|
+
Args:
|
|
824
|
+
where: Predicate function that receives a pandas DataFrame
|
|
825
|
+
of manifest fields and returns a boolean Series selecting
|
|
826
|
+
matching rows.
|
|
827
|
+
|
|
828
|
+
Returns:
|
|
829
|
+
List of ``SampleLocation`` for matching samples.
|
|
830
|
+
|
|
831
|
+
Raises:
|
|
832
|
+
FileNotFoundError: If no manifest files are found alongside shards.
|
|
833
|
+
|
|
834
|
+
Examples:
|
|
835
|
+
>>> locs = ds.query(where=lambda df: df["confidence"] > 0.9)
|
|
836
|
+
>>> len(locs)
|
|
837
|
+
42
|
|
838
|
+
"""
|
|
839
|
+
from .manifest import QueryExecutor
|
|
840
|
+
|
|
841
|
+
shard_urls = self.list_shards()
|
|
842
|
+
executor = QueryExecutor.from_shard_urls(shard_urls)
|
|
843
|
+
return executor.query(where=where)
|
|
844
|
+
|
|
845
|
+
def to_pandas(self, limit: int | None = None) -> "pd.DataFrame":
|
|
846
|
+
"""Materialize the dataset (or first *limit* samples) as a DataFrame.
|
|
847
|
+
|
|
848
|
+
Args:
|
|
849
|
+
limit: Maximum number of samples to include. ``None`` means all
|
|
850
|
+
samples (may use significant memory for large datasets).
|
|
851
|
+
|
|
852
|
+
Returns:
|
|
853
|
+
A pandas DataFrame with one row per sample and columns matching
|
|
854
|
+
the sample fields.
|
|
855
|
+
|
|
856
|
+
Warning:
|
|
857
|
+
With ``limit=None`` this loads the entire dataset into memory.
|
|
858
|
+
|
|
859
|
+
Examples:
|
|
860
|
+
>>> df = ds.to_pandas(limit=100)
|
|
861
|
+
>>> df.columns.tolist()
|
|
862
|
+
['name', 'embedding']
|
|
863
|
+
"""
|
|
864
|
+
samples = self.head(limit) if limit is not None else list(self.ordered())
|
|
865
|
+
rows = [
|
|
866
|
+
asdict(s) if dataclasses.is_dataclass(s) else s.to_dict() for s in samples
|
|
867
|
+
]
|
|
868
|
+
return pd.DataFrame(rows)
|
|
869
|
+
|
|
870
|
+
def to_dict(self, limit: int | None = None) -> dict[str, list[Any]]:
|
|
871
|
+
"""Materialize the dataset as a column-oriented dictionary.
|
|
872
|
+
|
|
873
|
+
Args:
|
|
874
|
+
limit: Maximum number of samples to include. ``None`` means all.
|
|
875
|
+
|
|
876
|
+
Returns:
|
|
877
|
+
Dictionary mapping field names to lists of values (one entry
|
|
878
|
+
per sample).
|
|
879
|
+
|
|
880
|
+
Warning:
|
|
881
|
+
With ``limit=None`` this loads the entire dataset into memory.
|
|
882
|
+
|
|
883
|
+
Examples:
|
|
884
|
+
>>> d = ds.to_dict(limit=10)
|
|
885
|
+
>>> d.keys()
|
|
886
|
+
dict_keys(['name', 'embedding'])
|
|
887
|
+
>>> len(d['name'])
|
|
888
|
+
10
|
|
889
|
+
"""
|
|
890
|
+
samples = self.head(limit) if limit is not None else list(self.ordered())
|
|
891
|
+
if not samples:
|
|
892
|
+
return {}
|
|
893
|
+
if dataclasses.is_dataclass(samples[0]):
|
|
894
|
+
fields = [f.name for f in dataclasses.fields(samples[0])]
|
|
895
|
+
return {f: [getattr(s, f) for s in samples] for f in fields}
|
|
896
|
+
# DictSample path
|
|
897
|
+
keys = samples[0].keys()
|
|
898
|
+
return {k: [s[k] for s in samples] for k in keys}
|
|
899
|
+
|
|
900
|
+
def _post_wrap_stages(self) -> list:
|
|
901
|
+
"""Build extra pipeline stages for filter/map set via .filter()/.map()."""
|
|
902
|
+
stages: list = []
|
|
903
|
+
filter_fn = getattr(self, "_filter_fn", None)
|
|
904
|
+
if filter_fn is not None:
|
|
905
|
+
stages.append(wds.filters.select(filter_fn))
|
|
906
|
+
map_fn = getattr(self, "_map_fn", None)
|
|
907
|
+
if map_fn is not None:
|
|
908
|
+
stages.append(wds.filters.map(map_fn))
|
|
909
|
+
return stages
|
|
910
|
+
|
|
729
911
|
@overload
|
|
730
912
|
def ordered(
|
|
731
913
|
self,
|
|
@@ -769,6 +951,7 @@ class Dataset(Generic[ST]):
|
|
|
769
951
|
wds.tariterators.tar_file_expander,
|
|
770
952
|
wds.tariterators.group_by_keys,
|
|
771
953
|
wds.filters.map(self.wrap),
|
|
954
|
+
*self._post_wrap_stages(),
|
|
772
955
|
)
|
|
773
956
|
|
|
774
957
|
return wds.pipeline.DataPipeline(
|
|
@@ -839,6 +1022,7 @@ class Dataset(Generic[ST]):
|
|
|
839
1022
|
wds.tariterators.group_by_keys,
|
|
840
1023
|
wds.filters.shuffle(buffer_samples),
|
|
841
1024
|
wds.filters.map(self.wrap),
|
|
1025
|
+
*self._post_wrap_stages(),
|
|
842
1026
|
)
|
|
843
1027
|
|
|
844
1028
|
return wds.pipeline.DataPipeline(
|
|
@@ -862,100 +1046,47 @@ class Dataset(Generic[ST]):
|
|
|
862
1046
|
maxcount: Optional[int] = None,
|
|
863
1047
|
**kwargs,
|
|
864
1048
|
):
|
|
865
|
-
"""Export dataset
|
|
866
|
-
|
|
867
|
-
Converts all samples to a pandas DataFrame and saves to parquet file(s).
|
|
868
|
-
Useful for interoperability with data analysis tools.
|
|
1049
|
+
"""Export dataset to parquet file(s).
|
|
869
1050
|
|
|
870
1051
|
Args:
|
|
871
|
-
path: Output path
|
|
872
|
-
|
|
873
|
-
sample_map:
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
**kwargs: Additional arguments passed to ``pandas.DataFrame.to_parquet()``.
|
|
878
|
-
Common options include ``compression``, ``index``, ``engine``.
|
|
879
|
-
|
|
880
|
-
Warning:
|
|
881
|
-
**Memory Usage**: When ``maxcount=None`` (default), this method loads
|
|
882
|
-
the **entire dataset into memory** as a pandas DataFrame before writing.
|
|
883
|
-
For large datasets, this can cause memory exhaustion.
|
|
884
|
-
|
|
885
|
-
For datasets larger than available RAM, always specify ``maxcount``::
|
|
886
|
-
|
|
887
|
-
# Safe for large datasets - processes in chunks
|
|
888
|
-
ds.to_parquet("output.parquet", maxcount=10000)
|
|
889
|
-
|
|
890
|
-
This creates multiple parquet files: ``output-000000.parquet``,
|
|
891
|
-
``output-000001.parquet``, etc.
|
|
1052
|
+
path: Output path. With *maxcount*, files are named
|
|
1053
|
+
``{stem}-{segment:06d}.parquet``.
|
|
1054
|
+
sample_map: Convert sample to dict. Defaults to ``dataclasses.asdict``.
|
|
1055
|
+
maxcount: Split into files of at most this many samples.
|
|
1056
|
+
Without it, the entire dataset is loaded into memory.
|
|
1057
|
+
**kwargs: Passed to ``pandas.DataFrame.to_parquet()``.
|
|
892
1058
|
|
|
893
1059
|
Examples:
|
|
894
|
-
>>> ds = Dataset[MySample]("data.tar")
|
|
895
|
-
>>> # Small dataset - load all at once
|
|
896
|
-
>>> ds.to_parquet("output.parquet")
|
|
897
|
-
>>>
|
|
898
|
-
>>> # Large dataset - process in chunks
|
|
899
1060
|
>>> ds.to_parquet("output.parquet", maxcount=50000)
|
|
900
1061
|
"""
|
|
901
|
-
##
|
|
902
|
-
|
|
903
|
-
# Normalize args
|
|
904
1062
|
path = Path(path)
|
|
905
1063
|
if sample_map is None:
|
|
906
1064
|
sample_map = asdict
|
|
907
1065
|
|
|
908
|
-
verbose = kwargs.get("verbose", False)
|
|
909
|
-
|
|
910
|
-
it = self.ordered(batch_size=None)
|
|
911
|
-
if verbose:
|
|
912
|
-
it = tqdm(it)
|
|
913
|
-
|
|
914
|
-
#
|
|
915
|
-
|
|
916
1066
|
if maxcount is None:
|
|
917
|
-
# Load and save full dataset
|
|
918
1067
|
df = pd.DataFrame([sample_map(x) for x in self.ordered(batch_size=None)])
|
|
919
1068
|
df.to_parquet(path, **kwargs)
|
|
920
|
-
|
|
921
1069
|
else:
|
|
922
|
-
# Load and save dataset in segments of size `maxcount`
|
|
923
|
-
|
|
924
1070
|
cur_segment = 0
|
|
925
|
-
cur_buffer = []
|
|
1071
|
+
cur_buffer: list = []
|
|
926
1072
|
path_template = (
|
|
927
1073
|
path.parent / f"{path.stem}-{{:06d}}{path.suffix}"
|
|
928
1074
|
).as_posix()
|
|
929
1075
|
|
|
930
1076
|
for x in self.ordered(batch_size=None):
|
|
931
1077
|
cur_buffer.append(sample_map(x))
|
|
932
|
-
|
|
933
1078
|
if len(cur_buffer) >= maxcount:
|
|
934
|
-
# Write current segment
|
|
935
1079
|
cur_path = path_template.format(cur_segment)
|
|
936
|
-
|
|
937
|
-
df.to_parquet(cur_path, **kwargs)
|
|
938
|
-
|
|
1080
|
+
pd.DataFrame(cur_buffer).to_parquet(cur_path, **kwargs)
|
|
939
1081
|
cur_segment += 1
|
|
940
1082
|
cur_buffer = []
|
|
941
1083
|
|
|
942
|
-
if
|
|
943
|
-
# Write one last segment with remainder
|
|
1084
|
+
if cur_buffer:
|
|
944
1085
|
cur_path = path_template.format(cur_segment)
|
|
945
|
-
|
|
946
|
-
df.to_parquet(cur_path, **kwargs)
|
|
1086
|
+
pd.DataFrame(cur_buffer).to_parquet(cur_path, **kwargs)
|
|
947
1087
|
|
|
948
1088
|
def wrap(self, sample: WDSRawSample) -> ST:
|
|
949
|
-
"""
|
|
950
|
-
|
|
951
|
-
Args:
|
|
952
|
-
sample: A dictionary containing at minimum a ``'msgpack'`` key with
|
|
953
|
-
serialized sample bytes.
|
|
954
|
-
|
|
955
|
-
Returns:
|
|
956
|
-
A deserialized sample of type ``ST``, optionally transformed through
|
|
957
|
-
a lens if ``as_type()`` was called.
|
|
958
|
-
"""
|
|
1089
|
+
"""Deserialize a raw WDS sample dict into type ``ST``."""
|
|
959
1090
|
if "msgpack" not in sample:
|
|
960
1091
|
raise ValueError(
|
|
961
1092
|
f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}"
|
|
@@ -972,20 +1103,7 @@ class Dataset(Generic[ST]):
|
|
|
972
1103
|
return self._output_lens(source_sample)
|
|
973
1104
|
|
|
974
1105
|
def wrap_batch(self, batch: WDSRawBatch) -> SampleBatch[ST]:
|
|
975
|
-
"""
|
|
976
|
-
|
|
977
|
-
Args:
|
|
978
|
-
batch: A dictionary containing a ``'msgpack'`` key with a list of
|
|
979
|
-
serialized sample bytes.
|
|
980
|
-
|
|
981
|
-
Returns:
|
|
982
|
-
A ``SampleBatch[ST]`` containing deserialized samples, optionally
|
|
983
|
-
transformed through a lens if ``as_type()`` was called.
|
|
984
|
-
|
|
985
|
-
Note:
|
|
986
|
-
This implementation deserializes samples one at a time, then
|
|
987
|
-
aggregates them into a batch.
|
|
988
|
-
"""
|
|
1106
|
+
"""Deserialize a raw WDS batch dict into ``SampleBatch[ST]``."""
|
|
989
1107
|
|
|
990
1108
|
if "msgpack" not in batch:
|
|
991
1109
|
raise ValueError(
|
|
@@ -1009,24 +1127,12 @@ _T = TypeVar("_T")
|
|
|
1009
1127
|
|
|
1010
1128
|
|
|
1011
1129
|
@dataclass_transform()
|
|
1012
|
-
def packable(cls: type[_T]) -> type[
|
|
1013
|
-
"""
|
|
1014
|
-
|
|
1015
|
-
This decorator transforms a class into a dataclass that inherits from
|
|
1016
|
-
``PackableSample``, enabling automatic msgpack serialization/deserialization
|
|
1017
|
-
with special handling for NDArray fields.
|
|
1018
|
-
|
|
1019
|
-
The resulting class satisfies the ``Packable`` protocol, making it compatible
|
|
1020
|
-
with all atdata APIs that accept packable types (e.g., ``publish_schema``,
|
|
1021
|
-
lens transformations, etc.).
|
|
1022
|
-
|
|
1023
|
-
Args:
|
|
1024
|
-
cls: The class to convert. Should have type annotations for its fields.
|
|
1130
|
+
def packable(cls: type[_T]) -> type[Packable]:
|
|
1131
|
+
"""Convert a class into a ``PackableSample`` dataclass with msgpack serialization.
|
|
1025
1132
|
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
``Packable`` protocol and can be used with ``Type[Packable]`` signatures.
|
|
1133
|
+
The resulting class gains ``packed``, ``as_wds``, ``from_bytes``, and
|
|
1134
|
+
``from_data`` methods, and satisfies the ``Packable`` protocol.
|
|
1135
|
+
NDArray fields are automatically handled during serialization.
|
|
1030
1136
|
|
|
1031
1137
|
Examples:
|
|
1032
1138
|
>>> @packable
|
|
@@ -1035,11 +1141,7 @@ def packable(cls: type[_T]) -> type[_T]:
|
|
|
1035
1141
|
... values: NDArray
|
|
1036
1142
|
...
|
|
1037
1143
|
>>> sample = MyData(name="test", values=np.array([1, 2, 3]))
|
|
1038
|
-
>>>
|
|
1039
|
-
>>> restored = MyData.from_bytes(bytes_data)
|
|
1040
|
-
>>>
|
|
1041
|
-
>>> # Works with Packable-typed APIs
|
|
1042
|
-
>>> index.publish_schema(MyData, version="1.0.0") # Type-safe
|
|
1144
|
+
>>> restored = MyData.from_bytes(sample.packed)
|
|
1043
1145
|
"""
|
|
1044
1146
|
|
|
1045
1147
|
##
|