atdata 0.1.3b3__py3-none-any.whl → 0.2.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atdata/dataset.py CHANGED
@@ -1,4 +1,29 @@
1
- """Schematized WebDatasets"""
1
+ """Core dataset and sample infrastructure for typed WebDatasets.
2
+
3
+ This module provides the core components for working with typed, msgpack-serialized
4
+ samples in WebDataset format:
5
+
6
+ - ``PackableSample``: Base class for msgpack-serializable samples with automatic
7
+ NDArray handling
8
+ - ``SampleBatch``: Automatic batching with attribute aggregation
9
+ - ``Dataset``: Generic typed dataset wrapper for WebDataset tar files
10
+ - ``@packable``: Decorator to convert regular classes into PackableSample subclasses
11
+
12
+ The implementation handles automatic conversion between numpy arrays and bytes
13
+ during serialization, enabling efficient storage of numerical data in WebDataset
14
+ archives.
15
+
16
+ Example:
17
+ >>> @packable
18
+ ... class ImageSample:
19
+ ... image: NDArray
20
+ ... label: str
21
+ ...
22
+ >>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
23
+ >>> for batch in ds.shuffled(batch_size=32):
24
+ ... images = batch.image # Stacked numpy array (32, H, W, C)
25
+ ... labels = batch.label # List of 32 strings
26
+ """
2
27
 
3
28
  ##
4
29
  # Imports
@@ -7,7 +32,6 @@ import webdataset as wds
7
32
 
8
33
  from pathlib import Path
9
34
  import uuid
10
- import functools
11
35
 
12
36
  import dataclasses
13
37
  import types
@@ -15,14 +39,12 @@ from dataclasses import (
15
39
  dataclass,
16
40
  asdict,
17
41
  )
18
- from abc import (
19
- ABC,
20
- abstractmethod,
21
- )
42
+ from abc import ABC
22
43
 
23
44
  from tqdm import tqdm
24
45
  import numpy as np
25
46
  import pandas as pd
47
+ import requests
26
48
 
27
49
  import typing
