atdata 0.1.3b4__py3-none-any.whl → 0.2.2b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +44 -8
- atdata/_cid.py +150 -0
- atdata/_hf_api.py +692 -0
- atdata/_protocols.py +519 -0
- atdata/_schema_codec.py +442 -0
- atdata/_sources.py +515 -0
- atdata/_stub_manager.py +529 -0
- atdata/_type_utils.py +90 -0
- atdata/atmosphere/__init__.py +332 -0
- atdata/atmosphere/_types.py +331 -0
- atdata/atmosphere/client.py +533 -0
- atdata/atmosphere/lens.py +284 -0
- atdata/atmosphere/records.py +509 -0
- atdata/atmosphere/schema.py +239 -0
- atdata/atmosphere/store.py +208 -0
- atdata/cli/__init__.py +213 -0
- atdata/cli/diagnose.py +165 -0
- atdata/cli/local.py +280 -0
- atdata/dataset.py +510 -324
- atdata/lens.py +63 -112
- atdata/local.py +1707 -0
- atdata/promote.py +199 -0
- atdata-0.2.2b1.dist-info/METADATA +272 -0
- atdata-0.2.2b1.dist-info/RECORD +28 -0
- {atdata-0.1.3b4.dist-info → atdata-0.2.2b1.dist-info}/WHEEL +1 -1
- atdata-0.1.3b4.dist-info/METADATA +0 -172
- atdata-0.1.3b4.dist-info/RECORD +0 -9
- {atdata-0.1.3b4.dist-info → atdata-0.2.2b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.1.3b4.dist-info → atdata-0.2.2b1.dist-info}/licenses/LICENSE +0 -0
atdata/dataset.py
CHANGED
|
@@ -14,15 +14,17 @@ during serialization, enabling efficient storage of numerical data in WebDataset
|
|
|
14
14
|
archives.
|
|
15
15
|
|
|
16
16
|
Example:
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
17
|
+
::
|
|
18
|
+
|
|
19
|
+
>>> @packable
|
|
20
|
+
... class ImageSample:
|
|
21
|
+
... image: NDArray
|
|
22
|
+
... label: str
|
|
23
|
+
...
|
|
24
|
+
>>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
|
|
25
|
+
>>> for batch in ds.shuffled(batch_size=32):
|
|
26
|
+
... images = batch.image # Stacked numpy array (32, H, W, C)
|
|
27
|
+
... labels = batch.label # List of 32 strings
|
|
26
28
|
"""
|
|
27
29
|
|
|
28
30
|
##
|
|
@@ -32,7 +34,6 @@ import webdataset as wds
|
|
|
32
34
|
|
|
33
35
|
from pathlib import Path
|
|
34
36
|
import uuid
|
|
35
|
-
import functools
|
|
36
37
|
|
|
37
38
|
import dataclasses
|
|
38
39
|
import types
|
|
@@ -40,40 +41,33 @@ from dataclasses import (
|
|
|
40
41
|
dataclass,
|
|
41
42
|
asdict,
|
|
42
43
|
)
|
|
43
|
-
from abc import
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
44
|
+
from abc import ABC
|
|
45
|
+
|
|
46
|
+
from ._sources import URLSource, S3Source
|
|
47
|
+
from ._protocols import DataSource
|
|
47
48
|
|
|
48
49
|
from tqdm import tqdm
|
|
49
50
|
import numpy as np
|
|
50
51
|
import pandas as pd
|
|
52
|
+
import requests
|
|
51
53
|
|
|
52
54
|
import typing
|
|
53
55
|
from typing import (
|
|
54
56
|
Any,
|
|
55
57
|
Optional,
|
|
56
58
|
Dict,
|
|
59
|
+
Iterator,
|
|
57
60
|
Sequence,
|
|
58
61
|
Iterable,
|
|
59
62
|
Callable,
|
|
60
|
-
Union,
|
|
61
|
-
#
|
|
62
63
|
Self,
|
|
63
64
|
Generic,
|
|
64
65
|
Type,
|
|
65
66
|
TypeVar,
|
|
66
67
|
TypeAlias,
|
|
68
|
+
dataclass_transform,
|
|
67
69
|
)
|
|
68
|
-
|
|
69
|
-
from numpy.typing import (
|
|
70
|
-
NDArray,
|
|
71
|
-
ArrayLike,
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
#
|
|
75
|
-
|
|
76
|
-
# import ekumen.atmosphere as eat
|
|
70
|
+
from numpy.typing import NDArray
|
|
77
71
|
|
|
78
72
|
import msgpack
|
|
79
73
|
import ormsgpack
|
|
@@ -86,6 +80,7 @@ from .lens import Lens, LensNetwork
|
|
|
86
80
|
|
|
87
81
|
Pathlike = str | Path
|
|
88
82
|
|
|
83
|
+
# WebDataset sample/batch dictionaries (contain __key__, msgpack, etc.)
|
|
89
84
|
WDSRawSample: TypeAlias = Dict[str, Any]
|
|
90
85
|
WDSRawBatch: TypeAlias = Dict[str, Any]
|
|
91
86
|
|
|
@@ -96,83 +91,191 @@ SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
|
|
|
96
91
|
##
|
|
97
92
|
# Main base classes
|
|
98
93
|
|
|
99
|
-
# TODO Check for best way to ensure this typevar is used as a dataclass type
|
|
100
|
-
# DT = TypeVar( 'DT', bound = dataclass.__class__ )
|
|
101
94
|
DT = TypeVar( 'DT' )
|
|
102
95
|
|
|
103
|
-
MsgpackRawSample: TypeAlias = Dict[str, Any]
|
|
104
|
-
|
|
105
|
-
# @dataclass
|
|
106
|
-
# class ArrayBytes:
|
|
107
|
-
# """Annotates bytes that should be interpreted as the raw contents of a
|
|
108
|
-
# numpy NDArray"""
|
|
109
|
-
|
|
110
|
-
# raw_bytes: bytes
|
|
111
|
-
# """The raw bytes of the corresponding NDArray"""
|
|
112
|
-
|
|
113
|
-
# def __init__( self,
|
|
114
|
-
# array: Optional[ArrayLike] = None,
|
|
115
|
-
# raw: Optional[bytes] = None,
|
|
116
|
-
# ):
|
|
117
|
-
# """TODO"""
|
|
118
|
-
|
|
119
|
-
# if array is not None:
|
|
120
|
-
# array = np.array( array )
|
|
121
|
-
# self.raw_bytes = eh.array_to_bytes( array )
|
|
122
|
-
|
|
123
|
-
# elif raw is not None:
|
|
124
|
-
# self.raw_bytes = raw
|
|
125
|
-
|
|
126
|
-
# else:
|
|
127
|
-
# raise ValueError( 'Must provide either `array` or `raw` bytes' )
|
|
128
|
-
|
|
129
|
-
# @property
|
|
130
|
-
# def to_numpy( self ) -> NDArray:
|
|
131
|
-
# """Return the `raw_bytes` data as an NDArray"""
|
|
132
|
-
# return eh.bytes_to_array( self.raw_bytes )
|
|
133
96
|
|
|
134
97
|
def _make_packable( x ):
|
|
135
|
-
"""Convert
|
|
136
|
-
|
|
137
|
-
Args:
|
|
138
|
-
x: A value to convert. If it's a numpy array, converts to bytes.
|
|
139
|
-
Otherwise returns the value unchanged.
|
|
140
|
-
|
|
141
|
-
Returns:
|
|
142
|
-
The value in a format suitable for msgpack serialization.
|
|
143
|
-
"""
|
|
144
|
-
# if isinstance( x, ArrayBytes ):
|
|
145
|
-
# return x.raw_bytes
|
|
98
|
+
"""Convert numpy arrays to bytes; pass through other values unchanged."""
|
|
146
99
|
if isinstance( x, np.ndarray ):
|
|
147
100
|
return eh.array_to_bytes( x )
|
|
148
101
|
return x
|
|
149
102
|
|
|
150
|
-
def _is_possibly_ndarray_type( t ):
|
|
151
|
-
"""Check if a type annotation is or contains NDArray.
|
|
152
|
-
|
|
153
|
-
Args:
|
|
154
|
-
t: A type annotation to check.
|
|
155
|
-
|
|
156
|
-
Returns:
|
|
157
|
-
``True`` if the type is ``NDArray`` or a union containing ``NDArray``
|
|
158
|
-
(e.g., ``NDArray | None``), ``False`` otherwise.
|
|
159
|
-
"""
|
|
160
103
|
|
|
161
|
-
|
|
104
|
+
def _is_possibly_ndarray_type( t ):
|
|
105
|
+
"""Return True if type annotation is NDArray or Optional[NDArray]."""
|
|
162
106
|
if t == NDArray:
|
|
163
|
-
# print( 'is an NDArray' )
|
|
164
107
|
return True
|
|
165
|
-
|
|
166
|
-
# Check for Optionals (i.e., NDArray | None)
|
|
167
108
|
if isinstance( t, types.UnionType ):
|
|
168
|
-
|
|
169
|
-
if any( x == NDArray
|
|
170
|
-
for x in t_parts ):
|
|
171
|
-
return True
|
|
172
|
-
|
|
173
|
-
# Not an NDArray
|
|
109
|
+
return any( x == NDArray for x in t.__args__ )
|
|
174
110
|
return False
|
|
175
111
|
|
|
112
|
+
class DictSample:
|
|
113
|
+
"""Dynamic sample type providing dict-like access to raw msgpack data.
|
|
114
|
+
|
|
115
|
+
This class is the default sample type for datasets when no explicit type is
|
|
116
|
+
specified. It stores the raw unpacked msgpack data and provides both
|
|
117
|
+
attribute-style (``sample.field``) and dict-style (``sample["field"]``)
|
|
118
|
+
access to fields.
|
|
119
|
+
|
|
120
|
+
``DictSample`` is useful for:
|
|
121
|
+
- Exploring datasets without defining a schema first
|
|
122
|
+
- Working with datasets that have variable schemas
|
|
123
|
+
- Prototyping before committing to a typed schema
|
|
124
|
+
|
|
125
|
+
To convert to a typed schema, use ``Dataset.as_type()`` with a
|
|
126
|
+
``@packable``-decorated class. Every ``@packable`` class automatically
|
|
127
|
+
registers a lens from ``DictSample``, making this conversion seamless.
|
|
128
|
+
|
|
129
|
+
Example:
|
|
130
|
+
::
|
|
131
|
+
|
|
132
|
+
>>> ds = load_dataset("path/to/data.tar") # Returns Dataset[DictSample]
|
|
133
|
+
>>> for sample in ds.ordered():
|
|
134
|
+
... print(sample.some_field) # Attribute access
|
|
135
|
+
... print(sample["other_field"]) # Dict access
|
|
136
|
+
... print(sample.keys()) # Inspect available fields
|
|
137
|
+
...
|
|
138
|
+
>>> # Convert to typed schema
|
|
139
|
+
>>> typed_ds = ds.as_type(MyTypedSample)
|
|
140
|
+
|
|
141
|
+
Note:
|
|
142
|
+
NDArray fields are stored as raw bytes in DictSample. They are only
|
|
143
|
+
converted to numpy arrays when accessed through a typed sample class.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
__slots__ = ('_data',)
|
|
147
|
+
|
|
148
|
+
def __init__(self, _data: dict[str, Any] | None = None, **kwargs: Any) -> None:
|
|
149
|
+
"""Create a DictSample from a dictionary or keyword arguments.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
_data: Raw data dictionary. If provided, kwargs are ignored.
|
|
153
|
+
**kwargs: Field values if _data is not provided.
|
|
154
|
+
"""
|
|
155
|
+
if _data is not None:
|
|
156
|
+
object.__setattr__(self, '_data', _data)
|
|
157
|
+
else:
|
|
158
|
+
object.__setattr__(self, '_data', kwargs)
|
|
159
|
+
|
|
160
|
+
@classmethod
|
|
161
|
+
def from_data(cls, data: dict[str, Any]) -> 'DictSample':
|
|
162
|
+
"""Create a DictSample from unpacked msgpack data.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
data: Dictionary with field names as keys.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
New DictSample instance wrapping the data.
|
|
169
|
+
"""
|
|
170
|
+
return cls(_data=data)
|
|
171
|
+
|
|
172
|
+
@classmethod
|
|
173
|
+
def from_bytes(cls, bs: bytes) -> 'DictSample':
|
|
174
|
+
"""Create a DictSample from raw msgpack bytes.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
bs: Raw bytes from a msgpack-serialized sample.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
New DictSample instance with the unpacked data.
|
|
181
|
+
"""
|
|
182
|
+
return cls.from_data(ormsgpack.unpackb(bs))
|
|
183
|
+
|
|
184
|
+
def __getattr__(self, name: str) -> Any:
|
|
185
|
+
"""Access a field by attribute name.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
name: Field name to access.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
The field value.
|
|
192
|
+
|
|
193
|
+
Raises:
|
|
194
|
+
AttributeError: If the field doesn't exist.
|
|
195
|
+
"""
|
|
196
|
+
# Avoid infinite recursion for _data lookup
|
|
197
|
+
if name == '_data':
|
|
198
|
+
raise AttributeError(name)
|
|
199
|
+
try:
|
|
200
|
+
return self._data[name]
|
|
201
|
+
except KeyError:
|
|
202
|
+
raise AttributeError(
|
|
203
|
+
f"'{type(self).__name__}' has no field '{name}'. "
|
|
204
|
+
f"Available fields: {list(self._data.keys())}"
|
|
205
|
+
) from None
|
|
206
|
+
|
|
207
|
+
def __getitem__(self, key: str) -> Any:
|
|
208
|
+
"""Access a field by dict key.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
key: Field name to access.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
The field value.
|
|
215
|
+
|
|
216
|
+
Raises:
|
|
217
|
+
KeyError: If the field doesn't exist.
|
|
218
|
+
"""
|
|
219
|
+
return self._data[key]
|
|
220
|
+
|
|
221
|
+
def __contains__(self, key: str) -> bool:
|
|
222
|
+
"""Check if a field exists."""
|
|
223
|
+
return key in self._data
|
|
224
|
+
|
|
225
|
+
def keys(self) -> list[str]:
|
|
226
|
+
"""Return list of field names."""
|
|
227
|
+
return list(self._data.keys())
|
|
228
|
+
|
|
229
|
+
def values(self) -> list[Any]:
|
|
230
|
+
"""Return list of field values."""
|
|
231
|
+
return list(self._data.values())
|
|
232
|
+
|
|
233
|
+
def items(self) -> list[tuple[str, Any]]:
|
|
234
|
+
"""Return list of (field_name, value) tuples."""
|
|
235
|
+
return list(self._data.items())
|
|
236
|
+
|
|
237
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
238
|
+
"""Get a field value with optional default.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
key: Field name to access.
|
|
242
|
+
default: Value to return if field doesn't exist.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
The field value or default.
|
|
246
|
+
"""
|
|
247
|
+
return self._data.get(key, default)
|
|
248
|
+
|
|
249
|
+
def to_dict(self) -> dict[str, Any]:
|
|
250
|
+
"""Return a copy of the underlying data dictionary."""
|
|
251
|
+
return dict(self._data)
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def packed(self) -> bytes:
|
|
255
|
+
"""Pack this sample's data into msgpack bytes.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
Raw msgpack bytes representing this sample's data.
|
|
259
|
+
"""
|
|
260
|
+
return msgpack.packb(self._data)
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def as_wds(self) -> 'WDSRawSample':
|
|
264
|
+
"""Pack this sample's data for writing to WebDataset.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
A dictionary with ``__key__`` and ``msgpack`` fields.
|
|
268
|
+
"""
|
|
269
|
+
return {
|
|
270
|
+
'__key__': str(uuid.uuid1(0, 0)),
|
|
271
|
+
'msgpack': self.packed,
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
def __repr__(self) -> str:
|
|
275
|
+
fields = ', '.join(f'{k}=...' for k in self._data.keys())
|
|
276
|
+
return f'DictSample({fields})'
|
|
277
|
+
|
|
278
|
+
|
|
176
279
|
@dataclass
|
|
177
280
|
class PackableSample( ABC ):
|
|
178
281
|
"""Base class for samples that can be serialized with msgpack.
|
|
@@ -187,28 +290,20 @@ class PackableSample( ABC ):
|
|
|
187
290
|
2. Using the ``@packable`` decorator (recommended)
|
|
188
291
|
|
|
189
292
|
Example:
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
293
|
+
::
|
|
294
|
+
|
|
295
|
+
>>> @packable
|
|
296
|
+
... class MyData:
|
|
297
|
+
... name: str
|
|
298
|
+
... embeddings: NDArray
|
|
299
|
+
...
|
|
300
|
+
>>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
|
|
301
|
+
>>> packed = sample.packed # Serialize to bytes
|
|
302
|
+
>>> restored = MyData.from_bytes(packed) # Deserialize
|
|
198
303
|
"""
|
|
199
304
|
|
|
200
305
|
def _ensure_good( self ):
|
|
201
|
-
"""
|
|
202
|
-
|
|
203
|
-
This method scans all dataclass fields and for any field annotated as
|
|
204
|
-
``NDArray`` or ``NDArray | None``, automatically converts bytes values
|
|
205
|
-
to numpy arrays using the helper deserialization function. This enables
|
|
206
|
-
transparent handling of array serialization in msgpack data.
|
|
207
|
-
|
|
208
|
-
Note:
|
|
209
|
-
This is called during ``__post_init__`` to ensure proper type
|
|
210
|
-
conversion after deserialization.
|
|
211
|
-
"""
|
|
306
|
+
"""Convert bytes to NDArray for fields annotated as NDArray or NDArray | None."""
|
|
212
307
|
|
|
213
308
|
# Auto-convert known types when annotated
|
|
214
309
|
# for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
|
|
@@ -226,16 +321,13 @@ class PackableSample( ABC ):
|
|
|
226
321
|
# based on what is provided
|
|
227
322
|
|
|
228
323
|
if isinstance( var_cur_value, np.ndarray ):
|
|
229
|
-
#
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
# elif isinstance( var_cur_value, ArrayBytes ):
|
|
233
|
-
# setattr( self, var_name, var_cur_value.to_numpy )
|
|
324
|
+
# Already the correct type, no conversion needed
|
|
325
|
+
continue
|
|
234
326
|
|
|
235
327
|
elif isinstance( var_cur_value, bytes ):
|
|
236
|
-
#
|
|
237
|
-
#
|
|
238
|
-
# as
|
|
328
|
+
# Design note: bytes in NDArray-typed fields are always interpreted
|
|
329
|
+
# as serialized arrays. This means raw bytes fields must not be
|
|
330
|
+
# annotated as NDArray.
|
|
239
331
|
setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
|
|
240
332
|
|
|
241
333
|
def __post_init__( self ):
|
|
@@ -244,20 +336,16 @@ class PackableSample( ABC ):
|
|
|
244
336
|
##
|
|
245
337
|
|
|
246
338
|
@classmethod
|
|
247
|
-
def from_data( cls, data:
|
|
339
|
+
def from_data( cls, data: WDSRawSample ) -> Self:
|
|
248
340
|
"""Create a sample instance from unpacked msgpack data.
|
|
249
341
|
|
|
250
342
|
Args:
|
|
251
|
-
data:
|
|
252
|
-
the sample's field names.
|
|
343
|
+
data: Dictionary with keys matching the sample's field names.
|
|
253
344
|
|
|
254
345
|
Returns:
|
|
255
|
-
|
|
256
|
-
the data dictionary and NDArray fields auto-converted from bytes.
|
|
346
|
+
New instance with NDArray fields auto-converted from bytes.
|
|
257
347
|
"""
|
|
258
|
-
|
|
259
|
-
ret._ensure_good()
|
|
260
|
-
return ret
|
|
348
|
+
return cls( **data )
|
|
261
349
|
|
|
262
350
|
@classmethod
|
|
263
351
|
def from_bytes( cls, bs: bytes ) -> Self:
|
|
@@ -299,7 +387,6 @@ class PackableSample( ABC ):
|
|
|
299
387
|
|
|
300
388
|
return ret
|
|
301
389
|
|
|
302
|
-
# TODO Expand to allow for specifying explicit __key__
|
|
303
390
|
@property
|
|
304
391
|
def as_wds( self ) -> WDSRawSample:
|
|
305
392
|
"""Pack this sample's data for writing to WebDataset.
|
|
@@ -309,7 +396,8 @@ class PackableSample( ABC ):
|
|
|
309
396
|
``msgpack`` (packed sample data) fields suitable for WebDataset.
|
|
310
397
|
|
|
311
398
|
Note:
|
|
312
|
-
|
|
399
|
+
Keys are auto-generated as UUID v1 for time-sortable ordering.
|
|
400
|
+
Custom key specification is not currently supported.
|
|
313
401
|
"""
|
|
314
402
|
return {
|
|
315
403
|
# Generates a UUID that is timelike-sortable
|
|
@@ -318,25 +406,11 @@ class PackableSample( ABC ):
|
|
|
318
406
|
}
|
|
319
407
|
|
|
320
408
|
def _batch_aggregate( xs: Sequence ):
|
|
321
|
-
"""
|
|
322
|
-
|
|
323
|
-
Args:
|
|
324
|
-
xs: A sequence of values to aggregate. If the first element is a numpy
|
|
325
|
-
array, all elements are stacked into a single array. Otherwise,
|
|
326
|
-
returns a list.
|
|
327
|
-
|
|
328
|
-
Returns:
|
|
329
|
-
A numpy array (if elements are arrays) or a list (otherwise).
|
|
330
|
-
"""
|
|
331
|
-
|
|
409
|
+
"""Stack arrays into numpy array with batch dim; otherwise return list."""
|
|
332
410
|
if not xs:
|
|
333
|
-
# Empty sequence
|
|
334
411
|
return []
|
|
335
|
-
|
|
336
|
-
# Aggregate
|
|
337
412
|
if isinstance( xs[0], np.ndarray ):
|
|
338
413
|
return np.array( list( xs ) )
|
|
339
|
-
|
|
340
414
|
return list( xs )
|
|
341
415
|
|
|
342
416
|
class SampleBatch( Generic[DT] ):
|
|
@@ -350,17 +424,27 @@ class SampleBatch( Generic[DT] ):
|
|
|
350
424
|
NDArray fields are stacked into a numpy array with a batch dimension.
|
|
351
425
|
Other fields are aggregated into a list.
|
|
352
426
|
|
|
353
|
-
|
|
427
|
+
Parameters:
|
|
354
428
|
DT: The sample type, must derive from ``PackableSample``.
|
|
355
429
|
|
|
356
430
|
Attributes:
|
|
357
431
|
samples: The list of sample instances in this batch.
|
|
358
432
|
|
|
359
433
|
Example:
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
434
|
+
::
|
|
435
|
+
|
|
436
|
+
>>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
|
|
437
|
+
>>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
|
|
438
|
+
>>> batch.names # Returns list of names
|
|
439
|
+
|
|
440
|
+
Note:
|
|
441
|
+
This class uses Python's ``__orig_class__`` mechanism to extract the
|
|
442
|
+
type parameter at runtime. Instances must be created using the
|
|
443
|
+
subscripted syntax ``SampleBatch[MyType](samples)`` rather than
|
|
444
|
+
calling the constructor directly with an unsubscripted class.
|
|
363
445
|
"""
|
|
446
|
+
# Design note: The docstring uses "Parameters:" for type parameters because
|
|
447
|
+
# quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
|
|
364
448
|
|
|
365
449
|
def __init__( self, samples: Sequence[DT] ):
|
|
366
450
|
"""Create a batch from a sequence of samples.
|
|
@@ -372,6 +456,7 @@ class SampleBatch( Generic[DT] ):
|
|
|
372
456
|
"""
|
|
373
457
|
self.samples = list( samples )
|
|
374
458
|
self._aggregate_cache = dict()
|
|
459
|
+
self._sample_type_cache: Type | None = None
|
|
375
460
|
|
|
376
461
|
@property
|
|
377
462
|
def sample_type( self ) -> Type:
|
|
@@ -380,7 +465,10 @@ class SampleBatch( Generic[DT] ):
|
|
|
380
465
|
Returns:
|
|
381
466
|
The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
|
|
382
467
|
"""
|
|
383
|
-
|
|
468
|
+
if self._sample_type_cache is None:
|
|
469
|
+
self._sample_type_cache = typing.get_args( self.__orig_class__)[0]
|
|
470
|
+
assert self._sample_type_cache is not None
|
|
471
|
+
return self._sample_type_cache
|
|
384
472
|
|
|
385
473
|
def __getattr__( self, name ):
|
|
386
474
|
"""Aggregate an attribute across all samples in the batch.
|
|
@@ -411,23 +499,44 @@ class SampleBatch( Generic[DT] ):
|
|
|
411
499
|
raise AttributeError( f'No sample attribute named {name}' )
|
|
412
500
|
|
|
413
501
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
502
|
+
ST = TypeVar( 'ST', bound = PackableSample )
|
|
503
|
+
RT = TypeVar( 'RT', bound = PackableSample )
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
class _ShardListStage(wds.utils.PipelineStage):
|
|
507
|
+
"""Pipeline stage that yields {url: shard_id} dicts from a DataSource.
|
|
417
508
|
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
509
|
+
This is analogous to SimpleShardList but works with any DataSource.
|
|
510
|
+
Used as the first stage before split_by_worker.
|
|
511
|
+
"""
|
|
421
512
|
|
|
513
|
+
def __init__(self, source: DataSource):
|
|
514
|
+
self.source = source
|
|
422
515
|
|
|
423
|
-
|
|
424
|
-
|
|
516
|
+
def run(self):
|
|
517
|
+
"""Yield {url: shard_id} dicts for each shard."""
|
|
518
|
+
for shard_id in self.source.list_shards():
|
|
519
|
+
yield {"url": shard_id}
|
|
425
520
|
|
|
426
|
-
RT = TypeVar( 'RT', bound = PackableSample )
|
|
427
521
|
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
522
|
+
class _StreamOpenerStage(wds.utils.PipelineStage):
|
|
523
|
+
"""Pipeline stage that opens streams from a DataSource.
|
|
524
|
+
|
|
525
|
+
Takes {url: shard_id} dicts and adds a stream using source.open_shard().
|
|
526
|
+
This replaces WebDataset's url_opener stage.
|
|
527
|
+
"""
|
|
528
|
+
|
|
529
|
+
def __init__(self, source: DataSource):
|
|
530
|
+
self.source = source
|
|
531
|
+
|
|
532
|
+
def run(self, src):
|
|
533
|
+
"""Open streams for each shard dict."""
|
|
534
|
+
for sample in src:
|
|
535
|
+
shard_id = sample["url"]
|
|
536
|
+
stream = self.source.open_shard(shard_id)
|
|
537
|
+
sample["stream"] = stream
|
|
538
|
+
yield sample
|
|
539
|
+
|
|
431
540
|
|
|
432
541
|
class Dataset( Generic[ST] ):
|
|
433
542
|
"""A typed dataset built on WebDataset with lens transformations.
|
|
@@ -442,26 +551,31 @@ class Dataset( Generic[ST] ):
|
|
|
442
551
|
- Type transformations via the lens system (``as_type()``)
|
|
443
552
|
- Export to parquet format
|
|
444
553
|
|
|
445
|
-
|
|
554
|
+
Parameters:
|
|
446
555
|
ST: The sample type for this dataset, must derive from ``PackableSample``.
|
|
447
556
|
|
|
448
557
|
Attributes:
|
|
449
558
|
url: WebDataset brace-notation URL for the tar file(s).
|
|
450
559
|
|
|
451
560
|
Example:
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
561
|
+
::
|
|
562
|
+
|
|
563
|
+
>>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
|
|
564
|
+
>>> for sample in ds.ordered(batch_size=32):
|
|
565
|
+
... # sample is SampleBatch[MyData] with batch_size samples
|
|
566
|
+
... embeddings = sample.embeddings # shape: (32, ...)
|
|
567
|
+
...
|
|
568
|
+
>>> # Transform to a different view
|
|
569
|
+
>>> ds_view = ds.as_type(MyDataView)
|
|
570
|
+
|
|
571
|
+
Note:
|
|
572
|
+
This class uses Python's ``__orig_class__`` mechanism to extract the
|
|
573
|
+
type parameter at runtime. Instances must be created using the
|
|
574
|
+
subscripted syntax ``Dataset[MyType](url)`` rather than calling the
|
|
575
|
+
constructor directly with an unsubscripted class.
|
|
459
576
|
"""
|
|
460
|
-
|
|
461
|
-
#
|
|
462
|
-
# """The type of each returned sample from this `Dataset`'s iterator"""
|
|
463
|
-
# batch_class: Type = get_bound( BT )
|
|
464
|
-
# """The type of a batch built from `sample_class`"""
|
|
577
|
+
# Design note: The docstring uses "Parameters:" for type parameters because
|
|
578
|
+
# quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
|
|
465
579
|
|
|
466
580
|
@property
|
|
467
581
|
def sample_type( self ) -> Type:
|
|
@@ -469,12 +583,11 @@ class Dataset( Generic[ST] ):
|
|
|
469
583
|
|
|
470
584
|
Returns:
|
|
471
585
|
The type parameter ``ST`` used when creating this ``Dataset[ST]``.
|
|
472
|
-
|
|
473
|
-
Note:
|
|
474
|
-
Extracts the type parameter at runtime using ``__orig_class__``.
|
|
475
586
|
"""
|
|
476
|
-
|
|
477
|
-
|
|
587
|
+
if self._sample_type_cache is None:
|
|
588
|
+
self._sample_type_cache = typing.get_args( self.__orig_class__ )[0]
|
|
589
|
+
assert self._sample_type_cache is not None
|
|
590
|
+
return self._sample_type_cache
|
|
478
591
|
@property
|
|
479
592
|
def batch_type( self ) -> Type:
|
|
480
593
|
"""The type of batches produced by this dataset.
|
|
@@ -482,28 +595,60 @@ class Dataset( Generic[ST] ):
|
|
|
482
595
|
Returns:
|
|
483
596
|
``SampleBatch[ST]`` where ``ST`` is this dataset's sample type.
|
|
484
597
|
"""
|
|
485
|
-
# return self.__orig_class__.__args__[1]
|
|
486
598
|
return SampleBatch[self.sample_type]
|
|
487
599
|
|
|
600
|
+
def __init__( self,
|
|
601
|
+
source: DataSource | str | None = None,
|
|
602
|
+
metadata_url: str | None = None,
|
|
603
|
+
*,
|
|
604
|
+
url: str | None = None,
|
|
605
|
+
) -> None:
|
|
606
|
+
"""Create a dataset from a DataSource or URL.
|
|
488
607
|
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
608
|
+
Args:
|
|
609
|
+
source: Either a DataSource implementation or a WebDataset-compatible
|
|
610
|
+
URL string. If a string is provided, it's wrapped in URLSource
|
|
611
|
+
for backward compatibility.
|
|
493
612
|
|
|
494
|
-
|
|
495
|
-
|
|
613
|
+
Examples:
|
|
614
|
+
- String URL: ``"path/to/file-{000000..000009}.tar"``
|
|
615
|
+
- URLSource: ``URLSource("https://example.com/data.tar")``
|
|
616
|
+
- S3Source: ``S3Source(bucket="my-bucket", keys=["data.tar"])``
|
|
496
617
|
|
|
497
|
-
|
|
498
|
-
url:
|
|
499
|
-
``"path/to/file-{000000..000009}.tar"`` for multiple shards or
|
|
500
|
-
``"path/to/file-000000.tar"`` for a single shard.
|
|
618
|
+
metadata_url: Optional URL to msgpack-encoded metadata for this dataset.
|
|
619
|
+
url: Deprecated. Use ``source`` instead. Kept for backward compatibility.
|
|
501
620
|
"""
|
|
502
621
|
super().__init__()
|
|
503
|
-
self.url = url
|
|
504
622
|
|
|
505
|
-
#
|
|
623
|
+
# Handle backward compatibility: url= keyword argument
|
|
624
|
+
if source is None and url is not None:
|
|
625
|
+
source = url
|
|
626
|
+
elif source is None:
|
|
627
|
+
raise TypeError("Dataset() missing required argument: 'source' or 'url'")
|
|
628
|
+
|
|
629
|
+
# Normalize source: strings become URLSource for backward compatibility
|
|
630
|
+
if isinstance(source, str):
|
|
631
|
+
self._source: DataSource = URLSource(source)
|
|
632
|
+
self.url = source
|
|
633
|
+
else:
|
|
634
|
+
self._source = source
|
|
635
|
+
# For compatibility, expose URL if source has list_shards
|
|
636
|
+
shards = source.list_shards()
|
|
637
|
+
# Design note: Using first shard as url for legacy compatibility.
|
|
638
|
+
# Full shard list is available via list_shards() method.
|
|
639
|
+
self.url = shards[0] if shards else ""
|
|
640
|
+
|
|
641
|
+
self._metadata: dict[str, Any] | None = None
|
|
642
|
+
self.metadata_url: str | None = metadata_url
|
|
643
|
+
"""Optional URL to msgpack-encoded metadata for this dataset."""
|
|
644
|
+
|
|
506
645
|
self._output_lens: Lens | None = None
|
|
646
|
+
self._sample_type_cache: Type | None = None
|
|
647
|
+
|
|
648
|
+
@property
|
|
649
|
+
def source(self) -> DataSource:
|
|
650
|
+
"""The underlying data source for this dataset."""
|
|
651
|
+
return self._source
|
|
507
652
|
|
|
508
653
|
def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
|
|
509
654
|
"""View this dataset through a different sample type using a registered lens.
|
|
@@ -521,76 +666,104 @@ class Dataset( Generic[ST] ):
|
|
|
521
666
|
ValueError: If no registered lens exists between the current
|
|
522
667
|
sample type and the target type.
|
|
523
668
|
"""
|
|
524
|
-
ret = Dataset[other]( self.
|
|
669
|
+
ret = Dataset[other]( self._source )
|
|
525
670
|
# Get the singleton lens registry
|
|
526
671
|
lenses = LensNetwork()
|
|
527
672
|
ret._output_lens = lenses.transform( self.sample_type, ret.sample_type )
|
|
528
673
|
return ret
|
|
529
674
|
|
|
530
|
-
# @classmethod
|
|
531
|
-
# def register( cls, uri: str,
|
|
532
|
-
# sample_class: Type,
|
|
533
|
-
# batch_class: Optional[Type] = None,
|
|
534
|
-
# ):
|
|
535
|
-
# """Register an `ekumen` schema to use a particular dataset sample class"""
|
|
536
|
-
# cls._schema_registry_sample[uri] = sample_class
|
|
537
|
-
# cls._schema_registry_batch[uri] = batch_class
|
|
538
|
-
|
|
539
|
-
# @classmethod
|
|
540
|
-
# def at( cls, uri: str ) -> 'Dataset':
|
|
541
|
-
# """Create a Dataset for the `ekumen` index entry at `uri`"""
|
|
542
|
-
# client = eat.Client()
|
|
543
|
-
# return cls( )
|
|
544
|
-
|
|
545
|
-
# Common functionality
|
|
546
|
-
|
|
547
675
|
@property
|
|
548
|
-
def
|
|
549
|
-
"""
|
|
550
|
-
|
|
676
|
+
def shards(self) -> Iterator[str]:
|
|
677
|
+
"""Lazily iterate over shard identifiers.
|
|
678
|
+
|
|
679
|
+
Yields:
|
|
680
|
+
Shard identifiers (e.g., 'train-000000.tar', 'train-000001.tar').
|
|
681
|
+
|
|
682
|
+
Example:
|
|
683
|
+
::
|
|
684
|
+
|
|
685
|
+
>>> for shard in ds.shards:
|
|
686
|
+
... print(f"Processing {shard}")
|
|
687
|
+
"""
|
|
688
|
+
return iter(self._source.list_shards())
|
|
689
|
+
|
|
690
|
+
def list_shards(self) -> list[str]:
|
|
691
|
+
"""Get list of individual dataset shards.
|
|
692
|
+
|
|
551
693
|
Returns:
|
|
552
694
|
A full (non-lazy) list of the individual ``tar`` files within the
|
|
553
695
|
source WebDataset.
|
|
554
696
|
"""
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
697
|
+
return self._source.list_shards()
|
|
698
|
+
|
|
699
|
+
# Legacy alias for backwards compatibility
|
|
700
|
+
@property
|
|
701
|
+
def shard_list(self) -> list[str]:
|
|
702
|
+
"""List of individual dataset shards (deprecated, use list_shards()).
|
|
703
|
+
|
|
704
|
+
.. deprecated::
|
|
705
|
+
Use :meth:`list_shards` instead.
|
|
706
|
+
"""
|
|
707
|
+
import warnings
|
|
708
|
+
warnings.warn(
|
|
709
|
+
"shard_list is deprecated, use list_shards() instead",
|
|
710
|
+
DeprecationWarning,
|
|
711
|
+
stacklevel=2,
|
|
558
712
|
)
|
|
559
|
-
return
|
|
713
|
+
return self.list_shards()
|
|
714
|
+
|
|
715
|
+
@property
|
|
716
|
+
def metadata( self ) -> dict[str, Any] | None:
|
|
717
|
+
"""Fetch and cache metadata from metadata_url.
|
|
718
|
+
|
|
719
|
+
Returns:
|
|
720
|
+
Deserialized metadata dictionary, or None if no metadata_url is set.
|
|
721
|
+
|
|
722
|
+
Raises:
|
|
723
|
+
requests.HTTPError: If metadata fetch fails.
|
|
724
|
+
"""
|
|
725
|
+
if self.metadata_url is None:
|
|
726
|
+
return None
|
|
727
|
+
|
|
728
|
+
if self._metadata is None:
|
|
729
|
+
with requests.get( self.metadata_url, stream = True ) as response:
|
|
730
|
+
response.raise_for_status()
|
|
731
|
+
self._metadata = msgpack.unpackb( response.content, raw = False )
|
|
732
|
+
|
|
733
|
+
# Use our cached values
|
|
734
|
+
return self._metadata
|
|
560
735
|
|
|
561
736
|
def ordered( self,
|
|
562
|
-
batch_size: int | None =
|
|
737
|
+
batch_size: int | None = None,
|
|
563
738
|
) -> Iterable[ST]:
|
|
564
739
|
"""Iterate over the dataset in order
|
|
565
|
-
|
|
740
|
+
|
|
566
741
|
Args:
|
|
567
742
|
batch_size (:obj:`int`, optional): The size of iterated batches.
|
|
568
|
-
Default:
|
|
569
|
-
with no batch dimension.
|
|
570
|
-
|
|
743
|
+
Default: None (unbatched). If ``None``, iterates over one
|
|
744
|
+
sample at a time with no batch dimension.
|
|
745
|
+
|
|
571
746
|
Returns:
|
|
572
747
|
:obj:`webdataset.DataPipeline` A data pipeline that iterates over
|
|
573
748
|
the dataset in its original sample order
|
|
574
|
-
|
|
575
|
-
"""
|
|
576
749
|
|
|
750
|
+
"""
|
|
577
751
|
if batch_size is None:
|
|
578
|
-
# TODO Duplication here
|
|
579
752
|
return wds.pipeline.DataPipeline(
|
|
580
|
-
|
|
753
|
+
_ShardListStage(self._source),
|
|
581
754
|
wds.shardlists.split_by_worker,
|
|
582
|
-
|
|
583
|
-
wds.tariterators.
|
|
584
|
-
|
|
755
|
+
_StreamOpenerStage(self._source),
|
|
756
|
+
wds.tariterators.tar_file_expander,
|
|
757
|
+
wds.tariterators.group_by_keys,
|
|
585
758
|
wds.filters.map( self.wrap ),
|
|
586
759
|
)
|
|
587
760
|
|
|
588
761
|
return wds.pipeline.DataPipeline(
|
|
589
|
-
|
|
762
|
+
_ShardListStage(self._source),
|
|
590
763
|
wds.shardlists.split_by_worker,
|
|
591
|
-
|
|
592
|
-
wds.tariterators.
|
|
593
|
-
|
|
764
|
+
_StreamOpenerStage(self._source),
|
|
765
|
+
wds.tariterators.tar_file_expander,
|
|
766
|
+
wds.tariterators.group_by_keys,
|
|
594
767
|
wds.filters.batched( batch_size ),
|
|
595
768
|
wds.filters.map( self.wrap_batch ),
|
|
596
769
|
)
|
|
@@ -598,7 +771,7 @@ class Dataset( Generic[ST] ):
|
|
|
598
771
|
def shuffled( self,
|
|
599
772
|
buffer_shards: int = 100,
|
|
600
773
|
buffer_samples: int = 10_000,
|
|
601
|
-
batch_size: int | None =
|
|
774
|
+
batch_size: int | None = None,
|
|
602
775
|
) -> Iterable[ST]:
|
|
603
776
|
"""Iterate over the dataset in random order.
|
|
604
777
|
|
|
@@ -609,8 +782,9 @@ class Dataset( Generic[ST] ):
|
|
|
609
782
|
buffer_samples: Number of samples to buffer for shuffling within
|
|
610
783
|
shards. Larger values increase randomness but use more memory.
|
|
611
784
|
Default: 10,000.
|
|
612
|
-
batch_size: The size of iterated batches. Default:
|
|
613
|
-
iterates over one sample at a time with no batch
|
|
785
|
+
batch_size: The size of iterated batches. Default: None (unbatched).
|
|
786
|
+
If ``None``, iterates over one sample at a time with no batch
|
|
787
|
+
dimension.
|
|
614
788
|
|
|
615
789
|
Returns:
|
|
616
790
|
A WebDataset data pipeline that iterates over the dataset in
|
|
@@ -618,44 +792,74 @@ class Dataset( Generic[ST] ):
|
|
|
618
792
|
``SampleBatch[ST]`` instances; otherwise yields individual ``ST``
|
|
619
793
|
samples.
|
|
620
794
|
"""
|
|
621
|
-
|
|
622
795
|
if batch_size is None:
|
|
623
|
-
# TODO Duplication here
|
|
624
796
|
return wds.pipeline.DataPipeline(
|
|
625
|
-
|
|
797
|
+
_ShardListStage(self._source),
|
|
626
798
|
wds.filters.shuffle( buffer_shards ),
|
|
627
799
|
wds.shardlists.split_by_worker,
|
|
628
|
-
|
|
629
|
-
wds.tariterators.
|
|
630
|
-
|
|
631
|
-
# wds.map( self.preprocess ),
|
|
800
|
+
_StreamOpenerStage(self._source),
|
|
801
|
+
wds.tariterators.tar_file_expander,
|
|
802
|
+
wds.tariterators.group_by_keys,
|
|
632
803
|
wds.filters.shuffle( buffer_samples ),
|
|
633
804
|
wds.filters.map( self.wrap ),
|
|
634
805
|
)
|
|
635
806
|
|
|
636
807
|
return wds.pipeline.DataPipeline(
|
|
637
|
-
|
|
808
|
+
_ShardListStage(self._source),
|
|
638
809
|
wds.filters.shuffle( buffer_shards ),
|
|
639
810
|
wds.shardlists.split_by_worker,
|
|
640
|
-
|
|
641
|
-
wds.tariterators.
|
|
642
|
-
|
|
643
|
-
# wds.map( self.preprocess ),
|
|
811
|
+
_StreamOpenerStage(self._source),
|
|
812
|
+
wds.tariterators.tar_file_expander,
|
|
813
|
+
wds.tariterators.group_by_keys,
|
|
644
814
|
wds.filters.shuffle( buffer_samples ),
|
|
645
815
|
wds.filters.batched( batch_size ),
|
|
646
816
|
wds.filters.map( self.wrap_batch ),
|
|
647
817
|
)
|
|
648
818
|
|
|
649
|
-
#
|
|
650
|
-
#
|
|
819
|
+
# Design note: Uses pandas for parquet export. Could be replaced with
|
|
820
|
+
# direct fastparquet calls to reduce dependencies if needed.
|
|
651
821
|
def to_parquet( self, path: Pathlike,
|
|
652
822
|
sample_map: Optional[SampleExportMap] = None,
|
|
653
823
|
maxcount: Optional[int] = None,
|
|
654
824
|
**kwargs,
|
|
655
825
|
):
|
|
656
|
-
"""
|
|
826
|
+
"""Export dataset contents to parquet format.
|
|
827
|
+
|
|
828
|
+
Converts all samples to a pandas DataFrame and saves to parquet file(s).
|
|
829
|
+
Useful for interoperability with data analysis tools.
|
|
657
830
|
|
|
658
|
-
|
|
831
|
+
Args:
|
|
832
|
+
path: Output path for the parquet file. If ``maxcount`` is specified,
|
|
833
|
+
files are named ``{stem}-{segment:06d}.parquet``.
|
|
834
|
+
sample_map: Optional function to convert samples to dictionaries.
|
|
835
|
+
Defaults to ``dataclasses.asdict``.
|
|
836
|
+
maxcount: If specified, split output into multiple files with at most
|
|
837
|
+
this many samples each. Recommended for large datasets.
|
|
838
|
+
**kwargs: Additional arguments passed to ``pandas.DataFrame.to_parquet()``.
|
|
839
|
+
Common options include ``compression``, ``index``, ``engine``.
|
|
840
|
+
|
|
841
|
+
Warning:
|
|
842
|
+
**Memory Usage**: When ``maxcount=None`` (default), this method loads
|
|
843
|
+
the **entire dataset into memory** as a pandas DataFrame before writing.
|
|
844
|
+
For large datasets, this can cause memory exhaustion.
|
|
845
|
+
|
|
846
|
+
For datasets larger than available RAM, always specify ``maxcount``::
|
|
847
|
+
|
|
848
|
+
# Safe for large datasets - processes in chunks
|
|
849
|
+
ds.to_parquet("output.parquet", maxcount=10000)
|
|
850
|
+
|
|
851
|
+
This creates multiple parquet files: ``output-000000.parquet``,
|
|
852
|
+
``output-000001.parquet``, etc.
|
|
853
|
+
|
|
854
|
+
Example:
|
|
855
|
+
::
|
|
856
|
+
|
|
857
|
+
>>> ds = Dataset[MySample]("data.tar")
|
|
858
|
+
>>> # Small dataset - load all at once
|
|
859
|
+
>>> ds.to_parquet("output.parquet")
|
|
860
|
+
>>>
|
|
861
|
+
>>> # Large dataset - process in chunks
|
|
862
|
+
>>> ds.to_parquet("output.parquet", maxcount=50000)
|
|
659
863
|
"""
|
|
660
864
|
##
|
|
661
865
|
|
|
@@ -683,11 +887,11 @@ class Dataset( Generic[ST] ):
|
|
|
683
887
|
|
|
684
888
|
cur_segment = 0
|
|
685
889
|
cur_buffer = []
|
|
686
|
-
path_template = (path.parent / f'{path.stem}
|
|
890
|
+
path_template = (path.parent / f'{path.stem}-{{:06d}}{path.suffix}').as_posix()
|
|
687
891
|
|
|
688
892
|
for x in self.ordered( batch_size = None ):
|
|
689
893
|
cur_buffer.append( sample_map( x ) )
|
|
690
|
-
|
|
894
|
+
|
|
691
895
|
if len( cur_buffer ) >= maxcount:
|
|
692
896
|
# Write current segment
|
|
693
897
|
cur_path = path_template.format( cur_segment )
|
|
@@ -703,24 +907,7 @@ class Dataset( Generic[ST] ):
|
|
|
703
907
|
df = pd.DataFrame( cur_buffer )
|
|
704
908
|
df.to_parquet( cur_path, **kwargs )
|
|
705
909
|
|
|
706
|
-
|
|
707
|
-
# Implemented by specific subclasses
|
|
708
|
-
|
|
709
|
-
# @property
|
|
710
|
-
# @abstractmethod
|
|
711
|
-
# def url( self ) -> str:
|
|
712
|
-
# """str: Brace-notation URL of the underlying full WebDataset"""
|
|
713
|
-
# pass
|
|
714
|
-
|
|
715
|
-
# @classmethod
|
|
716
|
-
# # TODO replace Any with IT
|
|
717
|
-
# def preprocess( cls, sample: WDSRawSample ) -> Any:
|
|
718
|
-
# """Pre-built preprocessor for a raw `sample` from the given dataset"""
|
|
719
|
-
# return sample
|
|
720
|
-
|
|
721
|
-
# @classmethod
|
|
722
|
-
# TODO replace Any with IT
|
|
723
|
-
def wrap( self, sample: MsgpackRawSample ) -> ST:
|
|
910
|
+
def wrap( self, sample: WDSRawSample ) -> ST:
|
|
724
911
|
"""Wrap a raw msgpack sample into the appropriate dataset-specific type.
|
|
725
912
|
|
|
726
913
|
Args:
|
|
@@ -731,27 +918,16 @@ class Dataset( Generic[ST] ):
|
|
|
731
918
|
A deserialized sample of type ``ST``, optionally transformed through
|
|
732
919
|
a lens if ``as_type()`` was called.
|
|
733
920
|
"""
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
921
|
+
if 'msgpack' not in sample:
|
|
922
|
+
raise ValueError(f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}")
|
|
923
|
+
if not isinstance(sample['msgpack'], bytes):
|
|
924
|
+
raise ValueError(f"Expected sample['msgpack'] to be bytes, got {type(sample['msgpack']).__name__}")
|
|
925
|
+
|
|
737
926
|
if self._output_lens is None:
|
|
738
927
|
return self.sample_type.from_bytes( sample['msgpack'] )
|
|
739
928
|
|
|
740
929
|
source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] )
|
|
741
930
|
return self._output_lens( source_sample )
|
|
742
|
-
|
|
743
|
-
# try:
|
|
744
|
-
# assert type( sample ) == dict
|
|
745
|
-
# return cls.sample_class( **{
|
|
746
|
-
# k: v
|
|
747
|
-
# for k, v in sample.items() if k != '__key__'
|
|
748
|
-
# } )
|
|
749
|
-
|
|
750
|
-
# except Exception as e:
|
|
751
|
-
# # Sample constructor failed -- revert to default
|
|
752
|
-
# return AnySample(
|
|
753
|
-
# value = sample,
|
|
754
|
-
# )
|
|
755
931
|
|
|
756
932
|
def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
|
|
757
933
|
"""Wrap a batch of raw msgpack samples into a typed SampleBatch.
|
|
@@ -769,7 +945,8 @@ class Dataset( Generic[ST] ):
|
|
|
769
945
|
aggregates them into a batch.
|
|
770
946
|
"""
|
|
771
947
|
|
|
772
|
-
|
|
948
|
+
if 'msgpack' not in batch:
|
|
949
|
+
raise ValueError(f"Batch missing 'msgpack' key, got keys: {list(batch.keys())}")
|
|
773
950
|
|
|
774
951
|
if self._output_lens is None:
|
|
775
952
|
batch_unpacked = [ self.sample_type.from_bytes( bs )
|
|
@@ -782,58 +959,44 @@ class Dataset( Generic[ST] ):
|
|
|
782
959
|
for s in batch_source ]
|
|
783
960
|
return SampleBatch[self.sample_type]( batch_view )
|
|
784
961
|
|
|
785
|
-
# # @classmethod
|
|
786
|
-
# def wrap_batch( self, batch: WDSRawBatch ) -> BT:
|
|
787
|
-
# """Wrap a `batch` of samples into the appropriate dataset-specific type
|
|
788
|
-
|
|
789
|
-
# This default implementation simply creates a list one sample at a time
|
|
790
|
-
# """
|
|
791
|
-
# assert cls.batch_class is not None, 'No batch class specified'
|
|
792
|
-
# return cls.batch_class( **batch )
|
|
793
|
-
|
|
794
962
|
|
|
795
|
-
|
|
796
|
-
# Shortcut decorators
|
|
797
|
-
|
|
798
|
-
# def packable( cls ):
|
|
799
|
-
# """TODO"""
|
|
800
|
-
|
|
801
|
-
# def decorator( cls ):
|
|
802
|
-
# # Create a new class dynamically
|
|
803
|
-
# # The new class inherits from the new_parent_class first, then the original cls
|
|
804
|
-
# new_bases = (PackableSample,) + cls.__bases__
|
|
805
|
-
# new_cls = type(cls.__name__, new_bases, dict(cls.__dict__))
|
|
963
|
+
_T = TypeVar('_T')
|
|
806
964
|
|
|
807
|
-
# # Optionally, update __module__ and __qualname__ for better introspection
|
|
808
|
-
# new_cls.__module__ = cls.__module__
|
|
809
|
-
# new_cls.__qualname__ = cls.__qualname__
|
|
810
965
|
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
def packable( cls ):
|
|
966
|
+
@dataclass_transform()
|
|
967
|
+
def packable( cls: type[_T] ) -> type[_T]:
|
|
815
968
|
"""Decorator to convert a regular class into a ``PackableSample``.
|
|
816
969
|
|
|
817
970
|
This decorator transforms a class into a dataclass that inherits from
|
|
818
971
|
``PackableSample``, enabling automatic msgpack serialization/deserialization
|
|
819
972
|
with special handling for NDArray fields.
|
|
820
973
|
|
|
974
|
+
The resulting class satisfies the ``Packable`` protocol, making it compatible
|
|
975
|
+
with all atdata APIs that accept packable types (e.g., ``publish_schema``,
|
|
976
|
+
lens transformations, etc.).
|
|
977
|
+
|
|
821
978
|
Args:
|
|
822
979
|
cls: The class to convert. Should have type annotations for its fields.
|
|
823
980
|
|
|
824
981
|
Returns:
|
|
825
982
|
A new dataclass that inherits from ``PackableSample`` with the same
|
|
826
|
-
name and annotations as the original class.
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
983
|
+
name and annotations as the original class. The class satisfies the
|
|
984
|
+
``Packable`` protocol and can be used with ``Type[Packable]`` signatures.
|
|
985
|
+
|
|
986
|
+
Examples:
|
|
987
|
+
This is a test of the functionality::
|
|
988
|
+
|
|
989
|
+
@packable
|
|
990
|
+
class MyData:
|
|
991
|
+
name: str
|
|
992
|
+
values: NDArray
|
|
993
|
+
|
|
994
|
+
sample = MyData(name="test", values=np.array([1, 2, 3]))
|
|
995
|
+
bytes_data = sample.packed
|
|
996
|
+
restored = MyData.from_bytes(bytes_data)
|
|
997
|
+
|
|
998
|
+
# Works with Packable-typed APIs
|
|
999
|
+
index.publish_schema(MyData, version="1.0.0") # Type-safe
|
|
837
1000
|
"""
|
|
838
1001
|
|
|
839
1002
|
##
|
|
@@ -850,9 +1013,32 @@ def packable( cls ):
|
|
|
850
1013
|
def __post_init__( self ):
|
|
851
1014
|
return PackableSample.__post_init__( self )
|
|
852
1015
|
|
|
853
|
-
#
|
|
1016
|
+
# Restore original class identity for better repr/debugging
|
|
854
1017
|
as_packable.__name__ = class_name
|
|
1018
|
+
as_packable.__qualname__ = class_name
|
|
1019
|
+
as_packable.__module__ = cls.__module__
|
|
855
1020
|
as_packable.__annotations__ = class_annotations
|
|
1021
|
+
if cls.__doc__:
|
|
1022
|
+
as_packable.__doc__ = cls.__doc__
|
|
1023
|
+
|
|
1024
|
+
# Fix qualnames of dataclass-generated methods so they don't show
|
|
1025
|
+
# 'packable.<locals>.as_packable' in help() and IDE hints
|
|
1026
|
+
old_qualname_prefix = 'packable.<locals>.as_packable'
|
|
1027
|
+
for attr_name in ('__init__', '__repr__', '__eq__', '__post_init__'):
|
|
1028
|
+
attr = getattr(as_packable, attr_name, None)
|
|
1029
|
+
if attr is not None and hasattr(attr, '__qualname__'):
|
|
1030
|
+
if attr.__qualname__.startswith(old_qualname_prefix):
|
|
1031
|
+
attr.__qualname__ = attr.__qualname__.replace(
|
|
1032
|
+
old_qualname_prefix, class_name, 1
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
# Auto-register lens from DictSample to this type
|
|
1036
|
+
# This enables ds.as_type(MyType) when ds is Dataset[DictSample]
|
|
1037
|
+
def _dict_to_typed(ds: DictSample) -> as_packable:
|
|
1038
|
+
return as_packable.from_data(ds._data)
|
|
1039
|
+
|
|
1040
|
+
_dict_lens = Lens(_dict_to_typed)
|
|
1041
|
+
LensNetwork().register(_dict_lens)
|
|
856
1042
|
|
|
857
1043
|
##
|
|
858
1044
|
|