atdata 0.2.3b1__py3-none-any.whl → 0.3.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/.gitignore +1 -0
- atdata/__init__.py +30 -0
- atdata/_exceptions.py +168 -0
- atdata/_helpers.py +29 -15
- atdata/_hf_api.py +63 -11
- atdata/_logging.py +70 -0
- atdata/_protocols.py +19 -62
- atdata/_schema_codec.py +5 -4
- atdata/_type_utils.py +28 -2
- atdata/atmosphere/__init__.py +19 -9
- atdata/atmosphere/records.py +3 -2
- atdata/atmosphere/schema.py +2 -2
- atdata/cli/__init__.py +157 -171
- atdata/cli/inspect.py +69 -0
- atdata/cli/local.py +1 -1
- atdata/cli/preview.py +63 -0
- atdata/cli/schema.py +109 -0
- atdata/dataset.py +428 -326
- atdata/lens.py +9 -2
- atdata/local/__init__.py +71 -0
- atdata/local/_entry.py +157 -0
- atdata/local/_index.py +940 -0
- atdata/local/_repo_legacy.py +218 -0
- atdata/local/_s3.py +349 -0
- atdata/local/_schema.py +380 -0
- atdata/manifest/__init__.py +28 -0
- atdata/manifest/_aggregates.py +156 -0
- atdata/manifest/_builder.py +163 -0
- atdata/manifest/_fields.py +154 -0
- atdata/manifest/_manifest.py +146 -0
- atdata/manifest/_query.py +150 -0
- atdata/manifest/_writer.py +74 -0
- atdata/promote.py +4 -4
- atdata/providers/__init__.py +25 -0
- atdata/providers/_base.py +140 -0
- atdata/providers/_factory.py +69 -0
- atdata/providers/_postgres.py +214 -0
- atdata/providers/_redis.py +171 -0
- atdata/providers/_sqlite.py +191 -0
- atdata/repository.py +323 -0
- atdata/testing.py +337 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.0b1.dist-info}/METADATA +4 -1
- atdata-0.3.0b1.dist-info/RECORD +54 -0
- atdata/local.py +0 -1720
- atdata-0.2.3b1.dist-info/RECORD +0 -28
- {atdata-0.2.3b1.dist-info → atdata-0.3.0b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.0b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.0b1.dist-info}/licenses/LICENSE +0 -0
atdata/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
!manifest/
|
atdata/__init__.py
CHANGED
|
@@ -55,6 +55,8 @@ from .lens import (
|
|
|
55
55
|
from ._hf_api import (
|
|
56
56
|
load_dataset as load_dataset,
|
|
57
57
|
DatasetDict as DatasetDict,
|
|
58
|
+
get_default_index as get_default_index,
|
|
59
|
+
set_default_index as set_default_index,
|
|
58
60
|
)
|
|
59
61
|
|
|
60
62
|
from ._protocols import (
|
|
@@ -71,10 +73,29 @@ from ._sources import (
|
|
|
71
73
|
BlobSource as BlobSource,
|
|
72
74
|
)
|
|
73
75
|
|
|
76
|
+
from ._exceptions import (
|
|
77
|
+
AtdataError as AtdataError,
|
|
78
|
+
LensNotFoundError as LensNotFoundError,
|
|
79
|
+
SchemaError as SchemaError,
|
|
80
|
+
SampleKeyError as SampleKeyError,
|
|
81
|
+
ShardError as ShardError,
|
|
82
|
+
PartialFailureError as PartialFailureError,
|
|
83
|
+
)
|
|
84
|
+
|
|
74
85
|
from ._schema_codec import (
|
|
75
86
|
schema_to_type as schema_to_type,
|
|
76
87
|
)
|
|
77
88
|
|
|
89
|
+
from ._logging import (
|
|
90
|
+
configure_logging as configure_logging,
|
|
91
|
+
get_logger as get_logger,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
from .repository import (
|
|
95
|
+
Repository as Repository,
|
|
96
|
+
create_repository as create_repository,
|
|
97
|
+
)
|
|
98
|
+
|
|
78
99
|
from ._cid import (
|
|
79
100
|
generate_cid as generate_cid,
|
|
80
101
|
verify_cid as verify_cid,
|
|
@@ -84,6 +105,15 @@ from .promote import (
|
|
|
84
105
|
promote_to_atmosphere as promote_to_atmosphere,
|
|
85
106
|
)
|
|
86
107
|
|
|
108
|
+
from .manifest import (
|
|
109
|
+
ManifestField as ManifestField,
|
|
110
|
+
ManifestBuilder as ManifestBuilder,
|
|
111
|
+
ShardManifest as ShardManifest,
|
|
112
|
+
ManifestWriter as ManifestWriter,
|
|
113
|
+
QueryExecutor as QueryExecutor,
|
|
114
|
+
SampleLocation as SampleLocation,
|
|
115
|
+
)
|
|
116
|
+
|
|
87
117
|
# ATProto integration (lazy import to avoid requiring atproto package)
|
|
88
118
|
from . import atmosphere as atmosphere
|
|
89
119
|
|
atdata/_exceptions.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""Custom exception hierarchy for atdata.
|
|
2
|
+
|
|
3
|
+
Provides actionable error messages with contextual help, available
|
|
4
|
+
alternatives, and suggested fix code snippets.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from typing import Type
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AtdataError(Exception):
|
|
16
|
+
"""Base exception for all atdata errors."""
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LensNotFoundError(AtdataError, ValueError):
|
|
20
|
+
"""No lens registered to transform between two sample types.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
source_type: The source sample type.
|
|
24
|
+
view_type: The target view type.
|
|
25
|
+
available_targets: Types reachable from the source via registered lenses.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
source_type: Type,
|
|
31
|
+
view_type: Type,
|
|
32
|
+
available_targets: list[tuple[Type, str]] | None = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
self.source_type = source_type
|
|
35
|
+
self.view_type = view_type
|
|
36
|
+
self.available_targets = available_targets or []
|
|
37
|
+
|
|
38
|
+
src_name = source_type.__name__
|
|
39
|
+
view_name = view_type.__name__
|
|
40
|
+
|
|
41
|
+
lines = [f"No lens transforms {src_name} \u2192 {view_name}"]
|
|
42
|
+
|
|
43
|
+
if self.available_targets:
|
|
44
|
+
lines.append("")
|
|
45
|
+
lines.append(f"Available lenses from {src_name}:")
|
|
46
|
+
for target_type, lens_name in self.available_targets:
|
|
47
|
+
lines.append(
|
|
48
|
+
f" - {src_name} \u2192 {target_type.__name__} (via {lens_name})"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
lines.append("")
|
|
52
|
+
lines.append("Did you mean to define:")
|
|
53
|
+
lines.append(" @lens")
|
|
54
|
+
lines.append(
|
|
55
|
+
f" def {src_name.lower()}_to_{view_name.lower()}(source: {src_name}) -> {view_name}:"
|
|
56
|
+
)
|
|
57
|
+
lines.append(f" return {view_name}(...)")
|
|
58
|
+
|
|
59
|
+
super().__init__("\n".join(lines))
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class SchemaError(AtdataError):
|
|
63
|
+
"""Schema mismatch during sample deserialization.
|
|
64
|
+
|
|
65
|
+
Raised when the data in a shard doesn't match the expected sample type.
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
expected_fields: Fields expected by the sample type.
|
|
69
|
+
actual_fields: Fields found in the data.
|
|
70
|
+
sample_type_name: Name of the target sample type.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
sample_type_name: str,
|
|
76
|
+
expected_fields: list[str],
|
|
77
|
+
actual_fields: list[str],
|
|
78
|
+
) -> None:
|
|
79
|
+
self.sample_type_name = sample_type_name
|
|
80
|
+
self.expected_fields = expected_fields
|
|
81
|
+
self.actual_fields = actual_fields
|
|
82
|
+
|
|
83
|
+
missing = sorted(set(expected_fields) - set(actual_fields))
|
|
84
|
+
extra = sorted(set(actual_fields) - set(expected_fields))
|
|
85
|
+
|
|
86
|
+
lines = [f"Schema mismatch for {sample_type_name}"]
|
|
87
|
+
if missing:
|
|
88
|
+
lines.append(f" Missing fields: {', '.join(missing)}")
|
|
89
|
+
if extra:
|
|
90
|
+
lines.append(f" Unexpected fields: {', '.join(extra)}")
|
|
91
|
+
lines.append("")
|
|
92
|
+
lines.append(f"Expected: {', '.join(sorted(expected_fields))}")
|
|
93
|
+
lines.append(f"Got: {', '.join(sorted(actual_fields))}")
|
|
94
|
+
|
|
95
|
+
super().__init__("\n".join(lines))
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class SampleKeyError(AtdataError, KeyError):
|
|
99
|
+
"""Sample with the given key was not found in the dataset.
|
|
100
|
+
|
|
101
|
+
Attributes:
|
|
102
|
+
key: The key that was not found.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(self, key: str) -> None:
|
|
106
|
+
self.key = key
|
|
107
|
+
super().__init__(
|
|
108
|
+
f"Sample with key '{key}' not found in dataset. "
|
|
109
|
+
f"Note: key lookup requires scanning all shards and is O(n)."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class ShardError(AtdataError):
|
|
114
|
+
"""Error accessing or reading a dataset shard.
|
|
115
|
+
|
|
116
|
+
Attributes:
|
|
117
|
+
shard_id: Identifier of the shard that failed.
|
|
118
|
+
reason: Human-readable description of what went wrong.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(self, shard_id: str, reason: str) -> None:
|
|
122
|
+
self.shard_id = shard_id
|
|
123
|
+
self.reason = reason
|
|
124
|
+
super().__init__(f"Failed to read shard '{shard_id}': {reason}")
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class PartialFailureError(AtdataError):
|
|
128
|
+
"""Some shards succeeded but others failed during processing.
|
|
129
|
+
|
|
130
|
+
Raised by :meth:`Dataset.process_shards` when at least one shard fails.
|
|
131
|
+
Provides access to both the successful results and the per-shard errors,
|
|
132
|
+
enabling retry of only the failed shards.
|
|
133
|
+
|
|
134
|
+
Attributes:
|
|
135
|
+
succeeded_shards: List of shard identifiers that succeeded.
|
|
136
|
+
failed_shards: List of shard identifiers that failed.
|
|
137
|
+
errors: Mapping from shard identifier to the exception that occurred.
|
|
138
|
+
results: Mapping from shard identifier to the result for succeeded shards.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
succeeded_shards: list[str],
|
|
144
|
+
failed_shards: list[str],
|
|
145
|
+
errors: dict[str, Exception],
|
|
146
|
+
results: dict[str, object],
|
|
147
|
+
) -> None:
|
|
148
|
+
self.succeeded_shards = succeeded_shards
|
|
149
|
+
self.failed_shards = failed_shards
|
|
150
|
+
self.errors = errors
|
|
151
|
+
self.results = results
|
|
152
|
+
|
|
153
|
+
n_ok = len(succeeded_shards)
|
|
154
|
+
n_fail = len(failed_shards)
|
|
155
|
+
total = n_ok + n_fail
|
|
156
|
+
|
|
157
|
+
lines = [f"{n_fail}/{total} shards failed during processing"]
|
|
158
|
+
for shard_id in failed_shards[:5]:
|
|
159
|
+
lines.append(f" {shard_id}: {errors[shard_id]}")
|
|
160
|
+
if n_fail > 5:
|
|
161
|
+
lines.append(f" ... and {n_fail - 5} more")
|
|
162
|
+
lines.append("")
|
|
163
|
+
lines.append(
|
|
164
|
+
f"Access .succeeded_shards ({n_ok}) and .failed_shards ({n_fail}) "
|
|
165
|
+
f"to inspect or retry."
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
super().__init__("\n".join(lines))
|
atdata/_helpers.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
"""Helper utilities for numpy array serialization.
|
|
2
2
|
|
|
3
3
|
This module provides utility functions for converting numpy arrays to and from
|
|
4
|
-
bytes for msgpack serialization.
|
|
5
|
-
format to preserve array dtype and shape information.
|
|
4
|
+
bytes for msgpack serialization.
|
|
6
5
|
|
|
7
6
|
Functions:
|
|
8
7
|
- ``array_to_bytes()``: Serialize numpy array to bytes
|
|
@@ -15,10 +14,14 @@ handling of NDArray fields during msgpack packing/unpacking.
|
|
|
15
14
|
##
|
|
16
15
|
# Imports
|
|
17
16
|
|
|
17
|
+
import struct
|
|
18
18
|
from io import BytesIO
|
|
19
19
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
|
|
22
|
+
# .npy format magic prefix (used for backward-compatible deserialization)
|
|
23
|
+
_NPY_MAGIC = b"\x93NUMPY"
|
|
24
|
+
|
|
22
25
|
|
|
23
26
|
##
|
|
24
27
|
|
|
@@ -26,35 +29,46 @@ import numpy as np
|
|
|
26
29
|
def array_to_bytes(x: np.ndarray) -> bytes:
|
|
27
30
|
"""Convert a numpy array to bytes for msgpack serialization.
|
|
28
31
|
|
|
29
|
-
Uses
|
|
32
|
+
Uses a compact binary format: a short header (dtype + shape) followed by
|
|
33
|
+
raw array bytes via ``ndarray.tobytes()``. Falls back to numpy's ``.npy``
|
|
34
|
+
format for object dtypes that cannot be represented as raw bytes.
|
|
30
35
|
|
|
31
36
|
Args:
|
|
32
37
|
x: A numpy array to serialize.
|
|
33
38
|
|
|
34
39
|
Returns:
|
|
35
40
|
Raw bytes representing the serialized array.
|
|
36
|
-
|
|
37
|
-
Note:
|
|
38
|
-
Uses ``allow_pickle=True`` to support object dtypes.
|
|
39
41
|
"""
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
42
|
+
if x.dtype == object:
|
|
43
|
+
buf = BytesIO()
|
|
44
|
+
np.save(buf, x, allow_pickle=True)
|
|
45
|
+
return buf.getvalue()
|
|
46
|
+
|
|
47
|
+
dtype_str = x.dtype.str.encode() # e.g. b'<f4'
|
|
48
|
+
header = struct.pack(f"<B{len(x.shape)}q", len(x.shape), *x.shape)
|
|
49
|
+
return struct.pack("<B", len(dtype_str)) + dtype_str + header + x.tobytes()
|
|
43
50
|
|
|
44
51
|
|
|
45
52
|
def bytes_to_array(b: bytes) -> np.ndarray:
|
|
46
53
|
"""Convert serialized bytes back to a numpy array.
|
|
47
54
|
|
|
48
|
-
|
|
55
|
+
Transparently handles both the compact format produced by the current
|
|
56
|
+
``array_to_bytes()`` and the legacy ``.npy`` format.
|
|
49
57
|
|
|
50
58
|
Args:
|
|
51
59
|
b: Raw bytes from a serialized numpy array.
|
|
52
60
|
|
|
53
61
|
Returns:
|
|
54
62
|
The deserialized numpy array with original dtype and shape.
|
|
55
|
-
|
|
56
|
-
Note:
|
|
57
|
-
Uses ``allow_pickle=True`` to support object dtypes.
|
|
58
63
|
"""
|
|
59
|
-
|
|
60
|
-
|
|
64
|
+
if b[:6] == _NPY_MAGIC:
|
|
65
|
+
return np.load(BytesIO(b), allow_pickle=True)
|
|
66
|
+
|
|
67
|
+
# Compact format: dtype_len(1B) + dtype_str + ndim(1B) + shape(ndim×8B) + data
|
|
68
|
+
dlen = b[0]
|
|
69
|
+
dtype = np.dtype(b[1 : 1 + dlen].decode())
|
|
70
|
+
ndim = b[1 + dlen]
|
|
71
|
+
offset = 2 + dlen
|
|
72
|
+
shape = struct.unpack_from(f"<{ndim}q", b, offset)
|
|
73
|
+
offset += ndim * 8
|
|
74
|
+
return np.frombuffer(b, dtype=dtype, offset=offset).reshape(shape).copy()
|
atdata/_hf_api.py
CHANGED
|
@@ -29,6 +29,7 @@ Examples:
|
|
|
29
29
|
from __future__ import annotations
|
|
30
30
|
|
|
31
31
|
import re
|
|
32
|
+
import threading
|
|
32
33
|
from pathlib import Path
|
|
33
34
|
from typing import (
|
|
34
35
|
TYPE_CHECKING,
|
|
@@ -40,9 +41,9 @@ from typing import (
|
|
|
40
41
|
overload,
|
|
41
42
|
)
|
|
42
43
|
|
|
43
|
-
from .dataset import Dataset,
|
|
44
|
+
from .dataset import Dataset, DictSample
|
|
44
45
|
from ._sources import URLSource, S3Source
|
|
45
|
-
from ._protocols import DataSource
|
|
46
|
+
from ._protocols import DataSource, Packable
|
|
46
47
|
|
|
47
48
|
if TYPE_CHECKING:
|
|
48
49
|
from ._protocols import AbstractIndex
|
|
@@ -50,7 +51,60 @@ if TYPE_CHECKING:
|
|
|
50
51
|
##
|
|
51
52
|
# Type variables
|
|
52
53
|
|
|
53
|
-
ST = TypeVar("ST", bound=
|
|
54
|
+
ST = TypeVar("ST", bound=Packable)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
##
|
|
58
|
+
# Default Index singleton
|
|
59
|
+
|
|
60
|
+
_default_index: "Index | None" = None # noqa: F821 (forward ref)
|
|
61
|
+
_default_index_lock = threading.Lock()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_default_index() -> "Index": # noqa: F821
|
|
65
|
+
"""Get or create the module-level default Index.
|
|
66
|
+
|
|
67
|
+
The default Index uses Redis for local storage (backwards-compatible
|
|
68
|
+
default) and an anonymous AtmosphereClient for read-only public data
|
|
69
|
+
resolution.
|
|
70
|
+
|
|
71
|
+
The default is created lazily on first access and cached for the
|
|
72
|
+
lifetime of the process.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
The default Index instance.
|
|
76
|
+
|
|
77
|
+
Examples:
|
|
78
|
+
>>> index = get_default_index()
|
|
79
|
+
>>> entry = index.get_dataset("local/mnist")
|
|
80
|
+
"""
|
|
81
|
+
global _default_index
|
|
82
|
+
if _default_index is None:
|
|
83
|
+
with _default_index_lock:
|
|
84
|
+
if _default_index is None:
|
|
85
|
+
from .local import Index
|
|
86
|
+
|
|
87
|
+
_default_index = Index()
|
|
88
|
+
return _default_index
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def set_default_index(index: "Index") -> None: # noqa: F821
|
|
92
|
+
"""Override the module-level default Index.
|
|
93
|
+
|
|
94
|
+
Use this to configure a custom default Index with specific repositories,
|
|
95
|
+
an authenticated atmosphere client, or non-default providers.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
index: The Index instance to use as the default.
|
|
99
|
+
|
|
100
|
+
Examples:
|
|
101
|
+
>>> from atdata.local import Index
|
|
102
|
+
>>> from atdata.providers import create_provider
|
|
103
|
+
>>> custom = Index(provider=create_provider("sqlite"))
|
|
104
|
+
>>> set_default_index(custom)
|
|
105
|
+
"""
|
|
106
|
+
global _default_index
|
|
107
|
+
_default_index = index
|
|
54
108
|
|
|
55
109
|
|
|
56
110
|
##
|
|
@@ -74,10 +128,11 @@ class DatasetDict(Generic[ST], dict):
|
|
|
74
128
|
>>>
|
|
75
129
|
>>> # Iterate over all splits
|
|
76
130
|
>>> for split_name, dataset in ds_dict.items():
|
|
77
|
-
... print(f"{split_name}: {len(dataset.
|
|
131
|
+
... print(f"{split_name}: {len(dataset.list_shards())} shards")
|
|
78
132
|
"""
|
|
79
133
|
|
|
80
|
-
#
|
|
134
|
+
# Note: The docstring uses "Parameters:" for type parameters as a workaround
|
|
135
|
+
# for quartodoc not supporting "Type Parameters:" sections.
|
|
81
136
|
|
|
82
137
|
def __init__(
|
|
83
138
|
self,
|
|
@@ -459,7 +514,7 @@ def _resolve_indexed_path(
|
|
|
459
514
|
handle_or_did, dataset_name = _parse_indexed_path(path)
|
|
460
515
|
|
|
461
516
|
# For AtmosphereIndex, we need to resolve handle to DID first
|
|
462
|
-
# For
|
|
517
|
+
# For local Index, the handle is ignored and we just look up by name
|
|
463
518
|
entry = index.get_dataset(dataset_name)
|
|
464
519
|
data_urls = entry.data_urls
|
|
465
520
|
|
|
@@ -624,16 +679,13 @@ def load_dataset(
|
|
|
624
679
|
>>> train_ds = load_dataset("./data/train-*.tar", TextData, split="train")
|
|
625
680
|
>>>
|
|
626
681
|
>>> # Load from index with auto-type resolution
|
|
627
|
-
>>> index =
|
|
682
|
+
>>> index = Index()
|
|
628
683
|
>>> ds = load_dataset("@local/my-dataset", index=index, split="train")
|
|
629
684
|
"""
|
|
630
685
|
# Handle @handle/dataset indexed path resolution
|
|
631
686
|
if _is_indexed_path(path):
|
|
632
687
|
if index is None:
|
|
633
|
-
|
|
634
|
-
f"Index required for indexed path: {path}. "
|
|
635
|
-
"Pass index=LocalIndex() or index=AtmosphereIndex(client)."
|
|
636
|
-
)
|
|
688
|
+
index = get_default_index()
|
|
637
689
|
|
|
638
690
|
source, schema_ref = _resolve_indexed_path(path, index)
|
|
639
691
|
|
atdata/_logging.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Pluggable logging for atdata.
|
|
2
|
+
|
|
3
|
+
Provides a thin abstraction over Python's stdlib ``logging`` module that can
|
|
4
|
+
be replaced with ``structlog`` or any other logger implementing the standard
|
|
5
|
+
``debug``/``info``/``warning``/``error`` interface.
|
|
6
|
+
|
|
7
|
+
Usage::
|
|
8
|
+
|
|
9
|
+
# Default: stdlib logging (no config needed)
|
|
10
|
+
from atdata._logging import get_logger
|
|
11
|
+
log = get_logger()
|
|
12
|
+
log.info("processing shard", extra={"shard": "data-000.tar"})
|
|
13
|
+
|
|
14
|
+
# Plug in structlog (or any compatible logger):
|
|
15
|
+
import structlog
|
|
16
|
+
import atdata
|
|
17
|
+
atdata.configure_logging(structlog.get_logger())
|
|
18
|
+
|
|
19
|
+
The module also exports a lightweight ``LoggerProtocol`` for type checking
|
|
20
|
+
custom logger implementations.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import logging
|
|
26
|
+
from typing import Any, Protocol, runtime_checkable
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@runtime_checkable
|
|
30
|
+
class LoggerProtocol(Protocol):
|
|
31
|
+
"""Minimal interface that a pluggable logger must satisfy."""
|
|
32
|
+
|
|
33
|
+
def debug(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
|
|
34
|
+
def info(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
|
|
35
|
+
def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
|
|
36
|
+
def error(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# ---------------------------------------------------------------------------
|
|
40
|
+
# Module-level state
|
|
41
|
+
# ---------------------------------------------------------------------------
|
|
42
|
+
|
|
43
|
+
_logger: LoggerProtocol = logging.getLogger("atdata")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def configure_logging(logger: LoggerProtocol) -> None:
|
|
47
|
+
"""Replace the default logger with a custom implementation.
|
|
48
|
+
|
|
49
|
+
The provided logger must implement ``debug``, ``info``, ``warning``, and
|
|
50
|
+
``error`` methods. Both ``structlog`` bound loggers and stdlib
|
|
51
|
+
``logging.Logger`` instances satisfy this interface.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
logger: A logger instance implementing :class:`LoggerProtocol`.
|
|
55
|
+
|
|
56
|
+
Examples:
|
|
57
|
+
>>> import structlog
|
|
58
|
+
>>> atdata.configure_logging(structlog.get_logger())
|
|
59
|
+
"""
|
|
60
|
+
global _logger
|
|
61
|
+
_logger = logger
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_logger() -> LoggerProtocol:
|
|
65
|
+
"""Return the currently configured logger.
|
|
66
|
+
|
|
67
|
+
Returns the stdlib ``logging.getLogger("atdata")`` by default, or
|
|
68
|
+
whatever was last set via :func:`configure_logging`.
|
|
69
|
+
"""
|
|
70
|
+
return _logger
|
atdata/_protocols.py
CHANGED
|
@@ -10,7 +10,7 @@ formalize that common interface.
|
|
|
10
10
|
Note:
|
|
11
11
|
Protocol methods use ``...`` (Ellipsis) as the body per PEP 544. This is
|
|
12
12
|
the standard Python syntax for Protocol definitions - these are interface
|
|
13
|
-
specifications, not stub implementations. Concrete classes (
|
|
13
|
+
specifications, not stub implementations. Concrete classes (Index,
|
|
14
14
|
AtmosphereIndex, etc.) provide the actual implementations.
|
|
15
15
|
|
|
16
16
|
Protocols:
|
|
@@ -24,7 +24,7 @@ Examples:
|
|
|
24
24
|
... for entry in index.list_datasets():
|
|
25
25
|
... print(f"{entry.name}: {entry.data_urls}")
|
|
26
26
|
...
|
|
27
|
-
>>> # Works with either
|
|
27
|
+
>>> # Works with either Index or AtmosphereIndex
|
|
28
28
|
>>> process_datasets(local_index)
|
|
29
29
|
>>> process_datasets(atmosphere_index)
|
|
30
30
|
"""
|
|
@@ -77,24 +77,16 @@ class Packable(Protocol):
|
|
|
77
77
|
"""
|
|
78
78
|
|
|
79
79
|
@classmethod
|
|
80
|
-
def from_data(cls, data: dict[str, Any]) -> "Packable":
|
|
81
|
-
"""Create instance from unpacked msgpack data dictionary."""
|
|
82
|
-
...
|
|
80
|
+
def from_data(cls, data: dict[str, Any]) -> "Packable": ...
|
|
83
81
|
|
|
84
82
|
@classmethod
|
|
85
|
-
def from_bytes(cls, bs: bytes) -> "Packable":
|
|
86
|
-
"""Create instance from raw msgpack bytes."""
|
|
87
|
-
...
|
|
83
|
+
def from_bytes(cls, bs: bytes) -> "Packable": ...
|
|
88
84
|
|
|
89
85
|
@property
|
|
90
|
-
def packed(self) -> bytes:
|
|
91
|
-
"""Pack this sample's data into msgpack bytes."""
|
|
92
|
-
...
|
|
86
|
+
def packed(self) -> bytes: ...
|
|
93
87
|
|
|
94
88
|
@property
|
|
95
|
-
def as_wds(self) -> dict[str, Any]:
|
|
96
|
-
"""WebDataset-compatible representation with __key__ and msgpack."""
|
|
97
|
-
...
|
|
89
|
+
def as_wds(self) -> dict[str, Any]: ...
|
|
98
90
|
|
|
99
91
|
|
|
100
92
|
##
|
|
@@ -116,16 +108,14 @@ class IndexEntry(Protocol):
|
|
|
116
108
|
"""
|
|
117
109
|
|
|
118
110
|
@property
|
|
119
|
-
def name(self) -> str:
|
|
120
|
-
"""Human-readable dataset name."""
|
|
121
|
-
...
|
|
111
|
+
def name(self) -> str: ...
|
|
122
112
|
|
|
123
113
|
@property
|
|
124
114
|
def schema_ref(self) -> str:
|
|
125
|
-
"""
|
|
115
|
+
"""Schema reference string.
|
|
126
116
|
|
|
127
|
-
|
|
128
|
-
|
|
117
|
+
Local: ``local://schemas/{module.Class}@{version}``
|
|
118
|
+
Atmosphere: ``at://did:plc:.../ac.foundation.dataset.sampleSchema/...``
|
|
129
119
|
"""
|
|
130
120
|
...
|
|
131
121
|
|
|
@@ -139,9 +129,7 @@ class IndexEntry(Protocol):
|
|
|
139
129
|
...
|
|
140
130
|
|
|
141
131
|
@property
|
|
142
|
-
def metadata(self) -> Optional[dict]:
|
|
143
|
-
"""Arbitrary metadata dictionary, or None if not set."""
|
|
144
|
-
...
|
|
132
|
+
def metadata(self) -> Optional[dict]: ...
|
|
145
133
|
|
|
146
134
|
|
|
147
135
|
##
|
|
@@ -149,7 +137,7 @@ class IndexEntry(Protocol):
|
|
|
149
137
|
|
|
150
138
|
|
|
151
139
|
class AbstractIndex(Protocol):
|
|
152
|
-
"""Protocol for index operations - implemented by
|
|
140
|
+
"""Protocol for index operations - implemented by Index and AtmosphereIndex.
|
|
153
141
|
|
|
154
142
|
This protocol defines the common interface for managing dataset metadata:
|
|
155
143
|
- Publishing and retrieving schemas
|
|
@@ -239,21 +227,9 @@ class AbstractIndex(Protocol):
|
|
|
239
227
|
...
|
|
240
228
|
|
|
241
229
|
@property
|
|
242
|
-
def datasets(self) -> Iterator[IndexEntry]:
|
|
243
|
-
"""Lazily iterate over all dataset entries in this index.
|
|
230
|
+
def datasets(self) -> Iterator[IndexEntry]: ...
|
|
244
231
|
|
|
245
|
-
|
|
246
|
-
IndexEntry for each dataset (may be of different sample types).
|
|
247
|
-
"""
|
|
248
|
-
...
|
|
249
|
-
|
|
250
|
-
def list_datasets(self) -> list[IndexEntry]:
|
|
251
|
-
"""Get all dataset entries as a materialized list.
|
|
252
|
-
|
|
253
|
-
Returns:
|
|
254
|
-
List of IndexEntry for each dataset.
|
|
255
|
-
"""
|
|
256
|
-
...
|
|
232
|
+
def list_datasets(self) -> list[IndexEntry]: ...
|
|
257
233
|
|
|
258
234
|
# Schema operations
|
|
259
235
|
|
|
@@ -299,21 +275,9 @@ class AbstractIndex(Protocol):
|
|
|
299
275
|
...
|
|
300
276
|
|
|
301
277
|
@property
|
|
302
|
-
def schemas(self) -> Iterator[dict]:
|
|
303
|
-
"""Lazily iterate over all schema records in this index.
|
|
278
|
+
def schemas(self) -> Iterator[dict]: ...
|
|
304
279
|
|
|
305
|
-
|
|
306
|
-
Schema records as dictionaries.
|
|
307
|
-
"""
|
|
308
|
-
...
|
|
309
|
-
|
|
310
|
-
def list_schemas(self) -> list[dict]:
|
|
311
|
-
"""Get all schema records as a materialized list.
|
|
312
|
-
|
|
313
|
-
Returns:
|
|
314
|
-
List of schema records as dictionaries.
|
|
315
|
-
"""
|
|
316
|
-
...
|
|
280
|
+
def list_schemas(self) -> list[dict]: ...
|
|
317
281
|
|
|
318
282
|
def decode_schema(self, ref: str) -> Type[Packable]:
|
|
319
283
|
"""Reconstruct a Python Packable type from a stored schema.
|
|
@@ -401,14 +365,7 @@ class AbstractDataStore(Protocol):
|
|
|
401
365
|
"""
|
|
402
366
|
...
|
|
403
367
|
|
|
404
|
-
def supports_streaming(self) -> bool:
|
|
405
|
-
"""Whether this store supports streaming reads.
|
|
406
|
-
|
|
407
|
-
Returns:
|
|
408
|
-
True if the store supports efficient streaming (like S3),
|
|
409
|
-
False if data must be fully downloaded first.
|
|
410
|
-
"""
|
|
411
|
-
...
|
|
368
|
+
def supports_streaming(self) -> bool: ...
|
|
412
369
|
|
|
413
370
|
|
|
414
371
|
##
|
|
@@ -481,13 +438,13 @@ class DataSource(Protocol):
|
|
|
481
438
|
only its assigned shards rather than iterating all shards.
|
|
482
439
|
|
|
483
440
|
Args:
|
|
484
|
-
shard_id: Shard identifier from
|
|
441
|
+
shard_id: Shard identifier from list_shards().
|
|
485
442
|
|
|
486
443
|
Returns:
|
|
487
444
|
File-like stream for reading the shard.
|
|
488
445
|
|
|
489
446
|
Raises:
|
|
490
|
-
KeyError: If shard_id is not in
|
|
447
|
+
KeyError: If shard_id is not in list_shards().
|
|
491
448
|
"""
|
|
492
449
|
...
|
|
493
450
|
|