atdata 0.2.0a1__py3-none-any.whl → 0.2.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atdata/_hf_api.py ADDED
@@ -0,0 +1,692 @@
1
+ """HuggingFace Datasets-style API for atdata.
2
+
3
+ This module provides a familiar `load_dataset()` interface inspired by the
4
+ HuggingFace Datasets library, adapted for atdata's typed WebDataset approach.
5
+
6
+ Key differences from HuggingFace Datasets:
7
+ - Requires explicit `sample_type` parameter (typed dataclass)
8
+ - Returns atdata.Dataset[ST] instead of HF Dataset
9
+ - Built on WebDataset for efficient streaming of large datasets
10
+ - No Arrow caching layer (WebDataset handles remote/local transparently)
11
+
12
+ Example:
13
+ ::
14
+
15
+ >>> import atdata
16
+ >>> from atdata import load_dataset
17
+ >>>
18
+ >>> @atdata.packable
19
+ ... class MyData:
20
+ ... text: str
21
+ ... label: int
22
+ >>>
23
+ >>> # Load a single split
24
+ >>> ds = load_dataset("path/to/train-{000000..000099}.tar", MyData, split="train")
25
+ >>>
26
+ >>> # Load all splits (returns DatasetDict)
27
+ >>> ds_dict = load_dataset("path/to/{train,test}-*.tar", MyData)
28
+ >>> train_ds = ds_dict["train"]
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import re
34
+ from pathlib import Path
35
+ from typing import (
36
+ TYPE_CHECKING,
37
+ Generic,
38
+ Mapping,
39
+ Optional,
40
+ Type,
41
+ TypeVar,
42
+ overload,
43
+ )
44
+
45
+ from .dataset import Dataset, PackableSample, DictSample
46
+ from ._sources import URLSource, S3Source
47
+ from ._protocols import DataSource
48
+
49
+ if TYPE_CHECKING:
50
+ from ._protocols import AbstractIndex
51
+ from .local import S3DataStore
52
+
53
+ ##
54
+ # Type variables
55
+
56
+ ST = TypeVar("ST", bound=PackableSample)
57
+
58
+
59
+ ##
60
+ # DatasetDict - container for multiple splits
61
+
62
+
63
+ class DatasetDict(Generic[ST], dict):
64
+ """A dictionary of split names to Dataset instances.
65
+
66
+ Similar to HuggingFace's DatasetDict, this provides a container for
67
+ multiple dataset splits (train, test, validation, etc.) with convenience
68
+ methods that operate across all splits.
69
+
70
+ Parameters:
71
+ ST: The sample type for all datasets in this dict.
72
+
73
+ Example:
74
+ ::
75
+
76
+ >>> ds_dict = load_dataset("path/to/data", MyData)
77
+ >>> train = ds_dict["train"]
78
+ >>> test = ds_dict["test"]
79
+ >>>
80
+ >>> # Iterate over all splits
81
+ >>> for split_name, dataset in ds_dict.items():
82
+ ... print(f"{split_name}: {len(dataset.shard_list)} shards")
83
+ """
84
+ # TODO The above has a line for "Parameters:" that should be "Type Parameters:"; this is a temporary fix for `quartodoc` auto-generation bugs.
85
+
86
+ def __init__(
87
+ self,
88
+ splits: Mapping[str, Dataset[ST]] | None = None,
89
+ sample_type: Type[ST] | None = None,
90
+ streaming: bool = False,
91
+ ) -> None:
92
+ """Create a DatasetDict from a mapping of split names to datasets.
93
+
94
+ Args:
95
+ splits: Mapping of split names to Dataset instances.
96
+ sample_type: The sample type for datasets in this dict. If not
97
+ provided, inferred from the first dataset in splits.
98
+ streaming: Whether this DatasetDict was loaded in streaming mode.
99
+ """
100
+ super().__init__(splits or {})
101
+ self._sample_type = sample_type
102
+ self._streaming = streaming
103
+
104
+ @property
105
+ def sample_type(self) -> Type[ST] | None:
106
+ """The sample type for datasets in this dict."""
107
+ if self._sample_type is not None:
108
+ return self._sample_type
109
+ # Infer from first dataset
110
+ if self:
111
+ first_ds = next(iter(self.values()))
112
+ return first_ds.sample_type
113
+ return None
114
+
115
+ def __getitem__(self, key: str) -> Dataset[ST]:
116
+ """Get a dataset by split name."""
117
+ return super().__getitem__(key)
118
+
119
+ def __setitem__(self, key: str, value: Dataset[ST]) -> None:
120
+ """Set a dataset for a split name."""
121
+ super().__setitem__(key, value)
122
+
123
+ @property
124
+ def streaming(self) -> bool:
125
+ """Whether this DatasetDict was loaded in streaming mode."""
126
+ return self._streaming
127
+
128
+ @property
129
+ def num_shards(self) -> dict[str, int]:
130
+ """Number of shards in each split.
131
+
132
+ Returns:
133
+ Dict mapping split names to shard counts.
134
+
135
+ Note:
136
+ This property accesses the shard list, which may trigger
137
+ shard enumeration for remote datasets.
138
+ """
139
+ return {name: len(ds.list_shards()) for name, ds in self.items()}
140
+
141
+
142
+ ##
143
+ # Path resolution utilities
144
+
145
+
146
+ def _is_brace_pattern(path: str) -> bool:
147
+ """Check if path contains WebDataset brace expansion notation like {000..099}."""
148
+ return bool(re.search(r"\{[^}]+\}", path))
149
+
150
+
151
+ def _is_glob_pattern(path: str) -> bool:
152
+ """Check if path contains glob wildcards (* or ?)."""
153
+ return "*" in path or "?" in path
154
+
155
+
156
+ def _is_remote_url(path: str) -> bool:
157
+ """Check if path is a remote URL (s3://, gs://, http://, https://, az://)."""
158
+ return path.startswith(("s3://", "gs://", "http://", "https://", "az://"))
159
+
160
+
161
+ def _expand_local_glob(pattern: str) -> list[str]:
162
+ """Expand local glob pattern to sorted list of matching file paths."""
163
+ base_path = Path(pattern).parent
164
+ glob_part = Path(pattern).name
165
+
166
+ if not base_path.exists():
167
+ return []
168
+
169
+ matches = sorted(base_path.glob(glob_part))
170
+ return [str(p) for p in matches if p.is_file()]
171
+
172
+
173
+ # Pre-compiled split name patterns (pattern, split_name)
174
+ _SPLIT_PATTERNS: list[tuple[re.Pattern[str], str]] = [
175
+ # Patterns like "dataset-train-000000.tar" (split in middle with delimiters)
176
+ (re.compile(r"[_-](train|training)[_-]"), "train"),
177
+ (re.compile(r"[_-](test|testing)[_-]"), "test"),
178
+ (re.compile(r"[_-](val|valid|validation)[_-]"), "validation"),
179
+ (re.compile(r"[_-](dev|development)[_-]"), "validation"),
180
+ # Patterns at start of filename like "train-000.tar" or "test_data.tar"
181
+ (re.compile(r"^(train|training)[_-]"), "train"),
182
+ (re.compile(r"^(test|testing)[_-]"), "test"),
183
+ (re.compile(r"^(val|valid|validation)[_-]"), "validation"),
184
+ (re.compile(r"^(dev|development)[_-]"), "validation"),
185
+ # Patterns in directory path like "/path/train/shard-000.tar"
186
+ (re.compile(r"[/\\](train|training)[/\\]"), "train"),
187
+ (re.compile(r"[/\\](test|testing)[/\\]"), "test"),
188
+ (re.compile(r"[/\\](val|valid|validation)[/\\]"), "validation"),
189
+ (re.compile(r"[/\\](dev|development)[/\\]"), "validation"),
190
+ # Patterns at start of path like "train/shard-000.tar"
191
+ (re.compile(r"^(train|training)[/\\]"), "train"),
192
+ (re.compile(r"^(test|testing)[/\\]"), "test"),
193
+ (re.compile(r"^(val|valid|validation)[/\\]"), "validation"),
194
+ (re.compile(r"^(dev|development)[/\\]"), "validation"),
195
+ ]
196
+
197
+
198
+ def _detect_split_from_path(path: str) -> str | None:
199
+ """Detect split name (train/test/validation) from file path."""
200
+ filename = Path(path).name
201
+ path_lower = path.lower()
202
+ filename_lower = filename.lower()
203
+
204
+ # Check filename first (more specific)
205
+ for pattern, split_name in _SPLIT_PATTERNS:
206
+ if pattern.search(filename_lower):
207
+ return split_name
208
+
209
+ # Fall back to full path (catches directory patterns)
210
+ for pattern, split_name in _SPLIT_PATTERNS:
211
+ if pattern.search(path_lower):
212
+ return split_name
213
+
214
+ return None
215
+
216
+
217
+ def _resolve_shards(
218
+ path: str,
219
+ data_files: str | list[str] | dict[str, str | list[str]] | None = None,
220
+ ) -> dict[str, list[str]]:
221
+ """Resolve path specification to dict of split -> shard URLs.
222
+
223
+ Handles:
224
+ - WebDataset brace notation: "path/{train,test}-{000..099}.tar"
225
+ - Glob patterns: "path/*.tar"
226
+ - Explicit data_files mapping
227
+
228
+ Args:
229
+ path: Base path or pattern.
230
+ data_files: Optional explicit mapping of splits to files.
231
+
232
+ Returns:
233
+ Dict mapping split names to lists of shard URLs.
234
+ """
235
+ # If explicit data_files provided, use those
236
+ if data_files is not None:
237
+ return _resolve_data_files(path, data_files)
238
+
239
+ # WebDataset brace notation - pass through as-is
240
+ # WebDataset handles expansion internally
241
+ if _is_brace_pattern(path):
242
+ # Try to detect split from the pattern itself
243
+ split = _detect_split_from_path(path)
244
+ split_name = split or "train"
245
+ return {split_name: [path]}
246
+
247
+ # Local glob pattern
248
+ if not _is_remote_url(path) and _is_glob_pattern(path):
249
+ shards = _expand_local_glob(path)
250
+ return _group_shards_by_split(shards)
251
+
252
+ # Local directory - scan for .tar files
253
+ if not _is_remote_url(path) and Path(path).is_dir():
254
+ shards = _expand_local_glob(str(Path(path) / "*.tar"))
255
+ return _group_shards_by_split(shards)
256
+
257
+ # Single file or remote URL - treat as single shard
258
+ split = _detect_split_from_path(path)
259
+ split_name = split or "train"
260
+ return {split_name: [path]}
261
+
262
+
263
+ def _resolve_data_files(
264
+ base_path: str,
265
+ data_files: str | list[str] | dict[str, str | list[str]],
266
+ ) -> dict[str, list[str]]:
267
+ """Resolve explicit data_files specification.
268
+
269
+ Args:
270
+ base_path: Base path for relative file references.
271
+ data_files: File specification - can be:
272
+ - str: Single file pattern
273
+ - list[str]: List of file patterns
274
+ - dict[str, ...]: Mapping of split names to patterns
275
+
276
+ Returns:
277
+ Dict mapping split names to lists of resolved file paths.
278
+ """
279
+ base = Path(base_path) if not _is_remote_url(base_path) else None
280
+
281
+ if isinstance(data_files, str):
282
+ # Single pattern -> "train" split
283
+ if base and not Path(data_files).is_absolute():
284
+ data_files = str(base / data_files)
285
+ return {"train": [data_files]}
286
+
287
+ if isinstance(data_files, list):
288
+ # List of patterns -> "train" split
289
+ resolved = []
290
+ for f in data_files:
291
+ if base and not Path(f).is_absolute():
292
+ f = str(base / f)
293
+ resolved.append(f)
294
+ return {"train": resolved}
295
+
296
+ # Dict mapping splits to patterns
297
+ result: dict[str, list[str]] = {}
298
+ for split_name, files in data_files.items():
299
+ if isinstance(files, str):
300
+ files = [files]
301
+ resolved = []
302
+ for f in files:
303
+ if base and not Path(f).is_absolute():
304
+ f = str(base / f)
305
+ resolved.append(f)
306
+ result[split_name] = resolved
307
+
308
+ return result
309
+
310
+
311
+ def _shards_to_wds_url(shards: list[str]) -> str:
312
+ """Convert a list of shard paths to a WebDataset URL.
313
+
314
+ WebDataset supports brace expansion, so we convert multiple shards
315
+ into brace notation when they share a common prefix/suffix.
316
+
317
+ Args:
318
+ shards: List of shard file paths.
319
+
320
+ Returns:
321
+ WebDataset-compatible URL string.
322
+
323
+ Examples:
324
+ >>> _shards_to_wds_url(["data-000.tar", "data-001.tar", "data-002.tar"])
325
+ "data-{000,001,002}.tar"
326
+ >>> _shards_to_wds_url(["train.tar"])
327
+ "train.tar"
328
+ """
329
+ import os.path
330
+
331
+ if len(shards) == 0:
332
+ raise ValueError("Cannot create URL from empty shard list")
333
+
334
+ if len(shards) == 1:
335
+ return shards[0]
336
+
337
+ # Find common prefix using os.path.commonprefix (O(n) vs O(n²))
338
+ prefix = os.path.commonprefix(shards)
339
+
340
+ # Find common suffix by reversing strings
341
+ reversed_shards = [s[::-1] for s in shards]
342
+ suffix = os.path.commonprefix(reversed_shards)[::-1]
343
+
344
+ prefix_len = len(prefix)
345
+ suffix_len = len(suffix)
346
+
347
+ # Ensure prefix and suffix don't overlap
348
+ min_shard_len = min(len(s) for s in shards)
349
+ if prefix_len + suffix_len > min_shard_len:
350
+ # Overlapping - prefer prefix, reduce suffix
351
+ suffix_len = max(0, min_shard_len - prefix_len)
352
+ suffix = shards[0][-suffix_len:] if suffix_len > 0 else ""
353
+
354
+ if prefix_len > 0 or suffix_len > 0:
355
+ # Extract the varying middle parts
356
+ middles = []
357
+ for s in shards:
358
+ if suffix_len > 0:
359
+ middle = s[prefix_len:-suffix_len]
360
+ else:
361
+ middle = s[prefix_len:]
362
+ middles.append(middle)
363
+
364
+ # Only use brace notation if we have meaningful variation
365
+ if all(middles):
366
+ return f"{prefix}{{{','.join(middles)}}}{suffix}"
367
+
368
+ # Fallback: space-separated URLs for WebDataset
369
+ return " ".join(shards)
370
+
371
+
372
+ def _group_shards_by_split(shards: list[str]) -> dict[str, list[str]]:
373
+ """Group a list of shard paths by detected split.
374
+
375
+ Args:
376
+ shards: List of shard file paths.
377
+
378
+ Returns:
379
+ Dict mapping split names to lists of shards. Files with no
380
+ detected split are placed in "train".
381
+ """
382
+ result: dict[str, list[str]] = {}
383
+
384
+ for shard in shards:
385
+ split = _detect_split_from_path(shard)
386
+ split_name = split or "train"
387
+ if split_name not in result:
388
+ result[split_name] = []
389
+ result[split_name].append(shard)
390
+
391
+ return result
392
+
393
+
394
+ ##
395
+ # Index-based path resolution
396
+
397
+
398
+ def _is_indexed_path(path: str) -> bool:
399
+ """Check if path uses @handle/dataset notation for index lookup.
400
+
401
+ Examples:
402
+ >>> _is_indexed_path("@maxine.science/mnist")
403
+ True
404
+ >>> _is_indexed_path("@did:plc:abc123/my-dataset")
405
+ True
406
+ >>> _is_indexed_path("s3://bucket/data.tar")
407
+ False
408
+ """
409
+ return path.startswith("@")
410
+
411
+
412
+ def _parse_indexed_path(path: str) -> tuple[str, str]:
413
+ """Parse @handle/dataset path into (handle_or_did, dataset_name).
414
+
415
+ Args:
416
+ path: Path in format "@handle/dataset" or "@did:plc:xxx/dataset"
417
+
418
+ Returns:
419
+ Tuple of (handle_or_did, dataset_name)
420
+
421
+ Raises:
422
+ ValueError: If path format is invalid.
423
+ """
424
+ if not path.startswith("@"):
425
+ raise ValueError(f"Not an indexed path: {path}")
426
+
427
+ # Remove leading @
428
+ rest = path[1:]
429
+
430
+ # Split on first / (handle can contain . but dataset name is after /)
431
+ if "/" not in rest:
432
+ raise ValueError(
433
+ f"Invalid indexed path format: {path}. "
434
+ "Expected @handle/dataset or @did:plc:xxx/dataset"
435
+ )
436
+
437
+ # Find the split point - for DIDs, the format is did:plc:xxx/dataset
438
+ # For handles, it's handle.domain/dataset
439
+ parts = rest.split("/", 1)
440
+ if len(parts) != 2 or not parts[0] or not parts[1]:
441
+ raise ValueError(f"Invalid indexed path: {path}")
442
+
443
+ return parts[0], parts[1]
444
+
445
+
446
+ def _resolve_indexed_path(
447
+ path: str,
448
+ index: "AbstractIndex",
449
+ ) -> tuple[DataSource, str]:
450
+ """Resolve @handle/dataset path to DataSource and schema_ref via index lookup.
451
+
452
+ Args:
453
+ path: Path in @handle/dataset format.
454
+ index: Index to use for lookup.
455
+
456
+ Returns:
457
+ Tuple of (DataSource, schema_ref). The DataSource is configured with
458
+ appropriate credentials when the index has an S3DataStore.
459
+
460
+ Raises:
461
+ KeyError: If dataset not found in index.
462
+ """
463
+ handle_or_did, dataset_name = _parse_indexed_path(path)
464
+
465
+ # For AtmosphereIndex, we need to resolve handle to DID first
466
+ # For LocalIndex, the handle is ignored and we just look up by name
467
+ entry = index.get_dataset(dataset_name)
468
+ data_urls = entry.data_urls
469
+
470
+ # Check if index has a data store
471
+ if hasattr(index, 'data_store') and index.data_store is not None:
472
+ store = index.data_store
473
+
474
+ # Import here to avoid circular imports at module level
475
+ from .local import S3DataStore
476
+
477
+ # For S3DataStore with S3 URLs, create S3Source with credentials
478
+ if isinstance(store, S3DataStore):
479
+ if data_urls and all(url.startswith("s3://") for url in data_urls):
480
+ source = S3Source.from_urls(
481
+ data_urls,
482
+ endpoint=store.credentials.get("AWS_ENDPOINT"),
483
+ access_key=store.credentials.get("AWS_ACCESS_KEY_ID"),
484
+ secret_key=store.credentials.get("AWS_SECRET_ACCESS_KEY"),
485
+ region=store.credentials.get("AWS_REGION"),
486
+ )
487
+ return source, entry.schema_ref
488
+
489
+ # For any data store, use read_url to transform URLs if needed
490
+ # (handles endpoint URL conversion for HTTPS access, etc.)
491
+ transformed_urls = [store.read_url(url) for url in data_urls]
492
+ url = _shards_to_wds_url(transformed_urls)
493
+ return URLSource(url), entry.schema_ref
494
+
495
+ # Default: URL-based source without credentials
496
+ url = _shards_to_wds_url(data_urls)
497
+ return URLSource(url), entry.schema_ref
498
+
499
+
500
+ ##
501
+ # Main load_dataset function
502
+
503
+
504
+ # Overload: explicit type with split -> Dataset[ST]
505
+ @overload
506
+ def load_dataset(
507
+ path: str,
508
+ sample_type: Type[ST],
509
+ *,
510
+ split: str,
511
+ data_files: str | list[str] | dict[str, str | list[str]] | None = None,
512
+ streaming: bool = False,
513
+ index: Optional["AbstractIndex"] = None,
514
+ ) -> Dataset[ST]: ...
515
+
516
+
517
+ # Overload: explicit type without split -> DatasetDict[ST]
518
+ @overload
519
+ def load_dataset(
520
+ path: str,
521
+ sample_type: Type[ST],
522
+ *,
523
+ split: None = None,
524
+ data_files: str | list[str] | dict[str, str | list[str]] | None = None,
525
+ streaming: bool = False,
526
+ index: Optional["AbstractIndex"] = None,
527
+ ) -> DatasetDict[ST]: ...
528
+
529
+
530
+ # Overload: no type with split -> Dataset[DictSample]
531
+ @overload
532
+ def load_dataset(
533
+ path: str,
534
+ sample_type: None = None,
535
+ *,
536
+ split: str,
537
+ data_files: str | list[str] | dict[str, str | list[str]] | None = None,
538
+ streaming: bool = False,
539
+ index: Optional["AbstractIndex"] = None,
540
+ ) -> Dataset[DictSample]: ...
541
+
542
+
543
+ # Overload: no type without split -> DatasetDict[DictSample]
544
+ @overload
545
+ def load_dataset(
546
+ path: str,
547
+ sample_type: None = None,
548
+ *,
549
+ split: None = None,
550
+ data_files: str | list[str] | dict[str, str | list[str]] | None = None,
551
+ streaming: bool = False,
552
+ index: Optional["AbstractIndex"] = None,
553
+ ) -> DatasetDict[DictSample]: ...
554
+
555
+
556
+ def load_dataset(
557
+ path: str,
558
+ sample_type: Type[ST] | None = None,
559
+ *,
560
+ split: str | None = None,
561
+ data_files: str | list[str] | dict[str, str | list[str]] | None = None,
562
+ streaming: bool = False,
563
+ index: Optional["AbstractIndex"] = None,
564
+ ) -> Dataset[ST] | DatasetDict[ST]:
565
+ """Load a dataset from local files, remote URLs, or an index.
566
+
567
+ This function provides a HuggingFace Datasets-style interface for loading
568
+ atdata typed datasets. It handles path resolution, split detection, and
569
+ returns either a single Dataset or a DatasetDict depending on the split
570
+ parameter.
571
+
572
+ When no ``sample_type`` is provided, returns a ``Dataset[DictSample]`` that
573
+ provides dynamic dict-like access to fields. Use ``.as_type(MyType)`` to
574
+ convert to a typed schema.
575
+
576
+ Args:
577
+ path: Path to dataset. Can be:
578
+ - Index lookup: "@handle/dataset-name" or "@local/dataset-name"
579
+ - WebDataset brace notation: "path/to/{train,test}-{000..099}.tar"
580
+ - Local directory: "./data/" (scans for .tar files)
581
+ - Glob pattern: "path/to/*.tar"
582
+ - Remote URL: "s3://bucket/path/data-*.tar"
583
+ - Single file: "path/to/data.tar"
584
+
585
+ sample_type: The PackableSample subclass defining the schema. If None,
586
+ returns ``Dataset[DictSample]`` with dynamic field access. Can also
587
+ be resolved from an index when using @handle/dataset syntax.
588
+
589
+ split: Which split to load. If None, returns a DatasetDict with all
590
+ detected splits. If specified (e.g., "train", "test"), returns
591
+ a single Dataset for that split.
592
+
593
+ data_files: Optional explicit mapping of data files. Can be:
594
+ - str: Single file pattern
595
+ - list[str]: List of file patterns (assigned to "train")
596
+ - dict[str, str | list[str]]: Explicit split -> files mapping
597
+
598
+ streaming: If True, explicitly marks the dataset for streaming mode.
599
+ Note: atdata Datasets are already lazy/streaming via WebDataset
600
+ pipelines, so this parameter primarily signals intent.
601
+
602
+ index: Optional AbstractIndex for dataset lookup. Required when using
603
+ @handle/dataset syntax. When provided with an indexed path, the
604
+ schema can be auto-resolved from the index.
605
+
606
+ Returns:
607
+ If split is None: DatasetDict with all detected splits.
608
+ If split is specified: Dataset for that split.
609
+ Type is ``ST`` if sample_type provided, otherwise ``DictSample``.
610
+
611
+ Raises:
612
+ ValueError: If the specified split is not found.
613
+ FileNotFoundError: If no data files are found at the path.
614
+ KeyError: If dataset not found in index.
615
+
616
+ Example:
617
+ ::
618
+
619
+ >>> # Load without type - get DictSample for exploration
620
+ >>> ds = load_dataset("./data/train.tar", split="train")
621
+ >>> for sample in ds.ordered():
622
+ ... print(sample.keys()) # Explore fields
623
+ ... print(sample["text"]) # Dict-style access
624
+ ... print(sample.label) # Attribute access
625
+ >>>
626
+ >>> # Convert to typed schema
627
+ >>> typed_ds = ds.as_type(TextData)
628
+ >>>
629
+ >>> # Or load with explicit type directly
630
+ >>> train_ds = load_dataset("./data/train-*.tar", TextData, split="train")
631
+ >>>
632
+ >>> # Load from index with auto-type resolution
633
+ >>> index = LocalIndex()
634
+ >>> ds = load_dataset("@local/my-dataset", index=index, split="train")
635
+ """
636
+ # Handle @handle/dataset indexed path resolution
637
+ if _is_indexed_path(path):
638
+ if index is None:
639
+ raise ValueError(
640
+ f"Index required for indexed path: {path}. "
641
+ "Pass index=LocalIndex() or index=AtmosphereIndex(client)."
642
+ )
643
+
644
+ source, schema_ref = _resolve_indexed_path(path, index)
645
+
646
+ # Resolve sample_type from schema if not provided
647
+ resolved_type: Type = sample_type if sample_type is not None else index.decode_schema(schema_ref)
648
+
649
+ # Create dataset from the resolved source (includes credentials if S3)
650
+ ds = Dataset[resolved_type](source)
651
+
652
+ if split is not None:
653
+ # Indexed datasets are single-split by default
654
+ return ds
655
+
656
+ return DatasetDict({"train": ds}, sample_type=resolved_type, streaming=streaming)
657
+
658
+ # Use DictSample as default when no type specified
659
+ resolved_type = sample_type if sample_type is not None else DictSample
660
+
661
+ # Resolve path to split -> shard URL mapping
662
+ splits_shards = _resolve_shards(path, data_files)
663
+
664
+ if not splits_shards:
665
+ raise FileNotFoundError(f"No data files found at path: {path}")
666
+
667
+ # Build Dataset for each split
668
+ datasets: dict[str, Dataset] = {}
669
+ for split_name, shards in splits_shards.items():
670
+ url = _shards_to_wds_url(shards)
671
+ ds = Dataset[resolved_type](url)
672
+ datasets[split_name] = ds
673
+
674
+ # Return single Dataset or DatasetDict
675
+ if split is not None:
676
+ if split not in datasets:
677
+ available = list(datasets.keys())
678
+ raise ValueError(
679
+ f"Split '{split}' not found. Available splits: {available}"
680
+ )
681
+ return datasets[split]
682
+
683
+ return DatasetDict(datasets, sample_type=resolved_type, streaming=streaming)
684
+
685
+
686
+ ##
687
+ # Convenience re-exports (will be exposed in __init__.py)
688
+
689
+ __all__ = [
690
+ "load_dataset",
691
+ "DatasetDict",
692
+ ]