atdata 0.2.0a1__py3-none-any.whl → 0.2.2b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +43 -10
- atdata/_cid.py +150 -0
- atdata/_hf_api.py +692 -0
- atdata/_protocols.py +519 -0
- atdata/_schema_codec.py +442 -0
- atdata/_sources.py +515 -0
- atdata/_stub_manager.py +529 -0
- atdata/_type_utils.py +90 -0
- atdata/atmosphere/__init__.py +278 -7
- atdata/atmosphere/_types.py +9 -7
- atdata/atmosphere/client.py +146 -6
- atdata/atmosphere/lens.py +29 -25
- atdata/atmosphere/records.py +197 -30
- atdata/atmosphere/schema.py +41 -98
- atdata/atmosphere/store.py +208 -0
- atdata/cli/__init__.py +213 -0
- atdata/cli/diagnose.py +165 -0
- atdata/cli/local.py +280 -0
- atdata/dataset.py +482 -167
- atdata/lens.py +61 -57
- atdata/local.py +1400 -185
- atdata/promote.py +199 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.2b1.dist-info}/METADATA +105 -14
- atdata-0.2.2b1.dist-info/RECORD +28 -0
- atdata-0.2.0a1.dist-info/RECORD +0 -16
- {atdata-0.2.0a1.dist-info → atdata-0.2.2b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.2b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.2b1.dist-info}/licenses/LICENSE +0 -0
atdata/dataset.py
CHANGED
|
@@ -14,15 +14,17 @@ during serialization, enabling efficient storage of numerical data in WebDataset
|
|
|
14
14
|
archives.
|
|
15
15
|
|
|
16
16
|
Example:
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
17
|
+
::
|
|
18
|
+
|
|
19
|
+
>>> @packable
|
|
20
|
+
... class ImageSample:
|
|
21
|
+
... image: NDArray
|
|
22
|
+
... label: str
|
|
23
|
+
...
|
|
24
|
+
>>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
|
|
25
|
+
>>> for batch in ds.shuffled(batch_size=32):
|
|
26
|
+
... images = batch.image # Stacked numpy array (32, H, W, C)
|
|
27
|
+
... labels = batch.label # List of 32 strings
|
|
26
28
|
"""
|
|
27
29
|
|
|
28
30
|
##
|
|
@@ -41,6 +43,9 @@ from dataclasses import (
|
|
|
41
43
|
)
|
|
42
44
|
from abc import ABC
|
|
43
45
|
|
|
46
|
+
from ._sources import URLSource, S3Source
|
|
47
|
+
from ._protocols import DataSource
|
|
48
|
+
|
|
44
49
|
from tqdm import tqdm
|
|
45
50
|
import numpy as np
|
|
46
51
|
import pandas as pd
|
|
@@ -51,16 +56,16 @@ from typing import (
|
|
|
51
56
|
Any,
|
|
52
57
|
Optional,
|
|
53
58
|
Dict,
|
|
59
|
+
Iterator,
|
|
54
60
|
Sequence,
|
|
55
61
|
Iterable,
|
|
56
62
|
Callable,
|
|
57
|
-
Union,
|
|
58
|
-
#
|
|
59
63
|
Self,
|
|
60
64
|
Generic,
|
|
61
65
|
Type,
|
|
62
66
|
TypeVar,
|
|
63
67
|
TypeAlias,
|
|
68
|
+
dataclass_transform,
|
|
64
69
|
)
|
|
65
70
|
from numpy.typing import NDArray
|
|
66
71
|
|
|
@@ -75,6 +80,7 @@ from .lens import Lens, LensNetwork
|
|
|
75
80
|
|
|
76
81
|
Pathlike = str | Path
|
|
77
82
|
|
|
83
|
+
# WebDataset sample/batch dictionaries (contain __key__, msgpack, etc.)
|
|
78
84
|
WDSRawSample: TypeAlias = Dict[str, Any]
|
|
79
85
|
WDSRawBatch: TypeAlias = Dict[str, Any]
|
|
80
86
|
|
|
@@ -87,49 +93,189 @@ SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
|
|
|
87
93
|
|
|
88
94
|
DT = TypeVar( 'DT' )
|
|
89
95
|
|
|
90
|
-
MsgpackRawSample: TypeAlias = Dict[str, Any]
|
|
91
|
-
|
|
92
96
|
|
|
93
97
|
def _make_packable( x ):
|
|
94
|
-
"""Convert
|
|
95
|
-
|
|
96
|
-
Args:
|
|
97
|
-
x: A value to convert. If it's a numpy array, converts to bytes.
|
|
98
|
-
Otherwise returns the value unchanged.
|
|
99
|
-
|
|
100
|
-
Returns:
|
|
101
|
-
The value in a format suitable for msgpack serialization.
|
|
102
|
-
"""
|
|
98
|
+
"""Convert numpy arrays to bytes; pass through other values unchanged."""
|
|
103
99
|
if isinstance( x, np.ndarray ):
|
|
104
100
|
return eh.array_to_bytes( x )
|
|
105
101
|
return x
|
|
106
102
|
|
|
107
|
-
def _is_possibly_ndarray_type( t ):
|
|
108
|
-
"""Check if a type annotation is or contains NDArray.
|
|
109
103
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
Returns:
|
|
114
|
-
``True`` if the type is ``NDArray`` or a union containing ``NDArray``
|
|
115
|
-
(e.g., ``NDArray | None``), ``False`` otherwise.
|
|
116
|
-
"""
|
|
117
|
-
|
|
118
|
-
# Directly an NDArray
|
|
104
|
+
def _is_possibly_ndarray_type( t ):
|
|
105
|
+
"""Return True if type annotation is NDArray or Optional[NDArray]."""
|
|
119
106
|
if t == NDArray:
|
|
120
|
-
# print( 'is an NDArray' )
|
|
121
107
|
return True
|
|
122
|
-
|
|
123
|
-
# Check for Optionals (i.e., NDArray | None)
|
|
124
108
|
if isinstance( t, types.UnionType ):
|
|
125
|
-
|
|
126
|
-
if any( x == NDArray
|
|
127
|
-
for x in t_parts ):
|
|
128
|
-
return True
|
|
129
|
-
|
|
130
|
-
# Not an NDArray
|
|
109
|
+
return any( x == NDArray for x in t.__args__ )
|
|
131
110
|
return False
|
|
132
111
|
|
|
112
|
+
class DictSample:
|
|
113
|
+
"""Dynamic sample type providing dict-like access to raw msgpack data.
|
|
114
|
+
|
|
115
|
+
This class is the default sample type for datasets when no explicit type is
|
|
116
|
+
specified. It stores the raw unpacked msgpack data and provides both
|
|
117
|
+
attribute-style (``sample.field``) and dict-style (``sample["field"]``)
|
|
118
|
+
access to fields.
|
|
119
|
+
|
|
120
|
+
``DictSample`` is useful for:
|
|
121
|
+
- Exploring datasets without defining a schema first
|
|
122
|
+
- Working with datasets that have variable schemas
|
|
123
|
+
- Prototyping before committing to a typed schema
|
|
124
|
+
|
|
125
|
+
To convert to a typed schema, use ``Dataset.as_type()`` with a
|
|
126
|
+
``@packable``-decorated class. Every ``@packable`` class automatically
|
|
127
|
+
registers a lens from ``DictSample``, making this conversion seamless.
|
|
128
|
+
|
|
129
|
+
Example:
|
|
130
|
+
::
|
|
131
|
+
|
|
132
|
+
>>> ds = load_dataset("path/to/data.tar") # Returns Dataset[DictSample]
|
|
133
|
+
>>> for sample in ds.ordered():
|
|
134
|
+
... print(sample.some_field) # Attribute access
|
|
135
|
+
... print(sample["other_field"]) # Dict access
|
|
136
|
+
... print(sample.keys()) # Inspect available fields
|
|
137
|
+
...
|
|
138
|
+
>>> # Convert to typed schema
|
|
139
|
+
>>> typed_ds = ds.as_type(MyTypedSample)
|
|
140
|
+
|
|
141
|
+
Note:
|
|
142
|
+
NDArray fields are stored as raw bytes in DictSample. They are only
|
|
143
|
+
converted to numpy arrays when accessed through a typed sample class.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
__slots__ = ('_data',)
|
|
147
|
+
|
|
148
|
+
def __init__(self, _data: dict[str, Any] | None = None, **kwargs: Any) -> None:
|
|
149
|
+
"""Create a DictSample from a dictionary or keyword arguments.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
_data: Raw data dictionary. If provided, kwargs are ignored.
|
|
153
|
+
**kwargs: Field values if _data is not provided.
|
|
154
|
+
"""
|
|
155
|
+
if _data is not None:
|
|
156
|
+
object.__setattr__(self, '_data', _data)
|
|
157
|
+
else:
|
|
158
|
+
object.__setattr__(self, '_data', kwargs)
|
|
159
|
+
|
|
160
|
+
@classmethod
|
|
161
|
+
def from_data(cls, data: dict[str, Any]) -> 'DictSample':
|
|
162
|
+
"""Create a DictSample from unpacked msgpack data.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
data: Dictionary with field names as keys.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
New DictSample instance wrapping the data.
|
|
169
|
+
"""
|
|
170
|
+
return cls(_data=data)
|
|
171
|
+
|
|
172
|
+
@classmethod
|
|
173
|
+
def from_bytes(cls, bs: bytes) -> 'DictSample':
|
|
174
|
+
"""Create a DictSample from raw msgpack bytes.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
bs: Raw bytes from a msgpack-serialized sample.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
New DictSample instance with the unpacked data.
|
|
181
|
+
"""
|
|
182
|
+
return cls.from_data(ormsgpack.unpackb(bs))
|
|
183
|
+
|
|
184
|
+
def __getattr__(self, name: str) -> Any:
|
|
185
|
+
"""Access a field by attribute name.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
name: Field name to access.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
The field value.
|
|
192
|
+
|
|
193
|
+
Raises:
|
|
194
|
+
AttributeError: If the field doesn't exist.
|
|
195
|
+
"""
|
|
196
|
+
# Avoid infinite recursion for _data lookup
|
|
197
|
+
if name == '_data':
|
|
198
|
+
raise AttributeError(name)
|
|
199
|
+
try:
|
|
200
|
+
return self._data[name]
|
|
201
|
+
except KeyError:
|
|
202
|
+
raise AttributeError(
|
|
203
|
+
f"'{type(self).__name__}' has no field '{name}'. "
|
|
204
|
+
f"Available fields: {list(self._data.keys())}"
|
|
205
|
+
) from None
|
|
206
|
+
|
|
207
|
+
def __getitem__(self, key: str) -> Any:
|
|
208
|
+
"""Access a field by dict key.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
key: Field name to access.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
The field value.
|
|
215
|
+
|
|
216
|
+
Raises:
|
|
217
|
+
KeyError: If the field doesn't exist.
|
|
218
|
+
"""
|
|
219
|
+
return self._data[key]
|
|
220
|
+
|
|
221
|
+
def __contains__(self, key: str) -> bool:
|
|
222
|
+
"""Check if a field exists."""
|
|
223
|
+
return key in self._data
|
|
224
|
+
|
|
225
|
+
def keys(self) -> list[str]:
|
|
226
|
+
"""Return list of field names."""
|
|
227
|
+
return list(self._data.keys())
|
|
228
|
+
|
|
229
|
+
def values(self) -> list[Any]:
|
|
230
|
+
"""Return list of field values."""
|
|
231
|
+
return list(self._data.values())
|
|
232
|
+
|
|
233
|
+
def items(self) -> list[tuple[str, Any]]:
|
|
234
|
+
"""Return list of (field_name, value) tuples."""
|
|
235
|
+
return list(self._data.items())
|
|
236
|
+
|
|
237
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
238
|
+
"""Get a field value with optional default.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
key: Field name to access.
|
|
242
|
+
default: Value to return if field doesn't exist.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
The field value or default.
|
|
246
|
+
"""
|
|
247
|
+
return self._data.get(key, default)
|
|
248
|
+
|
|
249
|
+
def to_dict(self) -> dict[str, Any]:
|
|
250
|
+
"""Return a copy of the underlying data dictionary."""
|
|
251
|
+
return dict(self._data)
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def packed(self) -> bytes:
|
|
255
|
+
"""Pack this sample's data into msgpack bytes.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
Raw msgpack bytes representing this sample's data.
|
|
259
|
+
"""
|
|
260
|
+
return msgpack.packb(self._data)
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def as_wds(self) -> 'WDSRawSample':
|
|
264
|
+
"""Pack this sample's data for writing to WebDataset.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
A dictionary with ``__key__`` and ``msgpack`` fields.
|
|
268
|
+
"""
|
|
269
|
+
return {
|
|
270
|
+
'__key__': str(uuid.uuid1(0, 0)),
|
|
271
|
+
'msgpack': self.packed,
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
def __repr__(self) -> str:
|
|
275
|
+
fields = ', '.join(f'{k}=...' for k in self._data.keys())
|
|
276
|
+
return f'DictSample({fields})'
|
|
277
|
+
|
|
278
|
+
|
|
133
279
|
@dataclass
|
|
134
280
|
class PackableSample( ABC ):
|
|
135
281
|
"""Base class for samples that can be serialized with msgpack.
|
|
@@ -144,28 +290,20 @@ class PackableSample( ABC ):
|
|
|
144
290
|
2. Using the ``@packable`` decorator (recommended)
|
|
145
291
|
|
|
146
292
|
Example:
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
293
|
+
::
|
|
294
|
+
|
|
295
|
+
>>> @packable
|
|
296
|
+
... class MyData:
|
|
297
|
+
... name: str
|
|
298
|
+
... embeddings: NDArray
|
|
299
|
+
...
|
|
300
|
+
>>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
|
|
301
|
+
>>> packed = sample.packed # Serialize to bytes
|
|
302
|
+
>>> restored = MyData.from_bytes(packed) # Deserialize
|
|
155
303
|
"""
|
|
156
304
|
|
|
157
305
|
def _ensure_good( self ):
|
|
158
|
-
"""
|
|
159
|
-
|
|
160
|
-
This method scans all dataclass fields and for any field annotated as
|
|
161
|
-
``NDArray`` or ``NDArray | None``, automatically converts bytes values
|
|
162
|
-
to numpy arrays using the helper deserialization function. This enables
|
|
163
|
-
transparent handling of array serialization in msgpack data.
|
|
164
|
-
|
|
165
|
-
Note:
|
|
166
|
-
This is called during ``__post_init__`` to ensure proper type
|
|
167
|
-
conversion after deserialization.
|
|
168
|
-
"""
|
|
306
|
+
"""Convert bytes to NDArray for fields annotated as NDArray or NDArray | None."""
|
|
169
307
|
|
|
170
308
|
# Auto-convert known types when annotated
|
|
171
309
|
# for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
|
|
@@ -187,9 +325,9 @@ class PackableSample( ABC ):
|
|
|
187
325
|
continue
|
|
188
326
|
|
|
189
327
|
elif isinstance( var_cur_value, bytes ):
|
|
190
|
-
#
|
|
191
|
-
#
|
|
192
|
-
# as
|
|
328
|
+
# Design note: bytes in NDArray-typed fields are always interpreted
|
|
329
|
+
# as serialized arrays. This means raw bytes fields must not be
|
|
330
|
+
# annotated as NDArray.
|
|
193
331
|
setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
|
|
194
332
|
|
|
195
333
|
def __post_init__( self ):
|
|
@@ -198,20 +336,16 @@ class PackableSample( ABC ):
|
|
|
198
336
|
##
|
|
199
337
|
|
|
200
338
|
@classmethod
|
|
201
|
-
def from_data( cls, data:
|
|
339
|
+
def from_data( cls, data: WDSRawSample ) -> Self:
|
|
202
340
|
"""Create a sample instance from unpacked msgpack data.
|
|
203
341
|
|
|
204
342
|
Args:
|
|
205
|
-
data:
|
|
206
|
-
the sample's field names.
|
|
343
|
+
data: Dictionary with keys matching the sample's field names.
|
|
207
344
|
|
|
208
345
|
Returns:
|
|
209
|
-
|
|
210
|
-
the data dictionary and NDArray fields auto-converted from bytes.
|
|
346
|
+
New instance with NDArray fields auto-converted from bytes.
|
|
211
347
|
"""
|
|
212
|
-
|
|
213
|
-
ret._ensure_good()
|
|
214
|
-
return ret
|
|
348
|
+
return cls( **data )
|
|
215
349
|
|
|
216
350
|
@classmethod
|
|
217
351
|
def from_bytes( cls, bs: bytes ) -> Self:
|
|
@@ -253,7 +387,6 @@ class PackableSample( ABC ):
|
|
|
253
387
|
|
|
254
388
|
return ret
|
|
255
389
|
|
|
256
|
-
# TODO Expand to allow for specifying explicit __key__
|
|
257
390
|
@property
|
|
258
391
|
def as_wds( self ) -> WDSRawSample:
|
|
259
392
|
"""Pack this sample's data for writing to WebDataset.
|
|
@@ -263,7 +396,8 @@ class PackableSample( ABC ):
|
|
|
263
396
|
``msgpack`` (packed sample data) fields suitable for WebDataset.
|
|
264
397
|
|
|
265
398
|
Note:
|
|
266
|
-
|
|
399
|
+
Keys are auto-generated as UUID v1 for time-sortable ordering.
|
|
400
|
+
Custom key specification is not currently supported.
|
|
267
401
|
"""
|
|
268
402
|
return {
|
|
269
403
|
# Generates a UUID that is timelike-sortable
|
|
@@ -272,25 +406,11 @@ class PackableSample( ABC ):
|
|
|
272
406
|
}
|
|
273
407
|
|
|
274
408
|
def _batch_aggregate( xs: Sequence ):
|
|
275
|
-
"""
|
|
276
|
-
|
|
277
|
-
Args:
|
|
278
|
-
xs: A sequence of values to aggregate. If the first element is a numpy
|
|
279
|
-
array, all elements are stacked into a single array. Otherwise,
|
|
280
|
-
returns a list.
|
|
281
|
-
|
|
282
|
-
Returns:
|
|
283
|
-
A numpy array (if elements are arrays) or a list (otherwise).
|
|
284
|
-
"""
|
|
285
|
-
|
|
409
|
+
"""Stack arrays into numpy array with batch dim; otherwise return list."""
|
|
286
410
|
if not xs:
|
|
287
|
-
# Empty sequence
|
|
288
411
|
return []
|
|
289
|
-
|
|
290
|
-
# Aggregate
|
|
291
412
|
if isinstance( xs[0], np.ndarray ):
|
|
292
413
|
return np.array( list( xs ) )
|
|
293
|
-
|
|
294
414
|
return list( xs )
|
|
295
415
|
|
|
296
416
|
class SampleBatch( Generic[DT] ):
|
|
@@ -304,17 +424,27 @@ class SampleBatch( Generic[DT] ):
|
|
|
304
424
|
NDArray fields are stacked into a numpy array with a batch dimension.
|
|
305
425
|
Other fields are aggregated into a list.
|
|
306
426
|
|
|
307
|
-
|
|
427
|
+
Parameters:
|
|
308
428
|
DT: The sample type, must derive from ``PackableSample``.
|
|
309
429
|
|
|
310
430
|
Attributes:
|
|
311
431
|
samples: The list of sample instances in this batch.
|
|
312
432
|
|
|
313
433
|
Example:
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
434
|
+
::
|
|
435
|
+
|
|
436
|
+
>>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
|
|
437
|
+
>>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
|
|
438
|
+
>>> batch.names # Returns list of names
|
|
439
|
+
|
|
440
|
+
Note:
|
|
441
|
+
This class uses Python's ``__orig_class__`` mechanism to extract the
|
|
442
|
+
type parameter at runtime. Instances must be created using the
|
|
443
|
+
subscripted syntax ``SampleBatch[MyType](samples)`` rather than
|
|
444
|
+
calling the constructor directly with an unsubscripted class.
|
|
317
445
|
"""
|
|
446
|
+
# Design note: The docstring uses "Parameters:" for type parameters because
|
|
447
|
+
# quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
|
|
318
448
|
|
|
319
449
|
def __init__( self, samples: Sequence[DT] ):
|
|
320
450
|
"""Create a batch from a sequence of samples.
|
|
@@ -326,6 +456,7 @@ class SampleBatch( Generic[DT] ):
|
|
|
326
456
|
"""
|
|
327
457
|
self.samples = list( samples )
|
|
328
458
|
self._aggregate_cache = dict()
|
|
459
|
+
self._sample_type_cache: Type | None = None
|
|
329
460
|
|
|
330
461
|
@property
|
|
331
462
|
def sample_type( self ) -> Type:
|
|
@@ -334,7 +465,10 @@ class SampleBatch( Generic[DT] ):
|
|
|
334
465
|
Returns:
|
|
335
466
|
The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
|
|
336
467
|
"""
|
|
337
|
-
|
|
468
|
+
if self._sample_type_cache is None:
|
|
469
|
+
self._sample_type_cache = typing.get_args( self.__orig_class__)[0]
|
|
470
|
+
assert self._sample_type_cache is not None
|
|
471
|
+
return self._sample_type_cache
|
|
338
472
|
|
|
339
473
|
def __getattr__( self, name ):
|
|
340
474
|
"""Aggregate an attribute across all samples in the batch.
|
|
@@ -368,6 +502,42 @@ class SampleBatch( Generic[DT] ):
|
|
|
368
502
|
ST = TypeVar( 'ST', bound = PackableSample )
|
|
369
503
|
RT = TypeVar( 'RT', bound = PackableSample )
|
|
370
504
|
|
|
505
|
+
|
|
506
|
+
class _ShardListStage(wds.utils.PipelineStage):
|
|
507
|
+
"""Pipeline stage that yields {url: shard_id} dicts from a DataSource.
|
|
508
|
+
|
|
509
|
+
This is analogous to SimpleShardList but works with any DataSource.
|
|
510
|
+
Used as the first stage before split_by_worker.
|
|
511
|
+
"""
|
|
512
|
+
|
|
513
|
+
def __init__(self, source: DataSource):
|
|
514
|
+
self.source = source
|
|
515
|
+
|
|
516
|
+
def run(self):
|
|
517
|
+
"""Yield {url: shard_id} dicts for each shard."""
|
|
518
|
+
for shard_id in self.source.list_shards():
|
|
519
|
+
yield {"url": shard_id}
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
class _StreamOpenerStage(wds.utils.PipelineStage):
|
|
523
|
+
"""Pipeline stage that opens streams from a DataSource.
|
|
524
|
+
|
|
525
|
+
Takes {url: shard_id} dicts and adds a stream using source.open_shard().
|
|
526
|
+
This replaces WebDataset's url_opener stage.
|
|
527
|
+
"""
|
|
528
|
+
|
|
529
|
+
def __init__(self, source: DataSource):
|
|
530
|
+
self.source = source
|
|
531
|
+
|
|
532
|
+
def run(self, src):
|
|
533
|
+
"""Open streams for each shard dict."""
|
|
534
|
+
for sample in src:
|
|
535
|
+
shard_id = sample["url"]
|
|
536
|
+
stream = self.source.open_shard(shard_id)
|
|
537
|
+
sample["stream"] = stream
|
|
538
|
+
yield sample
|
|
539
|
+
|
|
540
|
+
|
|
371
541
|
class Dataset( Generic[ST] ):
|
|
372
542
|
"""A typed dataset built on WebDataset with lens transformations.
|
|
373
543
|
|
|
@@ -381,22 +551,31 @@ class Dataset( Generic[ST] ):
|
|
|
381
551
|
- Type transformations via the lens system (``as_type()``)
|
|
382
552
|
- Export to parquet format
|
|
383
553
|
|
|
384
|
-
|
|
554
|
+
Parameters:
|
|
385
555
|
ST: The sample type for this dataset, must derive from ``PackableSample``.
|
|
386
556
|
|
|
387
557
|
Attributes:
|
|
388
558
|
url: WebDataset brace-notation URL for the tar file(s).
|
|
389
559
|
|
|
390
560
|
Example:
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
561
|
+
::
|
|
562
|
+
|
|
563
|
+
>>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
|
|
564
|
+
>>> for sample in ds.ordered(batch_size=32):
|
|
565
|
+
... # sample is SampleBatch[MyData] with batch_size samples
|
|
566
|
+
... embeddings = sample.embeddings # shape: (32, ...)
|
|
567
|
+
...
|
|
568
|
+
>>> # Transform to a different view
|
|
569
|
+
>>> ds_view = ds.as_type(MyDataView)
|
|
570
|
+
|
|
571
|
+
Note:
|
|
572
|
+
This class uses Python's ``__orig_class__`` mechanism to extract the
|
|
573
|
+
type parameter at runtime. Instances must be created using the
|
|
574
|
+
subscripted syntax ``Dataset[MyType](url)`` rather than calling the
|
|
575
|
+
constructor directly with an unsubscripted class.
|
|
399
576
|
"""
|
|
577
|
+
# Design note: The docstring uses "Parameters:" for type parameters because
|
|
578
|
+
# quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
|
|
400
579
|
|
|
401
580
|
@property
|
|
402
581
|
def sample_type( self ) -> Type:
|
|
@@ -404,12 +583,11 @@ class Dataset( Generic[ST] ):
|
|
|
404
583
|
|
|
405
584
|
Returns:
|
|
406
585
|
The type parameter ``ST`` used when creating this ``Dataset[ST]``.
|
|
407
|
-
|
|
408
|
-
Note:
|
|
409
|
-
Extracts the type parameter at runtime using ``__orig_class__``.
|
|
410
586
|
"""
|
|
411
|
-
|
|
412
|
-
|
|
587
|
+
if self._sample_type_cache is None:
|
|
588
|
+
self._sample_type_cache = typing.get_args( self.__orig_class__ )[0]
|
|
589
|
+
assert self._sample_type_cache is not None
|
|
590
|
+
return self._sample_type_cache
|
|
413
591
|
@property
|
|
414
592
|
def batch_type( self ) -> Type:
|
|
415
593
|
"""The type of batches produced by this dataset.
|
|
@@ -419,29 +597,58 @@ class Dataset( Generic[ST] ):
|
|
|
419
597
|
"""
|
|
420
598
|
return SampleBatch[self.sample_type]
|
|
421
599
|
|
|
422
|
-
def __init__( self,
|
|
600
|
+
def __init__( self,
|
|
601
|
+
source: DataSource | str | None = None,
|
|
423
602
|
metadata_url: str | None = None,
|
|
603
|
+
*,
|
|
604
|
+
url: str | None = None,
|
|
424
605
|
) -> None:
|
|
425
|
-
"""Create a dataset from a
|
|
606
|
+
"""Create a dataset from a DataSource or URL.
|
|
426
607
|
|
|
427
608
|
Args:
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
609
|
+
source: Either a DataSource implementation or a WebDataset-compatible
|
|
610
|
+
URL string. If a string is provided, it's wrapped in URLSource
|
|
611
|
+
for backward compatibility.
|
|
612
|
+
|
|
613
|
+
Examples:
|
|
614
|
+
- String URL: ``"path/to/file-{000000..000009}.tar"``
|
|
615
|
+
- URLSource: ``URLSource("https://example.com/data.tar")``
|
|
616
|
+
- S3Source: ``S3Source(bucket="my-bucket", keys=["data.tar"])``
|
|
617
|
+
|
|
618
|
+
metadata_url: Optional URL to msgpack-encoded metadata for this dataset.
|
|
619
|
+
url: Deprecated. Use ``source`` instead. Kept for backward compatibility.
|
|
431
620
|
"""
|
|
432
621
|
super().__init__()
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
622
|
+
|
|
623
|
+
# Handle backward compatibility: url= keyword argument
|
|
624
|
+
if source is None and url is not None:
|
|
625
|
+
source = url
|
|
626
|
+
elif source is None:
|
|
627
|
+
raise TypeError("Dataset() missing required argument: 'source' or 'url'")
|
|
628
|
+
|
|
629
|
+
# Normalize source: strings become URLSource for backward compatibility
|
|
630
|
+
if isinstance(source, str):
|
|
631
|
+
self._source: DataSource = URLSource(source)
|
|
632
|
+
self.url = source
|
|
633
|
+
else:
|
|
634
|
+
self._source = source
|
|
635
|
+
# For compatibility, expose URL if source has list_shards
|
|
636
|
+
shards = source.list_shards()
|
|
637
|
+
# Design note: Using first shard as url for legacy compatibility.
|
|
638
|
+
# Full shard list is available via list_shards() method.
|
|
639
|
+
self.url = shards[0] if shards else ""
|
|
438
640
|
|
|
439
641
|
self._metadata: dict[str, Any] | None = None
|
|
440
642
|
self.metadata_url: str | None = metadata_url
|
|
441
643
|
"""Optional URL to msgpack-encoded metadata for this dataset."""
|
|
442
644
|
|
|
443
|
-
# Allow addition of automatic transformation of raw underlying data
|
|
444
645
|
self._output_lens: Lens | None = None
|
|
646
|
+
self._sample_type_cache: Type | None = None
|
|
647
|
+
|
|
648
|
+
@property
|
|
649
|
+
def source(self) -> DataSource:
|
|
650
|
+
"""The underlying data source for this dataset."""
|
|
651
|
+
return self._source
|
|
445
652
|
|
|
446
653
|
def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
|
|
447
654
|
"""View this dataset through a different sample type using a registered lens.
|
|
@@ -459,25 +666,51 @@ class Dataset( Generic[ST] ):
|
|
|
459
666
|
ValueError: If no registered lens exists between the current
|
|
460
667
|
sample type and the target type.
|
|
461
668
|
"""
|
|
462
|
-
ret = Dataset[other]( self.
|
|
669
|
+
ret = Dataset[other]( self._source )
|
|
463
670
|
# Get the singleton lens registry
|
|
464
671
|
lenses = LensNetwork()
|
|
465
672
|
ret._output_lens = lenses.transform( self.sample_type, ret.sample_type )
|
|
466
673
|
return ret
|
|
467
674
|
|
|
468
675
|
@property
|
|
469
|
-
def
|
|
470
|
-
"""
|
|
471
|
-
|
|
676
|
+
def shards(self) -> Iterator[str]:
|
|
677
|
+
"""Lazily iterate over shard identifiers.
|
|
678
|
+
|
|
679
|
+
Yields:
|
|
680
|
+
Shard identifiers (e.g., 'train-000000.tar', 'train-000001.tar').
|
|
681
|
+
|
|
682
|
+
Example:
|
|
683
|
+
::
|
|
684
|
+
|
|
685
|
+
>>> for shard in ds.shards:
|
|
686
|
+
... print(f"Processing {shard}")
|
|
687
|
+
"""
|
|
688
|
+
return iter(self._source.list_shards())
|
|
689
|
+
|
|
690
|
+
def list_shards(self) -> list[str]:
|
|
691
|
+
"""Get list of individual dataset shards.
|
|
692
|
+
|
|
472
693
|
Returns:
|
|
473
694
|
A full (non-lazy) list of the individual ``tar`` files within the
|
|
474
695
|
source WebDataset.
|
|
475
696
|
"""
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
697
|
+
return self._source.list_shards()
|
|
698
|
+
|
|
699
|
+
# Legacy alias for backwards compatibility
|
|
700
|
+
@property
|
|
701
|
+
def shard_list(self) -> list[str]:
|
|
702
|
+
"""List of individual dataset shards (deprecated, use list_shards()).
|
|
703
|
+
|
|
704
|
+
.. deprecated::
|
|
705
|
+
Use :meth:`list_shards` instead.
|
|
706
|
+
"""
|
|
707
|
+
import warnings
|
|
708
|
+
warnings.warn(
|
|
709
|
+
"shard_list is deprecated, use list_shards() instead",
|
|
710
|
+
DeprecationWarning,
|
|
711
|
+
stacklevel=2,
|
|
479
712
|
)
|
|
480
|
-
return
|
|
713
|
+
return self.list_shards()
|
|
481
714
|
|
|
482
715
|
@property
|
|
483
716
|
def metadata( self ) -> dict[str, Any] | None:
|
|
@@ -501,33 +734,36 @@ class Dataset( Generic[ST] ):
|
|
|
501
734
|
return self._metadata
|
|
502
735
|
|
|
503
736
|
def ordered( self,
|
|
504
|
-
batch_size: int | None =
|
|
737
|
+
batch_size: int | None = None,
|
|
505
738
|
) -> Iterable[ST]:
|
|
506
739
|
"""Iterate over the dataset in order
|
|
507
|
-
|
|
740
|
+
|
|
508
741
|
Args:
|
|
509
742
|
batch_size (:obj:`int`, optional): The size of iterated batches.
|
|
510
|
-
Default:
|
|
511
|
-
with no batch dimension.
|
|
512
|
-
|
|
743
|
+
Default: None (unbatched). If ``None``, iterates over one
|
|
744
|
+
sample at a time with no batch dimension.
|
|
745
|
+
|
|
513
746
|
Returns:
|
|
514
747
|
:obj:`webdataset.DataPipeline` A data pipeline that iterates over
|
|
515
748
|
the dataset in its original sample order
|
|
516
|
-
|
|
517
|
-
"""
|
|
518
749
|
|
|
750
|
+
"""
|
|
519
751
|
if batch_size is None:
|
|
520
752
|
return wds.pipeline.DataPipeline(
|
|
521
|
-
|
|
753
|
+
_ShardListStage(self._source),
|
|
522
754
|
wds.shardlists.split_by_worker,
|
|
523
|
-
|
|
755
|
+
_StreamOpenerStage(self._source),
|
|
756
|
+
wds.tariterators.tar_file_expander,
|
|
757
|
+
wds.tariterators.group_by_keys,
|
|
524
758
|
wds.filters.map( self.wrap ),
|
|
525
759
|
)
|
|
526
760
|
|
|
527
761
|
return wds.pipeline.DataPipeline(
|
|
528
|
-
|
|
762
|
+
_ShardListStage(self._source),
|
|
529
763
|
wds.shardlists.split_by_worker,
|
|
530
|
-
|
|
764
|
+
_StreamOpenerStage(self._source),
|
|
765
|
+
wds.tariterators.tar_file_expander,
|
|
766
|
+
wds.tariterators.group_by_keys,
|
|
531
767
|
wds.filters.batched( batch_size ),
|
|
532
768
|
wds.filters.map( self.wrap_batch ),
|
|
533
769
|
)
|
|
@@ -535,7 +771,7 @@ class Dataset( Generic[ST] ):
|
|
|
535
771
|
def shuffled( self,
|
|
536
772
|
buffer_shards: int = 100,
|
|
537
773
|
buffer_samples: int = 10_000,
|
|
538
|
-
batch_size: int | None =
|
|
774
|
+
batch_size: int | None = None,
|
|
539
775
|
) -> Iterable[ST]:
|
|
540
776
|
"""Iterate over the dataset in random order.
|
|
541
777
|
|
|
@@ -546,8 +782,9 @@ class Dataset( Generic[ST] ):
|
|
|
546
782
|
buffer_samples: Number of samples to buffer for shuffling within
|
|
547
783
|
shards. Larger values increase randomness but use more memory.
|
|
548
784
|
Default: 10,000.
|
|
549
|
-
batch_size: The size of iterated batches. Default:
|
|
550
|
-
iterates over one sample at a time with no batch
|
|
785
|
+
batch_size: The size of iterated batches. Default: None (unbatched).
|
|
786
|
+
If ``None``, iterates over one sample at a time with no batch
|
|
787
|
+
dimension.
|
|
551
788
|
|
|
552
789
|
Returns:
|
|
553
790
|
A WebDataset data pipeline that iterates over the dataset in
|
|
@@ -557,34 +794,72 @@ class Dataset( Generic[ST] ):
|
|
|
557
794
|
"""
|
|
558
795
|
if batch_size is None:
|
|
559
796
|
return wds.pipeline.DataPipeline(
|
|
560
|
-
|
|
797
|
+
_ShardListStage(self._source),
|
|
561
798
|
wds.filters.shuffle( buffer_shards ),
|
|
562
799
|
wds.shardlists.split_by_worker,
|
|
563
|
-
|
|
800
|
+
_StreamOpenerStage(self._source),
|
|
801
|
+
wds.tariterators.tar_file_expander,
|
|
802
|
+
wds.tariterators.group_by_keys,
|
|
564
803
|
wds.filters.shuffle( buffer_samples ),
|
|
565
804
|
wds.filters.map( self.wrap ),
|
|
566
805
|
)
|
|
567
806
|
|
|
568
807
|
return wds.pipeline.DataPipeline(
|
|
569
|
-
|
|
808
|
+
_ShardListStage(self._source),
|
|
570
809
|
wds.filters.shuffle( buffer_shards ),
|
|
571
810
|
wds.shardlists.split_by_worker,
|
|
572
|
-
|
|
811
|
+
_StreamOpenerStage(self._source),
|
|
812
|
+
wds.tariterators.tar_file_expander,
|
|
813
|
+
wds.tariterators.group_by_keys,
|
|
573
814
|
wds.filters.shuffle( buffer_samples ),
|
|
574
815
|
wds.filters.batched( batch_size ),
|
|
575
816
|
wds.filters.map( self.wrap_batch ),
|
|
576
817
|
)
|
|
577
818
|
|
|
578
|
-
#
|
|
579
|
-
#
|
|
819
|
+
# Design note: Uses pandas for parquet export. Could be replaced with
|
|
820
|
+
# direct fastparquet calls to reduce dependencies if needed.
|
|
580
821
|
def to_parquet( self, path: Pathlike,
|
|
581
822
|
sample_map: Optional[SampleExportMap] = None,
|
|
582
823
|
maxcount: Optional[int] = None,
|
|
583
824
|
**kwargs,
|
|
584
825
|
):
|
|
585
|
-
"""
|
|
826
|
+
"""Export dataset contents to parquet format.
|
|
827
|
+
|
|
828
|
+
Converts all samples to a pandas DataFrame and saves to parquet file(s).
|
|
829
|
+
Useful for interoperability with data analysis tools.
|
|
586
830
|
|
|
587
|
-
|
|
831
|
+
Args:
|
|
832
|
+
path: Output path for the parquet file. If ``maxcount`` is specified,
|
|
833
|
+
files are named ``{stem}-{segment:06d}.parquet``.
|
|
834
|
+
sample_map: Optional function to convert samples to dictionaries.
|
|
835
|
+
Defaults to ``dataclasses.asdict``.
|
|
836
|
+
maxcount: If specified, split output into multiple files with at most
|
|
837
|
+
this many samples each. Recommended for large datasets.
|
|
838
|
+
**kwargs: Additional arguments passed to ``pandas.DataFrame.to_parquet()``.
|
|
839
|
+
Common options include ``compression``, ``index``, ``engine``.
|
|
840
|
+
|
|
841
|
+
Warning:
|
|
842
|
+
**Memory Usage**: When ``maxcount=None`` (default), this method loads
|
|
843
|
+
the **entire dataset into memory** as a pandas DataFrame before writing.
|
|
844
|
+
For large datasets, this can cause memory exhaustion.
|
|
845
|
+
|
|
846
|
+
For datasets larger than available RAM, always specify ``maxcount``::
|
|
847
|
+
|
|
848
|
+
# Safe for large datasets - processes in chunks
|
|
849
|
+
ds.to_parquet("output.parquet", maxcount=10000)
|
|
850
|
+
|
|
851
|
+
This creates multiple parquet files: ``output-000000.parquet``,
|
|
852
|
+
``output-000001.parquet``, etc.
|
|
853
|
+
|
|
854
|
+
Example:
|
|
855
|
+
::
|
|
856
|
+
|
|
857
|
+
>>> ds = Dataset[MySample]("data.tar")
|
|
858
|
+
>>> # Small dataset - load all at once
|
|
859
|
+
>>> ds.to_parquet("output.parquet")
|
|
860
|
+
>>>
|
|
861
|
+
>>> # Large dataset - process in chunks
|
|
862
|
+
>>> ds.to_parquet("output.parquet", maxcount=50000)
|
|
588
863
|
"""
|
|
589
864
|
##
|
|
590
865
|
|
|
@@ -632,7 +907,7 @@ class Dataset( Generic[ST] ):
|
|
|
632
907
|
df = pd.DataFrame( cur_buffer )
|
|
633
908
|
df.to_parquet( cur_path, **kwargs )
|
|
634
909
|
|
|
635
|
-
def wrap( self, sample:
|
|
910
|
+
def wrap( self, sample: WDSRawSample ) -> ST:
|
|
636
911
|
"""Wrap a raw msgpack sample into the appropriate dataset-specific type.
|
|
637
912
|
|
|
638
913
|
Args:
|
|
@@ -643,9 +918,11 @@ class Dataset( Generic[ST] ):
|
|
|
643
918
|
A deserialized sample of type ``ST``, optionally transformed through
|
|
644
919
|
a lens if ``as_type()`` was called.
|
|
645
920
|
"""
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
921
|
+
if 'msgpack' not in sample:
|
|
922
|
+
raise ValueError(f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}")
|
|
923
|
+
if not isinstance(sample['msgpack'], bytes):
|
|
924
|
+
raise ValueError(f"Expected sample['msgpack'] to be bytes, got {type(sample['msgpack']).__name__}")
|
|
925
|
+
|
|
649
926
|
if self._output_lens is None:
|
|
650
927
|
return self.sample_type.from_bytes( sample['msgpack'] )
|
|
651
928
|
|
|
@@ -668,7 +945,8 @@ class Dataset( Generic[ST] ):
|
|
|
668
945
|
aggregates them into a batch.
|
|
669
946
|
"""
|
|
670
947
|
|
|
671
|
-
|
|
948
|
+
if 'msgpack' not in batch:
|
|
949
|
+
raise ValueError(f"Batch missing 'msgpack' key, got keys: {list(batch.keys())}")
|
|
672
950
|
|
|
673
951
|
if self._output_lens is None:
|
|
674
952
|
batch_unpacked = [ self.sample_type.from_bytes( bs )
|
|
@@ -682,29 +960,43 @@ class Dataset( Generic[ST] ):
|
|
|
682
960
|
return SampleBatch[self.sample_type]( batch_view )
|
|
683
961
|
|
|
684
962
|
|
|
685
|
-
|
|
963
|
+
_T = TypeVar('_T')
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
@dataclass_transform()
|
|
967
|
+
def packable( cls: type[_T] ) -> type[_T]:
|
|
686
968
|
"""Decorator to convert a regular class into a ``PackableSample``.
|
|
687
969
|
|
|
688
970
|
This decorator transforms a class into a dataclass that inherits from
|
|
689
971
|
``PackableSample``, enabling automatic msgpack serialization/deserialization
|
|
690
972
|
with special handling for NDArray fields.
|
|
691
973
|
|
|
974
|
+
The resulting class satisfies the ``Packable`` protocol, making it compatible
|
|
975
|
+
with all atdata APIs that accept packable types (e.g., ``publish_schema``,
|
|
976
|
+
lens transformations, etc.).
|
|
977
|
+
|
|
692
978
|
Args:
|
|
693
979
|
cls: The class to convert. Should have type annotations for its fields.
|
|
694
980
|
|
|
695
981
|
Returns:
|
|
696
982
|
A new dataclass that inherits from ``PackableSample`` with the same
|
|
697
|
-
name and annotations as the original class.
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
983
|
+
name and annotations as the original class. The class satisfies the
|
|
984
|
+
``Packable`` protocol and can be used with ``Type[Packable]`` signatures.
|
|
985
|
+
|
|
986
|
+
Examples:
|
|
987
|
+
This is a test of the functionality::
|
|
988
|
+
|
|
989
|
+
@packable
|
|
990
|
+
class MyData:
|
|
991
|
+
name: str
|
|
992
|
+
values: NDArray
|
|
993
|
+
|
|
994
|
+
sample = MyData(name="test", values=np.array([1, 2, 3]))
|
|
995
|
+
bytes_data = sample.packed
|
|
996
|
+
restored = MyData.from_bytes(bytes_data)
|
|
997
|
+
|
|
998
|
+
# Works with Packable-typed APIs
|
|
999
|
+
index.publish_schema(MyData, version="1.0.0") # Type-safe
|
|
708
1000
|
"""
|
|
709
1001
|
|
|
710
1002
|
##
|
|
@@ -721,9 +1013,32 @@ def packable( cls ):
|
|
|
721
1013
|
def __post_init__( self ):
|
|
722
1014
|
return PackableSample.__post_init__( self )
|
|
723
1015
|
|
|
724
|
-
#
|
|
1016
|
+
# Restore original class identity for better repr/debugging
|
|
725
1017
|
as_packable.__name__ = class_name
|
|
1018
|
+
as_packable.__qualname__ = class_name
|
|
1019
|
+
as_packable.__module__ = cls.__module__
|
|
726
1020
|
as_packable.__annotations__ = class_annotations
|
|
1021
|
+
if cls.__doc__:
|
|
1022
|
+
as_packable.__doc__ = cls.__doc__
|
|
1023
|
+
|
|
1024
|
+
# Fix qualnames of dataclass-generated methods so they don't show
|
|
1025
|
+
# 'packable.<locals>.as_packable' in help() and IDE hints
|
|
1026
|
+
old_qualname_prefix = 'packable.<locals>.as_packable'
|
|
1027
|
+
for attr_name in ('__init__', '__repr__', '__eq__', '__post_init__'):
|
|
1028
|
+
attr = getattr(as_packable, attr_name, None)
|
|
1029
|
+
if attr is not None and hasattr(attr, '__qualname__'):
|
|
1030
|
+
if attr.__qualname__.startswith(old_qualname_prefix):
|
|
1031
|
+
attr.__qualname__ = attr.__qualname__.replace(
|
|
1032
|
+
old_qualname_prefix, class_name, 1
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
# Auto-register lens from DictSample to this type
|
|
1036
|
+
# This enables ds.as_type(MyType) when ds is Dataset[DictSample]
|
|
1037
|
+
def _dict_to_typed(ds: DictSample) -> as_packable:
|
|
1038
|
+
return as_packable.from_data(ds._data)
|
|
1039
|
+
|
|
1040
|
+
_dict_lens = Lens(_dict_to_typed)
|
|
1041
|
+
LensNetwork().register(_dict_lens)
|
|
727
1042
|
|
|
728
1043
|
##
|
|
729
1044
|
|