28
50
  from typing import (
@@ -40,15 +62,7 @@ from typing import (
40
62
  TypeVar,
41
63
  TypeAlias,
42
64
  )
43
- # from typing_inspect import get_bound, get_parameters
44
- from numpy.typing import (
45
- NDArray,
46
- ArrayLike,
47
- )
48
-
49
- #
50
-
51
- # import ekumen.atmosphere as eat
65
+ from numpy.typing import NDArray
52
66
 
53
67
  import msgpack
54
68
  import ormsgpack
@@ -71,50 +85,35 @@ SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
71
85
  ##
72
86
  # Main base classes
73
87
 
74
- # TODO Check for best way to ensure this typevar is used as a dataclass type
75
- # DT = TypeVar( 'DT', bound = dataclass.__class__ )
76
88
  DT = TypeVar( 'DT' )
77
89
 
78
90
  MsgpackRawSample: TypeAlias = Dict[str, Any]
79
91
 
80
- # @dataclass
81
- # class ArrayBytes:
82
- # """Annotates bytes that should be interpreted as the raw contents of a
83
- # numpy NDArray"""
84
-
85
- # raw_bytes: bytes
86
- # """The raw bytes of the corresponding NDArray"""
87
-
88
- # def __init__( self,
89
- # array: Optional[ArrayLike] = None,
90
- # raw: Optional[bytes] = None,
91
- # ):
92
- # """TODO"""
93
-
94
- # if array is not None:
95
- # array = np.array( array )
96
- # self.raw_bytes = eh.array_to_bytes( array )
97
-
98
- # elif raw is not None:
99
- # self.raw_bytes = raw
100
-
101
- # else:
102
- # raise ValueError( 'Must provide either `array` or `raw` bytes' )
103
-
104
- # @property
105
- # def to_numpy( self ) -> NDArray:
106
- # """Return the `raw_bytes` data as an NDArray"""
107
- # return eh.bytes_to_array( self.raw_bytes )
108
92
 
109
93
  def _make_packable( x ):
110
- # if isinstance( x, ArrayBytes ):
111
- # return x.raw_bytes
94
+ """Convert a value to a msgpack-compatible format.
95
+
96
+ Args:
97
+ x: A value to convert. If it's a numpy array, converts to bytes.
98
+ Otherwise returns the value unchanged.
99
+
100
+ Returns:
101
+ The value in a format suitable for msgpack serialization.
102
+ """
112
103
  if isinstance( x, np.ndarray ):
113
104
  return eh.array_to_bytes( x )
114
105
  return x
115
106
 
116
107
  def _is_possibly_ndarray_type( t ):
117
- """Checks if a type annotation is possibly an NDArray."""
108
+ """Check if a type annotation is or contains NDArray.
109
+
110
+ Args:
111
+ t: A type annotation to check.
112
+
113
+ Returns:
114
+ ``True`` if the type is ``NDArray`` or a union containing ``NDArray``
115
+ (e.g., ``NDArray | None``), ``False`` otherwise.
116
+ """
118
117
 
119
118
  # Directly an NDArray
120
119
  if t == NDArray:
@@ -133,10 +132,40 @@ def _is_possibly_ndarray_type( t ):
133
132
 
134
133
  @dataclass
135
134
  class PackableSample( ABC ):
136
- """A sample that can be packed and unpacked with msgpack"""
135
+ """Base class for samples that can be serialized with msgpack.
136
+
137
+ This abstract base class provides automatic serialization/deserialization
138
+ for dataclass-based samples. Fields annotated as ``NDArray`` or
139
+ ``NDArray | None`` are automatically converted between numpy arrays and
140
+ bytes during packing/unpacking.
141
+
142
+ Subclasses should be defined either by:
143
+ 1. Direct inheritance with the ``@dataclass`` decorator
144
+ 2. Using the ``@packable`` decorator (recommended)
145
+
146
+ Example:
147
+ >>> @packable
148
+ ... class MyData:
149
+ ... name: str
150
+ ... embeddings: NDArray
151
+ ...
152
+ >>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
153
+ >>> packed = sample.packed # Serialize to bytes
154
+ >>> restored = MyData.from_bytes(packed) # Deserialize
155
+ """
137
156
 
138
157
  def _ensure_good( self ):
139
- """TODO Stupid kludge because of __post_init__ nonsense for wrapped classes"""
158
+ """Auto-convert annotated NDArray fields from bytes to numpy arrays.
159
+
160
+ This method scans all dataclass fields and for any field annotated as
161
+ ``NDArray`` or ``NDArray | None``, automatically converts bytes values
162
+ to numpy arrays using the helper deserialization function. This enables
163
+ transparent handling of array serialization in msgpack data.
164
+
165
+ Note:
166
+ This is called during ``__post_init__`` to ensure proper type
167
+ conversion after deserialization.
168
+ """
140
169
 
141
170
  # Auto-convert known types when annotated
142
171
  # for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
@@ -154,11 +183,8 @@ class PackableSample( ABC ):
154
183
  # based on what is provided
155
184
 
156
185
  if isinstance( var_cur_value, np.ndarray ):
157
- # we're good!
158
- pass
159
-
160
- # elif isinstance( var_cur_value, ArrayBytes ):
161
- # setattr( self, var_name, var_cur_value.to_numpy )
186
+ # Already the correct type, no conversion needed
187
+ continue
162
188
 
163
189
  elif isinstance( var_cur_value, bytes ):
164
190
  # TODO This does create a constraint that serialized bytes
@@ -173,19 +199,45 @@ class PackableSample( ABC ):
173
199
 
174
200
  @classmethod
175
201
  def from_data( cls, data: MsgpackRawSample ) -> Self:
176
- """Create a sample instance from unpacked msgpack data"""
202
+ """Create a sample instance from unpacked msgpack data.
203
+
204
+ Args:
205
+ data: A dictionary of unpacked msgpack data with keys matching
206
+ the sample's field names.
207
+
208
+ Returns:
209
+ A new instance of this sample class with fields populated from
210
+ the data dictionary and NDArray fields auto-converted from bytes.
211
+ """
177
212
  ret = cls( **data )
178
213
  ret._ensure_good()
179
214
  return ret
180
215
 
181
216
  @classmethod
182
217
  def from_bytes( cls, bs: bytes ) -> Self:
183
- """Create a sample instance from raw msgpack bytes"""
218
+ """Create a sample instance from raw msgpack bytes.
219
+
220
+ Args:
221
+ bs: Raw bytes from a msgpack-serialized sample.
222
+
223
+ Returns:
224
+ A new instance of this sample class deserialized from the bytes.
225
+ """
184
226
  return cls.from_data( ormsgpack.unpackb( bs ) )
185
227
 
186
228
  @property
187
229
  def packed( self ) -> bytes:
188
- """Pack this sample's data into msgpack bytes"""
230
+ """Pack this sample's data into msgpack bytes.
231
+
232
+ NDArray fields are automatically converted to bytes before packing.
233
+ All other fields are packed as-is if they're msgpack-compatible.
234
+
235
+ Returns:
236
+ Raw msgpack bytes representing this sample's data.
237
+
238
+ Raises:
239
+ RuntimeError: If msgpack serialization fails.
240
+ """
189
241
 
190
242
  # Make sure that all of our (possibly unpackable) data is in a packable
191
243
  # format
@@ -204,7 +256,15 @@ class PackableSample( ABC ):
204
256
  # TODO Expand to allow for specifying explicit __key__
205
257
  @property
206
258
  def as_wds( self ) -> WDSRawSample:
207
- """Pack this sample's data for writing to webdataset"""
259
+ """Pack this sample's data for writing to WebDataset.
260
+
261
+ Returns:
262
+ A dictionary with ``__key__`` (UUID v1 for sortable keys) and
263
+ ``msgpack`` (packed sample data) fields suitable for WebDataset.
264
+
265
+ Note:
266
+ TODO: Expand to allow specifying explicit ``__key__`` values.
267
+ """
208
268
  return {
209
269
  # Generates a UUID that is timelike-sortable
210
270
  '__key__': str( uuid.uuid1( 0, 0 ) ),
@@ -212,30 +272,86 @@ class PackableSample( ABC ):
212
272
  }
213
273
 
214
274
  def _batch_aggregate( xs: Sequence ):
275
+ """Aggregate a sequence of values into a batch-appropriate format.
276
+
277
+ Args:
278
+ xs: A sequence of values to aggregate. If the first element is a numpy
279
+ array, all elements are stacked into a single array. Otherwise,
280
+ returns a list.
281
+
282
+ Returns:
283
+ A numpy array (if elements are arrays) or a list (otherwise).
284
+ """
215
285
 
216
286
  if not xs:
217
287
  # Empty sequence
218
288
  return []
219
289
 
220
- # Aggregate
290
+ # Aggregate
221
291
  if isinstance( xs[0], np.ndarray ):
222
292
  return np.array( list( xs ) )
223
293
 
224
294
  return list( xs )
225
295
 
226
296
  class SampleBatch( Generic[DT] ):
297
+ """A batch of samples with automatic attribute aggregation.
298
+
299
+ This class wraps a sequence of samples and provides magic ``__getattr__``
300
+ access to aggregate sample attributes. When you access an attribute that
301
+ exists on the sample type, it automatically aggregates values across all
302
+ samples in the batch.
303
+
304
+ NDArray fields are stacked into a numpy array with a batch dimension.
305
+ Other fields are aggregated into a list.
306
+
307
+ Type Parameters:
308
+ DT: The sample type, must derive from ``PackableSample``.
309
+
310
+ Attributes:
311
+ samples: The list of sample instances in this batch.
312
+
313
+ Example:
314
+ >>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
315
+ >>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
316
+ >>> batch.names # Returns list of names
317
+ """
227
318
 
228
319
  def __init__( self, samples: Sequence[DT] ):
229
- """TODO"""
320
+ """Create a batch from a sequence of samples.
321
+
322
+ Args:
323
+ samples: A sequence of sample instances to aggregate into a batch.
324
+ Each sample must be an instance of a type derived from
325
+ ``PackableSample``.
326
+ """
230
327
  self.samples = list( samples )
231
328
  self._aggregate_cache = dict()
232
329
 
233
330
  @property
234
331
  def sample_type( self ) -> Type:
235
- """The type of each sample in this batch"""
332
+ """The type of each sample in this batch.
333
+
334
+ Returns:
335
+ The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
336
+ """
236
337
  return typing.get_args( self.__orig_class__)[0]
237
338
 
238
339
  def __getattr__( self, name ):
340
+ """Aggregate an attribute across all samples in the batch.
341
+
342
+ This magic method enables attribute-style access to aggregated sample
343
+ fields. Results are cached for efficiency.
344
+
345
+ Args:
346
+ name: The attribute name to aggregate across samples.
347
+
348
+ Returns:
349
+ For NDArray fields: a stacked numpy array with batch dimension.
350
+ For other fields: a list of values from each sample.
351
+
352
+ Raises:
353
+ AttributeError: If the attribute doesn't exist on the sample type.
354
+ """
239
355
  # Aggregate named params of sample type
240
356
  if name in vars( self.sample_type )['__annotations__']:
241
357
  if name not in self._aggregate_cache:
@@ -243,91 +359,112 @@ class SampleBatch( Generic[DT] ):
243
359
  [ getattr( x, name )
244
360
  for x in self.samples ]
245
361
  )
