atdata 0.1.3b4__py3-none-any.whl → 0.2.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atdata/dataset.py CHANGED
@@ -32,7 +32,6 @@ import webdataset as wds
32
32
 
33
33
  from pathlib import Path
34
34
  import uuid
35
- import functools
36
35
 
37
36
  import dataclasses
38
37
  import types
@@ -40,14 +39,12 @@ from dataclasses import (
40
39
  dataclass,
41
40
  asdict,
42
41
  )
43
- from abc import (
44
- ABC,
45
- abstractmethod,
46
- )
42
+ from abc import ABC
47
43
 
48
44
  from tqdm import tqdm
49
45
  import numpy as np
50
46
  import pandas as pd
47
+ import requests
51
48
 
52
49
  import typing
53
50
  from typing import (
@@ -65,15 +62,7 @@ from typing import (
65
62
  TypeVar,
66
63
  TypeAlias,
67
64
  )
68
- # from typing_inspect import get_bound, get_parameters
69
- from numpy.typing import (
70
- NDArray,
71
- ArrayLike,
72
- )
73
-
74
- #
75
-
76
- # import ekumen.atmosphere as eat
65
+ from numpy.typing import NDArray
77
66
 
78
67
  import msgpack
79
68
  import ormsgpack
@@ -96,40 +85,10 @@ SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
96
85
  ##
97
86
  # Main base classes
98
87
 
99
- # TODO Check for best way to ensure this typevar is used as a dataclass type
100
- # DT = TypeVar( 'DT', bound = dataclass.__class__ )
101
88
  DT = TypeVar( 'DT' )
102
89
 
103
90
  MsgpackRawSample: TypeAlias = Dict[str, Any]
104
91
 
105
- # @dataclass
106
- # class ArrayBytes:
107
- # """Annotates bytes that should be interpreted as the raw contents of a
108
- # numpy NDArray"""
109
-
110
- # raw_bytes: bytes
111
- # """The raw bytes of the corresponding NDArray"""
112
-
113
- # def __init__( self,
114
- # array: Optional[ArrayLike] = None,
115
- # raw: Optional[bytes] = None,
116
- # ):
117
- # """TODO"""
118
-
119
- # if array is not None:
120
- # array = np.array( array )
121
- # self.raw_bytes = eh.array_to_bytes( array )
122
-
123
- # elif raw is not None:
124
- # self.raw_bytes = raw
125
-
126
- # else:
127
- # raise ValueError( 'Must provide either `array` or `raw` bytes' )
128
-
129
- # @property
130
- # def to_numpy( self ) -> NDArray:
131
- # """Return the `raw_bytes` data as an NDArray"""
132
- # return eh.bytes_to_array( self.raw_bytes )
133
92
 
134
93
  def _make_packable( x ):
135
94
  """Convert a value to a msgpack-compatible format.
@@ -141,8 +100,6 @@ def _make_packable( x ):
141
100
  Returns:
142
101
  The value in a format suitable for msgpack serialization.
143
102
  """
144
- # if isinstance( x, ArrayBytes ):
145
- # return x.raw_bytes
146
103
  if isinstance( x, np.ndarray ):
147
104
  return eh.array_to_bytes( x )
148
105
  return x
@@ -226,11 +183,8 @@ class PackableSample( ABC ):
226
183
  # based on what is provided
227
184
 
228
185
  if isinstance( var_cur_value, np.ndarray ):
229
- # we're good!
230
- pass
231
-
232
- # elif isinstance( var_cur_value, ArrayBytes ):
233
- # setattr( self, var_name, var_cur_value.to_numpy )
186
+ # Already the correct type, no conversion needed
187
+ continue
234
188
 
235
189
  elif isinstance( var_cur_value, bytes ):
236
190
  # TODO This does create a constraint that serialized bytes
@@ -411,24 +365,9 @@ class SampleBatch( Generic[DT] ):
411
365
  raise AttributeError( f'No sample attribute named {name}' )
412
366
 
413
367
 
414
- # class AnySample( BaseModel ):
415
- # """A sample that can hold anything"""
416
- # value: Any
417
-
418
- # class AnyBatch( BaseModel ):
419
- # """A batch of `AnySample`s"""
420
- # values: list[AnySample]
421
-
422
-
423
368
  ST = TypeVar( 'ST', bound = PackableSample )
424
- # BT = TypeVar( 'BT' )
425
-
426
369
  RT = TypeVar( 'RT', bound = PackableSample )
427
370
 
428
- # TODO For python 3.13
429
- # BT = TypeVar( 'BT', default = None )
430
- # IT = TypeVar( 'IT', default = Any )
431
-
432
371
  class Dataset( Generic[ST] ):
433
372
  """A typed dataset built on WebDataset with lens transformations.
434
373
 
@@ -456,13 +395,9 @@ class Dataset( Generic[ST] ):
456
395
  ...
457
396
  >>> # Transform to a different view
458
397
  >>> ds_view = ds.as_type(MyDataView)
398
+
459
399
  """
460
400
 
461
- # sample_class: Type = get_parameters( )
462
- # """The type of each returned sample from this `Dataset`'s iterator"""
463
- # batch_class: Type = get_bound( BT )
464
- # """The type of a batch built from `sample_class`"""
465
-
466
401
  @property
467
402
  def sample_type( self ) -> Type:
468
403
  """The type of each returned sample from this dataset's iterator.
@@ -482,16 +417,11 @@ class Dataset( Generic[ST] ):
482
417
  Returns:
483
418
  ``SampleBatch[ST]`` where ``ST`` is this dataset's sample type.
484
419
  """
485
- # return self.__orig_class__.__args__[1]
486
420
  return SampleBatch[self.sample_type]
487
421
 
488
-
489
- # _schema_registry_sample: dict[str, Type]
490
- # _schema_registry_batch: dict[str, Type | None]
491
-
492
- #
493
-
494
- def __init__( self, url: str ) -> None:
422
+ def __init__( self, url: str,
423
+ metadata_url: str | None = None,
424
+ ) -> None:
495
425
  """Create a dataset from a WebDataset URL.
496
426
 
497
427
  Args:
@@ -501,6 +431,14 @@ class Dataset( Generic[ST] ):
501
431
  """
502
432
  super().__init__()
503
433
  self.url = url
434
+ """WebDataset brace-notation URL pointing to tar files, e.g.,
435
+ ``"path/to/file-{000000..000009}.tar"`` for multiple shards or
436
+ ``"path/to/file-000000.tar"`` for a single shard.
437
+ """
438
+
439
+ self._metadata: dict[str, Any] | None = None
440
+ self.metadata_url: str | None = metadata_url
441
+ """Optional URL to msgpack-encoded metadata for this dataset."""
504
442
 
505
443
  # Allow addition of automatic transformation of raw underlying data
506
444
  self._output_lens: Lens | None = None
@@ -527,23 +465,6 @@ class Dataset( Generic[ST] ):
527
465
  ret._output_lens = lenses.transform( self.sample_type, ret.sample_type )
528
466
  return ret
529
467
 
530
- # @classmethod
531
- # def register( cls, uri: str,
532
- # sample_class: Type,
533
- # batch_class: Optional[Type] = None,
534
- # ):
535
- # """Register an `ekumen` schema to use a particular dataset sample class"""
536
- # cls._schema_registry_sample[uri] = sample_class
537
- # cls._schema_registry_batch[uri] = batch_class
538
-
539
- # @classmethod
540
- # def at( cls, uri: str ) -> 'Dataset':
541
- # """Create a Dataset for the `ekumen` index entry at `uri`"""
542
- # client = eat.Client()
543
- # return cls( )
544
-
545
- # Common functionality
546
-
547
468
  @property
548
469
  def shard_list( self ) -> list[str]:
549
470
  """List of individual dataset shards
@@ -557,6 +478,27 @@ class Dataset( Generic[ST] ):
557
478
  wds.filters.map( lambda x: x['url'] )
558
479
  )
559
480
  return list( pipe )
481
+
482
+ @property
483
+ def metadata( self ) -> dict[str, Any] | None:
484
+ """Fetch and cache metadata from metadata_url.
485
+
486
+ Returns:
487
+ Deserialized metadata dictionary, or None if no metadata_url is set.
488
+
489
+ Raises:
490
+ requests.HTTPError: If metadata fetch fails.
491
+ """
492
+ if self.metadata_url is None:
493
+ return None
494
+
495
+ if self._metadata is None:
496
+ with requests.get( self.metadata_url, stream = True ) as response:
497
+ response.raise_for_status()
498
+ self._metadata = msgpack.unpackb( response.content, raw = False )
499
+
500
+ # Use our cached values
501
+ return self._metadata
560
502
 
561
503
  def ordered( self,
562
504
  batch_size: int | None = 1,
@@ -575,22 +517,17 @@ class Dataset( Generic[ST] ):
575
517
  """
576
518
 
577
519
  if batch_size is None:
578
- # TODO Duplication here
579
520
  return wds.pipeline.DataPipeline(
580
521
  wds.shardlists.SimpleShardList( self.url ),
581
522
  wds.shardlists.split_by_worker,
582
- #
583
523
  wds.tariterators.tarfile_to_samples(),
584
- # wds.map( self.preprocess ),
585
524
  wds.filters.map( self.wrap ),
586
525
  )
587
526
 
588
527
  return wds.pipeline.DataPipeline(
589
528
  wds.shardlists.SimpleShardList( self.url ),
590
529
  wds.shardlists.split_by_worker,
591
- #
592
530
  wds.tariterators.tarfile_to_samples(),
593
- # wds.map( self.preprocess ),
594
531
  wds.filters.batched( batch_size ),
595
532
  wds.filters.map( self.wrap_batch ),
596
533
  )
@@ -618,17 +555,12 @@ class Dataset( Generic[ST] ):
618
555
  ``SampleBatch[ST]`` instances; otherwise yields individual ``ST``
619
556
  samples.
620
557
  """
621
-
622
558
  if batch_size is None:
623
- # TODO Duplication here
624
559
  return wds.pipeline.DataPipeline(
625
560
  wds.shardlists.SimpleShardList( self.url ),
626
561
  wds.filters.shuffle( buffer_shards ),
627
562
  wds.shardlists.split_by_worker,
628
- #
629
563
  wds.tariterators.tarfile_to_samples(),
630
- # wds.shuffle( buffer_samples ),
631
- # wds.map( self.preprocess ),
632
564
  wds.filters.shuffle( buffer_samples ),
633
565
  wds.filters.map( self.wrap ),
634
566
  )
@@ -637,10 +569,7 @@ class Dataset( Generic[ST] ):
637
569
  wds.shardlists.SimpleShardList( self.url ),
638
570
  wds.filters.shuffle( buffer_shards ),
639
571
  wds.shardlists.split_by_worker,
640
- #
641
572
  wds.tariterators.tarfile_to_samples(),
642
- # wds.shuffle( buffer_samples ),
643
- # wds.map( self.preprocess ),
644
573
  wds.filters.shuffle( buffer_samples ),
645
574
  wds.filters.batched( batch_size ),
646
575
  wds.filters.map( self.wrap_batch ),
@@ -683,11 +612,11 @@ class Dataset( Generic[ST] ):
683
612
 
684
613
  cur_segment = 0
685
614
  cur_buffer = []
686
- path_template = (path.parent / f'{path.stem}-%06d.{path.suffix}').as_posix()
615
+ path_template = (path.parent / f'{path.stem}-{{:06d}}{path.suffix}').as_posix()
687
616
 
688
617
  for x in self.ordered( batch_size = None ):
689
618
  cur_buffer.append( sample_map( x ) )
690
-
619
+
691
620
  if len( cur_buffer ) >= maxcount:
692
621
  # Write current segment
693
622
  cur_path = path_template.format( cur_segment )
@@ -703,23 +632,6 @@ class Dataset( Generic[ST] ):
703
632
  df = pd.DataFrame( cur_buffer )
704
633
  df.to_parquet( cur_path, **kwargs )
705
634
 
706
-
707
- # Implemented by specific subclasses
708
-
709
- # @property
710
- # @abstractmethod
711
- # def url( self ) -> str:
712
- # """str: Brace-notation URL of the underlying full WebDataset"""
713
- # pass
714
-
715
- # @classmethod
716
- # # TODO replace Any with IT
717
- # def preprocess( cls, sample: WDSRawSample ) -> Any:
718
- # """Pre-built preprocessor for a raw `sample` from the given dataset"""
719
- # return sample
720
-
721
- # @classmethod
722
- # TODO replace Any with IT
723
635
  def wrap( self, sample: MsgpackRawSample ) -> ST:
724
636
  """Wrap a raw msgpack sample into the appropriate dataset-specific type.
725
637
 
@@ -739,19 +651,6 @@ class Dataset( Generic[ST] ):
739
651
 
740
652
  source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] )
