atdata 0.2.2b1__py3-none-any.whl → 0.3.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/.gitignore +1 -0
- atdata/__init__.py +31 -1
- atdata/_cid.py +29 -35
- atdata/_exceptions.py +168 -0
- atdata/_helpers.py +33 -17
- atdata/_hf_api.py +109 -59
- atdata/_logging.py +70 -0
- atdata/_protocols.py +74 -132
- atdata/_schema_codec.py +38 -41
- atdata/_sources.py +57 -64
- atdata/_stub_manager.py +31 -26
- atdata/_type_utils.py +47 -7
- atdata/atmosphere/__init__.py +31 -24
- atdata/atmosphere/_types.py +11 -11
- atdata/atmosphere/client.py +11 -8
- atdata/atmosphere/lens.py +27 -30
- atdata/atmosphere/records.py +34 -39
- atdata/atmosphere/schema.py +35 -31
- atdata/atmosphere/store.py +16 -20
- atdata/cli/__init__.py +163 -168
- atdata/cli/diagnose.py +12 -8
- atdata/cli/inspect.py +69 -0
- atdata/cli/local.py +5 -2
- atdata/cli/preview.py +63 -0
- atdata/cli/schema.py +109 -0
- atdata/dataset.py +678 -533
- atdata/lens.py +85 -83
- atdata/local/__init__.py +71 -0
- atdata/local/_entry.py +157 -0
- atdata/local/_index.py +940 -0
- atdata/local/_repo_legacy.py +218 -0
- atdata/local/_s3.py +349 -0
- atdata/local/_schema.py +380 -0
- atdata/manifest/__init__.py +28 -0
- atdata/manifest/_aggregates.py +156 -0
- atdata/manifest/_builder.py +163 -0
- atdata/manifest/_fields.py +154 -0
- atdata/manifest/_manifest.py +146 -0
- atdata/manifest/_query.py +150 -0
- atdata/manifest/_writer.py +74 -0
- atdata/promote.py +20 -24
- atdata/providers/__init__.py +25 -0
- atdata/providers/_base.py +140 -0
- atdata/providers/_factory.py +69 -0
- atdata/providers/_postgres.py +214 -0
- atdata/providers/_redis.py +171 -0
- atdata/providers/_sqlite.py +191 -0
- atdata/repository.py +323 -0
- atdata/testing.py +337 -0
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/METADATA +5 -1
- atdata-0.3.0b1.dist-info/RECORD +54 -0
- atdata/local.py +0 -1707
- atdata-0.2.2b1.dist-info/RECORD +0 -28
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/licenses/LICENSE +0 -0
atdata/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
!manifest/
|
atdata/__init__.py
CHANGED
|
@@ -55,6 +55,8 @@ from .lens import (
|
|
|
55
55
|
from ._hf_api import (
|
|
56
56
|
load_dataset as load_dataset,
|
|
57
57
|
DatasetDict as DatasetDict,
|
|
58
|
+
get_default_index as get_default_index,
|
|
59
|
+
set_default_index as set_default_index,
|
|
58
60
|
)
|
|
59
61
|
|
|
60
62
|
from ._protocols import (
|
|
@@ -71,10 +73,29 @@ from ._sources import (
|
|
|
71
73
|
BlobSource as BlobSource,
|
|
72
74
|
)
|
|
73
75
|
|
|
76
|
+
from ._exceptions import (
|
|
77
|
+
AtdataError as AtdataError,
|
|
78
|
+
LensNotFoundError as LensNotFoundError,
|
|
79
|
+
SchemaError as SchemaError,
|
|
80
|
+
SampleKeyError as SampleKeyError,
|
|
81
|
+
ShardError as ShardError,
|
|
82
|
+
PartialFailureError as PartialFailureError,
|
|
83
|
+
)
|
|
84
|
+
|
|
74
85
|
from ._schema_codec import (
|
|
75
86
|
schema_to_type as schema_to_type,
|
|
76
87
|
)
|
|
77
88
|
|
|
89
|
+
from ._logging import (
|
|
90
|
+
configure_logging as configure_logging,
|
|
91
|
+
get_logger as get_logger,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
from .repository import (
|
|
95
|
+
Repository as Repository,
|
|
96
|
+
create_repository as create_repository,
|
|
97
|
+
)
|
|
98
|
+
|
|
78
99
|
from ._cid import (
|
|
79
100
|
generate_cid as generate_cid,
|
|
80
101
|
verify_cid as verify_cid,
|
|
@@ -84,8 +105,17 @@ from .promote import (
|
|
|
84
105
|
promote_to_atmosphere as promote_to_atmosphere,
|
|
85
106
|
)
|
|
86
107
|
|
|
108
|
+
from .manifest import (
|
|
109
|
+
ManifestField as ManifestField,
|
|
110
|
+
ManifestBuilder as ManifestBuilder,
|
|
111
|
+
ShardManifest as ShardManifest,
|
|
112
|
+
ManifestWriter as ManifestWriter,
|
|
113
|
+
QueryExecutor as QueryExecutor,
|
|
114
|
+
SampleLocation as SampleLocation,
|
|
115
|
+
)
|
|
116
|
+
|
|
87
117
|
# ATProto integration (lazy import to avoid requiring atproto package)
|
|
88
118
|
from . import atmosphere as atmosphere
|
|
89
119
|
|
|
90
120
|
# CLI entry point
|
|
91
|
-
from .cli import main as main
|
|
121
|
+
from .cli import main as main
|
atdata/_cid.py
CHANGED
|
@@ -12,13 +12,11 @@ The CIDs generated here use:
|
|
|
12
12
|
This ensures compatibility with ATProto's CID requirements and enables
|
|
13
13
|
seamless promotion from local storage to atmosphere (ATProto network).
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
>>> print(cid)
|
|
21
|
-
bafyreihffx5a2e7k6r5zqgp5iwpjqr2gfyheqhzqtlxagvqjqyxzqpzqaa
|
|
15
|
+
Examples:
|
|
16
|
+
>>> schema = {"name": "ImageSample", "version": "1.0.0", "fields": [...]}
|
|
17
|
+
>>> cid = generate_cid(schema)
|
|
18
|
+
>>> print(cid)
|
|
19
|
+
bafyreihffx5a2e7k6r5zqgp5iwpjqr2gfyheqhzqtlxagvqjqyxzqpzqaa
|
|
22
20
|
"""
|
|
23
21
|
|
|
24
22
|
import hashlib
|
|
@@ -50,11 +48,9 @@ def generate_cid(data: Any) -> str:
|
|
|
50
48
|
Raises:
|
|
51
49
|
ValueError: If the data cannot be encoded as DAG-CBOR.
|
|
52
50
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
>>> generate_cid({"name": "test", "value": 42})
|
|
57
|
-
'bafyrei...'
|
|
51
|
+
Examples:
|
|
52
|
+
>>> generate_cid({"name": "test", "value": 42})
|
|
53
|
+
'bafyrei...'
|
|
58
54
|
"""
|
|
59
55
|
# Encode data as DAG-CBOR
|
|
60
56
|
try:
|
|
@@ -68,7 +64,9 @@ def generate_cid(data: Any) -> str:
|
|
|
68
64
|
# Build raw CID bytes:
|
|
69
65
|
# CIDv1 = version(1) + codec(dag-cbor) + multihash
|
|
70
66
|
# Multihash = code(sha256) + size(32) + digest
|
|
71
|
-
raw_cid_bytes =
|
|
67
|
+
raw_cid_bytes = (
|
|
68
|
+
bytes([CID_VERSION_1, CODEC_DAG_CBOR, HASH_SHA256, SHA256_SIZE]) + sha256_hash
|
|
69
|
+
)
|
|
72
70
|
|
|
73
71
|
# Encode to base32 multibase string
|
|
74
72
|
return libipld.encode_cid(raw_cid_bytes)
|
|
@@ -86,14 +84,14 @@ def generate_cid_from_bytes(data_bytes: bytes) -> str:
|
|
|
86
84
|
Returns:
|
|
87
85
|
CIDv1 string in base32 multibase format.
|
|
88
86
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
>>> cbor_bytes = libipld.encode_dag_cbor({"key": "value"})
|
|
93
|
-
>>> cid = generate_cid_from_bytes(cbor_bytes)
|
|
87
|
+
Examples:
|
|
88
|
+
>>> cbor_bytes = libipld.encode_dag_cbor({"key": "value"})
|
|
89
|
+
>>> cid = generate_cid_from_bytes(cbor_bytes)
|
|
94
90
|
"""
|
|
95
91
|
sha256_hash = hashlib.sha256(data_bytes).digest()
|
|
96
|
-
raw_cid_bytes =
|
|
92
|
+
raw_cid_bytes = (
|
|
93
|
+
bytes([CID_VERSION_1, CODEC_DAG_CBOR, HASH_SHA256, SHA256_SIZE]) + sha256_hash
|
|
94
|
+
)
|
|
97
95
|
return libipld.encode_cid(raw_cid_bytes)
|
|
98
96
|
|
|
99
97
|
|
|
@@ -107,14 +105,12 @@ def verify_cid(cid: str, data: Any) -> bool:
|
|
|
107
105
|
Returns:
|
|
108
106
|
True if the CID matches the data, False otherwise.
|
|
109
107
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
>>> verify_cid(cid, {"name": "different"})
|
|
117
|
-
False
|
|
108
|
+
Examples:
|
|
109
|
+
>>> cid = generate_cid({"name": "test"})
|
|
110
|
+
>>> verify_cid(cid, {"name": "test"})
|
|
111
|
+
True
|
|
112
|
+
>>> verify_cid(cid, {"name": "different"})
|
|
113
|
+
False
|
|
118
114
|
"""
|
|
119
115
|
expected_cid = generate_cid(data)
|
|
120
116
|
return cid == expected_cid
|
|
@@ -130,14 +126,12 @@ def parse_cid(cid: str) -> dict:
|
|
|
130
126
|
Dictionary with 'version', 'codec', and 'hash' keys.
|
|
131
127
|
The 'hash' value is itself a dict with 'code', 'size', and 'digest'.
|
|
132
128
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
>>> info['codec']
|
|
140
|
-
113 # 0x71 = dag-cbor
|
|
129
|
+
Examples:
|
|
130
|
+
>>> info = parse_cid('bafyrei...')
|
|
131
|
+
>>> info['version']
|
|
132
|
+
1
|
|
133
|
+
>>> info['codec']
|
|
134
|
+
113 # 0x71 = dag-cbor
|
|
141
135
|
"""
|
|
142
136
|
return libipld.decode_cid(cid)
|
|
143
137
|
|
atdata/_exceptions.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""Custom exception hierarchy for atdata.
|
|
2
|
+
|
|
3
|
+
Provides actionable error messages with contextual help, available
|
|
4
|
+
alternatives, and suggested fix code snippets.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from typing import Type
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AtdataError(Exception):
|
|
16
|
+
"""Base exception for all atdata errors."""
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LensNotFoundError(AtdataError, ValueError):
|
|
20
|
+
"""No lens registered to transform between two sample types.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
source_type: The source sample type.
|
|
24
|
+
view_type: The target view type.
|
|
25
|
+
available_targets: Types reachable from the source via registered lenses.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
source_type: Type,
|
|
31
|
+
view_type: Type,
|
|
32
|
+
available_targets: list[tuple[Type, str]] | None = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
self.source_type = source_type
|
|
35
|
+
self.view_type = view_type
|
|
36
|
+
self.available_targets = available_targets or []
|
|
37
|
+
|
|
38
|
+
src_name = source_type.__name__
|
|
39
|
+
view_name = view_type.__name__
|
|
40
|
+
|
|
41
|
+
lines = [f"No lens transforms {src_name} \u2192 {view_name}"]
|
|
42
|
+
|
|
43
|
+
if self.available_targets:
|
|
44
|
+
lines.append("")
|
|
45
|
+
lines.append(f"Available lenses from {src_name}:")
|
|
46
|
+
for target_type, lens_name in self.available_targets:
|
|
47
|
+
lines.append(
|
|
48
|
+
f" - {src_name} \u2192 {target_type.__name__} (via {lens_name})"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
lines.append("")
|
|
52
|
+
lines.append("Did you mean to define:")
|
|
53
|
+
lines.append(" @lens")
|
|
54
|
+
lines.append(
|
|
55
|
+
f" def {src_name.lower()}_to_{view_name.lower()}(source: {src_name}) -> {view_name}:"
|
|
56
|
+
)
|
|
57
|
+
lines.append(f" return {view_name}(...)")
|
|
58
|
+
|
|
59
|
+
super().__init__("\n".join(lines))
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class SchemaError(AtdataError):
|
|
63
|
+
"""Schema mismatch during sample deserialization.
|
|
64
|
+
|
|
65
|
+
Raised when the data in a shard doesn't match the expected sample type.
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
expected_fields: Fields expected by the sample type.
|
|
69
|
+
actual_fields: Fields found in the data.
|
|
70
|
+
sample_type_name: Name of the target sample type.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
sample_type_name: str,
|
|
76
|
+
expected_fields: list[str],
|
|
77
|
+
actual_fields: list[str],
|
|
78
|
+
) -> None:
|
|
79
|
+
self.sample_type_name = sample_type_name
|
|
80
|
+
self.expected_fields = expected_fields
|
|
81
|
+
self.actual_fields = actual_fields
|
|
82
|
+
|
|
83
|
+
missing = sorted(set(expected_fields) - set(actual_fields))
|
|
84
|
+
extra = sorted(set(actual_fields) - set(expected_fields))
|
|
85
|
+
|
|
86
|
+
lines = [f"Schema mismatch for {sample_type_name}"]
|
|
87
|
+
if missing:
|
|
88
|
+
lines.append(f" Missing fields: {', '.join(missing)}")
|
|
89
|
+
if extra:
|
|
90
|
+
lines.append(f" Unexpected fields: {', '.join(extra)}")
|
|
91
|
+
lines.append("")
|
|
92
|
+
lines.append(f"Expected: {', '.join(sorted(expected_fields))}")
|
|
93
|
+
lines.append(f"Got: {', '.join(sorted(actual_fields))}")
|
|
94
|
+
|
|
95
|
+
super().__init__("\n".join(lines))
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class SampleKeyError(AtdataError, KeyError):
|
|
99
|
+
"""Sample with the given key was not found in the dataset.
|
|
100
|
+
|
|
101
|
+
Attributes:
|
|
102
|
+
key: The key that was not found.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(self, key: str) -> None:
|
|
106
|
+
self.key = key
|
|
107
|
+
super().__init__(
|
|
108
|
+
f"Sample with key '{key}' not found in dataset. "
|
|
109
|
+
f"Note: key lookup requires scanning all shards and is O(n)."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class ShardError(AtdataError):
|
|
114
|
+
"""Error accessing or reading a dataset shard.
|
|
115
|
+
|
|
116
|
+
Attributes:
|
|
117
|
+
shard_id: Identifier of the shard that failed.
|
|
118
|
+
reason: Human-readable description of what went wrong.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(self, shard_id: str, reason: str) -> None:
|
|
122
|
+
self.shard_id = shard_id
|
|
123
|
+
self.reason = reason
|
|
124
|
+
super().__init__(f"Failed to read shard '{shard_id}': {reason}")
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class PartialFailureError(AtdataError):
|
|
128
|
+
"""Some shards succeeded but others failed during processing.
|
|
129
|
+
|
|
130
|
+
Raised by :meth:`Dataset.process_shards` when at least one shard fails.
|
|
131
|
+
Provides access to both the successful results and the per-shard errors,
|
|
132
|
+
enabling retry of only the failed shards.
|
|
133
|
+
|
|
134
|
+
Attributes:
|
|
135
|
+
succeeded_shards: List of shard identifiers that succeeded.
|
|
136
|
+
failed_shards: List of shard identifiers that failed.
|
|
137
|
+
errors: Mapping from shard identifier to the exception that occurred.
|
|
138
|
+
results: Mapping from shard identifier to the result for succeeded shards.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
succeeded_shards: list[str],
|
|
144
|
+
failed_shards: list[str],
|
|
145
|
+
errors: dict[str, Exception],
|
|
146
|
+
results: dict[str, object],
|
|
147
|
+
) -> None:
|
|
148
|
+
self.succeeded_shards = succeeded_shards
|
|
149
|
+
self.failed_shards = failed_shards
|
|
150
|
+
self.errors = errors
|
|
151
|
+
self.results = results
|
|
152
|
+
|
|
153
|
+
n_ok = len(succeeded_shards)
|
|
154
|
+
n_fail = len(failed_shards)
|
|
155
|
+
total = n_ok + n_fail
|
|
156
|
+
|
|
157
|
+
lines = [f"{n_fail}/{total} shards failed during processing"]
|
|
158
|
+
for shard_id in failed_shards[:5]:
|
|
159
|
+
lines.append(f" {shard_id}: {errors[shard_id]}")
|
|
160
|
+
if n_fail > 5:
|
|
161
|
+
lines.append(f" ... and {n_fail - 5} more")
|
|
162
|
+
lines.append("")
|
|
163
|
+
lines.append(
|
|
164
|
+
f"Access .succeeded_shards ({n_ok}) and .failed_shards ({n_fail}) "
|
|
165
|
+
f"to inspect or retry."
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
super().__init__("\n".join(lines))
|
atdata/_helpers.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
"""Helper utilities for numpy array serialization.
|
|
2
2
|
|
|
3
3
|
This module provides utility functions for converting numpy arrays to and from
|
|
4
|
-
bytes for msgpack serialization.
|
|
5
|
-
format to preserve array dtype and shape information.
|
|
4
|
+
bytes for msgpack serialization.
|
|
6
5
|
|
|
7
6
|
Functions:
|
|
8
7
|
- ``array_to_bytes()``: Serialize numpy array to bytes
|
|
@@ -15,44 +14,61 @@ handling of NDArray fields during msgpack packing/unpacking.
|
|
|
15
14
|
##
|
|
16
15
|
# Imports
|
|
17
16
|
|
|
17
|
+
import struct
|
|
18
18
|
from io import BytesIO
|
|
19
19
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
|
|
22
|
+
# .npy format magic prefix (used for backward-compatible deserialization)
|
|
23
|
+
_NPY_MAGIC = b"\x93NUMPY"
|
|
24
|
+
|
|
22
25
|
|
|
23
26
|
##
|
|
24
27
|
|
|
25
|
-
|
|
28
|
+
|
|
29
|
+
def array_to_bytes(x: np.ndarray) -> bytes:
|
|
26
30
|
"""Convert a numpy array to bytes for msgpack serialization.
|
|
27
31
|
|
|
28
|
-
Uses
|
|
32
|
+
Uses a compact binary format: a short header (dtype + shape) followed by
|
|
33
|
+
raw array bytes via ``ndarray.tobytes()``. Falls back to numpy's ``.npy``
|
|
34
|
+
format for object dtypes that cannot be represented as raw bytes.
|
|
29
35
|
|
|
30
36
|
Args:
|
|
31
37
|
x: A numpy array to serialize.
|
|
32
38
|
|
|
33
39
|
Returns:
|
|
34
40
|
Raw bytes representing the serialized array.
|
|
35
|
-
|
|
36
|
-
Note:
|
|
37
|
-
Uses ``allow_pickle=True`` to support object dtypes.
|
|
38
41
|
"""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
+
if x.dtype == object:
|
|
43
|
+
buf = BytesIO()
|
|
44
|
+
np.save(buf, x, allow_pickle=True)
|
|
45
|
+
return buf.getvalue()
|
|
42
46
|
|
|
43
|
-
|
|
47
|
+
dtype_str = x.dtype.str.encode() # e.g. b'<f4'
|
|
48
|
+
header = struct.pack(f"<B{len(x.shape)}q", len(x.shape), *x.shape)
|
|
49
|
+
return struct.pack("<B", len(dtype_str)) + dtype_str + header + x.tobytes()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def bytes_to_array(b: bytes) -> np.ndarray:
|
|
44
53
|
"""Convert serialized bytes back to a numpy array.
|
|
45
54
|
|
|
46
|
-
|
|
55
|
+
Transparently handles both the compact format produced by the current
|
|
56
|
+
``array_to_bytes()`` and the legacy ``.npy`` format.
|
|
47
57
|
|
|
48
58
|
Args:
|
|
49
59
|
b: Raw bytes from a serialized numpy array.
|
|
50
60
|
|
|
51
61
|
Returns:
|
|
52
62
|
The deserialized numpy array with original dtype and shape.
|
|
53
|
-
|
|
54
|
-
Note:
|
|
55
|
-
Uses ``allow_pickle=True`` to support object dtypes.
|
|
56
63
|
"""
|
|
57
|
-
|
|
58
|
-
|
|
64
|
+
if b[:6] == _NPY_MAGIC:
|
|
65
|
+
return np.load(BytesIO(b), allow_pickle=True)
|
|
66
|
+
|
|
67
|
+
# Compact format: dtype_len(1B) + dtype_str + ndim(1B) + shape(ndim×8B) + data
|
|
68
|
+
dlen = b[0]
|
|
69
|
+
dtype = np.dtype(b[1 : 1 + dlen].decode())
|
|
70
|
+
ndim = b[1 + dlen]
|
|
71
|
+
offset = 2 + dlen
|
|
72
|
+
shape = struct.unpack_from(f"<{ndim}q", b, offset)
|
|
73
|
+
offset += ndim * 8
|
|
74
|
+
return np.frombuffer(b, dtype=dtype, offset=offset).reshape(shape).copy()
|
atdata/_hf_api.py
CHANGED
|
@@ -9,28 +9,27 @@ Key differences from HuggingFace Datasets:
|
|
|
9
9
|
- Built on WebDataset for efficient streaming of large datasets
|
|
10
10
|
- No Arrow caching layer (WebDataset handles remote/local transparently)
|
|
11
11
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
>>> ds_dict = load_dataset("path/to/{train,test}-*.tar", MyData)
|
|
28
|
-
>>> train_ds = ds_dict["train"]
|
|
12
|
+
Examples:
|
|
13
|
+
>>> import atdata
|
|
14
|
+
>>> from atdata import load_dataset
|
|
15
|
+
>>>
|
|
16
|
+
>>> @atdata.packable
|
|
17
|
+
... class MyData:
|
|
18
|
+
... text: str
|
|
19
|
+
... label: int
|
|
20
|
+
>>>
|
|
21
|
+
>>> # Load a single split
|
|
22
|
+
>>> ds = load_dataset("path/to/train-{000000..000099}.tar", MyData, split="train")
|
|
23
|
+
>>>
|
|
24
|
+
>>> # Load all splits (returns DatasetDict)
|
|
25
|
+
>>> ds_dict = load_dataset("path/to/{train,test}-*.tar", MyData)
|
|
26
|
+
>>> train_ds = ds_dict["train"]
|
|
29
27
|
"""
|
|
30
28
|
|
|
31
29
|
from __future__ import annotations
|
|
32
30
|
|
|
33
31
|
import re
|
|
32
|
+
import threading
|
|
34
33
|
from pathlib import Path
|
|
35
34
|
from typing import (
|
|
36
35
|
TYPE_CHECKING,
|
|
@@ -42,18 +41,70 @@ from typing import (
|
|
|
42
41
|
overload,
|
|
43
42
|
)
|
|
44
43
|
|
|
45
|
-
from .dataset import Dataset,
|
|
44
|
+
from .dataset import Dataset, DictSample
|
|
46
45
|
from ._sources import URLSource, S3Source
|
|
47
|
-
from ._protocols import DataSource
|
|
46
|
+
from ._protocols import DataSource, Packable
|
|
48
47
|
|
|
49
48
|
if TYPE_CHECKING:
|
|
50
49
|
from ._protocols import AbstractIndex
|
|
51
|
-
from .local import S3DataStore
|
|
52
50
|
|
|
53
51
|
##
|
|
54
52
|
# Type variables
|
|
55
53
|
|
|
56
|
-
ST = TypeVar("ST", bound=
|
|
54
|
+
ST = TypeVar("ST", bound=Packable)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
##
|
|
58
|
+
# Default Index singleton
|
|
59
|
+
|
|
60
|
+
_default_index: "Index | None" = None # noqa: F821 (forward ref)
|
|
61
|
+
_default_index_lock = threading.Lock()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_default_index() -> "Index": # noqa: F821
|
|
65
|
+
"""Get or create the module-level default Index.
|
|
66
|
+
|
|
67
|
+
The default Index uses Redis for local storage (backwards-compatible
|
|
68
|
+
default) and an anonymous AtmosphereClient for read-only public data
|
|
69
|
+
resolution.
|
|
70
|
+
|
|
71
|
+
The default is created lazily on first access and cached for the
|
|
72
|
+
lifetime of the process.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
The default Index instance.
|
|
76
|
+
|
|
77
|
+
Examples:
|
|
78
|
+
>>> index = get_default_index()
|
|
79
|
+
>>> entry = index.get_dataset("local/mnist")
|
|
80
|
+
"""
|
|
81
|
+
global _default_index
|
|
82
|
+
if _default_index is None:
|
|
83
|
+
with _default_index_lock:
|
|
84
|
+
if _default_index is None:
|
|
85
|
+
from .local import Index
|
|
86
|
+
|
|
87
|
+
_default_index = Index()
|
|
88
|
+
return _default_index
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def set_default_index(index: "Index") -> None: # noqa: F821
|
|
92
|
+
"""Override the module-level default Index.
|
|
93
|
+
|
|
94
|
+
Use this to configure a custom default Index with specific repositories,
|
|
95
|
+
an authenticated atmosphere client, or non-default providers.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
index: The Index instance to use as the default.
|
|
99
|
+
|
|
100
|
+
Examples:
|
|
101
|
+
>>> from atdata.local import Index
|
|
102
|
+
>>> from atdata.providers import create_provider
|
|
103
|
+
>>> custom = Index(provider=create_provider("sqlite"))
|
|
104
|
+
>>> set_default_index(custom)
|
|
105
|
+
"""
|
|
106
|
+
global _default_index
|
|
107
|
+
_default_index = index
|
|
57
108
|
|
|
58
109
|
|
|
59
110
|
##
|
|
@@ -70,18 +121,18 @@ class DatasetDict(Generic[ST], dict):
|
|
|
70
121
|
Parameters:
|
|
71
122
|
ST: The sample type for all datasets in this dict.
|
|
72
123
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
>>> for split_name, dataset in ds_dict.items():
|
|
82
|
-
... print(f"{split_name}: {len(dataset.shard_list)} shards")
|
|
124
|
+
Examples:
|
|
125
|
+
>>> ds_dict = load_dataset("path/to/data", MyData)
|
|
126
|
+
>>> train = ds_dict["train"]
|
|
127
|
+
>>> test = ds_dict["test"]
|
|
128
|
+
>>>
|
|
129
|
+
>>> # Iterate over all splits
|
|
130
|
+
>>> for split_name, dataset in ds_dict.items():
|
|
131
|
+
... print(f"{split_name}: {len(dataset.list_shards())} shards")
|
|
83
132
|
"""
|
|
84
|
-
|
|
133
|
+
|
|
134
|
+
# Note: The docstring uses "Parameters:" for type parameters as a workaround
|
|
135
|
+
# for quartodoc not supporting "Type Parameters:" sections.
|
|
85
136
|
|
|
86
137
|
def __init__(
|
|
87
138
|
self,
|
|
@@ -463,12 +514,12 @@ def _resolve_indexed_path(
|
|
|
463
514
|
handle_or_did, dataset_name = _parse_indexed_path(path)
|
|
464
515
|
|
|
465
516
|
# For AtmosphereIndex, we need to resolve handle to DID first
|
|
466
|
-
# For
|
|
517
|
+
# For local Index, the handle is ignored and we just look up by name
|
|
467
518
|
entry = index.get_dataset(dataset_name)
|
|
468
519
|
data_urls = entry.data_urls
|
|
469
520
|
|
|
470
521
|
# Check if index has a data store
|
|
471
|
-
if hasattr(index,
|
|
522
|
+
if hasattr(index, "data_store") and index.data_store is not None:
|
|
472
523
|
store = index.data_store
|
|
473
524
|
|
|
474
525
|
# Import here to avoid circular imports at module level
|
|
@@ -613,38 +664,35 @@ def load_dataset(
|
|
|
613
664
|
FileNotFoundError: If no data files are found at the path.
|
|
614
665
|
KeyError: If dataset not found in index.
|
|
615
666
|
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
>>> index = LocalIndex()
|
|
634
|
-
>>> ds = load_dataset("@local/my-dataset", index=index, split="train")
|
|
667
|
+
Examples:
|
|
668
|
+
>>> # Load without type - get DictSample for exploration
|
|
669
|
+
>>> ds = load_dataset("./data/train.tar", split="train")
|
|
670
|
+
>>> for sample in ds.ordered():
|
|
671
|
+
... print(sample.keys()) # Explore fields
|
|
672
|
+
... print(sample["text"]) # Dict-style access
|
|
673
|
+
... print(sample.label) # Attribute access
|
|
674
|
+
>>>
|
|
675
|
+
>>> # Convert to typed schema
|
|
676
|
+
>>> typed_ds = ds.as_type(TextData)
|
|
677
|
+
>>>
|
|
678
|
+
>>> # Or load with explicit type directly
|
|
679
|
+
>>> train_ds = load_dataset("./data/train-*.tar", TextData, split="train")
|
|
680
|
+
>>>
|
|
681
|
+
>>> # Load from index with auto-type resolution
|
|
682
|
+
>>> index = Index()
|
|
683
|
+
>>> ds = load_dataset("@local/my-dataset", index=index, split="train")
|
|
635
684
|
"""
|
|
636
685
|
# Handle @handle/dataset indexed path resolution
|
|
637
686
|
if _is_indexed_path(path):
|
|
638
687
|
if index is None:
|
|
639
|
-
|
|
640
|
-
f"Index required for indexed path: {path}. "
|
|
641
|
-
"Pass index=LocalIndex() or index=AtmosphereIndex(client)."
|
|
642
|
-
)
|
|
688
|
+
index = get_default_index()
|
|
643
689
|
|
|
644
690
|
source, schema_ref = _resolve_indexed_path(path, index)
|
|
645
691
|
|
|
646
692
|
# Resolve sample_type from schema if not provided
|
|
647
|
-
resolved_type: Type =
|
|
693
|
+
resolved_type: Type = (
|
|
694
|
+
sample_type if sample_type is not None else index.decode_schema(schema_ref)
|
|
695
|
+
)
|
|
648
696
|
|
|
649
697
|
# Create dataset from the resolved source (includes credentials if S3)
|
|
650
698
|
ds = Dataset[resolved_type](source)
|
|
@@ -653,7 +701,9 @@ def load_dataset(
|
|
|
653
701
|
# Indexed datasets are single-split by default
|
|
654
702
|
return ds
|
|
655
703
|
|
|
656
|
-
return DatasetDict(
|
|
704
|
+
return DatasetDict(
|
|
705
|
+
{"train": ds}, sample_type=resolved_type, streaming=streaming
|
|
706
|
+
)
|
|
657
707
|
|
|
658
708
|
# Use DictSample as default when no type specified
|
|
659
709
|
resolved_type = sample_type if sample_type is not None else DictSample
|