246
-
247
- return self._aggregate_cache[name]
248
-
249
- raise AttributeError( f'No sample attribute named {name}' )
250
-
251
362
 
252
- # class AnySample( BaseModel ):
253
- # """A sample that can hold anything"""
254
- # value: Any
363
+ return self._aggregate_cache[name]
255
364
 
256
- # class AnyBatch( BaseModel ):
257
- # """A batch of `AnySample`s"""
258
- # values: list[AnySample]
365
+ raise AttributeError( f'No sample attribute named {name}' )
259
366
 
260
367
 
261
368
  ST = TypeVar( 'ST', bound = PackableSample )
262
- # BT = TypeVar( 'BT' )
263
-
264
369
  RT = TypeVar( 'RT', bound = PackableSample )
265
370
 
266
- # TODO For python 3.13
267
- # BT = TypeVar( 'BT', default = None )
268
- # IT = TypeVar( 'IT', default = Any )
269
-
270
371
  class Dataset( Generic[ST] ):
271
- """A dataset that ingests and formats raw samples from a WebDataset
372
+ """A typed dataset built on WebDataset with lens transformations.
373
+
374
+ This class wraps WebDataset tar archives and provides type-safe iteration
375
+ over samples of a specific ``PackableSample`` type. Samples are stored as
376
+ msgpack-serialized data within WebDataset shards.
377
+
378
+ The dataset supports:
379
+ - Ordered and shuffled iteration
380
+ - Automatic batching with ``SampleBatch``
381
+ - Type transformations via the lens system (``as_type()``)
382
+ - Export to parquet format
383
+
384
+ Type Parameters:
385
+ ST: The sample type for this dataset, must derive from ``PackableSample``.
386
+
387
+ Attributes:
388
+ url: WebDataset brace-notation URL for the tar file(s).
389
+
390
+ Example:
391
+ >>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
392
+ >>> for sample in ds.ordered(batch_size=32):
393
+ ... # sample is SampleBatch[MyData] with batch_size samples
394
+ ... embeddings = sample.embeddings # shape: (32, ...)
395
+ ...
396
+ >>> # Transform to a different view
397
+ >>> ds_view = ds.as_type(MyDataView)
272
398
 
