atdata 0.2.3b1__py3-none-any.whl → 0.3.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. atdata/.gitignore +1 -0
  2. atdata/__init__.py +30 -0
  3. atdata/_exceptions.py +168 -0
  4. atdata/_helpers.py +29 -15
  5. atdata/_hf_api.py +63 -11
  6. atdata/_logging.py +70 -0
  7. atdata/_protocols.py +19 -62
  8. atdata/_schema_codec.py +5 -4
  9. atdata/_type_utils.py +28 -2
  10. atdata/atmosphere/__init__.py +19 -9
  11. atdata/atmosphere/records.py +3 -2
  12. atdata/atmosphere/schema.py +2 -2
  13. atdata/cli/__init__.py +157 -171
  14. atdata/cli/inspect.py +69 -0
  15. atdata/cli/local.py +1 -1
  16. atdata/cli/preview.py +63 -0
  17. atdata/cli/schema.py +109 -0
  18. atdata/dataset.py +428 -326
  19. atdata/lens.py +9 -2
  20. atdata/local/__init__.py +71 -0
  21. atdata/local/_entry.py +157 -0
  22. atdata/local/_index.py +940 -0
  23. atdata/local/_repo_legacy.py +218 -0
  24. atdata/local/_s3.py +349 -0
  25. atdata/local/_schema.py +380 -0
  26. atdata/manifest/__init__.py +28 -0
  27. atdata/manifest/_aggregates.py +156 -0
  28. atdata/manifest/_builder.py +163 -0
  29. atdata/manifest/_fields.py +154 -0
  30. atdata/manifest/_manifest.py +146 -0
  31. atdata/manifest/_query.py +150 -0
  32. atdata/manifest/_writer.py +74 -0
  33. atdata/promote.py +4 -4
  34. atdata/providers/__init__.py +25 -0
  35. atdata/providers/_base.py +140 -0
  36. atdata/providers/_factory.py +69 -0
  37. atdata/providers/_postgres.py +214 -0
  38. atdata/providers/_redis.py +171 -0
  39. atdata/providers/_sqlite.py +191 -0
  40. atdata/repository.py +323 -0
  41. atdata/testing.py +337 -0
  42. {atdata-0.2.3b1.dist-info → atdata-0.3.0b1.dist-info}/METADATA +4 -1
  43. atdata-0.3.0b1.dist-info/RECORD +54 -0
  44. atdata/local.py +0 -1720
  45. atdata-0.2.3b1.dist-info/RECORD +0 -28
  46. {atdata-0.2.3b1.dist-info → atdata-0.3.0b1.dist-info}/WHEEL +0 -0
  47. {atdata-0.2.3b1.dist-info → atdata-0.3.0b1.dist-info}/entry_points.txt +0 -0
  48. {atdata-0.2.3b1.dist-info → atdata-0.3.0b1.dist-info}/licenses/LICENSE +0 -0
atdata/dataset.py CHANGED
@@ -31,6 +31,7 @@ Examples:
31
31
  import webdataset as wds
32
32
 
33
33
  from pathlib import Path
34
+ import itertools
34
35
  import uuid
35
36
 
36
37
  import dataclasses
