atdata 0.2.2b1__py3-none-any.whl → 0.3.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/.gitignore +1 -0
- atdata/__init__.py +31 -1
- atdata/_cid.py +29 -35
- atdata/_exceptions.py +168 -0
- atdata/_helpers.py +33 -17
- atdata/_hf_api.py +109 -59
- atdata/_logging.py +70 -0
- atdata/_protocols.py +74 -132
- atdata/_schema_codec.py +38 -41
- atdata/_sources.py +57 -64
- atdata/_stub_manager.py +31 -26
- atdata/_type_utils.py +47 -7
- atdata/atmosphere/__init__.py +31 -24
- atdata/atmosphere/_types.py +11 -11
- atdata/atmosphere/client.py +11 -8
- atdata/atmosphere/lens.py +27 -30
- atdata/atmosphere/records.py +34 -39
- atdata/atmosphere/schema.py +35 -31
- atdata/atmosphere/store.py +16 -20
- atdata/cli/__init__.py +163 -168
- atdata/cli/diagnose.py +12 -8
- atdata/cli/inspect.py +69 -0
- atdata/cli/local.py +5 -2
- atdata/cli/preview.py +63 -0
- atdata/cli/schema.py +109 -0
- atdata/dataset.py +678 -533
- atdata/lens.py +85 -83
- atdata/local/__init__.py +71 -0
- atdata/local/_entry.py +157 -0
- atdata/local/_index.py +940 -0
- atdata/local/_repo_legacy.py +218 -0
- atdata/local/_s3.py +349 -0
- atdata/local/_schema.py +380 -0
- atdata/manifest/__init__.py +28 -0
- atdata/manifest/_aggregates.py +156 -0
- atdata/manifest/_builder.py +163 -0
- atdata/manifest/_fields.py +154 -0
- atdata/manifest/_manifest.py +146 -0
- atdata/manifest/_query.py +150 -0
- atdata/manifest/_writer.py +74 -0
- atdata/promote.py +20 -24
- atdata/providers/__init__.py +25 -0
- atdata/providers/_base.py +140 -0
- atdata/providers/_factory.py +69 -0
- atdata/providers/_postgres.py +214 -0
- atdata/providers/_redis.py +171 -0
- atdata/providers/_sqlite.py +191 -0
- atdata/repository.py +323 -0
- atdata/testing.py +337 -0
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/METADATA +5 -1
- atdata-0.3.0b1.dist-info/RECORD +54 -0
- atdata/local.py +0 -1707
- atdata-0.2.2b1.dist-info/RECORD +0 -28
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/licenses/LICENSE +0 -0
atdata/dataset.py
CHANGED
|
@@ -13,18 +13,16 @@ The implementation handles automatic conversion between numpy arrays and bytes
|
|
|
13
13
|
during serialization, enabling efficient storage of numerical data in WebDataset
|
|
14
14
|
archives.
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
... images = batch.image # Stacked numpy array (32, H, W, C)
|
|
27
|
-
... labels = batch.label # List of 32 strings
|
|
16
|
+
Examples:
|
|
17
|
+
>>> @packable
|
|
18
|
+
... class ImageSample:
|
|
19
|
+
... image: NDArray
|
|
20
|
+
... label: str
|
|
21
|
+
...
|
|
22
|
+
>>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
|
|
23
|
+
>>> for batch in ds.shuffled(batch_size=32):
|
|
24
|
+
... images = batch.image # Stacked numpy array (32, H, W, C)
|
|
25
|
+
... labels = batch.label # List of 32 strings
|
|
28
26
|
"""
|
|
29
27
|
|
|
30
28
|
##
|
|
@@ -33,6 +31,7 @@ Example:
|
|
|
33
31
|
import webdataset as wds
|
|
34
32
|
|
|
35
33
|
from pathlib import Path
|
|
34
|
+
import itertools
|
|
36
35
|
import uuid
|
|
37
36
|
|
|
38
37
|
import dataclasses
|
|
@@ -43,16 +42,17 @@ from dataclasses import (
|
|
|
43
42
|
)
|
|
44
43
|
from abc import ABC
|
|
45
44
|
|
|
46
|
-
from ._sources import URLSource
|
|
47
|
-
from ._protocols import DataSource
|
|
45
|
+
from ._sources import URLSource
|
|
46
|
+
from ._protocols import DataSource, Packable
|
|
47
|
+
from ._exceptions import SampleKeyError, PartialFailureError
|
|
48
48
|
|
|
49
|
-
from tqdm import tqdm
|
|
50
49
|
import numpy as np
|
|
51
50
|
import pandas as pd
|
|
52
51
|
import requests
|
|
53
52
|
|
|
54
53
|
import typing
|
|
55
54
|
from typing import (
|
|
55
|
+
TYPE_CHECKING,
|
|
56
56
|
Any,
|
|
57
57
|
Optional,
|
|
58
58
|
Dict,
|
|
@@ -66,7 +66,11 @@ from typing import (
|
|
|
66
66
|
TypeVar,
|
|
67
67
|
TypeAlias,
|
|
68
68
|
dataclass_transform,
|
|
69
|
+
overload,
|
|
69
70
|
)
|
|
71
|
+
|
|
72
|
+
if TYPE_CHECKING:
|
|
73
|
+
from .manifest._query import SampleLocation
|
|
70
74
|
from numpy.typing import NDArray
|
|
71
75
|
|
|
72
76
|
import msgpack
|
|
@@ -85,30 +89,31 @@ WDSRawSample: TypeAlias = Dict[str, Any]
|
|
|
85
89
|
WDSRawBatch: TypeAlias = Dict[str, Any]
|
|
86
90
|
|
|
87
91
|
SampleExportRow: TypeAlias = Dict[str, Any]
|
|
88
|
-
SampleExportMap: TypeAlias = Callable[[
|
|
92
|
+
SampleExportMap: TypeAlias = Callable[["PackableSample"], SampleExportRow]
|
|
89
93
|
|
|
90
94
|
|
|
91
95
|
##
|
|
92
96
|
# Main base classes
|
|
93
97
|
|
|
94
|
-
DT = TypeVar(
|
|
98
|
+
DT = TypeVar("DT")
|
|
95
99
|
|
|
96
100
|
|
|
97
|
-
def _make_packable(
|
|
101
|
+
def _make_packable(x):
|
|
98
102
|
"""Convert numpy arrays to bytes; pass through other values unchanged."""
|
|
99
|
-
if isinstance(
|
|
100
|
-
return eh.array_to_bytes(
|
|
103
|
+
if isinstance(x, np.ndarray):
|
|
104
|
+
return eh.array_to_bytes(x)
|
|
101
105
|
return x
|
|
102
106
|
|
|
103
107
|
|
|
104
|
-
def _is_possibly_ndarray_type(
|
|
108
|
+
def _is_possibly_ndarray_type(t):
|
|
105
109
|
"""Return True if type annotation is NDArray or Optional[NDArray]."""
|
|
106
110
|
if t == NDArray:
|
|
107
111
|
return True
|
|
108
|
-
if isinstance(
|
|
109
|
-
return any(
|
|
112
|
+
if isinstance(t, types.UnionType):
|
|
113
|
+
return any(x == NDArray for x in t.__args__)
|
|
110
114
|
return False
|
|
111
115
|
|
|
116
|
+
|
|
112
117
|
class DictSample:
|
|
113
118
|
"""Dynamic sample type providing dict-like access to raw msgpack data.
|
|
114
119
|
|
|
@@ -126,24 +131,22 @@ class DictSample:
|
|
|
126
131
|
``@packable``-decorated class. Every ``@packable`` class automatically
|
|
127
132
|
registers a lens from ``DictSample``, making this conversion seamless.
|
|
128
133
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
>>> # Convert to typed schema
|
|
139
|
-
>>> typed_ds = ds.as_type(MyTypedSample)
|
|
134
|
+
Examples:
|
|
135
|
+
>>> ds = load_dataset("path/to/data.tar") # Returns Dataset[DictSample]
|
|
136
|
+
>>> for sample in ds.ordered():
|
|
137
|
+
... print(sample.some_field) # Attribute access
|
|
138
|
+
... print(sample["other_field"]) # Dict access
|
|
139
|
+
... print(sample.keys()) # Inspect available fields
|
|
140
|
+
...
|
|
141
|
+
>>> # Convert to typed schema
|
|
142
|
+
>>> typed_ds = ds.as_type(MyTypedSample)
|
|
140
143
|
|
|
141
144
|
Note:
|
|
142
145
|
NDArray fields are stored as raw bytes in DictSample. They are only
|
|
143
146
|
converted to numpy arrays when accessed through a typed sample class.
|
|
144
147
|
"""
|
|
145
148
|
|
|
146
|
-
__slots__ = (
|
|
149
|
+
__slots__ = ("_data",)
|
|
147
150
|
|
|
148
151
|
def __init__(self, _data: dict[str, Any] | None = None, **kwargs: Any) -> None:
|
|
149
152
|
"""Create a DictSample from a dictionary or keyword arguments.
|
|
@@ -153,48 +156,28 @@ class DictSample:
|
|
|
153
156
|
**kwargs: Field values if _data is not provided.
|
|
154
157
|
"""
|
|
155
158
|
if _data is not None:
|
|
156
|
-
object.__setattr__(self,
|
|
159
|
+
object.__setattr__(self, "_data", _data)
|
|
157
160
|
else:
|
|
158
|
-
object.__setattr__(self,
|
|
161
|
+
object.__setattr__(self, "_data", kwargs)
|
|
159
162
|
|
|
160
163
|
@classmethod
|
|
161
|
-
def from_data(cls, data: dict[str, Any]) ->
|
|
162
|
-
"""Create a DictSample from unpacked msgpack data.
|
|
163
|
-
|
|
164
|
-
Args:
|
|
165
|
-
data: Dictionary with field names as keys.
|
|
166
|
-
|
|
167
|
-
Returns:
|
|
168
|
-
New DictSample instance wrapping the data.
|
|
169
|
-
"""
|
|
164
|
+
def from_data(cls, data: dict[str, Any]) -> "DictSample":
|
|
165
|
+
"""Create a DictSample from unpacked msgpack data."""
|
|
170
166
|
return cls(_data=data)
|
|
171
167
|
|
|
172
168
|
@classmethod
|
|
173
|
-
def from_bytes(cls, bs: bytes) ->
|
|
174
|
-
"""Create a DictSample from raw msgpack bytes.
|
|
175
|
-
|
|
176
|
-
Args:
|
|
177
|
-
bs: Raw bytes from a msgpack-serialized sample.
|
|
178
|
-
|
|
179
|
-
Returns:
|
|
180
|
-
New DictSample instance with the unpacked data.
|
|
181
|
-
"""
|
|
169
|
+
def from_bytes(cls, bs: bytes) -> "DictSample":
|
|
170
|
+
"""Create a DictSample from raw msgpack bytes."""
|
|
182
171
|
return cls.from_data(ormsgpack.unpackb(bs))
|
|
183
172
|
|
|
184
173
|
def __getattr__(self, name: str) -> Any:
|
|
185
174
|
"""Access a field by attribute name.
|
|
186
175
|
|
|
187
|
-
Args:
|
|
188
|
-
name: Field name to access.
|
|
189
|
-
|
|
190
|
-
Returns:
|
|
191
|
-
The field value.
|
|
192
|
-
|
|
193
176
|
Raises:
|
|
194
177
|
AttributeError: If the field doesn't exist.
|
|
195
178
|
"""
|
|
196
179
|
# Avoid infinite recursion for _data lookup
|
|
197
|
-
if name ==
|
|
180
|
+
if name == "_data":
|
|
198
181
|
raise AttributeError(name)
|
|
199
182
|
try:
|
|
200
183
|
return self._data[name]
|
|
@@ -205,21 +188,9 @@ class DictSample:
|
|
|
205
188
|
) from None
|
|
206
189
|
|
|
207
190
|
def __getitem__(self, key: str) -> Any:
|
|
208
|
-
"""Access a field by dict key.
|
|
209
|
-
|
|
210
|
-
Args:
|
|
211
|
-
key: Field name to access.
|
|
212
|
-
|
|
213
|
-
Returns:
|
|
214
|
-
The field value.
|
|
215
|
-
|
|
216
|
-
Raises:
|
|
217
|
-
KeyError: If the field doesn't exist.
|
|
218
|
-
"""
|
|
219
191
|
return self._data[key]
|
|
220
192
|
|
|
221
193
|
def __contains__(self, key: str) -> bool:
|
|
222
|
-
"""Check if a field exists."""
|
|
223
194
|
return key in self._data
|
|
224
195
|
|
|
225
196
|
def keys(self) -> list[str]:
|
|
@@ -227,23 +198,13 @@ class DictSample:
|
|
|
227
198
|
return list(self._data.keys())
|
|
228
199
|
|
|
229
200
|
def values(self) -> list[Any]:
|
|
230
|
-
"""Return list of field values."""
|
|
231
201
|
return list(self._data.values())
|
|
232
202
|
|
|
233
203
|
def items(self) -> list[tuple[str, Any]]:
|
|
234
|
-
"""Return list of (field_name, value) tuples."""
|
|
235
204
|
return list(self._data.items())
|
|
236
205
|
|
|
237
206
|
def get(self, key: str, default: Any = None) -> Any:
|
|
238
|
-
"""Get a field value
|
|
239
|
-
|
|
240
|
-
Args:
|
|
241
|
-
key: Field name to access.
|
|
242
|
-
default: Value to return if field doesn't exist.
|
|
243
|
-
|
|
244
|
-
Returns:
|
|
245
|
-
The field value or default.
|
|
246
|
-
"""
|
|
207
|
+
"""Get a field value, returning *default* if missing."""
|
|
247
208
|
return self._data.get(key, default)
|
|
248
209
|
|
|
249
210
|
def to_dict(self) -> dict[str, Any]:
|
|
@@ -252,32 +213,24 @@ class DictSample:
|
|
|
252
213
|
|
|
253
214
|
@property
|
|
254
215
|
def packed(self) -> bytes:
|
|
255
|
-
"""
|
|
256
|
-
|
|
257
|
-
Returns:
|
|
258
|
-
Raw msgpack bytes representing this sample's data.
|
|
259
|
-
"""
|
|
216
|
+
"""Serialize to msgpack bytes."""
|
|
260
217
|
return msgpack.packb(self._data)
|
|
261
218
|
|
|
262
219
|
@property
|
|
263
|
-
def as_wds(self) ->
|
|
264
|
-
"""
|
|
265
|
-
|
|
266
|
-
Returns:
|
|
267
|
-
A dictionary with ``__key__`` and ``msgpack`` fields.
|
|
268
|
-
"""
|
|
220
|
+
def as_wds(self) -> "WDSRawSample":
|
|
221
|
+
"""Serialize for writing to WebDataset (``__key__`` + ``msgpack``)."""
|
|
269
222
|
return {
|
|
270
|
-
|
|
271
|
-
|
|
223
|
+
"__key__": str(uuid.uuid1(0, 0)),
|
|
224
|
+
"msgpack": self.packed,
|
|
272
225
|
}
|
|
273
226
|
|
|
274
227
|
def __repr__(self) -> str:
|
|
275
|
-
fields =
|
|
276
|
-
return f
|
|
228
|
+
fields = ", ".join(f"{k}=..." for k in self._data.keys())
|
|
229
|
+
return f"DictSample({fields})"
|
|
277
230
|
|
|
278
231
|
|
|
279
232
|
@dataclass
|
|
280
|
-
class PackableSample(
|
|
233
|
+
class PackableSample(ABC):
|
|
281
234
|
"""Base class for samples that can be serialized with msgpack.
|
|
282
235
|
|
|
283
236
|
This abstract base class provides automatic serialization/deserialization
|
|
@@ -289,218 +242,122 @@ class PackableSample( ABC ):
|
|
|
289
242
|
1. Direct inheritance with the ``@dataclass`` decorator
|
|
290
243
|
2. Using the ``@packable`` decorator (recommended)
|
|
291
244
|
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
>>> packed = sample.packed # Serialize to bytes
|
|
302
|
-
>>> restored = MyData.from_bytes(packed) # Deserialize
|
|
245
|
+
Examples:
|
|
246
|
+
>>> @packable
|
|
247
|
+
... class MyData:
|
|
248
|
+
... name: str
|
|
249
|
+
... embeddings: NDArray
|
|
250
|
+
...
|
|
251
|
+
>>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
|
|
252
|
+
>>> packed = sample.packed # Serialize to bytes
|
|
253
|
+
>>> restored = MyData.from_bytes(packed) # Deserialize
|
|
303
254
|
"""
|
|
304
255
|
|
|
305
|
-
def _ensure_good(
|
|
256
|
+
def _ensure_good(self):
|
|
306
257
|
"""Convert bytes to NDArray for fields annotated as NDArray or NDArray | None."""
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
var_name = field.name
|
|
312
|
-
var_type = field.type
|
|
313
|
-
|
|
314
|
-
# Annotation for this variable is to be an NDArray
|
|
315
|
-
if _is_possibly_ndarray_type( var_type ):
|
|
316
|
-
# ... so, we'll always auto-convert to numpy
|
|
317
|
-
|
|
318
|
-
var_cur_value = getattr( self, var_name )
|
|
319
|
-
|
|
320
|
-
# Execute the appropriate conversion for intermediate data
|
|
321
|
-
# based on what is provided
|
|
322
|
-
|
|
323
|
-
if isinstance( var_cur_value, np.ndarray ):
|
|
324
|
-
# Already the correct type, no conversion needed
|
|
258
|
+
for field in dataclasses.fields(self):
|
|
259
|
+
if _is_possibly_ndarray_type(field.type):
|
|
260
|
+
value = getattr(self, field.name)
|
|
261
|
+
if isinstance(value, np.ndarray):
|
|
325
262
|
continue
|
|
263
|
+
elif isinstance(value, bytes):
|
|
264
|
+
setattr(self, field.name, eh.bytes_to_array(value))
|
|
326
265
|
|
|
327
|
-
|
|
328
|
-
# Design note: bytes in NDArray-typed fields are always interpreted
|
|
329
|
-
# as serialized arrays. This means raw bytes fields must not be
|
|
330
|
-
# annotated as NDArray.
|
|
331
|
-
setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
|
|
332
|
-
|
|
333
|
-
def __post_init__( self ):
|
|
266
|
+
def __post_init__(self):
|
|
334
267
|
self._ensure_good()
|
|
335
268
|
|
|
336
269
|
##
|
|
337
270
|
|
|
338
271
|
@classmethod
|
|
339
|
-
def from_data(
|
|
340
|
-
"""Create
|
|
272
|
+
def from_data(cls, data: WDSRawSample) -> Self:
|
|
273
|
+
"""Create an instance from unpacked msgpack data."""
|
|
274
|
+
return cls(**data)
|
|
341
275
|
|
|
342
|
-
Args:
|
|
343
|
-
data: Dictionary with keys matching the sample's field names.
|
|
344
|
-
|
|
345
|
-
Returns:
|
|
346
|
-
New instance with NDArray fields auto-converted from bytes.
|
|
347
|
-
"""
|
|
348
|
-
return cls( **data )
|
|
349
|
-
|
|
350
276
|
@classmethod
|
|
351
|
-
def from_bytes(
|
|
352
|
-
"""Create
|
|
353
|
-
|
|
354
|
-
Args:
|
|
355
|
-
bs: Raw bytes from a msgpack-serialized sample.
|
|
356
|
-
|
|
357
|
-
Returns:
|
|
358
|
-
A new instance of this sample class deserialized from the bytes.
|
|
359
|
-
"""
|
|
360
|
-
return cls.from_data( ormsgpack.unpackb( bs ) )
|
|
277
|
+
def from_bytes(cls, bs: bytes) -> Self:
|
|
278
|
+
"""Create an instance from raw msgpack bytes."""
|
|
279
|
+
return cls.from_data(ormsgpack.unpackb(bs))
|
|
361
280
|
|
|
362
281
|
@property
|
|
363
|
-
def packed(
|
|
364
|
-
"""
|
|
365
|
-
|
|
366
|
-
NDArray fields are automatically converted to bytes before packing.
|
|
367
|
-
All other fields are packed as-is if they're msgpack-compatible.
|
|
368
|
-
|
|
369
|
-
Returns:
|
|
370
|
-
Raw msgpack bytes representing this sample's data.
|
|
282
|
+
def packed(self) -> bytes:
|
|
283
|
+
"""Serialize to msgpack bytes. NDArray fields are auto-converted.
|
|
371
284
|
|
|
372
285
|
Raises:
|
|
373
286
|
RuntimeError: If msgpack serialization fails.
|
|
374
287
|
"""
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
# format
|
|
378
|
-
o = {
|
|
379
|
-
k: _make_packable( v )
|
|
380
|
-
for k, v in vars( self ).items()
|
|
381
|
-
}
|
|
382
|
-
|
|
383
|
-
ret = msgpack.packb( o )
|
|
384
|
-
|
|
288
|
+
o = {k: _make_packable(v) for k, v in vars(self).items()}
|
|
289
|
+
ret = msgpack.packb(o)
|
|
385
290
|
if ret is None:
|
|
386
|
-
raise RuntimeError(
|
|
387
|
-
|
|
291
|
+
raise RuntimeError(f"Failed to pack sample to bytes: {o}")
|
|
388
292
|
return ret
|
|
389
|
-
|
|
390
|
-
@property
|
|
391
|
-
def as_wds( self ) -> WDSRawSample:
|
|
392
|
-
"""Pack this sample's data for writing to WebDataset.
|
|
393
|
-
|
|
394
|
-
Returns:
|
|
395
|
-
A dictionary with ``__key__`` (UUID v1 for sortable keys) and
|
|
396
|
-
``msgpack`` (packed sample data) fields suitable for WebDataset.
|
|
397
293
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
"""
|
|
294
|
+
@property
|
|
295
|
+
def as_wds(self) -> WDSRawSample:
|
|
296
|
+
"""Serialize for writing to WebDataset (``__key__`` + ``msgpack``)."""
|
|
402
297
|
return {
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
'msgpack': self.packed,
|
|
298
|
+
"__key__": str(uuid.uuid1(0, 0)),
|
|
299
|
+
"msgpack": self.packed,
|
|
406
300
|
}
|
|
407
301
|
|
|
408
|
-
|
|
302
|
+
|
|
303
|
+
def _batch_aggregate(xs: Sequence):
|
|
409
304
|
"""Stack arrays into numpy array with batch dim; otherwise return list."""
|
|
410
305
|
if not xs:
|
|
411
306
|
return []
|
|
412
|
-
if isinstance(
|
|
413
|
-
return np.array(
|
|
414
|
-
return list(
|
|
307
|
+
if isinstance(xs[0], np.ndarray):
|
|
308
|
+
return np.array(list(xs))
|
|
309
|
+
return list(xs)
|
|
415
310
|
|
|
416
|
-
class SampleBatch( Generic[DT] ):
|
|
417
|
-
"""A batch of samples with automatic attribute aggregation.
|
|
418
311
|
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
exists on the sample type, it automatically aggregates values across all
|
|
422
|
-
samples in the batch.
|
|
312
|
+
class SampleBatch(Generic[DT]):
|
|
313
|
+
"""A batch of samples with automatic attribute aggregation.
|
|
423
314
|
|
|
424
|
-
|
|
425
|
-
|
|
315
|
+
Accessing an attribute aggregates that field across all samples:
|
|
316
|
+
NDArray fields are stacked into a numpy array with a batch dimension;
|
|
317
|
+
other fields are collected into a list. Results are cached.
|
|
426
318
|
|
|
427
319
|
Parameters:
|
|
428
320
|
DT: The sample type, must derive from ``PackableSample``.
|
|
429
321
|
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
::
|
|
435
|
-
|
|
436
|
-
>>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
|
|
437
|
-
>>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
|
|
438
|
-
>>> batch.names # Returns list of names
|
|
439
|
-
|
|
440
|
-
Note:
|
|
441
|
-
This class uses Python's ``__orig_class__`` mechanism to extract the
|
|
442
|
-
type parameter at runtime. Instances must be created using the
|
|
443
|
-
subscripted syntax ``SampleBatch[MyType](samples)`` rather than
|
|
444
|
-
calling the constructor directly with an unsubscripted class.
|
|
322
|
+
Examples:
|
|
323
|
+
>>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
|
|
324
|
+
>>> batch.embeddings # Stacked numpy array of shape (3, ...)
|
|
325
|
+
>>> batch.names # List of names
|
|
445
326
|
"""
|
|
446
|
-
# Design note: The docstring uses "Parameters:" for type parameters because
|
|
447
|
-
# quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
|
|
448
|
-
|
|
449
|
-
def __init__( self, samples: Sequence[DT] ):
|
|
450
|
-
"""Create a batch from a sequence of samples.
|
|
451
327
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
``PackableSample``.
|
|
456
|
-
"""
|
|
457
|
-
self.samples = list( samples )
|
|
328
|
+
def __init__(self, samples: Sequence[DT]):
|
|
329
|
+
"""Create a batch from a sequence of samples."""
|
|
330
|
+
self.samples = list(samples)
|
|
458
331
|
self._aggregate_cache = dict()
|
|
459
332
|
self._sample_type_cache: Type | None = None
|
|
460
333
|
|
|
461
334
|
@property
|
|
462
|
-
def sample_type(
|
|
463
|
-
"""The type
|
|
464
|
-
|
|
465
|
-
Returns:
|
|
466
|
-
The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
|
|
467
|
-
"""
|
|
335
|
+
def sample_type(self) -> Type:
|
|
336
|
+
"""The type parameter ``DT`` used when creating this batch."""
|
|
468
337
|
if self._sample_type_cache is None:
|
|
469
|
-
self._sample_type_cache = typing.get_args(
|
|
470
|
-
|
|
338
|
+
self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
|
|
339
|
+
if self._sample_type_cache is None:
|
|
340
|
+
raise TypeError(
|
|
341
|
+
"SampleBatch requires a type parameter, e.g. SampleBatch[MySample]"
|
|
342
|
+
)
|
|
471
343
|
return self._sample_type_cache
|
|
472
344
|
|
|
473
|
-
def __getattr__(
|
|
474
|
-
"""Aggregate
|
|
475
|
-
|
|
476
|
-
This magic method enables attribute-style access to aggregated sample
|
|
477
|
-
fields. Results are cached for efficiency.
|
|
478
|
-
|
|
479
|
-
Args:
|
|
480
|
-
name: The attribute name to aggregate across samples.
|
|
481
|
-
|
|
482
|
-
Returns:
|
|
483
|
-
For NDArray fields: a stacked numpy array with batch dimension.
|
|
484
|
-
For other fields: a list of values from each sample.
|
|
485
|
-
|
|
486
|
-
Raises:
|
|
487
|
-
AttributeError: If the attribute doesn't exist on the sample type.
|
|
488
|
-
"""
|
|
345
|
+
def __getattr__(self, name):
|
|
346
|
+
"""Aggregate a field across all samples (cached)."""
|
|
489
347
|
# Aggregate named params of sample type
|
|
490
|
-
if name in vars(
|
|
348
|
+
if name in vars(self.sample_type)["__annotations__"]:
|
|
491
349
|
if name not in self._aggregate_cache:
|
|
492
350
|
self._aggregate_cache[name] = _batch_aggregate(
|
|
493
|
-
[
|
|
494
|
-
for x in self.samples ]
|
|
351
|
+
[getattr(x, name) for x in self.samples]
|
|
495
352
|
)
|
|
496
353
|
|
|
497
354
|
return self._aggregate_cache[name]
|
|
498
355
|
|
|
499
|
-
raise AttributeError(
|
|
356
|
+
raise AttributeError(f"No sample attribute named {name}")
|
|
500
357
|
|
|
501
358
|
|
|
502
|
-
ST = TypeVar(
|
|
503
|
-
RT = TypeVar(
|
|
359
|
+
ST = TypeVar("ST", bound=Packable)
|
|
360
|
+
RT = TypeVar("RT", bound=Packable)
|
|
504
361
|
|
|
505
362
|
|
|
506
363
|
class _ShardListStage(wds.utils.PipelineStage):
|
|
@@ -538,7 +395,7 @@ class _StreamOpenerStage(wds.utils.PipelineStage):
|
|
|
538
395
|
yield sample
|
|
539
396
|
|
|
540
397
|
|
|
541
|
-
class Dataset(
|
|
398
|
+
class Dataset(Generic[ST]):
|
|
542
399
|
"""A typed dataset built on WebDataset with lens transformations.
|
|
543
400
|
|
|
544
401
|
This class wraps WebDataset tar archives and provides type-safe iteration
|
|
@@ -557,16 +414,14 @@ class Dataset( Generic[ST] ):
|
|
|
557
414
|
Attributes:
|
|
558
415
|
url: WebDataset brace-notation URL for the tar file(s).
|
|
559
416
|
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
>>> # Transform to a different view
|
|
569
|
-
>>> ds_view = ds.as_type(MyDataView)
|
|
417
|
+
Examples:
|
|
418
|
+
>>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
|
|
419
|
+
>>> for sample in ds.ordered(batch_size=32):
|
|
420
|
+
... # sample is SampleBatch[MyData] with batch_size samples
|
|
421
|
+
... embeddings = sample.embeddings # shape: (32, ...)
|
|
422
|
+
...
|
|
423
|
+
>>> # Transform to a different view
|
|
424
|
+
>>> ds_view = ds.as_type(MyDataView)
|
|
570
425
|
|
|
571
426
|
Note:
|
|
572
427
|
This class uses Python's ``__orig_class__`` mechanism to extract the
|
|
@@ -574,35 +429,33 @@ class Dataset( Generic[ST] ):
|
|
|
574
429
|
subscripted syntax ``Dataset[MyType](url)`` rather than calling the
|
|
575
430
|
constructor directly with an unsubscripted class.
|
|
576
431
|
"""
|
|
432
|
+
|
|
577
433
|
# Design note: The docstring uses "Parameters:" for type parameters because
|
|
578
434
|
# quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
|
|
579
435
|
|
|
580
436
|
@property
|
|
581
|
-
def sample_type(
|
|
582
|
-
"""The type
|
|
583
|
-
|
|
584
|
-
Returns:
|
|
585
|
-
The type parameter ``ST`` used when creating this ``Dataset[ST]``.
|
|
586
|
-
"""
|
|
437
|
+
def sample_type(self) -> Type:
|
|
438
|
+
"""The type parameter ``ST`` used when creating this dataset."""
|
|
587
439
|
if self._sample_type_cache is None:
|
|
588
|
-
self._sample_type_cache = typing.get_args(
|
|
589
|
-
|
|
440
|
+
self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
|
|
441
|
+
if self._sample_type_cache is None:
|
|
442
|
+
raise TypeError(
|
|
443
|
+
"Dataset requires a type parameter, e.g. Dataset[MySample]"
|
|
444
|
+
)
|
|
590
445
|
return self._sample_type_cache
|
|
591
|
-
@property
|
|
592
|
-
def batch_type( self ) -> Type:
|
|
593
|
-
"""The type of batches produced by this dataset.
|
|
594
446
|
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
"""
|
|
447
|
+
@property
|
|
448
|
+
def batch_type(self) -> Type:
|
|
449
|
+
"""``SampleBatch[ST]`` where ``ST`` is this dataset's sample type."""
|
|
598
450
|
return SampleBatch[self.sample_type]
|
|
599
451
|
|
|
600
|
-
def __init__(
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
452
|
+
def __init__(
|
|
453
|
+
self,
|
|
454
|
+
source: DataSource | str | None = None,
|
|
455
|
+
metadata_url: str | None = None,
|
|
456
|
+
*,
|
|
457
|
+
url: str | None = None,
|
|
458
|
+
) -> None:
|
|
606
459
|
"""Create a dataset from a DataSource or URL.
|
|
607
460
|
|
|
608
461
|
Args:
|
|
@@ -620,28 +473,21 @@ class Dataset( Generic[ST] ):
|
|
|
620
473
|
"""
|
|
621
474
|
super().__init__()
|
|
622
475
|
|
|
623
|
-
# Handle backward compatibility: url= keyword argument
|
|
624
476
|
if source is None and url is not None:
|
|
625
477
|
source = url
|
|
626
478
|
elif source is None:
|
|
627
479
|
raise TypeError("Dataset() missing required argument: 'source' or 'url'")
|
|
628
480
|
|
|
629
|
-
# Normalize source: strings become URLSource for backward compatibility
|
|
630
481
|
if isinstance(source, str):
|
|
631
482
|
self._source: DataSource = URLSource(source)
|
|
632
483
|
self.url = source
|
|
633
484
|
else:
|
|
634
485
|
self._source = source
|
|
635
|
-
# For compatibility, expose URL if source has list_shards
|
|
636
486
|
shards = source.list_shards()
|
|
637
|
-
# Design note: Using first shard as url for legacy compatibility.
|
|
638
|
-
# Full shard list is available via list_shards() method.
|
|
639
487
|
self.url = shards[0] if shards else ""
|
|
640
488
|
|
|
641
489
|
self._metadata: dict[str, Any] | None = None
|
|
642
490
|
self.metadata_url: str | None = metadata_url
|
|
643
|
-
"""Optional URL to msgpack-encoded metadata for this dataset."""
|
|
644
|
-
|
|
645
491
|
self._output_lens: Lens | None = None
|
|
646
492
|
self._sample_type_cache: Type | None = None
|
|
647
493
|
|
|
@@ -650,50 +496,24 @@ class Dataset( Generic[ST] ):
|
|
|
650
496
|
"""The underlying data source for this dataset."""
|
|
651
497
|
return self._source
|
|
652
498
|
|
|
653
|
-
def as_type(
|
|
654
|
-
"""View this dataset through a different sample type
|
|
655
|
-
|
|
656
|
-
Args:
|
|
657
|
-
other: The target sample type to transform into. Must be a type
|
|
658
|
-
derived from ``PackableSample``.
|
|
659
|
-
|
|
660
|
-
Returns:
|
|
661
|
-
A new ``Dataset`` instance that yields samples of type ``other``
|
|
662
|
-
by applying the appropriate lens transformation from the global
|
|
663
|
-
``LensNetwork`` registry.
|
|
499
|
+
def as_type(self, other: Type[RT]) -> "Dataset[RT]":
|
|
500
|
+
"""View this dataset through a different sample type via a registered lens.
|
|
664
501
|
|
|
665
502
|
Raises:
|
|
666
|
-
ValueError: If no
|
|
667
|
-
sample type and the target type.
|
|
503
|
+
ValueError: If no lens exists between the current and target types.
|
|
668
504
|
"""
|
|
669
|
-
ret = Dataset[other](
|
|
670
|
-
# Get the singleton lens registry
|
|
505
|
+
ret = Dataset[other](self._source)
|
|
671
506
|
lenses = LensNetwork()
|
|
672
|
-
ret._output_lens = lenses.transform(
|
|
507
|
+
ret._output_lens = lenses.transform(self.sample_type, ret.sample_type)
|
|
673
508
|
return ret
|
|
674
509
|
|
|
675
510
|
@property
|
|
676
511
|
def shards(self) -> Iterator[str]:
|
|
677
|
-
"""Lazily iterate over shard identifiers.
|
|
678
|
-
|
|
679
|
-
Yields:
|
|
680
|
-
Shard identifiers (e.g., 'train-000000.tar', 'train-000001.tar').
|
|
681
|
-
|
|
682
|
-
Example:
|
|
683
|
-
::
|
|
684
|
-
|
|
685
|
-
>>> for shard in ds.shards:
|
|
686
|
-
... print(f"Processing {shard}")
|
|
687
|
-
"""
|
|
512
|
+
"""Lazily iterate over shard identifiers."""
|
|
688
513
|
return iter(self._source.list_shards())
|
|
689
514
|
|
|
690
515
|
def list_shards(self) -> list[str]:
|
|
691
|
-
"""
|
|
692
|
-
|
|
693
|
-
Returns:
|
|
694
|
-
A full (non-lazy) list of the individual ``tar`` files within the
|
|
695
|
-
source WebDataset.
|
|
696
|
-
"""
|
|
516
|
+
"""Return all shard paths/URLs as a list."""
|
|
697
517
|
return self._source.list_shards()
|
|
698
518
|
|
|
699
519
|
# Legacy alias for backwards compatibility
|
|
@@ -705,6 +525,7 @@ class Dataset( Generic[ST] ):
|
|
|
705
525
|
Use :meth:`list_shards` instead.
|
|
706
526
|
"""
|
|
707
527
|
import warnings
|
|
528
|
+
|
|
708
529
|
warnings.warn(
|
|
709
530
|
"shard_list is deprecated, use list_shards() instead",
|
|
710
531
|
DeprecationWarning,
|
|
@@ -713,40 +534,414 @@ class Dataset( Generic[ST] ):
|
|
|
713
534
|
return self.list_shards()
|
|
714
535
|
|
|
715
536
|
@property
|
|
716
|
-
def metadata(
|
|
717
|
-
"""Fetch and cache metadata from metadata_url.
|
|
718
|
-
|
|
719
|
-
Returns:
|
|
720
|
-
Deserialized metadata dictionary, or None if no metadata_url is set.
|
|
721
|
-
|
|
722
|
-
Raises:
|
|
723
|
-
requests.HTTPError: If metadata fetch fails.
|
|
724
|
-
"""
|
|
537
|
+
def metadata(self) -> dict[str, Any] | None:
|
|
538
|
+
"""Fetch and cache metadata from metadata_url, or ``None`` if unset."""
|
|
725
539
|
if self.metadata_url is None:
|
|
726
540
|
return None
|
|
727
541
|
|
|
728
542
|
if self._metadata is None:
|
|
729
|
-
with requests.get(
|
|
543
|
+
with requests.get(self.metadata_url, stream=True) as response:
|
|
730
544
|
response.raise_for_status()
|
|
731
|
-
self._metadata = msgpack.unpackb(
|
|
732
|
-
|
|
545
|
+
self._metadata = msgpack.unpackb(response.content, raw=False)
|
|
546
|
+
|
|
733
547
|
# Use our cached values
|
|
734
548
|
return self._metadata
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
549
|
+
|
|
550
|
+
##
|
|
551
|
+
# Convenience methods (GH#38 developer experience)
|
|
552
|
+
|
|
553
|
+
@property
|
|
554
|
+
def schema(self) -> dict[str, type]:
|
|
555
|
+
"""Field names and types for this dataset's sample type.
|
|
556
|
+
|
|
557
|
+
Examples:
|
|
558
|
+
>>> ds = Dataset[MyData]("data.tar")
|
|
559
|
+
>>> ds.schema
|
|
560
|
+
{'name': <class 'str'>, 'embedding': numpy.ndarray}
|
|
561
|
+
"""
|
|
562
|
+
st = self.sample_type
|
|
563
|
+
if st is DictSample:
|
|
564
|
+
return {"_data": dict}
|
|
565
|
+
if dataclasses.is_dataclass(st):
|
|
566
|
+
return {f.name: f.type for f in dataclasses.fields(st)}
|
|
567
|
+
return {}
|
|
568
|
+
|
|
569
|
+
@property
|
|
570
|
+
def column_names(self) -> list[str]:
|
|
571
|
+
"""List of field names for this dataset's sample type."""
|
|
572
|
+
st = self.sample_type
|
|
573
|
+
if dataclasses.is_dataclass(st):
|
|
574
|
+
return [f.name for f in dataclasses.fields(st)]
|
|
575
|
+
return []
|
|
576
|
+
|
|
577
|
+
def __iter__(self) -> Iterator[ST]:
|
|
578
|
+
"""Shorthand for ``ds.ordered()``."""
|
|
579
|
+
return iter(self.ordered())
|
|
580
|
+
|
|
581
|
+
def __len__(self) -> int:
|
|
582
|
+
"""Total sample count (iterates all shards on first call, then cached)."""
|
|
583
|
+
if not hasattr(self, "_len_cache"):
|
|
584
|
+
self._len_cache: int = sum(1 for _ in self.ordered())
|
|
585
|
+
return self._len_cache
|
|
586
|
+
|
|
587
|
+
def head(self, n: int = 5) -> list[ST]:
|
|
588
|
+
"""Return the first *n* samples from the dataset.
|
|
589
|
+
|
|
590
|
+
Args:
|
|
591
|
+
n: Number of samples to return. Default: 5.
|
|
592
|
+
|
|
593
|
+
Returns:
|
|
594
|
+
List of up to *n* samples in shard order.
|
|
595
|
+
|
|
596
|
+
Examples:
|
|
597
|
+
>>> samples = ds.head(3)
|
|
598
|
+
>>> len(samples)
|
|
599
|
+
3
|
|
600
|
+
"""
|
|
601
|
+
return list(itertools.islice(self.ordered(), n))
|
|
602
|
+
|
|
603
|
+
def get(self, key: str) -> ST:
|
|
604
|
+
"""Retrieve a single sample by its ``__key__``.
|
|
605
|
+
|
|
606
|
+
Scans shards sequentially until a sample with a matching key is found.
|
|
607
|
+
This is O(n) for streaming datasets.
|
|
740
608
|
|
|
741
609
|
Args:
|
|
742
|
-
|
|
743
|
-
Default: None (unbatched). If ``None``, iterates over one
|
|
744
|
-
sample at a time with no batch dimension.
|
|
610
|
+
key: The WebDataset ``__key__`` string to search for.
|
|
745
611
|
|
|
746
612
|
Returns:
|
|
747
|
-
|
|
748
|
-
|
|
613
|
+
The matching sample.
|
|
614
|
+
|
|
615
|
+
Raises:
|
|
616
|
+
SampleKeyError: If no sample with the given key exists.
|
|
749
617
|
|
|
618
|
+
Examples:
|
|
619
|
+
>>> sample = ds.get("00000001-0001-1000-8000-010000000000")
|
|
620
|
+
"""
|
|
621
|
+
pipeline = wds.pipeline.DataPipeline(
|
|
622
|
+
_ShardListStage(self._source),
|
|
623
|
+
wds.shardlists.split_by_worker,
|
|
624
|
+
_StreamOpenerStage(self._source),
|
|
625
|
+
wds.tariterators.tar_file_expander,
|
|
626
|
+
wds.tariterators.group_by_keys,
|
|
627
|
+
)
|
|
628
|
+
for raw_sample in pipeline:
|
|
629
|
+
if raw_sample.get("__key__") == key:
|
|
630
|
+
return self.wrap(raw_sample)
|
|
631
|
+
raise SampleKeyError(key)
|
|
632
|
+
|
|
633
|
+
def describe(self) -> dict[str, Any]:
|
|
634
|
+
"""Summary statistics: sample_type, fields, num_shards, shards, url, metadata."""
|
|
635
|
+
shards = self.list_shards()
|
|
636
|
+
return {
|
|
637
|
+
"sample_type": self.sample_type.__name__,
|
|
638
|
+
"fields": self.schema,
|
|
639
|
+
"num_shards": len(shards),
|
|
640
|
+
"shards": shards,
|
|
641
|
+
"url": self.url,
|
|
642
|
+
"metadata": self.metadata,
|
|
643
|
+
}
|
|
644
|
+
|
|
645
|
+
def filter(self, predicate: Callable[[ST], bool]) -> "Dataset[ST]":
|
|
646
|
+
"""Return a new dataset that yields only samples matching *predicate*.
|
|
647
|
+
|
|
648
|
+
The filter is applied lazily during iteration — no data is copied.
|
|
649
|
+
|
|
650
|
+
Args:
|
|
651
|
+
predicate: A function that takes a sample and returns ``True``
|
|
652
|
+
to keep it or ``False`` to discard it.
|
|
653
|
+
|
|
654
|
+
Returns:
|
|
655
|
+
A new ``Dataset`` whose iterators apply the filter.
|
|
656
|
+
|
|
657
|
+
Examples:
|
|
658
|
+
>>> long_names = ds.filter(lambda s: len(s.name) > 10)
|
|
659
|
+
>>> for sample in long_names:
|
|
660
|
+
... assert len(sample.name) > 10
|
|
661
|
+
"""
|
|
662
|
+
filtered = Dataset[self.sample_type](self._source, self.metadata_url)
|
|
663
|
+
filtered._sample_type_cache = self._sample_type_cache
|
|
664
|
+
filtered._output_lens = self._output_lens
|
|
665
|
+
filtered._filter_fn = predicate
|
|
666
|
+
# Preserve any existing filters
|
|
667
|
+
parent_filters = getattr(self, "_filter_fn", None)
|
|
668
|
+
if parent_filters is not None:
|
|
669
|
+
outer = parent_filters
|
|
670
|
+
filtered._filter_fn = lambda s: outer(s) and predicate(s)
|
|
671
|
+
# Preserve any existing map
|
|
672
|
+
if hasattr(self, "_map_fn"):
|
|
673
|
+
filtered._map_fn = self._map_fn
|
|
674
|
+
return filtered
|
|
675
|
+
|
|
676
|
+
def map(self, fn: Callable[[ST], Any]) -> "Dataset":
|
|
677
|
+
"""Return a new dataset that applies *fn* to each sample during iteration.
|
|
678
|
+
|
|
679
|
+
The mapping is applied lazily during iteration — no data is copied.
|
|
680
|
+
|
|
681
|
+
Args:
|
|
682
|
+
fn: A function that takes a sample of type ``ST`` and returns
|
|
683
|
+
a transformed value.
|
|
684
|
+
|
|
685
|
+
Returns:
|
|
686
|
+
A new ``Dataset`` whose iterators apply the mapping.
|
|
687
|
+
|
|
688
|
+
Examples:
|
|
689
|
+
>>> names = ds.map(lambda s: s.name)
|
|
690
|
+
>>> for name in names:
|
|
691
|
+
... print(name)
|
|
692
|
+
"""
|
|
693
|
+
mapped = Dataset[self.sample_type](self._source, self.metadata_url)
|
|
694
|
+
mapped._sample_type_cache = self._sample_type_cache
|
|
695
|
+
mapped._output_lens = self._output_lens
|
|
696
|
+
mapped._map_fn = fn
|
|
697
|
+
# Preserve any existing map
|
|
698
|
+
if hasattr(self, "_map_fn"):
|
|
699
|
+
outer = self._map_fn
|
|
700
|
+
mapped._map_fn = lambda s: fn(outer(s))
|
|
701
|
+
# Preserve any existing filter
|
|
702
|
+
if hasattr(self, "_filter_fn"):
|
|
703
|
+
mapped._filter_fn = self._filter_fn
|
|
704
|
+
return mapped
|
|
705
|
+
|
|
706
|
+
def process_shards(
|
|
707
|
+
self,
|
|
708
|
+
fn: Callable[[list[ST]], Any],
|
|
709
|
+
*,
|
|
710
|
+
shards: list[str] | None = None,
|
|
711
|
+
) -> dict[str, Any]:
|
|
712
|
+
"""Process each shard independently, collecting per-shard results.
|
|
713
|
+
|
|
714
|
+
Unlike :meth:`map` (which is lazy and per-sample), this method eagerly
|
|
715
|
+
processes each shard in turn, calling *fn* with the full list of samples
|
|
716
|
+
from that shard. If some shards fail, raises
|
|
717
|
+
:class:`~atdata._exceptions.PartialFailureError` containing both the
|
|
718
|
+
successful results and the per-shard errors.
|
|
719
|
+
|
|
720
|
+
Args:
|
|
721
|
+
fn: Function receiving a list of samples from one shard and
|
|
722
|
+
returning an arbitrary result.
|
|
723
|
+
shards: Optional list of shard identifiers to process. If ``None``,
|
|
724
|
+
processes all shards in the dataset. Useful for retrying only
|
|
725
|
+
the failed shards from a previous ``PartialFailureError``.
|
|
726
|
+
|
|
727
|
+
Returns:
|
|
728
|
+
Dict mapping shard identifier to *fn*'s return value for each shard.
|
|
729
|
+
|
|
730
|
+
Raises:
|
|
731
|
+
PartialFailureError: If at least one shard fails. The exception
|
|
732
|
+
carries ``.succeeded_shards``, ``.failed_shards``, ``.errors``,
|
|
733
|
+
and ``.results`` for inspection and retry.
|
|
734
|
+
|
|
735
|
+
Examples:
|
|
736
|
+
>>> results = ds.process_shards(lambda samples: len(samples))
|
|
737
|
+
>>> # On partial failure, retry just the failed shards:
|
|
738
|
+
>>> try:
|
|
739
|
+
... results = ds.process_shards(expensive_fn)
|
|
740
|
+
... except PartialFailureError as e:
|
|
741
|
+
... retry = ds.process_shards(expensive_fn, shards=e.failed_shards)
|
|
742
|
+
"""
|
|
743
|
+
from ._logging import get_logger
|
|
744
|
+
|
|
745
|
+
log = get_logger()
|
|
746
|
+
shard_ids = shards or self.list_shards()
|
|
747
|
+
log.info("process_shards: starting %d shards", len(shard_ids))
|
|
748
|
+
|
|
749
|
+
succeeded: list[str] = []
|
|
750
|
+
failed: list[str] = []
|
|
751
|
+
errors: dict[str, Exception] = {}
|
|
752
|
+
results: dict[str, Any] = {}
|
|
753
|
+
|
|
754
|
+
for shard_id in shard_ids:
|
|
755
|
+
try:
|
|
756
|
+
shard_ds = Dataset[self.sample_type](shard_id)
|
|
757
|
+
shard_ds._sample_type_cache = self._sample_type_cache
|
|
758
|
+
samples = list(shard_ds.ordered())
|
|
759
|
+
results[shard_id] = fn(samples)
|
|
760
|
+
succeeded.append(shard_id)
|
|
761
|
+
log.debug("process_shards: shard ok %s", shard_id)
|
|
762
|
+
except Exception as exc:
|
|
763
|
+
failed.append(shard_id)
|
|
764
|
+
errors[shard_id] = exc
|
|
765
|
+
log.warning("process_shards: shard failed %s: %s", shard_id, exc)
|
|
766
|
+
|
|
767
|
+
if failed:
|
|
768
|
+
log.error(
|
|
769
|
+
"process_shards: %d/%d shards failed",
|
|
770
|
+
len(failed),
|
|
771
|
+
len(shard_ids),
|
|
772
|
+
)
|
|
773
|
+
raise PartialFailureError(
|
|
774
|
+
succeeded_shards=succeeded,
|
|
775
|
+
failed_shards=failed,
|
|
776
|
+
errors=errors,
|
|
777
|
+
results=results,
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
log.info("process_shards: all %d shards succeeded", len(shard_ids))
|
|
781
|
+
return results
|
|
782
|
+
|
|
783
|
+
def select(self, indices: Sequence[int]) -> list[ST]:
|
|
784
|
+
"""Return samples at the given integer indices.
|
|
785
|
+
|
|
786
|
+
Iterates through the dataset in order and collects samples whose
|
|
787
|
+
positional index matches. This is O(n) for streaming datasets.
|
|
788
|
+
|
|
789
|
+
Args:
|
|
790
|
+
indices: Sequence of zero-based indices to select.
|
|
791
|
+
|
|
792
|
+
Returns:
|
|
793
|
+
List of samples at the requested positions, in index order.
|
|
794
|
+
|
|
795
|
+
Examples:
|
|
796
|
+
>>> samples = ds.select([0, 5, 10])
|
|
797
|
+
>>> len(samples)
|
|
798
|
+
3
|
|
799
|
+
"""
|
|
800
|
+
if not indices:
|
|
801
|
+
return []
|
|
802
|
+
target = set(indices)
|
|
803
|
+
max_idx = max(indices)
|
|
804
|
+
result: dict[int, ST] = {}
|
|
805
|
+
for i, sample in enumerate(self.ordered()):
|
|
806
|
+
if i in target:
|
|
807
|
+
result[i] = sample
|
|
808
|
+
if i >= max_idx:
|
|
809
|
+
break
|
|
810
|
+
return [result[i] for i in indices if i in result]
|
|
811
|
+
|
|
812
|
+
def query(
|
|
813
|
+
self,
|
|
814
|
+
where: "Callable[[pd.DataFrame], pd.Series]",
|
|
815
|
+
) -> "list[SampleLocation]":
|
|
816
|
+
"""Query this dataset using per-shard manifest metadata.
|
|
817
|
+
|
|
818
|
+
Requires manifests to have been generated during shard writing.
|
|
819
|
+
Discovers manifest files alongside the tar shards, loads them,
|
|
820
|
+
and executes a two-phase query (shard-level aggregate pruning,
|
|
821
|
+
then sample-level parquet filtering).
|
|
822
|
+
|
|
823
|
+
Args:
|
|
824
|
+
where: Predicate function that receives a pandas DataFrame
|
|
825
|
+
of manifest fields and returns a boolean Series selecting
|
|
826
|
+
matching rows.
|
|
827
|
+
|
|
828
|
+
Returns:
|
|
829
|
+
List of ``SampleLocation`` for matching samples.
|
|
830
|
+
|
|
831
|
+
Raises:
|
|
832
|
+
FileNotFoundError: If no manifest files are found alongside shards.
|
|
833
|
+
|
|
834
|
+
Examples:
|
|
835
|
+
>>> locs = ds.query(where=lambda df: df["confidence"] > 0.9)
|
|
836
|
+
>>> len(locs)
|
|
837
|
+
42
|
|
838
|
+
"""
|
|
839
|
+
from .manifest import QueryExecutor
|
|
840
|
+
|
|
841
|
+
shard_urls = self.list_shards()
|
|
842
|
+
executor = QueryExecutor.from_shard_urls(shard_urls)
|
|
843
|
+
return executor.query(where=where)
|
|
844
|
+
|
|
845
|
+
def to_pandas(self, limit: int | None = None) -> "pd.DataFrame":
|
|
846
|
+
"""Materialize the dataset (or first *limit* samples) as a DataFrame.
|
|
847
|
+
|
|
848
|
+
Args:
|
|
849
|
+
limit: Maximum number of samples to include. ``None`` means all
|
|
850
|
+
samples (may use significant memory for large datasets).
|
|
851
|
+
|
|
852
|
+
Returns:
|
|
853
|
+
A pandas DataFrame with one row per sample and columns matching
|
|
854
|
+
the sample fields.
|
|
855
|
+
|
|
856
|
+
Warning:
|
|
857
|
+
With ``limit=None`` this loads the entire dataset into memory.
|
|
858
|
+
|
|
859
|
+
Examples:
|
|
860
|
+
>>> df = ds.to_pandas(limit=100)
|
|
861
|
+
>>> df.columns.tolist()
|
|
862
|
+
['name', 'embedding']
|
|
863
|
+
"""
|
|
864
|
+
samples = self.head(limit) if limit is not None else list(self.ordered())
|
|
865
|
+
rows = [
|
|
866
|
+
asdict(s) if dataclasses.is_dataclass(s) else s.to_dict() for s in samples
|
|
867
|
+
]
|
|
868
|
+
return pd.DataFrame(rows)
|
|
869
|
+
|
|
870
|
+
def to_dict(self, limit: int | None = None) -> dict[str, list[Any]]:
|
|
871
|
+
"""Materialize the dataset as a column-oriented dictionary.
|
|
872
|
+
|
|
873
|
+
Args:
|
|
874
|
+
limit: Maximum number of samples to include. ``None`` means all.
|
|
875
|
+
|
|
876
|
+
Returns:
|
|
877
|
+
Dictionary mapping field names to lists of values (one entry
|
|
878
|
+
per sample).
|
|
879
|
+
|
|
880
|
+
Warning:
|
|
881
|
+
With ``limit=None`` this loads the entire dataset into memory.
|
|
882
|
+
|
|
883
|
+
Examples:
|
|
884
|
+
>>> d = ds.to_dict(limit=10)
|
|
885
|
+
>>> d.keys()
|
|
886
|
+
dict_keys(['name', 'embedding'])
|
|
887
|
+
>>> len(d['name'])
|
|
888
|
+
10
|
|
889
|
+
"""
|
|
890
|
+
samples = self.head(limit) if limit is not None else list(self.ordered())
|
|
891
|
+
if not samples:
|
|
892
|
+
return {}
|
|
893
|
+
if dataclasses.is_dataclass(samples[0]):
|
|
894
|
+
fields = [f.name for f in dataclasses.fields(samples[0])]
|
|
895
|
+
return {f: [getattr(s, f) for s in samples] for f in fields}
|
|
896
|
+
# DictSample path
|
|
897
|
+
keys = samples[0].keys()
|
|
898
|
+
return {k: [s[k] for s in samples] for k in keys}
|
|
899
|
+
|
|
900
|
+
def _post_wrap_stages(self) -> list:
|
|
901
|
+
"""Build extra pipeline stages for filter/map set via .filter()/.map()."""
|
|
902
|
+
stages: list = []
|
|
903
|
+
filter_fn = getattr(self, "_filter_fn", None)
|
|
904
|
+
if filter_fn is not None:
|
|
905
|
+
stages.append(wds.filters.select(filter_fn))
|
|
906
|
+
map_fn = getattr(self, "_map_fn", None)
|
|
907
|
+
if map_fn is not None:
|
|
908
|
+
stages.append(wds.filters.map(map_fn))
|
|
909
|
+
return stages
|
|
910
|
+
|
|
911
|
+
@overload
|
|
912
|
+
def ordered(
|
|
913
|
+
self,
|
|
914
|
+
batch_size: None = None,
|
|
915
|
+
) -> Iterable[ST]: ...
|
|
916
|
+
|
|
917
|
+
@overload
|
|
918
|
+
def ordered(
|
|
919
|
+
self,
|
|
920
|
+
batch_size: int,
|
|
921
|
+
) -> Iterable[SampleBatch[ST]]: ...
|
|
922
|
+
|
|
923
|
+
def ordered(
|
|
924
|
+
self,
|
|
925
|
+
batch_size: int | None = None,
|
|
926
|
+
) -> Iterable[ST] | Iterable[SampleBatch[ST]]:
|
|
927
|
+
"""Iterate over the dataset in order.
|
|
928
|
+
|
|
929
|
+
Args:
|
|
930
|
+
batch_size: The size of iterated batches. Default: None (unbatched).
|
|
931
|
+
If ``None``, iterates over one sample at a time with no batch
|
|
932
|
+
dimension.
|
|
933
|
+
|
|
934
|
+
Returns:
|
|
935
|
+
A data pipeline that iterates over the dataset in its original
|
|
936
|
+
sample order. When ``batch_size`` is ``None``, yields individual
|
|
937
|
+
samples of type ``ST``. When ``batch_size`` is an integer, yields
|
|
938
|
+
``SampleBatch[ST]`` instances containing that many samples.
|
|
939
|
+
|
|
940
|
+
Examples:
|
|
941
|
+
>>> for sample in ds.ordered():
|
|
942
|
+
... process(sample) # sample is ST
|
|
943
|
+
>>> for batch in ds.ordered(batch_size=32):
|
|
944
|
+
... process(batch) # batch is SampleBatch[ST]
|
|
750
945
|
"""
|
|
751
946
|
if batch_size is None:
|
|
752
947
|
return wds.pipeline.DataPipeline(
|
|
@@ -755,7 +950,8 @@ class Dataset( Generic[ST] ):
|
|
|
755
950
|
_StreamOpenerStage(self._source),
|
|
756
951
|
wds.tariterators.tar_file_expander,
|
|
757
952
|
wds.tariterators.group_by_keys,
|
|
758
|
-
wds.filters.map(
|
|
953
|
+
wds.filters.map(self.wrap),
|
|
954
|
+
*self._post_wrap_stages(),
|
|
759
955
|
)
|
|
760
956
|
|
|
761
957
|
return wds.pipeline.DataPipeline(
|
|
@@ -764,15 +960,33 @@ class Dataset( Generic[ST] ):
|
|
|
764
960
|
_StreamOpenerStage(self._source),
|
|
765
961
|
wds.tariterators.tar_file_expander,
|
|
766
962
|
wds.tariterators.group_by_keys,
|
|
767
|
-
wds.filters.batched(
|
|
768
|
-
wds.filters.map(
|
|
963
|
+
wds.filters.batched(batch_size),
|
|
964
|
+
wds.filters.map(self.wrap_batch),
|
|
769
965
|
)
|
|
770
966
|
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
967
|
+
@overload
|
|
968
|
+
def shuffled(
|
|
969
|
+
self,
|
|
970
|
+
buffer_shards: int = 100,
|
|
971
|
+
buffer_samples: int = 10_000,
|
|
972
|
+
batch_size: None = None,
|
|
973
|
+
) -> Iterable[ST]: ...
|
|
974
|
+
|
|
975
|
+
@overload
|
|
976
|
+
def shuffled(
|
|
977
|
+
self,
|
|
978
|
+
buffer_shards: int = 100,
|
|
979
|
+
buffer_samples: int = 10_000,
|
|
980
|
+
*,
|
|
981
|
+
batch_size: int,
|
|
982
|
+
) -> Iterable[SampleBatch[ST]]: ...
|
|
983
|
+
|
|
984
|
+
def shuffled(
|
|
985
|
+
self,
|
|
986
|
+
buffer_shards: int = 100,
|
|
987
|
+
buffer_samples: int = 10_000,
|
|
988
|
+
batch_size: int | None = None,
|
|
989
|
+
) -> Iterable[ST] | Iterable[SampleBatch[ST]]:
|
|
776
990
|
"""Iterate over the dataset in random order.
|
|
777
991
|
|
|
778
992
|
Args:
|
|
@@ -787,216 +1001,147 @@ class Dataset( Generic[ST] ):
|
|
|
787
1001
|
dimension.
|
|
788
1002
|
|
|
789
1003
|
Returns:
|
|
790
|
-
A
|
|
791
|
-
|
|
792
|
-
``
|
|
793
|
-
samples.
|
|
1004
|
+
A data pipeline that iterates over the dataset in randomized order.
|
|
1005
|
+
When ``batch_size`` is ``None``, yields individual samples of type
|
|
1006
|
+
``ST``. When ``batch_size`` is an integer, yields ``SampleBatch[ST]``
|
|
1007
|
+
instances containing that many samples.
|
|
1008
|
+
|
|
1009
|
+
Examples:
|
|
1010
|
+
>>> for sample in ds.shuffled():
|
|
1011
|
+
... process(sample) # sample is ST
|
|
1012
|
+
>>> for batch in ds.shuffled(batch_size=32):
|
|
1013
|
+
... process(batch) # batch is SampleBatch[ST]
|
|
794
1014
|
"""
|
|
795
1015
|
if batch_size is None:
|
|
796
1016
|
return wds.pipeline.DataPipeline(
|
|
797
1017
|
_ShardListStage(self._source),
|
|
798
|
-
wds.filters.shuffle(
|
|
1018
|
+
wds.filters.shuffle(buffer_shards),
|
|
799
1019
|
wds.shardlists.split_by_worker,
|
|
800
1020
|
_StreamOpenerStage(self._source),
|
|
801
1021
|
wds.tariterators.tar_file_expander,
|
|
802
1022
|
wds.tariterators.group_by_keys,
|
|
803
|
-
wds.filters.shuffle(
|
|
804
|
-
wds.filters.map(
|
|
1023
|
+
wds.filters.shuffle(buffer_samples),
|
|
1024
|
+
wds.filters.map(self.wrap),
|
|
1025
|
+
*self._post_wrap_stages(),
|
|
805
1026
|
)
|
|
806
1027
|
|
|
807
1028
|
return wds.pipeline.DataPipeline(
|
|
808
1029
|
_ShardListStage(self._source),
|
|
809
|
-
wds.filters.shuffle(
|
|
1030
|
+
wds.filters.shuffle(buffer_shards),
|
|
810
1031
|
wds.shardlists.split_by_worker,
|
|
811
1032
|
_StreamOpenerStage(self._source),
|
|
812
1033
|
wds.tariterators.tar_file_expander,
|
|
813
1034
|
wds.tariterators.group_by_keys,
|
|
814
|
-
wds.filters.shuffle(
|
|
815
|
-
wds.filters.batched(
|
|
816
|
-
wds.filters.map(
|
|
1035
|
+
wds.filters.shuffle(buffer_samples),
|
|
1036
|
+
wds.filters.batched(batch_size),
|
|
1037
|
+
wds.filters.map(self.wrap_batch),
|
|
817
1038
|
)
|
|
818
|
-
|
|
1039
|
+
|
|
819
1040
|
# Design note: Uses pandas for parquet export. Could be replaced with
|
|
820
1041
|
# direct fastparquet calls to reduce dependencies if needed.
|
|
821
|
-
def to_parquet(
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
Useful for interoperability with data analysis tools.
|
|
1042
|
+
def to_parquet(
|
|
1043
|
+
self,
|
|
1044
|
+
path: Pathlike,
|
|
1045
|
+
sample_map: Optional[SampleExportMap] = None,
|
|
1046
|
+
maxcount: Optional[int] = None,
|
|
1047
|
+
**kwargs,
|
|
1048
|
+
):
|
|
1049
|
+
"""Export dataset to parquet file(s).
|
|
830
1050
|
|
|
831
1051
|
Args:
|
|
832
|
-
path: Output path
|
|
833
|
-
|
|
834
|
-
sample_map:
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
Warning:
|
|
842
|
-
**Memory Usage**: When ``maxcount=None`` (default), this method loads
|
|
843
|
-
the **entire dataset into memory** as a pandas DataFrame before writing.
|
|
844
|
-
For large datasets, this can cause memory exhaustion.
|
|
845
|
-
|
|
846
|
-
For datasets larger than available RAM, always specify ``maxcount``::
|
|
847
|
-
|
|
848
|
-
# Safe for large datasets - processes in chunks
|
|
849
|
-
ds.to_parquet("output.parquet", maxcount=10000)
|
|
850
|
-
|
|
851
|
-
This creates multiple parquet files: ``output-000000.parquet``,
|
|
852
|
-
``output-000001.parquet``, etc.
|
|
853
|
-
|
|
854
|
-
Example:
|
|
855
|
-
::
|
|
856
|
-
|
|
857
|
-
>>> ds = Dataset[MySample]("data.tar")
|
|
858
|
-
>>> # Small dataset - load all at once
|
|
859
|
-
>>> ds.to_parquet("output.parquet")
|
|
860
|
-
>>>
|
|
861
|
-
>>> # Large dataset - process in chunks
|
|
862
|
-
>>> ds.to_parquet("output.parquet", maxcount=50000)
|
|
1052
|
+
path: Output path. With *maxcount*, files are named
|
|
1053
|
+
``{stem}-{segment:06d}.parquet``.
|
|
1054
|
+
sample_map: Convert sample to dict. Defaults to ``dataclasses.asdict``.
|
|
1055
|
+
maxcount: Split into files of at most this many samples.
|
|
1056
|
+
Without it, the entire dataset is loaded into memory.
|
|
1057
|
+
**kwargs: Passed to ``pandas.DataFrame.to_parquet()``.
|
|
1058
|
+
|
|
1059
|
+
Examples:
|
|
1060
|
+
>>> ds.to_parquet("output.parquet", maxcount=50000)
|
|
863
1061
|
"""
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
# Normalize args
|
|
867
|
-
path = Path( path )
|
|
1062
|
+
path = Path(path)
|
|
868
1063
|
if sample_map is None:
|
|
869
1064
|
sample_map = asdict
|
|
870
|
-
|
|
871
|
-
verbose = kwargs.get( 'verbose', False )
|
|
872
|
-
|
|
873
|
-
it = self.ordered( batch_size = None )
|
|
874
|
-
if verbose:
|
|
875
|
-
it = tqdm( it )
|
|
876
|
-
|
|
877
|
-
#
|
|
878
1065
|
|
|
879
1066
|
if maxcount is None:
|
|
880
|
-
|
|
881
|
-
df
|
|
882
|
-
for x in self.ordered( batch_size = None ) ] )
|
|
883
|
-
df.to_parquet( path, **kwargs )
|
|
884
|
-
|
|
1067
|
+
df = pd.DataFrame([sample_map(x) for x in self.ordered(batch_size=None)])
|
|
1068
|
+
df.to_parquet(path, **kwargs)
|
|
885
1069
|
else:
|
|
886
|
-
# Load and save dataset in segments of size `maxcount`
|
|
887
|
-
|
|
888
1070
|
cur_segment = 0
|
|
889
|
-
cur_buffer = []
|
|
890
|
-
path_template = (
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
cur_path = path_template.format(
|
|
898
|
-
|
|
899
|
-
df.to_parquet( cur_path, **kwargs )
|
|
900
|
-
|
|
1071
|
+
cur_buffer: list = []
|
|
1072
|
+
path_template = (
|
|
1073
|
+
path.parent / f"{path.stem}-{{:06d}}{path.suffix}"
|
|
1074
|
+
).as_posix()
|
|
1075
|
+
|
|
1076
|
+
for x in self.ordered(batch_size=None):
|
|
1077
|
+
cur_buffer.append(sample_map(x))
|
|
1078
|
+
if len(cur_buffer) >= maxcount:
|
|
1079
|
+
cur_path = path_template.format(cur_segment)
|
|
1080
|
+
pd.DataFrame(cur_buffer).to_parquet(cur_path, **kwargs)
|
|
901
1081
|
cur_segment += 1
|
|
902
1082
|
cur_buffer = []
|
|
903
|
-
|
|
904
|
-
if len( cur_buffer ) > 0:
|
|
905
|
-
# Write one last segment with remainder
|
|
906
|
-
cur_path = path_template.format( cur_segment )
|
|
907
|
-
df = pd.DataFrame( cur_buffer )
|
|
908
|
-
df.to_parquet( cur_path, **kwargs )
|
|
909
|
-
|
|
910
|
-
def wrap( self, sample: WDSRawSample ) -> ST:
|
|
911
|
-
"""Wrap a raw msgpack sample into the appropriate dataset-specific type.
|
|
912
1083
|
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
1084
|
+
if cur_buffer:
|
|
1085
|
+
cur_path = path_template.format(cur_segment)
|
|
1086
|
+
pd.DataFrame(cur_buffer).to_parquet(cur_path, **kwargs)
|
|
916
1087
|
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
if not isinstance(sample[
|
|
924
|
-
raise ValueError(
|
|
1088
|
+
def wrap(self, sample: WDSRawSample) -> ST:
|
|
1089
|
+
"""Deserialize a raw WDS sample dict into type ``ST``."""
|
|
1090
|
+
if "msgpack" not in sample:
|
|
1091
|
+
raise ValueError(
|
|
1092
|
+
f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}"
|
|
1093
|
+
)
|
|
1094
|
+
if not isinstance(sample["msgpack"], bytes):
|
|
1095
|
+
raise ValueError(
|
|
1096
|
+
f"Expected sample['msgpack'] to be bytes, got {type(sample['msgpack']).__name__}"
|
|
1097
|
+
)
|
|
925
1098
|
|
|
926
1099
|
if self._output_lens is None:
|
|
927
|
-
return self.sample_type.from_bytes(
|
|
928
|
-
|
|
929
|
-
source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] )
|
|
930
|
-
return self._output_lens( source_sample )
|
|
1100
|
+
return self.sample_type.from_bytes(sample["msgpack"])
|
|
931
1101
|
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
Args:
|
|
936
|
-
batch: A dictionary containing a ``'msgpack'`` key with a list of
|
|
937
|
-
serialized sample bytes.
|
|
938
|
-
|
|
939
|
-
Returns:
|
|
940
|
-
A ``SampleBatch[ST]`` containing deserialized samples, optionally
|
|
941
|
-
transformed through a lens if ``as_type()`` was called.
|
|
1102
|
+
source_sample = self._output_lens.source_type.from_bytes(sample["msgpack"])
|
|
1103
|
+
return self._output_lens(source_sample)
|
|
942
1104
|
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
aggregates them into a batch.
|
|
946
|
-
"""
|
|
1105
|
+
def wrap_batch(self, batch: WDSRawBatch) -> SampleBatch[ST]:
|
|
1106
|
+
"""Deserialize a raw WDS batch dict into ``SampleBatch[ST]``."""
|
|
947
1107
|
|
|
948
|
-
if
|
|
949
|
-
raise ValueError(
|
|
1108
|
+
if "msgpack" not in batch:
|
|
1109
|
+
raise ValueError(
|
|
1110
|
+
f"Batch missing 'msgpack' key, got keys: {list(batch.keys())}"
|
|
1111
|
+
)
|
|
950
1112
|
|
|
951
1113
|
if self._output_lens is None:
|
|
952
|
-
batch_unpacked = [
|
|
953
|
-
|
|
954
|
-
|
|
1114
|
+
batch_unpacked = [
|
|
1115
|
+
self.sample_type.from_bytes(bs) for bs in batch["msgpack"]
|
|
1116
|
+
]
|
|
1117
|
+
return SampleBatch[self.sample_type](batch_unpacked)
|
|
955
1118
|
|
|
956
|
-
batch_source = [
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
return SampleBatch[self.sample_type](
|
|
1119
|
+
batch_source = [
|
|
1120
|
+
self._output_lens.source_type.from_bytes(bs) for bs in batch["msgpack"]
|
|
1121
|
+
]
|
|
1122
|
+
batch_view = [self._output_lens(s) for s in batch_source]
|
|
1123
|
+
return SampleBatch[self.sample_type](batch_view)
|
|
961
1124
|
|
|
962
1125
|
|
|
963
|
-
_T = TypeVar(
|
|
1126
|
+
_T = TypeVar("_T")
|
|
964
1127
|
|
|
965
1128
|
|
|
966
1129
|
@dataclass_transform()
|
|
967
|
-
def packable(
|
|
968
|
-
"""
|
|
969
|
-
|
|
970
|
-
This decorator transforms a class into a dataclass that inherits from
|
|
971
|
-
``PackableSample``, enabling automatic msgpack serialization/deserialization
|
|
972
|
-
with special handling for NDArray fields.
|
|
973
|
-
|
|
974
|
-
The resulting class satisfies the ``Packable`` protocol, making it compatible
|
|
975
|
-
with all atdata APIs that accept packable types (e.g., ``publish_schema``,
|
|
976
|
-
lens transformations, etc.).
|
|
1130
|
+
def packable(cls: type[_T]) -> type[Packable]:
|
|
1131
|
+
"""Convert a class into a ``PackableSample`` dataclass with msgpack serialization.
|
|
977
1132
|
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
Returns:
|
|
982
|
-
A new dataclass that inherits from ``PackableSample`` with the same
|
|
983
|
-
name and annotations as the original class. The class satisfies the
|
|
984
|
-
``Packable`` protocol and can be used with ``Type[Packable]`` signatures.
|
|
1133
|
+
The resulting class gains ``packed``, ``as_wds``, ``from_bytes``, and
|
|
1134
|
+
``from_data`` methods, and satisfies the ``Packable`` protocol.
|
|
1135
|
+
NDArray fields are automatically handled during serialization.
|
|
985
1136
|
|
|
986
1137
|
Examples:
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
sample = MyData(name="test", values=np.array([1, 2, 3]))
|
|
995
|
-
bytes_data = sample.packed
|
|
996
|
-
restored = MyData.from_bytes(bytes_data)
|
|
997
|
-
|
|
998
|
-
# Works with Packable-typed APIs
|
|
999
|
-
index.publish_schema(MyData, version="1.0.0") # Type-safe
|
|
1138
|
+
>>> @packable
|
|
1139
|
+
... class MyData:
|
|
1140
|
+
... name: str
|
|
1141
|
+
... values: NDArray
|
|
1142
|
+
...
|
|
1143
|
+
>>> sample = MyData(name="test", values=np.array([1, 2, 3]))
|
|
1144
|
+
>>> restored = MyData.from_bytes(sample.packed)
|
|
1000
1145
|
"""
|
|
1001
1146
|
|
|
1002
1147
|
##
|
|
@@ -1005,14 +1150,14 @@ def packable( cls: type[_T] ) -> type[_T]:
|
|
|
1005
1150
|
class_annotations = cls.__annotations__
|
|
1006
1151
|
|
|
1007
1152
|
# Add in dataclass niceness to original class
|
|
1008
|
-
as_dataclass = dataclass(
|
|
1153
|
+
as_dataclass = dataclass(cls)
|
|
1009
1154
|
|
|
1010
1155
|
# This triggers a bunch of behind-the-scenes stuff for the newly annotated class
|
|
1011
1156
|
@dataclass
|
|
1012
|
-
class as_packable(
|
|
1013
|
-
def __post_init__(
|
|
1014
|
-
return PackableSample.__post_init__(
|
|
1015
|
-
|
|
1157
|
+
class as_packable(as_dataclass, PackableSample):
|
|
1158
|
+
def __post_init__(self):
|
|
1159
|
+
return PackableSample.__post_init__(self)
|
|
1160
|
+
|
|
1016
1161
|
# Restore original class identity for better repr/debugging
|
|
1017
1162
|
as_packable.__name__ = class_name
|
|
1018
1163
|
as_packable.__qualname__ = class_name
|
|
@@ -1023,10 +1168,10 @@ def packable( cls: type[_T] ) -> type[_T]:
|
|
|
1023
1168
|
|
|
1024
1169
|
# Fix qualnames of dataclass-generated methods so they don't show
|
|
1025
1170
|
# 'packable.<locals>.as_packable' in help() and IDE hints
|
|
1026
|
-
old_qualname_prefix =
|
|
1027
|
-
for attr_name in (
|
|
1171
|
+
old_qualname_prefix = "packable.<locals>.as_packable"
|
|
1172
|
+
for attr_name in ("__init__", "__repr__", "__eq__", "__post_init__"):
|
|
1028
1173
|
attr = getattr(as_packable, attr_name, None)
|
|
1029
|
-
if attr is not None and hasattr(attr,
|
|
1174
|
+
if attr is not None and hasattr(attr, "__qualname__"):
|
|
1030
1175
|
if attr.__qualname__.startswith(old_qualname_prefix):
|
|
1031
1176
|
attr.__qualname__ = attr.__qualname__.replace(
|
|
1032
1177
|
old_qualname_prefix, class_name, 1
|
|
@@ -1042,4 +1187,4 @@ def packable( cls: type[_T] ) -> type[_T]:
|
|
|
1042
1187
|
|
|
1043
1188
|
##
|
|
1044
1189
|
|
|
1045
|
-
return as_packable
|
|
1190
|
+
return as_packable
|