273
- (Abstract base for subclassing)
274
399
  """
275
400
 
276
- # sample_class: Type = get_parameters( )
277
- # """The type of each returned sample from this `Dataset`'s iterator"""
278
- # batch_class: Type = get_bound( BT )
279
- # """The type of a batch built from `sample_class`"""
280
-
281
401
  @property
282
402
  def sample_type( self ) -> Type:
283
- """The type of each returned sample from this `Dataset`'s iterator"""
284
- # TODO Figure out why linting fails here
403
+ """The type of each returned sample from this dataset's iterator.
404
+
405
+ Returns:
406
+ The type parameter ``ST`` used when creating this ``Dataset[ST]``.
407
+
408
+ Note:
409
+ Extracts the type parameter at runtime using ``__orig_class__``.
410
+ """
411
+ # NOTE: Linting may fail here due to __orig_class__ being a runtime attribute
285
412
  return typing.get_args( self.__orig_class__ )[0]
286
413
  @property
287
414
  def batch_type( self ) -> Type:
288
- """The type of a batch built from `sample_class`"""
289
- # return self.__orig_class__.__args__[1]
290
- return SampleBatch[self.sample_type]
291
-
415
+ """The type of batches produced by this dataset.
292
416
 
293
- # _schema_registry_sample: dict[str, Type]
294
- # _schema_registry_batch: dict[str, Type | None]
417
+ Returns:
418
+ ``SampleBatch[ST]`` where ``ST`` is this dataset's sample type.
419
+ """
420
+ return SampleBatch[self.sample_type]
295
421
 