@@ -42,15 +43,16 @@ from dataclasses import (
42
43
  from abc import ABC
43
44
 
44
45
  from ._sources import URLSource
45
- from ._protocols import DataSource
46
+ from ._protocols import DataSource, Packable
47
+ from ._exceptions import SampleKeyError, PartialFailureError
46
48
 
47
- from tqdm import tqdm
48
49
  import numpy as np
49
50
  import pandas as pd
50
51
  import requests
51
52
 
52
53
  import typing
53
54
  from typing import (
55
+ TYPE_CHECKING,
54
56
  Any,
55
57
  Optional,
56
58
  Dict,
@@ -66,6 +68,9 @@ from typing import (
66
68
  dataclass_transform,
67
69
  overload,
68
70
  )
71
+
72
+ if TYPE_CHECKING:
73
+ from .manifest._query import SampleLocation
69
74
  from numpy.typing import NDArray
70
75
 
71
76
  import msgpack
@@ -157,37 +162,17 @@ class DictSample:
157
162
 
158
163
  @classmethod
159
164
  def from_data(cls, data: dict[str, Any]) -> "DictSample":
160
- """Create a DictSample from unpacked msgpack data.
161
-
162
- Args:
163
- data: Dictionary with field names as keys.
164
-
165
- Returns:
166
- New DictSample instance wrapping the data.
167
- """
165
+ """Create a DictSample from unpacked msgpack data."""
168
166
  return cls(_data=data)
169
167
 
170
168
  @classmethod
171
169
  def from_bytes(cls, bs: bytes) -> "DictSample":
172
- """Create a DictSample from raw msgpack bytes.
173
-
174
- Args:
175
- bs: Raw bytes from a msgpack-serialized sample.
176
-
177
- Returns:
178
- New DictSample instance with the unpacked data.
179
- """
170
+ """Create a DictSample from raw msgpack bytes."""
180
171
  return cls.from_data(ormsgpack.unpackb(bs))
181
172
 
182
173
  def __getattr__(self, name: str) -> Any:
183
174
  """Access a field by attribute name.
184
175
 
185
- Args:
186
- name: Field name to access.
187
-
188
- Returns:
189
- The field value.
190
-
191
176
  Raises:
192
177
  AttributeError: If the field doesn't exist.
193
178
  """
@@ -203,21 +188,9 @@ class DictSample:
203
188
  ) from None
204
189
 
205
190
  def __getitem__(self, key: str) -> Any:
206
- """Access a field by dict key.
207
-
208
- Args:
209
- key: Field name to access.
210
-
211
- Returns:
212
- The field value.
213
-
214
- Raises:
215
- KeyError: If the field doesn't exist.
216
- """
217
191
  return self._data[key]
218
192
 
219
193
  def __contains__(self, key: str) -> bool:
220
- """Check if a field exists."""
221
194
  return key in self._data
222
195
 
223
196
  def keys(self) -> list[str]:
@@ -225,23 +198,13 @@ class DictSample:
225
198
  return list(self._data.keys())
226
199
 
227
200
  def values(self) -> list[Any]:
228
- """Return list of field values."""
229
201
  return list(self._data.values())
230
202
 
231
203
  def items(self) -> list[tuple[str, Any]]:
232
- """Return list of (field_name, value) tuples."""
233
204
  return list(self._data.items())
234
205
 
235
206
  def get(self, key: str, default: Any = None) -> Any:
236
- """Get a field value with optional default.
237
-
238
- Args:
239
- key: Field name to access.
240
- default: Value to return if field doesn't exist.
241
-
242
- Returns:
243
- The field value or default.
244
- """
207
+ """Get a field value, returning *default* if missing."""
245
208
  return self._data.get(key, default)
246
209
 
247
210
  def to_dict(self) -> dict[str, Any]:
@@ -250,20 +213,12 @@ class DictSample:
250
213
 
251
214
  @property
252
215
  def packed(self) -> bytes:
253
- """Pack this sample's data into msgpack bytes.
254
-
255
- Returns:
256
- Raw msgpack bytes representing this sample's data.
257
- """
216
+ """Serialize to msgpack bytes."""
258
217
  return msgpack.packb(self._data)
259
218
 
260
219
  @property
261
220
  def as_wds(self) -> "WDSRawSample":
262
- """Pack this sample's data for writing to WebDataset.
263
-
264
- Returns:
265
- A dictionary with ``__key__`` and ``msgpack`` fields.
266
- """
221
+ """Serialize for writing to WebDataset (``__key__`` + ``msgpack``)."""
267
222
  return {
268
223
  "__key__": str(uuid.uuid1(0, 0)),
269
224
  "msgpack": self.packed,
@@ -300,31 +255,13 @@ class PackableSample(ABC):
300
255
 
301
256
  def _ensure_good(self):
302
257
  """Convert bytes to NDArray for fields annotated as NDArray or NDArray | None."""
303
-
304
- # Auto-convert known types when annotated
305
- # for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
306
258
  for field in dataclasses.fields(self):
307
- var_name = field.name
308
- var_type = field.type
309
-
310
- # Annotation for this variable is to be an NDArray
311
- if _is_possibly_ndarray_type(var_type):
312
- # ... so, we'll always auto-convert to numpy
313
-
314
- var_cur_value = getattr(self, var_name)
315
-
316
- # Execute the appropriate conversion for intermediate data
317
- # based on what is provided
318
-
319
- if isinstance(var_cur_value, np.ndarray):
320
- # Already the correct type, no conversion needed
259
+ if _is_possibly_ndarray_type(field.type):
260
+ value = getattr(self, field.name)
261
+ if isinstance(value, np.ndarray):
321
262
  continue
322
-
323
- elif isinstance(var_cur_value, bytes):
324
- # Design note: bytes in NDArray-typed fields are always interpreted
325
- # as serialized arrays. This means raw bytes fields must not be
326
- # annotated as NDArray.
327
- setattr(self, var_name, eh.bytes_to_array(var_cur_value))
263
+ elif isinstance(value, bytes):
264
+ setattr(self, field.name, eh.bytes_to_array(value))
328
265
 
329
266
  def __post_init__(self):
330
267
  self._ensure_good()
@@ -333,67 +270,31 @@ class PackableSample(ABC):
333
270
 
334
271
  @classmethod
335
272
  def from_data(cls, data: WDSRawSample) -> Self:
336
- """Create a sample instance from unpacked msgpack data.
337
-
338
- Args:
339
- data: Dictionary with keys matching the sample's field names.
340
-
341
- Returns:
342
- New instance with NDArray fields auto-converted from bytes.
343
- """
273
+ """Create an instance from unpacked msgpack data."""
344
274
  return cls(**data)
345
275
 
346
276
  @classmethod
347
277
  def from_bytes(cls, bs: bytes) -> Self:
348
- """Create a sample instance from raw msgpack bytes.
349
-
350
- Args:
351
- bs: Raw bytes from a msgpack-serialized sample.
352
-
353
- Returns:
354
- A new instance of this sample class deserialized from the bytes.
355
- """
278
+ """Create an instance from raw msgpack bytes."""
356
279
  return cls.from_data(ormsgpack.unpackb(bs))
357
280
 
358
281
  @property
359
282
  def packed(self) -> bytes:
360
- """Pack this sample's data into msgpack bytes.
361
-
362
- NDArray fields are automatically converted to bytes before packing.
363
- All other fields are packed as-is if they're msgpack-compatible.
364
-
365
- Returns:
366
- Raw msgpack bytes representing this sample's data.
283
+ """Serialize to msgpack bytes. NDArray fields are auto-converted.
367
284
 
368
285
  Raises:
369
286
  RuntimeError: If msgpack serialization fails.
370
287
  """
371
-
372
- # Make sure that all of our (possibly unpackable) data is in a packable
373
- # format
374
288
  o = {k: _make_packable(v) for k, v in vars(self).items()}
375
-
376
289
  ret = msgpack.packb(o)
377
-
378
290
  if ret is None:
379
291
  raise RuntimeError(f"Failed to pack sample to bytes: {o}")
380
-
381
292
  return ret
382
293
 
383
294
  @property
384
295
  def as_wds(self) -> WDSRawSample:
385
- """Pack this sample's data for writing to WebDataset.
386
-
387
- Returns:
388
- A dictionary with ``__key__`` (UUID v1 for sortable keys) and
389
- ``msgpack`` (packed sample data) fields suitable for WebDataset.
390
-
391
- Note:
392
- Keys are auto-generated as UUID v1 for time-sortable ordering.
393
- Custom key specification is not currently supported.
394
- """
296
+ """Serialize for writing to WebDataset (``__key__`` + ``msgpack``)."""
395
297
  return {
396
- # Generates a UUID that is timelike-sortable
397
298
  "__key__": str(uuid.uuid1(0, 0)),
398
299
  "msgpack": self.packed,
399
300
  }
@@ -411,75 +312,38 @@ def _batch_aggregate(xs: Sequence):
411
312
  class SampleBatch(Generic[DT]):
412
313
  """A batch of samples with automatic attribute aggregation.
413
314
 
414
- This class wraps a sequence of samples and provides magic ``__getattr__``
415
- access to aggregate sample attributes. When you access an attribute that
416
- exists on the sample type, it automatically aggregates values across all
417
- samples in the batch.
418
-
419
- NDArray fields are stacked into a numpy array with a batch dimension.
420
- Other fields are aggregated into a list.
315
+ Accessing an attribute aggregates that field across all samples:
316
+ NDArray fields are stacked into a numpy array with a batch dimension;
317
+ other fields are collected into a list. Results are cached.
421
318
 
422
319
  Parameters:
423
320
  DT: The sample type, must derive from ``PackableSample``.
424
321
 
425
- Attributes:
426
- samples: The list of sample instances in this batch.
427
-
428
322
  Examples:
429
323
  >>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
430
- >>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
431
- >>> batch.names # Returns list of names
432
-
433
- Note:
434
- This class uses Python's ``__orig_class__`` mechanism to extract the
435
- type parameter at runtime. Instances must be created using the
436
- subscripted syntax ``SampleBatch[MyType](samples)`` rather than
437
- calling the constructor directly with an unsubscripted class.
324
+ >>> batch.embeddings # Stacked numpy array of shape (3, ...)
325
+ >>> batch.names # List of names
438
326
  """
439
327
 
440
- # Design note: The docstring uses "Parameters:" for type parameters because
441
- # quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
442
-
443
328
  def __init__(self, samples: Sequence[DT]):
444
- """Create a batch from a sequence of samples.
445
-
446
- Args:
447
- samples: A sequence of sample instances to aggregate into a batch.
448
- Each sample must be an instance of a type derived from
449
- ``PackableSample``.
450
- """
329
+ """Create a batch from a sequence of samples."""
451
330
  self.samples = list(samples)
452
331
  self._aggregate_cache = dict()
453
332
  self._sample_type_cache: Type | None = None
454
333
 
455
334
  @property
456
335
  def sample_type(self) -> Type:
457
- """The type of each sample in this batch.
458
-
459
- Returns:
460
- The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
461
- """
336
+ """The type parameter ``DT`` used when creating this batch."""
462
337
  if self._sample_type_cache is None:
463
338
  self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
464
- assert self._sample_type_cache is not None
339
+ if self._sample_type_cache is None:
340
+ raise TypeError(
341
+ "SampleBatch requires a type parameter, e.g. SampleBatch[MySample]"
342
+ )
465
343
  return self._sample_type_cache
466
344
 
467
345
  def __getattr__(self, name):
468
- """Aggregate an attribute across all samples in the batch.
469
-
470
- This magic method enables attribute-style access to aggregated sample
471
- fields. Results are cached for efficiency.
472
-
473
- Args:
474
- name: The attribute name to aggregate across samples.
475
-
476
- Returns:
477
- For NDArray fields: a stacked numpy array with batch dimension.
478
- For other fields: a list of values from each sample.
479
-
480
- Raises:
481
- AttributeError: If the attribute doesn't exist on the sample type.
482
- """
346
+ """Aggregate a field across all samples (cached)."""
483
347
  # Aggregate named params of sample type
484
348
  if name in vars(self.sample_type)["__annotations__"]:
485
349
  if name not in self._aggregate_cache:
@@ -492,8 +356,8 @@ class SampleBatch(Generic[DT]):
492
356
  raise AttributeError(f"No sample attribute named {name}")
493
357
 
494
358
 
495
- ST = TypeVar("ST", bound=PackableSample)
496
- RT = TypeVar("RT", bound=PackableSample)
359
+ ST = TypeVar("ST", bound=Packable)
360
+ RT = TypeVar("RT", bound=Packable)
497
361
 
498
362
 
499
363
  class _ShardListStage(wds.utils.PipelineStage):
@@ -571,23 +435,18 @@ class Dataset(Generic[ST]):
571
435
 
572
436
  @property
573
437
  def sample_type(self) -> Type:
574
- """The type of each returned sample from this dataset's iterator.
575
-
576
- Returns:
577
- The type parameter ``ST`` used when creating this ``Dataset[ST]``.
578
- """
438
+ """The type parameter ``ST`` used when creating this dataset."""
579
439
  if self._sample_type_cache is None:
580
440
  self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
581
- assert self._sample_type_cache is not None
441
+ if self._sample_type_cache is None:
442
+ raise TypeError(
443
+ "Dataset requires a type parameter, e.g. Dataset[MySample]"
444
+ )
582
445
  return self._sample_type_cache
583
446
 
584
447
  @property
585
448
  def batch_type(self) -> Type:
586
- """The type of batches produced by this dataset.
587
-
588
- Returns:
589
- ``SampleBatch[ST]`` where ``ST`` is this dataset's sample type.
590
- """
449
+ """``SampleBatch[ST]`` where ``ST`` is this dataset's sample type."""
591
450
  return SampleBatch[self.sample_type]
592
451
 
593
452
  def __init__(
@@ -614,28 +473,21 @@ class Dataset(Generic[ST]):
614
473
  """
615
474
  super().__init__()
616
475
 
617
- # Handle backward compatibility: url= keyword argument
618
476
  if source is None and url is not None:
619
477
  source = url
620
478
  elif source is None:
621
479
  raise TypeError("Dataset() missing required argument: 'source' or 'url'")
622
480
 
623
- # Normalize source: strings become URLSource for backward compatibility
624
481
  if isinstance(source, str):
625
482
  self._source: DataSource = URLSource(source)
626
483
  self.url = source
627
484
  else:
628
485
  self._source = source
629
- # For compatibility, expose URL if source has list_shards
630
486
  shards = source.list_shards()
631
- # Design note: Using first shard as url for legacy compatibility.
632
- # Full shard list is available via list_shards() method.
633
487
  self.url = shards[0] if shards else ""
634
488
 
635
489
  self._metadata: dict[str, Any] | None = None
636
490
  self.metadata_url: str | None = metadata_url
637
- """Optional URL to msgpack-encoded metadata for this dataset."""
638
-
639
491
  self._output_lens: Lens | None = None
640
492
  self._sample_type_cache: Type | None = None
641
493
 
@@ -645,47 +497,23 @@ class Dataset(Generic[ST]):
645
497
  return self._source
646
498
 
647
499
  def as_type(self, other: Type[RT]) -> "Dataset[RT]":
648
- """View this dataset through a different sample type using a registered lens.
649
-
650
- Args:
651
- other: The target sample type to transform into. Must be a type
652
- derived from ``PackableSample``.
653
-
654
- Returns:
655
- A new ``Dataset`` instance that yields samples of type ``other``
656
- by applying the appropriate lens transformation from the global
657
- ``LensNetwork`` registry.
500
+ """View this dataset through a different sample type via a registered lens.
658
501
 
659
502
  Raises:
660
- ValueError: If no registered lens exists between the current
661
- sample type and the target type.
503
+ ValueError: If no lens exists between the current and target types.
662
504
  """
663
505
  ret = Dataset[other](self._source)
664
- # Get the singleton lens registry
665
506
  lenses = LensNetwork()
666
507
  ret._output_lens = lenses.transform(self.sample_type, ret.sample_type)
667
508
  return ret
668
509
 
669
510
  @property
670
511
  def shards(self) -> Iterator[str]:
671
- """Lazily iterate over shard identifiers.
672
-
673
- Yields:
674
- Shard identifiers (e.g., 'train-000000.tar', 'train-000001.tar').
675
-
676
- Examples:
677
- >>> for shard in ds.shards:
678
- ... print(f"Processing {shard}")
679
- """
512
+ """Lazily iterate over shard identifiers."""
680
513
  return iter(self._source.list_shards())
681
514
 
682
515
  def list_shards(self) -> list[str]:
683
- """Get list of individual dataset shards.
684
-
685
- Returns:
686
- A full (non-lazy) list of the individual ``tar`` files within the
687
- source WebDataset.
688
- """
516
+ """Return all shard paths/URLs as a list."""
689
517
  return self._source.list_shards()
690
518
 
691
519
  # Legacy alias for backwards compatibility
@@ -707,14 +535,7 @@ class Dataset(Generic[ST]):
707
535
 
708
536
  @property
709
537
  def metadata(self) -> dict[str, Any] | None:
710
- """Fetch and cache metadata from metadata_url.
711
-
712
- Returns:
713
- Deserialized metadata dictionary, or None if no metadata_url is set.
714
-
715
- Raises:
716
- requests.HTTPError: If metadata fetch fails.
717
- """
538
+ """Fetch and cache metadata from metadata_url, or ``None`` if unset."""
718
539
  if self.metadata_url is None:
719
540
  return None
720
541
 
@@ -726,6 +547,367 @@ class Dataset(Generic[ST]):
726
547
  # Use our cached values
727
548
  return self._metadata
728
549
 
550
+ ##
551
+ # Convenience methods (GH#38 developer experience)
552
+
553
+ @property
554
+ def schema(self) -> dict[str, type]:
555
+ """Field names and types for this dataset's sample type.
556
+
557
+ Examples:
558
+ >>> ds = Dataset[MyData]("data.tar")
559
+ >>> ds.schema
560
+ {'name': <class 'str'>, 'embedding': numpy.ndarray}
561
+ """
562
+ st = self.sample_type
563
+ if st is DictSample:
564
+ return {"_data": dict}
565
+ if dataclasses.is_dataclass(st):
566
+ return {f.name: f.type for f in dataclasses.fields(st)}
567
+ return {}
568
+
569
+ @property
570
+ def column_names(self) -> list[str]:
571
+ """List of field names for this dataset's sample type."""
572
+ st = self.sample_type
573
+ if dataclasses.is_dataclass(st):
574
+ return [f.name for f in dataclasses.fields(st)]
575
+ return []
576
+
577
+ def __iter__(self) -> Iterator[ST]:
578
+ """Shorthand for ``ds.ordered()``."""
579
+ return iter(self.ordered())
580
+
581
+ def __len__(self) -> int:
582
+ """Total sample count (iterates all shards on first call, then cached)."""
583
+ if not hasattr(self, "_len_cache"):
584
+ self._len_cache: int = sum(1 for _ in self.ordered())
585
+ return self._len_cache
586
+
587
+ def head(self, n: int = 5) -> list[ST]:
588
+ """Return the first *n* samples from the dataset.
589
+
590
+ Args:
591
+ n: Number of samples to return. Default: 5.
592
+
593
+ Returns:
594
+ List of up to *n* samples in shard order.
595
+
596
+ Examples:
597
+ >>> samples = ds.head(3)
598
+ >>> len(samples)
599
+ 3
600
+ """
601
+ return list(itertools.islice(self.ordered(), n))
602
+
603
+ def get(self, key: str) -> ST:
604
+ """Retrieve a single sample by its ``__key__``.
605
+
606
+ Scans shards sequentially until a sample with a matching key is found.
607
+ This is O(n) for streaming datasets.
608
+
609
+ Args:
610
+ key: The WebDataset ``__key__`` string to search for.
611
+
612
+ Returns:
613
+ The matching sample.
614
+
615
+ Raises:
616
+ SampleKeyError: If no sample with the given key exists.
617
+
618
+ Examples:
619
+ >>> sample = ds.get("00000001-0001-1000-8000-010000000000")
620
+ """
621
+ pipeline = wds.pipeline.DataPipeline(
622
+ _ShardListStage(self._source),
623
+ wds.shardlists.split_by_worker,
624
+ _StreamOpenerStage(self._source),
625
+ wds.tariterators.tar_file_expander,
626
+ wds.tariterators.group_by_keys,
627
+ )
628
+ for raw_sample in pipeline:
629
+ if raw_sample.get("__key__") == key:
630
+ return self.wrap(raw_sample)
631
+ raise SampleKeyError(key)
632
+
633
+ def describe(self) -> dict[str, Any]:
634
+ """Summary statistics: sample_type, fields, num_shards, shards, url, metadata."""
635
+ shards = self.list_shards()
636
+ return {
637
+ "sample_type": self.sample_type.__name__,
638
+ "fields": self.schema,
639
+ "num_shards": len(shards),
640
+ "shards": shards,
641
+ "url": self.url,
642
+ "metadata": self.metadata,
643
+ }
644
+
645
+ def filter(self, predicate: Callable[[ST], bool]) -> "Dataset[ST]":
646
+ """Return a new dataset that yields only samples matching *predicate*.
647
+
648
+ The filter is applied lazily during iteration — no data is copied.
649
+
650
+ Args:
651
+ predicate: A function that takes a sample and returns ``True``
652
+ to keep it or ``False`` to discard it.
653
+
654
+ Returns:
655
+ A new ``Dataset`` whose iterators apply the filter.
656
+
657
+ Examples:
658
+ >>> long_names = ds.filter(lambda s: len(s.name) > 10)
659
+ >>> for sample in long_names:
660
+ ... assert len(sample.name) > 10
661
+ """
662
+ filtered = Dataset[self.sample_type](self._source, self.metadata_url)
663
+ filtered._sample_type_cache = self._sample_type_cache
664
+ filtered._output_lens = self._output_lens
665
+ filtered._filter_fn = predicate
666
+ # Preserve any existing filters
667
+ parent_filters = getattr(self, "_filter_fn", None)
668
+ if parent_filters is not None:
669
+ outer = parent_filters
670
+ filtered._filter_fn = lambda s: outer(s) and predicate(s)
671
+ # Preserve any existing map
672
+ if hasattr(self, "_map_fn"):
673
+ filtered._map_fn = self._map_fn
674
+ return filtered
675
+
676
+ def map(self, fn: Callable[[ST], Any]) -> "Dataset":
677
+ """Return a new dataset that applies *fn* to each sample during iteration.
678
+
679
+ The mapping is applied lazily during iteration — no data is copied.
680
+
681
+ Args:
682
+ fn: A function that takes a sample of type ``ST`` and returns
683
+ a transformed value.
684
+
685
+ Returns:
686
+ A new ``Dataset`` whose iterators apply the mapping.
687
+
688
+ Examples:
689
+ >>> names = ds.map(lambda s: s.name)
690
+ >>> for name in names:
691
+ ... print(name)
692
+ """
693
+ mapped = Dataset[self.sample_type](self._source, self.metadata_url)
694
+ mapped._sample_type_cache = self._sample_type_cache
695
+ mapped._output_lens = self._output_lens
696
+ mapped._map_fn = fn
697
+ # Preserve any existing map
698
+ if hasattr(self, "_map_fn"):
699
+ outer = self._map_fn
700
+ mapped._map_fn = lambda s: fn(outer(s))
701
+ # Preserve any existing filter
702
+ if hasattr(self, "_filter_fn"):
703
+ mapped._filter_fn = self._filter_fn
704
+ return mapped
705
+
706
+ def process_shards(
707
+ self,
708
+ fn: Callable[[list[ST]], Any],
709
+ *,
710
+ shards: list[str] | None = None,
711
+ ) -> dict[str, Any]:
712
+ """Process each shard independently, collecting per-shard results.
713
+
714
+ Unlike :meth:`map` (which is lazy and per-sample), this method eagerly
715
+ processes each shard in turn, calling *fn* with the full list of samples
716
+ from that shard. If some shards fail, raises
717
+ :class:`~atdata._exceptions.PartialFailureError` containing both the
718
+ successful results and the per-shard errors.
719
+
720
+ Args:
721
+ fn: Function receiving a list of samples from one shard and
722
+ returning an arbitrary result.
723
+ shards: Optional list of shard identifiers to process. If ``None``,
724
+ processes all shards in the dataset. Useful for retrying only
725
+ the failed shards from a previous ``PartialFailureError``.
726
+
727
+ Returns:
728
+ Dict mapping shard identifier to *fn*'s return value for each shard.
729
+
730
+ Raises:
731
+ PartialFailureError: If at least one shard fails. The exception
732
+ carries ``.succeeded_shards``, ``.failed_shards``, ``.errors``,
733
+ and ``.results`` for inspection and retry.
734
+
735
+ Examples:
736
+ >>> results = ds.process_shards(lambda samples: len(samples))
737
+ >>> # On partial failure, retry just the failed shards:
738
+ >>> try:
739
+ ... results = ds.process_shards(expensive_fn)
740
+ ... except PartialFailureError as e:
741
+ ... retry = ds.process_shards(expensive_fn, shards=e.failed_shards)
742
+ """
743
+ from ._logging import get_logger
744
+
745
+ log = get_logger()
746
+ shard_ids = shards or self.list_shards()
747
+ log.info("process_shards: starting %d shards", len(shard_ids))
748
+
749
+ succeeded: list[str] = []
750
+ failed: list[str] = []
751
+ errors: dict[str, Exception] = {}
752
+ results: dict[str, Any] = {}
753
+
754
+ for shard_id in shard_ids:
755
+ try:
756
+ shard_ds = Dataset[self.sample_type](shard_id)
757
+ shard_ds._sample_type_cache = self._sample_type_cache
758
+ samples = list(shard_ds.ordered())
759
+ results[shard_id] = fn(samples)
760
+ succeeded.append(shard_id)
761
+ log.debug("process_shards: shard ok %s", shard_id)
762
+ except Exception as exc:
763
+ failed.append(shard_id)
764
+ errors[shard_id] = exc
765
+ log.warning("process_shards: shard failed %s: %s", shard_id, exc)
766
+
767
+ if failed:
768
+ log.error(
769
+ "process_shards: %d/%d shards failed",
770
+ len(failed),
771
+ len(shard_ids),
772
+ )
773
+ raise PartialFailureError(
774
+ succeeded_shards=succeeded,
775
+ failed_shards=failed,
776
+ errors=errors,
777
+ results=results,
778
+ )
779
+
780
+ log.info("process_shards: all %d shards succeeded", len(shard_ids))
781
+ return results
782
+
783
+ def select(self, indices: Sequence[int]) -> list[ST]:
784
+ """Return samples at the given integer indices.
785
+
786
+ Iterates through the dataset in order and collects samples whose
787
+ positional index matches. This is O(n) for streaming datasets.
788
+
789
+ Args:
790
+ indices: Sequence of zero-based indices to select.
791
+
792
+ Returns:
793
+ List of samples at the requested positions, in index order.
794
+
795
+ Examples:
796
+ >>> samples = ds.select([0, 5, 10])
797
+ >>> len(samples)
798
+ 3
799
+ """
800
+ if not indices:
801
+ return []
802
+ target = set(indices)
803
+ max_idx = max(indices)
804
+ result: dict[int, ST] = {}
805
+ for i, sample in enumerate(self.ordered()):
806
+ if i in target:
807
+ result[i] = sample
808
+ if i >= max_idx:
809
+ break
810
+ return [result[i] for i in indices if i in result]
811
+
812
+ def query(
813
+ self,
814
+ where: "Callable[[pd.DataFrame], pd.Series]",
815
+ ) -> "list[SampleLocation]":
816
+ """Query this dataset using per-shard manifest metadata.
817
+
818
+ Requires manifests to have been generated during shard writing.
819
+ Discovers manifest files alongside the tar shards, loads them,
820
+ and executes a two-phase query (shard-level aggregate pruning,
821
+ then sample-level parquet filtering).
822
+
823
+ Args:
824
+ where: Predicate function that receives a pandas DataFrame
825
+ of manifest fields and returns a boolean Series selecting
826
+ matching rows.
827
+
828
+ Returns:
829
+ List of ``SampleLocation`` for matching samples.
830
+
831
+ Raises:
832
+ FileNotFoundError: If no manifest files are found alongside shards.
833
+
834
+ Examples:
835
+ >>> locs = ds.query(where=lambda df: df["confidence"] > 0.9)
836
+ >>> len(locs)
837
+ 42
838
+ """
839
+ from .manifest import QueryExecutor
840
+
841
+ shard_urls = self.list_shards()
842
+ executor = QueryExecutor.from_shard_urls(shard_urls)
843
+ return executor.query(where=where)
844
+
845
+ def to_pandas(self, limit: int | None = None) -> "pd.DataFrame":
846
+ """Materialize the dataset (or first *limit* samples) as a DataFrame.
847
+
848
+ Args:
849
+ limit: Maximum number of samples to include. ``None`` means all
850
+ samples (may use significant memory for large datasets).
851
+
852
+ Returns:
853
+ A pandas DataFrame with one row per sample and columns matching
854
+ the sample fields.
855
+
856
+ Warning:
857
+ With ``limit=None`` this loads the entire dataset into memory.
858
+
859
+ Examples:
860
+ >>> df = ds.to_pandas(limit=100)
861
+ >>> df.columns.tolist()
862
+ ['name', 'embedding']
863
+ """
864
+ samples = self.head(limit) if limit is not None else list(self.ordered())
865
+ rows = [
866
+ asdict(s) if dataclasses.is_dataclass(s) else s.to_dict() for s in samples
867
+ ]
868
+ return pd.DataFrame(rows)
869
+
870
+ def to_dict(self, limit: int | None = None) -> dict[str, list[Any]]:
871
+ """Materialize the dataset as a column-oriented dictionary.
872
+
873
+ Args:
874
+ limit: Maximum number of samples to include. ``None`` means all.
875
+
876
+ Returns:
877
+ Dictionary mapping field names to lists of values (one entry
878
+ per sample).
879
+
880
+ Warning:
881
+ With ``limit=None`` this loads the entire dataset into memory.
882
+
883
+ Examples:
884
+ >>> d = ds.to_dict(limit=10)
885
+ >>> d.keys()
886
+ dict_keys(['name', 'embedding'])
887
+ >>> len(d['name'])
888
+ 10
889
+ """
890
+ samples = self.head(limit) if limit is not None else list(self.ordered())
891
+ if not samples:
892
+ return {}
893
+ if dataclasses.is_dataclass(samples[0]):
894
+ fields = [f.name for f in dataclasses.fields(samples[0])]
895
+ return {f: [getattr(s, f) for s in samples] for f in fields}
896
+ # DictSample path
897
+ keys = samples[0].keys()
898
+ return {k: [s[k] for s in samples] for k in keys}
899
+
900
+ def _post_wrap_stages(self) -> list:
901
+ """Build extra pipeline stages for filter/map set via .filter()/.map()."""
902
+ stages: list = []
903
+ filter_fn = getattr(self, "_filter_fn", None)
904
+ if filter_fn is not None:
905
+ stages.append(wds.filters.select(filter_fn))
906
+ map_fn = getattr(self, "_map_fn", None)
907
+ if map_fn is not None:
908
+ stages.append(wds.filters.map(map_fn))
909
+ return stages
910
+
729
911
  @overload
730
912
  def ordered(
731
913
  self,
@@ -769,6 +951,7 @@ class Dataset(Generic[ST]):
769
951
  wds.tariterators.tar_file_expander,
770
952
  wds.tariterators.group_by_keys,
771
953
  wds.filters.map(self.wrap),
954
+ *self._post_wrap_stages(),
772
955
  )
773
956
 
774
957
  return wds.pipeline.DataPipeline(
@@ -839,6 +1022,7 @@ class Dataset(Generic[ST]):
839
1022
  wds.tariterators.group_by_keys,
840
1023
  wds.filters.shuffle(buffer_samples),
841
1024
  wds.filters.map(self.wrap),
1025
+ *self._post_wrap_stages(),
842
1026
  )
843
1027
 
844
1028
  return wds.pipeline.DataPipeline(
@@ -862,100 +1046,47 @@ class Dataset(Generic[ST]):
862
1046
  maxcount: Optional[int] = None,
863
1047
  **kwargs,
864
1048
  ):
865
- """Export dataset contents to parquet format.
866
-
867
- Converts all samples to a pandas DataFrame and saves to parquet file(s).
868
- Useful for interoperability with data analysis tools.
1049
+ """Export dataset to parquet file(s).
869
1050
 
870
1051
  Args:
871
- path: Output path for the parquet file. If ``maxcount`` is specified,
872
- files are named ``{stem}-{segment:06d}.parquet``.
873
- sample_map: Optional function to convert samples to dictionaries.
874
- Defaults to ``dataclasses.asdict``.
875
- maxcount: If specified, split output into multiple files with at most
876
- this many samples each. Recommended for large datasets.
877
- **kwargs: Additional arguments passed to ``pandas.DataFrame.to_parquet()``.
878
- Common options include ``compression``, ``index``, ``engine``.
879
-
880
- Warning:
881
- **Memory Usage**: When ``maxcount=None`` (default), this method loads
882
- the **entire dataset into memory** as a pandas DataFrame before writing.
883
- For large datasets, this can cause memory exhaustion.
884
-
885
- For datasets larger than available RAM, always specify ``maxcount``::
886
-
887
- # Safe for large datasets - processes in chunks
888
- ds.to_parquet("output.parquet", maxcount=10000)
889
-
890
- This creates multiple parquet files: ``output-000000.parquet``,
891
- ``output-000001.parquet``, etc.
1052
+ path: Output path. With *maxcount*, files are named
1053
+ ``{stem}-{segment:06d}.parquet``.
1054
+ sample_map: Convert sample to dict. Defaults to ``dataclasses.asdict``.
1055
+ maxcount: Split into files of at most this many samples.
1056
+ Without it, the entire dataset is loaded into memory.
1057
+ **kwargs: Passed to ``pandas.DataFrame.to_parquet()``.
892
1058
 
893
1059
  Examples:
894
- >>> ds = Dataset[MySample]("data.tar")
895
- >>> # Small dataset - load all at once
896
- >>> ds.to_parquet("output.parquet")
897
- >>>
898
- >>> # Large dataset - process in chunks
899
1060
  >>> ds.to_parquet("output.parquet", maxcount=50000)
900
1061
  """
901
- ##
902
-
903
- # Normalize args
904
1062
  path = Path(path)
905
1063
  if sample_map is None:
906
1064
  sample_map = asdict
907
1065
 
908
- verbose = kwargs.get("verbose", False)
909
-
910
- it = self.ordered(batch_size=None)
911
- if verbose:
912
- it = tqdm(it)
913
-
914
- #
915
-
916
1066
  if maxcount is None:
917
- # Load and save full dataset
918
1067
  df = pd.DataFrame([sample_map(x) for x in self.ordered(batch_size=None)])
919
1068
  df.to_parquet(path, **kwargs)
920
-
921
1069
  else:
922
- # Load and save dataset in segments of size `maxcount`
923
-
924
1070
  cur_segment = 0
925
- cur_buffer = []
1071
+ cur_buffer: list = []
926
1072
  path_template = (
927
1073
  path.parent / f"{path.stem}-{{:06d}}{path.suffix}"
928
1074
  ).as_posix()
929
1075
 
930
1076
  for x in self.ordered(batch_size=None):
931
1077
  cur_buffer.append(sample_map(x))
932
-
933
1078
  if len(cur_buffer) >= maxcount:
934
- # Write current segment
935
1079
  cur_path = path_template.format(cur_segment)
936
- df = pd.DataFrame(cur_buffer)
937
- df.to_parquet(cur_path, **kwargs)
938
-
1080
+ pd.DataFrame(cur_buffer).to_parquet(cur_path, **kwargs)
939
1081
  cur_segment += 1
940
1082
  cur_buffer = []
941
1083
 
942
- if len(cur_buffer) > 0:
943
- # Write one last segment with remainder
1084
+ if cur_buffer:
944
1085
  cur_path = path_template.format(cur_segment)
945
- df = pd.DataFrame(cur_buffer)
946
- df.to_parquet(cur_path, **kwargs)
1086
+ pd.DataFrame(cur_buffer).to_parquet(cur_path, **kwargs)
947
1087
 
948
1088
  def wrap(self, sample: WDSRawSample) -> ST:
949
- """Wrap a raw msgpack sample into the appropriate dataset-specific type.
950
-
951
- Args:
952
- sample: A dictionary containing at minimum a ``'msgpack'`` key with
953
- serialized sample bytes.
954
-
955
- Returns:
956
- A deserialized sample of type ``ST``, optionally transformed through
957
- a lens if ``as_type()`` was called.
958
- """
1089
+ """Deserialize a raw WDS sample dict into type ``ST``."""
959
1090
  if "msgpack" not in sample:
960
1091
  raise ValueError(
961
1092
  f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}"
@@ -972,20 +1103,7 @@ class Dataset(Generic[ST]):
972
1103
  return self._output_lens(source_sample)
973
1104
 
974
1105
  def wrap_batch(self, batch: WDSRawBatch) -> SampleBatch[ST]:
975
- """Wrap a batch of raw msgpack samples into a typed SampleBatch.
976
-
977
- Args:
978
- batch: A dictionary containing a ``'msgpack'`` key with a list of
979
- serialized sample bytes.
980
-
981
- Returns:
982
- A ``SampleBatch[ST]`` containing deserialized samples, optionally
983
- transformed through a lens if ``as_type()`` was called.
984
-
985
- Note:
986
- This implementation deserializes samples one at a time, then
987
- aggregates them into a batch.
988
- """
1106
+ """Deserialize a raw WDS batch dict into ``SampleBatch[ST]``."""
989
1107
 
990
1108
  if "msgpack" not in batch:
991
1109
  raise ValueError(
@@ -1009,24 +1127,12 @@ _T = TypeVar("_T")
1009
1127
 
1010
1128
 
1011
1129
  @dataclass_transform()
1012
- def packable(cls: type[_T]) -> type[_T]:
1013
- """Decorator to convert a regular class into a ``PackableSample``.
1014
-
1015
- This decorator transforms a class into a dataclass that inherits from
1016
- ``PackableSample``, enabling automatic msgpack serialization/deserialization
1017
- with special handling for NDArray fields.
1018
-
1019
- The resulting class satisfies the ``Packable`` protocol, making it compatible
1020
- with all atdata APIs that accept packable types (e.g., ``publish_schema``,
1021
- lens transformations, etc.).
1022
-
1023
- Args:
1024
- cls: The class to convert. Should have type annotations for its fields.
1130
+ def packable(cls: type[_T]) -> type[Packable]:
1131
+ """Convert a class into a ``PackableSample`` dataclass with msgpack serialization.
1025
1132
 
1026
- Returns:
1027
- A new dataclass that inherits from ``PackableSample`` with the same
1028
- name and annotations as the original class. The class satisfies the
1029
- ``Packable`` protocol and can be used with ``Type[Packable]`` signatures.
1133
+ The resulting class gains ``packed``, ``as_wds``, ``from_bytes``, and
1134
+ ``from_data`` methods, and satisfies the ``Packable`` protocol.
1135
+ NDArray fields are automatically handled during serialization.
1030
1136
 
1031
1137
  Examples:
1032
1138
  >>> @packable
@@ -1035,11 +1141,7 @@ def packable(cls: type[_T]) -> type[_T]:
1035
1141
  ... values: NDArray
1036
1142
  ...
1037
1143
  >>> sample = MyData(name="test", values=np.array([1, 2, 3]))
1038
- >>> bytes_data = sample.packed
1039
- >>> restored = MyData.from_bytes(bytes_data)
1040
- >>>
1041
- >>> # Works with Packable-typed APIs
1042
- >>> index.publish_schema(MyData, version="1.0.0") # Type-safe
1144
+ >>> restored = MyData.from_bytes(sample.packed)
1043
1145
  """
1044
1146
 
1045
1147
  ##