741
653
  return self._output_lens( source_sample )
742
-
743
- # try:
744
- # assert type( sample ) == dict
745
- # return cls.sample_class( **{
746
- # k: v
747
- # for k, v in sample.items() if k != '__key__'
748
- # } )
749
-
750
- # except Exception as e:
751
- # # Sample constructor failed -- revert to default
752
- # return AnySample(
753
- # value = sample,
754
- # )
755
654
 
756
655
  def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
757
656
  """Wrap a batch of raw msgpack samples into a typed SampleBatch.
@@ -782,34 +681,6 @@ class Dataset( Generic[ST] ):
782
681
  for s in batch_source ]
783
682
  return SampleBatch[self.sample_type]( batch_view )
784
683
 
785
- # # @classmethod
786
- # def wrap_batch( self, batch: WDSRawBatch ) -> BT:
787
- # """Wrap a `batch` of samples into the appropriate dataset-specific type
788
-
789
- # This default implementation simply creates a list one sample at a time
790
- # """
791
- # assert cls.batch_class is not None, 'No batch class specified'
792
- # return cls.batch_class( **batch )
793
-
794
-
795
- ##
796
- # Shortcut decorators
797
-
798
- # def packable( cls ):
799
- # """TODO"""
800
-
801
- # def decorator( cls ):
802
- # # Create a new class dynamically
803
- # # The new class inherits from the new_parent_class first, then the original cls
804
- # new_bases = (PackableSample,) + cls.__bases__
805
- # new_cls = type(cls.__name__, new_bases, dict(cls.__dict__))
806
-
807
- # # Optionally, update __module__ and __qualname__ for better introspection
808
- # new_cls.__module__ = cls.__module__
809
- # new_cls.__qualname__ = cls.__qualname__
810
-
811
- # return new_cls
812
- # return decorator
813
684
 
