atdata 0.2.2b1__py3-none-any.whl → 0.2.3b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +1 -1
- atdata/_cid.py +29 -35
- atdata/_helpers.py +7 -5
- atdata/_hf_api.py +48 -50
- atdata/_protocols.py +56 -71
- atdata/_schema_codec.py +33 -37
- atdata/_sources.py +57 -64
- atdata/_stub_manager.py +31 -26
- atdata/_type_utils.py +19 -5
- atdata/atmosphere/__init__.py +20 -23
- atdata/atmosphere/_types.py +11 -11
- atdata/atmosphere/client.py +11 -8
- atdata/atmosphere/lens.py +27 -30
- atdata/atmosphere/records.py +31 -37
- atdata/atmosphere/schema.py +33 -29
- atdata/atmosphere/store.py +16 -20
- atdata/cli/__init__.py +12 -3
- atdata/cli/diagnose.py +12 -8
- atdata/cli/local.py +4 -1
- atdata/dataset.py +284 -241
- atdata/lens.py +77 -82
- atdata/local.py +182 -169
- atdata/promote.py +18 -22
- {atdata-0.2.2b1.dist-info → atdata-0.2.3b1.dist-info}/METADATA +2 -1
- atdata-0.2.3b1.dist-info/RECORD +28 -0
- atdata-0.2.2b1.dist-info/RECORD +0 -28
- {atdata-0.2.2b1.dist-info → atdata-0.2.3b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.2b1.dist-info → atdata-0.2.3b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.2b1.dist-info → atdata-0.2.3b1.dist-info}/licenses/LICENSE +0 -0
atdata/dataset.py
CHANGED
|
@@ -13,18 +13,16 @@ The implementation handles automatic conversion between numpy arrays and bytes
|
|
|
13
13
|
during serialization, enabling efficient storage of numerical data in WebDataset
|
|
14
14
|
archives.
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
... images = batch.image # Stacked numpy array (32, H, W, C)
|
|
27
|
-
... labels = batch.label # List of 32 strings
|
|
16
|
+
Examples:
|
|
17
|
+
>>> @packable
|
|
18
|
+
... class ImageSample:
|
|
19
|
+
... image: NDArray
|
|
20
|
+
... label: str
|
|
21
|
+
...
|
|
22
|
+
>>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
|
|
23
|
+
>>> for batch in ds.shuffled(batch_size=32):
|
|
24
|
+
... images = batch.image # Stacked numpy array (32, H, W, C)
|
|
25
|
+
... labels = batch.label # List of 32 strings
|
|
28
26
|
"""
|
|
29
27
|
|
|
30
28
|
##
|
|
@@ -43,7 +41,7 @@ from dataclasses import (
|
|
|
43
41
|
)
|
|
44
42
|
from abc import ABC
|
|
45
43
|
|
|
46
|
-
from ._sources import URLSource
|
|
44
|
+
from ._sources import URLSource
|
|
47
45
|
from ._protocols import DataSource
|
|
48
46
|
|
|
49
47
|
from tqdm import tqdm
|
|
@@ -66,6 +64,7 @@ from typing import (
|
|
|
66
64
|
TypeVar,
|
|
67
65
|
TypeAlias,
|
|
68
66
|
dataclass_transform,
|
|
67
|
+
overload,
|
|
69
68
|
)
|
|
70
69
|
from numpy.typing import NDArray
|
|
71
70
|
|
|
@@ -85,30 +84,31 @@ WDSRawSample: TypeAlias = Dict[str, Any]
|
|
|
85
84
|
WDSRawBatch: TypeAlias = Dict[str, Any]
|
|
86
85
|
|
|
87
86
|
SampleExportRow: TypeAlias = Dict[str, Any]
|
|
88
|
-
SampleExportMap: TypeAlias = Callable[[
|
|
87
|
+
SampleExportMap: TypeAlias = Callable[["PackableSample"], SampleExportRow]
|
|
89
88
|
|
|
90
89
|
|
|
91
90
|
##
|
|
92
91
|
# Main base classes
|
|
93
92
|
|
|
94
|
-
DT = TypeVar(
|
|
93
|
+
DT = TypeVar("DT")
|
|
95
94
|
|
|
96
95
|
|
|
97
|
-
def _make_packable(
|
|
96
|
+
def _make_packable(x):
|
|
98
97
|
"""Convert numpy arrays to bytes; pass through other values unchanged."""
|
|
99
|
-
if isinstance(
|
|
100
|
-
return eh.array_to_bytes(
|
|
98
|
+
if isinstance(x, np.ndarray):
|
|
99
|
+
return eh.array_to_bytes(x)
|
|
101
100
|
return x
|
|
102
101
|
|
|
103
102
|
|
|
104
|
-
def _is_possibly_ndarray_type(
|
|
103
|
+
def _is_possibly_ndarray_type(t):
|
|
105
104
|
"""Return True if type annotation is NDArray or Optional[NDArray]."""
|
|
106
105
|
if t == NDArray:
|
|
107
106
|
return True
|
|
108
|
-
if isinstance(
|
|
109
|
-
return any(
|
|
107
|
+
if isinstance(t, types.UnionType):
|
|
108
|
+
return any(x == NDArray for x in t.__args__)
|
|
110
109
|
return False
|
|
111
110
|
|
|
111
|
+
|
|
112
112
|
class DictSample:
|
|
113
113
|
"""Dynamic sample type providing dict-like access to raw msgpack data.
|
|
114
114
|
|
|
@@ -126,24 +126,22 @@ class DictSample:
|
|
|
126
126
|
``@packable``-decorated class. Every ``@packable`` class automatically
|
|
127
127
|
registers a lens from ``DictSample``, making this conversion seamless.
|
|
128
128
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
>>> # Convert to typed schema
|
|
139
|
-
>>> typed_ds = ds.as_type(MyTypedSample)
|
|
129
|
+
Examples:
|
|
130
|
+
>>> ds = load_dataset("path/to/data.tar") # Returns Dataset[DictSample]
|
|
131
|
+
>>> for sample in ds.ordered():
|
|
132
|
+
... print(sample.some_field) # Attribute access
|
|
133
|
+
... print(sample["other_field"]) # Dict access
|
|
134
|
+
... print(sample.keys()) # Inspect available fields
|
|
135
|
+
...
|
|
136
|
+
>>> # Convert to typed schema
|
|
137
|
+
>>> typed_ds = ds.as_type(MyTypedSample)
|
|
140
138
|
|
|
141
139
|
Note:
|
|
142
140
|
NDArray fields are stored as raw bytes in DictSample. They are only
|
|
143
141
|
converted to numpy arrays when accessed through a typed sample class.
|
|
144
142
|
"""
|
|
145
143
|
|
|
146
|
-
__slots__ = (
|
|
144
|
+
__slots__ = ("_data",)
|
|
147
145
|
|
|
148
146
|
def __init__(self, _data: dict[str, Any] | None = None, **kwargs: Any) -> None:
|
|
149
147
|
"""Create a DictSample from a dictionary or keyword arguments.
|
|
@@ -153,12 +151,12 @@ class DictSample:
|
|
|
153
151
|
**kwargs: Field values if _data is not provided.
|
|
154
152
|
"""
|
|
155
153
|
if _data is not None:
|
|
156
|
-
object.__setattr__(self,
|
|
154
|
+
object.__setattr__(self, "_data", _data)
|
|
157
155
|
else:
|
|
158
|
-
object.__setattr__(self,
|
|
156
|
+
object.__setattr__(self, "_data", kwargs)
|
|
159
157
|
|
|
160
158
|
@classmethod
|
|
161
|
-
def from_data(cls, data: dict[str, Any]) ->
|
|
159
|
+
def from_data(cls, data: dict[str, Any]) -> "DictSample":
|
|
162
160
|
"""Create a DictSample from unpacked msgpack data.
|
|
163
161
|
|
|
164
162
|
Args:
|
|
@@ -170,7 +168,7 @@ class DictSample:
|
|
|
170
168
|
return cls(_data=data)
|
|
171
169
|
|
|
172
170
|
@classmethod
|
|
173
|
-
def from_bytes(cls, bs: bytes) ->
|
|
171
|
+
def from_bytes(cls, bs: bytes) -> "DictSample":
|
|
174
172
|
"""Create a DictSample from raw msgpack bytes.
|
|
175
173
|
|
|
176
174
|
Args:
|
|
@@ -194,7 +192,7 @@ class DictSample:
|
|
|
194
192
|
AttributeError: If the field doesn't exist.
|
|
195
193
|
"""
|
|
196
194
|
# Avoid infinite recursion for _data lookup
|
|
197
|
-
if name ==
|
|
195
|
+
if name == "_data":
|
|
198
196
|
raise AttributeError(name)
|
|
199
197
|
try:
|
|
200
198
|
return self._data[name]
|
|
@@ -260,24 +258,24 @@ class DictSample:
|
|
|
260
258
|
return msgpack.packb(self._data)
|
|
261
259
|
|
|
262
260
|
@property
|
|
263
|
-
def as_wds(self) ->
|
|
261
|
+
def as_wds(self) -> "WDSRawSample":
|
|
264
262
|
"""Pack this sample's data for writing to WebDataset.
|
|
265
263
|
|
|
266
264
|
Returns:
|
|
267
265
|
A dictionary with ``__key__`` and ``msgpack`` fields.
|
|
268
266
|
"""
|
|
269
267
|
return {
|
|
270
|
-
|
|
271
|
-
|
|
268
|
+
"__key__": str(uuid.uuid1(0, 0)),
|
|
269
|
+
"msgpack": self.packed,
|
|
272
270
|
}
|
|
273
271
|
|
|
274
272
|
def __repr__(self) -> str:
|
|
275
|
-
fields =
|
|
276
|
-
return f
|
|
273
|
+
fields = ", ".join(f"{k}=..." for k in self._data.keys())
|
|
274
|
+
return f"DictSample({fields})"
|
|
277
275
|
|
|
278
276
|
|
|
279
277
|
@dataclass
|
|
280
|
-
class PackableSample(
|
|
278
|
+
class PackableSample(ABC):
|
|
281
279
|
"""Base class for samples that can be serialized with msgpack.
|
|
282
280
|
|
|
283
281
|
This abstract base class provides automatic serialization/deserialization
|
|
@@ -289,54 +287,52 @@ class PackableSample( ABC ):
|
|
|
289
287
|
1. Direct inheritance with the ``@dataclass`` decorator
|
|
290
288
|
2. Using the ``@packable`` decorator (recommended)
|
|
291
289
|
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
>>> packed = sample.packed # Serialize to bytes
|
|
302
|
-
>>> restored = MyData.from_bytes(packed) # Deserialize
|
|
290
|
+
Examples:
|
|
291
|
+
>>> @packable
|
|
292
|
+
... class MyData:
|
|
293
|
+
... name: str
|
|
294
|
+
... embeddings: NDArray
|
|
295
|
+
...
|
|
296
|
+
>>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
|
|
297
|
+
>>> packed = sample.packed # Serialize to bytes
|
|
298
|
+
>>> restored = MyData.from_bytes(packed) # Deserialize
|
|
303
299
|
"""
|
|
304
300
|
|
|
305
|
-
def _ensure_good(
|
|
301
|
+
def _ensure_good(self):
|
|
306
302
|
"""Convert bytes to NDArray for fields annotated as NDArray or NDArray | None."""
|
|
307
303
|
|
|
308
304
|
# Auto-convert known types when annotated
|
|
309
305
|
# for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
|
|
310
|
-
for field in dataclasses.fields(
|
|
306
|
+
for field in dataclasses.fields(self):
|
|
311
307
|
var_name = field.name
|
|
312
308
|
var_type = field.type
|
|
313
309
|
|
|
314
310
|
# Annotation for this variable is to be an NDArray
|
|
315
|
-
if _is_possibly_ndarray_type(
|
|
311
|
+
if _is_possibly_ndarray_type(var_type):
|
|
316
312
|
# ... so, we'll always auto-convert to numpy
|
|
317
313
|
|
|
318
|
-
var_cur_value = getattr(
|
|
314
|
+
var_cur_value = getattr(self, var_name)
|
|
319
315
|
|
|
320
316
|
# Execute the appropriate conversion for intermediate data
|
|
321
317
|
# based on what is provided
|
|
322
318
|
|
|
323
|
-
if isinstance(
|
|
319
|
+
if isinstance(var_cur_value, np.ndarray):
|
|
324
320
|
# Already the correct type, no conversion needed
|
|
325
321
|
continue
|
|
326
322
|
|
|
327
|
-
elif isinstance(
|
|
323
|
+
elif isinstance(var_cur_value, bytes):
|
|
328
324
|
# Design note: bytes in NDArray-typed fields are always interpreted
|
|
329
325
|
# as serialized arrays. This means raw bytes fields must not be
|
|
330
326
|
# annotated as NDArray.
|
|
331
|
-
setattr(
|
|
327
|
+
setattr(self, var_name, eh.bytes_to_array(var_cur_value))
|
|
332
328
|
|
|
333
|
-
def __post_init__(
|
|
329
|
+
def __post_init__(self):
|
|
334
330
|
self._ensure_good()
|
|
335
331
|
|
|
336
332
|
##
|
|
337
333
|
|
|
338
334
|
@classmethod
|
|
339
|
-
def from_data(
|
|
335
|
+
def from_data(cls, data: WDSRawSample) -> Self:
|
|
340
336
|
"""Create a sample instance from unpacked msgpack data.
|
|
341
337
|
|
|
342
338
|
Args:
|
|
@@ -345,10 +341,10 @@ class PackableSample( ABC ):
|
|
|
345
341
|
Returns:
|
|
346
342
|
New instance with NDArray fields auto-converted from bytes.
|
|
347
343
|
"""
|
|
348
|
-
return cls(
|
|
349
|
-
|
|
344
|
+
return cls(**data)
|
|
345
|
+
|
|
350
346
|
@classmethod
|
|
351
|
-
def from_bytes(
|
|
347
|
+
def from_bytes(cls, bs: bytes) -> Self:
|
|
352
348
|
"""Create a sample instance from raw msgpack bytes.
|
|
353
349
|
|
|
354
350
|
Args:
|
|
@@ -357,10 +353,10 @@ class PackableSample( ABC ):
|
|
|
357
353
|
Returns:
|
|
358
354
|
A new instance of this sample class deserialized from the bytes.
|
|
359
355
|
"""
|
|
360
|
-
return cls.from_data(
|
|
356
|
+
return cls.from_data(ormsgpack.unpackb(bs))
|
|
361
357
|
|
|
362
358
|
@property
|
|
363
|
-
def packed(
|
|
359
|
+
def packed(self) -> bytes:
|
|
364
360
|
"""Pack this sample's data into msgpack bytes.
|
|
365
361
|
|
|
366
362
|
NDArray fields are automatically converted to bytes before packing.
|
|
@@ -375,20 +371,17 @@ class PackableSample( ABC ):
|
|
|
375
371
|
|
|
376
372
|
# Make sure that all of our (possibly unpackable) data is in a packable
|
|
377
373
|
# format
|
|
378
|
-
o = {
|
|
379
|
-
k: _make_packable( v )
|
|
380
|
-
for k, v in vars( self ).items()
|
|
381
|
-
}
|
|
374
|
+
o = {k: _make_packable(v) for k, v in vars(self).items()}
|
|
382
375
|
|
|
383
|
-
ret = msgpack.packb(
|
|
376
|
+
ret = msgpack.packb(o)
|
|
384
377
|
|
|
385
378
|
if ret is None:
|
|
386
|
-
raise RuntimeError(
|
|
379
|
+
raise RuntimeError(f"Failed to pack sample to bytes: {o}")
|
|
387
380
|
|
|
388
381
|
return ret
|
|
389
|
-
|
|
382
|
+
|
|
390
383
|
@property
|
|
391
|
-
def as_wds(
|
|
384
|
+
def as_wds(self) -> WDSRawSample:
|
|
392
385
|
"""Pack this sample's data for writing to WebDataset.
|
|
393
386
|
|
|
394
387
|
Returns:
|
|
@@ -401,19 +394,21 @@ class PackableSample( ABC ):
|
|
|
401
394
|
"""
|
|
402
395
|
return {
|
|
403
396
|
# Generates a UUID that is timelike-sortable
|
|
404
|
-
|
|
405
|
-
|
|
397
|
+
"__key__": str(uuid.uuid1(0, 0)),
|
|
398
|
+
"msgpack": self.packed,
|
|
406
399
|
}
|
|
407
400
|
|
|
408
|
-
|
|
401
|
+
|
|
402
|
+
def _batch_aggregate(xs: Sequence):
|
|
409
403
|
"""Stack arrays into numpy array with batch dim; otherwise return list."""
|
|
410
404
|
if not xs:
|
|
411
405
|
return []
|
|
412
|
-
if isinstance(
|
|
413
|
-
return np.array(
|
|
414
|
-
return list(
|
|
406
|
+
if isinstance(xs[0], np.ndarray):
|
|
407
|
+
return np.array(list(xs))
|
|
408
|
+
return list(xs)
|
|
409
|
+
|
|
415
410
|
|
|
416
|
-
class SampleBatch(
|
|
411
|
+
class SampleBatch(Generic[DT]):
|
|
417
412
|
"""A batch of samples with automatic attribute aggregation.
|
|
418
413
|
|
|
419
414
|
This class wraps a sequence of samples and provides magic ``__getattr__``
|
|
@@ -430,12 +425,10 @@ class SampleBatch( Generic[DT] ):
|
|
|
430
425
|
Attributes:
|
|
431
426
|
samples: The list of sample instances in this batch.
|
|
432
427
|
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
>>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
|
|
438
|
-
>>> batch.names # Returns list of names
|
|
428
|
+
Examples:
|
|
429
|
+
>>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
|
|
430
|
+
>>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
|
|
431
|
+
>>> batch.names # Returns list of names
|
|
439
432
|
|
|
440
433
|
Note:
|
|
441
434
|
This class uses Python's ``__orig_class__`` mechanism to extract the
|
|
@@ -443,10 +436,11 @@ class SampleBatch( Generic[DT] ):
|
|
|
443
436
|
subscripted syntax ``SampleBatch[MyType](samples)`` rather than
|
|
444
437
|
calling the constructor directly with an unsubscripted class.
|
|
445
438
|
"""
|
|
439
|
+
|
|
446
440
|
# Design note: The docstring uses "Parameters:" for type parameters because
|
|
447
441
|
# quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
|
|
448
442
|
|
|
449
|
-
def __init__(
|
|
443
|
+
def __init__(self, samples: Sequence[DT]):
|
|
450
444
|
"""Create a batch from a sequence of samples.
|
|
451
445
|
|
|
452
446
|
Args:
|
|
@@ -454,23 +448,23 @@ class SampleBatch( Generic[DT] ):
|
|
|
454
448
|
Each sample must be an instance of a type derived from
|
|
455
449
|
``PackableSample``.
|
|
456
450
|
"""
|
|
457
|
-
self.samples = list(
|
|
451
|
+
self.samples = list(samples)
|
|
458
452
|
self._aggregate_cache = dict()
|
|
459
453
|
self._sample_type_cache: Type | None = None
|
|
460
454
|
|
|
461
455
|
@property
|
|
462
|
-
def sample_type(
|
|
456
|
+
def sample_type(self) -> Type:
|
|
463
457
|
"""The type of each sample in this batch.
|
|
464
458
|
|
|
465
459
|
Returns:
|
|
466
460
|
The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
|
|
467
461
|
"""
|
|
468
462
|
if self._sample_type_cache is None:
|
|
469
|
-
self._sample_type_cache = typing.get_args(
|
|
463
|
+
self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
|
|
470
464
|
assert self._sample_type_cache is not None
|
|
471
465
|
return self._sample_type_cache
|
|
472
466
|
|
|
473
|
-
def __getattr__(
|
|
467
|
+
def __getattr__(self, name):
|
|
474
468
|
"""Aggregate an attribute across all samples in the batch.
|
|
475
469
|
|
|
476
470
|
This magic method enables attribute-style access to aggregated sample
|
|
@@ -487,20 +481,19 @@ class SampleBatch( Generic[DT] ):
|
|
|
487
481
|
AttributeError: If the attribute doesn't exist on the sample type.
|
|
488
482
|
"""
|
|
489
483
|
# Aggregate named params of sample type
|
|
490
|
-
if name in vars(
|
|
484
|
+
if name in vars(self.sample_type)["__annotations__"]:
|
|
491
485
|
if name not in self._aggregate_cache:
|
|
492
486
|
self._aggregate_cache[name] = _batch_aggregate(
|
|
493
|
-
[
|
|
494
|
-
for x in self.samples ]
|
|
487
|
+
[getattr(x, name) for x in self.samples]
|
|
495
488
|
)
|
|
496
489
|
|
|
497
490
|
return self._aggregate_cache[name]
|
|
498
491
|
|
|
499
|
-
raise AttributeError(
|
|
492
|
+
raise AttributeError(f"No sample attribute named {name}")
|
|
500
493
|
|
|
501
494
|
|
|
502
|
-
ST = TypeVar(
|
|
503
|
-
RT = TypeVar(
|
|
495
|
+
ST = TypeVar("ST", bound=PackableSample)
|
|
496
|
+
RT = TypeVar("RT", bound=PackableSample)
|
|
504
497
|
|
|
505
498
|
|
|
506
499
|
class _ShardListStage(wds.utils.PipelineStage):
|
|
@@ -538,7 +531,7 @@ class _StreamOpenerStage(wds.utils.PipelineStage):
|
|
|
538
531
|
yield sample
|
|
539
532
|
|
|
540
533
|
|
|
541
|
-
class Dataset(
|
|
534
|
+
class Dataset(Generic[ST]):
|
|
542
535
|
"""A typed dataset built on WebDataset with lens transformations.
|
|
543
536
|
|
|
544
537
|
This class wraps WebDataset tar archives and provides type-safe iteration
|
|
@@ -557,16 +550,14 @@ class Dataset( Generic[ST] ):
|
|
|
557
550
|
Attributes:
|
|
558
551
|
url: WebDataset brace-notation URL for the tar file(s).
|
|
559
552
|
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
>>> # Transform to a different view
|
|
569
|
-
>>> ds_view = ds.as_type(MyDataView)
|
|
553
|
+
Examples:
|
|
554
|
+
>>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
|
|
555
|
+
>>> for sample in ds.ordered(batch_size=32):
|
|
556
|
+
... # sample is SampleBatch[MyData] with batch_size samples
|
|
557
|
+
... embeddings = sample.embeddings # shape: (32, ...)
|
|
558
|
+
...
|
|
559
|
+
>>> # Transform to a different view
|
|
560
|
+
>>> ds_view = ds.as_type(MyDataView)
|
|
570
561
|
|
|
571
562
|
Note:
|
|
572
563
|
This class uses Python's ``__orig_class__`` mechanism to extract the
|
|
@@ -574,22 +565,24 @@ class Dataset( Generic[ST] ):
|
|
|
574
565
|
subscripted syntax ``Dataset[MyType](url)`` rather than calling the
|
|
575
566
|
constructor directly with an unsubscripted class.
|
|
576
567
|
"""
|
|
568
|
+
|
|
577
569
|
# Design note: The docstring uses "Parameters:" for type parameters because
|
|
578
570
|
# quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
|
|
579
571
|
|
|
580
572
|
@property
|
|
581
|
-
def sample_type(
|
|
573
|
+
def sample_type(self) -> Type:
|
|
582
574
|
"""The type of each returned sample from this dataset's iterator.
|
|
583
575
|
|
|
584
576
|
Returns:
|
|
585
577
|
The type parameter ``ST`` used when creating this ``Dataset[ST]``.
|
|
586
578
|
"""
|
|
587
579
|
if self._sample_type_cache is None:
|
|
588
|
-
self._sample_type_cache = typing.get_args(
|
|
580
|
+
self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
|
|
589
581
|
assert self._sample_type_cache is not None
|
|
590
582
|
return self._sample_type_cache
|
|
583
|
+
|
|
591
584
|
@property
|
|
592
|
-
def batch_type(
|
|
585
|
+
def batch_type(self) -> Type:
|
|
593
586
|
"""The type of batches produced by this dataset.
|
|
594
587
|
|
|
595
588
|
Returns:
|
|
@@ -597,12 +590,13 @@ class Dataset( Generic[ST] ):
|
|
|
597
590
|
"""
|
|
598
591
|
return SampleBatch[self.sample_type]
|
|
599
592
|
|
|
600
|
-
def __init__(
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
593
|
+
def __init__(
|
|
594
|
+
self,
|
|
595
|
+
source: DataSource | str | None = None,
|
|
596
|
+
metadata_url: str | None = None,
|
|
597
|
+
*,
|
|
598
|
+
url: str | None = None,
|
|
599
|
+
) -> None:
|
|
606
600
|
"""Create a dataset from a DataSource or URL.
|
|
607
601
|
|
|
608
602
|
Args:
|
|
@@ -650,7 +644,7 @@ class Dataset( Generic[ST] ):
|
|
|
650
644
|
"""The underlying data source for this dataset."""
|
|
651
645
|
return self._source
|
|
652
646
|
|
|
653
|
-
def as_type(
|
|
647
|
+
def as_type(self, other: Type[RT]) -> "Dataset[RT]":
|
|
654
648
|
"""View this dataset through a different sample type using a registered lens.
|
|
655
649
|
|
|
656
650
|
Args:
|
|
@@ -666,10 +660,10 @@ class Dataset( Generic[ST] ):
|
|
|
666
660
|
ValueError: If no registered lens exists between the current
|
|
667
661
|
sample type and the target type.
|
|
668
662
|
"""
|
|
669
|
-
ret = Dataset[other](
|
|
663
|
+
ret = Dataset[other](self._source)
|
|
670
664
|
# Get the singleton lens registry
|
|
671
665
|
lenses = LensNetwork()
|
|
672
|
-
ret._output_lens = lenses.transform(
|
|
666
|
+
ret._output_lens = lenses.transform(self.sample_type, ret.sample_type)
|
|
673
667
|
return ret
|
|
674
668
|
|
|
675
669
|
@property
|
|
@@ -679,11 +673,9 @@ class Dataset( Generic[ST] ):
|
|
|
679
673
|
Yields:
|
|
680
674
|
Shard identifiers (e.g., 'train-000000.tar', 'train-000001.tar').
|
|
681
675
|
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
>>> for shard in ds.shards:
|
|
686
|
-
... print(f"Processing {shard}")
|
|
676
|
+
Examples:
|
|
677
|
+
>>> for shard in ds.shards:
|
|
678
|
+
... print(f"Processing {shard}")
|
|
687
679
|
"""
|
|
688
680
|
return iter(self._source.list_shards())
|
|
689
681
|
|
|
@@ -705,6 +697,7 @@ class Dataset( Generic[ST] ):
|
|
|
705
697
|
Use :meth:`list_shards` instead.
|
|
706
698
|
"""
|
|
707
699
|
import warnings
|
|
700
|
+
|
|
708
701
|
warnings.warn(
|
|
709
702
|
"shard_list is deprecated, use list_shards() instead",
|
|
710
703
|
DeprecationWarning,
|
|
@@ -713,7 +706,7 @@ class Dataset( Generic[ST] ):
|
|
|
713
706
|
return self.list_shards()
|
|
714
707
|
|
|
715
708
|
@property
|
|
716
|
-
def metadata(
|
|
709
|
+
def metadata(self) -> dict[str, Any] | None:
|
|
717
710
|
"""Fetch and cache metadata from metadata_url.
|
|
718
711
|
|
|
719
712
|
Returns:
|
|
@@ -726,27 +719,47 @@ class Dataset( Generic[ST] ):
|
|
|
726
719
|
return None
|
|
727
720
|
|
|
728
721
|
if self._metadata is None:
|
|
729
|
-
with requests.get(
|
|
722
|
+
with requests.get(self.metadata_url, stream=True) as response:
|
|
730
723
|
response.raise_for_status()
|
|
731
|
-
self._metadata = msgpack.unpackb(
|
|
732
|
-
|
|
724
|
+
self._metadata = msgpack.unpackb(response.content, raw=False)
|
|
725
|
+
|
|
733
726
|
# Use our cached values
|
|
734
727
|
return self._metadata
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
728
|
+
|
|
729
|
+
@overload
|
|
730
|
+
def ordered(
|
|
731
|
+
self,
|
|
732
|
+
batch_size: None = None,
|
|
733
|
+
) -> Iterable[ST]: ...
|
|
734
|
+
|
|
735
|
+
@overload
|
|
736
|
+
def ordered(
|
|
737
|
+
self,
|
|
738
|
+
batch_size: int,
|
|
739
|
+
) -> Iterable[SampleBatch[ST]]: ...
|
|
740
|
+
|
|
741
|
+
def ordered(
|
|
742
|
+
self,
|
|
743
|
+
batch_size: int | None = None,
|
|
744
|
+
) -> Iterable[ST] | Iterable[SampleBatch[ST]]:
|
|
745
|
+
"""Iterate over the dataset in order.
|
|
740
746
|
|
|
741
747
|
Args:
|
|
742
|
-
batch_size
|
|
743
|
-
|
|
744
|
-
|
|
748
|
+
batch_size: The size of iterated batches. Default: None (unbatched).
|
|
749
|
+
If ``None``, iterates over one sample at a time with no batch
|
|
750
|
+
dimension.
|
|
745
751
|
|
|
746
752
|
Returns:
|
|
747
|
-
|
|
748
|
-
|
|
753
|
+
A data pipeline that iterates over the dataset in its original
|
|
754
|
+
sample order. When ``batch_size`` is ``None``, yields individual
|
|
755
|
+
samples of type ``ST``. When ``batch_size`` is an integer, yields
|
|
756
|
+
``SampleBatch[ST]`` instances containing that many samples.
|
|
749
757
|
|
|
758
|
+
Examples:
|
|
759
|
+
>>> for sample in ds.ordered():
|
|
760
|
+
... process(sample) # sample is ST
|
|
761
|
+
>>> for batch in ds.ordered(batch_size=32):
|
|
762
|
+
... process(batch) # batch is SampleBatch[ST]
|
|
750
763
|
"""
|
|
751
764
|
if batch_size is None:
|
|
752
765
|
return wds.pipeline.DataPipeline(
|
|
@@ -755,7 +768,7 @@ class Dataset( Generic[ST] ):
|
|
|
755
768
|
_StreamOpenerStage(self._source),
|
|
756
769
|
wds.tariterators.tar_file_expander,
|
|
757
770
|
wds.tariterators.group_by_keys,
|
|
758
|
-
wds.filters.map(
|
|
771
|
+
wds.filters.map(self.wrap),
|
|
759
772
|
)
|
|
760
773
|
|
|
761
774
|
return wds.pipeline.DataPipeline(
|
|
@@ -764,15 +777,33 @@ class Dataset( Generic[ST] ):
|
|
|
764
777
|
_StreamOpenerStage(self._source),
|
|
765
778
|
wds.tariterators.tar_file_expander,
|
|
766
779
|
wds.tariterators.group_by_keys,
|
|
767
|
-
wds.filters.batched(
|
|
768
|
-
wds.filters.map(
|
|
780
|
+
wds.filters.batched(batch_size),
|
|
781
|
+
wds.filters.map(self.wrap_batch),
|
|
769
782
|
)
|
|
770
783
|
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
784
|
+
@overload
|
|
785
|
+
def shuffled(
|
|
786
|
+
self,
|
|
787
|
+
buffer_shards: int = 100,
|
|
788
|
+
buffer_samples: int = 10_000,
|
|
789
|
+
batch_size: None = None,
|
|
790
|
+
) -> Iterable[ST]: ...
|
|
791
|
+
|
|
792
|
+
@overload
|
|
793
|
+
def shuffled(
|
|
794
|
+
self,
|
|
795
|
+
buffer_shards: int = 100,
|
|
796
|
+
buffer_samples: int = 10_000,
|
|
797
|
+
*,
|
|
798
|
+
batch_size: int,
|
|
799
|
+
) -> Iterable[SampleBatch[ST]]: ...
|
|
800
|
+
|
|
801
|
+
def shuffled(
|
|
802
|
+
self,
|
|
803
|
+
buffer_shards: int = 100,
|
|
804
|
+
buffer_samples: int = 10_000,
|
|
805
|
+
batch_size: int | None = None,
|
|
806
|
+
) -> Iterable[ST] | Iterable[SampleBatch[ST]]:
|
|
776
807
|
"""Iterate over the dataset in random order.
|
|
777
808
|
|
|
778
809
|
Args:
|
|
@@ -787,42 +818,50 @@ class Dataset( Generic[ST] ):
|
|
|
787
818
|
dimension.
|
|
788
819
|
|
|
789
820
|
Returns:
|
|
790
|
-
A
|
|
791
|
-
|
|
792
|
-
``
|
|
793
|
-
samples.
|
|
821
|
+
A data pipeline that iterates over the dataset in randomized order.
|
|
822
|
+
When ``batch_size`` is ``None``, yields individual samples of type
|
|
823
|
+
``ST``. When ``batch_size`` is an integer, yields ``SampleBatch[ST]``
|
|
824
|
+
instances containing that many samples.
|
|
825
|
+
|
|
826
|
+
Examples:
|
|
827
|
+
>>> for sample in ds.shuffled():
|
|
828
|
+
... process(sample) # sample is ST
|
|
829
|
+
>>> for batch in ds.shuffled(batch_size=32):
|
|
830
|
+
... process(batch) # batch is SampleBatch[ST]
|
|
794
831
|
"""
|
|
795
832
|
if batch_size is None:
|
|
796
833
|
return wds.pipeline.DataPipeline(
|
|
797
834
|
_ShardListStage(self._source),
|
|
798
|
-
wds.filters.shuffle(
|
|
835
|
+
wds.filters.shuffle(buffer_shards),
|
|
799
836
|
wds.shardlists.split_by_worker,
|
|
800
837
|
_StreamOpenerStage(self._source),
|
|
801
838
|
wds.tariterators.tar_file_expander,
|
|
802
839
|
wds.tariterators.group_by_keys,
|
|
803
|
-
wds.filters.shuffle(
|
|
804
|
-
wds.filters.map(
|
|
840
|
+
wds.filters.shuffle(buffer_samples),
|
|
841
|
+
wds.filters.map(self.wrap),
|
|
805
842
|
)
|
|
806
843
|
|
|
807
844
|
return wds.pipeline.DataPipeline(
|
|
808
845
|
_ShardListStage(self._source),
|
|
809
|
-
wds.filters.shuffle(
|
|
846
|
+
wds.filters.shuffle(buffer_shards),
|
|
810
847
|
wds.shardlists.split_by_worker,
|
|
811
848
|
_StreamOpenerStage(self._source),
|
|
812
849
|
wds.tariterators.tar_file_expander,
|
|
813
850
|
wds.tariterators.group_by_keys,
|
|
814
|
-
wds.filters.shuffle(
|
|
815
|
-
wds.filters.batched(
|
|
816
|
-
wds.filters.map(
|
|
851
|
+
wds.filters.shuffle(buffer_samples),
|
|
852
|
+
wds.filters.batched(batch_size),
|
|
853
|
+
wds.filters.map(self.wrap_batch),
|
|
817
854
|
)
|
|
818
|
-
|
|
855
|
+
|
|
819
856
|
# Design note: Uses pandas for parquet export. Could be replaced with
|
|
820
857
|
# direct fastparquet calls to reduce dependencies if needed.
|
|
821
|
-
def to_parquet(
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
858
|
+
def to_parquet(
|
|
859
|
+
self,
|
|
860
|
+
path: Pathlike,
|
|
861
|
+
sample_map: Optional[SampleExportMap] = None,
|
|
862
|
+
maxcount: Optional[int] = None,
|
|
863
|
+
**kwargs,
|
|
864
|
+
):
|
|
826
865
|
"""Export dataset contents to parquet format.
|
|
827
866
|
|
|
828
867
|
Converts all samples to a pandas DataFrame and saves to parquet file(s).
|
|
@@ -851,63 +890,62 @@ class Dataset( Generic[ST] ):
|
|
|
851
890
|
This creates multiple parquet files: ``output-000000.parquet``,
|
|
852
891
|
``output-000001.parquet``, etc.
|
|
853
892
|
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
>>> # Large dataset - process in chunks
|
|
862
|
-
>>> ds.to_parquet("output.parquet", maxcount=50000)
|
|
893
|
+
Examples:
|
|
894
|
+
>>> ds = Dataset[MySample]("data.tar")
|
|
895
|
+
>>> # Small dataset - load all at once
|
|
896
|
+
>>> ds.to_parquet("output.parquet")
|
|
897
|
+
>>>
|
|
898
|
+
>>> # Large dataset - process in chunks
|
|
899
|
+
>>> ds.to_parquet("output.parquet", maxcount=50000)
|
|
863
900
|
"""
|
|
864
901
|
##
|
|
865
902
|
|
|
866
903
|
# Normalize args
|
|
867
|
-
path = Path(
|
|
904
|
+
path = Path(path)
|
|
868
905
|
if sample_map is None:
|
|
869
906
|
sample_map = asdict
|
|
870
|
-
|
|
871
|
-
verbose = kwargs.get( 'verbose', False )
|
|
872
907
|
|
|
873
|
-
|
|
908
|
+
verbose = kwargs.get("verbose", False)
|
|
909
|
+
|
|
910
|
+
it = self.ordered(batch_size=None)
|
|
874
911
|
if verbose:
|
|
875
|
-
it = tqdm(
|
|
912
|
+
it = tqdm(it)
|
|
876
913
|
|
|
877
914
|
#
|
|
878
915
|
|
|
879
916
|
if maxcount is None:
|
|
880
917
|
# Load and save full dataset
|
|
881
|
-
df = pd.DataFrame(
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
918
|
+
df = pd.DataFrame([sample_map(x) for x in self.ordered(batch_size=None)])
|
|
919
|
+
df.to_parquet(path, **kwargs)
|
|
920
|
+
|
|
885
921
|
else:
|
|
886
922
|
# Load and save dataset in segments of size `maxcount`
|
|
887
923
|
|
|
888
924
|
cur_segment = 0
|
|
889
925
|
cur_buffer = []
|
|
890
|
-
path_template = (
|
|
926
|
+
path_template = (
|
|
927
|
+
path.parent / f"{path.stem}-{{:06d}}{path.suffix}"
|
|
928
|
+
).as_posix()
|
|
891
929
|
|
|
892
|
-
for x in self.ordered(
|
|
893
|
-
cur_buffer.append(
|
|
930
|
+
for x in self.ordered(batch_size=None):
|
|
931
|
+
cur_buffer.append(sample_map(x))
|
|
894
932
|
|
|
895
|
-
if len(
|
|
933
|
+
if len(cur_buffer) >= maxcount:
|
|
896
934
|
# Write current segment
|
|
897
|
-
cur_path = path_template.format(
|
|
898
|
-
df = pd.DataFrame(
|
|
899
|
-
df.to_parquet(
|
|
935
|
+
cur_path = path_template.format(cur_segment)
|
|
936
|
+
df = pd.DataFrame(cur_buffer)
|
|
937
|
+
df.to_parquet(cur_path, **kwargs)
|
|
900
938
|
|
|
901
939
|
cur_segment += 1
|
|
902
940
|
cur_buffer = []
|
|
903
|
-
|
|
904
|
-
if len(
|
|
941
|
+
|
|
942
|
+
if len(cur_buffer) > 0:
|
|
905
943
|
# Write one last segment with remainder
|
|
906
|
-
cur_path = path_template.format(
|
|
907
|
-
df = pd.DataFrame(
|
|
908
|
-
df.to_parquet(
|
|
944
|
+
cur_path = path_template.format(cur_segment)
|
|
945
|
+
df = pd.DataFrame(cur_buffer)
|
|
946
|
+
df.to_parquet(cur_path, **kwargs)
|
|
909
947
|
|
|
910
|
-
def wrap(
|
|
948
|
+
def wrap(self, sample: WDSRawSample) -> ST:
|
|
911
949
|
"""Wrap a raw msgpack sample into the appropriate dataset-specific type.
|
|
912
950
|
|
|
913
951
|
Args:
|
|
@@ -918,18 +956,22 @@ class Dataset( Generic[ST] ):
|
|
|
918
956
|
A deserialized sample of type ``ST``, optionally transformed through
|
|
919
957
|
a lens if ``as_type()`` was called.
|
|
920
958
|
"""
|
|
921
|
-
if
|
|
922
|
-
raise ValueError(
|
|
923
|
-
|
|
924
|
-
|
|
959
|
+
if "msgpack" not in sample:
|
|
960
|
+
raise ValueError(
|
|
961
|
+
f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}"
|
|
962
|
+
)
|
|
963
|
+
if not isinstance(sample["msgpack"], bytes):
|
|
964
|
+
raise ValueError(
|
|
965
|
+
f"Expected sample['msgpack'] to be bytes, got {type(sample['msgpack']).__name__}"
|
|
966
|
+
)
|
|
925
967
|
|
|
926
968
|
if self._output_lens is None:
|
|
927
|
-
return self.sample_type.from_bytes(
|
|
969
|
+
return self.sample_type.from_bytes(sample["msgpack"])
|
|
928
970
|
|
|
929
|
-
source_sample = self._output_lens.source_type.from_bytes(
|
|
930
|
-
return self._output_lens(
|
|
971
|
+
source_sample = self._output_lens.source_type.from_bytes(sample["msgpack"])
|
|
972
|
+
return self._output_lens(source_sample)
|
|
931
973
|
|
|
932
|
-
def wrap_batch(
|
|
974
|
+
def wrap_batch(self, batch: WDSRawBatch) -> SampleBatch[ST]:
|
|
933
975
|
"""Wrap a batch of raw msgpack samples into a typed SampleBatch.
|
|
934
976
|
|
|
935
977
|
Args:
|
|
@@ -945,26 +987,29 @@ class Dataset( Generic[ST] ):
|
|
|
945
987
|
aggregates them into a batch.
|
|
946
988
|
"""
|
|
947
989
|
|
|
948
|
-
if
|
|
949
|
-
raise ValueError(
|
|
990
|
+
if "msgpack" not in batch:
|
|
991
|
+
raise ValueError(
|
|
992
|
+
f"Batch missing 'msgpack' key, got keys: {list(batch.keys())}"
|
|
993
|
+
)
|
|
950
994
|
|
|
951
995
|
if self._output_lens is None:
|
|
952
|
-
batch_unpacked = [
|
|
953
|
-
|
|
954
|
-
|
|
996
|
+
batch_unpacked = [
|
|
997
|
+
self.sample_type.from_bytes(bs) for bs in batch["msgpack"]
|
|
998
|
+
]
|
|
999
|
+
return SampleBatch[self.sample_type](batch_unpacked)
|
|
955
1000
|
|
|
956
|
-
batch_source = [
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
return SampleBatch[self.sample_type](
|
|
1001
|
+
batch_source = [
|
|
1002
|
+
self._output_lens.source_type.from_bytes(bs) for bs in batch["msgpack"]
|
|
1003
|
+
]
|
|
1004
|
+
batch_view = [self._output_lens(s) for s in batch_source]
|
|
1005
|
+
return SampleBatch[self.sample_type](batch_view)
|
|
961
1006
|
|
|
962
1007
|
|
|
963
|
-
_T = TypeVar(
|
|
1008
|
+
_T = TypeVar("_T")
|
|
964
1009
|
|
|
965
1010
|
|
|
966
1011
|
@dataclass_transform()
|
|
967
|
-
def packable(
|
|
1012
|
+
def packable(cls: type[_T]) -> type[_T]:
|
|
968
1013
|
"""Decorator to convert a regular class into a ``PackableSample``.
|
|
969
1014
|
|
|
970
1015
|
This decorator transforms a class into a dataclass that inherits from
|
|
@@ -984,19 +1029,17 @@ def packable( cls: type[_T] ) -> type[_T]:
|
|
|
984
1029
|
``Packable`` protocol and can be used with ``Type[Packable]`` signatures.
|
|
985
1030
|
|
|
986
1031
|
Examples:
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
# Works with Packable-typed APIs
|
|
999
|
-
index.publish_schema(MyData, version="1.0.0") # Type-safe
|
|
1032
|
+
>>> @packable
|
|
1033
|
+
... class MyData:
|
|
1034
|
+
... name: str
|
|
1035
|
+
... values: NDArray
|
|
1036
|
+
...
|
|
1037
|
+
>>> sample = MyData(name="test", values=np.array([1, 2, 3]))
|
|
1038
|
+
>>> bytes_data = sample.packed
|
|
1039
|
+
>>> restored = MyData.from_bytes(bytes_data)
|
|
1040
|
+
>>>
|
|
1041
|
+
>>> # Works with Packable-typed APIs
|
|
1042
|
+
>>> index.publish_schema(MyData, version="1.0.0") # Type-safe
|
|
1000
1043
|
"""
|
|
1001
1044
|
|
|
1002
1045
|
##
|
|
@@ -1005,14 +1048,14 @@ def packable( cls: type[_T] ) -> type[_T]:
|
|
|
1005
1048
|
class_annotations = cls.__annotations__
|
|
1006
1049
|
|
|
1007
1050
|
# Add in dataclass niceness to original class
|
|
1008
|
-
as_dataclass = dataclass(
|
|
1051
|
+
as_dataclass = dataclass(cls)
|
|
1009
1052
|
|
|
1010
1053
|
# This triggers a bunch of behind-the-scenes stuff for the newly annotated class
|
|
1011
1054
|
@dataclass
|
|
1012
|
-
class as_packable(
|
|
1013
|
-
def __post_init__(
|
|
1014
|
-
return PackableSample.__post_init__(
|
|
1015
|
-
|
|
1055
|
+
class as_packable(as_dataclass, PackableSample):
|
|
1056
|
+
def __post_init__(self):
|
|
1057
|
+
return PackableSample.__post_init__(self)
|
|
1058
|
+
|
|
1016
1059
|
# Restore original class identity for better repr/debugging
|
|
1017
1060
|
as_packable.__name__ = class_name
|
|
1018
1061
|
as_packable.__qualname__ = class_name
|
|
@@ -1023,10 +1066,10 @@ def packable( cls: type[_T] ) -> type[_T]:
|
|
|
1023
1066
|
|
|
1024
1067
|
# Fix qualnames of dataclass-generated methods so they don't show
|
|
1025
1068
|
# 'packable.<locals>.as_packable' in help() and IDE hints
|
|
1026
|
-
old_qualname_prefix =
|
|
1027
|
-
for attr_name in (
|
|
1069
|
+
old_qualname_prefix = "packable.<locals>.as_packable"
|
|
1070
|
+
for attr_name in ("__init__", "__repr__", "__eq__", "__post_init__"):
|
|
1028
1071
|
attr = getattr(as_packable, attr_name, None)
|
|
1029
|
-
if attr is not None and hasattr(attr,
|
|
1072
|
+
if attr is not None and hasattr(attr, "__qualname__"):
|
|
1030
1073
|
if attr.__qualname__.startswith(old_qualname_prefix):
|
|
1031
1074
|
attr.__qualname__ = attr.__qualname__.replace(
|
|
1032
1075
|
old_qualname_prefix, class_name, 1
|
|
@@ -1042,4 +1085,4 @@ def packable( cls: type[_T] ) -> type[_T]:
|
|
|
1042
1085
|
|
|
1043
1086
|
##
|
|
1044
1087
|
|
|
1045
|
-
return as_packable
|
|
1088
|
+
return as_packable
|