atdata 0.2.0a1__py3-none-any.whl → 0.2.3b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +43 -10
- atdata/_cid.py +144 -0
- atdata/_helpers.py +7 -5
- atdata/_hf_api.py +690 -0
- atdata/_protocols.py +504 -0
- atdata/_schema_codec.py +438 -0
- atdata/_sources.py +508 -0
- atdata/_stub_manager.py +534 -0
- atdata/_type_utils.py +104 -0
- atdata/atmosphere/__init__.py +269 -1
- atdata/atmosphere/_types.py +4 -2
- atdata/atmosphere/client.py +146 -3
- atdata/atmosphere/lens.py +4 -3
- atdata/atmosphere/records.py +168 -7
- atdata/atmosphere/schema.py +29 -82
- atdata/atmosphere/store.py +204 -0
- atdata/cli/__init__.py +222 -0
- atdata/cli/diagnose.py +169 -0
- atdata/cli/local.py +283 -0
- atdata/dataset.py +615 -257
- atdata/lens.py +53 -54
- atdata/local.py +1456 -228
- atdata/promote.py +195 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.3b1.dist-info}/METADATA +106 -14
- atdata-0.2.3b1.dist-info/RECORD +28 -0
- atdata-0.2.0a1.dist-info/RECORD +0 -16
- {atdata-0.2.0a1.dist-info → atdata-0.2.3b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.3b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.3b1.dist-info}/licenses/LICENSE +0 -0
atdata/dataset.py
CHANGED
|
@@ -13,7 +13,7 @@ The implementation handles automatic conversion between numpy arrays and bytes
|
|
|
13
13
|
during serialization, enabling efficient storage of numerical data in WebDataset
|
|
14
14
|
archives.
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
Examples:
|
|
17
17
|
>>> @packable
|
|
18
18
|
... class ImageSample:
|
|
19
19
|
... image: NDArray
|
|
@@ -41,6 +41,9 @@ from dataclasses import (
|
|
|
41
41
|
)
|
|
42
42
|
from abc import ABC
|
|
43
43
|
|
|
44
|
+
from ._sources import URLSource
|
|
45
|
+
from ._protocols import DataSource
|
|
46
|
+
|
|
44
47
|
from tqdm import tqdm
|
|
45
48
|
import numpy as np
|
|
46
49
|
import pandas as pd
|
|
@@ -51,16 +54,17 @@ from typing import (
|
|
|
51
54
|
Any,
|
|
52
55
|
Optional,
|
|
53
56
|
Dict,
|
|
57
|
+
Iterator,
|
|
54
58
|
Sequence,
|
|
55
59
|
Iterable,
|
|
56
60
|
Callable,
|
|
57
|
-
Union,
|
|
58
|
-
#
|
|
59
61
|
Self,
|
|
60
62
|
Generic,
|
|
61
63
|
Type,
|
|
62
64
|
TypeVar,
|
|
63
65
|
TypeAlias,
|
|
66
|
+
dataclass_transform,
|
|
67
|
+
overload,
|
|
64
68
|
)
|
|
65
69
|
from numpy.typing import NDArray
|
|
66
70
|
|
|
@@ -75,63 +79,203 @@ from .lens import Lens, LensNetwork
|
|
|
75
79
|
|
|
76
80
|
Pathlike = str | Path
|
|
77
81
|
|
|
82
|
+
# WebDataset sample/batch dictionaries (contain __key__, msgpack, etc.)
|
|
78
83
|
WDSRawSample: TypeAlias = Dict[str, Any]
|
|
79
84
|
WDSRawBatch: TypeAlias = Dict[str, Any]
|
|
80
85
|
|
|
81
86
|
SampleExportRow: TypeAlias = Dict[str, Any]
|
|
82
|
-
SampleExportMap: TypeAlias = Callable[[
|
|
87
|
+
SampleExportMap: TypeAlias = Callable[["PackableSample"], SampleExportRow]
|
|
83
88
|
|
|
84
89
|
|
|
85
90
|
##
|
|
86
91
|
# Main base classes
|
|
87
92
|
|
|
88
|
-
DT = TypeVar(
|
|
93
|
+
DT = TypeVar("DT")
|
|
89
94
|
|
|
90
|
-
MsgpackRawSample: TypeAlias = Dict[str, Any]
|
|
91
95
|
|
|
96
|
+
def _make_packable(x):
|
|
97
|
+
"""Convert numpy arrays to bytes; pass through other values unchanged."""
|
|
98
|
+
if isinstance(x, np.ndarray):
|
|
99
|
+
return eh.array_to_bytes(x)
|
|
100
|
+
return x
|
|
92
101
|
|
|
93
|
-
def _make_packable( x ):
|
|
94
|
-
"""Convert a value to a msgpack-compatible format.
|
|
95
102
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
103
|
+
def _is_possibly_ndarray_type(t):
|
|
104
|
+
"""Return True if type annotation is NDArray or Optional[NDArray]."""
|
|
105
|
+
if t == NDArray:
|
|
106
|
+
return True
|
|
107
|
+
if isinstance(t, types.UnionType):
|
|
108
|
+
return any(x == NDArray for x in t.__args__)
|
|
109
|
+
return False
|
|
99
110
|
|
|
100
|
-
Returns:
|
|
101
|
-
The value in a format suitable for msgpack serialization.
|
|
102
|
-
"""
|
|
103
|
-
if isinstance( x, np.ndarray ):
|
|
104
|
-
return eh.array_to_bytes( x )
|
|
105
|
-
return x
|
|
106
111
|
|
|
107
|
-
|
|
108
|
-
"""
|
|
112
|
+
class DictSample:
|
|
113
|
+
"""Dynamic sample type providing dict-like access to raw msgpack data.
|
|
109
114
|
|
|
110
|
-
|
|
111
|
-
|
|
115
|
+
This class is the default sample type for datasets when no explicit type is
|
|
116
|
+
specified. It stores the raw unpacked msgpack data and provides both
|
|
117
|
+
attribute-style (``sample.field``) and dict-style (``sample["field"]``)
|
|
118
|
+
access to fields.
|
|
112
119
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
120
|
+
``DictSample`` is useful for:
|
|
121
|
+
- Exploring datasets without defining a schema first
|
|
122
|
+
- Working with datasets that have variable schemas
|
|
123
|
+
- Prototyping before committing to a typed schema
|
|
124
|
+
|
|
125
|
+
To convert to a typed schema, use ``Dataset.as_type()`` with a
|
|
126
|
+
``@packable``-decorated class. Every ``@packable`` class automatically
|
|
127
|
+
registers a lens from ``DictSample``, making this conversion seamless.
|
|
128
|
+
|
|
129
|
+
Examples:
|
|
130
|
+
>>> ds = load_dataset("path/to/data.tar") # Returns Dataset[DictSample]
|
|
131
|
+
>>> for sample in ds.ordered():
|
|
132
|
+
... print(sample.some_field) # Attribute access
|
|
133
|
+
... print(sample["other_field"]) # Dict access
|
|
134
|
+
... print(sample.keys()) # Inspect available fields
|
|
135
|
+
...
|
|
136
|
+
>>> # Convert to typed schema
|
|
137
|
+
>>> typed_ds = ds.as_type(MyTypedSample)
|
|
138
|
+
|
|
139
|
+
Note:
|
|
140
|
+
NDArray fields are stored as raw bytes in DictSample. They are only
|
|
141
|
+
converted to numpy arrays when accessed through a typed sample class.
|
|
116
142
|
"""
|
|
117
143
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
144
|
+
__slots__ = ("_data",)
|
|
145
|
+
|
|
146
|
+
def __init__(self, _data: dict[str, Any] | None = None, **kwargs: Any) -> None:
|
|
147
|
+
"""Create a DictSample from a dictionary or keyword arguments.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
_data: Raw data dictionary. If provided, kwargs are ignored.
|
|
151
|
+
**kwargs: Field values if _data is not provided.
|
|
152
|
+
"""
|
|
153
|
+
if _data is not None:
|
|
154
|
+
object.__setattr__(self, "_data", _data)
|
|
155
|
+
else:
|
|
156
|
+
object.__setattr__(self, "_data", kwargs)
|
|
157
|
+
|
|
158
|
+
@classmethod
|
|
159
|
+
def from_data(cls, data: dict[str, Any]) -> "DictSample":
|
|
160
|
+
"""Create a DictSample from unpacked msgpack data.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
data: Dictionary with field names as keys.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
New DictSample instance wrapping the data.
|
|
167
|
+
"""
|
|
168
|
+
return cls(_data=data)
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def from_bytes(cls, bs: bytes) -> "DictSample":
|
|
172
|
+
"""Create a DictSample from raw msgpack bytes.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
bs: Raw bytes from a msgpack-serialized sample.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
New DictSample instance with the unpacked data.
|
|
179
|
+
"""
|
|
180
|
+
return cls.from_data(ormsgpack.unpackb(bs))
|
|
181
|
+
|
|
182
|
+
def __getattr__(self, name: str) -> Any:
|
|
183
|
+
"""Access a field by attribute name.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
name: Field name to access.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
The field value.
|
|
190
|
+
|
|
191
|
+
Raises:
|
|
192
|
+
AttributeError: If the field doesn't exist.
|
|
193
|
+
"""
|
|
194
|
+
# Avoid infinite recursion for _data lookup
|
|
195
|
+
if name == "_data":
|
|
196
|
+
raise AttributeError(name)
|
|
197
|
+
try:
|
|
198
|
+
return self._data[name]
|
|
199
|
+
except KeyError:
|
|
200
|
+
raise AttributeError(
|
|
201
|
+
f"'{type(self).__name__}' has no field '{name}'. "
|
|
202
|
+
f"Available fields: {list(self._data.keys())}"
|
|
203
|
+
) from None
|
|
204
|
+
|
|
205
|
+
def __getitem__(self, key: str) -> Any:
|
|
206
|
+
"""Access a field by dict key.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
key: Field name to access.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
The field value.
|
|
213
|
+
|
|
214
|
+
Raises:
|
|
215
|
+
KeyError: If the field doesn't exist.
|
|
216
|
+
"""
|
|
217
|
+
return self._data[key]
|
|
218
|
+
|
|
219
|
+
def __contains__(self, key: str) -> bool:
|
|
220
|
+
"""Check if a field exists."""
|
|
221
|
+
return key in self._data
|
|
222
|
+
|
|
223
|
+
def keys(self) -> list[str]:
|
|
224
|
+
"""Return list of field names."""
|
|
225
|
+
return list(self._data.keys())
|
|
226
|
+
|
|
227
|
+
def values(self) -> list[Any]:
|
|
228
|
+
"""Return list of field values."""
|
|
229
|
+
return list(self._data.values())
|
|
230
|
+
|
|
231
|
+
def items(self) -> list[tuple[str, Any]]:
|
|
232
|
+
"""Return list of (field_name, value) tuples."""
|
|
233
|
+
return list(self._data.items())
|
|
234
|
+
|
|
235
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
236
|
+
"""Get a field value with optional default.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
key: Field name to access.
|
|
240
|
+
default: Value to return if field doesn't exist.
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
The field value or default.
|
|
244
|
+
"""
|
|
245
|
+
return self._data.get(key, default)
|
|
246
|
+
|
|
247
|
+
def to_dict(self) -> dict[str, Any]:
|
|
248
|
+
"""Return a copy of the underlying data dictionary."""
|
|
249
|
+
return dict(self._data)
|
|
250
|
+
|
|
251
|
+
@property
|
|
252
|
+
def packed(self) -> bytes:
|
|
253
|
+
"""Pack this sample's data into msgpack bytes.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Raw msgpack bytes representing this sample's data.
|
|
257
|
+
"""
|
|
258
|
+
return msgpack.packb(self._data)
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def as_wds(self) -> "WDSRawSample":
|
|
262
|
+
"""Pack this sample's data for writing to WebDataset.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
A dictionary with ``__key__`` and ``msgpack`` fields.
|
|
266
|
+
"""
|
|
267
|
+
return {
|
|
268
|
+
"__key__": str(uuid.uuid1(0, 0)),
|
|
269
|
+
"msgpack": self.packed,
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
def __repr__(self) -> str:
|
|
273
|
+
fields = ", ".join(f"{k}=..." for k in self._data.keys())
|
|
274
|
+
return f"DictSample({fields})"
|
|
275
|
+
|
|
132
276
|
|
|
133
277
|
@dataclass
|
|
134
|
-
class PackableSample(
|
|
278
|
+
class PackableSample(ABC):
|
|
135
279
|
"""Base class for samples that can be serialized with msgpack.
|
|
136
280
|
|
|
137
281
|
This abstract base class provides automatic serialization/deserialization
|
|
@@ -143,7 +287,7 @@ class PackableSample( ABC ):
|
|
|
143
287
|
1. Direct inheritance with the ``@dataclass`` decorator
|
|
144
288
|
2. Using the ``@packable`` decorator (recommended)
|
|
145
289
|
|
|
146
|
-
|
|
290
|
+
Examples:
|
|
147
291
|
>>> @packable
|
|
148
292
|
... class MyData:
|
|
149
293
|
... name: str
|
|
@@ -154,67 +298,53 @@ class PackableSample( ABC ):
|
|
|
154
298
|
>>> restored = MyData.from_bytes(packed) # Deserialize
|
|
155
299
|
"""
|
|
156
300
|
|
|
157
|
-
def _ensure_good(
|
|
158
|
-
"""
|
|
159
|
-
|
|
160
|
-
This method scans all dataclass fields and for any field annotated as
|
|
161
|
-
``NDArray`` or ``NDArray | None``, automatically converts bytes values
|
|
162
|
-
to numpy arrays using the helper deserialization function. This enables
|
|
163
|
-
transparent handling of array serialization in msgpack data.
|
|
164
|
-
|
|
165
|
-
Note:
|
|
166
|
-
This is called during ``__post_init__`` to ensure proper type
|
|
167
|
-
conversion after deserialization.
|
|
168
|
-
"""
|
|
301
|
+
def _ensure_good(self):
|
|
302
|
+
"""Convert bytes to NDArray for fields annotated as NDArray or NDArray | None."""
|
|
169
303
|
|
|
170
304
|
# Auto-convert known types when annotated
|
|
171
305
|
# for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
|
|
172
|
-
for field in dataclasses.fields(
|
|
306
|
+
for field in dataclasses.fields(self):
|
|
173
307
|
var_name = field.name
|
|
174
308
|
var_type = field.type
|
|
175
309
|
|
|
176
310
|
# Annotation for this variable is to be an NDArray
|
|
177
|
-
if _is_possibly_ndarray_type(
|
|
311
|
+
if _is_possibly_ndarray_type(var_type):
|
|
178
312
|
# ... so, we'll always auto-convert to numpy
|
|
179
313
|
|
|
180
|
-
var_cur_value = getattr(
|
|
314
|
+
var_cur_value = getattr(self, var_name)
|
|
181
315
|
|
|
182
316
|
# Execute the appropriate conversion for intermediate data
|
|
183
317
|
# based on what is provided
|
|
184
318
|
|
|
185
|
-
if isinstance(
|
|
319
|
+
if isinstance(var_cur_value, np.ndarray):
|
|
186
320
|
# Already the correct type, no conversion needed
|
|
187
321
|
continue
|
|
188
322
|
|
|
189
|
-
elif isinstance(
|
|
190
|
-
#
|
|
191
|
-
#
|
|
192
|
-
# as
|
|
193
|
-
setattr(
|
|
323
|
+
elif isinstance(var_cur_value, bytes):
|
|
324
|
+
# Design note: bytes in NDArray-typed fields are always interpreted
|
|
325
|
+
# as serialized arrays. This means raw bytes fields must not be
|
|
326
|
+
# annotated as NDArray.
|
|
327
|
+
setattr(self, var_name, eh.bytes_to_array(var_cur_value))
|
|
194
328
|
|
|
195
|
-
def __post_init__(
|
|
329
|
+
def __post_init__(self):
|
|
196
330
|
self._ensure_good()
|
|
197
331
|
|
|
198
332
|
##
|
|
199
333
|
|
|
200
334
|
@classmethod
|
|
201
|
-
def from_data(
|
|
335
|
+
def from_data(cls, data: WDSRawSample) -> Self:
|
|
202
336
|
"""Create a sample instance from unpacked msgpack data.
|
|
203
337
|
|
|
204
338
|
Args:
|
|
205
|
-
data:
|
|
206
|
-
the sample's field names.
|
|
339
|
+
data: Dictionary with keys matching the sample's field names.
|
|
207
340
|
|
|
208
341
|
Returns:
|
|
209
|
-
|
|
210
|
-
the data dictionary and NDArray fields auto-converted from bytes.
|
|
342
|
+
New instance with NDArray fields auto-converted from bytes.
|
|
211
343
|
"""
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
return ret
|
|
215
|
-
|
|
344
|
+
return cls(**data)
|
|
345
|
+
|
|
216
346
|
@classmethod
|
|
217
|
-
def from_bytes(
|
|
347
|
+
def from_bytes(cls, bs: bytes) -> Self:
|
|
218
348
|
"""Create a sample instance from raw msgpack bytes.
|
|
219
349
|
|
|
220
350
|
Args:
|
|
@@ -223,10 +353,10 @@ class PackableSample( ABC ):
|
|
|
223
353
|
Returns:
|
|
224
354
|
A new instance of this sample class deserialized from the bytes.
|
|
225
355
|
"""
|
|
226
|
-
return cls.from_data(
|
|
356
|
+
return cls.from_data(ormsgpack.unpackb(bs))
|
|
227
357
|
|
|
228
358
|
@property
|
|
229
|
-
def packed(
|
|
359
|
+
def packed(self) -> bytes:
|
|
230
360
|
"""Pack this sample's data into msgpack bytes.
|
|
231
361
|
|
|
232
362
|
NDArray fields are automatically converted to bytes before packing.
|
|
@@ -241,21 +371,17 @@ class PackableSample( ABC ):
|
|
|
241
371
|
|
|
242
372
|
# Make sure that all of our (possibly unpackable) data is in a packable
|
|
243
373
|
# format
|
|
244
|
-
o = {
|
|
245
|
-
k: _make_packable( v )
|
|
246
|
-
for k, v in vars( self ).items()
|
|
247
|
-
}
|
|
374
|
+
o = {k: _make_packable(v) for k, v in vars(self).items()}
|
|
248
375
|
|
|
249
|
-
ret = msgpack.packb(
|
|
376
|
+
ret = msgpack.packb(o)
|
|
250
377
|
|
|
251
378
|
if ret is None:
|
|
252
|
-
raise RuntimeError(
|
|
379
|
+
raise RuntimeError(f"Failed to pack sample to bytes: {o}")
|
|
253
380
|
|
|
254
381
|
return ret
|
|
255
|
-
|
|
256
|
-
# TODO Expand to allow for specifying explicit __key__
|
|
382
|
+
|
|
257
383
|
@property
|
|
258
|
-
def as_wds(
|
|
384
|
+
def as_wds(self) -> WDSRawSample:
|
|
259
385
|
"""Pack this sample's data for writing to WebDataset.
|
|
260
386
|
|
|
261
387
|
Returns:
|
|
@@ -263,37 +389,26 @@ class PackableSample( ABC ):
|
|
|
263
389
|
``msgpack`` (packed sample data) fields suitable for WebDataset.
|
|
264
390
|
|
|
265
391
|
Note:
|
|
266
|
-
|
|
392
|
+
Keys are auto-generated as UUID v1 for time-sortable ordering.
|
|
393
|
+
Custom key specification is not currently supported.
|
|
267
394
|
"""
|
|
268
395
|
return {
|
|
269
396
|
# Generates a UUID that is timelike-sortable
|
|
270
|
-
|
|
271
|
-
|
|
397
|
+
"__key__": str(uuid.uuid1(0, 0)),
|
|
398
|
+
"msgpack": self.packed,
|
|
272
399
|
}
|
|
273
400
|
|
|
274
|
-
def _batch_aggregate( xs: Sequence ):
|
|
275
|
-
"""Aggregate a sequence of values into a batch-appropriate format.
|
|
276
|
-
|
|
277
|
-
Args:
|
|
278
|
-
xs: A sequence of values to aggregate. If the first element is a numpy
|
|
279
|
-
array, all elements are stacked into a single array. Otherwise,
|
|
280
|
-
returns a list.
|
|
281
|
-
|
|
282
|
-
Returns:
|
|
283
|
-
A numpy array (if elements are arrays) or a list (otherwise).
|
|
284
|
-
"""
|
|
285
401
|
|
|
402
|
+
def _batch_aggregate(xs: Sequence):
|
|
403
|
+
"""Stack arrays into numpy array with batch dim; otherwise return list."""
|
|
286
404
|
if not xs:
|
|
287
|
-
# Empty sequence
|
|
288
405
|
return []
|
|
406
|
+
if isinstance(xs[0], np.ndarray):
|
|
407
|
+
return np.array(list(xs))
|
|
408
|
+
return list(xs)
|
|
289
409
|
|
|
290
|
-
# Aggregate
|
|
291
|
-
if isinstance( xs[0], np.ndarray ):
|
|
292
|
-
return np.array( list( xs ) )
|
|
293
410
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
class SampleBatch( Generic[DT] ):
|
|
411
|
+
class SampleBatch(Generic[DT]):
|
|
297
412
|
"""A batch of samples with automatic attribute aggregation.
|
|
298
413
|
|
|
299
414
|
This class wraps a sequence of samples and provides magic ``__getattr__``
|
|
@@ -304,19 +419,28 @@ class SampleBatch( Generic[DT] ):
|
|
|
304
419
|
NDArray fields are stacked into a numpy array with a batch dimension.
|
|
305
420
|
Other fields are aggregated into a list.
|
|
306
421
|
|
|
307
|
-
|
|
422
|
+
Parameters:
|
|
308
423
|
DT: The sample type, must derive from ``PackableSample``.
|
|
309
424
|
|
|
310
425
|
Attributes:
|
|
311
426
|
samples: The list of sample instances in this batch.
|
|
312
427
|
|
|
313
|
-
|
|
428
|
+
Examples:
|
|
314
429
|
>>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
|
|
315
430
|
>>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
|
|
316
431
|
>>> batch.names # Returns list of names
|
|
432
|
+
|
|
433
|
+
Note:
|
|
434
|
+
This class uses Python's ``__orig_class__`` mechanism to extract the
|
|
435
|
+
type parameter at runtime. Instances must be created using the
|
|
436
|
+
subscripted syntax ``SampleBatch[MyType](samples)`` rather than
|
|
437
|
+
calling the constructor directly with an unsubscripted class.
|
|
317
438
|
"""
|
|
318
439
|
|
|
319
|
-
|
|
440
|
+
# Design note: The docstring uses "Parameters:" for type parameters because
|
|
441
|
+
# quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
|
|
442
|
+
|
|
443
|
+
def __init__(self, samples: Sequence[DT]):
|
|
320
444
|
"""Create a batch from a sequence of samples.
|
|
321
445
|
|
|
322
446
|
Args:
|
|
@@ -324,19 +448,23 @@ class SampleBatch( Generic[DT] ):
|
|
|
324
448
|
Each sample must be an instance of a type derived from
|
|
325
449
|
``PackableSample``.
|
|
326
450
|
"""
|
|
327
|
-
self.samples = list(
|
|
451
|
+
self.samples = list(samples)
|
|
328
452
|
self._aggregate_cache = dict()
|
|
453
|
+
self._sample_type_cache: Type | None = None
|
|
329
454
|
|
|
330
455
|
@property
|
|
331
|
-
def sample_type(
|
|
456
|
+
def sample_type(self) -> Type:
|
|
332
457
|
"""The type of each sample in this batch.
|
|
333
458
|
|
|
334
459
|
Returns:
|
|
335
460
|
The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
|
|
336
461
|
"""
|
|
337
|
-
|
|
462
|
+
if self._sample_type_cache is None:
|
|
463
|
+
self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
|
|
464
|
+
assert self._sample_type_cache is not None
|
|
465
|
+
return self._sample_type_cache
|
|
338
466
|
|
|
339
|
-
def __getattr__(
|
|
467
|
+
def __getattr__(self, name):
|
|
340
468
|
"""Aggregate an attribute across all samples in the batch.
|
|
341
469
|
|
|
342
470
|
This magic method enables attribute-style access to aggregated sample
|
|
@@ -353,22 +481,57 @@ class SampleBatch( Generic[DT] ):
|
|
|
353
481
|
AttributeError: If the attribute doesn't exist on the sample type.
|
|
354
482
|
"""
|
|
355
483
|
# Aggregate named params of sample type
|
|
356
|
-
if name in vars(
|
|
484
|
+
if name in vars(self.sample_type)["__annotations__"]:
|
|
357
485
|
if name not in self._aggregate_cache:
|
|
358
486
|
self._aggregate_cache[name] = _batch_aggregate(
|
|
359
|
-
[
|
|
360
|
-
for x in self.samples ]
|
|
487
|
+
[getattr(x, name) for x in self.samples]
|
|
361
488
|
)
|
|
362
489
|
|
|
363
490
|
return self._aggregate_cache[name]
|
|
364
491
|
|
|
365
|
-
raise AttributeError(
|
|
492
|
+
raise AttributeError(f"No sample attribute named {name}")
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
ST = TypeVar("ST", bound=PackableSample)
|
|
496
|
+
RT = TypeVar("RT", bound=PackableSample)
|
|
366
497
|
|
|
367
498
|
|
|
368
|
-
|
|
369
|
-
|
|
499
|
+
class _ShardListStage(wds.utils.PipelineStage):
|
|
500
|
+
"""Pipeline stage that yields {url: shard_id} dicts from a DataSource.
|
|
370
501
|
|
|
371
|
-
|
|
502
|
+
This is analogous to SimpleShardList but works with any DataSource.
|
|
503
|
+
Used as the first stage before split_by_worker.
|
|
504
|
+
"""
|
|
505
|
+
|
|
506
|
+
def __init__(self, source: DataSource):
|
|
507
|
+
self.source = source
|
|
508
|
+
|
|
509
|
+
def run(self):
|
|
510
|
+
"""Yield {url: shard_id} dicts for each shard."""
|
|
511
|
+
for shard_id in self.source.list_shards():
|
|
512
|
+
yield {"url": shard_id}
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
class _StreamOpenerStage(wds.utils.PipelineStage):
|
|
516
|
+
"""Pipeline stage that opens streams from a DataSource.
|
|
517
|
+
|
|
518
|
+
Takes {url: shard_id} dicts and adds a stream using source.open_shard().
|
|
519
|
+
This replaces WebDataset's url_opener stage.
|
|
520
|
+
"""
|
|
521
|
+
|
|
522
|
+
def __init__(self, source: DataSource):
|
|
523
|
+
self.source = source
|
|
524
|
+
|
|
525
|
+
def run(self, src):
|
|
526
|
+
"""Open streams for each shard dict."""
|
|
527
|
+
for sample in src:
|
|
528
|
+
shard_id = sample["url"]
|
|
529
|
+
stream = self.source.open_shard(shard_id)
|
|
530
|
+
sample["stream"] = stream
|
|
531
|
+
yield sample
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
class Dataset(Generic[ST]):
|
|
372
535
|
"""A typed dataset built on WebDataset with lens transformations.
|
|
373
536
|
|
|
374
537
|
This class wraps WebDataset tar archives and provides type-safe iteration
|
|
@@ -381,13 +544,13 @@ class Dataset( Generic[ST] ):
|
|
|
381
544
|
- Type transformations via the lens system (``as_type()``)
|
|
382
545
|
- Export to parquet format
|
|
383
546
|
|
|
384
|
-
|
|
547
|
+
Parameters:
|
|
385
548
|
ST: The sample type for this dataset, must derive from ``PackableSample``.
|
|
386
549
|
|
|
387
550
|
Attributes:
|
|
388
551
|
url: WebDataset brace-notation URL for the tar file(s).
|
|
389
552
|
|
|
390
|
-
|
|
553
|
+
Examples:
|
|
391
554
|
>>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
|
|
392
555
|
>>> for sample in ds.ordered(batch_size=32):
|
|
393
556
|
... # sample is SampleBatch[MyData] with batch_size samples
|
|
@@ -395,23 +558,31 @@ class Dataset( Generic[ST] ):
|
|
|
395
558
|
...
|
|
396
559
|
>>> # Transform to a different view
|
|
397
560
|
>>> ds_view = ds.as_type(MyDataView)
|
|
398
|
-
|
|
561
|
+
|
|
562
|
+
Note:
|
|
563
|
+
This class uses Python's ``__orig_class__`` mechanism to extract the
|
|
564
|
+
type parameter at runtime. Instances must be created using the
|
|
565
|
+
subscripted syntax ``Dataset[MyType](url)`` rather than calling the
|
|
566
|
+
constructor directly with an unsubscripted class.
|
|
399
567
|
"""
|
|
400
568
|
|
|
569
|
+
# Design note: The docstring uses "Parameters:" for type parameters because
|
|
570
|
+
# quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
|
|
571
|
+
|
|
401
572
|
@property
|
|
402
|
-
def sample_type(
|
|
573
|
+
def sample_type(self) -> Type:
|
|
403
574
|
"""The type of each returned sample from this dataset's iterator.
|
|
404
575
|
|
|
405
576
|
Returns:
|
|
406
577
|
The type parameter ``ST`` used when creating this ``Dataset[ST]``.
|
|
407
|
-
|
|
408
|
-
Note:
|
|
409
|
-
Extracts the type parameter at runtime using ``__orig_class__``.
|
|
410
578
|
"""
|
|
411
|
-
|
|
412
|
-
|
|
579
|
+
if self._sample_type_cache is None:
|
|
580
|
+
self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
|
|
581
|
+
assert self._sample_type_cache is not None
|
|
582
|
+
return self._sample_type_cache
|
|
583
|
+
|
|
413
584
|
@property
|
|
414
|
-
def batch_type(
|
|
585
|
+
def batch_type(self) -> Type:
|
|
415
586
|
"""The type of batches produced by this dataset.
|
|
416
587
|
|
|
417
588
|
Returns:
|
|
@@ -419,31 +590,61 @@ class Dataset( Generic[ST] ):
|
|
|
419
590
|
"""
|
|
420
591
|
return SampleBatch[self.sample_type]
|
|
421
592
|
|
|
422
|
-
def __init__(
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
593
|
+
def __init__(
|
|
594
|
+
self,
|
|
595
|
+
source: DataSource | str | None = None,
|
|
596
|
+
metadata_url: str | None = None,
|
|
597
|
+
*,
|
|
598
|
+
url: str | None = None,
|
|
599
|
+
) -> None:
|
|
600
|
+
"""Create a dataset from a DataSource or URL.
|
|
426
601
|
|
|
427
602
|
Args:
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
603
|
+
source: Either a DataSource implementation or a WebDataset-compatible
|
|
604
|
+
URL string. If a string is provided, it's wrapped in URLSource
|
|
605
|
+
for backward compatibility.
|
|
606
|
+
|
|
607
|
+
Examples:
|
|
608
|
+
- String URL: ``"path/to/file-{000000..000009}.tar"``
|
|
609
|
+
- URLSource: ``URLSource("https://example.com/data.tar")``
|
|
610
|
+
- S3Source: ``S3Source(bucket="my-bucket", keys=["data.tar"])``
|
|
611
|
+
|
|
612
|
+
metadata_url: Optional URL to msgpack-encoded metadata for this dataset.
|
|
613
|
+
url: Deprecated. Use ``source`` instead. Kept for backward compatibility.
|
|
431
614
|
"""
|
|
432
615
|
super().__init__()
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
616
|
+
|
|
617
|
+
# Handle backward compatibility: url= keyword argument
|
|
618
|
+
if source is None and url is not None:
|
|
619
|
+
source = url
|
|
620
|
+
elif source is None:
|
|
621
|
+
raise TypeError("Dataset() missing required argument: 'source' or 'url'")
|
|
622
|
+
|
|
623
|
+
# Normalize source: strings become URLSource for backward compatibility
|
|
624
|
+
if isinstance(source, str):
|
|
625
|
+
self._source: DataSource = URLSource(source)
|
|
626
|
+
self.url = source
|
|
627
|
+
else:
|
|
628
|
+
self._source = source
|
|
629
|
+
# For compatibility, expose URL if source has list_shards
|
|
630
|
+
shards = source.list_shards()
|
|
631
|
+
# Design note: Using first shard as url for legacy compatibility.
|
|
632
|
+
# Full shard list is available via list_shards() method.
|
|
633
|
+
self.url = shards[0] if shards else ""
|
|
438
634
|
|
|
439
635
|
self._metadata: dict[str, Any] | None = None
|
|
440
636
|
self.metadata_url: str | None = metadata_url
|
|
441
637
|
"""Optional URL to msgpack-encoded metadata for this dataset."""
|
|
442
638
|
|
|
443
|
-
# Allow addition of automatic transformation of raw underlying data
|
|
444
639
|
self._output_lens: Lens | None = None
|
|
640
|
+
self._sample_type_cache: Type | None = None
|
|
445
641
|
|
|
446
|
-
|
|
642
|
+
@property
|
|
643
|
+
def source(self) -> DataSource:
|
|
644
|
+
"""The underlying data source for this dataset."""
|
|
645
|
+
return self._source
|
|
646
|
+
|
|
647
|
+
def as_type(self, other: Type[RT]) -> "Dataset[RT]":
|
|
447
648
|
"""View this dataset through a different sample type using a registered lens.
|
|
448
649
|
|
|
449
650
|
Args:
|
|
@@ -459,28 +660,53 @@ class Dataset( Generic[ST] ):
|
|
|
459
660
|
ValueError: If no registered lens exists between the current
|
|
460
661
|
sample type and the target type.
|
|
461
662
|
"""
|
|
462
|
-
ret = Dataset[other](
|
|
663
|
+
ret = Dataset[other](self._source)
|
|
463
664
|
# Get the singleton lens registry
|
|
464
665
|
lenses = LensNetwork()
|
|
465
|
-
ret._output_lens = lenses.transform(
|
|
666
|
+
ret._output_lens = lenses.transform(self.sample_type, ret.sample_type)
|
|
466
667
|
return ret
|
|
467
668
|
|
|
468
669
|
@property
|
|
469
|
-
def
|
|
470
|
-
"""
|
|
471
|
-
|
|
670
|
+
def shards(self) -> Iterator[str]:
|
|
671
|
+
"""Lazily iterate over shard identifiers.
|
|
672
|
+
|
|
673
|
+
Yields:
|
|
674
|
+
Shard identifiers (e.g., 'train-000000.tar', 'train-000001.tar').
|
|
675
|
+
|
|
676
|
+
Examples:
|
|
677
|
+
>>> for shard in ds.shards:
|
|
678
|
+
... print(f"Processing {shard}")
|
|
679
|
+
"""
|
|
680
|
+
return iter(self._source.list_shards())
|
|
681
|
+
|
|
682
|
+
def list_shards(self) -> list[str]:
|
|
683
|
+
"""Get list of individual dataset shards.
|
|
684
|
+
|
|
472
685
|
Returns:
|
|
473
686
|
A full (non-lazy) list of the individual ``tar`` files within the
|
|
474
687
|
source WebDataset.
|
|
475
688
|
"""
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
689
|
+
return self._source.list_shards()
|
|
690
|
+
|
|
691
|
+
# Legacy alias for backwards compatibility
|
|
692
|
+
@property
|
|
693
|
+
def shard_list(self) -> list[str]:
|
|
694
|
+
"""List of individual dataset shards (deprecated, use list_shards()).
|
|
695
|
+
|
|
696
|
+
.. deprecated::
|
|
697
|
+
Use :meth:`list_shards` instead.
|
|
698
|
+
"""
|
|
699
|
+
import warnings
|
|
700
|
+
|
|
701
|
+
warnings.warn(
|
|
702
|
+
"shard_list is deprecated, use list_shards() instead",
|
|
703
|
+
DeprecationWarning,
|
|
704
|
+
stacklevel=2,
|
|
479
705
|
)
|
|
480
|
-
return
|
|
706
|
+
return self.list_shards()
|
|
481
707
|
|
|
482
708
|
@property
|
|
483
|
-
def metadata(
|
|
709
|
+
def metadata(self) -> dict[str, Any] | None:
|
|
484
710
|
"""Fetch and cache metadata from metadata_url.
|
|
485
711
|
|
|
486
712
|
Returns:
|
|
@@ -493,50 +719,91 @@ class Dataset( Generic[ST] ):
|
|
|
493
719
|
return None
|
|
494
720
|
|
|
495
721
|
if self._metadata is None:
|
|
496
|
-
with requests.get(
|
|
722
|
+
with requests.get(self.metadata_url, stream=True) as response:
|
|
497
723
|
response.raise_for_status()
|
|
498
|
-
self._metadata = msgpack.unpackb(
|
|
499
|
-
|
|
724
|
+
self._metadata = msgpack.unpackb(response.content, raw=False)
|
|
725
|
+
|
|
500
726
|
# Use our cached values
|
|
501
727
|
return self._metadata
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
728
|
+
|
|
729
|
+
@overload
|
|
730
|
+
def ordered(
|
|
731
|
+
self,
|
|
732
|
+
batch_size: None = None,
|
|
733
|
+
) -> Iterable[ST]: ...
|
|
734
|
+
|
|
735
|
+
@overload
|
|
736
|
+
def ordered(
|
|
737
|
+
self,
|
|
738
|
+
batch_size: int,
|
|
739
|
+
) -> Iterable[SampleBatch[ST]]: ...
|
|
740
|
+
|
|
741
|
+
def ordered(
|
|
742
|
+
self,
|
|
743
|
+
batch_size: int | None = None,
|
|
744
|
+
) -> Iterable[ST] | Iterable[SampleBatch[ST]]:
|
|
745
|
+
"""Iterate over the dataset in order.
|
|
746
|
+
|
|
508
747
|
Args:
|
|
509
|
-
batch_size
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
748
|
+
batch_size: The size of iterated batches. Default: None (unbatched).
|
|
749
|
+
If ``None``, iterates over one sample at a time with no batch
|
|
750
|
+
dimension.
|
|
751
|
+
|
|
513
752
|
Returns:
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
753
|
+
A data pipeline that iterates over the dataset in its original
|
|
754
|
+
sample order. When ``batch_size`` is ``None``, yields individual
|
|
755
|
+
samples of type ``ST``. When ``batch_size`` is an integer, yields
|
|
756
|
+
``SampleBatch[ST]`` instances containing that many samples.
|
|
757
|
+
|
|
758
|
+
Examples:
|
|
759
|
+
>>> for sample in ds.ordered():
|
|
760
|
+
... process(sample) # sample is ST
|
|
761
|
+
>>> for batch in ds.ordered(batch_size=32):
|
|
762
|
+
... process(batch) # batch is SampleBatch[ST]
|
|
517
763
|
"""
|
|
518
|
-
|
|
519
764
|
if batch_size is None:
|
|
520
765
|
return wds.pipeline.DataPipeline(
|
|
521
|
-
|
|
766
|
+
_ShardListStage(self._source),
|
|
522
767
|
wds.shardlists.split_by_worker,
|
|
523
|
-
|
|
524
|
-
wds.
|
|
768
|
+
_StreamOpenerStage(self._source),
|
|
769
|
+
wds.tariterators.tar_file_expander,
|
|
770
|
+
wds.tariterators.group_by_keys,
|
|
771
|
+
wds.filters.map(self.wrap),
|
|
525
772
|
)
|
|
526
773
|
|
|
527
774
|
return wds.pipeline.DataPipeline(
|
|
528
|
-
|
|
775
|
+
_ShardListStage(self._source),
|
|
529
776
|
wds.shardlists.split_by_worker,
|
|
530
|
-
|
|
531
|
-
wds.
|
|
532
|
-
wds.
|
|
777
|
+
_StreamOpenerStage(self._source),
|
|
778
|
+
wds.tariterators.tar_file_expander,
|
|
779
|
+
wds.tariterators.group_by_keys,
|
|
780
|
+
wds.filters.batched(batch_size),
|
|
781
|
+
wds.filters.map(self.wrap_batch),
|
|
533
782
|
)
|
|
534
783
|
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
784
|
+
@overload
|
|
785
|
+
def shuffled(
|
|
786
|
+
self,
|
|
787
|
+
buffer_shards: int = 100,
|
|
788
|
+
buffer_samples: int = 10_000,
|
|
789
|
+
batch_size: None = None,
|
|
790
|
+
) -> Iterable[ST]: ...
|
|
791
|
+
|
|
792
|
+
@overload
|
|
793
|
+
def shuffled(
|
|
794
|
+
self,
|
|
795
|
+
buffer_shards: int = 100,
|
|
796
|
+
buffer_samples: int = 10_000,
|
|
797
|
+
*,
|
|
798
|
+
batch_size: int,
|
|
799
|
+
) -> Iterable[SampleBatch[ST]]: ...
|
|
800
|
+
|
|
801
|
+
def shuffled(
|
|
802
|
+
self,
|
|
803
|
+
buffer_shards: int = 100,
|
|
804
|
+
buffer_samples: int = 10_000,
|
|
805
|
+
batch_size: int | None = None,
|
|
806
|
+
) -> Iterable[ST] | Iterable[SampleBatch[ST]]:
|
|
540
807
|
"""Iterate over the dataset in random order.
|
|
541
808
|
|
|
542
809
|
Args:
|
|
@@ -546,93 +813,139 @@ class Dataset( Generic[ST] ):
|
|
|
546
813
|
buffer_samples: Number of samples to buffer for shuffling within
|
|
547
814
|
shards. Larger values increase randomness but use more memory.
|
|
548
815
|
Default: 10,000.
|
|
549
|
-
batch_size: The size of iterated batches. Default:
|
|
550
|
-
iterates over one sample at a time with no batch
|
|
816
|
+
batch_size: The size of iterated batches. Default: None (unbatched).
|
|
817
|
+
If ``None``, iterates over one sample at a time with no batch
|
|
818
|
+
dimension.
|
|
551
819
|
|
|
552
820
|
Returns:
|
|
553
|
-
A
|
|
554
|
-
|
|
555
|
-
``
|
|
556
|
-
samples.
|
|
821
|
+
A data pipeline that iterates over the dataset in randomized order.
|
|
822
|
+
When ``batch_size`` is ``None``, yields individual samples of type
|
|
823
|
+
``ST``. When ``batch_size`` is an integer, yields ``SampleBatch[ST]``
|
|
824
|
+
instances containing that many samples.
|
|
825
|
+
|
|
826
|
+
Examples:
|
|
827
|
+
>>> for sample in ds.shuffled():
|
|
828
|
+
... process(sample) # sample is ST
|
|
829
|
+
>>> for batch in ds.shuffled(batch_size=32):
|
|
830
|
+
... process(batch) # batch is SampleBatch[ST]
|
|
557
831
|
"""
|
|
558
832
|
if batch_size is None:
|
|
559
833
|
return wds.pipeline.DataPipeline(
|
|
560
|
-
|
|
561
|
-
wds.filters.shuffle(
|
|
834
|
+
_ShardListStage(self._source),
|
|
835
|
+
wds.filters.shuffle(buffer_shards),
|
|
562
836
|
wds.shardlists.split_by_worker,
|
|
563
|
-
|
|
564
|
-
wds.
|
|
565
|
-
wds.
|
|
837
|
+
_StreamOpenerStage(self._source),
|
|
838
|
+
wds.tariterators.tar_file_expander,
|
|
839
|
+
wds.tariterators.group_by_keys,
|
|
840
|
+
wds.filters.shuffle(buffer_samples),
|
|
841
|
+
wds.filters.map(self.wrap),
|
|
566
842
|
)
|
|
567
843
|
|
|
568
844
|
return wds.pipeline.DataPipeline(
|
|
569
|
-
|
|
570
|
-
wds.filters.shuffle(
|
|
845
|
+
_ShardListStage(self._source),
|
|
846
|
+
wds.filters.shuffle(buffer_shards),
|
|
571
847
|
wds.shardlists.split_by_worker,
|
|
572
|
-
|
|
573
|
-
wds.
|
|
574
|
-
wds.
|
|
575
|
-
wds.filters.
|
|
848
|
+
_StreamOpenerStage(self._source),
|
|
849
|
+
wds.tariterators.tar_file_expander,
|
|
850
|
+
wds.tariterators.group_by_keys,
|
|
851
|
+
wds.filters.shuffle(buffer_samples),
|
|
852
|
+
wds.filters.batched(batch_size),
|
|
853
|
+
wds.filters.map(self.wrap_batch),
|
|
576
854
|
)
|
|
577
|
-
|
|
578
|
-
#
|
|
579
|
-
#
|
|
580
|
-
def to_parquet(
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
855
|
+
|
|
856
|
+
# Design note: Uses pandas for parquet export. Could be replaced with
|
|
857
|
+
# direct fastparquet calls to reduce dependencies if needed.
|
|
858
|
+
def to_parquet(
|
|
859
|
+
self,
|
|
860
|
+
path: Pathlike,
|
|
861
|
+
sample_map: Optional[SampleExportMap] = None,
|
|
862
|
+
maxcount: Optional[int] = None,
|
|
863
|
+
**kwargs,
|
|
864
|
+
):
|
|
865
|
+
"""Export dataset contents to parquet format.
|
|
866
|
+
|
|
867
|
+
Converts all samples to a pandas DataFrame and saves to parquet file(s).
|
|
868
|
+
Useful for interoperability with data analysis tools.
|
|
869
|
+
|
|
870
|
+
Args:
|
|
871
|
+
path: Output path for the parquet file. If ``maxcount`` is specified,
|
|
872
|
+
files are named ``{stem}-{segment:06d}.parquet``.
|
|
873
|
+
sample_map: Optional function to convert samples to dictionaries.
|
|
874
|
+
Defaults to ``dataclasses.asdict``.
|
|
875
|
+
maxcount: If specified, split output into multiple files with at most
|
|
876
|
+
this many samples each. Recommended for large datasets.
|
|
877
|
+
**kwargs: Additional arguments passed to ``pandas.DataFrame.to_parquet()``.
|
|
878
|
+
Common options include ``compression``, ``index``, ``engine``.
|
|
879
|
+
|
|
880
|
+
Warning:
|
|
881
|
+
**Memory Usage**: When ``maxcount=None`` (default), this method loads
|
|
882
|
+
the **entire dataset into memory** as a pandas DataFrame before writing.
|
|
883
|
+
For large datasets, this can cause memory exhaustion.
|
|
884
|
+
|
|
885
|
+
For datasets larger than available RAM, always specify ``maxcount``::
|
|
886
|
+
|
|
887
|
+
# Safe for large datasets - processes in chunks
|
|
888
|
+
ds.to_parquet("output.parquet", maxcount=10000)
|
|
889
|
+
|
|
890
|
+
This creates multiple parquet files: ``output-000000.parquet``,
|
|
891
|
+
``output-000001.parquet``, etc.
|
|
892
|
+
|
|
893
|
+
Examples:
|
|
894
|
+
>>> ds = Dataset[MySample]("data.tar")
|
|
895
|
+
>>> # Small dataset - load all at once
|
|
896
|
+
>>> ds.to_parquet("output.parquet")
|
|
897
|
+
>>>
|
|
898
|
+
>>> # Large dataset - process in chunks
|
|
899
|
+
>>> ds.to_parquet("output.parquet", maxcount=50000)
|
|
588
900
|
"""
|
|
589
901
|
##
|
|
590
902
|
|
|
591
903
|
# Normalize args
|
|
592
|
-
path = Path(
|
|
904
|
+
path = Path(path)
|
|
593
905
|
if sample_map is None:
|
|
594
906
|
sample_map = asdict
|
|
595
|
-
|
|
596
|
-
verbose = kwargs.get( 'verbose', False )
|
|
597
907
|
|
|
598
|
-
|
|
908
|
+
verbose = kwargs.get("verbose", False)
|
|
909
|
+
|
|
910
|
+
it = self.ordered(batch_size=None)
|
|
599
911
|
if verbose:
|
|
600
|
-
it = tqdm(
|
|
912
|
+
it = tqdm(it)
|
|
601
913
|
|
|
602
914
|
#
|
|
603
915
|
|
|
604
916
|
if maxcount is None:
|
|
605
917
|
# Load and save full dataset
|
|
606
|
-
df = pd.DataFrame(
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
918
|
+
df = pd.DataFrame([sample_map(x) for x in self.ordered(batch_size=None)])
|
|
919
|
+
df.to_parquet(path, **kwargs)
|
|
920
|
+
|
|
610
921
|
else:
|
|
611
922
|
# Load and save dataset in segments of size `maxcount`
|
|
612
923
|
|
|
613
924
|
cur_segment = 0
|
|
614
925
|
cur_buffer = []
|
|
615
|
-
path_template = (
|
|
926
|
+
path_template = (
|
|
927
|
+
path.parent / f"{path.stem}-{{:06d}}{path.suffix}"
|
|
928
|
+
).as_posix()
|
|
616
929
|
|
|
617
|
-
for x in self.ordered(
|
|
618
|
-
cur_buffer.append(
|
|
930
|
+
for x in self.ordered(batch_size=None):
|
|
931
|
+
cur_buffer.append(sample_map(x))
|
|
619
932
|
|
|
620
|
-
if len(
|
|
933
|
+
if len(cur_buffer) >= maxcount:
|
|
621
934
|
# Write current segment
|
|
622
|
-
cur_path = path_template.format(
|
|
623
|
-
df = pd.DataFrame(
|
|
624
|
-
df.to_parquet(
|
|
935
|
+
cur_path = path_template.format(cur_segment)
|
|
936
|
+
df = pd.DataFrame(cur_buffer)
|
|
937
|
+
df.to_parquet(cur_path, **kwargs)
|
|
625
938
|
|
|
626
939
|
cur_segment += 1
|
|
627
940
|
cur_buffer = []
|
|
628
|
-
|
|
629
|
-
if len(
|
|
941
|
+
|
|
942
|
+
if len(cur_buffer) > 0:
|
|
630
943
|
# Write one last segment with remainder
|
|
631
|
-
cur_path = path_template.format(
|
|
632
|
-
df = pd.DataFrame(
|
|
633
|
-
df.to_parquet(
|
|
944
|
+
cur_path = path_template.format(cur_segment)
|
|
945
|
+
df = pd.DataFrame(cur_buffer)
|
|
946
|
+
df.to_parquet(cur_path, **kwargs)
|
|
634
947
|
|
|
635
|
-
def wrap(
|
|
948
|
+
def wrap(self, sample: WDSRawSample) -> ST:
|
|
636
949
|
"""Wrap a raw msgpack sample into the appropriate dataset-specific type.
|
|
637
950
|
|
|
638
951
|
Args:
|
|
@@ -643,16 +956,22 @@ class Dataset( Generic[ST] ):
|
|
|
643
956
|
A deserialized sample of type ``ST``, optionally transformed through
|
|
644
957
|
a lens if ``as_type()`` was called.
|
|
645
958
|
"""
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
959
|
+
if "msgpack" not in sample:
|
|
960
|
+
raise ValueError(
|
|
961
|
+
f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}"
|
|
962
|
+
)
|
|
963
|
+
if not isinstance(sample["msgpack"], bytes):
|
|
964
|
+
raise ValueError(
|
|
965
|
+
f"Expected sample['msgpack'] to be bytes, got {type(sample['msgpack']).__name__}"
|
|
966
|
+
)
|
|
967
|
+
|
|
649
968
|
if self._output_lens is None:
|
|
650
|
-
return self.sample_type.from_bytes(
|
|
969
|
+
return self.sample_type.from_bytes(sample["msgpack"])
|
|
651
970
|
|
|
652
|
-
source_sample = self._output_lens.source_type.from_bytes(
|
|
653
|
-
return self._output_lens(
|
|
971
|
+
source_sample = self._output_lens.source_type.from_bytes(sample["msgpack"])
|
|
972
|
+
return self._output_lens(source_sample)
|
|
654
973
|
|
|
655
|
-
def wrap_batch(
|
|
974
|
+
def wrap_batch(self, batch: WDSRawBatch) -> SampleBatch[ST]:
|
|
656
975
|
"""Wrap a batch of raw msgpack samples into a typed SampleBatch.
|
|
657
976
|
|
|
658
977
|
Args:
|
|
@@ -668,35 +987,48 @@ class Dataset( Generic[ST] ):
|
|
|
668
987
|
aggregates them into a batch.
|
|
669
988
|
"""
|
|
670
989
|
|
|
671
|
-
|
|
990
|
+
if "msgpack" not in batch:
|
|
991
|
+
raise ValueError(
|
|
992
|
+
f"Batch missing 'msgpack' key, got keys: {list(batch.keys())}"
|
|
993
|
+
)
|
|
672
994
|
|
|
673
995
|
if self._output_lens is None:
|
|
674
|
-
batch_unpacked = [
|
|
675
|
-
|
|
676
|
-
|
|
996
|
+
batch_unpacked = [
|
|
997
|
+
self.sample_type.from_bytes(bs) for bs in batch["msgpack"]
|
|
998
|
+
]
|
|
999
|
+
return SampleBatch[self.sample_type](batch_unpacked)
|
|
1000
|
+
|
|
1001
|
+
batch_source = [
|
|
1002
|
+
self._output_lens.source_type.from_bytes(bs) for bs in batch["msgpack"]
|
|
1003
|
+
]
|
|
1004
|
+
batch_view = [self._output_lens(s) for s in batch_source]
|
|
1005
|
+
return SampleBatch[self.sample_type](batch_view)
|
|
677
1006
|
|
|
678
|
-
batch_source = [ self._output_lens.source_type.from_bytes( bs )
|
|
679
|
-
for bs in batch['msgpack'] ]
|
|
680
|
-
batch_view = [ self._output_lens( s )
|
|
681
|
-
for s in batch_source ]
|
|
682
|
-
return SampleBatch[self.sample_type]( batch_view )
|
|
683
1007
|
|
|
1008
|
+
_T = TypeVar("_T")
|
|
684
1009
|
|
|
685
|
-
|
|
1010
|
+
|
|
1011
|
+
@dataclass_transform()
|
|
1012
|
+
def packable(cls: type[_T]) -> type[_T]:
|
|
686
1013
|
"""Decorator to convert a regular class into a ``PackableSample``.
|
|
687
1014
|
|
|
688
1015
|
This decorator transforms a class into a dataclass that inherits from
|
|
689
1016
|
``PackableSample``, enabling automatic msgpack serialization/deserialization
|
|
690
1017
|
with special handling for NDArray fields.
|
|
691
1018
|
|
|
1019
|
+
The resulting class satisfies the ``Packable`` protocol, making it compatible
|
|
1020
|
+
with all atdata APIs that accept packable types (e.g., ``publish_schema``,
|
|
1021
|
+
lens transformations, etc.).
|
|
1022
|
+
|
|
692
1023
|
Args:
|
|
693
1024
|
cls: The class to convert. Should have type annotations for its fields.
|
|
694
1025
|
|
|
695
1026
|
Returns:
|
|
696
1027
|
A new dataclass that inherits from ``PackableSample`` with the same
|
|
697
|
-
name and annotations as the original class.
|
|
1028
|
+
name and annotations as the original class. The class satisfies the
|
|
1029
|
+
``Packable`` protocol and can be used with ``Type[Packable]`` signatures.
|
|
698
1030
|
|
|
699
|
-
|
|
1031
|
+
Examples:
|
|
700
1032
|
>>> @packable
|
|
701
1033
|
... class MyData:
|
|
702
1034
|
... name: str
|
|
@@ -705,6 +1037,9 @@ def packable( cls ):
|
|
|
705
1037
|
>>> sample = MyData(name="test", values=np.array([1, 2, 3]))
|
|
706
1038
|
>>> bytes_data = sample.packed
|
|
707
1039
|
>>> restored = MyData.from_bytes(bytes_data)
|
|
1040
|
+
>>>
|
|
1041
|
+
>>> # Works with Packable-typed APIs
|
|
1042
|
+
>>> index.publish_schema(MyData, version="1.0.0") # Type-safe
|
|
708
1043
|
"""
|
|
709
1044
|
|
|
710
1045
|
##
|
|
@@ -713,18 +1048,41 @@ def packable( cls ):
|
|
|
713
1048
|
class_annotations = cls.__annotations__
|
|
714
1049
|
|
|
715
1050
|
# Add in dataclass niceness to original class
|
|
716
|
-
as_dataclass = dataclass(
|
|
1051
|
+
as_dataclass = dataclass(cls)
|
|
717
1052
|
|
|
718
1053
|
# This triggers a bunch of behind-the-scenes stuff for the newly annotated class
|
|
719
1054
|
@dataclass
|
|
720
|
-
class as_packable(
|
|
721
|
-
def __post_init__(
|
|
722
|
-
return PackableSample.__post_init__(
|
|
723
|
-
|
|
724
|
-
#
|
|
1055
|
+
class as_packable(as_dataclass, PackableSample):
|
|
1056
|
+
def __post_init__(self):
|
|
1057
|
+
return PackableSample.__post_init__(self)
|
|
1058
|
+
|
|
1059
|
+
# Restore original class identity for better repr/debugging
|
|
725
1060
|
as_packable.__name__ = class_name
|
|
1061
|
+
as_packable.__qualname__ = class_name
|
|
1062
|
+
as_packable.__module__ = cls.__module__
|
|
726
1063
|
as_packable.__annotations__ = class_annotations
|
|
1064
|
+
if cls.__doc__:
|
|
1065
|
+
as_packable.__doc__ = cls.__doc__
|
|
1066
|
+
|
|
1067
|
+
# Fix qualnames of dataclass-generated methods so they don't show
|
|
1068
|
+
# 'packable.<locals>.as_packable' in help() and IDE hints
|
|
1069
|
+
old_qualname_prefix = "packable.<locals>.as_packable"
|
|
1070
|
+
for attr_name in ("__init__", "__repr__", "__eq__", "__post_init__"):
|
|
1071
|
+
attr = getattr(as_packable, attr_name, None)
|
|
1072
|
+
if attr is not None and hasattr(attr, "__qualname__"):
|
|
1073
|
+
if attr.__qualname__.startswith(old_qualname_prefix):
|
|
1074
|
+
attr.__qualname__ = attr.__qualname__.replace(
|
|
1075
|
+
old_qualname_prefix, class_name, 1
|
|
1076
|
+
)
|
|
1077
|
+
|
|
1078
|
+
# Auto-register lens from DictSample to this type
|
|
1079
|
+
# This enables ds.as_type(MyType) when ds is Dataset[DictSample]
|
|
1080
|
+
def _dict_to_typed(ds: DictSample) -> as_packable:
|
|
1081
|
+
return as_packable.from_data(ds._data)
|
|
1082
|
+
|
|
1083
|
+
_dict_lens = Lens(_dict_to_typed)
|
|
1084
|
+
LensNetwork().register(_dict_lens)
|
|
727
1085
|
|
|
728
1086
|
##
|
|
729
1087
|
|
|
730
|
-
return as_packable
|
|
1088
|
+
return as_packable
|