814
685
  def packable( cls ):
815
686
  """Decorator to convert a regular class into a ``PackableSample``.
atdata/lens.py CHANGED
@@ -201,22 +201,6 @@ class Lens( Generic[S, V] ):
201
201
  """
202
202
  return self._getter( s )
203
203
 
204
- # TODO Figure out how to properly parameterize this
205
- # def _lens_factory[S, V]( register: bool = True ):
206
- # """Register the annotated function `f` as the getter of a sample lens"""
207
-
208
- # # The actual lens decorator taking a lens getter function to a lens object
209
- # def _decorator( f: LensGetter[S, V] ) -> Lens[S, V]:
210
- # ret = Lens[S, V]( f )
211
- # if register:
212
- # _network.register( ret )
213
- # return ret
214
-
215
- # # Return the lens decorator
216
- # return _decorator
217
-
218
- # # For convenience
219
- # lens = _lens_factory
220
204
 
221
205
  def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
222
206
  """Decorator to create and register a lens transformation.
@@ -246,12 +230,6 @@ def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
246
230
  return ret
247
231
 
248
232
 
249
- ##
250
- # Global registry of used lenses
251
-
252
- # _registered_lenses: Dict[LensSignature, Lens] = dict()
253
- # """TODO"""
254
-
255
233
  class LensNetwork:
256
234
  """Global registry for lens transformations between sample types.
