atdata 0.2.3b1__py3-none-any.whl → 0.3.1b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. atdata/.gitignore +1 -0
  2. atdata/__init__.py +39 -0
  3. atdata/_cid.py +0 -21
  4. atdata/_exceptions.py +168 -0
  5. atdata/_helpers.py +41 -15
  6. atdata/_hf_api.py +95 -11
  7. atdata/_logging.py +70 -0
  8. atdata/_protocols.py +77 -238
  9. atdata/_schema_codec.py +7 -6
  10. atdata/_stub_manager.py +5 -25
  11. atdata/_type_utils.py +28 -2
  12. atdata/atmosphere/__init__.py +31 -20
  13. atdata/atmosphere/_types.py +4 -4
  14. atdata/atmosphere/client.py +64 -12
  15. atdata/atmosphere/lens.py +11 -12
  16. atdata/atmosphere/records.py +12 -12
  17. atdata/atmosphere/schema.py +16 -18
  18. atdata/atmosphere/store.py +6 -7
  19. atdata/cli/__init__.py +161 -175
  20. atdata/cli/diagnose.py +2 -2
  21. atdata/cli/{local.py → infra.py} +11 -11
  22. atdata/cli/inspect.py +69 -0
  23. atdata/cli/preview.py +63 -0
  24. atdata/cli/schema.py +109 -0
  25. atdata/dataset.py +583 -328
  26. atdata/index/__init__.py +54 -0
  27. atdata/index/_entry.py +157 -0
  28. atdata/index/_index.py +1198 -0
  29. atdata/index/_schema.py +380 -0
  30. atdata/lens.py +9 -2
  31. atdata/lexicons/__init__.py +121 -0
  32. atdata/lexicons/ac.foundation.dataset.arrayFormat.json +16 -0
  33. atdata/lexicons/ac.foundation.dataset.getLatestSchema.json +78 -0
  34. atdata/lexicons/ac.foundation.dataset.lens.json +99 -0
  35. atdata/lexicons/ac.foundation.dataset.record.json +96 -0
  36. atdata/lexicons/ac.foundation.dataset.schema.json +107 -0
  37. atdata/lexicons/ac.foundation.dataset.schemaType.json +16 -0
  38. atdata/lexicons/ac.foundation.dataset.storageBlobs.json +24 -0
  39. atdata/lexicons/ac.foundation.dataset.storageExternal.json +25 -0
  40. atdata/lexicons/ndarray_shim.json +16 -0
  41. atdata/local/__init__.py +70 -0
  42. atdata/local/_repo_legacy.py +218 -0
  43. atdata/manifest/__init__.py +28 -0
  44. atdata/manifest/_aggregates.py +156 -0
  45. atdata/manifest/_builder.py +163 -0
  46. atdata/manifest/_fields.py +154 -0
  47. atdata/manifest/_manifest.py +146 -0
  48. atdata/manifest/_query.py +150 -0
  49. atdata/manifest/_writer.py +74 -0
  50. atdata/promote.py +18 -14
  51. atdata/providers/__init__.py +25 -0
  52. atdata/providers/_base.py +140 -0
  53. atdata/providers/_factory.py +69 -0
  54. atdata/providers/_postgres.py +214 -0
  55. atdata/providers/_redis.py +171 -0
  56. atdata/providers/_sqlite.py +191 -0
  57. atdata/repository.py +323 -0
  58. atdata/stores/__init__.py +23 -0
  59. atdata/stores/_disk.py +123 -0
  60. atdata/stores/_s3.py +349 -0
  61. atdata/testing.py +341 -0
  62. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/METADATA +5 -2
  63. atdata-0.3.1b1.dist-info/RECORD +67 -0
  64. atdata/local.py +0 -1720
  65. atdata-0.2.3b1.dist-info/RECORD +0 -28
  66. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/WHEEL +0 -0
  67. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/entry_points.txt +0 -0
  68. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/licenses/LICENSE +0 -0
atdata/dataset.py CHANGED
@@ -31,6 +31,7 @@ Examples:
31
31
  import webdataset as wds
32
32
 
33
33
  from pathlib import Path
34
+ import itertools
34
35
  import uuid
35
36
 
36
37
  import dataclasses