296
- #
422
+ def __init__( self, url: str,
423
+ metadata_url: str | None = None,
424
+ ) -> None:
425
+ """Create a dataset from a WebDataset URL.
297
426
 
298
- def __init__( self, url: str ) -> None:
299
- """TODO"""
427
+ Args:
428
+ url: WebDataset brace-notation URL pointing to tar files, e.g.,
429
+ ``"path/to/file-{000000..000009}.tar"`` for multiple shards or
430
+ ``"path/to/file-000000.tar"`` for a single shard.
431
+ """
300
432
  super().__init__()
301
433
  self.url = url
434
+ """WebDataset brace-notation URL pointing to tar files, e.g.,
435
+ ``"path/to/file-{000000..000009}.tar"`` for multiple shards or
436
+ ``"path/to/file-000000.tar"`` for a single shard.
437
+ """
438
+
439
+ self._metadata: dict[str, Any] | None = None
440
+ self.metadata_url: str | None = metadata_url
441
+ """Optional URL to msgpack-encoded metadata for this dataset."""
302
442
 
303
443
  # Allow addition of automatic transformation of raw underlying data
304
444
  self._output_lens: Lens | None = None
305
445
 
306
446
  def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
307
- """TODO"""
447
+ """View this dataset through a different sample type using a registered lens.
448
+
449
+ Args:
450
+ other: The target sample type to transform into. Must be a type
451
+ derived from ``PackableSample``.
452
+
453
+ Returns:
454
+ A new ``Dataset`` instance that yields samples of type ``other``
455
+ by applying the appropriate lens transformation from the global
456
+ ``LensNetwork`` registry.
457
+
458
+ Raises:
459
+ ValueError: If no registered lens exists between the current
460
+ sample type and the target type.
461
+ """
308
462
  ret = Dataset[other]( self.url )
309
463
  # Get the singleton lens registry
310
464
  lenses = LensNetwork()
311
465
  ret._output_lens = lenses.transform( self.sample_type, ret.sample_type )
312
466
  return ret
313
467
 
314
- # @classmethod
315
- # def register( cls, uri: str,
316
- # sample_class: Type,
317
- # batch_class: Optional[Type] = None,
318
- # ):
319
- # """Register an `ekumen` schema to use a particular dataset sample class"""
320
- # cls._schema_registry_sample[uri] = sample_class
321
- # cls._schema_registry_batch[uri] = batch_class
322
-
323
- # @classmethod
324
- # def at( cls, uri: str ) -> 'Dataset':
325
- # """Create a Dataset for the `ekumen` index entry at `uri`"""
326
- # client = eat.Client()
327
- # return cls( )
328
-
329
- # Common functionality
330
-
331
468
  @property
332
469
  def shard_list( self ) -> list[str]:
333
470
  """List of individual dataset shards
@@ -341,6 +478,27 @@ class Dataset( Generic[ST] ):
341
478
  wds.filters.map( lambda x: x['url'] )
342
479
  )
343
480
  return list( pipe )
481
+
482
+ @property
483
+ def metadata( self ) -> dict[str, Any] | None:
484
+ """Fetch and cache metadata from metadata_url.
485
+
486
+ Returns:
487
+ Deserialized metadata dictionary, or None if no metadata_url is set.
488
+
489
+ Raises:
490
+ requests.HTTPError: If metadata fetch fails.
491
+ """
492
+ if self.metadata_url is None:
493
+ return None
494
+
495
+ if self._metadata is None:
496
+ with requests.get( self.metadata_url, stream = True ) as response:
497
+ response.raise_for_status()
498
+ self._metadata = msgpack.unpackb( response.content, raw = False )
499
+
500
+ # Use our cached values
501
+ return self._metadata
344
502
 
345
503
  def ordered( self,
346
504
  batch_size: int | None = 1,
@@ -359,22 +517,17 @@ class Dataset( Generic[ST] ):
359
517
  """
360
518
 
361
519
  if batch_size is None:
362
- # TODO Duplication here
363
520
  return wds.pipeline.DataPipeline(
364
521
  wds.shardlists.SimpleShardList( self.url ),
365
522
  wds.shardlists.split_by_worker,
366
- #
367
523
  wds.tariterators.tarfile_to_samples(),
368
- # wds.map( self.preprocess ),
369
524
  wds.filters.map( self.wrap ),
370
525
  )
371
526
 
