atdata 0.2.2b1__py3-none-any.whl → 0.3.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. atdata/.gitignore +1 -0
  2. atdata/__init__.py +31 -1
  3. atdata/_cid.py +29 -35
  4. atdata/_exceptions.py +168 -0
  5. atdata/_helpers.py +33 -17
  6. atdata/_hf_api.py +109 -59
  7. atdata/_logging.py +70 -0
  8. atdata/_protocols.py +74 -132
  9. atdata/_schema_codec.py +38 -41
  10. atdata/_sources.py +57 -64
  11. atdata/_stub_manager.py +31 -26
  12. atdata/_type_utils.py +47 -7
  13. atdata/atmosphere/__init__.py +31 -24
  14. atdata/atmosphere/_types.py +11 -11
  15. atdata/atmosphere/client.py +11 -8
  16. atdata/atmosphere/lens.py +27 -30
  17. atdata/atmosphere/records.py +34 -39
  18. atdata/atmosphere/schema.py +35 -31
  19. atdata/atmosphere/store.py +16 -20
  20. atdata/cli/__init__.py +163 -168
  21. atdata/cli/diagnose.py +12 -8
  22. atdata/cli/inspect.py +69 -0
  23. atdata/cli/local.py +5 -2
  24. atdata/cli/preview.py +63 -0
  25. atdata/cli/schema.py +109 -0
  26. atdata/dataset.py +678 -533
  27. atdata/lens.py +85 -83
  28. atdata/local/__init__.py +71 -0
  29. atdata/local/_entry.py +157 -0
  30. atdata/local/_index.py +940 -0
  31. atdata/local/_repo_legacy.py +218 -0
  32. atdata/local/_s3.py +349 -0
  33. atdata/local/_schema.py +380 -0
  34. atdata/manifest/__init__.py +28 -0
  35. atdata/manifest/_aggregates.py +156 -0
  36. atdata/manifest/_builder.py +163 -0
  37. atdata/manifest/_fields.py +154 -0
  38. atdata/manifest/_manifest.py +146 -0
  39. atdata/manifest/_query.py +150 -0
  40. atdata/manifest/_writer.py +74 -0
  41. atdata/promote.py +20 -24
  42. atdata/providers/__init__.py +25 -0
  43. atdata/providers/_base.py +140 -0
  44. atdata/providers/_factory.py +69 -0
  45. atdata/providers/_postgres.py +214 -0
  46. atdata/providers/_redis.py +171 -0
  47. atdata/providers/_sqlite.py +191 -0
  48. atdata/repository.py +323 -0
  49. atdata/testing.py +337 -0
  50. {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/METADATA +5 -1
  51. atdata-0.3.0b1.dist-info/RECORD +54 -0
  52. atdata/local.py +0 -1707
  53. atdata-0.2.2b1.dist-info/RECORD +0 -28
  54. {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/WHEEL +0 -0
  55. {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/entry_points.txt +0 -0
  56. {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/licenses/LICENSE +0 -0
atdata/.gitignore ADDED
@@ -0,0 +1 @@
1
+ !manifest/
atdata/__init__.py CHANGED
@@ -55,6 +55,8 @@ from .lens import (
55
55
  from ._hf_api import (
56
56
  load_dataset as load_dataset,
57
57
  DatasetDict as DatasetDict,
58
+ get_default_index as get_default_index,
59
+ set_default_index as set_default_index,
58
60
  )
59
61
 
60
62
  from ._protocols import (
@@ -71,10 +73,29 @@ from ._sources import (
71
73
  BlobSource as BlobSource,
72
74
  )
73
75
 
76
+ from ._exceptions import (
77
+ AtdataError as AtdataError,
78
+ LensNotFoundError as LensNotFoundError,
79
+ SchemaError as SchemaError,
80
+ SampleKeyError as SampleKeyError,
81
+ ShardError as ShardError,
82
+ PartialFailureError as PartialFailureError,
83
+ )
84
+
74
85
  from ._schema_codec import (
75
86
  schema_to_type as schema_to_type,
76
87
  )
77
88
 
89
+ from ._logging import (
90
+ configure_logging as configure_logging,
91
+ get_logger as get_logger,
92
+ )
93
+
94
+ from .repository import (
95
+ Repository as Repository,
96
+ create_repository as create_repository,
97
+ )
98
+
78
99
  from ._cid import (
79
100
  generate_cid as generate_cid,
80
101
  verify_cid as verify_cid,
@@ -84,8 +105,17 @@ from .promote import (
84
105
  promote_to_atmosphere as promote_to_atmosphere,
85
106
  )
86
107
 
108
+ from .manifest import (
109
+ ManifestField as ManifestField,
110
+ ManifestBuilder as ManifestBuilder,
111
+ ShardManifest as ShardManifest,
112
+ ManifestWriter as ManifestWriter,
113
+ QueryExecutor as QueryExecutor,
114
+ SampleLocation as SampleLocation,
115
+ )
116
+
87
117
  # ATProto integration (lazy import to avoid requiring atproto package)
88
118
  from . import atmosphere as atmosphere
89
119
 
90
120
  # CLI entry point
91
- from .cli import main as main
121
+ from .cli import main as main
atdata/_cid.py CHANGED
@@ -12,13 +12,11 @@ The CIDs generated here use:
12
12
  This ensures compatibility with ATProto's CID requirements and enables
13
13
  seamless promotion from local storage to atmosphere (ATProto network).
14
14
 
15
- Example:
16
- ::
17
-
18
- >>> schema = {"name": "ImageSample", "version": "1.0.0", "fields": [...]}
19
- >>> cid = generate_cid(schema)
20
- >>> print(cid)
21
- bafyreihffx5a2e7k6r5zqgp5iwpjqr2gfyheqhzqtlxagvqjqyxzqpzqaa
15
+ Examples:
16
+ >>> schema = {"name": "ImageSample", "version": "1.0.0", "fields": [...]}
17
+ >>> cid = generate_cid(schema)
18
+ >>> print(cid)
19
+ bafyreihffx5a2e7k6r5zqgp5iwpjqr2gfyheqhzqtlxagvqjqyxzqpzqaa
22
20
  """
23
21
 
24
22
  import hashlib
@@ -50,11 +48,9 @@ def generate_cid(data: Any) -> str:
50
48
  Raises:
51
49
  ValueError: If the data cannot be encoded as DAG-CBOR.
52
50
 
53
- Example:
54
- ::
55
-
56
- >>> generate_cid({"name": "test", "value": 42})
57
- 'bafyrei...'
51
+ Examples:
52
+ >>> generate_cid({"name": "test", "value": 42})
53
+ 'bafyrei...'
58
54
  """
59
55
  # Encode data as DAG-CBOR
60
56
  try:
@@ -68,7 +64,9 @@ def generate_cid(data: Any) -> str:
68
64
  # Build raw CID bytes:
69
65
  # CIDv1 = version(1) + codec(dag-cbor) + multihash
70
66
  # Multihash = code(sha256) + size(32) + digest
71
- raw_cid_bytes = bytes([CID_VERSION_1, CODEC_DAG_CBOR, HASH_SHA256, SHA256_SIZE]) + sha256_hash
67
+ raw_cid_bytes = (
68
+ bytes([CID_VERSION_1, CODEC_DAG_CBOR, HASH_SHA256, SHA256_SIZE]) + sha256_hash
69
+ )
72
70
 
73
71
  # Encode to base32 multibase string
74
72
  return libipld.encode_cid(raw_cid_bytes)
@@ -86,14 +84,14 @@ def generate_cid_from_bytes(data_bytes: bytes) -> str:
86
84
  Returns:
87
85
  CIDv1 string in base32 multibase format.
88
86
 
89
- Example:
90
- ::
91
-
92
- >>> cbor_bytes = libipld.encode_dag_cbor({"key": "value"})
93
- >>> cid = generate_cid_from_bytes(cbor_bytes)
87
+ Examples:
88
+ >>> cbor_bytes = libipld.encode_dag_cbor({"key": "value"})
89
+ >>> cid = generate_cid_from_bytes(cbor_bytes)
94
90
  """
95
91
  sha256_hash = hashlib.sha256(data_bytes).digest()
96
- raw_cid_bytes = bytes([CID_VERSION_1, CODEC_DAG_CBOR, HASH_SHA256, SHA256_SIZE]) + sha256_hash
92
+ raw_cid_bytes = (
93
+ bytes([CID_VERSION_1, CODEC_DAG_CBOR, HASH_SHA256, SHA256_SIZE]) + sha256_hash
94
+ )
97
95
  return libipld.encode_cid(raw_cid_bytes)
98
96
 
99
97
 
@@ -107,14 +105,12 @@ def verify_cid(cid: str, data: Any) -> bool:
107
105
  Returns:
108
106
  True if the CID matches the data, False otherwise.
109
107
 
110
- Example:
111
- ::
112
-
113
- >>> cid = generate_cid({"name": "test"})
114
- >>> verify_cid(cid, {"name": "test"})
115
- True
116
- >>> verify_cid(cid, {"name": "different"})
117
- False
108
+ Examples:
109
+ >>> cid = generate_cid({"name": "test"})
110
+ >>> verify_cid(cid, {"name": "test"})
111
+ True
112
+ >>> verify_cid(cid, {"name": "different"})
113
+ False
118
114
  """
119
115
  expected_cid = generate_cid(data)
120
116
  return cid == expected_cid
@@ -130,14 +126,12 @@ def parse_cid(cid: str) -> dict:
130
126
  Dictionary with 'version', 'codec', and 'hash' keys.
131
127
  The 'hash' value is itself a dict with 'code', 'size', and 'digest'.
132
128
 
133
- Example:
134
- ::
135
-
136
- >>> info = parse_cid('bafyrei...')
137
- >>> info['version']
138
- 1
139
- >>> info['codec']
140
- 113 # 0x71 = dag-cbor
129
+ Examples:
130
+ >>> info = parse_cid('bafyrei...')
131
+ >>> info['version']
132
+ 1
133
+ >>> info['codec']
134
+ 113 # 0x71 = dag-cbor
141
135
  """
142
136
  return libipld.decode_cid(cid)
143
137
 
atdata/_exceptions.py ADDED
@@ -0,0 +1,168 @@
1
+ """Custom exception hierarchy for atdata.
2
+
3
+ Provides actionable error messages with contextual help, available
4
+ alternatives, and suggested fix code snippets.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING
10
+
11
+ if TYPE_CHECKING:
12
+ from typing import Type
13
+
14
+
15
+ class AtdataError(Exception):
16
+ """Base exception for all atdata errors."""
17
+
18
+
19
+ class LensNotFoundError(AtdataError, ValueError):
20
+ """No lens registered to transform between two sample types.
21
+
22
+ Attributes:
23
+ source_type: The source sample type.
24
+ view_type: The target view type.
25
+ available_targets: Types reachable from the source via registered lenses.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ source_type: Type,
31
+ view_type: Type,
32
+ available_targets: list[tuple[Type, str]] | None = None,
33
+ ) -> None:
34
+ self.source_type = source_type
35
+ self.view_type = view_type
36
+ self.available_targets = available_targets or []
37
+
38
+ src_name = source_type.__name__
39
+ view_name = view_type.__name__
40
+
41
+ lines = [f"No lens transforms {src_name} \u2192 {view_name}"]
42
+
43
+ if self.available_targets:
44
+ lines.append("")
45
+ lines.append(f"Available lenses from {src_name}:")
46
+ for target_type, lens_name in self.available_targets:
47
+ lines.append(
48
+ f" - {src_name} \u2192 {target_type.__name__} (via {lens_name})"
49
+ )
50
+
51
+ lines.append("")
52
+ lines.append("Did you mean to define:")
53
+ lines.append(" @lens")
54
+ lines.append(
55
+ f" def {src_name.lower()}_to_{view_name.lower()}(source: {src_name}) -> {view_name}:"
56
+ )
57
+ lines.append(f" return {view_name}(...)")
58
+
59
+ super().__init__("\n".join(lines))
60
+
61
+
62
+ class SchemaError(AtdataError):
63
+ """Schema mismatch during sample deserialization.
64
+
65
+ Raised when the data in a shard doesn't match the expected sample type.
66
+
67
+ Attributes:
68
+ expected_fields: Fields expected by the sample type.
69
+ actual_fields: Fields found in the data.
70
+ sample_type_name: Name of the target sample type.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ sample_type_name: str,
76
+ expected_fields: list[str],
77
+ actual_fields: list[str],
78
+ ) -> None:
79
+ self.sample_type_name = sample_type_name
80
+ self.expected_fields = expected_fields
81
+ self.actual_fields = actual_fields
82
+
83
+ missing = sorted(set(expected_fields) - set(actual_fields))
84
+ extra = sorted(set(actual_fields) - set(expected_fields))
85
+
86
+ lines = [f"Schema mismatch for {sample_type_name}"]
87
+ if missing:
88
+ lines.append(f" Missing fields: {', '.join(missing)}")
89
+ if extra:
90
+ lines.append(f" Unexpected fields: {', '.join(extra)}")
91
+ lines.append("")
92
+ lines.append(f"Expected: {', '.join(sorted(expected_fields))}")
93
+ lines.append(f"Got: {', '.join(sorted(actual_fields))}")
94
+
95
+ super().__init__("\n".join(lines))
96
+
97
+
98
+ class SampleKeyError(AtdataError, KeyError):
99
+ """Sample with the given key was not found in the dataset.
100
+
101
+ Attributes:
102
+ key: The key that was not found.
103
+ """
104
+
105
+ def __init__(self, key: str) -> None:
106
+ self.key = key
107
+ super().__init__(
108
+ f"Sample with key '{key}' not found in dataset. "
109
+ f"Note: key lookup requires scanning all shards and is O(n)."
110
+ )
111
+
112
+
113
+ class ShardError(AtdataError):
114
+ """Error accessing or reading a dataset shard.
115
+
116
+ Attributes:
117
+ shard_id: Identifier of the shard that failed.
118
+ reason: Human-readable description of what went wrong.
119
+ """
120
+
121
+ def __init__(self, shard_id: str, reason: str) -> None:
122
+ self.shard_id = shard_id
123
+ self.reason = reason
124
+ super().__init__(f"Failed to read shard '{shard_id}': {reason}")
125
+
126
+
127
+ class PartialFailureError(AtdataError):
128
+ """Some shards succeeded but others failed during processing.
129
+
130
+ Raised by :meth:`Dataset.process_shards` when at least one shard fails.
131
+ Provides access to both the successful results and the per-shard errors,
132
+ enabling retry of only the failed shards.
133
+
134
+ Attributes:
135
+ succeeded_shards: List of shard identifiers that succeeded.
136
+ failed_shards: List of shard identifiers that failed.
137
+ errors: Mapping from shard identifier to the exception that occurred.
138
+ results: Mapping from shard identifier to the result for succeeded shards.
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ succeeded_shards: list[str],
144
+ failed_shards: list[str],
145
+ errors: dict[str, Exception],
146
+ results: dict[str, object],
147
+ ) -> None:
148
+ self.succeeded_shards = succeeded_shards
149
+ self.failed_shards = failed_shards
150
+ self.errors = errors
151
+ self.results = results
152
+
153
+ n_ok = len(succeeded_shards)
154
+ n_fail = len(failed_shards)
155
+ total = n_ok + n_fail
156
+
157
+ lines = [f"{n_fail}/{total} shards failed during processing"]
158
+ for shard_id in failed_shards[:5]:
159
+ lines.append(f" {shard_id}: {errors[shard_id]}")
160
+ if n_fail > 5:
161
+ lines.append(f" ... and {n_fail - 5} more")
162
+ lines.append("")
163
+ lines.append(
164
+ f"Access .succeeded_shards ({n_ok}) and .failed_shards ({n_fail}) "
165
+ f"to inspect or retry."
166
+ )
167
+
168
+ super().__init__("\n".join(lines))
atdata/_helpers.py CHANGED
@@ -1,8 +1,7 @@
1
1
  """Helper utilities for numpy array serialization.
2
2
 
3
3
  This module provides utility functions for converting numpy arrays to and from
4
- bytes for msgpack serialization. The functions use numpy's native save/load
5
- format to preserve array dtype and shape information.
4
+ bytes for msgpack serialization.
6
5
 
7
6
  Functions:
8
7
  - ``array_to_bytes()``: Serialize numpy array to bytes
@@ -15,44 +14,61 @@ handling of NDArray fields during msgpack packing/unpacking.
15
14
  ##
16
15
  # Imports
17
16
 
17
+ import struct
18
18
  from io import BytesIO
19
19
 
20
20
  import numpy as np
21
21
 
22
+ # .npy format magic prefix (used for backward-compatible deserialization)
23
+ _NPY_MAGIC = b"\x93NUMPY"
24
+
22
25
 
23
26
  ##
24
27
 
25
- def array_to_bytes( x: np.ndarray ) -> bytes:
28
+
29
+ def array_to_bytes(x: np.ndarray) -> bytes:
26
30
  """Convert a numpy array to bytes for msgpack serialization.
27
31
 
28
- Uses numpy's native ``save()`` format to preserve array dtype and shape.
32
+ Uses a compact binary format: a short header (dtype + shape) followed by
33
+ raw array bytes via ``ndarray.tobytes()``. Falls back to numpy's ``.npy``
34
+ format for object dtypes that cannot be represented as raw bytes.
29
35
 
30
36
  Args:
31
37
  x: A numpy array to serialize.
32
38
 
33
39
  Returns:
34
40
  Raw bytes representing the serialized array.
35
-
36
- Note:
37
- Uses ``allow_pickle=True`` to support object dtypes.
38
41
  """
39
- np_bytes = BytesIO()
40
- np.save( np_bytes, x, allow_pickle = True )
41
- return np_bytes.getvalue()
42
+ if x.dtype == object:
43
+ buf = BytesIO()
44
+ np.save(buf, x, allow_pickle=True)
45
+ return buf.getvalue()
42
46
 
43
- def bytes_to_array( b: bytes ) -> np.ndarray:
47
+ dtype_str = x.dtype.str.encode() # e.g. b'<f4'
48
+ header = struct.pack(f"<B{len(x.shape)}q", len(x.shape), *x.shape)
49
+ return struct.pack("<B", len(dtype_str)) + dtype_str + header + x.tobytes()
50
+
51
+
52
+ def bytes_to_array(b: bytes) -> np.ndarray:
44
53
  """Convert serialized bytes back to a numpy array.
45
54
 
46
- Reverses the serialization performed by ``array_to_bytes()``.
55
+ Transparently handles both the compact format produced by the current
56
+ ``array_to_bytes()`` and the legacy ``.npy`` format.
47
57
 
48
58
  Args:
49
59
  b: Raw bytes from a serialized numpy array.
50
60
 
51
61
  Returns:
52
62
  The deserialized numpy array with original dtype and shape.
53
-
54
- Note:
55
- Uses ``allow_pickle=True`` to support object dtypes.
56
63
  """
57
- np_bytes = BytesIO( b )
58
- return np.load( np_bytes, allow_pickle = True )
64
+ if b[:6] == _NPY_MAGIC:
65
+ return np.load(BytesIO(b), allow_pickle=True)
66
+
67
+ # Compact format: dtype_len(1B) + dtype_str + ndim(1B) + shape(ndim×8B) + data
68
+ dlen = b[0]
69
+ dtype = np.dtype(b[1 : 1 + dlen].decode())
70
+ ndim = b[1 + dlen]
71
+ offset = 2 + dlen
72
+ shape = struct.unpack_from(f"<{ndim}q", b, offset)
73
+ offset += ndim * 8
74
+ return np.frombuffer(b, dtype=dtype, offset=offset).reshape(shape).copy()
atdata/_hf_api.py CHANGED
@@ -9,28 +9,27 @@ Key differences from HuggingFace Datasets:
9
9
  - Built on WebDataset for efficient streaming of large datasets
10
10
  - No Arrow caching layer (WebDataset handles remote/local transparently)
11
11
 
12
- Example:
13
- ::
14
-
15
- >>> import atdata
16
- >>> from atdata import load_dataset
17
- >>>
18
- >>> @atdata.packable
19
- ... class MyData:
20
- ... text: str
21
- ... label: int
22
- >>>
23
- >>> # Load a single split
24
- >>> ds = load_dataset("path/to/train-{000000..000099}.tar", MyData, split="train")
25
- >>>
26
- >>> # Load all splits (returns DatasetDict)
27
- >>> ds_dict = load_dataset("path/to/{train,test}-*.tar", MyData)
28
- >>> train_ds = ds_dict["train"]
12
+ Examples:
13
+ >>> import atdata
14
+ >>> from atdata import load_dataset
15
+ >>>
16
+ >>> @atdata.packable
17
+ ... class MyData:
18
+ ... text: str
19
+ ... label: int
20
+ >>>
21
+ >>> # Load a single split
22
+ >>> ds = load_dataset("path/to/train-{000000..000099}.tar", MyData, split="train")
23
+ >>>
24
+ >>> # Load all splits (returns DatasetDict)
25
+ >>> ds_dict = load_dataset("path/to/{train,test}-*.tar", MyData)
26
+ >>> train_ds = ds_dict["train"]
29
27
  """
30
28
 
31
29
  from __future__ import annotations
32
30
 
33
31
  import re
32
+ import threading
34
33
  from pathlib import Path
35
34
  from typing import (
36
35
  TYPE_CHECKING,
@@ -42,18 +41,70 @@ from typing import (
42
41
  overload,
43
42
  )
44
43
 
45
- from .dataset import Dataset, PackableSample, DictSample
44
+ from .dataset import Dataset, DictSample
46
45
  from ._sources import URLSource, S3Source
47
- from ._protocols import DataSource
46
+ from ._protocols import DataSource, Packable
48
47
 
49
48
  if TYPE_CHECKING:
50
49
  from ._protocols import AbstractIndex
51
- from .local import S3DataStore
52
50
 
53
51
  ##
54
52
  # Type variables
55
53
 
56
- ST = TypeVar("ST", bound=PackableSample)
54
+ ST = TypeVar("ST", bound=Packable)
55
+
56
+
57
+ ##
58
+ # Default Index singleton
59
+
60
+ _default_index: "Index | None" = None # noqa: F821 (forward ref)
61
+ _default_index_lock = threading.Lock()
62
+
63
+
64
+ def get_default_index() -> "Index": # noqa: F821
65
+ """Get or create the module-level default Index.
66
+
67
+ The default Index uses Redis for local storage (backwards-compatible
68
+ default) and an anonymous AtmosphereClient for read-only public data
69
+ resolution.
70
+
71
+ The default is created lazily on first access and cached for the
72
+ lifetime of the process.
73
+
74
+ Returns:
75
+ The default Index instance.
76
+
77
+ Examples:
78
+ >>> index = get_default_index()
79
+ >>> entry = index.get_dataset("local/mnist")
80
+ """
81
+ global _default_index
82
+ if _default_index is None:
83
+ with _default_index_lock:
84
+ if _default_index is None:
85
+ from .local import Index
86
+
87
+ _default_index = Index()
88
+ return _default_index
89
+
90
+
91
+ def set_default_index(index: "Index") -> None: # noqa: F821
92
+ """Override the module-level default Index.
93
+
94
+ Use this to configure a custom default Index with specific repositories,
95
+ an authenticated atmosphere client, or non-default providers.
96
+
97
+ Args:
98
+ index: The Index instance to use as the default.
99
+
100
+ Examples:
101
+ >>> from atdata.local import Index
102
+ >>> from atdata.providers import create_provider
103
+ >>> custom = Index(provider=create_provider("sqlite"))
104
+ >>> set_default_index(custom)
105
+ """
106
+ global _default_index
107
+ _default_index = index
57
108
 
58
109
 
59
110
  ##
@@ -70,18 +121,18 @@ class DatasetDict(Generic[ST], dict):
70
121
  Parameters:
71
122
  ST: The sample type for all datasets in this dict.
72
123
 
73
- Example:
74
- ::
75
-
76
- >>> ds_dict = load_dataset("path/to/data", MyData)
77
- >>> train = ds_dict["train"]
78
- >>> test = ds_dict["test"]
79
- >>>
80
- >>> # Iterate over all splits
81
- >>> for split_name, dataset in ds_dict.items():
82
- ... print(f"{split_name}: {len(dataset.shard_list)} shards")
124
+ Examples:
125
+ >>> ds_dict = load_dataset("path/to/data", MyData)
126
+ >>> train = ds_dict["train"]
127
+ >>> test = ds_dict["test"]
128
+ >>>
129
+ >>> # Iterate over all splits
130
+ >>> for split_name, dataset in ds_dict.items():
131
+ ... print(f"{split_name}: {len(dataset.list_shards())} shards")
83
132
  """
84
- # TODO The above has a line for "Parameters:" that should be "Type Parameters:"; this is a temporary fix for `quartodoc` auto-generation bugs.
133
+
134
+ # Note: The docstring uses "Parameters:" for type parameters as a workaround
135
+ # for quartodoc not supporting "Type Parameters:" sections.
85
136
 
86
137
  def __init__(
87
138
  self,
@@ -463,12 +514,12 @@ def _resolve_indexed_path(
463
514
  handle_or_did, dataset_name = _parse_indexed_path(path)
464
515
 
465
516
  # For AtmosphereIndex, we need to resolve handle to DID first
466
- # For LocalIndex, the handle is ignored and we just look up by name
517
+ # For local Index, the handle is ignored and we just look up by name
467
518
  entry = index.get_dataset(dataset_name)
468
519
  data_urls = entry.data_urls
469
520
 
470
521
  # Check if index has a data store
471
- if hasattr(index, 'data_store') and index.data_store is not None:
522
+ if hasattr(index, "data_store") and index.data_store is not None:
472
523
  store = index.data_store
473
524
 
474
525
  # Import here to avoid circular imports at module level
@@ -613,38 +664,35 @@ def load_dataset(
613
664
  FileNotFoundError: If no data files are found at the path.
614
665
  KeyError: If dataset not found in index.
615
666
 
616
- Example:
617
- ::
618
-
619
- >>> # Load without type - get DictSample for exploration
620
- >>> ds = load_dataset("./data/train.tar", split="train")
621
- >>> for sample in ds.ordered():
622
- ... print(sample.keys()) # Explore fields
623
- ... print(sample["text"]) # Dict-style access
624
- ... print(sample.label) # Attribute access
625
- >>>
626
- >>> # Convert to typed schema
627
- >>> typed_ds = ds.as_type(TextData)
628
- >>>
629
- >>> # Or load with explicit type directly
630
- >>> train_ds = load_dataset("./data/train-*.tar", TextData, split="train")
631
- >>>
632
- >>> # Load from index with auto-type resolution
633
- >>> index = LocalIndex()
634
- >>> ds = load_dataset("@local/my-dataset", index=index, split="train")
667
+ Examples:
668
+ >>> # Load without type - get DictSample for exploration
669
+ >>> ds = load_dataset("./data/train.tar", split="train")
670
+ >>> for sample in ds.ordered():
671
+ ... print(sample.keys()) # Explore fields
672
+ ... print(sample["text"]) # Dict-style access
673
+ ... print(sample.label) # Attribute access
674
+ >>>
675
+ >>> # Convert to typed schema
676
+ >>> typed_ds = ds.as_type(TextData)
677
+ >>>
678
+ >>> # Or load with explicit type directly
679
+ >>> train_ds = load_dataset("./data/train-*.tar", TextData, split="train")
680
+ >>>
681
+ >>> # Load from index with auto-type resolution
682
+ >>> index = Index()
683
+ >>> ds = load_dataset("@local/my-dataset", index=index, split="train")
635
684
  """
636
685
  # Handle @handle/dataset indexed path resolution
637
686
  if _is_indexed_path(path):
638
687
  if index is None:
639
- raise ValueError(
640
- f"Index required for indexed path: {path}. "
641
- "Pass index=LocalIndex() or index=AtmosphereIndex(client)."
642
- )
688
+ index = get_default_index()
643
689
 
644
690
  source, schema_ref = _resolve_indexed_path(path, index)
645
691
 
646
692
  # Resolve sample_type from schema if not provided
647
- resolved_type: Type = sample_type if sample_type is not None else index.decode_schema(schema_ref)
693
+ resolved_type: Type = (
694
+ sample_type if sample_type is not None else index.decode_schema(schema_ref)
695
+ )
648
696
 
649
697
  # Create dataset from the resolved source (includes credentials if S3)
650
698
  ds = Dataset[resolved_type](source)
@@ -653,7 +701,9 @@ def load_dataset(
653
701
  # Indexed datasets are single-split by default
654
702
  return ds
655
703
 
656
- return DatasetDict({"train": ds}, sample_type=resolved_type, streaming=streaming)
704
+ return DatasetDict(
705
+ {"train": ds}, sample_type=resolved_type, streaming=streaming
706
+ )
657
707
 
658
708
  # Use DictSample as default when no type specified
659
709
  resolved_type = sample_type if sample_type is not None else DictSample