atdata 0.1.3b3__py3-none-any.whl → 0.2.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +39 -1
- atdata/_helpers.py +39 -3
- atdata/atmosphere/__init__.py +61 -0
- atdata/atmosphere/_types.py +329 -0
- atdata/atmosphere/client.py +393 -0
- atdata/atmosphere/lens.py +280 -0
- atdata/atmosphere/records.py +342 -0
- atdata/atmosphere/schema.py +296 -0
- atdata/dataset.py +336 -203
- atdata/lens.py +177 -77
- atdata/local.py +492 -0
- atdata-0.2.0a1.dist-info/METADATA +181 -0
- atdata-0.2.0a1.dist-info/RECORD +16 -0
- {atdata-0.1.3b3.dist-info → atdata-0.2.0a1.dist-info}/WHEEL +1 -1
- atdata-0.1.3b3.dist-info/METADATA +0 -18
- atdata-0.1.3b3.dist-info/RECORD +0 -9
- {atdata-0.1.3b3.dist-info → atdata-0.2.0a1.dist-info}/entry_points.txt +0 -0
- {atdata-0.1.3b3.dist-info → atdata-0.2.0a1.dist-info}/licenses/LICENSE +0 -0
atdata/dataset.py
CHANGED
|
@@ -1,4 +1,29 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Core dataset and sample infrastructure for typed WebDatasets.
|
|
2
|
+
|
|
3
|
+
This module provides the core components for working with typed, msgpack-serialized
|
|
4
|
+
samples in WebDataset format:
|
|
5
|
+
|
|
6
|
+
- ``PackableSample``: Base class for msgpack-serializable samples with automatic
|
|
7
|
+
NDArray handling
|
|
8
|
+
- ``SampleBatch``: Automatic batching with attribute aggregation
|
|
9
|
+
- ``Dataset``: Generic typed dataset wrapper for WebDataset tar files
|
|
10
|
+
- ``@packable``: Decorator to convert regular classes into PackableSample subclasses
|
|
11
|
+
|
|
12
|
+
The implementation handles automatic conversion between numpy arrays and bytes
|
|
13
|
+
during serialization, enabling efficient storage of numerical data in WebDataset
|
|
14
|
+
archives.
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
>>> @packable
|
|
18
|
+
... class ImageSample:
|
|
19
|
+
... image: NDArray
|
|
20
|
+
... label: str
|
|
21
|
+
...
|
|
22
|
+
>>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
|
|
23
|
+
>>> for batch in ds.shuffled(batch_size=32):
|
|
24
|
+
... images = batch.image # Stacked numpy array (32, H, W, C)
|
|
25
|
+
... labels = batch.label # List of 32 strings
|
|
26
|
+
"""
|
|
2
27
|
|
|
3
28
|
##
|
|
4
29
|
# Imports
|
|
@@ -7,7 +32,6 @@ import webdataset as wds
|
|
|
7
32
|
|
|
8
33
|
from pathlib import Path
|
|
9
34
|
import uuid
|
|
10
|
-
import functools
|
|
11
35
|
|
|
12
36
|
import dataclasses
|
|
13
37
|
import types
|
|
@@ -15,14 +39,12 @@ from dataclasses import (
|
|
|
15
39
|
dataclass,
|
|
16
40
|
asdict,
|
|
17
41
|
)
|
|
18
|
-
from abc import
|
|
19
|
-
ABC,
|
|
20
|
-
abstractmethod,
|
|
21
|
-
)
|
|
42
|
+
from abc import ABC
|
|
22
43
|
|
|
23
44
|
from tqdm import tqdm
|
|
24
45
|
import numpy as np
|
|
25
46
|
import pandas as pd
|
|
47
|
+
import requests
|
|
26
48
|
|
|
27
49
|
import typing
|
|
28
50
|
from typing import (
|
|
@@ -40,15 +62,7 @@ from typing import (
|
|
|
40
62
|
TypeVar,
|
|
41
63
|
TypeAlias,
|
|
42
64
|
)
|
|
43
|
-
|
|
44
|
-
from numpy.typing import (
|
|
45
|
-
NDArray,
|
|
46
|
-
ArrayLike,
|
|
47
|
-
)
|
|
48
|
-
|
|
49
|
-
#
|
|
50
|
-
|
|
51
|
-
# import ekumen.atmosphere as eat
|
|
65
|
+
from numpy.typing import NDArray
|
|
52
66
|
|
|
53
67
|
import msgpack
|
|
54
68
|
import ormsgpack
|
|
@@ -71,50 +85,35 @@ SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
|
|
|
71
85
|
##
|
|
72
86
|
# Main base classes
|
|
73
87
|
|
|
74
|
-
# TODO Check for best way to ensure this typevar is used as a dataclass type
|
|
75
|
-
# DT = TypeVar( 'DT', bound = dataclass.__class__ )
|
|
76
88
|
DT = TypeVar( 'DT' )
|
|
77
89
|
|
|
78
90
|
MsgpackRawSample: TypeAlias = Dict[str, Any]
|
|
79
91
|
|
|
80
|
-
# @dataclass
|
|
81
|
-
# class ArrayBytes:
|
|
82
|
-
# """Annotates bytes that should be interpreted as the raw contents of a
|
|
83
|
-
# numpy NDArray"""
|
|
84
|
-
|
|
85
|
-
# raw_bytes: bytes
|
|
86
|
-
# """The raw bytes of the corresponding NDArray"""
|
|
87
|
-
|
|
88
|
-
# def __init__( self,
|
|
89
|
-
# array: Optional[ArrayLike] = None,
|
|
90
|
-
# raw: Optional[bytes] = None,
|
|
91
|
-
# ):
|
|
92
|
-
# """TODO"""
|
|
93
|
-
|
|
94
|
-
# if array is not None:
|
|
95
|
-
# array = np.array( array )
|
|
96
|
-
# self.raw_bytes = eh.array_to_bytes( array )
|
|
97
|
-
|
|
98
|
-
# elif raw is not None:
|
|
99
|
-
# self.raw_bytes = raw
|
|
100
|
-
|
|
101
|
-
# else:
|
|
102
|
-
# raise ValueError( 'Must provide either `array` or `raw` bytes' )
|
|
103
|
-
|
|
104
|
-
# @property
|
|
105
|
-
# def to_numpy( self ) -> NDArray:
|
|
106
|
-
# """Return the `raw_bytes` data as an NDArray"""
|
|
107
|
-
# return eh.bytes_to_array( self.raw_bytes )
|
|
108
92
|
|
|
109
93
|
def _make_packable( x ):
|
|
110
|
-
|
|
111
|
-
|
|
94
|
+
"""Convert a value to a msgpack-compatible format.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
x: A value to convert. If it's a numpy array, converts to bytes.
|
|
98
|
+
Otherwise returns the value unchanged.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
The value in a format suitable for msgpack serialization.
|
|
102
|
+
"""
|
|
112
103
|
if isinstance( x, np.ndarray ):
|
|
113
104
|
return eh.array_to_bytes( x )
|
|
114
105
|
return x
|
|
115
106
|
|
|
116
107
|
def _is_possibly_ndarray_type( t ):
|
|
117
|
-
"""
|
|
108
|
+
"""Check if a type annotation is or contains NDArray.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
t: A type annotation to check.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
``True`` if the type is ``NDArray`` or a union containing ``NDArray``
|
|
115
|
+
(e.g., ``NDArray | None``), ``False`` otherwise.
|
|
116
|
+
"""
|
|
118
117
|
|
|
119
118
|
# Directly an NDArray
|
|
120
119
|
if t == NDArray:
|
|
@@ -133,10 +132,40 @@ def _is_possibly_ndarray_type( t ):
|
|
|
133
132
|
|
|
134
133
|
@dataclass
|
|
135
134
|
class PackableSample( ABC ):
|
|
136
|
-
"""
|
|
135
|
+
"""Base class for samples that can be serialized with msgpack.
|
|
136
|
+
|
|
137
|
+
This abstract base class provides automatic serialization/deserialization
|
|
138
|
+
for dataclass-based samples. Fields annotated as ``NDArray`` or
|
|
139
|
+
``NDArray | None`` are automatically converted between numpy arrays and
|
|
140
|
+
bytes during packing/unpacking.
|
|
141
|
+
|
|
142
|
+
Subclasses should be defined either by:
|
|
143
|
+
1. Direct inheritance with the ``@dataclass`` decorator
|
|
144
|
+
2. Using the ``@packable`` decorator (recommended)
|
|
145
|
+
|
|
146
|
+
Example:
|
|
147
|
+
>>> @packable
|
|
148
|
+
... class MyData:
|
|
149
|
+
... name: str
|
|
150
|
+
... embeddings: NDArray
|
|
151
|
+
...
|
|
152
|
+
>>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
|
|
153
|
+
>>> packed = sample.packed # Serialize to bytes
|
|
154
|
+
>>> restored = MyData.from_bytes(packed) # Deserialize
|
|
155
|
+
"""
|
|
137
156
|
|
|
138
157
|
def _ensure_good( self ):
|
|
139
|
-
"""
|
|
158
|
+
"""Auto-convert annotated NDArray fields from bytes to numpy arrays.
|
|
159
|
+
|
|
160
|
+
This method scans all dataclass fields and for any field annotated as
|
|
161
|
+
``NDArray`` or ``NDArray | None``, automatically converts bytes values
|
|
162
|
+
to numpy arrays using the helper deserialization function. This enables
|
|
163
|
+
transparent handling of array serialization in msgpack data.
|
|
164
|
+
|
|
165
|
+
Note:
|
|
166
|
+
This is called during ``__post_init__`` to ensure proper type
|
|
167
|
+
conversion after deserialization.
|
|
168
|
+
"""
|
|
140
169
|
|
|
141
170
|
# Auto-convert known types when annotated
|
|
142
171
|
# for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
|
|
@@ -154,11 +183,8 @@ class PackableSample( ABC ):
|
|
|
154
183
|
# based on what is provided
|
|
155
184
|
|
|
156
185
|
if isinstance( var_cur_value, np.ndarray ):
|
|
157
|
-
#
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
# elif isinstance( var_cur_value, ArrayBytes ):
|
|
161
|
-
# setattr( self, var_name, var_cur_value.to_numpy )
|
|
186
|
+
# Already the correct type, no conversion needed
|
|
187
|
+
continue
|
|
162
188
|
|
|
163
189
|
elif isinstance( var_cur_value, bytes ):
|
|
164
190
|
# TODO This does create a constraint that serialized bytes
|
|
@@ -173,19 +199,45 @@ class PackableSample( ABC ):
|
|
|
173
199
|
|
|
174
200
|
@classmethod
|
|
175
201
|
def from_data( cls, data: MsgpackRawSample ) -> Self:
|
|
176
|
-
"""Create a sample instance from unpacked msgpack data
|
|
202
|
+
"""Create a sample instance from unpacked msgpack data.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
data: A dictionary of unpacked msgpack data with keys matching
|
|
206
|
+
the sample's field names.
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
A new instance of this sample class with fields populated from
|
|
210
|
+
the data dictionary and NDArray fields auto-converted from bytes.
|
|
211
|
+
"""
|
|
177
212
|
ret = cls( **data )
|
|
178
213
|
ret._ensure_good()
|
|
179
214
|
return ret
|
|
180
215
|
|
|
181
216
|
@classmethod
|
|
182
217
|
def from_bytes( cls, bs: bytes ) -> Self:
|
|
183
|
-
"""Create a sample instance from raw msgpack bytes
|
|
218
|
+
"""Create a sample instance from raw msgpack bytes.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
bs: Raw bytes from a msgpack-serialized sample.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
A new instance of this sample class deserialized from the bytes.
|
|
225
|
+
"""
|
|
184
226
|
return cls.from_data( ormsgpack.unpackb( bs ) )
|
|
185
227
|
|
|
186
228
|
@property
|
|
187
229
|
def packed( self ) -> bytes:
|
|
188
|
-
"""Pack this sample's data into msgpack bytes
|
|
230
|
+
"""Pack this sample's data into msgpack bytes.
|
|
231
|
+
|
|
232
|
+
NDArray fields are automatically converted to bytes before packing.
|
|
233
|
+
All other fields are packed as-is if they're msgpack-compatible.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Raw msgpack bytes representing this sample's data.
|
|
237
|
+
|
|
238
|
+
Raises:
|
|
239
|
+
RuntimeError: If msgpack serialization fails.
|
|
240
|
+
"""
|
|
189
241
|
|
|
190
242
|
# Make sure that all of our (possibly unpackable) data is in a packable
|
|
191
243
|
# format
|
|
@@ -204,7 +256,15 @@ class PackableSample( ABC ):
|
|
|
204
256
|
# TODO Expand to allow for specifying explicit __key__
|
|
205
257
|
@property
|
|
206
258
|
def as_wds( self ) -> WDSRawSample:
|
|
207
|
-
"""Pack this sample's data for writing to
|
|
259
|
+
"""Pack this sample's data for writing to WebDataset.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
A dictionary with ``__key__`` (UUID v1 for sortable keys) and
|
|
263
|
+
``msgpack`` (packed sample data) fields suitable for WebDataset.
|
|
264
|
+
|
|
265
|
+
Note:
|
|
266
|
+
TODO: Expand to allow specifying explicit ``__key__`` values.
|
|
267
|
+
"""
|
|
208
268
|
return {
|
|
209
269
|
# Generates a UUID that is timelike-sortable
|
|
210
270
|
'__key__': str( uuid.uuid1( 0, 0 ) ),
|
|
@@ -212,30 +272,86 @@ class PackableSample( ABC ):
|
|
|
212
272
|
}
|
|
213
273
|
|
|
214
274
|
def _batch_aggregate( xs: Sequence ):
|
|
275
|
+
"""Aggregate a sequence of values into a batch-appropriate format.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
xs: A sequence of values to aggregate. If the first element is a numpy
|
|
279
|
+
array, all elements are stacked into a single array. Otherwise,
|
|
280
|
+
returns a list.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
A numpy array (if elements are arrays) or a list (otherwise).
|
|
284
|
+
"""
|
|
215
285
|
|
|
216
286
|
if not xs:
|
|
217
287
|
# Empty sequence
|
|
218
288
|
return []
|
|
219
289
|
|
|
220
|
-
# Aggregate
|
|
290
|
+
# Aggregate
|
|
221
291
|
if isinstance( xs[0], np.ndarray ):
|
|
222
292
|
return np.array( list( xs ) )
|
|
223
293
|
|
|
224
294
|
return list( xs )
|
|
225
295
|
|
|
226
296
|
class SampleBatch( Generic[DT] ):
|
|
297
|
+
"""A batch of samples with automatic attribute aggregation.
|
|
298
|
+
|
|
299
|
+
This class wraps a sequence of samples and provides magic ``__getattr__``
|
|
300
|
+
access to aggregate sample attributes. When you access an attribute that
|
|
301
|
+
exists on the sample type, it automatically aggregates values across all
|
|
302
|
+
samples in the batch.
|
|
303
|
+
|
|
304
|
+
NDArray fields are stacked into a numpy array with a batch dimension.
|
|
305
|
+
Other fields are aggregated into a list.
|
|
306
|
+
|
|
307
|
+
Type Parameters:
|
|
308
|
+
DT: The sample type, must derive from ``PackableSample``.
|
|
309
|
+
|
|
310
|
+
Attributes:
|
|
311
|
+
samples: The list of sample instances in this batch.
|
|
312
|
+
|
|
313
|
+
Example:
|
|
314
|
+
>>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
|
|
315
|
+
>>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
|
|
316
|
+
>>> batch.names # Returns list of names
|
|
317
|
+
"""
|
|
227
318
|
|
|
228
319
|
def __init__( self, samples: Sequence[DT] ):
|
|
229
|
-
"""
|
|
320
|
+
"""Create a batch from a sequence of samples.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
samples: A sequence of sample instances to aggregate into a batch.
|
|
324
|
+
Each sample must be an instance of a type derived from
|
|
325
|
+
``PackableSample``.
|
|
326
|
+
"""
|
|
230
327
|
self.samples = list( samples )
|
|
231
328
|
self._aggregate_cache = dict()
|
|
232
329
|
|
|
233
330
|
@property
|
|
234
331
|
def sample_type( self ) -> Type:
|
|
235
|
-
"""The type of each sample in this batch
|
|
332
|
+
"""The type of each sample in this batch.
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
|
|
336
|
+
"""
|
|
236
337
|
return typing.get_args( self.__orig_class__)[0]
|
|
237
338
|
|
|
238
339
|
def __getattr__( self, name ):
|
|
340
|
+
"""Aggregate an attribute across all samples in the batch.
|
|
341
|
+
|
|
342
|
+
This magic method enables attribute-style access to aggregated sample
|
|
343
|
+
fields. Results are cached for efficiency.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
name: The attribute name to aggregate across samples.
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
For NDArray fields: a stacked numpy array with batch dimension.
|
|
350
|
+
For other fields: a list of values from each sample.
|
|
351
|
+
|
|
352
|
+
Raises:
|
|
353
|
+
AttributeError: If the attribute doesn't exist on the sample type.
|
|
354
|
+
"""
|
|
239
355
|
# Aggregate named params of sample type
|
|
240
356
|
if name in vars( self.sample_type )['__annotations__']:
|
|
241
357
|
if name not in self._aggregate_cache:
|
|
@@ -243,91 +359,112 @@ class SampleBatch( Generic[DT] ):
|
|
|
243
359
|
[ getattr( x, name )
|
|
244
360
|
for x in self.samples ]
|
|
245
361
|
)
|
|
246
|
-
|
|
247
|
-
return self._aggregate_cache[name]
|
|
248
|
-
|
|
249
|
-
raise AttributeError( f'No sample attribute named {name}' )
|
|
250
|
-
|
|
251
362
|
|
|
252
|
-
|
|
253
|
-
# """A sample that can hold anything"""
|
|
254
|
-
# value: Any
|
|
363
|
+
return self._aggregate_cache[name]
|
|
255
364
|
|
|
256
|
-
|
|
257
|
-
# """A batch of `AnySample`s"""
|
|
258
|
-
# values: list[AnySample]
|
|
365
|
+
raise AttributeError( f'No sample attribute named {name}' )
|
|
259
366
|
|
|
260
367
|
|
|
261
368
|
ST = TypeVar( 'ST', bound = PackableSample )
|
|
262
|
-
# BT = TypeVar( 'BT' )
|
|
263
|
-
|
|
264
369
|
RT = TypeVar( 'RT', bound = PackableSample )
|
|
265
370
|
|
|
266
|
-
# TODO For python 3.13
|
|
267
|
-
# BT = TypeVar( 'BT', default = None )
|
|
268
|
-
# IT = TypeVar( 'IT', default = Any )
|
|
269
|
-
|
|
270
371
|
class Dataset( Generic[ST] ):
|
|
271
|
-
"""A dataset
|
|
372
|
+
"""A typed dataset built on WebDataset with lens transformations.
|
|
373
|
+
|
|
374
|
+
This class wraps WebDataset tar archives and provides type-safe iteration
|
|
375
|
+
over samples of a specific ``PackableSample`` type. Samples are stored as
|
|
376
|
+
msgpack-serialized data within WebDataset shards.
|
|
377
|
+
|
|
378
|
+
The dataset supports:
|
|
379
|
+
- Ordered and shuffled iteration
|
|
380
|
+
- Automatic batching with ``SampleBatch``
|
|
381
|
+
- Type transformations via the lens system (``as_type()``)
|
|
382
|
+
- Export to parquet format
|
|
383
|
+
|
|
384
|
+
Type Parameters:
|
|
385
|
+
ST: The sample type for this dataset, must derive from ``PackableSample``.
|
|
386
|
+
|
|
387
|
+
Attributes:
|
|
388
|
+
url: WebDataset brace-notation URL for the tar file(s).
|
|
389
|
+
|
|
390
|
+
Example:
|
|
391
|
+
>>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
|
|
392
|
+
>>> for sample in ds.ordered(batch_size=32):
|
|
393
|
+
... # sample is SampleBatch[MyData] with batch_size samples
|
|
394
|
+
... embeddings = sample.embeddings # shape: (32, ...)
|
|
395
|
+
...
|
|
396
|
+
>>> # Transform to a different view
|
|
397
|
+
>>> ds_view = ds.as_type(MyDataView)
|
|
272
398
|
|
|
273
|
-
(Abstract base for subclassing)
|
|
274
399
|
"""
|
|
275
400
|
|
|
276
|
-
# sample_class: Type = get_parameters( )
|
|
277
|
-
# """The type of each returned sample from this `Dataset`'s iterator"""
|
|
278
|
-
# batch_class: Type = get_bound( BT )
|
|
279
|
-
# """The type of a batch built from `sample_class`"""
|
|
280
|
-
|
|
281
401
|
@property
|
|
282
402
|
def sample_type( self ) -> Type:
|
|
283
|
-
"""The type of each returned sample from this
|
|
284
|
-
|
|
403
|
+
"""The type of each returned sample from this dataset's iterator.
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
The type parameter ``ST`` used when creating this ``Dataset[ST]``.
|
|
407
|
+
|
|
408
|
+
Note:
|
|
409
|
+
Extracts the type parameter at runtime using ``__orig_class__``.
|
|
410
|
+
"""
|
|
411
|
+
# NOTE: Linting may fail here due to __orig_class__ being a runtime attribute
|
|
285
412
|
return typing.get_args( self.__orig_class__ )[0]
|
|
286
413
|
@property
|
|
287
414
|
def batch_type( self ) -> Type:
|
|
288
|
-
"""The type of
|
|
289
|
-
# return self.__orig_class__.__args__[1]
|
|
290
|
-
return SampleBatch[self.sample_type]
|
|
291
|
-
|
|
415
|
+
"""The type of batches produced by this dataset.
|
|
292
416
|
|
|
293
|
-
|
|
294
|
-
|
|
417
|
+
Returns:
|
|
418
|
+
``SampleBatch[ST]`` where ``ST`` is this dataset's sample type.
|
|
419
|
+
"""
|
|
420
|
+
return SampleBatch[self.sample_type]
|
|
295
421
|
|
|
296
|
-
|
|
422
|
+
def __init__( self, url: str,
|
|
423
|
+
metadata_url: str | None = None,
|
|
424
|
+
) -> None:
|
|
425
|
+
"""Create a dataset from a WebDataset URL.
|
|
297
426
|
|
|
298
|
-
|
|
299
|
-
|
|
427
|
+
Args:
|
|
428
|
+
url: WebDataset brace-notation URL pointing to tar files, e.g.,
|
|
429
|
+
``"path/to/file-{000000..000009}.tar"`` for multiple shards or
|
|
430
|
+
``"path/to/file-000000.tar"`` for a single shard.
|
|
431
|
+
"""
|
|
300
432
|
super().__init__()
|
|
301
433
|
self.url = url
|
|
434
|
+
"""WebDataset brace-notation URL pointing to tar files, e.g.,
|
|
435
|
+
``"path/to/file-{000000..000009}.tar"`` for multiple shards or
|
|
436
|
+
``"path/to/file-000000.tar"`` for a single shard.
|
|
437
|
+
"""
|
|
438
|
+
|
|
439
|
+
self._metadata: dict[str, Any] | None = None
|
|
440
|
+
self.metadata_url: str | None = metadata_url
|
|
441
|
+
"""Optional URL to msgpack-encoded metadata for this dataset."""
|
|
302
442
|
|
|
303
443
|
# Allow addition of automatic transformation of raw underlying data
|
|
304
444
|
self._output_lens: Lens | None = None
|
|
305
445
|
|
|
306
446
|
def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
|
|
307
|
-
"""
|
|
447
|
+
"""View this dataset through a different sample type using a registered lens.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
other: The target sample type to transform into. Must be a type
|
|
451
|
+
derived from ``PackableSample``.
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
A new ``Dataset`` instance that yields samples of type ``other``
|
|
455
|
+
by applying the appropriate lens transformation from the global
|
|
456
|
+
``LensNetwork`` registry.
|
|
457
|
+
|
|
458
|
+
Raises:
|
|
459
|
+
ValueError: If no registered lens exists between the current
|
|
460
|
+
sample type and the target type.
|
|
461
|
+
"""
|
|
308
462
|
ret = Dataset[other]( self.url )
|
|
309
463
|
# Get the singleton lens registry
|
|
310
464
|
lenses = LensNetwork()
|
|
311
465
|
ret._output_lens = lenses.transform( self.sample_type, ret.sample_type )
|
|
312
466
|
return ret
|
|
313
467
|
|
|
314
|
-
# @classmethod
|
|
315
|
-
# def register( cls, uri: str,
|
|
316
|
-
# sample_class: Type,
|
|
317
|
-
# batch_class: Optional[Type] = None,
|
|
318
|
-
# ):
|
|
319
|
-
# """Register an `ekumen` schema to use a particular dataset sample class"""
|
|
320
|
-
# cls._schema_registry_sample[uri] = sample_class
|
|
321
|
-
# cls._schema_registry_batch[uri] = batch_class
|
|
322
|
-
|
|
323
|
-
# @classmethod
|
|
324
|
-
# def at( cls, uri: str ) -> 'Dataset':
|
|
325
|
-
# """Create a Dataset for the `ekumen` index entry at `uri`"""
|
|
326
|
-
# client = eat.Client()
|
|
327
|
-
# return cls( )
|
|
328
|
-
|
|
329
|
-
# Common functionality
|
|
330
|
-
|
|
331
468
|
@property
|
|
332
469
|
def shard_list( self ) -> list[str]:
|
|
333
470
|
"""List of individual dataset shards
|
|
@@ -341,6 +478,27 @@ class Dataset( Generic[ST] ):
|
|
|
341
478
|
wds.filters.map( lambda x: x['url'] )
|
|
342
479
|
)
|
|
343
480
|
return list( pipe )
|
|
481
|
+
|
|
482
|
+
@property
|
|
483
|
+
def metadata( self ) -> dict[str, Any] | None:
|
|
484
|
+
"""Fetch and cache metadata from metadata_url.
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
Deserialized metadata dictionary, or None if no metadata_url is set.
|
|
488
|
+
|
|
489
|
+
Raises:
|
|
490
|
+
requests.HTTPError: If metadata fetch fails.
|
|
491
|
+
"""
|
|
492
|
+
if self.metadata_url is None:
|
|
493
|
+
return None
|
|
494
|
+
|
|
495
|
+
if self._metadata is None:
|
|
496
|
+
with requests.get( self.metadata_url, stream = True ) as response:
|
|
497
|
+
response.raise_for_status()
|
|
498
|
+
self._metadata = msgpack.unpackb( response.content, raw = False )
|
|
499
|
+
|
|
500
|
+
# Use our cached values
|
|
501
|
+
return self._metadata
|
|
344
502
|
|
|
345
503
|
def ordered( self,
|
|
346
504
|
batch_size: int | None = 1,
|
|
@@ -359,22 +517,17 @@ class Dataset( Generic[ST] ):
|
|
|
359
517
|
"""
|
|
360
518
|
|
|
361
519
|
if batch_size is None:
|
|
362
|
-
# TODO Duplication here
|
|
363
520
|
return wds.pipeline.DataPipeline(
|
|
364
521
|
wds.shardlists.SimpleShardList( self.url ),
|
|
365
522
|
wds.shardlists.split_by_worker,
|
|
366
|
-
#
|
|
367
523
|
wds.tariterators.tarfile_to_samples(),
|
|
368
|
-
# wds.map( self.preprocess ),
|
|
369
524
|
wds.filters.map( self.wrap ),
|
|
370
525
|
)
|
|
371
526
|
|
|
372
527
|
return wds.pipeline.DataPipeline(
|
|
373
528
|
wds.shardlists.SimpleShardList( self.url ),
|
|
374
529
|
wds.shardlists.split_by_worker,
|
|
375
|
-
#
|
|
376
530
|
wds.tariterators.tarfile_to_samples(),
|
|
377
|
-
# wds.map( self.preprocess ),
|
|
378
531
|
wds.filters.batched( batch_size ),
|
|
379
532
|
wds.filters.map( self.wrap_batch ),
|
|
380
533
|
)
|
|
@@ -384,30 +537,30 @@ class Dataset( Generic[ST] ):
|
|
|
384
537
|
buffer_samples: int = 10_000,
|
|
385
538
|
batch_size: int | None = 1,
|
|
386
539
|
) -> Iterable[ST]:
|
|
387
|
-
"""Iterate over the dataset in random order
|
|
388
|
-
|
|
540
|
+
"""Iterate over the dataset in random order.
|
|
541
|
+
|
|
389
542
|
Args:
|
|
390
|
-
buffer_shards
|
|
391
|
-
|
|
392
|
-
Default:
|
|
393
|
-
|
|
394
|
-
|
|
543
|
+
buffer_shards: Number of shards to buffer for shuffling at the
|
|
544
|
+
shard level. Larger values increase randomness but use more
|
|
545
|
+
memory. Default: 100.
|
|
546
|
+
buffer_samples: Number of samples to buffer for shuffling within
|
|
547
|
+
shards. Larger values increase randomness but use more memory.
|
|
548
|
+
Default: 10,000.
|
|
549
|
+
batch_size: The size of iterated batches. Default: 1. If ``None``,
|
|
550
|
+
iterates over one sample at a time with no batch dimension.
|
|
551
|
+
|
|
395
552
|
Returns:
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
553
|
+
A WebDataset data pipeline that iterates over the dataset in
|
|
554
|
+
randomized order. If ``batch_size`` is not ``None``, yields
|
|
555
|
+
``SampleBatch[ST]`` instances; otherwise yields individual ``ST``
|
|
556
|
+
samples.
|
|
399
557
|
"""
|
|
400
|
-
|
|
401
558
|
if batch_size is None:
|
|
402
|
-
# TODO Duplication here
|
|
403
559
|
return wds.pipeline.DataPipeline(
|
|
404
560
|
wds.shardlists.SimpleShardList( self.url ),
|
|
405
561
|
wds.filters.shuffle( buffer_shards ),
|
|
406
562
|
wds.shardlists.split_by_worker,
|
|
407
|
-
#
|
|
408
563
|
wds.tariterators.tarfile_to_samples(),
|
|
409
|
-
# wds.shuffle( buffer_samples ),
|
|
410
|
-
# wds.map( self.preprocess ),
|
|
411
564
|
wds.filters.shuffle( buffer_samples ),
|
|
412
565
|
wds.filters.map( self.wrap ),
|
|
413
566
|
)
|
|
@@ -416,10 +569,7 @@ class Dataset( Generic[ST] ):
|
|
|
416
569
|
wds.shardlists.SimpleShardList( self.url ),
|
|
417
570
|
wds.filters.shuffle( buffer_shards ),
|
|
418
571
|
wds.shardlists.split_by_worker,
|
|
419
|
-
#
|
|
420
572
|
wds.tariterators.tarfile_to_samples(),
|
|
421
|
-
# wds.shuffle( buffer_samples ),
|
|
422
|
-
# wds.map( self.preprocess ),
|
|
423
573
|
wds.filters.shuffle( buffer_samples ),
|
|
424
574
|
wds.filters.batched( batch_size ),
|
|
425
575
|
wds.filters.map( self.wrap_batch ),
|
|
@@ -462,11 +612,11 @@ class Dataset( Generic[ST] ):
|
|
|
462
612
|
|
|
463
613
|
cur_segment = 0
|
|
464
614
|
cur_buffer = []
|
|
465
|
-
path_template = (path.parent / f'{path.stem}
|
|
615
|
+
path_template = (path.parent / f'{path.stem}-{{:06d}}{path.suffix}').as_posix()
|
|
466
616
|
|
|
467
617
|
for x in self.ordered( batch_size = None ):
|
|
468
618
|
cur_buffer.append( sample_map( x ) )
|
|
469
|
-
|
|
619
|
+
|
|
470
620
|
if len( cur_buffer ) >= maxcount:
|
|
471
621
|
# Write current segment
|
|
472
622
|
cur_path = path_template.format( cur_segment )
|
|
@@ -482,25 +632,17 @@ class Dataset( Generic[ST] ):
|
|
|
482
632
|
df = pd.DataFrame( cur_buffer )
|
|
483
633
|
df.to_parquet( cur_path, **kwargs )
|
|
484
634
|
|
|
635
|
+
def wrap( self, sample: MsgpackRawSample ) -> ST:
|
|
636
|
+
"""Wrap a raw msgpack sample into the appropriate dataset-specific type.
|
|
485
637
|
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
# @abstractmethod
|
|
490
|
-
# def url( self ) -> str:
|
|
491
|
-
# """str: Brace-notation URL of the underlying full WebDataset"""
|
|
492
|
-
# pass
|
|
493
|
-
|
|
494
|
-
# @classmethod
|
|
495
|
-
# # TODO replace Any with IT
|
|
496
|
-
# def preprocess( cls, sample: WDSRawSample ) -> Any:
|
|
497
|
-
# """Pre-built preprocessor for a raw `sample` from the given dataset"""
|
|
498
|
-
# return sample
|
|
638
|
+
Args:
|
|
639
|
+
sample: A dictionary containing at minimum a ``'msgpack'`` key with
|
|
640
|
+
serialized sample bytes.
|
|
499
641
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
"""
|
|
642
|
+
Returns:
|
|
643
|
+
A deserialized sample of type ``ST``, optionally transformed through
|
|
644
|
+
a lens if ``as_type()`` was called.
|
|
645
|
+
"""
|
|
504
646
|
assert 'msgpack' in sample
|
|
505
647
|
assert type( sample['msgpack'] ) == bytes
|
|
506
648
|
|
|
@@ -509,24 +651,21 @@ class Dataset( Generic[ST] ):
|
|
|
509
651
|
|
|
510
652
|
source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] )
|
|
511
653
|
return self._output_lens( source_sample )
|
|
512
|
-
|
|
513
|
-
# try:
|
|
514
|
-
# assert type( sample ) == dict
|
|
515
|
-
# return cls.sample_class( **{
|
|
516
|
-
# k: v
|
|
517
|
-
# for k, v in sample.items() if k != '__key__'
|
|
518
|
-
# } )
|
|
519
|
-
|
|
520
|
-
# except Exception as e:
|
|
521
|
-
# # Sample constructor failed -- revert to default
|
|
522
|
-
# return AnySample(
|
|
523
|
-
# value = sample,
|
|
524
|
-
# )
|
|
525
654
|
|
|
526
655
|
def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
|
|
527
|
-
"""Wrap a
|
|
528
|
-
|
|
529
|
-
|
|
656
|
+
"""Wrap a batch of raw msgpack samples into a typed SampleBatch.
|
|
657
|
+
|
|
658
|
+
Args:
|
|
659
|
+
batch: A dictionary containing a ``'msgpack'`` key with a list of
|
|
660
|
+
serialized sample bytes.
|
|
661
|
+
|
|
662
|
+
Returns:
|
|
663
|
+
A ``SampleBatch[ST]`` containing deserialized samples, optionally
|
|
664
|
+
transformed through a lens if ``as_type()`` was called.
|
|
665
|
+
|
|
666
|
+
Note:
|
|
667
|
+
This implementation deserializes samples one at a time, then
|
|
668
|
+
aggregates them into a batch.
|
|
530
669
|
"""
|
|
531
670
|
|
|
532
671
|
assert 'msgpack' in batch
|
|
@@ -542,38 +681,32 @@ class Dataset( Generic[ST] ):
|
|
|
542
681
|
for s in batch_source ]
|
|
543
682
|
return SampleBatch[self.sample_type]( batch_view )
|
|
544
683
|
|
|
545
|
-
# # @classmethod
|
|
546
|
-
# def wrap_batch( self, batch: WDSRawBatch ) -> BT:
|
|
547
|
-
# """Wrap a `batch` of samples into the appropriate dataset-specific type
|
|
548
|
-
|
|
549
|
-
# This default implementation simply creates a list one sample at a time
|
|
550
|
-
# """
|
|
551
|
-
# assert cls.batch_class is not None, 'No batch class specified'
|
|
552
|
-
# return cls.batch_class( **batch )
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
##
|
|
556
|
-
# Shortcut decorators
|
|
557
|
-
|
|
558
|
-
# def packable( cls ):
|
|
559
|
-
# """TODO"""
|
|
560
|
-
|
|
561
|
-
# def decorator( cls ):
|
|
562
|
-
# # Create a new class dynamically
|
|
563
|
-
# # The new class inherits from the new_parent_class first, then the original cls
|
|
564
|
-
# new_bases = (PackableSample,) + cls.__bases__
|
|
565
|
-
# new_cls = type(cls.__name__, new_bases, dict(cls.__dict__))
|
|
566
|
-
|
|
567
|
-
# # Optionally, update __module__ and __qualname__ for better introspection
|
|
568
|
-
# new_cls.__module__ = cls.__module__
|
|
569
|
-
# new_cls.__qualname__ = cls.__qualname__
|
|
570
|
-
|
|
571
|
-
# return new_cls
|
|
572
|
-
# return decorator
|
|
573
684
|
|
|
574
685
|
def packable( cls ):
|
|
575
|
-
"""
|
|
576
|
-
|
|
686
|
+
"""Decorator to convert a regular class into a ``PackableSample``.
|
|
687
|
+
|
|
688
|
+
This decorator transforms a class into a dataclass that inherits from
|
|
689
|
+
``PackableSample``, enabling automatic msgpack serialization/deserialization
|
|
690
|
+
with special handling for NDArray fields.
|
|
691
|
+
|
|
692
|
+
Args:
|
|
693
|
+
cls: The class to convert. Should have type annotations for its fields.
|
|
694
|
+
|
|
695
|
+
Returns:
|
|
696
|
+
A new dataclass that inherits from ``PackableSample`` with the same
|
|
697
|
+
name and annotations as the original class.
|
|
698
|
+
|
|
699
|
+
Example:
|
|
700
|
+
>>> @packable
|
|
701
|
+
... class MyData:
|
|
702
|
+
... name: str
|
|
703
|
+
... values: NDArray
|
|
704
|
+
...
|
|
705
|
+
>>> sample = MyData(name="test", values=np.array([1, 2, 3]))
|
|
706
|
+
>>> bytes_data = sample.packed
|
|
707
|
+
>>> restored = MyData.from_bytes(bytes_data)
|
|
708
|
+
"""
|
|
709
|
+
|
|
577
710
|
##
|
|
578
711
|
|
|
579
712
|
class_name = cls.__name__
|