372
527
  return wds.pipeline.DataPipeline(
373
528
  wds.shardlists.SimpleShardList( self.url ),
374
529
  wds.shardlists.split_by_worker,
375
- #
376
530
  wds.tariterators.tarfile_to_samples(),
377
- # wds.map( self.preprocess ),
378
531
  wds.filters.batched( batch_size ),
379
532
  wds.filters.map( self.wrap_batch ),
380
533
  )
@@ -384,30 +537,30 @@ class Dataset( Generic[ST] ):
384
537
  buffer_samples: int = 10_000,
385
538
  batch_size: int | None = 1,
386
539
  ) -> Iterable[ST]:
387
- """Iterate over the dataset in random order
388
-
540
+ """Iterate over the dataset in random order.
541
+
389
542
  Args:
390
- buffer_shards (int): Asdf
391
- batch_size (:obj:`int`, optional) The size of iterated batches.
392
- Default: 1. If ``None``, iterates over one sample at a time
393
- with no batch dimension.
394
-
543
+ buffer_shards: Number of shards to buffer for shuffling at the
544
+ shard level. Larger values increase randomness but use more
545
+ memory. Default: 100.
546
+ buffer_samples: Number of samples to buffer for shuffling within
547
+ shards. Larger values increase randomness but use more memory.
548
+ Default: 10,000.
549
+ batch_size: The size of iterated batches. Default: 1. If ``None``,
550
+ iterates over one sample at a time with no batch dimension.
551
+
395
552
  Returns:
396
- :obj:`webdataset.DataPipeline` A data pipeline that iterates over
397
- the dataset in its original sample order
398
-
553
+ A WebDataset data pipeline that iterates over the dataset in
554
+ randomized order. If ``batch_size`` is not ``None``, yields
555
+ ``SampleBatch[ST]`` instances; otherwise yields individual ``ST``
556
+ samples.
399
557
  """
400
-
401
558
  if batch_size is None:
402
- # TODO Duplication here
403
559
  return wds.pipeline.DataPipeline(
404
560
  wds.shardlists.SimpleShardList( self.url ),
405
561
  wds.filters.shuffle( buffer_shards ),
406
562
  wds.shardlists.split_by_worker,
407
- #
408
563
  wds.tariterators.tarfile_to_samples(),
409
- # wds.shuffle( buffer_samples ),
410
- # wds.map( self.preprocess ),
411
564
  wds.filters.shuffle( buffer_samples ),
412
565
  wds.filters.map( self.wrap ),
413
566
  )
@@ -416,10 +569,7 @@ class Dataset( Generic[ST] ):
416
569
  wds.shardlists.SimpleShardList( self.url ),
417
570
  wds.filters.shuffle( buffer_shards ),
418
571
  wds.shardlists.split_by_worker,
419
- #
420
572
  wds.tariterators.tarfile_to_samples(),
421
- # wds.shuffle( buffer_samples ),
422
- # wds.map( self.preprocess ),
423
573
  wds.filters.shuffle( buffer_samples ),
424
574
  wds.filters.batched( batch_size ),
425
575
  wds.filters.map( self.wrap_batch ),
@@ -462,11 +612,11 @@ class Dataset( Generic[ST] ):
462
612
 
463
613
  cur_segment = 0
464
614
  cur_buffer = []
465
- path_template = (path.parent / f'{path.stem}-%06d.{path.suffix}').as_posix()
615
+ path_template = (path.parent / f'{path.stem}-{{:06d}}{path.suffix}').as_posix()
466
616
 
467
617
  for x in self.ordered( batch_size = None ):
468
618
  cur_buffer.append( sample_map( x ) )
469
-
619
+
470
620
  if len( cur_buffer ) >= maxcount:
471
621
  # Write current segment
472
622
  cur_path = path_template.format( cur_segment )
@@ -482,25 +632,17 @@ class Dataset( Generic[ST] ):
482
632
  df = pd.DataFrame( cur_buffer )
483
633
  df.to_parquet( cur_path, **kwargs )
484
634
 
635
+ def wrap( self, sample: MsgpackRawSample ) -> ST:
636
+ """Wrap a raw msgpack sample into the appropriate dataset-specific type.
485
637
 
486
- # Implemented by specific subclasses
487
-
488
- # @property
489
- # @abstractmethod
490
- # def url( self ) -> str:
491
- # """str: Brace-notation URL of the underlying full WebDataset"""
492
- # pass
493
-
494
- # @classmethod
495
- # # TODO replace Any with IT
496
- # def preprocess( cls, sample: WDSRawSample ) -> Any:
497
- # """Pre-built preprocessor for a raw `sample` from the given dataset"""
498
- # return sample
638
+ Args:
639
+ sample: A dictionary containing at minimum a ``'msgpack'`` key with
640
+ serialized sample bytes.
499
641
 
