atdata 0.2.0a1__py3-none-any.whl → 0.2.3b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atdata/_hf_api.py ADDED
@@ -0,0 +1,690 @@
1
+ """HuggingFace Datasets-style API for atdata.
2
+
3
+ This module provides a familiar `load_dataset()` interface inspired by the
4
+ HuggingFace Datasets library, adapted for atdata's typed WebDataset approach.
5
+
6
+ Key differences from HuggingFace Datasets:
7
+ - Requires explicit `sample_type` parameter (typed dataclass)
8
+ - Returns atdata.Dataset[ST] instead of HF Dataset
9
+ - Built on WebDataset for efficient streaming of large datasets
10
+ - No Arrow caching layer (WebDataset handles remote/local transparently)
11
+
12
+ Examples:
13
+ >>> import atdata
14
+ >>> from atdata import load_dataset
15
+ >>>
16
+ >>> @atdata.packable
17
+ ... class MyData:
18
+ ... text: str
19
+ ... label: int
20
+ >>>
21
+ >>> # Load a single split
22
+ >>> ds = load_dataset("path/to/train-{000000..000099}.tar", MyData, split="train")
23
+ >>>
24
+ >>> # Load all splits (returns DatasetDict)
25
+ >>> ds_dict = load_dataset("path/to/{train,test}-*.tar", MyData)
26
+ >>> train_ds = ds_dict["train"]
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import re
32
+ from pathlib import Path
33
+ from typing import (
34
+ TYPE_CHECKING,
35
+ Generic,
36
+ Mapping,
37
+ Optional,
38
+ Type,
39
+ TypeVar,
40
+ overload,
41
+ )
42
+
43
+ from .dataset import Dataset, PackableSample, DictSample
44
+ from ._sources import URLSource, S3Source
45
+ from ._protocols import DataSource
46
+
47
+ if TYPE_CHECKING:
48
+ from ._protocols import AbstractIndex
49
+
50
+ ##
51
+ # Type variables
52
+
53
+ ST = TypeVar("ST", bound=PackableSample)
54
+
55
+
56
+ ##
57
+ # DatasetDict - container for multiple splits
58
+
59
+
60
+ class DatasetDict(Generic[ST], dict):
61
+ """A dictionary of split names to Dataset instances.
62
+
63
+ Similar to HuggingFace's DatasetDict, this provides a container for
64
+ multiple dataset splits (train, test, validation, etc.) with convenience
65
+ methods that operate across all splits.
66
+
67
+ Parameters:
68
+ ST: The sample type for all datasets in this dict.
69
+
70
+ Examples:
71
+ >>> ds_dict = load_dataset("path/to/data", MyData)
72
+ >>> train = ds_dict["train"]
73
+ >>> test = ds_dict["test"]
74
+ >>>
75
+ >>> # Iterate over all splits
76
+ >>> for split_name, dataset in ds_dict.items():
77
+ ... print(f"{split_name}: {len(dataset.shard_list)} shards")
78
+ """
79
+
80
+ # TODO The above has a line for "Parameters:" that should be "Type Parameters:"; this is a temporary fix for `quartodoc` auto-generation bugs.
81
+
82
+ def __init__(
83
+ self,
84
+ splits: Mapping[str, Dataset[ST]] | None = None,
85
+ sample_type: Type[ST] | None = None,
86
+ streaming: bool = False,
87
+ ) -> None:
88
+ """Create a DatasetDict from a mapping of split names to datasets.
89
+
90
+ Args:
91
+ splits: Mapping of split names to Dataset instances.
92
+ sample_type: The sample type for datasets in this dict. If not
93
+ provided, inferred from the first dataset in splits.
94
+ streaming: Whether this DatasetDict was loaded in streaming mode.
95
+ """
96
+ super().__init__(splits or {})
97
+ self._sample_type = sample_type
98
+ self._streaming = streaming
99
+
100
+ @property
101
+ def sample_type(self) -> Type[ST] | None:
102
+ """The sample type for datasets in this dict."""
103
+ if self._sample_type is not None:
104
+ return self._sample_type
105
+ # Infer from first dataset
106
+ if self:
107
+ first_ds = next(iter(self.values()))
108
+ return first_ds.sample_type
109
+ return None
110
+
111
+ def __getitem__(self, key: str) -> Dataset[ST]:
112
+ """Get a dataset by split name."""
113
+ return super().__getitem__(key)
114
+
115
+ def __setitem__(self, key: str, value: Dataset[ST]) -> None:
116
+ """Set a dataset for a split name."""
117
+ super().__setitem__(key, value)
118
+
119
+ @property
120
+ def streaming(self) -> bool:
121
+ """Whether this DatasetDict was loaded in streaming mode."""
122
+ return self._streaming
123
+
124
+ @property
125
+ def num_shards(self) -> dict[str, int]:
126
+ """Number of shards in each split.
127
+
128
+ Returns:
129
+ Dict mapping split names to shard counts.
130
+
131
+ Note:
132
+ This property accesses the shard list, which may trigger
133
+ shard enumeration for remote datasets.
134
+ """
135
+ return {name: len(ds.list_shards()) for name, ds in self.items()}
136
+
137
+
138
+ ##
139
+ # Path resolution utilities
140
+
141
+
142
+ def _is_brace_pattern(path: str) -> bool:
143
+ """Check if path contains WebDataset brace expansion notation like {000..099}."""
144
+ return bool(re.search(r"\{[^}]+\}", path))
145
+
146
+
147
+ def _is_glob_pattern(path: str) -> bool:
148
+ """Check if path contains glob wildcards (* or ?)."""
149
+ return "*" in path or "?" in path
150
+
151
+
152
+ def _is_remote_url(path: str) -> bool:
153
+ """Check if path is a remote URL (s3://, gs://, http://, https://, az://)."""
154
+ return path.startswith(("s3://", "gs://", "http://", "https://", "az://"))
155
+
156
+
157
+ def _expand_local_glob(pattern: str) -> list[str]:
158
+ """Expand local glob pattern to sorted list of matching file paths."""
159
+ base_path = Path(pattern).parent
160
+ glob_part = Path(pattern).name
161
+
162
+ if not base_path.exists():
163
+ return []
164
+
165
+ matches = sorted(base_path.glob(glob_part))
166
+ return [str(p) for p in matches if p.is_file()]
167
+
168
+
169
+ # Pre-compiled split name patterns (pattern, split_name)
170
+ _SPLIT_PATTERNS: list[tuple[re.Pattern[str], str]] = [
171
+ # Patterns like "dataset-train-000000.tar" (split in middle with delimiters)
172
+ (re.compile(r"[_-](train|training)[_-]"), "train"),
173
+ (re.compile(r"[_-](test|testing)[_-]"), "test"),
174
+ (re.compile(r"[_-](val|valid|validation)[_-]"), "validation"),
175
+ (re.compile(r"[_-](dev|development)[_-]"), "validation"),
176
+ # Patterns at start of filename like "train-000.tar" or "test_data.tar"
177
+ (re.compile(r"^(train|training)[_-]"), "train"),
178
+ (re.compile(r"^(test|testing)[_-]"), "test"),
179
+ (re.compile(r"^(val|valid|validation)[_-]"), "validation"),
180
+ (re.compile(r"^(dev|development)[_-]"), "validation"),
181
+ # Patterns in directory path like "/path/train/shard-000.tar"
182
+ (re.compile(r"[/\\](train|training)[/\\]"), "train"),
183
+ (re.compile(r"[/\\](test|testing)[/\\]"), "test"),
184
+ (re.compile(r"[/\\](val|valid|validation)[/\\]"), "validation"),
185
+ (re.compile(r"[/\\](dev|development)[/\\]"), "validation"),
186
+ # Patterns at start of path like "train/shard-000.tar"
187
+ (re.compile(r"^(train|training)[/\\]"), "train"),
188
+ (re.compile(r"^(test|testing)[/\\]"), "test"),
189
+ (re.compile(r"^(val|valid|validation)[/\\]"), "validation"),
190
+ (re.compile(r"^(dev|development)[/\\]"), "validation"),
191
+ ]
192
+
193
+
194
+ def _detect_split_from_path(path: str) -> str | None:
195
+ """Detect split name (train/test/validation) from file path."""
196
+ filename = Path(path).name
197
+ path_lower = path.lower()
198
+ filename_lower = filename.lower()
199
+
200
+ # Check filename first (more specific)
201
+ for pattern, split_name in _SPLIT_PATTERNS:
202
+ if pattern.search(filename_lower):
203
+ return split_name
204
+
205
+ # Fall back to full path (catches directory patterns)
206
+ for pattern, split_name in _SPLIT_PATTERNS:
207
+ if pattern.search(path_lower):
208
+ return split_name
209
+
210
+ return None
211
+
212
+
213
+ def _resolve_shards(
214
+ path: str,
215
+ data_files: str | list[str] | dict[str, str | list[str]] | None = None,
216
+ ) -> dict[str, list[str]]:
217
+ """Resolve path specification to dict of split -> shard URLs.
218
+
219
+ Handles:
220
+ - WebDataset brace notation: "path/{train,test}-{000..099}.tar"
221
+ - Glob patterns: "path/*.tar"
222
+ - Explicit data_files mapping
223
+
224
+ Args:
225
+ path: Base path or pattern.
226
+ data_files: Optional explicit mapping of splits to files.
227
+
228
+ Returns:
229
+ Dict mapping split names to lists of shard URLs.
230
+ """
231
+ # If explicit data_files provided, use those
232
+ if data_files is not None:
233
+ return _resolve_data_files(path, data_files)
234
+
235
+ # WebDataset brace notation - pass through as-is
236
+ # WebDataset handles expansion internally
237
+ if _is_brace_pattern(path):
238
+ # Try to detect split from the pattern itself
239
+ split = _detect_split_from_path(path)
240
+ split_name = split or "train"
241
+ return {split_name: [path]}
242
+
243
+ # Local glob pattern
244
+ if not _is_remote_url(path) and _is_glob_pattern(path):
245
+ shards = _expand_local_glob(path)
246
+ return _group_shards_by_split(shards)
247
+
248
+ # Local directory - scan for .tar files
249
+ if not _is_remote_url(path) and Path(path).is_dir():
250
+ shards = _expand_local_glob(str(Path(path) / "*.tar"))
251
+ return _group_shards_by_split(shards)
252
+
253
+ # Single file or remote URL - treat as single shard
254
+ split = _detect_split_from_path(path)
255
+ split_name = split or "train"
256
+ return {split_name: [path]}
257
+
258
+
259
+ def _resolve_data_files(
260
+ base_path: str,
261
+ data_files: str | list[str] | dict[str, str | list[str]],
262
+ ) -> dict[str, list[str]]:
263
+ """Resolve explicit data_files specification.
264
+
265
+ Args:
266
+ base_path: Base path for relative file references.
267
+ data_files: File specification - can be:
268
+ - str: Single file pattern
269
+ - list[str]: List of file patterns
270
+ - dict[str, ...]: Mapping of split names to patterns
271
+
272
+ Returns:
273
+ Dict mapping split names to lists of resolved file paths.
274
+ """
275
+ base = Path(base_path) if not _is_remote_url(base_path) else None
276
+
277
+ if isinstance(data_files, str):
278
+ # Single pattern -> "train" split
279
+ if base and not Path(data_files).is_absolute():
280
+ data_files = str(base / data_files)
281
+ return {"train": [data_files]}
282
+
283
+ if isinstance(data_files, list):
284
+ # List of patterns -> "train" split
285
+ resolved = []
286
+ for f in data_files:
287
+ if base and not Path(f).is_absolute():
288
+ f = str(base / f)
289
+ resolved.append(f)
290
+ return {"train": resolved}
291
+
292
+ # Dict mapping splits to patterns
293
+ result: dict[str, list[str]] = {}
294
+ for split_name, files in data_files.items():
295
+ if isinstance(files, str):
296
+ files = [files]
297
+ resolved = []
298
+ for f in files:
299
+ if base and not Path(f).is_absolute():
300
+ f = str(base / f)
301
+ resolved.append(f)
302
+ result[split_name] = resolved
303
+
304
+ return result
305
+
306
+
307
+ def _shards_to_wds_url(shards: list[str]) -> str:
308
+ """Convert a list of shard paths to a WebDataset URL.
309
+
310
+ WebDataset supports brace expansion, so we convert multiple shards
311
+ into brace notation when they share a common prefix/suffix.
312
+
313
+ Args:
314
+ shards: List of shard file paths.
315
+
316
+ Returns:
317
+ WebDataset-compatible URL string.
318
+
319
+ Examples:
320
+ >>> _shards_to_wds_url(["data-000.tar", "data-001.tar", "data-002.tar"])
321
+ "data-{000,001,002}.tar"
322
+ >>> _shards_to_wds_url(["train.tar"])
323
+ "train.tar"
324
+ """
325
+ import os.path
326
+
327
+ if len(shards) == 0:
328
+ raise ValueError("Cannot create URL from empty shard list")
329
+
330
+ if len(shards) == 1:
331
+ return shards[0]
332
+
333
+ # Find common prefix using os.path.commonprefix (O(n) vs O(n²))
334
+ prefix = os.path.commonprefix(shards)
335
+
336
+ # Find common suffix by reversing strings
337
+ reversed_shards = [s[::-1] for s in shards]
338
+ suffix = os.path.commonprefix(reversed_shards)[::-1]
339
+
340
+ prefix_len = len(prefix)
341
+ suffix_len = len(suffix)
342
+
343
+ # Ensure prefix and suffix don't overlap
344
+ min_shard_len = min(len(s) for s in shards)
345
+ if prefix_len + suffix_len > min_shard_len:
346
+ # Overlapping - prefer prefix, reduce suffix
347
+ suffix_len = max(0, min_shard_len - prefix_len)
348
+ suffix = shards[0][-suffix_len:] if suffix_len > 0 else ""
349
+
350
+ if prefix_len > 0 or suffix_len > 0:
351
+ # Extract the varying middle parts
352
+ middles = []
353
+ for s in shards:
354
+ if suffix_len > 0:
355
+ middle = s[prefix_len:-suffix_len]
356
+ else:
357
+ middle = s[prefix_len:]
358
+ middles.append(middle)
359
+
360
+ # Only use brace notation if we have meaningful variation
361
+ if all(middles):
362
+ return f"{prefix}{{{','.join(middles)}}}{suffix}"
363
+
364
+ # Fallback: space-separated URLs for WebDataset
365
+ return " ".join(shards)
366
+
367
+
368
+ def _group_shards_by_split(shards: list[str]) -> dict[str, list[str]]:
369
+ """Group a list of shard paths by detected split.
370
+
371
+ Args:
372
+ shards: List of shard file paths.
373
+
374
+ Returns:
375
+ Dict mapping split names to lists of shards. Files with no
376
+ detected split are placed in "train".
377
+ """
378
+ result: dict[str, list[str]] = {}
379
+
380
+ for shard in shards:
381
+ split = _detect_split_from_path(shard)
382
+ split_name = split or "train"
383
+ if split_name not in result:
384
+ result[split_name] = []
385
+ result[split_name].append(shard)
386
+
387
+ return result
388
+
389
+
390
+ ##
391
+ # Index-based path resolution
392
+
393
+
394
+ def _is_indexed_path(path: str) -> bool:
395
+ """Check if path uses @handle/dataset notation for index lookup.
396
+
397
+ Examples:
398
+ >>> _is_indexed_path("@maxine.science/mnist")
399
+ True
400
+ >>> _is_indexed_path("@did:plc:abc123/my-dataset")
401
+ True
402
+ >>> _is_indexed_path("s3://bucket/data.tar")
403
+ False
404
+ """
405
+ return path.startswith("@")
406
+
407
+
408
+ def _parse_indexed_path(path: str) -> tuple[str, str]:
409
+ """Parse @handle/dataset path into (handle_or_did, dataset_name).
410
+
411
+ Args:
412
+ path: Path in format "@handle/dataset" or "@did:plc:xxx/dataset"
413
+
414
+ Returns:
415
+ Tuple of (handle_or_did, dataset_name)
416
+
417
+ Raises:
418
+ ValueError: If path format is invalid.
419
+ """
420
+ if not path.startswith("@"):
421
+ raise ValueError(f"Not an indexed path: {path}")
422
+
423
+ # Remove leading @
424
+ rest = path[1:]
425
+
426
+ # Split on first / (handle can contain . but dataset name is after /)
427
+ if "/" not in rest:
428
+ raise ValueError(
429
+ f"Invalid indexed path format: {path}. "
430
+ "Expected @handle/dataset or @did:plc:xxx/dataset"
431
+ )
432
+
433
+ # Find the split point - for DIDs, the format is did:plc:xxx/dataset
434
+ # For handles, it's handle.domain/dataset
435
+ parts = rest.split("/", 1)
436
+ if len(parts) != 2 or not parts[0] or not parts[1]:
437
+ raise ValueError(f"Invalid indexed path: {path}")
438
+
439
+ return parts[0], parts[1]
440
+
441
+
442
+ def _resolve_indexed_path(
443
+ path: str,
444
+ index: "AbstractIndex",
445
+ ) -> tuple[DataSource, str]:
446
+ """Resolve @handle/dataset path to DataSource and schema_ref via index lookup.
447
+
448
+ Args:
449
+ path: Path in @handle/dataset format.
450
+ index: Index to use for lookup.
451
+
452
+ Returns:
453
+ Tuple of (DataSource, schema_ref). The DataSource is configured with
454
+ appropriate credentials when the index has an S3DataStore.
455
+
456
+ Raises:
457
+ KeyError: If dataset not found in index.
458
+ """
459
+ handle_or_did, dataset_name = _parse_indexed_path(path)
460
+
461
+ # For AtmosphereIndex, we need to resolve handle to DID first
462
+ # For LocalIndex, the handle is ignored and we just look up by name
463
+ entry = index.get_dataset(dataset_name)
464
+ data_urls = entry.data_urls
465
+
466
+ # Check if index has a data store
467
+ if hasattr(index, "data_store") and index.data_store is not None:
468
+ store = index.data_store
469
+
470
+ # Import here to avoid circular imports at module level
471
+ from .local import S3DataStore
472
+
473
+ # For S3DataStore with S3 URLs, create S3Source with credentials
474
+ if isinstance(store, S3DataStore):
475
+ if data_urls and all(url.startswith("s3://") for url in data_urls):
476
+ source = S3Source.from_urls(
477
+ data_urls,
478
+ endpoint=store.credentials.get("AWS_ENDPOINT"),
479
+ access_key=store.credentials.get("AWS_ACCESS_KEY_ID"),
480
+ secret_key=store.credentials.get("AWS_SECRET_ACCESS_KEY"),
481
+ region=store.credentials.get("AWS_REGION"),
482
+ )
483
+ return source, entry.schema_ref
484
+
485
+ # For any data store, use read_url to transform URLs if needed
486
+ # (handles endpoint URL conversion for HTTPS access, etc.)
487
+ transformed_urls = [store.read_url(url) for url in data_urls]
488
+ url = _shards_to_wds_url(transformed_urls)
489
+ return URLSource(url), entry.schema_ref
490
+
491
+ # Default: URL-based source without credentials
492
+ url = _shards_to_wds_url(data_urls)
493
+ return URLSource(url), entry.schema_ref
494
+
495
+
496
+ ##
497
+ # Main load_dataset function
498
+
499
+
500
+ # Overload: explicit type with split -> Dataset[ST]
501
+ @overload
502
+ def load_dataset(
503
+ path: str,
504
+ sample_type: Type[ST],
505
+ *,
506
+ split: str,
507
+ data_files: str | list[str] | dict[str, str | list[str]] | None = None,
508
+ streaming: bool = False,
509
+ index: Optional["AbstractIndex"] = None,
510
+ ) -> Dataset[ST]: ...
511
+
512
+
513
+ # Overload: explicit type without split -> DatasetDict[ST]
514
+ @overload
515
+ def load_dataset(
516
+ path: str,
517
+ sample_type: Type[ST],
518
+ *,
519
+ split: None = None,
520
+ data_files: str | list[str] | dict[str, str | list[str]] | None = None,
521
+ streaming: bool = False,
522
+ index: Optional["AbstractIndex"] = None,
523
+ ) -> DatasetDict[ST]: ...
524
+
525
+
526
+ # Overload: no type with split -> Dataset[DictSample]
527
+ @overload
528
+ def load_dataset(
529
+ path: str,
530
+ sample_type: None = None,
531
+ *,
532
+ split: str,
533
+ data_files: str | list[str] | dict[str, str | list[str]] | None = None,
534
+ streaming: bool = False,
535
+ index: Optional["AbstractIndex"] = None,
536
+ ) -> Dataset[DictSample]: ...
537
+
538
+
539
+ # Overload: no type without split -> DatasetDict[DictSample]
540
+ @overload
541
+ def load_dataset(
542
+ path: str,
543
+ sample_type: None = None,
544
+ *,
545
+ split: None = None,
546
+ data_files: str | list[str] | dict[str, str | list[str]] | None = None,
547
+ streaming: bool = False,
548
+ index: Optional["AbstractIndex"] = None,
549
+ ) -> DatasetDict[DictSample]: ...
550
+
551
+
552
+ def load_dataset(
553
+ path: str,
554
+ sample_type: Type[ST] | None = None,
555
+ *,
556
+ split: str | None = None,
557
+ data_files: str | list[str] | dict[str, str | list[str]] | None = None,
558
+ streaming: bool = False,
559
+ index: Optional["AbstractIndex"] = None,
560
+ ) -> Dataset[ST] | DatasetDict[ST]:
561
+ """Load a dataset from local files, remote URLs, or an index.
562
+
563
+ This function provides a HuggingFace Datasets-style interface for loading
564
+ atdata typed datasets. It handles path resolution, split detection, and
565
+ returns either a single Dataset or a DatasetDict depending on the split
566
+ parameter.
567
+
568
+ When no ``sample_type`` is provided, returns a ``Dataset[DictSample]`` that
569
+ provides dynamic dict-like access to fields. Use ``.as_type(MyType)`` to
570
+ convert to a typed schema.
571
+
572
+ Args:
573
+ path: Path to dataset. Can be:
574
+ - Index lookup: "@handle/dataset-name" or "@local/dataset-name"
575
+ - WebDataset brace notation: "path/to/{train,test}-{000..099}.tar"
576
+ - Local directory: "./data/" (scans for .tar files)
577
+ - Glob pattern: "path/to/*.tar"
578
+ - Remote URL: "s3://bucket/path/data-*.tar"
579
+ - Single file: "path/to/data.tar"
580
+
581
+ sample_type: The PackableSample subclass defining the schema. If None,
582
+ returns ``Dataset[DictSample]`` with dynamic field access. Can also
583
+ be resolved from an index when using @handle/dataset syntax.
584
+
585
+ split: Which split to load. If None, returns a DatasetDict with all
586
+ detected splits. If specified (e.g., "train", "test"), returns
587
+ a single Dataset for that split.
588
+
589
+ data_files: Optional explicit mapping of data files. Can be:
590
+ - str: Single file pattern
591
+ - list[str]: List of file patterns (assigned to "train")
592
+ - dict[str, str | list[str]]: Explicit split -> files mapping
593
+
594
+ streaming: If True, explicitly marks the dataset for streaming mode.
595
+ Note: atdata Datasets are already lazy/streaming via WebDataset
596
+ pipelines, so this parameter primarily signals intent.
597
+
598
+ index: Optional AbstractIndex for dataset lookup. Required when using
599
+ @handle/dataset syntax. When provided with an indexed path, the
600
+ schema can be auto-resolved from the index.
601
+
602
+ Returns:
603
+ If split is None: DatasetDict with all detected splits.
604
+ If split is specified: Dataset for that split.
605
+ Type is ``ST`` if sample_type provided, otherwise ``DictSample``.
606
+
607
+ Raises:
608
+ ValueError: If the specified split is not found.
609
+ FileNotFoundError: If no data files are found at the path.
610
+ KeyError: If dataset not found in index.
611
+
612
+ Examples:
613
+ >>> # Load without type - get DictSample for exploration
614
+ >>> ds = load_dataset("./data/train.tar", split="train")
615
+ >>> for sample in ds.ordered():
616
+ ... print(sample.keys()) # Explore fields
617
+ ... print(sample["text"]) # Dict-style access
618
+ ... print(sample.label) # Attribute access
619
+ >>>
620
+ >>> # Convert to typed schema
621
+ >>> typed_ds = ds.as_type(TextData)
622
+ >>>
623
+ >>> # Or load with explicit type directly
624
+ >>> train_ds = load_dataset("./data/train-*.tar", TextData, split="train")
625
+ >>>
626
+ >>> # Load from index with auto-type resolution
627
+ >>> index = LocalIndex()
628
+ >>> ds = load_dataset("@local/my-dataset", index=index, split="train")
629
+ """
630
+ # Handle @handle/dataset indexed path resolution
631
+ if _is_indexed_path(path):
632
+ if index is None:
633
+ raise ValueError(
634
+ f"Index required for indexed path: {path}. "
635
+ "Pass index=LocalIndex() or index=AtmosphereIndex(client)."
636
+ )
637
+
638
+ source, schema_ref = _resolve_indexed_path(path, index)
639
+
640
+ # Resolve sample_type from schema if not provided
641
+ resolved_type: Type = (
642
+ sample_type if sample_type is not None else index.decode_schema(schema_ref)
643
+ )
644
+
645
+ # Create dataset from the resolved source (includes credentials if S3)
646
+ ds = Dataset[resolved_type](source)
647
+
648
+ if split is not None:
649
+ # Indexed datasets are single-split by default
650
+ return ds
651
+
652
+ return DatasetDict(
653
+ {"train": ds}, sample_type=resolved_type, streaming=streaming
654
+ )
655
+
656
+ # Use DictSample as default when no type specified
657
+ resolved_type = sample_type if sample_type is not None else DictSample
658
+
659
+ # Resolve path to split -> shard URL mapping
660
+ splits_shards = _resolve_shards(path, data_files)
661
+
662
+ if not splits_shards:
663
+ raise FileNotFoundError(f"No data files found at path: {path}")
664
+
665
+ # Build Dataset for each split
666
+ datasets: dict[str, Dataset] = {}
667
+ for split_name, shards in splits_shards.items():
668
+ url = _shards_to_wds_url(shards)
669
+ ds = Dataset[resolved_type](url)
670
+ datasets[split_name] = ds
671
+
672
+ # Return single Dataset or DatasetDict
673
+ if split is not None:
674
+ if split not in datasets:
675
+ available = list(datasets.keys())
676
+ raise ValueError(
677
+ f"Split '{split}' not found. Available splits: {available}"
678
+ )
679
+ return datasets[split]
680
+
681
+ return DatasetDict(datasets, sample_type=resolved_type, streaming=streaming)
682
+
683
+
684
+ ##
685
+ # Convenience re-exports (will be exposed in __init__.py)
686
+
687
+ __all__ = [
688
+ "load_dataset",
689
+ "DatasetDict",
690
+ ]