257
235
 
@@ -292,18 +270,6 @@ class LensNetwork:
292
270
  If a lens already exists for the same type pair, it will be
293
271
  overwritten.
294
272
  """
295
-
296
- # sig = inspect.signature( _lens.get )
297
- # input_types = list( sig.parameters.values() )
298
- # assert len( input_types ) == 1, \
299
- # 'Wrong number of input args for lens: should only have one'
300
-
301
- # input_type = input_types[0].annotation
302
- # print( input_type )
303
- # output_type = sig.return_annotation
304
-
305
- # self._registry[input_type, output_type] = _lens
306
- # print( _lens.source_type )
307
273
  self._registry[_lens.source_type, _lens.view_type] = _lens
308
274
 
309
275
  def transform( self, source: DatasetType, view: DatasetType ) -> Lens:
@@ -323,8 +289,6 @@ class LensNetwork:
323
289
  Currently only supports direct transformations. Compositional
324
290
  transformations (chaining multiple lenses) are not yet implemented.
325
291
  """
326
-
327
- # TODO Handle compositional closure
328
292
  ret = self._registry.get( (source, view), None )
329
293
  if ret is None:
330
294
  raise ValueError( f'No registered lens from source {source} to view {view}' )
@@ -332,22 +296,5 @@ class LensNetwork:
332
296
  return ret
333
297
 
334
298
 
335
- # Create global singleton registry instance
336
- _network = LensNetwork()
337
-
338
- # def lens( f: LensPutter ) -> Lens:
339
- # """Register the annotated function `f` as a sample lens"""
340
- # ##
341
-
342
- # sig = inspect.signature( f )
343
-
344
- # input_types = list( sig.parameters.values() )
345
- # output_type = sig.return_annotation
346
-
347
- # _registered_lenses[]
348
-
349
- # f.lens = Lens(
350
-
351
- # )
352
-
353
- # return f
299
+ # Global singleton registry instance
300
+ _network = LensNetwork()