atdata 0.2.0a1__py3-none-any.whl → 0.2.3b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +43 -10
- atdata/_cid.py +144 -0
- atdata/_helpers.py +7 -5
- atdata/_hf_api.py +690 -0
- atdata/_protocols.py +504 -0
- atdata/_schema_codec.py +438 -0
- atdata/_sources.py +508 -0
- atdata/_stub_manager.py +534 -0
- atdata/_type_utils.py +104 -0
- atdata/atmosphere/__init__.py +269 -1
- atdata/atmosphere/_types.py +4 -2
- atdata/atmosphere/client.py +146 -3
- atdata/atmosphere/lens.py +4 -3
- atdata/atmosphere/records.py +168 -7
- atdata/atmosphere/schema.py +29 -82
- atdata/atmosphere/store.py +204 -0
- atdata/cli/__init__.py +222 -0
- atdata/cli/diagnose.py +169 -0
- atdata/cli/local.py +283 -0
- atdata/dataset.py +615 -257
- atdata/lens.py +53 -54
- atdata/local.py +1456 -228
- atdata/promote.py +195 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.3b1.dist-info}/METADATA +106 -14
- atdata-0.2.3b1.dist-info/RECORD +28 -0
- atdata-0.2.0a1.dist-info/RECORD +0 -16
- {atdata-0.2.0a1.dist-info → atdata-0.2.3b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.3b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.3b1.dist-info}/licenses/LICENSE +0 -0
atdata/_hf_api.py
ADDED
|
@@ -0,0 +1,690 @@
|
|
|
1
|
+
"""HuggingFace Datasets-style API for atdata.
|
|
2
|
+
|
|
3
|
+
This module provides a familiar `load_dataset()` interface inspired by the
|
|
4
|
+
HuggingFace Datasets library, adapted for atdata's typed WebDataset approach.
|
|
5
|
+
|
|
6
|
+
Key differences from HuggingFace Datasets:
|
|
7
|
+
- Requires explicit `sample_type` parameter (typed dataclass)
|
|
8
|
+
- Returns atdata.Dataset[ST] instead of HF Dataset
|
|
9
|
+
- Built on WebDataset for efficient streaming of large datasets
|
|
10
|
+
- No Arrow caching layer (WebDataset handles remote/local transparently)
|
|
11
|
+
|
|
12
|
+
Examples:
|
|
13
|
+
>>> import atdata
|
|
14
|
+
>>> from atdata import load_dataset
|
|
15
|
+
>>>
|
|
16
|
+
>>> @atdata.packable
|
|
17
|
+
... class MyData:
|
|
18
|
+
... text: str
|
|
19
|
+
... label: int
|
|
20
|
+
>>>
|
|
21
|
+
>>> # Load a single split
|
|
22
|
+
>>> ds = load_dataset("path/to/train-{000000..000099}.tar", MyData, split="train")
|
|
23
|
+
>>>
|
|
24
|
+
>>> # Load all splits (returns DatasetDict)
|
|
25
|
+
>>> ds_dict = load_dataset("path/to/{train,test}-*.tar", MyData)
|
|
26
|
+
>>> train_ds = ds_dict["train"]
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
import re
|
|
32
|
+
from pathlib import Path
|
|
33
|
+
from typing import (
|
|
34
|
+
TYPE_CHECKING,
|
|
35
|
+
Generic,
|
|
36
|
+
Mapping,
|
|
37
|
+
Optional,
|
|
38
|
+
Type,
|
|
39
|
+
TypeVar,
|
|
40
|
+
overload,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
from .dataset import Dataset, PackableSample, DictSample
|
|
44
|
+
from ._sources import URLSource, S3Source
|
|
45
|
+
from ._protocols import DataSource
|
|
46
|
+
|
|
47
|
+
if TYPE_CHECKING:
|
|
48
|
+
from ._protocols import AbstractIndex
|
|
49
|
+
|
|
50
|
+
##
|
|
51
|
+
# Type variables
|
|
52
|
+
|
|
53
|
+
ST = TypeVar("ST", bound=PackableSample)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
##
|
|
57
|
+
# DatasetDict - container for multiple splits
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class DatasetDict(Generic[ST], dict):
|
|
61
|
+
"""A dictionary of split names to Dataset instances.
|
|
62
|
+
|
|
63
|
+
Similar to HuggingFace's DatasetDict, this provides a container for
|
|
64
|
+
multiple dataset splits (train, test, validation, etc.) with convenience
|
|
65
|
+
methods that operate across all splits.
|
|
66
|
+
|
|
67
|
+
Parameters:
|
|
68
|
+
ST: The sample type for all datasets in this dict.
|
|
69
|
+
|
|
70
|
+
Examples:
|
|
71
|
+
>>> ds_dict = load_dataset("path/to/data", MyData)
|
|
72
|
+
>>> train = ds_dict["train"]
|
|
73
|
+
>>> test = ds_dict["test"]
|
|
74
|
+
>>>
|
|
75
|
+
>>> # Iterate over all splits
|
|
76
|
+
>>> for split_name, dataset in ds_dict.items():
|
|
77
|
+
... print(f"{split_name}: {len(dataset.shard_list)} shards")
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
# TODO The above has a line for "Parameters:" that should be "Type Parameters:"; this is a temporary fix for `quartodoc` auto-generation bugs.
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
splits: Mapping[str, Dataset[ST]] | None = None,
|
|
85
|
+
sample_type: Type[ST] | None = None,
|
|
86
|
+
streaming: bool = False,
|
|
87
|
+
) -> None:
|
|
88
|
+
"""Create a DatasetDict from a mapping of split names to datasets.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
splits: Mapping of split names to Dataset instances.
|
|
92
|
+
sample_type: The sample type for datasets in this dict. If not
|
|
93
|
+
provided, inferred from the first dataset in splits.
|
|
94
|
+
streaming: Whether this DatasetDict was loaded in streaming mode.
|
|
95
|
+
"""
|
|
96
|
+
super().__init__(splits or {})
|
|
97
|
+
self._sample_type = sample_type
|
|
98
|
+
self._streaming = streaming
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def sample_type(self) -> Type[ST] | None:
|
|
102
|
+
"""The sample type for datasets in this dict."""
|
|
103
|
+
if self._sample_type is not None:
|
|
104
|
+
return self._sample_type
|
|
105
|
+
# Infer from first dataset
|
|
106
|
+
if self:
|
|
107
|
+
first_ds = next(iter(self.values()))
|
|
108
|
+
return first_ds.sample_type
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
def __getitem__(self, key: str) -> Dataset[ST]:
|
|
112
|
+
"""Get a dataset by split name."""
|
|
113
|
+
return super().__getitem__(key)
|
|
114
|
+
|
|
115
|
+
def __setitem__(self, key: str, value: Dataset[ST]) -> None:
|
|
116
|
+
"""Set a dataset for a split name."""
|
|
117
|
+
super().__setitem__(key, value)
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def streaming(self) -> bool:
|
|
121
|
+
"""Whether this DatasetDict was loaded in streaming mode."""
|
|
122
|
+
return self._streaming
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def num_shards(self) -> dict[str, int]:
|
|
126
|
+
"""Number of shards in each split.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Dict mapping split names to shard counts.
|
|
130
|
+
|
|
131
|
+
Note:
|
|
132
|
+
This property accesses the shard list, which may trigger
|
|
133
|
+
shard enumeration for remote datasets.
|
|
134
|
+
"""
|
|
135
|
+
return {name: len(ds.list_shards()) for name, ds in self.items()}
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
##
|
|
139
|
+
# Path resolution utilities
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _is_brace_pattern(path: str) -> bool:
|
|
143
|
+
"""Check if path contains WebDataset brace expansion notation like {000..099}."""
|
|
144
|
+
return bool(re.search(r"\{[^}]+\}", path))
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _is_glob_pattern(path: str) -> bool:
|
|
148
|
+
"""Check if path contains glob wildcards (* or ?)."""
|
|
149
|
+
return "*" in path or "?" in path
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _is_remote_url(path: str) -> bool:
|
|
153
|
+
"""Check if path is a remote URL (s3://, gs://, http://, https://, az://)."""
|
|
154
|
+
return path.startswith(("s3://", "gs://", "http://", "https://", "az://"))
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _expand_local_glob(pattern: str) -> list[str]:
|
|
158
|
+
"""Expand local glob pattern to sorted list of matching file paths."""
|
|
159
|
+
base_path = Path(pattern).parent
|
|
160
|
+
glob_part = Path(pattern).name
|
|
161
|
+
|
|
162
|
+
if not base_path.exists():
|
|
163
|
+
return []
|
|
164
|
+
|
|
165
|
+
matches = sorted(base_path.glob(glob_part))
|
|
166
|
+
return [str(p) for p in matches if p.is_file()]
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# Pre-compiled split name patterns (pattern, split_name)
|
|
170
|
+
_SPLIT_PATTERNS: list[tuple[re.Pattern[str], str]] = [
|
|
171
|
+
# Patterns like "dataset-train-000000.tar" (split in middle with delimiters)
|
|
172
|
+
(re.compile(r"[_-](train|training)[_-]"), "train"),
|
|
173
|
+
(re.compile(r"[_-](test|testing)[_-]"), "test"),
|
|
174
|
+
(re.compile(r"[_-](val|valid|validation)[_-]"), "validation"),
|
|
175
|
+
(re.compile(r"[_-](dev|development)[_-]"), "validation"),
|
|
176
|
+
# Patterns at start of filename like "train-000.tar" or "test_data.tar"
|
|
177
|
+
(re.compile(r"^(train|training)[_-]"), "train"),
|
|
178
|
+
(re.compile(r"^(test|testing)[_-]"), "test"),
|
|
179
|
+
(re.compile(r"^(val|valid|validation)[_-]"), "validation"),
|
|
180
|
+
(re.compile(r"^(dev|development)[_-]"), "validation"),
|
|
181
|
+
# Patterns in directory path like "/path/train/shard-000.tar"
|
|
182
|
+
(re.compile(r"[/\\](train|training)[/\\]"), "train"),
|
|
183
|
+
(re.compile(r"[/\\](test|testing)[/\\]"), "test"),
|
|
184
|
+
(re.compile(r"[/\\](val|valid|validation)[/\\]"), "validation"),
|
|
185
|
+
(re.compile(r"[/\\](dev|development)[/\\]"), "validation"),
|
|
186
|
+
# Patterns at start of path like "train/shard-000.tar"
|
|
187
|
+
(re.compile(r"^(train|training)[/\\]"), "train"),
|
|
188
|
+
(re.compile(r"^(test|testing)[/\\]"), "test"),
|
|
189
|
+
(re.compile(r"^(val|valid|validation)[/\\]"), "validation"),
|
|
190
|
+
(re.compile(r"^(dev|development)[/\\]"), "validation"),
|
|
191
|
+
]
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _detect_split_from_path(path: str) -> str | None:
|
|
195
|
+
"""Detect split name (train/test/validation) from file path."""
|
|
196
|
+
filename = Path(path).name
|
|
197
|
+
path_lower = path.lower()
|
|
198
|
+
filename_lower = filename.lower()
|
|
199
|
+
|
|
200
|
+
# Check filename first (more specific)
|
|
201
|
+
for pattern, split_name in _SPLIT_PATTERNS:
|
|
202
|
+
if pattern.search(filename_lower):
|
|
203
|
+
return split_name
|
|
204
|
+
|
|
205
|
+
# Fall back to full path (catches directory patterns)
|
|
206
|
+
for pattern, split_name in _SPLIT_PATTERNS:
|
|
207
|
+
if pattern.search(path_lower):
|
|
208
|
+
return split_name
|
|
209
|
+
|
|
210
|
+
return None
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _resolve_shards(
|
|
214
|
+
path: str,
|
|
215
|
+
data_files: str | list[str] | dict[str, str | list[str]] | None = None,
|
|
216
|
+
) -> dict[str, list[str]]:
|
|
217
|
+
"""Resolve path specification to dict of split -> shard URLs.
|
|
218
|
+
|
|
219
|
+
Handles:
|
|
220
|
+
- WebDataset brace notation: "path/{train,test}-{000..099}.tar"
|
|
221
|
+
- Glob patterns: "path/*.tar"
|
|
222
|
+
- Explicit data_files mapping
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
path: Base path or pattern.
|
|
226
|
+
data_files: Optional explicit mapping of splits to files.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
Dict mapping split names to lists of shard URLs.
|
|
230
|
+
"""
|
|
231
|
+
# If explicit data_files provided, use those
|
|
232
|
+
if data_files is not None:
|
|
233
|
+
return _resolve_data_files(path, data_files)
|
|
234
|
+
|
|
235
|
+
# WebDataset brace notation - pass through as-is
|
|
236
|
+
# WebDataset handles expansion internally
|
|
237
|
+
if _is_brace_pattern(path):
|
|
238
|
+
# Try to detect split from the pattern itself
|
|
239
|
+
split = _detect_split_from_path(path)
|
|
240
|
+
split_name = split or "train"
|
|
241
|
+
return {split_name: [path]}
|
|
242
|
+
|
|
243
|
+
# Local glob pattern
|
|
244
|
+
if not _is_remote_url(path) and _is_glob_pattern(path):
|
|
245
|
+
shards = _expand_local_glob(path)
|
|
246
|
+
return _group_shards_by_split(shards)
|
|
247
|
+
|
|
248
|
+
# Local directory - scan for .tar files
|
|
249
|
+
if not _is_remote_url(path) and Path(path).is_dir():
|
|
250
|
+
shards = _expand_local_glob(str(Path(path) / "*.tar"))
|
|
251
|
+
return _group_shards_by_split(shards)
|
|
252
|
+
|
|
253
|
+
# Single file or remote URL - treat as single shard
|
|
254
|
+
split = _detect_split_from_path(path)
|
|
255
|
+
split_name = split or "train"
|
|
256
|
+
return {split_name: [path]}
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def _resolve_data_files(
|
|
260
|
+
base_path: str,
|
|
261
|
+
data_files: str | list[str] | dict[str, str | list[str]],
|
|
262
|
+
) -> dict[str, list[str]]:
|
|
263
|
+
"""Resolve explicit data_files specification.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
base_path: Base path for relative file references.
|
|
267
|
+
data_files: File specification - can be:
|
|
268
|
+
- str: Single file pattern
|
|
269
|
+
- list[str]: List of file patterns
|
|
270
|
+
- dict[str, ...]: Mapping of split names to patterns
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
Dict mapping split names to lists of resolved file paths.
|
|
274
|
+
"""
|
|
275
|
+
base = Path(base_path) if not _is_remote_url(base_path) else None
|
|
276
|
+
|
|
277
|
+
if isinstance(data_files, str):
|
|
278
|
+
# Single pattern -> "train" split
|
|
279
|
+
if base and not Path(data_files).is_absolute():
|
|
280
|
+
data_files = str(base / data_files)
|
|
281
|
+
return {"train": [data_files]}
|
|
282
|
+
|
|
283
|
+
if isinstance(data_files, list):
|
|
284
|
+
# List of patterns -> "train" split
|
|
285
|
+
resolved = []
|
|
286
|
+
for f in data_files:
|
|
287
|
+
if base and not Path(f).is_absolute():
|
|
288
|
+
f = str(base / f)
|
|
289
|
+
resolved.append(f)
|
|
290
|
+
return {"train": resolved}
|
|
291
|
+
|
|
292
|
+
# Dict mapping splits to patterns
|
|
293
|
+
result: dict[str, list[str]] = {}
|
|
294
|
+
for split_name, files in data_files.items():
|
|
295
|
+
if isinstance(files, str):
|
|
296
|
+
files = [files]
|
|
297
|
+
resolved = []
|
|
298
|
+
for f in files:
|
|
299
|
+
if base and not Path(f).is_absolute():
|
|
300
|
+
f = str(base / f)
|
|
301
|
+
resolved.append(f)
|
|
302
|
+
result[split_name] = resolved
|
|
303
|
+
|
|
304
|
+
return result
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _shards_to_wds_url(shards: list[str]) -> str:
|
|
308
|
+
"""Convert a list of shard paths to a WebDataset URL.
|
|
309
|
+
|
|
310
|
+
WebDataset supports brace expansion, so we convert multiple shards
|
|
311
|
+
into brace notation when they share a common prefix/suffix.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
shards: List of shard file paths.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
WebDataset-compatible URL string.
|
|
318
|
+
|
|
319
|
+
Examples:
|
|
320
|
+
>>> _shards_to_wds_url(["data-000.tar", "data-001.tar", "data-002.tar"])
|
|
321
|
+
"data-{000,001,002}.tar"
|
|
322
|
+
>>> _shards_to_wds_url(["train.tar"])
|
|
323
|
+
"train.tar"
|
|
324
|
+
"""
|
|
325
|
+
import os.path
|
|
326
|
+
|
|
327
|
+
if len(shards) == 0:
|
|
328
|
+
raise ValueError("Cannot create URL from empty shard list")
|
|
329
|
+
|
|
330
|
+
if len(shards) == 1:
|
|
331
|
+
return shards[0]
|
|
332
|
+
|
|
333
|
+
# Find common prefix using os.path.commonprefix (O(n) vs O(n²))
|
|
334
|
+
prefix = os.path.commonprefix(shards)
|
|
335
|
+
|
|
336
|
+
# Find common suffix by reversing strings
|
|
337
|
+
reversed_shards = [s[::-1] for s in shards]
|
|
338
|
+
suffix = os.path.commonprefix(reversed_shards)[::-1]
|
|
339
|
+
|
|
340
|
+
prefix_len = len(prefix)
|
|
341
|
+
suffix_len = len(suffix)
|
|
342
|
+
|
|
343
|
+
# Ensure prefix and suffix don't overlap
|
|
344
|
+
min_shard_len = min(len(s) for s in shards)
|
|
345
|
+
if prefix_len + suffix_len > min_shard_len:
|
|
346
|
+
# Overlapping - prefer prefix, reduce suffix
|
|
347
|
+
suffix_len = max(0, min_shard_len - prefix_len)
|
|
348
|
+
suffix = shards[0][-suffix_len:] if suffix_len > 0 else ""
|
|
349
|
+
|
|
350
|
+
if prefix_len > 0 or suffix_len > 0:
|
|
351
|
+
# Extract the varying middle parts
|
|
352
|
+
middles = []
|
|
353
|
+
for s in shards:
|
|
354
|
+
if suffix_len > 0:
|
|
355
|
+
middle = s[prefix_len:-suffix_len]
|
|
356
|
+
else:
|
|
357
|
+
middle = s[prefix_len:]
|
|
358
|
+
middles.append(middle)
|
|
359
|
+
|
|
360
|
+
# Only use brace notation if we have meaningful variation
|
|
361
|
+
if all(middles):
|
|
362
|
+
return f"{prefix}{{{','.join(middles)}}}{suffix}"
|
|
363
|
+
|
|
364
|
+
# Fallback: space-separated URLs for WebDataset
|
|
365
|
+
return " ".join(shards)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def _group_shards_by_split(shards: list[str]) -> dict[str, list[str]]:
|
|
369
|
+
"""Group a list of shard paths by detected split.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
shards: List of shard file paths.
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
Dict mapping split names to lists of shards. Files with no
|
|
376
|
+
detected split are placed in "train".
|
|
377
|
+
"""
|
|
378
|
+
result: dict[str, list[str]] = {}
|
|
379
|
+
|
|
380
|
+
for shard in shards:
|
|
381
|
+
split = _detect_split_from_path(shard)
|
|
382
|
+
split_name = split or "train"
|
|
383
|
+
if split_name not in result:
|
|
384
|
+
result[split_name] = []
|
|
385
|
+
result[split_name].append(shard)
|
|
386
|
+
|
|
387
|
+
return result
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
##
|
|
391
|
+
# Index-based path resolution
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def _is_indexed_path(path: str) -> bool:
|
|
395
|
+
"""Check if path uses @handle/dataset notation for index lookup.
|
|
396
|
+
|
|
397
|
+
Examples:
|
|
398
|
+
>>> _is_indexed_path("@maxine.science/mnist")
|
|
399
|
+
True
|
|
400
|
+
>>> _is_indexed_path("@did:plc:abc123/my-dataset")
|
|
401
|
+
True
|
|
402
|
+
>>> _is_indexed_path("s3://bucket/data.tar")
|
|
403
|
+
False
|
|
404
|
+
"""
|
|
405
|
+
return path.startswith("@")
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def _parse_indexed_path(path: str) -> tuple[str, str]:
|
|
409
|
+
"""Parse @handle/dataset path into (handle_or_did, dataset_name).
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
path: Path in format "@handle/dataset" or "@did:plc:xxx/dataset"
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
Tuple of (handle_or_did, dataset_name)
|
|
416
|
+
|
|
417
|
+
Raises:
|
|
418
|
+
ValueError: If path format is invalid.
|
|
419
|
+
"""
|
|
420
|
+
if not path.startswith("@"):
|
|
421
|
+
raise ValueError(f"Not an indexed path: {path}")
|
|
422
|
+
|
|
423
|
+
# Remove leading @
|
|
424
|
+
rest = path[1:]
|
|
425
|
+
|
|
426
|
+
# Split on first / (handle can contain . but dataset name is after /)
|
|
427
|
+
if "/" not in rest:
|
|
428
|
+
raise ValueError(
|
|
429
|
+
f"Invalid indexed path format: {path}. "
|
|
430
|
+
"Expected @handle/dataset or @did:plc:xxx/dataset"
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
# Find the split point - for DIDs, the format is did:plc:xxx/dataset
|
|
434
|
+
# For handles, it's handle.domain/dataset
|
|
435
|
+
parts = rest.split("/", 1)
|
|
436
|
+
if len(parts) != 2 or not parts[0] or not parts[1]:
|
|
437
|
+
raise ValueError(f"Invalid indexed path: {path}")
|
|
438
|
+
|
|
439
|
+
return parts[0], parts[1]
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
def _resolve_indexed_path(
|
|
443
|
+
path: str,
|
|
444
|
+
index: "AbstractIndex",
|
|
445
|
+
) -> tuple[DataSource, str]:
|
|
446
|
+
"""Resolve @handle/dataset path to DataSource and schema_ref via index lookup.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
path: Path in @handle/dataset format.
|
|
450
|
+
index: Index to use for lookup.
|
|
451
|
+
|
|
452
|
+
Returns:
|
|
453
|
+
Tuple of (DataSource, schema_ref). The DataSource is configured with
|
|
454
|
+
appropriate credentials when the index has an S3DataStore.
|
|
455
|
+
|
|
456
|
+
Raises:
|
|
457
|
+
KeyError: If dataset not found in index.
|
|
458
|
+
"""
|
|
459
|
+
handle_or_did, dataset_name = _parse_indexed_path(path)
|
|
460
|
+
|
|
461
|
+
# For AtmosphereIndex, we need to resolve handle to DID first
|
|
462
|
+
# For LocalIndex, the handle is ignored and we just look up by name
|
|
463
|
+
entry = index.get_dataset(dataset_name)
|
|
464
|
+
data_urls = entry.data_urls
|
|
465
|
+
|
|
466
|
+
# Check if index has a data store
|
|
467
|
+
if hasattr(index, "data_store") and index.data_store is not None:
|
|
468
|
+
store = index.data_store
|
|
469
|
+
|
|
470
|
+
# Import here to avoid circular imports at module level
|
|
471
|
+
from .local import S3DataStore
|
|
472
|
+
|
|
473
|
+
# For S3DataStore with S3 URLs, create S3Source with credentials
|
|
474
|
+
if isinstance(store, S3DataStore):
|
|
475
|
+
if data_urls and all(url.startswith("s3://") for url in data_urls):
|
|
476
|
+
source = S3Source.from_urls(
|
|
477
|
+
data_urls,
|
|
478
|
+
endpoint=store.credentials.get("AWS_ENDPOINT"),
|
|
479
|
+
access_key=store.credentials.get("AWS_ACCESS_KEY_ID"),
|
|
480
|
+
secret_key=store.credentials.get("AWS_SECRET_ACCESS_KEY"),
|
|
481
|
+
region=store.credentials.get("AWS_REGION"),
|
|
482
|
+
)
|
|
483
|
+
return source, entry.schema_ref
|
|
484
|
+
|
|
485
|
+
# For any data store, use read_url to transform URLs if needed
|
|
486
|
+
# (handles endpoint URL conversion for HTTPS access, etc.)
|
|
487
|
+
transformed_urls = [store.read_url(url) for url in data_urls]
|
|
488
|
+
url = _shards_to_wds_url(transformed_urls)
|
|
489
|
+
return URLSource(url), entry.schema_ref
|
|
490
|
+
|
|
491
|
+
# Default: URL-based source without credentials
|
|
492
|
+
url = _shards_to_wds_url(data_urls)
|
|
493
|
+
return URLSource(url), entry.schema_ref
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
##
|
|
497
|
+
# Main load_dataset function
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
# Overload: explicit type with split -> Dataset[ST]
|
|
501
|
+
@overload
|
|
502
|
+
def load_dataset(
|
|
503
|
+
path: str,
|
|
504
|
+
sample_type: Type[ST],
|
|
505
|
+
*,
|
|
506
|
+
split: str,
|
|
507
|
+
data_files: str | list[str] | dict[str, str | list[str]] | None = None,
|
|
508
|
+
streaming: bool = False,
|
|
509
|
+
index: Optional["AbstractIndex"] = None,
|
|
510
|
+
) -> Dataset[ST]: ...
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
# Overload: explicit type without split -> DatasetDict[ST]
|
|
514
|
+
@overload
|
|
515
|
+
def load_dataset(
|
|
516
|
+
path: str,
|
|
517
|
+
sample_type: Type[ST],
|
|
518
|
+
*,
|
|
519
|
+
split: None = None,
|
|
520
|
+
data_files: str | list[str] | dict[str, str | list[str]] | None = None,
|
|
521
|
+
streaming: bool = False,
|
|
522
|
+
index: Optional["AbstractIndex"] = None,
|
|
523
|
+
) -> DatasetDict[ST]: ...
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
# Overload: no type with split -> Dataset[DictSample]
|
|
527
|
+
@overload
|
|
528
|
+
def load_dataset(
|
|
529
|
+
path: str,
|
|
530
|
+
sample_type: None = None,
|
|
531
|
+
*,
|
|
532
|
+
split: str,
|
|
533
|
+
data_files: str | list[str] | dict[str, str | list[str]] | None = None,
|
|
534
|
+
streaming: bool = False,
|
|
535
|
+
index: Optional["AbstractIndex"] = None,
|
|
536
|
+
) -> Dataset[DictSample]: ...
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
# Overload: no type without split -> DatasetDict[DictSample]
|
|
540
|
+
@overload
|
|
541
|
+
def load_dataset(
|
|
542
|
+
path: str,
|
|
543
|
+
sample_type: None = None,
|
|
544
|
+
*,
|
|
545
|
+
split: None = None,
|
|
546
|
+
data_files: str | list[str] | dict[str, str | list[str]] | None = None,
|
|
547
|
+
streaming: bool = False,
|
|
548
|
+
index: Optional["AbstractIndex"] = None,
|
|
549
|
+
) -> DatasetDict[DictSample]: ...
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
def load_dataset(
|
|
553
|
+
path: str,
|
|
554
|
+
sample_type: Type[ST] | None = None,
|
|
555
|
+
*,
|
|
556
|
+
split: str | None = None,
|
|
557
|
+
data_files: str | list[str] | dict[str, str | list[str]] | None = None,
|
|
558
|
+
streaming: bool = False,
|
|
559
|
+
index: Optional["AbstractIndex"] = None,
|
|
560
|
+
) -> Dataset[ST] | DatasetDict[ST]:
|
|
561
|
+
"""Load a dataset from local files, remote URLs, or an index.
|
|
562
|
+
|
|
563
|
+
This function provides a HuggingFace Datasets-style interface for loading
|
|
564
|
+
atdata typed datasets. It handles path resolution, split detection, and
|
|
565
|
+
returns either a single Dataset or a DatasetDict depending on the split
|
|
566
|
+
parameter.
|
|
567
|
+
|
|
568
|
+
When no ``sample_type`` is provided, returns a ``Dataset[DictSample]`` that
|
|
569
|
+
provides dynamic dict-like access to fields. Use ``.as_type(MyType)`` to
|
|
570
|
+
convert to a typed schema.
|
|
571
|
+
|
|
572
|
+
Args:
|
|
573
|
+
path: Path to dataset. Can be:
|
|
574
|
+
- Index lookup: "@handle/dataset-name" or "@local/dataset-name"
|
|
575
|
+
- WebDataset brace notation: "path/to/{train,test}-{000..099}.tar"
|
|
576
|
+
- Local directory: "./data/" (scans for .tar files)
|
|
577
|
+
- Glob pattern: "path/to/*.tar"
|
|
578
|
+
- Remote URL: "s3://bucket/path/data-*.tar"
|
|
579
|
+
- Single file: "path/to/data.tar"
|
|
580
|
+
|
|
581
|
+
sample_type: The PackableSample subclass defining the schema. If None,
|
|
582
|
+
returns ``Dataset[DictSample]`` with dynamic field access. Can also
|
|
583
|
+
be resolved from an index when using @handle/dataset syntax.
|
|
584
|
+
|
|
585
|
+
split: Which split to load. If None, returns a DatasetDict with all
|
|
586
|
+
detected splits. If specified (e.g., "train", "test"), returns
|
|
587
|
+
a single Dataset for that split.
|
|
588
|
+
|
|
589
|
+
data_files: Optional explicit mapping of data files. Can be:
|
|
590
|
+
- str: Single file pattern
|
|
591
|
+
- list[str]: List of file patterns (assigned to "train")
|
|
592
|
+
- dict[str, str | list[str]]: Explicit split -> files mapping
|
|
593
|
+
|
|
594
|
+
streaming: If True, explicitly marks the dataset for streaming mode.
|
|
595
|
+
Note: atdata Datasets are already lazy/streaming via WebDataset
|
|
596
|
+
pipelines, so this parameter primarily signals intent.
|
|
597
|
+
|
|
598
|
+
index: Optional AbstractIndex for dataset lookup. Required when using
|
|
599
|
+
@handle/dataset syntax. When provided with an indexed path, the
|
|
600
|
+
schema can be auto-resolved from the index.
|
|
601
|
+
|
|
602
|
+
Returns:
|
|
603
|
+
If split is None: DatasetDict with all detected splits.
|
|
604
|
+
If split is specified: Dataset for that split.
|
|
605
|
+
Type is ``ST`` if sample_type provided, otherwise ``DictSample``.
|
|
606
|
+
|
|
607
|
+
Raises:
|
|
608
|
+
ValueError: If the specified split is not found.
|
|
609
|
+
FileNotFoundError: If no data files are found at the path.
|
|
610
|
+
KeyError: If dataset not found in index.
|
|
611
|
+
|
|
612
|
+
Examples:
|
|
613
|
+
>>> # Load without type - get DictSample for exploration
|
|
614
|
+
>>> ds = load_dataset("./data/train.tar", split="train")
|
|
615
|
+
>>> for sample in ds.ordered():
|
|
616
|
+
... print(sample.keys()) # Explore fields
|
|
617
|
+
... print(sample["text"]) # Dict-style access
|
|
618
|
+
... print(sample.label) # Attribute access
|
|
619
|
+
>>>
|
|
620
|
+
>>> # Convert to typed schema
|
|
621
|
+
>>> typed_ds = ds.as_type(TextData)
|
|
622
|
+
>>>
|
|
623
|
+
>>> # Or load with explicit type directly
|
|
624
|
+
>>> train_ds = load_dataset("./data/train-*.tar", TextData, split="train")
|
|
625
|
+
>>>
|
|
626
|
+
>>> # Load from index with auto-type resolution
|
|
627
|
+
>>> index = LocalIndex()
|
|
628
|
+
>>> ds = load_dataset("@local/my-dataset", index=index, split="train")
|
|
629
|
+
"""
|
|
630
|
+
# Handle @handle/dataset indexed path resolution
|
|
631
|
+
if _is_indexed_path(path):
|
|
632
|
+
if index is None:
|
|
633
|
+
raise ValueError(
|
|
634
|
+
f"Index required for indexed path: {path}. "
|
|
635
|
+
"Pass index=LocalIndex() or index=AtmosphereIndex(client)."
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
source, schema_ref = _resolve_indexed_path(path, index)
|
|
639
|
+
|
|
640
|
+
# Resolve sample_type from schema if not provided
|
|
641
|
+
resolved_type: Type = (
|
|
642
|
+
sample_type if sample_type is not None else index.decode_schema(schema_ref)
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
# Create dataset from the resolved source (includes credentials if S3)
|
|
646
|
+
ds = Dataset[resolved_type](source)
|
|
647
|
+
|
|
648
|
+
if split is not None:
|
|
649
|
+
# Indexed datasets are single-split by default
|
|
650
|
+
return ds
|
|
651
|
+
|
|
652
|
+
return DatasetDict(
|
|
653
|
+
{"train": ds}, sample_type=resolved_type, streaming=streaming
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
# Use DictSample as default when no type specified
|
|
657
|
+
resolved_type = sample_type if sample_type is not None else DictSample
|
|
658
|
+
|
|
659
|
+
# Resolve path to split -> shard URL mapping
|
|
660
|
+
splits_shards = _resolve_shards(path, data_files)
|
|
661
|
+
|
|
662
|
+
if not splits_shards:
|
|
663
|
+
raise FileNotFoundError(f"No data files found at path: {path}")
|
|
664
|
+
|
|
665
|
+
# Build Dataset for each split
|
|
666
|
+
datasets: dict[str, Dataset] = {}
|
|
667
|
+
for split_name, shards in splits_shards.items():
|
|
668
|
+
url = _shards_to_wds_url(shards)
|
|
669
|
+
ds = Dataset[resolved_type](url)
|
|
670
|
+
datasets[split_name] = ds
|
|
671
|
+
|
|
672
|
+
# Return single Dataset or DatasetDict
|
|
673
|
+
if split is not None:
|
|
674
|
+
if split not in datasets:
|
|
675
|
+
available = list(datasets.keys())
|
|
676
|
+
raise ValueError(
|
|
677
|
+
f"Split '{split}' not found. Available splits: {available}"
|
|
678
|
+
)
|
|
679
|
+
return datasets[split]
|
|
680
|
+
|
|
681
|
+
return DatasetDict(datasets, sample_type=resolved_type, streaming=streaming)
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
##
|
|
685
|
+
# Convenience re-exports (will be exposed in __init__.py)
|
|
686
|
+
|
|
687
|
+
__all__ = [
|
|
688
|
+
"load_dataset",
|
|
689
|
+
"DatasetDict",
|
|
690
|
+
]
|