atdata 0.2.2b1__py3-none-any.whl → 0.2.3b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +1 -1
- atdata/_cid.py +29 -35
- atdata/_helpers.py +7 -5
- atdata/_hf_api.py +48 -50
- atdata/_protocols.py +56 -71
- atdata/_schema_codec.py +33 -37
- atdata/_sources.py +57 -64
- atdata/_stub_manager.py +31 -26
- atdata/_type_utils.py +19 -5
- atdata/atmosphere/__init__.py +20 -23
- atdata/atmosphere/_types.py +11 -11
- atdata/atmosphere/client.py +11 -8
- atdata/atmosphere/lens.py +27 -30
- atdata/atmosphere/records.py +31 -37
- atdata/atmosphere/schema.py +33 -29
- atdata/atmosphere/store.py +16 -20
- atdata/cli/__init__.py +12 -3
- atdata/cli/diagnose.py +12 -8
- atdata/cli/local.py +4 -1
- atdata/dataset.py +284 -241
- atdata/lens.py +77 -82
- atdata/local.py +182 -169
- atdata/promote.py +18 -22
- {atdata-0.2.2b1.dist-info → atdata-0.2.3b1.dist-info}/METADATA +2 -1
- atdata-0.2.3b1.dist-info/RECORD +28 -0
- atdata-0.2.2b1.dist-info/RECORD +0 -28
- {atdata-0.2.2b1.dist-info → atdata-0.2.3b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.2b1.dist-info → atdata-0.2.3b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.2b1.dist-info → atdata-0.2.3b1.dist-info}/licenses/LICENSE +0 -0
atdata/__init__.py
CHANGED
atdata/_cid.py
CHANGED
|
@@ -12,13 +12,11 @@ The CIDs generated here use:
|
|
|
12
12
|
This ensures compatibility with ATProto's CID requirements and enables
|
|
13
13
|
seamless promotion from local storage to atmosphere (ATProto network).
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
>>> print(cid)
|
|
21
|
-
bafyreihffx5a2e7k6r5zqgp5iwpjqr2gfyheqhzqtlxagvqjqyxzqpzqaa
|
|
15
|
+
Examples:
|
|
16
|
+
>>> schema = {"name": "ImageSample", "version": "1.0.0", "fields": [...]}
|
|
17
|
+
>>> cid = generate_cid(schema)
|
|
18
|
+
>>> print(cid)
|
|
19
|
+
bafyreihffx5a2e7k6r5zqgp5iwpjqr2gfyheqhzqtlxagvqjqyxzqpzqaa
|
|
22
20
|
"""
|
|
23
21
|
|
|
24
22
|
import hashlib
|
|
@@ -50,11 +48,9 @@ def generate_cid(data: Any) -> str:
|
|
|
50
48
|
Raises:
|
|
51
49
|
ValueError: If the data cannot be encoded as DAG-CBOR.
|
|
52
50
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
>>> generate_cid({"name": "test", "value": 42})
|
|
57
|
-
'bafyrei...'
|
|
51
|
+
Examples:
|
|
52
|
+
>>> generate_cid({"name": "test", "value": 42})
|
|
53
|
+
'bafyrei...'
|
|
58
54
|
"""
|
|
59
55
|
# Encode data as DAG-CBOR
|
|
60
56
|
try:
|
|
@@ -68,7 +64,9 @@ def generate_cid(data: Any) -> str:
|
|
|
68
64
|
# Build raw CID bytes:
|
|
69
65
|
# CIDv1 = version(1) + codec(dag-cbor) + multihash
|
|
70
66
|
# Multihash = code(sha256) + size(32) + digest
|
|
71
|
-
raw_cid_bytes =
|
|
67
|
+
raw_cid_bytes = (
|
|
68
|
+
bytes([CID_VERSION_1, CODEC_DAG_CBOR, HASH_SHA256, SHA256_SIZE]) + sha256_hash
|
|
69
|
+
)
|
|
72
70
|
|
|
73
71
|
# Encode to base32 multibase string
|
|
74
72
|
return libipld.encode_cid(raw_cid_bytes)
|
|
@@ -86,14 +84,14 @@ def generate_cid_from_bytes(data_bytes: bytes) -> str:
|
|
|
86
84
|
Returns:
|
|
87
85
|
CIDv1 string in base32 multibase format.
|
|
88
86
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
>>> cbor_bytes = libipld.encode_dag_cbor({"key": "value"})
|
|
93
|
-
>>> cid = generate_cid_from_bytes(cbor_bytes)
|
|
87
|
+
Examples:
|
|
88
|
+
>>> cbor_bytes = libipld.encode_dag_cbor({"key": "value"})
|
|
89
|
+
>>> cid = generate_cid_from_bytes(cbor_bytes)
|
|
94
90
|
"""
|
|
95
91
|
sha256_hash = hashlib.sha256(data_bytes).digest()
|
|
96
|
-
raw_cid_bytes =
|
|
92
|
+
raw_cid_bytes = (
|
|
93
|
+
bytes([CID_VERSION_1, CODEC_DAG_CBOR, HASH_SHA256, SHA256_SIZE]) + sha256_hash
|
|
94
|
+
)
|
|
97
95
|
return libipld.encode_cid(raw_cid_bytes)
|
|
98
96
|
|
|
99
97
|
|
|
@@ -107,14 +105,12 @@ def verify_cid(cid: str, data: Any) -> bool:
|
|
|
107
105
|
Returns:
|
|
108
106
|
True if the CID matches the data, False otherwise.
|
|
109
107
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
>>> verify_cid(cid, {"name": "different"})
|
|
117
|
-
False
|
|
108
|
+
Examples:
|
|
109
|
+
>>> cid = generate_cid({"name": "test"})
|
|
110
|
+
>>> verify_cid(cid, {"name": "test"})
|
|
111
|
+
True
|
|
112
|
+
>>> verify_cid(cid, {"name": "different"})
|
|
113
|
+
False
|
|
118
114
|
"""
|
|
119
115
|
expected_cid = generate_cid(data)
|
|
120
116
|
return cid == expected_cid
|
|
@@ -130,14 +126,12 @@ def parse_cid(cid: str) -> dict:
|
|
|
130
126
|
Dictionary with 'version', 'codec', and 'hash' keys.
|
|
131
127
|
The 'hash' value is itself a dict with 'code', 'size', and 'digest'.
|
|
132
128
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
>>> info['codec']
|
|
140
|
-
113 # 0x71 = dag-cbor
|
|
129
|
+
Examples:
|
|
130
|
+
>>> info = parse_cid('bafyrei...')
|
|
131
|
+
>>> info['version']
|
|
132
|
+
1
|
|
133
|
+
>>> info['codec']
|
|
134
|
+
113 # 0x71 = dag-cbor
|
|
141
135
|
"""
|
|
142
136
|
return libipld.decode_cid(cid)
|
|
143
137
|
|
atdata/_helpers.py
CHANGED
|
@@ -22,7 +22,8 @@ import numpy as np
|
|
|
22
22
|
|
|
23
23
|
##
|
|
24
24
|
|
|
25
|
-
|
|
25
|
+
|
|
26
|
+
def array_to_bytes(x: np.ndarray) -> bytes:
|
|
26
27
|
"""Convert a numpy array to bytes for msgpack serialization.
|
|
27
28
|
|
|
28
29
|
Uses numpy's native ``save()`` format to preserve array dtype and shape.
|
|
@@ -37,10 +38,11 @@ def array_to_bytes( x: np.ndarray ) -> bytes:
|
|
|
37
38
|
Uses ``allow_pickle=True`` to support object dtypes.
|
|
38
39
|
"""
|
|
39
40
|
np_bytes = BytesIO()
|
|
40
|
-
np.save(
|
|
41
|
+
np.save(np_bytes, x, allow_pickle=True)
|
|
41
42
|
return np_bytes.getvalue()
|
|
42
43
|
|
|
43
|
-
|
|
44
|
+
|
|
45
|
+
def bytes_to_array(b: bytes) -> np.ndarray:
|
|
44
46
|
"""Convert serialized bytes back to a numpy array.
|
|
45
47
|
|
|
46
48
|
Reverses the serialization performed by ``array_to_bytes()``.
|
|
@@ -54,5 +56,5 @@ def bytes_to_array( b: bytes ) -> np.ndarray:
|
|
|
54
56
|
Note:
|
|
55
57
|
Uses ``allow_pickle=True`` to support object dtypes.
|
|
56
58
|
"""
|
|
57
|
-
np_bytes = BytesIO(
|
|
58
|
-
return np.load(
|
|
59
|
+
np_bytes = BytesIO(b)
|
|
60
|
+
return np.load(np_bytes, allow_pickle=True)
|
atdata/_hf_api.py
CHANGED
|
@@ -9,23 +9,21 @@ Key differences from HuggingFace Datasets:
|
|
|
9
9
|
- Built on WebDataset for efficient streaming of large datasets
|
|
10
10
|
- No Arrow caching layer (WebDataset handles remote/local transparently)
|
|
11
11
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
>>> ds_dict = load_dataset("path/to/{train,test}-*.tar", MyData)
|
|
28
|
-
>>> train_ds = ds_dict["train"]
|
|
12
|
+
Examples:
|
|
13
|
+
>>> import atdata
|
|
14
|
+
>>> from atdata import load_dataset
|
|
15
|
+
>>>
|
|
16
|
+
>>> @atdata.packable
|
|
17
|
+
... class MyData:
|
|
18
|
+
... text: str
|
|
19
|
+
... label: int
|
|
20
|
+
>>>
|
|
21
|
+
>>> # Load a single split
|
|
22
|
+
>>> ds = load_dataset("path/to/train-{000000..000099}.tar", MyData, split="train")
|
|
23
|
+
>>>
|
|
24
|
+
>>> # Load all splits (returns DatasetDict)
|
|
25
|
+
>>> ds_dict = load_dataset("path/to/{train,test}-*.tar", MyData)
|
|
26
|
+
>>> train_ds = ds_dict["train"]
|
|
29
27
|
"""
|
|
30
28
|
|
|
31
29
|
from __future__ import annotations
|
|
@@ -48,7 +46,6 @@ from ._protocols import DataSource
|
|
|
48
46
|
|
|
49
47
|
if TYPE_CHECKING:
|
|
50
48
|
from ._protocols import AbstractIndex
|
|
51
|
-
from .local import S3DataStore
|
|
52
49
|
|
|
53
50
|
##
|
|
54
51
|
# Type variables
|
|
@@ -70,17 +67,16 @@ class DatasetDict(Generic[ST], dict):
|
|
|
70
67
|
Parameters:
|
|
71
68
|
ST: The sample type for all datasets in this dict.
|
|
72
69
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
>>> for split_name, dataset in ds_dict.items():
|
|
82
|
-
... print(f"{split_name}: {len(dataset.shard_list)} shards")
|
|
70
|
+
Examples:
|
|
71
|
+
>>> ds_dict = load_dataset("path/to/data", MyData)
|
|
72
|
+
>>> train = ds_dict["train"]
|
|
73
|
+
>>> test = ds_dict["test"]
|
|
74
|
+
>>>
|
|
75
|
+
>>> # Iterate over all splits
|
|
76
|
+
>>> for split_name, dataset in ds_dict.items():
|
|
77
|
+
... print(f"{split_name}: {len(dataset.shard_list)} shards")
|
|
83
78
|
"""
|
|
79
|
+
|
|
84
80
|
# TODO The above has a line for "Parameters:" that should be "Type Parameters:"; this is a temporary fix for `quartodoc` auto-generation bugs.
|
|
85
81
|
|
|
86
82
|
def __init__(
|
|
@@ -468,7 +464,7 @@ def _resolve_indexed_path(
|
|
|
468
464
|
data_urls = entry.data_urls
|
|
469
465
|
|
|
470
466
|
# Check if index has a data store
|
|
471
|
-
if hasattr(index,
|
|
467
|
+
if hasattr(index, "data_store") and index.data_store is not None:
|
|
472
468
|
store = index.data_store
|
|
473
469
|
|
|
474
470
|
# Import here to avoid circular imports at module level
|
|
@@ -613,25 +609,23 @@ def load_dataset(
|
|
|
613
609
|
FileNotFoundError: If no data files are found at the path.
|
|
614
610
|
KeyError: If dataset not found in index.
|
|
615
611
|
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
>>> index = LocalIndex()
|
|
634
|
-
>>> ds = load_dataset("@local/my-dataset", index=index, split="train")
|
|
612
|
+
Examples:
|
|
613
|
+
>>> # Load without type - get DictSample for exploration
|
|
614
|
+
>>> ds = load_dataset("./data/train.tar", split="train")
|
|
615
|
+
>>> for sample in ds.ordered():
|
|
616
|
+
... print(sample.keys()) # Explore fields
|
|
617
|
+
... print(sample["text"]) # Dict-style access
|
|
618
|
+
... print(sample.label) # Attribute access
|
|
619
|
+
>>>
|
|
620
|
+
>>> # Convert to typed schema
|
|
621
|
+
>>> typed_ds = ds.as_type(TextData)
|
|
622
|
+
>>>
|
|
623
|
+
>>> # Or load with explicit type directly
|
|
624
|
+
>>> train_ds = load_dataset("./data/train-*.tar", TextData, split="train")
|
|
625
|
+
>>>
|
|
626
|
+
>>> # Load from index with auto-type resolution
|
|
627
|
+
>>> index = LocalIndex()
|
|
628
|
+
>>> ds = load_dataset("@local/my-dataset", index=index, split="train")
|
|
635
629
|
"""
|
|
636
630
|
# Handle @handle/dataset indexed path resolution
|
|
637
631
|
if _is_indexed_path(path):
|
|
@@ -644,7 +638,9 @@ def load_dataset(
|
|
|
644
638
|
source, schema_ref = _resolve_indexed_path(path, index)
|
|
645
639
|
|
|
646
640
|
# Resolve sample_type from schema if not provided
|
|
647
|
-
resolved_type: Type =
|
|
641
|
+
resolved_type: Type = (
|
|
642
|
+
sample_type if sample_type is not None else index.decode_schema(schema_ref)
|
|
643
|
+
)
|
|
648
644
|
|
|
649
645
|
# Create dataset from the resolved source (includes credentials if S3)
|
|
650
646
|
ds = Dataset[resolved_type](source)
|
|
@@ -653,7 +649,9 @@ def load_dataset(
|
|
|
653
649
|
# Indexed datasets are single-split by default
|
|
654
650
|
return ds
|
|
655
651
|
|
|
656
|
-
return DatasetDict(
|
|
652
|
+
return DatasetDict(
|
|
653
|
+
{"train": ds}, sample_type=resolved_type, streaming=streaming
|
|
654
|
+
)
|
|
657
655
|
|
|
658
656
|
# Use DictSample as default when no type specified
|
|
659
657
|
resolved_type = sample_type if sample_type is not None else DictSample
|
atdata/_protocols.py
CHANGED
|
@@ -19,22 +19,19 @@ Protocols:
|
|
|
19
19
|
AbstractIndex: Protocol for index operations (schemas, datasets, lenses)
|
|
20
20
|
AbstractDataStore: Protocol for data storage operations
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
>>> process_datasets(local_index)
|
|
31
|
-
>>> process_datasets(atmosphere_index)
|
|
22
|
+
Examples:
|
|
23
|
+
>>> def process_datasets(index: AbstractIndex) -> None:
|
|
24
|
+
... for entry in index.list_datasets():
|
|
25
|
+
... print(f"{entry.name}: {entry.data_urls}")
|
|
26
|
+
...
|
|
27
|
+
>>> # Works with either LocalIndex or AtmosphereIndex
|
|
28
|
+
>>> process_datasets(local_index)
|
|
29
|
+
>>> process_datasets(atmosphere_index)
|
|
32
30
|
"""
|
|
33
31
|
|
|
34
32
|
from typing import (
|
|
35
33
|
IO,
|
|
36
34
|
Any,
|
|
37
|
-
ClassVar,
|
|
38
35
|
Iterator,
|
|
39
36
|
Optional,
|
|
40
37
|
Protocol,
|
|
@@ -67,18 +64,16 @@ class Packable(Protocol):
|
|
|
67
64
|
- Schema publishing (class introspection via dataclass fields)
|
|
68
65
|
- Serialization/deserialization (packed, from_bytes)
|
|
69
66
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
... instance = sample_type.from_bytes(data)
|
|
81
|
-
... print(instance.packed)
|
|
67
|
+
Examples:
|
|
68
|
+
>>> @packable
|
|
69
|
+
... class MySample:
|
|
70
|
+
... name: str
|
|
71
|
+
... value: int
|
|
72
|
+
...
|
|
73
|
+
>>> def process(sample_type: Type[Packable]) -> None:
|
|
74
|
+
... # Type checker knows sample_type has from_bytes, packed, etc.
|
|
75
|
+
... instance = sample_type.from_bytes(data)
|
|
76
|
+
... print(instance.packed)
|
|
82
77
|
"""
|
|
83
78
|
|
|
84
79
|
@classmethod
|
|
@@ -169,21 +164,19 @@ class AbstractIndex(Protocol):
|
|
|
169
164
|
- ``data_store``: An AbstractDataStore for reading/writing dataset shards.
|
|
170
165
|
If present, ``load_dataset`` will use it for S3 credential resolution.
|
|
171
166
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
... for entry in index.list_datasets():
|
|
186
|
-
... print(f"{entry.name} -> {entry.schema_ref}")
|
|
167
|
+
Examples:
|
|
168
|
+
>>> def publish_and_list(index: AbstractIndex) -> None:
|
|
169
|
+
... # Publish schemas for different types
|
|
170
|
+
... schema1 = index.publish_schema(ImageSample, version="1.0.0")
|
|
171
|
+
... schema2 = index.publish_schema(TextSample, version="1.0.0")
|
|
172
|
+
...
|
|
173
|
+
... # Insert datasets of different types
|
|
174
|
+
... index.insert_dataset(image_ds, name="images")
|
|
175
|
+
... index.insert_dataset(text_ds, name="texts")
|
|
176
|
+
...
|
|
177
|
+
... # List all datasets (mixed types)
|
|
178
|
+
... for entry in index.list_datasets():
|
|
179
|
+
... print(f"{entry.name} -> {entry.schema_ref}")
|
|
187
180
|
"""
|
|
188
181
|
|
|
189
182
|
@property
|
|
@@ -341,14 +334,12 @@ class AbstractIndex(Protocol):
|
|
|
341
334
|
KeyError: If schema not found.
|
|
342
335
|
ValueError: If schema cannot be decoded (unsupported field types).
|
|
343
336
|
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
>>> for sample in ds.ordered():
|
|
351
|
-
... print(sample) # sample is instance of SampleType
|
|
337
|
+
Examples:
|
|
338
|
+
>>> entry = index.get_dataset("my-dataset")
|
|
339
|
+
>>> SampleType = index.decode_schema(entry.schema_ref)
|
|
340
|
+
>>> ds = Dataset[SampleType](entry.data_urls[0])
|
|
341
|
+
>>> for sample in ds.ordered():
|
|
342
|
+
... print(sample) # sample is instance of SampleType
|
|
352
343
|
"""
|
|
353
344
|
...
|
|
354
345
|
|
|
@@ -368,13 +359,11 @@ class AbstractDataStore(Protocol):
|
|
|
368
359
|
flexible deployment: local index with S3 storage, atmosphere index with
|
|
369
360
|
S3 storage, or atmosphere index with PDS blobs.
|
|
370
361
|
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
>>> print(urls)
|
|
377
|
-
['s3://my-bucket/training/v1/shard-000000.tar', ...]
|
|
362
|
+
Examples:
|
|
363
|
+
>>> store = S3DataStore(credentials, bucket="my-bucket")
|
|
364
|
+
>>> urls = store.write_shards(dataset, prefix="training/v1")
|
|
365
|
+
>>> print(urls)
|
|
366
|
+
['s3://my-bucket/training/v1/shard-000000.tar', ...]
|
|
378
367
|
"""
|
|
379
368
|
|
|
380
369
|
def write_shards(
|
|
@@ -443,18 +432,16 @@ class DataSource(Protocol):
|
|
|
443
432
|
- ATProto blob streaming
|
|
444
433
|
- Any other source that can provide file-like objects
|
|
445
434
|
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
>>> for sample in ds.ordered():
|
|
457
|
-
... print(sample)
|
|
435
|
+
Examples:
|
|
436
|
+
>>> source = S3Source(
|
|
437
|
+
... bucket="my-bucket",
|
|
438
|
+
... keys=["data-000.tar", "data-001.tar"],
|
|
439
|
+
... endpoint="https://r2.example.com",
|
|
440
|
+
... credentials=creds,
|
|
441
|
+
... )
|
|
442
|
+
>>> ds = Dataset[MySample](source)
|
|
443
|
+
>>> for sample in ds.ordered():
|
|
444
|
+
... print(sample)
|
|
458
445
|
"""
|
|
459
446
|
|
|
460
447
|
@property
|
|
@@ -467,12 +454,10 @@ class DataSource(Protocol):
|
|
|
467
454
|
Yields:
|
|
468
455
|
Tuple of (shard_identifier, file_like_stream).
|
|
469
456
|
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
... print(f"Processing {shard_id}")
|
|
475
|
-
... data = stream.read()
|
|
457
|
+
Examples:
|
|
458
|
+
>>> for shard_id, stream in source.shards:
|
|
459
|
+
... print(f"Processing {shard_id}")
|
|
460
|
+
... data = stream.read()
|
|
476
461
|
"""
|
|
477
462
|
...
|
|
478
463
|
|
atdata/_schema_codec.py
CHANGED
|
@@ -9,19 +9,17 @@ The schema format follows the ATProto record structure defined in
|
|
|
9
9
|
``atmosphere/_types.py``, with field types supporting primitives, ndarrays,
|
|
10
10
|
arrays, and schema references.
|
|
11
11
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
>>> ImageSample = schema_to_type(schema)
|
|
24
|
-
>>> sample = ImageSample(image=np.zeros((64, 64)), label="cat")
|
|
12
|
+
Examples:
|
|
13
|
+
>>> schema = {
|
|
14
|
+
... "name": "ImageSample",
|
|
15
|
+
... "version": "1.0.0",
|
|
16
|
+
... "fields": [
|
|
17
|
+
... {"name": "image", "fieldType": {"$type": "...#ndarray", "dtype": "float32"}, "optional": False},
|
|
18
|
+
... {"name": "label", "fieldType": {"$type": "...#primitive", "primitive": "str"}, "optional": False},
|
|
19
|
+
... ]
|
|
20
|
+
... }
|
|
21
|
+
>>> ImageSample = schema_to_type(schema)
|
|
22
|
+
>>> sample = ImageSample(image=np.zeros((64, 64)), label="cat")
|
|
25
23
|
"""
|
|
26
24
|
|
|
27
25
|
from dataclasses import field, make_dataclass
|
|
@@ -151,14 +149,12 @@ def schema_to_type(
|
|
|
151
149
|
Raises:
|
|
152
150
|
ValueError: If schema is malformed or contains unsupported types.
|
|
153
151
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
>>> for sample in ds.ordered():
|
|
161
|
-
... print(sample)
|
|
152
|
+
Examples:
|
|
153
|
+
>>> schema = index.get_schema("local://schemas/MySample@1.0.0")
|
|
154
|
+
>>> MySample = schema_to_type(schema)
|
|
155
|
+
>>> ds = Dataset[MySample]("data.tar")
|
|
156
|
+
>>> for sample in ds.ordered():
|
|
157
|
+
... print(sample)
|
|
162
158
|
"""
|
|
163
159
|
# Check cache first
|
|
164
160
|
if use_cache:
|
|
@@ -207,7 +203,9 @@ def schema_to_type(
|
|
|
207
203
|
namespace={
|
|
208
204
|
"__post_init__": lambda self: PackableSample.__post_init__(self),
|
|
209
205
|
"__schema_version__": version,
|
|
210
|
-
"__schema_ref__": schema.get(
|
|
206
|
+
"__schema_ref__": schema.get(
|
|
207
|
+
"$ref", None
|
|
208
|
+
), # Store original ref if available
|
|
211
209
|
},
|
|
212
210
|
)
|
|
213
211
|
|
|
@@ -243,7 +241,9 @@ def _field_type_to_stub_str(field_type: dict, optional: bool = False) -> str:
|
|
|
243
241
|
|
|
244
242
|
if kind == "primitive":
|
|
245
243
|
primitive = field_type.get("primitive", "str")
|
|
246
|
-
py_type =
|
|
244
|
+
py_type = (
|
|
245
|
+
primitive # str, int, float, bool, bytes are all valid Python type names
|
|
246
|
+
)
|
|
247
247
|
elif kind == "ndarray":
|
|
248
248
|
py_type = "NDArray[Any]"
|
|
249
249
|
elif kind == "array":
|
|
@@ -282,14 +282,12 @@ def generate_stub(schema: dict) -> str:
|
|
|
282
282
|
Returns:
|
|
283
283
|
String content for a .pyi stub file.
|
|
284
284
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
>>> with open("stubs/my_sample.pyi", "w") as f:
|
|
292
|
-
... f.write(stub_content)
|
|
285
|
+
Examples:
|
|
286
|
+
>>> schema = index.get_schema("atdata://local/sampleSchema/MySample@1.0.0")
|
|
287
|
+
>>> stub_content = generate_stub(schema.to_dict())
|
|
288
|
+
>>> # Save to a stubs directory configured in your IDE
|
|
289
|
+
>>> with open("stubs/my_sample.pyi", "w") as f:
|
|
290
|
+
... f.write(stub_content)
|
|
293
291
|
"""
|
|
294
292
|
name = schema.get("name", "UnknownSample")
|
|
295
293
|
version = schema.get("version", "1.0.0")
|
|
@@ -360,12 +358,10 @@ def generate_module(schema: dict) -> str:
|
|
|
360
358
|
Returns:
|
|
361
359
|
String content for a .py module file.
|
|
362
360
|
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
>>> module_content = generate_module(schema.to_dict())
|
|
368
|
-
>>> # The module can be imported after being saved
|
|
361
|
+
Examples:
|
|
362
|
+
>>> schema = index.get_schema("atdata://local/sampleSchema/MySample@1.0.0")
|
|
363
|
+
>>> module_content = generate_module(schema.to_dict())
|
|
364
|
+
>>> # The module can be imported after being saved
|
|
369
365
|
"""
|
|
370
366
|
name = schema.get("name", "UnknownSample")
|
|
371
367
|
version = schema.get("version", "1.0.0")
|