atdata 0.2.3b1__py3-none-any.whl → 0.3.1b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. atdata/.gitignore +1 -0
  2. atdata/__init__.py +39 -0
  3. atdata/_cid.py +0 -21
  4. atdata/_exceptions.py +168 -0
  5. atdata/_helpers.py +41 -15
  6. atdata/_hf_api.py +95 -11
  7. atdata/_logging.py +70 -0
  8. atdata/_protocols.py +77 -238
  9. atdata/_schema_codec.py +7 -6
  10. atdata/_stub_manager.py +5 -25
  11. atdata/_type_utils.py +28 -2
  12. atdata/atmosphere/__init__.py +31 -20
  13. atdata/atmosphere/_types.py +4 -4
  14. atdata/atmosphere/client.py +64 -12
  15. atdata/atmosphere/lens.py +11 -12
  16. atdata/atmosphere/records.py +12 -12
  17. atdata/atmosphere/schema.py +16 -18
  18. atdata/atmosphere/store.py +6 -7
  19. atdata/cli/__init__.py +161 -175
  20. atdata/cli/diagnose.py +2 -2
  21. atdata/cli/{local.py → infra.py} +11 -11
  22. atdata/cli/inspect.py +69 -0
  23. atdata/cli/preview.py +63 -0
  24. atdata/cli/schema.py +109 -0
  25. atdata/dataset.py +583 -328
  26. atdata/index/__init__.py +54 -0
  27. atdata/index/_entry.py +157 -0
  28. atdata/index/_index.py +1198 -0
  29. atdata/index/_schema.py +380 -0
  30. atdata/lens.py +9 -2
  31. atdata/lexicons/__init__.py +121 -0
  32. atdata/lexicons/ac.foundation.dataset.arrayFormat.json +16 -0
  33. atdata/lexicons/ac.foundation.dataset.getLatestSchema.json +78 -0
  34. atdata/lexicons/ac.foundation.dataset.lens.json +99 -0
  35. atdata/lexicons/ac.foundation.dataset.record.json +96 -0
  36. atdata/lexicons/ac.foundation.dataset.schema.json +107 -0
  37. atdata/lexicons/ac.foundation.dataset.schemaType.json +16 -0
  38. atdata/lexicons/ac.foundation.dataset.storageBlobs.json +24 -0
  39. atdata/lexicons/ac.foundation.dataset.storageExternal.json +25 -0
  40. atdata/lexicons/ndarray_shim.json +16 -0
  41. atdata/local/__init__.py +70 -0
  42. atdata/local/_repo_legacy.py +218 -0
  43. atdata/manifest/__init__.py +28 -0
  44. atdata/manifest/_aggregates.py +156 -0
  45. atdata/manifest/_builder.py +163 -0
  46. atdata/manifest/_fields.py +154 -0
  47. atdata/manifest/_manifest.py +146 -0
  48. atdata/manifest/_query.py +150 -0
  49. atdata/manifest/_writer.py +74 -0
  50. atdata/promote.py +18 -14
  51. atdata/providers/__init__.py +25 -0
  52. atdata/providers/_base.py +140 -0
  53. atdata/providers/_factory.py +69 -0
  54. atdata/providers/_postgres.py +214 -0
  55. atdata/providers/_redis.py +171 -0
  56. atdata/providers/_sqlite.py +191 -0
  57. atdata/repository.py +323 -0
  58. atdata/stores/__init__.py +23 -0
  59. atdata/stores/_disk.py +123 -0
  60. atdata/stores/_s3.py +349 -0
  61. atdata/testing.py +341 -0
  62. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/METADATA +5 -2
  63. atdata-0.3.1b1.dist-info/RECORD +67 -0
  64. atdata/local.py +0 -1720
  65. atdata-0.2.3b1.dist-info/RECORD +0 -28
  66. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/WHEEL +0 -0
  67. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/entry_points.txt +0 -0
  68. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/licenses/LICENSE +0 -0
atdata/.gitignore ADDED
@@ -0,0 +1 @@
1
+ !manifest/
atdata/__init__.py CHANGED
@@ -44,6 +44,7 @@ from .dataset import (
44
44
  SampleBatch as SampleBatch,
45
45
  Dataset as Dataset,
46
46
  packable as packable,
47
+ write_samples as write_samples,
47
48
  )
48
49
 