500
- # @classmethod
501
- # TODO replace Any with IT
502
- def wrap( self, sample: MsgpackRawSample ) -> ST:
503
- """Wrap a `sample` into the appropriate dataset-specific type"""
642
+ Returns:
643
+ A deserialized sample of type ``ST``, optionally transformed through
644
+ a lens if ``as_type()`` was called.
645
+ """
504
646
  assert 'msgpack' in sample
505
647
  assert type( sample['msgpack'] ) == bytes
506
648
 
@@ -509,24 +651,21 @@ class Dataset( Generic[ST] ):
509
651
 
510
652
  source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] )
511
653
  return self._output_lens( source_sample )
512
-
513
- # try:
514
- # assert type( sample ) == dict
515
- # return cls.sample_class( **{
516
- # k: v
517
- # for k, v in sample.items() if k != '__key__'
518
- # } )
519
-
520
- # except Exception as e:
521
- # # Sample constructor failed -- revert to default
522
- # return AnySample(
523
- # value = sample,
524
- # )
525
654
 
526
655
  def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
527
- """Wrap a `batch` of samples into the appropriate dataset-specific type
528
-
529
- This default implementation simply creates a list one sample at a time
656
+ """Wrap a batch of raw msgpack samples into a typed SampleBatch.
657
+
658
+ Args:
659
+ batch: A dictionary containing a ``'msgpack'`` key with a list of
660
+ serialized sample bytes.
661
+
662
+ Returns:
663
+ A ``SampleBatch[ST]`` containing deserialized samples, optionally
664
+ transformed through a lens if ``as_type()`` was called.
665
+
666
+ Note:
667
+ This implementation deserializes samples one at a time, then
668
+ aggregates them into a batch.
530
669
  """
531
670
 
532
671
  assert 'msgpack' in batch
@@ -542,38 +681,32 @@ class Dataset( Generic[ST] ):
542
681
  for s in batch_source ]
543
682
  return SampleBatch[self.sample_type]( batch_view )
544
683
 
545
- # # @classmethod
546
- # def wrap_batch( self, batch: WDSRawBatch ) -> BT:
547
- # """Wrap a `batch` of samples into the appropriate dataset-specific type
548
-
549
- # This default implementation simply creates a list one sample at a time
550
- # """
551
- # assert cls.batch_class is not None, 'No batch class specified'
552
- # return cls.batch_class( **batch )
553
-
554
-
555
- ##
556
- # Shortcut decorators
557
-
558
- # def packable( cls ):
559
- # """TODO"""
560
-
561
- # def decorator( cls ):
562
- # # Create a new class dynamically
563
- # # The new class inherits from the new_parent_class first, then the original cls
564
- # new_bases = (PackableSample,) + cls.__bases__
565
- # new_cls = type(cls.__name__, new_bases, dict(cls.__dict__))
566
-
567
- # # Optionally, update __module__ and __qualname__ for better introspection
568
- # new_cls.__module__ = cls.__module__
569
- # new_cls.__qualname__ = cls.__qualname__
570
-
571
- # return new_cls
572
- # return decorator
573
684
 
574
685
  def packable( cls ):
575
- """TODO"""
576
-
686
+ """Decorator to convert a regular class into a ``PackableSample``.
687
+
688
+ This decorator transforms a class into a dataclass that inherits from
689
+ ``PackableSample``, enabling automatic msgpack serialization/deserialization
690
+ with special handling for NDArray fields.
691
+
692
+ Args:
693
+ cls: The class to convert. Should have type annotations for its fields.
694
+
695
+ Returns:
696
+ A new dataclass that inherits from ``PackableSample`` with the same
697
+ name and annotations as the original class.
698
+
699
+ Example:
700
+ >>> @packable
701
+ ... class MyData:
702
+ ... name: str
703
+ ... values: NDArray
704
+ ...
705
+ >>> sample = MyData(name="test", values=np.array([1, 2, 3]))
706
+ >>> bytes_data = sample.packed
707
+ >>> restored = MyData.from_bytes(bytes_data)
708
+ """
709
+
577
710
  ##
578
711
 
579
712
  class_name = cls.__name__