@@ -42,15 +43,16 @@ from dataclasses import (
42
43
  from abc import ABC
43
44
 
44
45
  from ._sources import URLSource
45
- from ._protocols import DataSource
46
+ from ._protocols import DataSource, Packable
47
+ from ._exceptions import SampleKeyError, PartialFailureError
46
48
 
47
- from tqdm import tqdm
48
49
  import numpy as np
49
50
  import pandas as pd
50
51
  import requests
51
52
 
52
53
  import typing
53
54
  from typing import (
55
+ TYPE_CHECKING,
54
56
  Any,
55
57
  Optional,
56
58
  Dict,
@@ -66,6 +68,9 @@ from typing import (
66
68
  dataclass_transform,
67
69
  overload,
68
70
  )
71
+
72
+ if TYPE_CHECKING:
73
+ from .manifest._query import SampleLocation
69
74
  from numpy.typing import NDArray
70
75
 
71
76
  import msgpack
@@ -94,9 +99,11 @@ DT = TypeVar("DT")
94
99
 
95
100
 
96
101
  def _make_packable(x):
97
- """Convert numpy arrays to bytes; pass through other values unchanged."""
102
+ """Convert numpy arrays to bytes; coerce numpy scalars to Python natives."""
98
103
  if isinstance(x, np.ndarray):
99
104
  return eh.array_to_bytes(x)
105
+ if isinstance(x, np.generic):
106
+ return x.item()
100
107
  return x
101
108
 
102
109
 
@@ -157,37 +164,17 @@ class DictSample:
157
164
 
158
165
  @classmethod
159
166
  def from_data(cls, data: dict[str, Any]) -> "DictSample":
160
- """Create a DictSample from unpacked msgpack data.
161
-
162
- Args:
163
- data: Dictionary with field names as keys.
164
-
165
- Returns:
166
- New DictSample instance wrapping the data.
167
- """
167
+ """Create a DictSample from unpacked msgpack data."""
168
168
  return cls(_data=data)
169
169
 
170
170
  @classmethod
171
171
  def from_bytes(cls, bs: bytes) -> "DictSample":
172
- """Create a DictSample from raw msgpack bytes.
173
-
174
- Args:
175
- bs: Raw bytes from a msgpack-serialized sample.
176
-
177
- Returns:
178
- New DictSample instance with the unpacked data.
179
- """
172
+ """Create a DictSample from raw msgpack bytes."""
180
173
  return cls.from_data(ormsgpack.unpackb(bs))
181
174
 
182
175
  def __getattr__(self, name: str) -> Any:
183
176
  """Access a field by attribute name.
184
177
 
185
- Args:
186
- name: Field name to access.
187
-
188
- Returns:
189
- The field value.
190
-
191
178
  Raises:
192
179
  AttributeError: If the field doesn't exist.
193
180
  """
@@ -203,21 +190,9 @@ class DictSample:
203
190
  ) from None
204
191
 
205
192
  def __getitem__(self, key: str) -> Any:
206
- """Access a field by dict key.
207
-
208
- Args:
209
- key: Field name to access.
210
-
211
- Returns:
212
- The field value.
213
-
214
- Raises:
215
- KeyError: If the field doesn't exist.
216
- """
217
193
  return self._data[key]
218
194
 
219
195
  def __contains__(self, key: str) -> bool:
220
- """Check if a field exists."""
221
196
  return key in self._data
222
197
 
223
198
  def keys(self) -> list[str]:
@@ -225,23 +200,13 @@ class DictSample:
225
200
  return list(self._data.keys())
226
201
 
227
202
  def values(self) -> list[Any]:
228
- """Return list of field values."""
229
203
  return list(self._data.values())
230
204
 
231
205
  def items(self) -> list[tuple[str, Any]]:
232
- """Return list of (field_name, value) tuples."""
233
206
  return list(self._data.items())
234
207
 
235
208
  def get(self, key: str, default: Any = None) -> Any:
236
- """Get a field value with optional default.
237
-
238
- Args:
239
- key: Field name to access.
240
- default: Value to return if field doesn't exist.
241
-
242
- Returns:
243
- The field value or default.
244
- """
209
+ """Get a field value, returning *default* if missing."""
245
210
  return self._data.get(key, default)
246
211
 
247
212
  def to_dict(self) -> dict[str, Any]:
@@ -250,20 +215,12 @@ class DictSample:
250
215
 
251
216
  @property
252
217
  def packed(self) -> bytes:
253
- """Pack this sample's data into msgpack bytes.
254
-
255
- Returns:
256
- Raw msgpack bytes representing this sample's data.
257
- """
218
+ """Serialize to msgpack bytes."""
258
219
  return msgpack.packb(self._data)
259
220
 
260
221
  @property
261
222
  def as_wds(self) -> "WDSRawSample":
262
- """Pack this sample's data for writing to WebDataset.
263
-
264
- Returns:
265
- A dictionary with ``__key__`` and ``msgpack`` fields.
266
- """
223
+ """Serialize for writing to WebDataset (``__key__`` + ``msgpack``)."""
267
224
  return {
268
225
  "__key__": str(uuid.uuid1(0, 0)),
269
226
  "msgpack": self.packed,
@@ -300,31 +257,13 @@ class PackableSample(ABC):
300
257
 
301
258
  def _ensure_good(self):
302
259
  """Convert bytes to NDArray for fields annotated as NDArray or NDArray | None."""
303
-
304
- # Auto-convert known types when annotated
305
- # for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
306
260
  for field in dataclasses.fields(self):
307
- var_name = field.name
308
- var_type = field.type
309
-
310
- # Annotation for this variable is to be an NDArray
311
- if _is_possibly_ndarray_type(var_type):
312
- # ... so, we'll always auto-convert to numpy
313
-
314
- var_cur_value = getattr(self, var_name)
315
-
316
- # Execute the appropriate conversion for intermediate data
317
- # based on what is provided
318
-
319
- if isinstance(var_cur_value, np.ndarray):
320
- # Already the correct type, no conversion needed
261
+ if _is_possibly_ndarray_type(field.type):
262
+ value = getattr(self, field.name)
263
+ if isinstance(value, np.ndarray):
321
264
  continue
322
-
323
- elif isinstance(var_cur_value, bytes):
324
- # Design note: bytes in NDArray-typed fields are always interpreted
325
- # as serialized arrays. This means raw bytes fields must not be
326
- # annotated as NDArray.
327
- setattr(self, var_name, eh.bytes_to_array(var_cur_value))
265
+ elif isinstance(value, bytes):
266
+ setattr(self, field.name, eh.bytes_to_array(value))
328
267
 
329
268
  def __post_init__(self):
330
269
  self._ensure_good()
@@ -333,67 +272,31 @@ class PackableSample(ABC):
333
272
 
334
273
  @classmethod
335
274
  def from_data(cls, data: WDSRawSample) -> Self:
336
- """Create a sample instance from unpacked msgpack data.
337
-
338
- Args:
339
- data: Dictionary with keys matching the sample's field names.
340
-
341
- Returns:
342
- New instance with NDArray fields auto-converted from bytes.
343
- """
275
+ """Create an instance from unpacked msgpack data."""
344
276
  return cls(**data)
345
277
 
346
278
  @classmethod
347
279
  def from_bytes(cls, bs: bytes) -> Self:
348
- """Create a sample instance from raw msgpack bytes.
349
-
350
- Args:
351
- bs: Raw bytes from a msgpack-serialized sample.
352
-
353
- Returns:
354
- A new instance of this sample class deserialized from the bytes.
355
- """
280
+ """Create an instance from raw msgpack bytes."""
356
281
  return cls.from_data(ormsgpack.unpackb(bs))
357
282
 
358
283
  @property
359
284
  def packed(self) -> bytes:
360
- """Pack this sample's data into msgpack bytes.
361
-
362
- NDArray fields are automatically converted to bytes before packing.
363
- All other fields are packed as-is if they're msgpack-compatible.
364
-
365
- Returns:
366
- Raw msgpack bytes representing this sample's data.
285
+ """Serialize to msgpack bytes. NDArray fields are auto-converted.
367
286
 
368
287
  Raises:
369
288
  RuntimeError: If msgpack serialization fails.
370
289
  """
371
-
372
- # Make sure that all of our (possibly unpackable) data is in a packable
373
- # format
374
290
  o = {k: _make_packable(v) for k, v in vars(self).items()}
375
-
376
291
  ret = msgpack.packb(o)
377
-
378
292
  if ret is None:
379
293
  raise RuntimeError(f"Failed to pack sample to bytes: {o}")
380
-
381
294
  return ret
382
295
 
383
296
  @property
384
297
  def as_wds(self) -> WDSRawSample:
385
- """Pack this sample's data for writing to WebDataset.
386
-
387
- Returns:
388
- A dictionary with ``__key__`` (UUID v1 for sortable keys) and
389
- ``msgpack`` (packed sample data) fields suitable for WebDataset.
390
-
391
- Note:
392
- Keys are auto-generated as UUID v1 for time-sortable ordering.
393
- Custom key specification is not currently supported.
394
- """
298
+ """Serialize for writing to WebDataset (``__key__`` + ``msgpack``)."""
395
299
  return {
396
- # Generates a UUID that is timelike-sortable
397
300
  "__key__": str(uuid.uuid1(0, 0)),
398
301
  "msgpack": self.packed,
399
302
  }
@@ -404,82 +307,45 @@ def _batch_aggregate(xs: Sequence):
404
307
  if not xs:
405
308
  return []
406
309
  if isinstance(xs[0], np.ndarray):
407
- return np.array(list(xs))
310
+ return np.stack(xs)
408
311
  return list(xs)
409
312
 
410
313
 
411
314
  class SampleBatch(Generic[DT]):
412
315
  """A batch of samples with automatic attribute aggregation.
413
316
 
414
- This class wraps a sequence of samples and provides magic ``__getattr__``
415
- access to aggregate sample attributes. When you access an attribute that
416
- exists on the sample type, it automatically aggregates values across all
417
- samples in the batch.
418
-
419
- NDArray fields are stacked into a numpy array with a batch dimension.
420
- Other fields are aggregated into a list.
317
+ Accessing an attribute aggregates that field across all samples:
318
+ NDArray fields are stacked into a numpy array with a batch dimension;
319
+ other fields are collected into a list. Results are cached.
421
320
 
422
321
  Parameters:
423
322
  DT: The sample type, must derive from ``PackableSample``.
424
323
 
425
- Attributes:
426
- samples: The list of sample instances in this batch.
427
-
428
324
  Examples:
429
325
  >>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
430
- >>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
431
- >>> batch.names # Returns list of names
432
-
433
- Note:
434
- This class uses Python's ``__orig_class__`` mechanism to extract the
435
- type parameter at runtime. Instances must be created using the
436
- subscripted syntax ``SampleBatch[MyType](samples)`` rather than
437
- calling the constructor directly with an unsubscripted class.
326
+ >>> batch.embeddings # Stacked numpy array of shape (3, ...)
327
+ >>> batch.names # List of names
438
328
  """
439
329
 
440
- # Design note: The docstring uses "Parameters:" for type parameters because
441
- # quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
442
-
443
330
  def __init__(self, samples: Sequence[DT]):
444
- """Create a batch from a sequence of samples.
445
-
446
- Args:
447
- samples: A sequence of sample instances to aggregate into a batch.
448
- Each sample must be an instance of a type derived from
449
- ``PackableSample``.
450
- """
331
+ """Create a batch from a sequence of samples."""
451
332
  self.samples = list(samples)
452
333
  self._aggregate_cache = dict()
453
334
  self._sample_type_cache: Type | None = None
454
335
 
455
336
  @property
456
337
  def sample_type(self) -> Type:
457
- """The type of each sample in this batch.
458
-
459
- Returns:
460
- The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
461
- """
338
+ """The type parameter ``DT`` used when creating this batch."""
462
339
  if self._sample_type_cache is None:
463
340
  self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
464
- assert self._sample_type_cache is not None
341
+ if self._sample_type_cache is None:
342
+ raise TypeError(
343
+ "SampleBatch requires a type parameter, e.g. SampleBatch[MySample]"
344
+ )
465
345
  return self._sample_type_cache
466
346
 
467
347
  def __getattr__(self, name):
468
- """Aggregate an attribute across all samples in the batch.
469
-
470
- This magic method enables attribute-style access to aggregated sample
471
- fields. Results are cached for efficiency.
472
-
473
- Args:
474
- name: The attribute name to aggregate across samples.
475
-
476
- Returns:
477
- For NDArray fields: a stacked numpy array with batch dimension.
478
- For other fields: a list of values from each sample.
479
-
480
- Raises:
481
- AttributeError: If the attribute doesn't exist on the sample type.
482
- """
348
+ """Aggregate a field across all samples (cached)."""
483
349
  # Aggregate named params of sample type
484
350
  if name in vars(self.sample_type)["__annotations__"]:
485
351
  if name not in self._aggregate_cache:
@@ -492,8 +358,8 @@ class SampleBatch(Generic[DT]):
492
358
  raise AttributeError(f"No sample attribute named {name}")
493
359
 
494
360
 
495
- ST = TypeVar("ST", bound=PackableSample)
496
- RT = TypeVar("RT", bound=PackableSample)
361
+ ST = TypeVar("ST", bound=Packable)
362
+ RT = TypeVar("RT", bound=Packable)
497
363
 
498
364
 
499
365
  class _ShardListStage(wds.utils.PipelineStage):
@@ -571,23 +437,18 @@ class Dataset(Generic[ST]):
571
437
 
572
438
  @property
573
439
  def sample_type(self) -> Type:
574
- """The type of each returned sample from this dataset's iterator.
575
-
576
- Returns:
577
- The type parameter ``ST`` used when creating this ``Dataset[ST]``.
578
- """
440
+ """The type parameter ``ST`` used when creating this dataset."""
579
441
  if self._sample_type_cache is None:
580
442
  self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
581
- assert self._sample_type_cache is not None
443
+ if self._sample_type_cache is None:
444
+ raise TypeError(
445
+ "Dataset requires a type parameter, e.g. Dataset[MySample]"
446
+ )
582
447
  return self._sample_type_cache
583
448
 
584
449
  @property
585
450
  def batch_type(self) -> Type:
586
- """The type of batches produced by this dataset.
587
-
588
- Returns:
589
- ``SampleBatch[ST]`` where ``ST`` is this dataset's sample type.
590
- """
451
+ """``SampleBatch[ST]`` where ``ST`` is this dataset's sample type."""
591
452
  return SampleBatch[self.sample_type]
592
453
 
593
454
  def __init__(
@@ -614,28 +475,21 @@ class Dataset(Generic[ST]):
614
475
  """
615
476
  super().__init__()
616
477
 
617
- # Handle backward compatibility: url= keyword argument
618
478
  if source is None and url is not None:
619
479
  source = url
620
480
  elif source is None:
621
481
  raise TypeError("Dataset() missing required argument: 'source' or 'url'")
622
482
 
623
- # Normalize source: strings become URLSource for backward compatibility
624
483
  if isinstance(source, str):
625
484
  self._source: DataSource = URLSource(source)
626
485
  self.url = source
627
486
  else:
628
487
  self._source = source
629
- # For compatibility, expose URL if source has list_shards
630
488
  shards = source.list_shards()
631
- # Design note: Using first shard as url for legacy compatibility.
632
- # Full shard list is available via list_shards() method.
633
489
  self.url = shards[0] if shards else ""
634
490
 
635
491
  self._metadata: dict[str, Any] | None = None
636
492
  self.metadata_url: str | None = metadata_url
637
- """Optional URL to msgpack-encoded metadata for this dataset."""
638
-
639
493
  self._output_lens: Lens | None = None
640
494
  self._sample_type_cache: Type | None = None
641
495
 
@@ -645,47 +499,23 @@ class Dataset(Generic[ST]):
645
499
  return self._source
646
500
 
647
501
  def as_type(self, other: Type[RT]) -> "Dataset[RT]":
648
- """View this dataset through a different sample type using a registered lens.
649
-
650
- Args:
651
- other: The target sample type to transform into. Must be a type
652
- derived from ``PackableSample``.
653
-
654
- Returns:
655
- A new ``Dataset`` instance that yields samples of type ``other``
656
- by applying the appropriate lens transformation from the global
657
- ``LensNetwork`` registry.
502
+ """View this dataset through a different sample type via a registered lens.
658
503
 
659
504
  Raises:
660
- ValueError: If no registered lens exists between the current
661
- sample type and the target type.
505
+ ValueError: If no lens exists between the current and target types.
662
506
  """
663
507
  ret = Dataset[other](self._source)
664
- # Get the singleton lens registry
665
508
  lenses = LensNetwork()
666
509
  ret._output_lens = lenses.transform(self.sample_type, ret.sample_type)
667
510
  return ret
668
511
 
669
512
  @property
670
513
  def shards(self) -> Iterator[str]:
671
- """Lazily iterate over shard identifiers.
672
-
673
- Yields:
674
- Shard identifiers (e.g., 'train-000000.tar', 'train-000001.tar').
675
-
676
- Examples:
677
- >>> for shard in ds.shards:
678
- ... print(f"Processing {shard}")
679
- """
514
+ """Lazily iterate over shard identifiers."""
680
515
  return iter(self._source.list_shards())
681
516
 
682
517
  def list_shards(self) -> list[str]:
683
- """Get list of individual dataset shards.
684
-
685
- Returns:
686
- A full (non-lazy) list of the individual ``tar`` files within the
687
- source WebDataset.
688
- """
518
+ """Return all shard paths/URLs as a list."""
689
519
  return self._source.list_shards()
690
520
 
691
521
  # Legacy alias for backwards compatibility
@@ -707,14 +537,7 @@ class Dataset(Generic[ST]):
707
537
 
708
538
  @property
709
539
  def metadata(self) -> dict[str, Any] | None:
710
- """Fetch and cache metadata from metadata_url.
711
-
712
- Returns:
713
- Deserialized metadata dictionary, or None if no metadata_url is set.
714
-
715
- Raises:
716
- requests.HTTPError: If metadata fetch fails.
717
- """
540
+ """Fetch and cache metadata from metadata_url, or ``None`` if unset."""
718
541
  if self.metadata_url is None:
719
542
  return None
720
543
 
@@ -726,6 +549,367 @@ class Dataset(Generic[ST]):
726
549
  # Use our cached values
727
550
  return self._metadata
728
551
 
552
+ ##
553
+ # Convenience methods (GH#38 developer experience)
554
+
555
+ @property
556
+ def schema(self) -> dict[str, type]:
557
+ """Field names and types for this dataset's sample type.
558
+
559
+ Examples:
560
+ >>> ds = Dataset[MyData]("data.tar")
561
+ >>> ds.schema
562
+ {'name': <class 'str'>, 'embedding': numpy.ndarray}
563
+ """
564
+ st = self.sample_type
565
+ if st is DictSample:
566
+ return {"_data": dict}
567
+ if dataclasses.is_dataclass(st):
568
+ return {f.name: f.type for f in dataclasses.fields(st)}
569
+ return {}
570
+
571
+ @property
572
+ def column_names(self) -> list[str]:
573
+ """List of field names for this dataset's sample type."""
574
+ st = self.sample_type
575
+ if dataclasses.is_dataclass(st):
576
+ return [f.name for f in dataclasses.fields(st)]
577
+ return []
578
+
579
+ def __iter__(self) -> Iterator[ST]:
580
+ """Shorthand for ``ds.ordered()``."""
581
+ return iter(self.ordered())
582
+
583
+ def __len__(self) -> int:
584
+ """Total sample count (iterates all shards on first call, then cached)."""
585
+ if not hasattr(self, "_len_cache"):
586
+ self._len_cache: int = sum(1 for _ in self.ordered())
587
+ return self._len_cache
588
+
589
+ def head(self, n: int = 5) -> list[ST]:
590
+ """Return the first *n* samples from the dataset.
591
+
592
+ Args:
593
+ n: Number of samples to return. Default: 5.
594
+
595
+ Returns:
596
+ List of up to *n* samples in shard order.
597
+
598
+ Examples:
599
+ >>> samples = ds.head(3)
600
+ >>> len(samples)
601
+ 3
602
+ """
603
+ return list(itertools.islice(self.ordered(), n))
604
+
605
+ def get(self, key: str) -> ST:
606
+ """Retrieve a single sample by its ``__key__``.
607
+
608
+ Scans shards sequentially until a sample with a matching key is found.
609
+ This is O(n) for streaming datasets.
610
+
611
+ Args:
612
+ key: The WebDataset ``__key__`` string to search for.
613
+
614
+ Returns:
615
+ The matching sample.
616
+
617
+ Raises:
618
+ SampleKeyError: If no sample with the given key exists.
619
+
620
+ Examples:
621
+ >>> sample = ds.get("00000001-0001-1000-8000-010000000000")
622
+ """
623
+ pipeline = wds.pipeline.DataPipeline(
624
+ _ShardListStage(self._source),
625
+ wds.shardlists.split_by_worker,
626
+ _StreamOpenerStage(self._source),
627
+ wds.tariterators.tar_file_expander,
628
+ wds.tariterators.group_by_keys,
629
+ )
630
+ for raw_sample in pipeline:
631
+ if raw_sample.get("__key__") == key:
632
+ return self.wrap(raw_sample)
633
+ raise SampleKeyError(key)
634
+
635
+ def describe(self) -> dict[str, Any]:
636
+ """Summary statistics: sample_type, fields, num_shards, shards, url, metadata."""
637
+ shards = self.list_shards()
638
+ return {
639
+ "sample_type": self.sample_type.__name__,
640
+ "fields": self.schema,
641
+ "num_shards": len(shards),
642
+ "shards": shards,
643
+ "url": self.url,
644
+ "metadata": self.metadata,
645
+ }
646
+
647
+ def filter(self, predicate: Callable[[ST], bool]) -> "Dataset[ST]":
648
+ """Return a new dataset that yields only samples matching *predicate*.
649
+
650
+ The filter is applied lazily during iteration — no data is copied.
651
+
652
+ Args:
653
+ predicate: A function that takes a sample and returns ``True``
654
+ to keep it or ``False`` to discard it.
655
+
656
+ Returns:
657
+ A new ``Dataset`` whose iterators apply the filter.
658
+
659
+ Examples:
660
+ >>> long_names = ds.filter(lambda s: len(s.name) > 10)
661
+ >>> for sample in long_names:
662
+ ... assert len(sample.name) > 10
663
+ """
664
+ filtered = Dataset[self.sample_type](self._source, self.metadata_url)
665
+ filtered._sample_type_cache = self._sample_type_cache
666
+ filtered._output_lens = self._output_lens
667
+ filtered._filter_fn = predicate
668
+ # Preserve any existing filters
669
+ parent_filters = getattr(self, "_filter_fn", None)
670
+ if parent_filters is not None:
671
+ outer = parent_filters
672
+ filtered._filter_fn = lambda s: outer(s) and predicate(s)
673
+ # Preserve any existing map
674
+ if hasattr(self, "_map_fn"):
675
+ filtered._map_fn = self._map_fn
676
+ return filtered
677
+
678
+ def map(self, fn: Callable[[ST], Any]) -> "Dataset":
679
+ """Return a new dataset that applies *fn* to each sample during iteration.
680
+
681
+ The mapping is applied lazily during iteration — no data is copied.
682
+
683
+ Args:
684
+ fn: A function that takes a sample of type ``ST`` and returns
685
+ a transformed value.
686
+
687
+ Returns:
688
+ A new ``Dataset`` whose iterators apply the mapping.
689
+
690
+ Examples:
691
+ >>> names = ds.map(lambda s: s.name)
692
+ >>> for name in names:
693
+ ... print(name)
694
+ """
695
+ mapped = Dataset[self.sample_type](self._source, self.metadata_url)
696
+ mapped._sample_type_cache = self._sample_type_cache
697
+ mapped._output_lens = self._output_lens
698
+ mapped._map_fn = fn
699
+ # Preserve any existing map
700
+ if hasattr(self, "_map_fn"):
701
+ outer = self._map_fn
702
+ mapped._map_fn = lambda s: fn(outer(s))
703
+ # Preserve any existing filter
704
+ if hasattr(self, "_filter_fn"):
705
+ mapped._filter_fn = self._filter_fn
706
+ return mapped
707
+
708
+ def process_shards(
709
+ self,
710
+ fn: Callable[[list[ST]], Any],
711
+ *,
712
+ shards: list[str] | None = None,
713
+ ) -> dict[str, Any]:
714
+ """Process each shard independently, collecting per-shard results.
715
+
716
+ Unlike :meth:`map` (which is lazy and per-sample), this method eagerly
717
+ processes each shard in turn, calling *fn* with the full list of samples
718
+ from that shard. If some shards fail, raises
719
+ :class:`~atdata._exceptions.PartialFailureError` containing both the
720
+ successful results and the per-shard errors.
721
+
722
+ Args:
723
+ fn: Function receiving a list of samples from one shard and
724
+ returning an arbitrary result.
725
+ shards: Optional list of shard identifiers to process. If ``None``,
726
+ processes all shards in the dataset. Useful for retrying only
727
+ the failed shards from a previous ``PartialFailureError``.
728
+
729
+ Returns:
730
+ Dict mapping shard identifier to *fn*'s return value for each shard.
731
+
732
+ Raises:
733
+ PartialFailureError: If at least one shard fails. The exception
734
+ carries ``.succeeded_shards``, ``.failed_shards``, ``.errors``,
735
+ and ``.results`` for inspection and retry.
736
+
737
+ Examples:
738
+ >>> results = ds.process_shards(lambda samples: len(samples))
739
+ >>> # On partial failure, retry just the failed shards:
740
+ >>> try:
741
+ ... results = ds.process_shards(expensive_fn)
742
+ ... except PartialFailureError as e:
743
+ ... retry = ds.process_shards(expensive_fn, shards=e.failed_shards)
744
+ """
745
+ from ._logging import get_logger
746
+
747
+ log = get_logger()
748
+ shard_ids = shards or self.list_shards()
749
+ log.info("process_shards: starting %d shards", len(shard_ids))
750
+
751
+ succeeded: list[str] = []
752
+ failed: list[str] = []
753
+ errors: dict[str, Exception] = {}
754
+ results: dict[str, Any] = {}
755
+
756
+ for shard_id in shard_ids:
757
+ try:
758
+ shard_ds = Dataset[self.sample_type](shard_id)
759
+ shard_ds._sample_type_cache = self._sample_type_cache
760
+ samples = list(shard_ds.ordered())
761
+ results[shard_id] = fn(samples)
762
+ succeeded.append(shard_id)
763
+ log.debug("process_shards: shard ok %s", shard_id)
764
+ except Exception as exc:
765
+ failed.append(shard_id)
766
+ errors[shard_id] = exc
767
+ log.warning("process_shards: shard failed %s: %s", shard_id, exc)
768
+
769
+ if failed:
770
+ log.error(
771
+ "process_shards: %d/%d shards failed",
772
+ len(failed),
773
+ len(shard_ids),
774
+ )
775
+ raise PartialFailureError(
776
+ succeeded_shards=succeeded,
777
+ failed_shards=failed,
778
+ errors=errors,
779
+ results=results,
780
+ )
781
+
782
+ log.info("process_shards: all %d shards succeeded", len(shard_ids))
783
+ return results
784
+
785
+ def select(self, indices: Sequence[int]) -> list[ST]:
786
+ """Return samples at the given integer indices.
787
+
788
+ Iterates through the dataset in order and collects samples whose
789
+ positional index matches. This is O(n) for streaming datasets.
790
+
791
+ Args:
792
+ indices: Sequence of zero-based indices to select.
793
+
794
+ Returns:
795
+ List of samples at the requested positions, in index order.
796
+
797
+ Examples:
798
+ >>> samples = ds.select([0, 5, 10])
799
+ >>> len(samples)
800
+ 3
801
+ """
802
+ if not indices:
803
+ return []
804
+ target = set(indices)
805
+ max_idx = max(indices)
806
+ result: dict[int, ST] = {}
807
+ for i, sample in enumerate(self.ordered()):
808
+ if i in target:
809
+ result[i] = sample
810
+ if i >= max_idx:
811
+ break
812
+ return [result[i] for i in indices if i in result]
813
+
814
+ def query(
815
+ self,
816
+ where: "Callable[[pd.DataFrame], pd.Series]",
817
+ ) -> "list[SampleLocation]":
818
+ """Query this dataset using per-shard manifest metadata.
819
+
820
+ Requires manifests to have been generated during shard writing.
821
+ Discovers manifest files alongside the tar shards, loads them,
822
+ and executes a two-phase query (shard-level aggregate pruning,
823
+ then sample-level parquet filtering).
824
+
825
+ Args:
826
+ where: Predicate function that receives a pandas DataFrame
827
+ of manifest fields and returns a boolean Series selecting
828
+ matching rows.
829
+
830
+ Returns:
831
+ List of ``SampleLocation`` for matching samples.
832
+
833
+ Raises:
834
+ FileNotFoundError: If no manifest files are found alongside shards.
835
+
836
+ Examples:
837
+ >>> locs = ds.query(where=lambda df: df["confidence"] > 0.9)
838
+ >>> len(locs)
839
+ 42
840
+ """
841
+ from .manifest import QueryExecutor
842
+
843
+ shard_urls = self.list_shards()
844
+ executor = QueryExecutor.from_shard_urls(shard_urls)
845
+ return executor.query(where=where)
846
+
847
+ def to_pandas(self, limit: int | None = None) -> "pd.DataFrame":
848
+ """Materialize the dataset (or first *limit* samples) as a DataFrame.
849
+
850
+ Args:
851
+ limit: Maximum number of samples to include. ``None`` means all
852
+ samples (may use significant memory for large datasets).
853
+
854
+ Returns:
855
+ A pandas DataFrame with one row per sample and columns matching
856
+ the sample fields.
857
+
858
+ Warning:
859
+ With ``limit=None`` this loads the entire dataset into memory.
860
+
861
+ Examples:
862
+ >>> df = ds.to_pandas(limit=100)
863
+ >>> df.columns.tolist()
864
+ ['name', 'embedding']
865
+ """
866
+ samples = self.head(limit) if limit is not None else list(self.ordered())
867
+ rows = [
868
+ asdict(s) if dataclasses.is_dataclass(s) else s.to_dict() for s in samples
869
+ ]
870
+ return pd.DataFrame(rows)
871
+
872
+ def to_dict(self, limit: int | None = None) -> dict[str, list[Any]]:
873
+ """Materialize the dataset as a column-oriented dictionary.
874
+
875
+ Args:
876
+ limit: Maximum number of samples to include. ``None`` means all.
877
+
878
+ Returns:
879
+ Dictionary mapping field names to lists of values (one entry
880
+ per sample).
881
+
882
+ Warning:
883
+ With ``limit=None`` this loads the entire dataset into memory.
884
+
885
+ Examples:
886
+ >>> d = ds.to_dict(limit=10)
887
+ >>> d.keys()
888
+ dict_keys(['name', 'embedding'])
889
+ >>> len(d['name'])
890
+ 10
891
+ """
892
+ samples = self.head(limit) if limit is not None else list(self.ordered())
893
+ if not samples:
894
+ return {}
895
+ if dataclasses.is_dataclass(samples[0]):
896
+ fields = [f.name for f in dataclasses.fields(samples[0])]
897
+ return {f: [getattr(s, f) for s in samples] for f in fields}
898
+ # DictSample path
899
+ keys = samples[0].keys()
900
+ return {k: [s[k] for s in samples] for k in keys}
901
+
902
+ def _post_wrap_stages(self) -> list:
903
+ """Build extra pipeline stages for filter/map set via .filter()/.map()."""
904
+ stages: list = []
905
+ filter_fn = getattr(self, "_filter_fn", None)
906
+ if filter_fn is not None:
907
+ stages.append(wds.filters.select(filter_fn))
908
+ map_fn = getattr(self, "_map_fn", None)
909
+ if map_fn is not None:
910
+ stages.append(wds.filters.map(map_fn))
911
+ return stages
912
+
729
913
  @overload
730
914
  def ordered(
731
915
  self,
@@ -769,6 +953,7 @@ class Dataset(Generic[ST]):
769
953
  wds.tariterators.tar_file_expander,
770
954
  wds.tariterators.group_by_keys,
771
955
  wds.filters.map(self.wrap),
956
+ *self._post_wrap_stages(),
772
957
  )
773
958
 
774
959
  return wds.pipeline.DataPipeline(
@@ -839,6 +1024,7 @@ class Dataset(Generic[ST]):
839
1024
  wds.tariterators.group_by_keys,
840
1025
  wds.filters.shuffle(buffer_samples),
841
1026
  wds.filters.map(self.wrap),
1027
+ *self._post_wrap_stages(),
842
1028
  )
843
1029
 
844
1030
  return wds.pipeline.DataPipeline(
@@ -862,100 +1048,47 @@ class Dataset(Generic[ST]):
862
1048
  maxcount: Optional[int] = None,
863
1049
  **kwargs,
864
1050
  ):
865
- """Export dataset contents to parquet format.
866
-
867
- Converts all samples to a pandas DataFrame and saves to parquet file(s).
868
- Useful for interoperability with data analysis tools.
1051
+ """Export dataset to parquet file(s).
869
1052
 
870
1053
  Args:
871
- path: Output path for the parquet file. If ``maxcount`` is specified,
872
- files are named ``{stem}-{segment:06d}.parquet``.
873
- sample_map: Optional function to convert samples to dictionaries.
874
- Defaults to ``dataclasses.asdict``.
875
- maxcount: If specified, split output into multiple files with at most
876
- this many samples each. Recommended for large datasets.
877
- **kwargs: Additional arguments passed to ``pandas.DataFrame.to_parquet()``.
878
- Common options include ``compression``, ``index``, ``engine``.
879
-
880
- Warning:
881
- **Memory Usage**: When ``maxcount=None`` (default), this method loads
882
- the **entire dataset into memory** as a pandas DataFrame before writing.
883
- For large datasets, this can cause memory exhaustion.
884
-
885
- For datasets larger than available RAM, always specify ``maxcount``::
886
-
887
- # Safe for large datasets - processes in chunks
888
- ds.to_parquet("output.parquet", maxcount=10000)
889
-
890
- This creates multiple parquet files: ``output-000000.parquet``,
891
- ``output-000001.parquet``, etc.
1054
+ path: Output path. With *maxcount*, files are named
1055
+ ``{stem}-{segment:06d}.parquet``.
1056
+ sample_map: Convert sample to dict. Defaults to ``dataclasses.asdict``.
1057
+ maxcount: Split into files of at most this many samples.
1058
+ Without it, the entire dataset is loaded into memory.
1059
+ **kwargs: Passed to ``pandas.DataFrame.to_parquet()``.
892
1060
 
893
1061
  Examples:
894
- >>> ds = Dataset[MySample]("data.tar")
895
- >>> # Small dataset - load all at once
896
- >>> ds.to_parquet("output.parquet")
897
- >>>
898
- >>> # Large dataset - process in chunks
899
1062
  >>> ds.to_parquet("output.parquet", maxcount=50000)
900
1063
  """
901
- ##
902
-
903
- # Normalize args
904
1064
  path = Path(path)
905
1065
  if sample_map is None:
906
1066
  sample_map = asdict
907
1067
 
908
- verbose = kwargs.get("verbose", False)
909
-
910
- it = self.ordered(batch_size=None)
911
- if verbose:
912
- it = tqdm(it)
913
-
914
- #
915
-
916
1068
  if maxcount is None:
917
- # Load and save full dataset
918
1069
  df = pd.DataFrame([sample_map(x) for x in self.ordered(batch_size=None)])
919
1070
  df.to_parquet(path, **kwargs)
920
-
921
1071
  else:
922
- # Load and save dataset in segments of size `maxcount`
923
-
924
1072
  cur_segment = 0
925
- cur_buffer = []
1073
+ cur_buffer: list = []
926
1074
  path_template = (
927
1075
  path.parent / f"{path.stem}-{{:06d}}{path.suffix}"
928
1076
  ).as_posix()
929
1077
 
930
1078
  for x in self.ordered(batch_size=None):
931
1079
  cur_buffer.append(sample_map(x))
932
-
933
1080
  if len(cur_buffer) >= maxcount:
934
- # Write current segment
935
1081
  cur_path = path_template.format(cur_segment)
936
- df = pd.DataFrame(cur_buffer)
937
- df.to_parquet(cur_path, **kwargs)
938
-
1082
+ pd.DataFrame(cur_buffer).to_parquet(cur_path, **kwargs)
939
1083
  cur_segment += 1
940
1084
  cur_buffer = []
941
1085
 
942
- if len(cur_buffer) > 0:
943
- # Write one last segment with remainder
1086
+ if cur_buffer:
944
1087
  cur_path = path_template.format(cur_segment)
945
- df = pd.DataFrame(cur_buffer)
946
- df.to_parquet(cur_path, **kwargs)
1088
+ pd.DataFrame(cur_buffer).to_parquet(cur_path, **kwargs)
947
1089
 
948
1090
  def wrap(self, sample: WDSRawSample) -> ST:
949
- """Wrap a raw msgpack sample into the appropriate dataset-specific type.
950
-
951
- Args:
952
- sample: A dictionary containing at minimum a ``'msgpack'`` key with
953
- serialized sample bytes.
954
-
955
- Returns:
956
- A deserialized sample of type ``ST``, optionally transformed through
957
- a lens if ``as_type()`` was called.
958
- """
1091
+ """Deserialize a raw WDS sample dict into type ``ST``."""
959
1092
  if "msgpack" not in sample:
960
1093
  raise ValueError(
961
1094
  f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}"
@@ -972,20 +1105,7 @@ class Dataset(Generic[ST]):
972
1105
  return self._output_lens(source_sample)
973
1106
 
974
1107
  def wrap_batch(self, batch: WDSRawBatch) -> SampleBatch[ST]:
975
- """Wrap a batch of raw msgpack samples into a typed SampleBatch.
976
-
977
- Args:
978
- batch: A dictionary containing a ``'msgpack'`` key with a list of
979
- serialized sample bytes.
980
-
981
- Returns:
982
- A ``SampleBatch[ST]`` containing deserialized samples, optionally
983
- transformed through a lens if ``as_type()`` was called.
984
-
985
- Note:
986
- This implementation deserializes samples one at a time, then
987
- aggregates them into a batch.
988
- """
1108
+ """Deserialize a raw WDS batch dict into ``SampleBatch[ST]``."""
989
1109
 
990
1110
  if "msgpack" not in batch:
991
1111
  raise ValueError(
@@ -1009,24 +1129,12 @@ _T = TypeVar("_T")
1009
1129
 
1010
1130
 
1011
1131
  @dataclass_transform()
1012
- def packable(cls: type[_T]) -> type[_T]:
1013
- """Decorator to convert a regular class into a ``PackableSample``.
1014
-
1015
- This decorator transforms a class into a dataclass that inherits from
1016
- ``PackableSample``, enabling automatic msgpack serialization/deserialization
1017
- with special handling for NDArray fields.
1132
+ def packable(cls: type[_T]) -> type[Packable]:
1133
+ """Convert a class into a ``PackableSample`` dataclass with msgpack serialization.
1018
1134
 
1019
- The resulting class satisfies the ``Packable`` protocol, making it compatible
1020
- with all atdata APIs that accept packable types (e.g., ``publish_schema``,
1021
- lens transformations, etc.).
1022
-
1023
- Args:
1024
- cls: The class to convert. Should have type annotations for its fields.
1025
-
1026
- Returns:
1027
- A new dataclass that inherits from ``PackableSample`` with the same
1028
- name and annotations as the original class. The class satisfies the
1029
- ``Packable`` protocol and can be used with ``Type[Packable]`` signatures.
1135
+ The resulting class gains ``packed``, ``as_wds``, ``from_bytes``, and
1136
+ ``from_data`` methods, and satisfies the ``Packable`` protocol.
1137
+ NDArray fields are automatically handled during serialization.
1030
1138
 
1031
1139
  Examples:
1032
1140
  >>> @packable
@@ -1035,11 +1143,7 @@ def packable(cls: type[_T]) -> type[_T]:
1035
1143
  ... values: NDArray
1036
1144
  ...
1037
1145
  >>> sample = MyData(name="test", values=np.array([1, 2, 3]))
1038
- >>> bytes_data = sample.packed
1039
- >>> restored = MyData.from_bytes(bytes_data)
1040
- >>>
1041
- >>> # Works with Packable-typed APIs
1042
- >>> index.publish_schema(MyData, version="1.0.0") # Type-safe
1146
+ >>> restored = MyData.from_bytes(sample.packed)
1043
1147
  """
1044
1148
 
1045
1149
  ##
@@ -1086,3 +1190,154 @@ def packable(cls: type[_T]) -> type[_T]:
1086
1190
  ##
1087
1191
 
1088
1192
  return as_packable
1193
+
1194
+
1195
+ # ---------------------------------------------------------------------------
1196
+ # write_samples — convenience function for writing samples to tar files
1197
+ # ---------------------------------------------------------------------------
1198
+
1199
+
1200
+ def write_samples(
1201
+ samples: Iterable[ST],
1202
+ path: str | Path,
1203
+ *,
1204
+ maxcount: int | None = None,
1205
+ maxsize: int | None = None,
1206
+ manifest: bool = False,
1207
+ ) -> "Dataset[ST]":
1208
+ """Write an iterable of samples to WebDataset tar file(s).
1209
+
1210
+ Args:
1211
+ samples: Iterable of ``PackableSample`` instances. Must be non-empty.
1212
+ path: Output path for the tar file. For sharded output (when
1213
+ *maxcount* or *maxsize* is set), a ``%06d`` pattern is
1214
+ auto-appended if the path does not already contain ``%``.
1215
+ maxcount: Maximum samples per shard. Triggers multi-shard output.
1216
+ maxsize: Maximum bytes per shard. Triggers multi-shard output.
1217
+ manifest: If True, write per-shard manifest sidecar files
1218
+ (``.manifest.json`` + ``.manifest.parquet``) alongside each
1219
+ tar file. Manifests enable metadata queries via
1220
+ ``QueryExecutor`` without opening the tars.
1221
+
1222
+ Returns:
1223
+ A ``Dataset`` wrapping the written file(s), typed to the sample
1224
+ type of the input samples.
1225
+
1226
+ Raises:
1227
+ ValueError: If *samples* is empty.
1228
+
1229
+ Examples:
1230
+ >>> samples = [MySample(key="0", text="hello")]
1231
+ >>> ds = write_samples(samples, "out.tar")
1232
+ >>> list(ds.ordered())
1233
+ [MySample(key='0', text='hello')]
1234
+ """
1235
+ from ._hf_api import _shards_to_wds_url
1236
+
1237
+ if manifest:
1238
+ from .manifest._builder import ManifestBuilder
1239
+ from .manifest._writer import ManifestWriter
1240
+
1241
+ path = Path(path)
1242
+ path.parent.mkdir(parents=True, exist_ok=True)
1243
+
1244
+ use_shard_writer = maxcount is not None or maxsize is not None
1245
+ sample_type: type | None = None
1246
+ written_paths: list[str] = []
1247
+
1248
+ # Manifest tracking state
1249
+ _current_builder: list = [] # single-element list for nonlocal mutation
1250
+ _builders: list[tuple[str, "ManifestBuilder"]] = []
1251
+ _running_offset: list[int] = [0]
1252
+
1253
+ def _finalize_builder() -> None:
1254
+ """Finalize the current manifest builder and stash it."""
1255
+ if _current_builder:
1256
+ shard_path = written_paths[-1] if written_paths else ""
1257
+ _builders.append((shard_path, _current_builder[0]))
1258
+ _current_builder.clear()
1259
+
1260
+ def _start_builder(shard_path: str) -> None:
1261
+ """Start a new manifest builder for a shard."""
1262
+ _finalize_builder()
1263
+ shard_id = Path(shard_path).stem
1264
+ _current_builder.append(
1265
+ ManifestBuilder(sample_type=sample_type, shard_id=shard_id)
1266
+ )
1267
+ _running_offset[0] = 0
1268
+
1269
+ def _record_sample(sample: "PackableSample", wds_dict: dict) -> None:
1270
+ """Record a sample in the active manifest builder."""
1271
+ if not _current_builder:
1272
+ return
1273
+ packed_bytes = wds_dict["msgpack"]
1274
+ size = len(packed_bytes)
1275
+ _current_builder[0].add_sample(
1276
+ key=wds_dict["__key__"],
1277
+ offset=_running_offset[0],
1278
+ size=size,
1279
+ sample=sample,
1280
+ )
1281
+ _running_offset[0] += size
1282
+
1283
+ if use_shard_writer:
1284
+ # Build shard pattern from path
1285
+ if "%" not in str(path):
1286
+ pattern = str(path.parent / f"{path.stem}-%06d{path.suffix}")
1287
+ else:
1288
+ pattern = str(path)
1289
+
1290
+ writer_kwargs: dict[str, Any] = {}
1291
+ if maxcount is not None:
1292
+ writer_kwargs["maxcount"] = maxcount
1293
+ if maxsize is not None:
1294
+ writer_kwargs["maxsize"] = maxsize
1295
+
1296
+ def _track(p: str) -> None:
1297
+ written_paths.append(str(Path(p).resolve()))
1298
+ if manifest and sample_type is not None:
1299
+ _start_builder(p)
1300
+
1301
+ with wds.writer.ShardWriter(pattern, post=_track, **writer_kwargs) as sink:
1302
+ for sample in samples:
1303
+ if sample_type is None:
1304
+ sample_type = type(sample)
1305
+ wds_dict = sample.as_wds
1306
+ sink.write(wds_dict)
1307
+ if manifest:
1308
+ # The first sample triggers _track before we get here when
1309
+ # ShardWriter opens the first shard, but just in case:
1310
+ if not _current_builder and sample_type is not None:
1311
+ _start_builder(str(path))
1312
+ _record_sample(sample, wds_dict)
1313
+ else:
1314
+ with wds.writer.TarWriter(str(path)) as sink:
1315
+ for sample in samples:
1316
+ if sample_type is None:
1317
+ sample_type = type(sample)
1318
+ wds_dict = sample.as_wds
1319
+ sink.write(wds_dict)
1320
+ if manifest:
1321
+ if not _current_builder and sample_type is not None:
1322
+ _current_builder.append(
1323
+ ManifestBuilder(sample_type=sample_type, shard_id=path.stem)
1324
+ )
1325
+ _record_sample(sample, wds_dict)
1326
+ written_paths.append(str(path.resolve()))
1327
+
1328
+ if sample_type is None:
1329
+ raise ValueError("samples must be non-empty")
1330
+
1331
+ # Finalize and write manifests
1332
+ if manifest:
1333
+ _finalize_builder()
1334
+ for shard_path, builder in _builders:
1335
+ m = builder.build()
1336
+ base = str(Path(shard_path).with_suffix(""))
1337
+ writer = ManifestWriter(base)
1338
+ writer.write(m)
1339
+
1340
+ url = _shards_to_wds_url(written_paths)
1341
+ ds: Dataset = Dataset(url)
1342
+ ds._sample_type_cache = sample_type
1343
+ return ds