49
50
  from .lens import (
@@ -55,6 +56,8 @@ from .lens import (
55
56
  from ._hf_api import (
56
57
  load_dataset as load_dataset,
57
58
  DatasetDict as DatasetDict,
59
+ get_default_index as get_default_index,
60
+ set_default_index as set_default_index,
58
61
  )
59
62
 
60
63
  from ._protocols import (
@@ -71,10 +74,37 @@ from ._sources import (
71
74
  BlobSource as BlobSource,
72
75
  )
73
76
 
77
+ from ._exceptions import (
78
+ AtdataError as AtdataError,
79
+ LensNotFoundError as LensNotFoundError,
80
+ SchemaError as SchemaError,
81
+ SampleKeyError as SampleKeyError,
82
+ ShardError as ShardError,
83
+ PartialFailureError as PartialFailureError,
84
+ )
85
+
74
86
  from ._schema_codec import (
75
87
  schema_to_type as schema_to_type,
76
88
  )
77
89
 
90
+ from ._logging import (
91
+ configure_logging as configure_logging,
92
+ get_logger as get_logger,
93
+ )
94
+
95
+ from .repository import (
96
+ Repository as Repository,
97
+ create_repository as create_repository,
98
+ )
99
+
100
+ from .index import (
101
+ Index as Index,
102
+ )
103
+
104
+ from .stores import (
105
+ LocalDiskStore as LocalDiskStore,
106
+ )
107
+
78
108
  from ._cid import (
79
109
  generate_cid as generate_cid,
80
110
  verify_cid as verify_cid,
@@ -84,6 +114,15 @@ from .promote import (
84
114
  promote_to_atmosphere as promote_to_atmosphere,
85
115
  )
86
116
 
117
+ from .manifest import (
118
+ ManifestField as ManifestField,
119
+ ManifestBuilder as ManifestBuilder,
120
+ ShardManifest as ShardManifest,
121
+ ManifestWriter as ManifestWriter,
122
+ QueryExecutor as QueryExecutor,
123
+ SampleLocation as SampleLocation,
124
+ )
125
+
87
126
  # ATProto integration (lazy import to avoid requiring atproto package)
88
127
  from . import atmosphere as atmosphere
89
128
 
atdata/_cid.py CHANGED
@@ -116,29 +116,8 @@ def verify_cid(cid: str, data: Any) -> bool:
116
116
  return cid == expected_cid
117
117
 
118
118
 
119
- def parse_cid(cid: str) -> dict:
120
- """Parse a CID string into its components.
121
-
122
- Args:
123
- cid: CID string to parse.
124
-
125
- Returns:
126
- Dictionary with 'version', 'codec', and 'hash' keys.
127
- The 'hash' value is itself a dict with 'code', 'size', and 'digest'.
128
-
129
- Examples:
130
- >>> info = parse_cid('bafyrei...')
131
- >>> info['version']
132
- 1
133
- >>> info['codec']
134
- 113 # 0x71 = dag-cbor
135
- """
136
- return libipld.decode_cid(cid)
137
-
138
-
139
119
  __all__ = [
140
120
  "generate_cid",
141
121
  "generate_cid_from_bytes",
142
122
  "verify_cid",
143
- "parse_cid",
144
123
  ]
atdata/_exceptions.py ADDED
@@ -0,0 +1,168 @@
1
+ """Custom exception hierarchy for atdata.
2
+
3
+ Provides actionable error messages with contextual help, available
4
+ alternatives, and suggested fix code snippets.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING
10
+
11
+ if TYPE_CHECKING:
12
+ from typing import Type
13
+
14
+
15
+ class AtdataError(Exception):
16
+ """Base exception for all atdata errors."""
17
+
18
+
19
+ class LensNotFoundError(AtdataError, ValueError):
20
+ """No lens registered to transform between two sample types.
21
+
22
+ Attributes:
23
+ source_type: The source sample type.
24
+ view_type: The target view type.
25
+ available_targets: Types reachable from the source via registered lenses.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ source_type: Type,
31
+ view_type: Type,
32
+ available_targets: list[tuple[Type, str]] | None = None,
33
+ ) -> None:
34
+ self.source_type = source_type
35
+ self.view_type = view_type
36
+ self.available_targets = available_targets or []
37
+
38
+ src_name = source_type.__name__
39
+ view_name = view_type.__name__
40
+
41
+ lines = [f"No lens transforms {src_name} \u2192 {view_name}"]
42
+
43
+ if self.available_targets:
44
+ lines.append("")
45
+ lines.append(f"Available lenses from {src_name}:")
46
+ for target_type, lens_name in self.available_targets:
47
+ lines.append(
48
+ f" - {src_name} \u2192 {target_type.__name__} (via {lens_name})"
49
+ )
50
+
51
+ lines.append("")
52
+ lines.append("Did you mean to define:")
53
+ lines.append(" @lens")
54
+ lines.append(
55
+ f" def {src_name.lower()}_to_{view_name.lower()}(source: {src_name}) -> {view_name}:"
56
+ )
57
+ lines.append(f" return {view_name}(...)")
58
+
59
+ super().__init__("\n".join(lines))
60
+
61
+
62
+ class SchemaError(AtdataError):
63
+ """Schema mismatch during sample deserialization.
64
+
65
+ Raised when the data in a shard doesn't match the expected sample type.
66
+
67
+ Attributes:
68
+ expected_fields: Fields expected by the sample type.
69
+ actual_fields: Fields found in the data.
70
+ sample_type_name: Name of the target sample type.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ sample_type_name: str,
76
+ expected_fields: list[str],
77
+ actual_fields: list[str],
78
+ ) -> None:
79
+ self.sample_type_name = sample_type_name
80
+ self.expected_fields = expected_fields
81
+ self.actual_fields = actual_fields
82
+
83
+ missing = sorted(set(expected_fields) - set(actual_fields))
84
+ extra = sorted(set(actual_fields) - set(expected_fields))
85
+
86
+ lines = [f"Schema mismatch for {sample_type_name}"]
87
+ if missing:
88
+ lines.append(f" Missing fields: {', '.join(missing)}")
89
+ if extra:
90
+ lines.append(f" Unexpected fields: {', '.join(extra)}")
91
+ lines.append("")
92
+ lines.append(f"Expected: {', '.join(sorted(expected_fields))}")
93
+ lines.append(f"Got: {', '.join(sorted(actual_fields))}")
94
+
95
+ super().__init__("\n".join(lines))
96
+
97
+
98
+ class SampleKeyError(AtdataError, KeyError):
99
+ """Sample with the given key was not found in the dataset.
100
+
101
+ Attributes:
102
+ key: The key that was not found.
103
+ """
104
+
105
+ def __init__(self, key: str) -> None:
106
+ self.key = key
107
+ super().__init__(
108
+ f"Sample with key '{key}' not found in dataset. "
109
+ f"Note: key lookup requires scanning all shards and is O(n)."
110
+ )
111
+
112
+
113
+ class ShardError(AtdataError):
114
+ """Error accessing or reading a dataset shard.
115
+
116
+ Attributes:
117
+ shard_id: Identifier of the shard that failed.
118
+ reason: Human-readable description of what went wrong.
119
+ """
120
+
121
+ def __init__(self, shard_id: str, reason: str) -> None:
122
+ self.shard_id = shard_id
123
+ self.reason = reason
124
+ super().__init__(f"Failed to read shard '{shard_id}': {reason}")
125
+
126
+
127
+ class PartialFailureError(AtdataError):
128
+ """Some shards succeeded but others failed during processing.
129
+
130
+ Raised by :meth:`Dataset.process_shards` when at least one shard fails.
131
+ Provides access to both the successful results and the per-shard errors,
132
+ enabling retry of only the failed shards.
133
+
134
+ Attributes:
135
+ succeeded_shards: List of shard identifiers that succeeded.
136
+ failed_shards: List of shard identifiers that failed.
137
+ errors: Mapping from shard identifier to the exception that occurred.
138
+ results: Mapping from shard identifier to the result for succeeded shards.
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ succeeded_shards: list[str],
144
+ failed_shards: list[str],
145
+ errors: dict[str, Exception],
146
+ results: dict[str, object],
147
+ ) -> None:
148
+ self.succeeded_shards = succeeded_shards
149
+ self.failed_shards = failed_shards
150
+ self.errors = errors
151
+ self.results = results
152
+
153
+ n_ok = len(succeeded_shards)
154
+ n_fail = len(failed_shards)
155
+ total = n_ok + n_fail
156
+
157
+ lines = [f"{n_fail}/{total} shards failed during processing"]
158
+ for shard_id in failed_shards[:5]:
159
+ lines.append(f" {shard_id}: {errors[shard_id]}")
160
+ if n_fail > 5:
161
+ lines.append(f" ... and {n_fail - 5} more")
162
+ lines.append("")
163
+ lines.append(
164
+ f"Access .succeeded_shards ({n_ok}) and .failed_shards ({n_fail}) "
165
+ f"to inspect or retry."
166
+ )
167
+
168
+ super().__init__("\n".join(lines))
atdata/_helpers.py CHANGED
@@ -1,8 +1,7 @@
1
1
  """Helper utilities for numpy array serialization.
2
2
 
3
3
  This module provides utility functions for converting numpy arrays to and from
4
- bytes for msgpack serialization. The functions use numpy's native save/load
5
- format to preserve array dtype and shape information.
4
+ bytes for msgpack serialization.
6
5
 
7
6
  Functions:
8
7
  - ``array_to_bytes()``: Serialize numpy array to bytes
@@ -15,10 +14,14 @@ handling of NDArray fields during msgpack packing/unpacking.
15
14
  ##
16
15
  # Imports
17
16
 
17
+ import struct
18
18
  from io import BytesIO
19
19
 
20
20
  import numpy as np
21
21
 
22
+ # .npy format magic prefix (used for backward-compatible deserialization)
23
+ _NPY_MAGIC = b"\x93NUMPY"
24
+
22
25
 
23
26
  ##
24
27
 
@@ -26,35 +29,58 @@ import numpy as np
26
29
  def array_to_bytes(x: np.ndarray) -> bytes:
27
30
  """Convert a numpy array to bytes for msgpack serialization.
28
31
 
29
- Uses numpy's native ``save()`` format to preserve array dtype and shape.
32
+ Uses a compact binary format: a short header (dtype + shape) followed by
33
+ raw array bytes via ``ndarray.tobytes()``. Falls back to numpy's ``.npy``
34
+ format for object dtypes that cannot be represented as raw bytes.
30
35
 
31
36
  Args:
32
37
  x: A numpy array to serialize.
33
38
 
34
39
  Returns:
35
40
  Raw bytes representing the serialized array.
36
-
37
- Note:
38
- Uses ``allow_pickle=True`` to support object dtypes.
39
41
  """
40
- np_bytes = BytesIO()
41
- np.save(np_bytes, x, allow_pickle=True)
42
- return np_bytes.getvalue()
42
+ if x.dtype == object:
43
+ buf = BytesIO()
44
+ np.save(buf, x, allow_pickle=True)
45
+ return buf.getvalue()
46
+
47
+ dtype_str = x.dtype.str.encode() # e.g. b'<f4'
48
+ header = struct.pack(f"<B{len(x.shape)}q", len(x.shape), *x.shape)
49
+ return struct.pack("<B", len(dtype_str)) + dtype_str + header + x.tobytes()
43
50
 
44
51
 
45
52
  def bytes_to_array(b: bytes) -> np.ndarray:
46
53
  """Convert serialized bytes back to a numpy array.
47
54
 
48
- Reverses the serialization performed by ``array_to_bytes()``.
55
+ Transparently handles both the compact format produced by the current
56
+ ``array_to_bytes()`` and the legacy ``.npy`` format.
49
57
 
50
58
  Args:
51
59
  b: Raw bytes from a serialized numpy array.
52
60
 
53
61
  Returns:
54
62
  The deserialized numpy array with original dtype and shape.
55
-
56
- Note:
57
- Uses ``allow_pickle=True`` to support object dtypes.
58
63
  """
59
- np_bytes = BytesIO(b)
60
- return np.load(np_bytes, allow_pickle=True)
64
+ if b[:6] == _NPY_MAGIC:
65
+ return np.load(BytesIO(b), allow_pickle=True)
66
+
67
+ # Compact format: dtype_len(1B) + dtype_str + ndim(1B) + shape(ndim×8B) + data
68
+ if len(b) < 2:
69
+ raise ValueError(f"Array buffer too short ({len(b)} bytes): need at least 2")
70
+ dlen = b[0]
71
+ min_header = 2 + dlen # dtype_len + dtype_str + ndim
72
+ if len(b) < min_header:
73
+ raise ValueError(
74
+ f"Array buffer too short ({len(b)} bytes): need at least {min_header} for header"
75
+ )
76
+ dtype = np.dtype(b[1 : 1 + dlen].decode())
77
+ ndim = b[1 + dlen]
78
+ offset = 2 + dlen
79
+ min_with_shape = offset + ndim * 8
80
+ if len(b) < min_with_shape:
81
+ raise ValueError(
82
+ f"Array buffer too short ({len(b)} bytes): need at least {min_with_shape} for shape"
83
+ )
84
+ shape = struct.unpack_from(f"<{ndim}q", b, offset)
85
+ offset += ndim * 8
86
+ return np.frombuffer(b, dtype=dtype, offset=offset).reshape(shape).copy()
atdata/_hf_api.py CHANGED
@@ -29,8 +29,10 @@ Examples:
29
29
  from __future__ import annotations
30
30
 
31
31
  import re
32
+ import threading
32
33
  from pathlib import Path
33
34
  from typing import (
35
+ Any,
34
36
  TYPE_CHECKING,
35
37
  Generic,
36
38
  Mapping,
@@ -40,9 +42,9 @@ from typing import (
40
42
  overload,
41
43
  )
42
44
 
43
- from .dataset import Dataset, PackableSample, DictSample
45
+ from .dataset import Dataset, DictSample
44
46
  from ._sources import URLSource, S3Source
45
- from ._protocols import DataSource
47
+ from ._protocols import DataSource, Packable
46
48
 
47
49
  if TYPE_CHECKING:
48
50
  from ._protocols import AbstractIndex
@@ -50,7 +52,60 @@ if TYPE_CHECKING:
50
52
  ##
51
53
  # Type variables
52
54
 
53
- ST = TypeVar("ST", bound=PackableSample)
55
+ ST = TypeVar("ST", bound=Packable)
56
+
57
+
58
+ ##
59
+ # Default Index singleton
60
+
61
+ _default_index: "Index | None" = None # noqa: F821 (forward ref)
62
+ _default_index_lock = threading.Lock()
63
+
64
+
65
+ def get_default_index() -> "Index": # noqa: F821
66
+ """Get or create the module-level default Index.
67
+
68
+ The default Index uses Redis for local storage (backwards-compatible
69
+ default) and an anonymous Atmosphere for read-only public data
70
+ resolution.
71
+
72
+ The default is created lazily on first access and cached for the
73
+ lifetime of the process.
74
+
75
+ Returns:
76
+ The default Index instance.
77
+
78
+ Examples:
79
+ >>> index = get_default_index()
80
+ >>> entry = index.get_dataset("local/mnist")
81
+ """
82
+ global _default_index
83
+ if _default_index is None:
84
+ with _default_index_lock:
85
+ if _default_index is None:
86
+ from .local import Index
87
+
88
+ _default_index = Index()
89
+ return _default_index
90
+
91
+
92
+ def set_default_index(index: "Index") -> None: # noqa: F821
93
+ """Override the module-level default Index.
94
+
95
+ Use this to configure a custom default Index with specific repositories,
96
+ an authenticated atmosphere client, or non-default providers.
97
+
98
+ Args:
99
+ index: The Index instance to use as the default.
100
+
101
+ Examples:
102
+ >>> from atdata.local import Index
103
+ >>> from atdata.providers import create_provider
104
+ >>> custom = Index(provider=create_provider("sqlite"))
105
+ >>> set_default_index(custom)
106
+ """
107
+ global _default_index
108
+ _default_index = index
54
109
 
55
110
 
56
111
  ##
@@ -74,10 +129,11 @@ class DatasetDict(Generic[ST], dict):
74
129
  >>>
75
130
  >>> # Iterate over all splits
76
131
  >>> for split_name, dataset in ds_dict.items():
77
- ... print(f"{split_name}: {len(dataset.shard_list)} shards")
132
+ ... print(f"{split_name}: {len(dataset.list_shards())} shards")
78
133
  """
79
134
 
80
- # TODO The above has a line for "Parameters:" that should be "Type Parameters:"; this is a temporary fix for `quartodoc` auto-generation bugs.
135
+ # Note: The docstring uses "Parameters:" for type parameters as a workaround
136
+ # for quartodoc not supporting "Type Parameters:" sections.
81
137
 
82
138
  def __init__(
83
139
  self,
@@ -134,6 +190,37 @@ class DatasetDict(Generic[ST], dict):
134
190
  """
135
191
  return {name: len(ds.list_shards()) for name, ds in self.items()}
136
192
 
193
+ # Methods proxied to the sole Dataset when only one split exists.
194
+ _DATASET_METHODS = frozenset(
195
+ {
196
+ "ordered",
197
+ "shuffled",
198
+ "as_type",
199
+ "list_shards",
200
+ "head",
201
+ }
202
+ )
203
+
204
+ def __getattr__(self, name: str) -> Any:
205
+ """Proxy common Dataset methods when this dict has exactly one split.
206
+
207
+ When a ``DatasetDict`` contains a single split, calling iteration
208
+ methods like ``.ordered()`` or ``.shuffled()`` is forwarded to the
209
+ contained ``Dataset`` for convenience. Multi-split dicts raise
210
+ ``AttributeError`` with a hint to select a split explicitly.
211
+ """
212
+ if name in self._DATASET_METHODS:
213
+ if len(self) == 1:
214
+ return getattr(next(iter(self.values())), name)
215
+ splits = ", ".join(f"'{k}'" for k in self.keys())
216
+ raise AttributeError(
217
+ f"'{type(self).__name__}' has {len(self)} splits ({splits}). "
218
+ f"Select one first, e.g. ds_dict['{next(iter(self.keys()))}'].{name}()"
219
+ )
220
+ raise AttributeError(
221
+ f"'{type(self).__name__}' object has no attribute '{name}'"
222
+ )
223
+
137
224
 
138
225
  ##
139
226
  # Path resolution utilities
@@ -459,7 +546,7 @@ def _resolve_indexed_path(
459
546
  handle_or_did, dataset_name = _parse_indexed_path(path)
460
547
 
461
548
  # For AtmosphereIndex, we need to resolve handle to DID first
462
- # For LocalIndex, the handle is ignored and we just look up by name
549
+ # For local Index, the handle is ignored and we just look up by name
463
550
  entry = index.get_dataset(dataset_name)
464
551
  data_urls = entry.data_urls
465
552
 
@@ -624,16 +711,13 @@ def load_dataset(
624
711
  >>> train_ds = load_dataset("./data/train-*.tar", TextData, split="train")
625
712
  >>>
626
713
  >>> # Load from index with auto-type resolution
627
- >>> index = LocalIndex()
714
+ >>> index = Index()
628
715
  >>> ds = load_dataset("@local/my-dataset", index=index, split="train")
629
716
  """
630
717
  # Handle @handle/dataset indexed path resolution
631
718
  if _is_indexed_path(path):
632
719
  if index is None:
633
- raise ValueError(
634
- f"Index required for indexed path: {path}. "
635
- "Pass index=LocalIndex() or index=AtmosphereIndex(client)."
636
- )
720
+ index = get_default_index()
637
721
 
638
722
  source, schema_ref = _resolve_indexed_path(path, index)
639
723
 
atdata/_logging.py ADDED
@@ -0,0 +1,70 @@
1
+ """Pluggable logging for atdata.
2
+
3
+ Provides a thin abstraction over Python's stdlib ``logging`` module that can
4
+ be replaced with ``structlog`` or any other logger implementing the standard
5
+ ``debug``/``info``/``warning``/``error`` interface.
6
+
7
+ Usage::
8
+
9
+ # Default: stdlib logging (no config needed)
10
+ from atdata._logging import get_logger
11
+ log = get_logger()
12
+ log.info("processing shard", extra={"shard": "data-000.tar"})
13
+
14
+ # Plug in structlog (or any compatible logger):
15
+ import structlog
16
+ import atdata
17
+ atdata.configure_logging(structlog.get_logger())
18
+
19
+ The module also exports a lightweight ``LoggerProtocol`` for type checking
20
+ custom logger implementations.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import logging
26
+ from typing import Any, Protocol, runtime_checkable
27
+
28
+
29
+ @runtime_checkable
30
+ class LoggerProtocol(Protocol):
31
+ """Minimal interface that a pluggable logger must satisfy."""
32
+
33
+ def debug(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
34
+ def info(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
35
+ def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
36
+ def error(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
37
+
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Module-level state
41
+ # ---------------------------------------------------------------------------
42
+
43
+ _logger: LoggerProtocol = logging.getLogger("atdata")
44
+
45
+
46
+ def configure_logging(logger: LoggerProtocol) -> None:
47
+ """Replace the default logger with a custom implementation.
48
+
49
+ The provided logger must implement ``debug``, ``info``, ``warning``, and
50
+ ``error`` methods. Both ``structlog`` bound loggers and stdlib
51
+ ``logging.Logger`` instances satisfy this interface.
52
+
53
+ Args:
54
+ logger: A logger instance implementing :class:`LoggerProtocol`.
55
+
56
+ Examples:
57
+ >>> import structlog
58
+ >>> atdata.configure_logging(structlog.get_logger())
59
+ """
60
+ global _logger
61
+ _logger = logger
62
+
63
+
64
+ def get_logger() -> LoggerProtocol:
65
+ """Return the currently configured logger.
66
+
67
+ Returns the stdlib ``logging.getLogger("atdata")`` by default, or
68
+ whatever was last set via :func:`configure_logging`.
69
+ """
70
+ return _logger