atdata 0.2.3b1__py3-none-any.whl → 0.3.1b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/.gitignore +1 -0
- atdata/__init__.py +39 -0
- atdata/_cid.py +0 -21
- atdata/_exceptions.py +168 -0
- atdata/_helpers.py +41 -15
- atdata/_hf_api.py +95 -11
- atdata/_logging.py +70 -0
- atdata/_protocols.py +77 -238
- atdata/_schema_codec.py +7 -6
- atdata/_stub_manager.py +5 -25
- atdata/_type_utils.py +28 -2
- atdata/atmosphere/__init__.py +31 -20
- atdata/atmosphere/_types.py +4 -4
- atdata/atmosphere/client.py +64 -12
- atdata/atmosphere/lens.py +11 -12
- atdata/atmosphere/records.py +12 -12
- atdata/atmosphere/schema.py +16 -18
- atdata/atmosphere/store.py +6 -7
- atdata/cli/__init__.py +161 -175
- atdata/cli/diagnose.py +2 -2
- atdata/cli/{local.py → infra.py} +11 -11
- atdata/cli/inspect.py +69 -0
- atdata/cli/preview.py +63 -0
- atdata/cli/schema.py +109 -0
- atdata/dataset.py +583 -328
- atdata/index/__init__.py +54 -0
- atdata/index/_entry.py +157 -0
- atdata/index/_index.py +1198 -0
- atdata/index/_schema.py +380 -0
- atdata/lens.py +9 -2
- atdata/lexicons/__init__.py +121 -0
- atdata/lexicons/ac.foundation.dataset.arrayFormat.json +16 -0
- atdata/lexicons/ac.foundation.dataset.getLatestSchema.json +78 -0
- atdata/lexicons/ac.foundation.dataset.lens.json +99 -0
- atdata/lexicons/ac.foundation.dataset.record.json +96 -0
- atdata/lexicons/ac.foundation.dataset.schema.json +107 -0
- atdata/lexicons/ac.foundation.dataset.schemaType.json +16 -0
- atdata/lexicons/ac.foundation.dataset.storageBlobs.json +24 -0
- atdata/lexicons/ac.foundation.dataset.storageExternal.json +25 -0
- atdata/lexicons/ndarray_shim.json +16 -0
- atdata/local/__init__.py +70 -0
- atdata/local/_repo_legacy.py +218 -0
- atdata/manifest/__init__.py +28 -0
- atdata/manifest/_aggregates.py +156 -0
- atdata/manifest/_builder.py +163 -0
- atdata/manifest/_fields.py +154 -0
- atdata/manifest/_manifest.py +146 -0
- atdata/manifest/_query.py +150 -0
- atdata/manifest/_writer.py +74 -0
- atdata/promote.py +18 -14
- atdata/providers/__init__.py +25 -0
- atdata/providers/_base.py +140 -0
- atdata/providers/_factory.py +69 -0
- atdata/providers/_postgres.py +214 -0
- atdata/providers/_redis.py +171 -0
- atdata/providers/_sqlite.py +191 -0
- atdata/repository.py +323 -0
- atdata/stores/__init__.py +23 -0
- atdata/stores/_disk.py +123 -0
- atdata/stores/_s3.py +349 -0
- atdata/testing.py +341 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/METADATA +5 -2
- atdata-0.3.1b1.dist-info/RECORD +67 -0
- atdata/local.py +0 -1720
- atdata-0.2.3b1.dist-info/RECORD +0 -28
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/licenses/LICENSE +0 -0
atdata/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
!manifest/
|
atdata/__init__.py
CHANGED
|
@@ -44,6 +44,7 @@ from .dataset import (
|
|
|
44
44
|
SampleBatch as SampleBatch,
|
|
45
45
|
Dataset as Dataset,
|
|
46
46
|
packable as packable,
|
|
47
|
+
write_samples as write_samples,
|
|
47
48
|
)
|
|
48
49
|
|
|
49
50
|
from .lens import (
|
|
@@ -55,6 +56,8 @@ from .lens import (
|
|
|
55
56
|
from ._hf_api import (
|
|
56
57
|
load_dataset as load_dataset,
|
|
57
58
|
DatasetDict as DatasetDict,
|
|
59
|
+
get_default_index as get_default_index,
|
|
60
|
+
set_default_index as set_default_index,
|
|
58
61
|
)
|
|
59
62
|
|
|
60
63
|
from ._protocols import (
|
|
@@ -71,10 +74,37 @@ from ._sources import (
|
|
|
71
74
|
BlobSource as BlobSource,
|
|
72
75
|
)
|
|
73
76
|
|
|
77
|
+
from ._exceptions import (
|
|
78
|
+
AtdataError as AtdataError,
|
|
79
|
+
LensNotFoundError as LensNotFoundError,
|
|
80
|
+
SchemaError as SchemaError,
|
|
81
|
+
SampleKeyError as SampleKeyError,
|
|
82
|
+
ShardError as ShardError,
|
|
83
|
+
PartialFailureError as PartialFailureError,
|
|
84
|
+
)
|
|
85
|
+
|
|
74
86
|
from ._schema_codec import (
|
|
75
87
|
schema_to_type as schema_to_type,
|
|
76
88
|
)
|
|
77
89
|
|
|
90
|
+
from ._logging import (
|
|
91
|
+
configure_logging as configure_logging,
|
|
92
|
+
get_logger as get_logger,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
from .repository import (
|
|
96
|
+
Repository as Repository,
|
|
97
|
+
create_repository as create_repository,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
from .index import (
|
|
101
|
+
Index as Index,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
from .stores import (
|
|
105
|
+
LocalDiskStore as LocalDiskStore,
|
|
106
|
+
)
|
|
107
|
+
|
|
78
108
|
from ._cid import (
|
|
79
109
|
generate_cid as generate_cid,
|
|
80
110
|
verify_cid as verify_cid,
|
|
@@ -84,6 +114,15 @@ from .promote import (
|
|
|
84
114
|
promote_to_atmosphere as promote_to_atmosphere,
|
|
85
115
|
)
|
|
86
116
|
|
|
117
|
+
from .manifest import (
|
|
118
|
+
ManifestField as ManifestField,
|
|
119
|
+
ManifestBuilder as ManifestBuilder,
|
|
120
|
+
ShardManifest as ShardManifest,
|
|
121
|
+
ManifestWriter as ManifestWriter,
|
|
122
|
+
QueryExecutor as QueryExecutor,
|
|
123
|
+
SampleLocation as SampleLocation,
|
|
124
|
+
)
|
|
125
|
+
|
|
87
126
|
# ATProto integration (lazy import to avoid requiring atproto package)
|
|
88
127
|
from . import atmosphere as atmosphere
|
|
89
128
|
|
atdata/_cid.py
CHANGED
|
@@ -116,29 +116,8 @@ def verify_cid(cid: str, data: Any) -> bool:
|
|
|
116
116
|
return cid == expected_cid
|
|
117
117
|
|
|
118
118
|
|
|
119
|
-
def parse_cid(cid: str) -> dict:
|
|
120
|
-
"""Parse a CID string into its components.
|
|
121
|
-
|
|
122
|
-
Args:
|
|
123
|
-
cid: CID string to parse.
|
|
124
|
-
|
|
125
|
-
Returns:
|
|
126
|
-
Dictionary with 'version', 'codec', and 'hash' keys.
|
|
127
|
-
The 'hash' value is itself a dict with 'code', 'size', and 'digest'.
|
|
128
|
-
|
|
129
|
-
Examples:
|
|
130
|
-
>>> info = parse_cid('bafyrei...')
|
|
131
|
-
>>> info['version']
|
|
132
|
-
1
|
|
133
|
-
>>> info['codec']
|
|
134
|
-
113 # 0x71 = dag-cbor
|
|
135
|
-
"""
|
|
136
|
-
return libipld.decode_cid(cid)
|
|
137
|
-
|
|
138
|
-
|
|
139
119
|
__all__ = [
|
|
140
120
|
"generate_cid",
|
|
141
121
|
"generate_cid_from_bytes",
|
|
142
122
|
"verify_cid",
|
|
143
|
-
"parse_cid",
|
|
144
123
|
]
|
atdata/_exceptions.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""Custom exception hierarchy for atdata.
|
|
2
|
+
|
|
3
|
+
Provides actionable error messages with contextual help, available
|
|
4
|
+
alternatives, and suggested fix code snippets.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from typing import Type
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AtdataError(Exception):
|
|
16
|
+
"""Base exception for all atdata errors."""
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LensNotFoundError(AtdataError, ValueError):
|
|
20
|
+
"""No lens registered to transform between two sample types.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
source_type: The source sample type.
|
|
24
|
+
view_type: The target view type.
|
|
25
|
+
available_targets: Types reachable from the source via registered lenses.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
source_type: Type,
|
|
31
|
+
view_type: Type,
|
|
32
|
+
available_targets: list[tuple[Type, str]] | None = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
self.source_type = source_type
|
|
35
|
+
self.view_type = view_type
|
|
36
|
+
self.available_targets = available_targets or []
|
|
37
|
+
|
|
38
|
+
src_name = source_type.__name__
|
|
39
|
+
view_name = view_type.__name__
|
|
40
|
+
|
|
41
|
+
lines = [f"No lens transforms {src_name} \u2192 {view_name}"]
|
|
42
|
+
|
|
43
|
+
if self.available_targets:
|
|
44
|
+
lines.append("")
|
|
45
|
+
lines.append(f"Available lenses from {src_name}:")
|
|
46
|
+
for target_type, lens_name in self.available_targets:
|
|
47
|
+
lines.append(
|
|
48
|
+
f" - {src_name} \u2192 {target_type.__name__} (via {lens_name})"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
lines.append("")
|
|
52
|
+
lines.append("Did you mean to define:")
|
|
53
|
+
lines.append(" @lens")
|
|
54
|
+
lines.append(
|
|
55
|
+
f" def {src_name.lower()}_to_{view_name.lower()}(source: {src_name}) -> {view_name}:"
|
|
56
|
+
)
|
|
57
|
+
lines.append(f" return {view_name}(...)")
|
|
58
|
+
|
|
59
|
+
super().__init__("\n".join(lines))
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class SchemaError(AtdataError):
|
|
63
|
+
"""Schema mismatch during sample deserialization.
|
|
64
|
+
|
|
65
|
+
Raised when the data in a shard doesn't match the expected sample type.
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
expected_fields: Fields expected by the sample type.
|
|
69
|
+
actual_fields: Fields found in the data.
|
|
70
|
+
sample_type_name: Name of the target sample type.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
sample_type_name: str,
|
|
76
|
+
expected_fields: list[str],
|
|
77
|
+
actual_fields: list[str],
|
|
78
|
+
) -> None:
|
|
79
|
+
self.sample_type_name = sample_type_name
|
|
80
|
+
self.expected_fields = expected_fields
|
|
81
|
+
self.actual_fields = actual_fields
|
|
82
|
+
|
|
83
|
+
missing = sorted(set(expected_fields) - set(actual_fields))
|
|
84
|
+
extra = sorted(set(actual_fields) - set(expected_fields))
|
|
85
|
+
|
|
86
|
+
lines = [f"Schema mismatch for {sample_type_name}"]
|
|
87
|
+
if missing:
|
|
88
|
+
lines.append(f" Missing fields: {', '.join(missing)}")
|
|
89
|
+
if extra:
|
|
90
|
+
lines.append(f" Unexpected fields: {', '.join(extra)}")
|
|
91
|
+
lines.append("")
|
|
92
|
+
lines.append(f"Expected: {', '.join(sorted(expected_fields))}")
|
|
93
|
+
lines.append(f"Got: {', '.join(sorted(actual_fields))}")
|
|
94
|
+
|
|
95
|
+
super().__init__("\n".join(lines))
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class SampleKeyError(AtdataError, KeyError):
|
|
99
|
+
"""Sample with the given key was not found in the dataset.
|
|
100
|
+
|
|
101
|
+
Attributes:
|
|
102
|
+
key: The key that was not found.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(self, key: str) -> None:
|
|
106
|
+
self.key = key
|
|
107
|
+
super().__init__(
|
|
108
|
+
f"Sample with key '{key}' not found in dataset. "
|
|
109
|
+
f"Note: key lookup requires scanning all shards and is O(n)."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class ShardError(AtdataError):
|
|
114
|
+
"""Error accessing or reading a dataset shard.
|
|
115
|
+
|
|
116
|
+
Attributes:
|
|
117
|
+
shard_id: Identifier of the shard that failed.
|
|
118
|
+
reason: Human-readable description of what went wrong.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(self, shard_id: str, reason: str) -> None:
|
|
122
|
+
self.shard_id = shard_id
|
|
123
|
+
self.reason = reason
|
|
124
|
+
super().__init__(f"Failed to read shard '{shard_id}': {reason}")
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class PartialFailureError(AtdataError):
|
|
128
|
+
"""Some shards succeeded but others failed during processing.
|
|
129
|
+
|
|
130
|
+
Raised by :meth:`Dataset.process_shards` when at least one shard fails.
|
|
131
|
+
Provides access to both the successful results and the per-shard errors,
|
|
132
|
+
enabling retry of only the failed shards.
|
|
133
|
+
|
|
134
|
+
Attributes:
|
|
135
|
+
succeeded_shards: List of shard identifiers that succeeded.
|
|
136
|
+
failed_shards: List of shard identifiers that failed.
|
|
137
|
+
errors: Mapping from shard identifier to the exception that occurred.
|
|
138
|
+
results: Mapping from shard identifier to the result for succeeded shards.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
succeeded_shards: list[str],
|
|
144
|
+
failed_shards: list[str],
|
|
145
|
+
errors: dict[str, Exception],
|
|
146
|
+
results: dict[str, object],
|
|
147
|
+
) -> None:
|
|
148
|
+
self.succeeded_shards = succeeded_shards
|
|
149
|
+
self.failed_shards = failed_shards
|
|
150
|
+
self.errors = errors
|
|
151
|
+
self.results = results
|
|
152
|
+
|
|
153
|
+
n_ok = len(succeeded_shards)
|
|
154
|
+
n_fail = len(failed_shards)
|
|
155
|
+
total = n_ok + n_fail
|
|
156
|
+
|
|
157
|
+
lines = [f"{n_fail}/{total} shards failed during processing"]
|
|
158
|
+
for shard_id in failed_shards[:5]:
|
|
159
|
+
lines.append(f" {shard_id}: {errors[shard_id]}")
|
|
160
|
+
if n_fail > 5:
|
|
161
|
+
lines.append(f" ... and {n_fail - 5} more")
|
|
162
|
+
lines.append("")
|
|
163
|
+
lines.append(
|
|
164
|
+
f"Access .succeeded_shards ({n_ok}) and .failed_shards ({n_fail}) "
|
|
165
|
+
f"to inspect or retry."
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
super().__init__("\n".join(lines))
|
atdata/_helpers.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
"""Helper utilities for numpy array serialization.
|
|
2
2
|
|
|
3
3
|
This module provides utility functions for converting numpy arrays to and from
|
|
4
|
-
bytes for msgpack serialization.
|
|
5
|
-
format to preserve array dtype and shape information.
|
|
4
|
+
bytes for msgpack serialization.
|
|
6
5
|
|
|
7
6
|
Functions:
|
|
8
7
|
- ``array_to_bytes()``: Serialize numpy array to bytes
|
|
@@ -15,10 +14,14 @@ handling of NDArray fields during msgpack packing/unpacking.
|
|
|
15
14
|
##
|
|
16
15
|
# Imports
|
|
17
16
|
|
|
17
|
+
import struct
|
|
18
18
|
from io import BytesIO
|
|
19
19
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
|
|
22
|
+
# .npy format magic prefix (used for backward-compatible deserialization)
|
|
23
|
+
_NPY_MAGIC = b"\x93NUMPY"
|
|
24
|
+
|
|
22
25
|
|
|
23
26
|
##
|
|
24
27
|
|
|
@@ -26,35 +29,58 @@ import numpy as np
|
|
|
26
29
|
def array_to_bytes(x: np.ndarray) -> bytes:
|
|
27
30
|
"""Convert a numpy array to bytes for msgpack serialization.
|
|
28
31
|
|
|
29
|
-
Uses
|
|
32
|
+
Uses a compact binary format: a short header (dtype + shape) followed by
|
|
33
|
+
raw array bytes via ``ndarray.tobytes()``. Falls back to numpy's ``.npy``
|
|
34
|
+
format for object dtypes that cannot be represented as raw bytes.
|
|
30
35
|
|
|
31
36
|
Args:
|
|
32
37
|
x: A numpy array to serialize.
|
|
33
38
|
|
|
34
39
|
Returns:
|
|
35
40
|
Raw bytes representing the serialized array.
|
|
36
|
-
|
|
37
|
-
Note:
|
|
38
|
-
Uses ``allow_pickle=True`` to support object dtypes.
|
|
39
41
|
"""
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
42
|
+
if x.dtype == object:
|
|
43
|
+
buf = BytesIO()
|
|
44
|
+
np.save(buf, x, allow_pickle=True)
|
|
45
|
+
return buf.getvalue()
|
|
46
|
+
|
|
47
|
+
dtype_str = x.dtype.str.encode() # e.g. b'<f4'
|
|
48
|
+
header = struct.pack(f"<B{len(x.shape)}q", len(x.shape), *x.shape)
|
|
49
|
+
return struct.pack("<B", len(dtype_str)) + dtype_str + header + x.tobytes()
|
|
43
50
|
|
|
44
51
|
|
|
45
52
|
def bytes_to_array(b: bytes) -> np.ndarray:
|
|
46
53
|
"""Convert serialized bytes back to a numpy array.
|
|
47
54
|
|
|
48
|
-
|
|
55
|
+
Transparently handles both the compact format produced by the current
|
|
56
|
+
``array_to_bytes()`` and the legacy ``.npy`` format.
|
|
49
57
|
|
|
50
58
|
Args:
|
|
51
59
|
b: Raw bytes from a serialized numpy array.
|
|
52
60
|
|
|
53
61
|
Returns:
|
|
54
62
|
The deserialized numpy array with original dtype and shape.
|
|
55
|
-
|
|
56
|
-
Note:
|
|
57
|
-
Uses ``allow_pickle=True`` to support object dtypes.
|
|
58
63
|
"""
|
|
59
|
-
|
|
60
|
-
|
|
64
|
+
if b[:6] == _NPY_MAGIC:
|
|
65
|
+
return np.load(BytesIO(b), allow_pickle=True)
|
|
66
|
+
|
|
67
|
+
# Compact format: dtype_len(1B) + dtype_str + ndim(1B) + shape(ndim×8B) + data
|
|
68
|
+
if len(b) < 2:
|
|
69
|
+
raise ValueError(f"Array buffer too short ({len(b)} bytes): need at least 2")
|
|
70
|
+
dlen = b[0]
|
|
71
|
+
min_header = 2 + dlen # dtype_len + dtype_str + ndim
|
|
72
|
+
if len(b) < min_header:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Array buffer too short ({len(b)} bytes): need at least {min_header} for header"
|
|
75
|
+
)
|
|
76
|
+
dtype = np.dtype(b[1 : 1 + dlen].decode())
|
|
77
|
+
ndim = b[1 + dlen]
|
|
78
|
+
offset = 2 + dlen
|
|
79
|
+
min_with_shape = offset + ndim * 8
|
|
80
|
+
if len(b) < min_with_shape:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
f"Array buffer too short ({len(b)} bytes): need at least {min_with_shape} for shape"
|
|
83
|
+
)
|
|
84
|
+
shape = struct.unpack_from(f"<{ndim}q", b, offset)
|
|
85
|
+
offset += ndim * 8
|
|
86
|
+
return np.frombuffer(b, dtype=dtype, offset=offset).reshape(shape).copy()
|
atdata/_hf_api.py
CHANGED
|
@@ -29,8 +29,10 @@ Examples:
|
|
|
29
29
|
from __future__ import annotations
|
|
30
30
|
|
|
31
31
|
import re
|
|
32
|
+
import threading
|
|
32
33
|
from pathlib import Path
|
|
33
34
|
from typing import (
|
|
35
|
+
Any,
|
|
34
36
|
TYPE_CHECKING,
|
|
35
37
|
Generic,
|
|
36
38
|
Mapping,
|
|
@@ -40,9 +42,9 @@ from typing import (
|
|
|
40
42
|
overload,
|
|
41
43
|
)
|
|
42
44
|
|
|
43
|
-
from .dataset import Dataset,
|
|
45
|
+
from .dataset import Dataset, DictSample
|
|
44
46
|
from ._sources import URLSource, S3Source
|
|
45
|
-
from ._protocols import DataSource
|
|
47
|
+
from ._protocols import DataSource, Packable
|
|
46
48
|
|
|
47
49
|
if TYPE_CHECKING:
|
|
48
50
|
from ._protocols import AbstractIndex
|
|
@@ -50,7 +52,60 @@ if TYPE_CHECKING:
|
|
|
50
52
|
##
|
|
51
53
|
# Type variables
|
|
52
54
|
|
|
53
|
-
ST = TypeVar("ST", bound=
|
|
55
|
+
ST = TypeVar("ST", bound=Packable)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
##
|
|
59
|
+
# Default Index singleton
|
|
60
|
+
|
|
61
|
+
_default_index: "Index | None" = None # noqa: F821 (forward ref)
|
|
62
|
+
_default_index_lock = threading.Lock()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_default_index() -> "Index": # noqa: F821
|
|
66
|
+
"""Get or create the module-level default Index.
|
|
67
|
+
|
|
68
|
+
The default Index uses Redis for local storage (backwards-compatible
|
|
69
|
+
default) and an anonymous Atmosphere for read-only public data
|
|
70
|
+
resolution.
|
|
71
|
+
|
|
72
|
+
The default is created lazily on first access and cached for the
|
|
73
|
+
lifetime of the process.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
The default Index instance.
|
|
77
|
+
|
|
78
|
+
Examples:
|
|
79
|
+
>>> index = get_default_index()
|
|
80
|
+
>>> entry = index.get_dataset("local/mnist")
|
|
81
|
+
"""
|
|
82
|
+
global _default_index
|
|
83
|
+
if _default_index is None:
|
|
84
|
+
with _default_index_lock:
|
|
85
|
+
if _default_index is None:
|
|
86
|
+
from .local import Index
|
|
87
|
+
|
|
88
|
+
_default_index = Index()
|
|
89
|
+
return _default_index
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def set_default_index(index: "Index") -> None: # noqa: F821
|
|
93
|
+
"""Override the module-level default Index.
|
|
94
|
+
|
|
95
|
+
Use this to configure a custom default Index with specific repositories,
|
|
96
|
+
an authenticated atmosphere client, or non-default providers.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
index: The Index instance to use as the default.
|
|
100
|
+
|
|
101
|
+
Examples:
|
|
102
|
+
>>> from atdata.local import Index
|
|
103
|
+
>>> from atdata.providers import create_provider
|
|
104
|
+
>>> custom = Index(provider=create_provider("sqlite"))
|
|
105
|
+
>>> set_default_index(custom)
|
|
106
|
+
"""
|
|
107
|
+
global _default_index
|
|
108
|
+
_default_index = index
|
|
54
109
|
|
|
55
110
|
|
|
56
111
|
##
|
|
@@ -74,10 +129,11 @@ class DatasetDict(Generic[ST], dict):
|
|
|
74
129
|
>>>
|
|
75
130
|
>>> # Iterate over all splits
|
|
76
131
|
>>> for split_name, dataset in ds_dict.items():
|
|
77
|
-
... print(f"{split_name}: {len(dataset.
|
|
132
|
+
... print(f"{split_name}: {len(dataset.list_shards())} shards")
|
|
78
133
|
"""
|
|
79
134
|
|
|
80
|
-
#
|
|
135
|
+
# Note: The docstring uses "Parameters:" for type parameters as a workaround
|
|
136
|
+
# for quartodoc not supporting "Type Parameters:" sections.
|
|
81
137
|
|
|
82
138
|
def __init__(
|
|
83
139
|
self,
|
|
@@ -134,6 +190,37 @@ class DatasetDict(Generic[ST], dict):
|
|
|
134
190
|
"""
|
|
135
191
|
return {name: len(ds.list_shards()) for name, ds in self.items()}
|
|
136
192
|
|
|
193
|
+
# Methods proxied to the sole Dataset when only one split exists.
|
|
194
|
+
_DATASET_METHODS = frozenset(
|
|
195
|
+
{
|
|
196
|
+
"ordered",
|
|
197
|
+
"shuffled",
|
|
198
|
+
"as_type",
|
|
199
|
+
"list_shards",
|
|
200
|
+
"head",
|
|
201
|
+
}
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
def __getattr__(self, name: str) -> Any:
|
|
205
|
+
"""Proxy common Dataset methods when this dict has exactly one split.
|
|
206
|
+
|
|
207
|
+
When a ``DatasetDict`` contains a single split, calling iteration
|
|
208
|
+
methods like ``.ordered()`` or ``.shuffled()`` is forwarded to the
|
|
209
|
+
contained ``Dataset`` for convenience. Multi-split dicts raise
|
|
210
|
+
``AttributeError`` with a hint to select a split explicitly.
|
|
211
|
+
"""
|
|
212
|
+
if name in self._DATASET_METHODS:
|
|
213
|
+
if len(self) == 1:
|
|
214
|
+
return getattr(next(iter(self.values())), name)
|
|
215
|
+
splits = ", ".join(f"'{k}'" for k in self.keys())
|
|
216
|
+
raise AttributeError(
|
|
217
|
+
f"'{type(self).__name__}' has {len(self)} splits ({splits}). "
|
|
218
|
+
f"Select one first, e.g. ds_dict['{next(iter(self.keys()))}'].{name}()"
|
|
219
|
+
)
|
|
220
|
+
raise AttributeError(
|
|
221
|
+
f"'{type(self).__name__}' object has no attribute '{name}'"
|
|
222
|
+
)
|
|
223
|
+
|
|
137
224
|
|
|
138
225
|
##
|
|
139
226
|
# Path resolution utilities
|
|
@@ -459,7 +546,7 @@ def _resolve_indexed_path(
|
|
|
459
546
|
handle_or_did, dataset_name = _parse_indexed_path(path)
|
|
460
547
|
|
|
461
548
|
# For AtmosphereIndex, we need to resolve handle to DID first
|
|
462
|
-
# For
|
|
549
|
+
# For local Index, the handle is ignored and we just look up by name
|
|
463
550
|
entry = index.get_dataset(dataset_name)
|
|
464
551
|
data_urls = entry.data_urls
|
|
465
552
|
|
|
@@ -624,16 +711,13 @@ def load_dataset(
|
|
|
624
711
|
>>> train_ds = load_dataset("./data/train-*.tar", TextData, split="train")
|
|
625
712
|
>>>
|
|
626
713
|
>>> # Load from index with auto-type resolution
|
|
627
|
-
>>> index =
|
|
714
|
+
>>> index = Index()
|
|
628
715
|
>>> ds = load_dataset("@local/my-dataset", index=index, split="train")
|
|
629
716
|
"""
|
|
630
717
|
# Handle @handle/dataset indexed path resolution
|
|
631
718
|
if _is_indexed_path(path):
|
|
632
719
|
if index is None:
|
|
633
|
-
|
|
634
|
-
f"Index required for indexed path: {path}. "
|
|
635
|
-
"Pass index=LocalIndex() or index=AtmosphereIndex(client)."
|
|
636
|
-
)
|
|
720
|
+
index = get_default_index()
|
|
637
721
|
|
|
638
722
|
source, schema_ref = _resolve_indexed_path(path, index)
|
|
639
723
|
|
atdata/_logging.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Pluggable logging for atdata.
|
|
2
|
+
|
|
3
|
+
Provides a thin abstraction over Python's stdlib ``logging`` module that can
|
|
4
|
+
be replaced with ``structlog`` or any other logger implementing the standard
|
|
5
|
+
``debug``/``info``/``warning``/``error`` interface.
|
|
6
|
+
|
|
7
|
+
Usage::
|
|
8
|
+
|
|
9
|
+
# Default: stdlib logging (no config needed)
|
|
10
|
+
from atdata._logging import get_logger
|
|
11
|
+
log = get_logger()
|
|
12
|
+
log.info("processing shard", extra={"shard": "data-000.tar"})
|
|
13
|
+
|
|
14
|
+
# Plug in structlog (or any compatible logger):
|
|
15
|
+
import structlog
|
|
16
|
+
import atdata
|
|
17
|
+
atdata.configure_logging(structlog.get_logger())
|
|
18
|
+
|
|
19
|
+
The module also exports a lightweight ``LoggerProtocol`` for type checking
|
|
20
|
+
custom logger implementations.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import logging
|
|
26
|
+
from typing import Any, Protocol, runtime_checkable
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@runtime_checkable
|
|
30
|
+
class LoggerProtocol(Protocol):
|
|
31
|
+
"""Minimal interface that a pluggable logger must satisfy."""
|
|
32
|
+
|
|
33
|
+
def debug(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
|
|
34
|
+
def info(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
|
|
35
|
+
def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
|
|
36
|
+
def error(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# ---------------------------------------------------------------------------
|
|
40
|
+
# Module-level state
|
|
41
|
+
# ---------------------------------------------------------------------------
|
|
42
|
+
|
|
43
|
+
_logger: LoggerProtocol = logging.getLogger("atdata")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def configure_logging(logger: LoggerProtocol) -> None:
|
|
47
|
+
"""Replace the default logger with a custom implementation.
|
|
48
|
+
|
|
49
|
+
The provided logger must implement ``debug``, ``info``, ``warning``, and
|
|
50
|
+
``error`` methods. Both ``structlog`` bound loggers and stdlib
|
|
51
|
+
``logging.Logger`` instances satisfy this interface.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
logger: A logger instance implementing :class:`LoggerProtocol`.
|
|
55
|
+
|
|
56
|
+
Examples:
|
|
57
|
+
>>> import structlog
|
|
58
|
+
>>> atdata.configure_logging(structlog.get_logger())
|
|
59
|
+
"""
|
|
60
|
+
global _logger
|
|
61
|
+
_logger = logger
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_logger() -> LoggerProtocol:
|
|
65
|
+
"""Return the currently configured logger.
|
|
66
|
+
|
|
67
|
+
Returns the stdlib ``logging.getLogger("atdata")`` by default, or
|
|
68
|
+
whatever was last set via :func:`configure_logging`.
|
|
69
|
+
"""
|
|
70
|
+
return _logger
|