atdata 0.2.0a1__py3-none-any.whl → 0.2.3b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atdata/local.py CHANGED
@@ -6,11 +6,13 @@ This module provides a local storage backend for atdata datasets using:
6
6
 
7
7
  The main classes are:
8
8
  - Repo: Manages dataset storage in S3 with Redis indexing
9
- - Index: Redis-backed index for tracking dataset metadata
10
- - BasicIndexEntry: Index entry representing a stored dataset
9
+ - LocalIndex: Redis-backed index for tracking dataset metadata
10
+ - LocalDatasetEntry: Index entry representing a stored dataset
11
11
 
12
12
  This is intended for development and small-scale deployment before
13
- migrating to the full atproto PDS infrastructure.
13
+ migrating to the full atproto PDS infrastructure. The implementation
14
+ uses ATProto-compatible CIDs for content addressing, enabling seamless
15
+ promotion from local storage to the atmosphere (ATProto network).
14
16
  """
15
17
 
16
18
  ##
@@ -20,8 +22,15 @@ from atdata import (
20
22
  PackableSample,
21
23
  Dataset,
22
24
  )
25
+ from atdata._cid import generate_cid
26
+ from atdata._type_utils import (
27
+ PRIMITIVE_TYPE_MAP,
28
+ unwrap_optional,
29
+ is_ndarray_type,
30
+ extract_ndarray_dtype,
31
+ )
32
+ from atdata._protocols import AbstractDataStore, Packable
23
33
 
24
- import os
25
34
  from pathlib import Path
26
35
  from uuid import uuid4
27
36
  from tempfile import TemporaryDirectory
@@ -38,142 +47,706 @@ import webdataset as wds
38
47
 
39
48
  from dataclasses import (
40
49
  dataclass,
41
- asdict,
42
50
  field,
43
51
  )
44
52
  from typing import (
45
53
  Any,
46
- Optional,
47
- Dict,
48
54
  Type,
49
55
  TypeVar,
50
56
  Generator,
57
+ Iterator,
51
58
  BinaryIO,
59
+ Optional,
60
+ Literal,
52
61
  cast,
62
+ get_type_hints,
63
+ get_origin,
64
+ get_args,
53
65
  )
66
+ from dataclasses import fields, is_dataclass
67
+ from datetime import datetime, timezone
68
+ import json
69
+ import warnings
70
+
71
+ T = TypeVar("T", bound=PackableSample)
72
+
73
+ # Redis key prefixes for index entries and schemas
74
+ REDIS_KEY_DATASET_ENTRY = "LocalDatasetEntry"
75
+ REDIS_KEY_SCHEMA = "LocalSchema"
76
+
54
77
 
55
- T = TypeVar( 'T', bound = PackableSample )
78
+ class SchemaNamespace:
79
+ """Namespace for accessing loaded schema types as attributes.
80
+
81
+ This class provides a module-like interface for accessing dynamically
82
+ loaded schema types. After calling ``index.load_schema(uri)``, the
83
+ schema's class becomes available as an attribute on this namespace.
84
+
85
+ Examples:
86
+ >>> index.load_schema("atdata://local/sampleSchema/MySample@1.0.0")
87
+ >>> MyType = index.types.MySample
88
+ >>> sample = MyType(field1="hello", field2=42)
89
+
90
+ The namespace supports:
91
+ - Attribute access: ``index.types.MySample``
92
+ - Iteration: ``for name in index.types: ...``
93
+ - Length: ``len(index.types)``
94
+ - Contains check: ``"MySample" in index.types``
95
+
96
+ Note:
97
+ For full IDE autocomplete support, import from the generated module::
98
+
99
+ # After load_schema with auto_stubs=True
100
+ from local.MySample_1_0_0 import MySample
101
+ sample = MySample(name="hello", value=42) # IDE knows signature!
102
+
103
+ Add ``index.stub_dir`` to your IDE's extraPaths for imports to resolve.
104
+ """
105
+
106
+ def __init__(self) -> None:
107
+ self._types: dict[str, Type[Packable]] = {}
108
+
109
+ def _register(self, name: str, cls: Type[Packable]) -> None:
110
+ """Register a schema type in the namespace."""
111
+ self._types[name] = cls
112
+
113
+ def __getattr__(self, name: str) -> Any:
114
+ # Returns Any to avoid IDE complaints about unknown attributes.
115
+ # For full IDE support, import from the generated module instead.
116
+ if name.startswith("_"):
117
+ raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'")
118
+ if name not in self._types:
119
+ raise AttributeError(
120
+ f"Schema '{name}' not loaded. "
121
+ f"Call index.load_schema() first to load the schema."
122
+ )
123
+ return self._types[name]
124
+
125
+ def __dir__(self) -> list[str]:
126
+ return list(self._types.keys()) + ["_types", "_register", "get"]
127
+
128
+ def __iter__(self) -> Iterator[str]:
129
+ return iter(self._types)
130
+
131
+ def __len__(self) -> int:
132
+ return len(self._types)
133
+
134
+ def __contains__(self, name: str) -> bool:
135
+ return name in self._types
136
+
137
+ def __repr__(self) -> str:
138
+ if not self._types:
139
+ return "SchemaNamespace(empty)"
140
+ names = ", ".join(sorted(self._types.keys()))
141
+ return f"SchemaNamespace({names})"
142
+
143
+ def get(self, name: str, default: T | None = None) -> Type[Packable] | T | None:
144
+ """Get a type by name, returning default if not found.
145
+
146
+ Args:
147
+ name: The schema class name to look up.
148
+ default: Value to return if not found (default: None).
149
+
150
+ Returns:
151
+ The schema class, or default if not loaded.
152
+ """
153
+ return self._types.get(name, default)
154
+
155
+
156
+ ##
157
+ # Schema types
158
+
159
+
160
+ @dataclass
161
+ class SchemaFieldType:
162
+ """Schema field type definition for local storage.
163
+
164
+ Represents a type in the schema type system, supporting primitives,
165
+ ndarrays, arrays, and references to other schemas.
166
+ """
167
+
168
+ kind: Literal["primitive", "ndarray", "ref", "array"]
169
+ """The category of type."""
170
+
171
+ primitive: Optional[str] = None
172
+ """For kind='primitive': one of 'str', 'int', 'float', 'bool', 'bytes'."""
173
+
174
+ dtype: Optional[str] = None
175
+ """For kind='ndarray': numpy dtype string (e.g., 'float32')."""
176
+
177
+ ref: Optional[str] = None
178
+ """For kind='ref': URI of referenced schema."""
179
+
180
+ items: Optional["SchemaFieldType"] = None
181
+ """For kind='array': type of array elements."""
182
+
183
+ @classmethod
184
+ def from_dict(cls, data: dict) -> "SchemaFieldType":
185
+ """Create from a dictionary (e.g., from Redis storage)."""
186
+ type_str = data.get("$type", "")
187
+ if "#" in type_str:
188
+ kind = type_str.split("#")[-1]
189
+ else:
190
+ kind = data.get("kind", "primitive")
191
+
192
+ items = None
193
+ if "items" in data and data["items"]:
194
+ items = cls.from_dict(data["items"])
195
+
196
+ return cls(
197
+ kind=kind, # type: ignore[arg-type]
198
+ primitive=data.get("primitive"),
199
+ dtype=data.get("dtype"),
200
+ ref=data.get("ref"),
201
+ items=items,
202
+ )
203
+
204
+ def to_dict(self) -> dict:
205
+ """Convert to dictionary for storage."""
206
+ result: dict[str, Any] = {"$type": f"local#{self.kind}"}
207
+ if self.kind == "primitive":
208
+ result["primitive"] = self.primitive
209
+ elif self.kind == "ndarray":
210
+ result["dtype"] = self.dtype
211
+ elif self.kind == "ref":
212
+ result["ref"] = self.ref
213
+ elif self.kind == "array" and self.items:
214
+ result["items"] = self.items.to_dict()
215
+ return result
216
+
217
+
218
+ @dataclass
219
+ class SchemaField:
220
+ """Schema field definition for local storage."""
221
+
222
+ name: str
223
+ """Field name."""
224
+
225
+ field_type: SchemaFieldType
226
+ """Type of this field."""
227
+
228
+ optional: bool = False
229
+ """Whether this field can be None."""
230
+
231
+ @classmethod
232
+ def from_dict(cls, data: dict) -> "SchemaField":
233
+ """Create from a dictionary."""
234
+ return cls(
235
+ name=data["name"],
236
+ field_type=SchemaFieldType.from_dict(data["fieldType"]),
237
+ optional=data.get("optional", False),
238
+ )
239
+
240
+ def to_dict(self) -> dict:
241
+ """Convert to dictionary for storage."""
242
+ return {
243
+ "name": self.name,
244
+ "fieldType": self.field_type.to_dict(),
245
+ "optional": self.optional,
246
+ }
247
+
248
+ def __getitem__(self, key: str) -> Any:
249
+ """Dict-style access for backwards compatibility."""
250
+ if key == "name":
251
+ return self.name
252
+ elif key == "fieldType":
253
+ return self.field_type.to_dict()
254
+ elif key == "optional":
255
+ return self.optional
256
+ raise KeyError(key)
257
+
258
+ def get(self, key: str, default: Any = None) -> Any:
259
+ """Dict-style get() for backwards compatibility."""
260
+ try:
261
+ return self[key]
262
+ except KeyError:
263
+ return default
264
+
265
+
266
+ @dataclass
267
+ class LocalSchemaRecord:
268
+ """Schema record for local storage.
269
+
270
+ Represents a PackableSample schema stored in the local index.
271
+ Aligns with the atmosphere SchemaRecord structure for seamless promotion.
272
+ """
273
+
274
+ name: str
275
+ """Schema name (typically the class name)."""
276
+
277
+ version: str
278
+ """Semantic version string (e.g., '1.0.0')."""
279
+
280
+ fields: list[SchemaField]
281
+ """List of field definitions."""
282
+
283
+ ref: str
284
+ """Schema reference URI (atdata://local/sampleSchema/{name}@{version})."""
285
+
286
+ description: Optional[str] = None
287
+ """Human-readable description."""
288
+
289
+ created_at: Optional[datetime] = None
290
+ """When this schema was published."""
291
+
292
+ @classmethod
293
+ def from_dict(cls, data: dict) -> "LocalSchemaRecord":
294
+ """Create from a dictionary (e.g., from Redis storage)."""
295
+ created_at = None
296
+ if "createdAt" in data:
297
+ try:
298
+ created_at = datetime.fromisoformat(data["createdAt"])
299
+ except (ValueError, TypeError):
300
+ created_at = None # Invalid datetime format, leave as None
301
+
302
+ return cls(
303
+ name=data["name"],
304
+ version=data["version"],
305
+ fields=[SchemaField.from_dict(f) for f in data.get("fields", [])],
306
+ ref=data.get("$ref", ""),
307
+ description=data.get("description"),
308
+ created_at=created_at,
309
+ )
310
+
311
+ def to_dict(self) -> dict:
312
+ """Convert to dictionary for storage."""
313
+ result: dict[str, Any] = {
314
+ "name": self.name,
315
+ "version": self.version,
316
+ "fields": [f.to_dict() for f in self.fields],
317
+ "$ref": self.ref,
318
+ }
319
+ if self.description:
320
+ result["description"] = self.description
321
+ if self.created_at:
322
+ result["createdAt"] = self.created_at.isoformat()
323
+ return result
324
+
325
+ def __getitem__(self, key: str) -> Any:
326
+ """Dict-style access for backwards compatibility."""
327
+ if key == "name":
328
+ return self.name
329
+ elif key == "version":
330
+ return self.version
331
+ elif key == "fields":
332
+ return self.fields # Returns list of SchemaField (also subscriptable)
333
+ elif key == "$ref":
334
+ return self.ref
335
+ elif key == "description":
336
+ return self.description
337
+ elif key == "createdAt":
338
+ return self.created_at.isoformat() if self.created_at else None
339
+ raise KeyError(key)
340
+
341
+ def __contains__(self, key: str) -> bool:
342
+ """Support 'in' operator for backwards compatibility."""
343
+ return key in ("name", "version", "fields", "$ref", "description", "createdAt")
344
+
345
+ def get(self, key: str, default: Any = None) -> Any:
346
+ """Dict-style get() for backwards compatibility."""
347
+ try:
348
+ return self[key]
349
+ except KeyError:
350
+ return default
56
351
 
57
352
 
58
353
  ##
59
354
  # Helpers
60
355
 
61
- def _kind_str_for_sample_type( st: Type[PackableSample] ) -> str:
62
- """Convert a sample type to a fully-qualified string identifier.
356
+
357
+ def _kind_str_for_sample_type(st: Type[Packable]) -> str:
358
+ """Return fully-qualified 'module.name' string for a sample type."""
359
+ return f"{st.__module__}.{st.__name__}"
360
+
361
+
362
+ def _create_s3_write_callbacks(
363
+ credentials: dict[str, Any],
364
+ temp_dir: str,
365
+ written_shards: list[str],
366
+ fs: S3FileSystem | None,
367
+ cache_local: bool,
368
+ add_s3_prefix: bool = False,
369
+ ) -> tuple:
370
+ """Create opener and post callbacks for ShardWriter with S3 upload.
63
371
 
64
372
  Args:
65
- st: The sample type class.
373
+ credentials: S3 credentials dict.
374
+ temp_dir: Temporary directory for local caching.
375
+ written_shards: List to append written shard paths to.
376
+ fs: S3FileSystem for direct writes (used when cache_local=False).
377
+ cache_local: If True, write locally then copy to S3.
378
+ add_s3_prefix: If True, prepend 's3://' to shard paths.
66
379
 
67
380
  Returns:
68
- A string in the format 'module.name' identifying the sample type.
381
+ Tuple of (writer_opener, writer_post) callbacks.
382
+ """
383
+ if cache_local:
384
+ import boto3
385
+
386
+ s3_client_kwargs = {
387
+ "aws_access_key_id": credentials["AWS_ACCESS_KEY_ID"],
388
+ "aws_secret_access_key": credentials["AWS_SECRET_ACCESS_KEY"],
389
+ }
390
+ if "AWS_ENDPOINT" in credentials:
391
+ s3_client_kwargs["endpoint_url"] = credentials["AWS_ENDPOINT"]
392
+ s3_client = boto3.client("s3", **s3_client_kwargs)
393
+
394
+ def _writer_opener(p: str):
395
+ local_path = Path(temp_dir) / p
396
+ local_path.parent.mkdir(parents=True, exist_ok=True)
397
+ return open(local_path, "wb")
398
+
399
+ def _writer_post(p: str):
400
+ local_path = Path(temp_dir) / p
401
+ path_parts = Path(p).parts
402
+ bucket = path_parts[0]
403
+ key = str(Path(*path_parts[1:]))
404
+
405
+ with open(local_path, "rb") as f_in:
406
+ s3_client.put_object(Bucket=bucket, Key=key, Body=f_in.read())
407
+
408
+ local_path.unlink()
409
+ if add_s3_prefix:
410
+ written_shards.append(f"s3://{p}")
411
+ else:
412
+ written_shards.append(p)
413
+
414
+ return _writer_opener, _writer_post
415
+ else:
416
+ assert fs is not None, "S3FileSystem required when cache_local=False"
417
+
418
+ def _direct_opener(s: str):
419
+ return cast(BinaryIO, fs.open(f"s3://{s}", "wb"))
420
+
421
+ def _direct_post(s: str):
422
+ if add_s3_prefix:
423
+ written_shards.append(f"s3://{s}")
424
+ else:
425
+ written_shards.append(s)
426
+
427
+ return _direct_opener, _direct_post
428
+
429
+
430
+ ##
431
+ # Schema helpers
432
+
433
+ # URI scheme prefixes
434
+ _ATDATA_URI_PREFIX = "atdata://local/sampleSchema/"
435
+ _LEGACY_URI_PREFIX = "local://schemas/"
436
+
437
+
438
+ def _schema_ref_from_type(sample_type: Type[Packable], version: str) -> str:
439
+ """Generate 'atdata://local/sampleSchema/{name}@{version}' reference."""
440
+ return _make_schema_ref(sample_type.__name__, version)
441
+
442
+
443
+ def _make_schema_ref(name: str, version: str) -> str:
444
+ """Generate schema reference URI from name and version."""
445
+ return f"{_ATDATA_URI_PREFIX}{name}@{version}"
446
+
447
+
448
+ def _parse_schema_ref(ref: str) -> tuple[str, str]:
449
+ """Parse schema reference into (name, version).
450
+
451
+ Supports both new format: 'atdata://local/sampleSchema/{name}@{version}'
452
+ and legacy format: 'local://schemas/{module.Class}@{version}'
69
453
  """
70
- return f'{st.__module__}.{st.__name__}'
454
+ if ref.startswith(_ATDATA_URI_PREFIX):
455
+ path = ref[len(_ATDATA_URI_PREFIX) :]
456
+ elif ref.startswith(_LEGACY_URI_PREFIX):
457
+ path = ref[len(_LEGACY_URI_PREFIX) :]
458
+ else:
459
+ raise ValueError(f"Invalid schema reference: {ref}")
460
+
461
+ if "@" not in path:
462
+ raise ValueError(f"Schema reference must include version (@version): {ref}")
463
+
464
+ name, version = path.rsplit("@", 1)
465
+ # For legacy format, extract just the class name from module.Class
466
+ if "." in name:
467
+ name = name.rsplit(".", 1)[1]
468
+ return name, version
469
+
470
+
471
+ def _parse_semver(version: str) -> tuple[int, int, int]:
472
+ """Parse semantic version string into (major, minor, patch) tuple."""
473
+ parts = version.split(".")
474
+ if len(parts) != 3:
475
+ raise ValueError(f"Invalid semver format: {version}")
476
+ return int(parts[0]), int(parts[1]), int(parts[2])
477
+
478
+
479
+ def _increment_patch(version: str) -> str:
480
+ """Increment patch version: 1.0.0 -> 1.0.1"""
481
+ major, minor, patch = _parse_semver(version)
482
+ return f"{major}.{minor}.{patch + 1}"
483
+
484
+
485
+ def _python_type_to_field_type(python_type: Any) -> dict:
486
+ """Convert Python type annotation to schema field type dict."""
487
+ if python_type in PRIMITIVE_TYPE_MAP:
488
+ return {
489
+ "$type": "local#primitive",
490
+ "primitive": PRIMITIVE_TYPE_MAP[python_type],
491
+ }
492
+
493
+ if is_ndarray_type(python_type):
494
+ return {"$type": "local#ndarray", "dtype": extract_ndarray_dtype(python_type)}
495
+
496
+ origin = get_origin(python_type)
497
+ if origin is list:
498
+ args = get_args(python_type)
499
+ items = (
500
+ _python_type_to_field_type(args[0])
501
+ if args
502
+ else {"$type": "local#primitive", "primitive": "str"}
503
+ )
504
+ return {"$type": "local#array", "items": items}
505
+
506
+ if is_dataclass(python_type):
507
+ raise TypeError(
508
+ f"Nested dataclass types not yet supported: {python_type.__name__}. "
509
+ "Publish nested types separately and use references."
510
+ )
71
511
 
72
- def _decode_bytes_dict( d: dict[bytes, bytes] ) -> dict[str, str]:
73
- """Decode a dictionary with byte keys and values to strings.
512
+ raise TypeError(f"Unsupported type for schema field: {python_type}")
74
513
 
75
- Redis returns dictionaries with bytes keys/values, this converts them to strings.
514
+
515
+ def _build_schema_record(
516
+ sample_type: Type[Packable],
517
+ *,
518
+ version: str,
519
+ description: str | None = None,
520
+ ) -> dict:
521
+ """Build a schema record dict from a PackableSample type.
76
522
 
77
523
  Args:
78
- d: Dictionary with bytes keys and values.
524
+ sample_type: The PackableSample subclass to introspect.
525
+ version: Semantic version string.
526
+ description: Optional human-readable description. If None, uses the
527
+ class docstring.
79
528
 
80
529
  Returns:
81
- Dictionary with UTF-8 decoded string keys and values.
530
+ Schema record dict suitable for Redis storage.
531
+
532
+ Raises:
533
+ ValueError: If sample_type is not a dataclass.
534
+ TypeError: If a field type is not supported.
82
535
  """
536
+ if not is_dataclass(sample_type):
537
+ raise ValueError(f"{sample_type.__name__} must be a dataclass (use @packable)")
538
+
539
+ # Use docstring as fallback for description
540
+ if description is None:
541
+ description = sample_type.__doc__
542
+
543
+ field_defs = []
544
+ type_hints = get_type_hints(sample_type)
545
+
546
+ for f in fields(sample_type):
547
+ field_type = type_hints.get(f.name, f.type)
548
+ field_type, is_optional = unwrap_optional(field_type)
549
+ field_type_dict = _python_type_to_field_type(field_type)
550
+
551
+ field_defs.append(
552
+ {
553
+ "name": f.name,
554
+ "fieldType": field_type_dict,
555
+ "optional": is_optional,
556
+ }
557
+ )
558
+
83
559
  return {
84
- k.decode('utf-8'): v.decode('utf-8')
85
- for k, v in d.items()
560
+ "name": sample_type.__name__,
561
+ "version": version,
562
+ "fields": field_defs,
563
+ "description": description,
564
+ "createdAt": datetime.now(timezone.utc).isoformat(),
86
565
  }
87
566
 
88
567
 
89
568
  ##
90
569
  # Redis object model
91
570
 
571
+
92
572
  @dataclass
93
- class BasicIndexEntry:
94
- """Index entry for a dataset stored in the repository.
573
+ class LocalDatasetEntry:
574
+ """Index entry for a dataset stored in the local repository.
575
+
576
+ Implements the IndexEntry protocol for compatibility with AbstractIndex.
577
+ Uses dual identity: a content-addressable CID (ATProto-compatible) and
578
+ a human-readable name.
579
+
580
+ The CID is generated from the entry's content (schema_ref + data_urls),
581
+ ensuring the same data produces the same CID whether stored locally or
582
+ in the atmosphere. This enables seamless promotion from local to ATProto.
95
583
 
96
- Tracks metadata about a dataset stored in S3, including its location,
97
- type, and unique identifier.
584
+ Attributes:
585
+ name: Human-readable name for this dataset.
586
+ schema_ref: Reference to the schema for this dataset.
587
+ data_urls: WebDataset URLs for the data.
588
+ metadata: Arbitrary metadata dictionary, or None if not set.
98
589
  """
590
+
99
591
  ##
100
592
 
101
- wds_url: str
102
- """WebDataset URL for the dataset tar files, for use with atdata.Dataset."""
593
+ name: str
594
+ """Human-readable name for this dataset."""
595
+
596
+ schema_ref: str
597
+ """Reference to the schema for this dataset."""
103
598
 
104
- sample_kind: str
105
- """Fully-qualified sample type name (e.g., 'module.ClassName')."""
599
+ data_urls: list[str]
600
+ """WebDataset URLs for the data."""
106
601
 
107
- metadata_url: str | None
108
- """S3 URL to the dataset's metadata msgpack file, if any."""
602
+ metadata: dict | None = None
603
+ """Arbitrary metadata dictionary, or None if not set."""
604
+
605
+ _cid: str | None = field(default=None, repr=False)
606
+ """Content identifier (ATProto-compatible CID). Generated from content if not provided."""
607
+
608
+ # Legacy field for backwards compatibility during migration
609
+ _legacy_uuid: str | None = field(default=None, repr=False)
610
+ """Legacy UUID for backwards compatibility with existing Redis entries."""
611
+
612
+ def __post_init__(self):
613
+ """Generate CID from content if not provided."""
614
+ if self._cid is None:
615
+ self._cid = self._generate_cid()
616
+
617
+ def _generate_cid(self) -> str:
618
+ """Generate ATProto-compatible CID from entry content."""
619
+ # CID is based on schema_ref and data_urls - the identity of the dataset
620
+ content = {
621
+ "schema_ref": self.schema_ref,
622
+ "data_urls": self.data_urls,
623
+ }
624
+ return generate_cid(content)
625
+
626
+ @property
627
+ def cid(self) -> str:
628
+ """Content identifier (ATProto-compatible CID)."""
629
+ assert self._cid is not None
630
+ return self._cid
109
631
 
110
- uuid: str = field( default_factory = lambda: str( uuid4() ) )
111
- """Unique identifier for this dataset entry. Defaults to a new UUID if not provided."""
632
+ # Legacy compatibility
112
633
 
113
- def write_to( self, redis: Redis ):
634
+ @property
635
+ def wds_url(self) -> str:
636
+ """Legacy property: returns first data URL for backwards compatibility."""
637
+ return self.data_urls[0] if self.data_urls else ""
638
+
639
+ @property
640
+ def sample_kind(self) -> str:
641
+ """Legacy property: returns schema_ref for backwards compatibility."""
642
+ return self.schema_ref
643
+
644
+ def write_to(self, redis: Redis):
114
645
  """Persist this index entry to Redis.
115
646
 
116
- Stores the entry as a Redis hash with key 'BasicIndexEntry:{uuid}'.
647
+ Stores the entry as a Redis hash with key '{REDIS_KEY_DATASET_ENTRY}:{cid}'.
117
648
 
118
649
  Args:
119
650
  redis: Redis connection to write to.
120
651
  """
121
- save_key = f'BasicIndexEntry:{self.uuid}'
122
- # Filter out None values - Redis doesn't accept None
123
- data = {k: v for k, v in asdict(self).items() if v is not None}
124
- # redis-py typing uses untyped dict, so type checker complains about dict[str, Any]
125
- redis.hset( save_key, mapping = data ) # type: ignore[arg-type]
652
+ save_key = f"{REDIS_KEY_DATASET_ENTRY}:{self.cid}"
653
+ data = {
654
+ "name": self.name,
655
+ "schema_ref": self.schema_ref,
656
+ "data_urls": msgpack.packb(self.data_urls), # Serialize list
657
+ "cid": self.cid,
658
+ }
659
+ if self.metadata is not None:
660
+ data["metadata"] = msgpack.packb(self.metadata)
661
+ if self._legacy_uuid is not None:
662
+ data["legacy_uuid"] = self._legacy_uuid
663
+
664
+ redis.hset(save_key, mapping=data) # type: ignore[arg-type]
665
+
666
+ @classmethod
667
+ def from_redis(cls, redis: Redis, cid: str) -> "LocalDatasetEntry":
668
+ """Load an entry from Redis by CID.
669
+
670
+ Args:
671
+ redis: Redis connection to read from.
672
+ cid: Content identifier of the entry to load.
126
673
 
127
- def _s3_env( credentials_path: str | Path ) -> dict[str, Any]:
128
- """Load S3 credentials from a .env file.
674
+ Returns:
675
+ LocalDatasetEntry loaded from Redis.
676
+
677
+ Raises:
678
+ KeyError: If entry not found.
679
+ """
680
+ save_key = f"{REDIS_KEY_DATASET_ENTRY}:{cid}"
681
+ raw_data = redis.hgetall(save_key)
682
+ if not raw_data:
683
+ raise KeyError(f"{REDIS_KEY_DATASET_ENTRY} not found: {cid}")
684
+
685
+ # Decode string fields, keep binary fields as bytes for msgpack
686
+ raw_data_typed = cast(dict[bytes, bytes], raw_data)
687
+ name = raw_data_typed[b"name"].decode("utf-8")
688
+ schema_ref = raw_data_typed[b"schema_ref"].decode("utf-8")
689
+ cid_value = raw_data_typed.get(b"cid", b"").decode("utf-8") or None
690
+ legacy_uuid = raw_data_typed.get(b"legacy_uuid", b"").decode("utf-8") or None
691
+
692
+ # Deserialize msgpack fields (stored as raw bytes)
693
+ data_urls = msgpack.unpackb(raw_data_typed[b"data_urls"])
694
+ metadata = None
695
+ if b"metadata" in raw_data_typed:
696
+ metadata = msgpack.unpackb(raw_data_typed[b"metadata"])
697
+
698
+ return cls(
699
+ name=name,
700
+ schema_ref=schema_ref,
701
+ data_urls=data_urls,
702
+ metadata=metadata,
703
+ _cid=cid_value,
704
+ _legacy_uuid=legacy_uuid,
705
+ )
706
+
707
+
708
+ # Backwards compatibility alias
709
+ BasicIndexEntry = LocalDatasetEntry
710
+
711
+
712
+ def _s3_env(credentials_path: str | Path) -> dict[str, Any]:
713
+ """Load S3 credentials from .env file.
129
714
 
130
715
  Args:
131
- credentials_path: Path to .env file containing S3 credentials.
716
+ credentials_path: Path to .env file containing AWS_ENDPOINT,
717
+ AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY.
132
718
 
133
719
  Returns:
134
- Dictionary with AWS_ENDPOINT, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY.
720
+ Dict with the three required credential keys.
135
721
 
136
722
  Raises:
137
- AssertionError: If required credentials are missing from the file.
723
+ ValueError: If any required key is missing from the .env file.
138
724
  """
139
- ##
140
- credentials_path = Path( credentials_path )
141
- env_values = dotenv_values( credentials_path )
142
- assert 'AWS_ENDPOINT' in env_values
143
- assert 'AWS_ACCESS_KEY_ID' in env_values
144
- assert 'AWS_SECRET_ACCESS_KEY' in env_values
145
-
146
- return {
147
- k: env_values[k]
148
- for k in (
149
- 'AWS_ENDPOINT',
150
- 'AWS_ACCESS_KEY_ID',
151
- 'AWS_SECRET_ACCESS_KEY',
725
+ credentials_path = Path(credentials_path)
726
+ env_values = dotenv_values(credentials_path)
727
+
728
+ required_keys = ("AWS_ENDPOINT", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY")
729
+ missing = [k for k in required_keys if k not in env_values]
730
+ if missing:
731
+ raise ValueError(
732
+ f"Missing required keys in {credentials_path}: {', '.join(missing)}"
152
733
  )
153
- }
154
734
 
155
- def _s3_from_credentials( creds: str | Path | dict ) -> S3FileSystem:
156
- """Create an S3FileSystem from credentials.
735
+ return {k: env_values[k] for k in required_keys}
157
736
 
158
- Args:
159
- creds: Either a path to a .env file with credentials, or a dict
160
- containing AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and optionally
161
- AWS_ENDPOINT.
162
737
 
163
- Returns:
164
- Configured S3FileSystem instance.
165
- """
166
- ##
167
- if not isinstance( creds, dict ):
168
- creds = _s3_env( creds )
738
+ def _s3_from_credentials(creds: str | Path | dict) -> S3FileSystem:
739
+ """Create S3FileSystem from credentials dict or .env file path."""
740
+ if not isinstance(creds, dict):
741
+ creds = _s3_env(creds)
169
742
 
170
743
  # Build kwargs, making endpoint_url optional
171
744
  kwargs = {
172
- 'key': creds['AWS_ACCESS_KEY_ID'],
173
- 'secret': creds['AWS_SECRET_ACCESS_KEY']
745
+ "key": creds["AWS_ACCESS_KEY_ID"],
746
+ "secret": creds["AWS_SECRET_ACCESS_KEY"],
174
747
  }
175
- if 'AWS_ENDPOINT' in creds:
176
- kwargs['endpoint_url'] = creds['AWS_ENDPOINT']
748
+ if "AWS_ENDPOINT" in creds:
749
+ kwargs["endpoint_url"] = creds["AWS_ENDPOINT"]
177
750
 
178
751
  return S3FileSystem(**kwargs)
179
752
 
@@ -181,9 +754,17 @@ def _s3_from_credentials( creds: str | Path | dict ) -> S3FileSystem:
181
754
  ##
182
755
  # Classes
183
756
 
757
+
184
758
  class Repo:
185
759
  """Repository for storing and managing atdata datasets.
186
760
 
761
+ .. deprecated::
762
+ Use :class:`Index` with :class:`S3DataStore` instead::
763
+
764
+ store = S3DataStore(credentials, bucket="my-bucket")
765
+ index = Index(redis=redis, data_store=store)
766
+ entry = index.insert_dataset(ds, name="my-dataset")
767
+
187
768
  Provides storage of datasets in S3-compatible object storage with Redis-based
188
769
  indexing. Datasets are stored as WebDataset tar files with optional metadata.
189
770
 
@@ -197,17 +778,17 @@ class Repo:
197
778
 
198
779
  ##
199
780
 
200
- def __init__( self,
201
- #
202
- s3_credentials: str | Path | dict[str, Any] | None = None,
203
- hive_path: str | Path | None = None,
204
- redis: Redis | None = None,
205
- #
206
- #
207
- **kwargs
208
- ) -> None:
781
+ def __init__(
782
+ self,
783
+ s3_credentials: str | Path | dict[str, Any] | None = None,
784
+ hive_path: str | Path | None = None,
785
+ redis: Redis | None = None,
786
+ ) -> None:
209
787
  """Initialize a repository.
210
788
 
789
+ .. deprecated::
790
+ Use Index with S3DataStore instead.
791
+
211
792
  Args:
212
793
  s3_credentials: Path to .env file with S3 credentials, or dict with
213
794
  AWS_ENDPOINT, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY.
@@ -215,45 +796,55 @@ class Repo:
215
796
  hive_path: Path within the S3 bucket to store datasets.
216
797
  Required if s3_credentials is provided.
217
798
  redis: Redis connection for indexing. If None, creates a new connection.
218
- **kwargs: Additional arguments (reserved for future use).
219
799
 
220
800
  Raises:
221
801
  ValueError: If hive_path is not provided when s3_credentials is set.
222
802
  """
803
+ warnings.warn(
804
+ "Repo is deprecated. Use Index with S3DataStore instead:\n"
805
+ " store = S3DataStore(credentials, bucket='my-bucket')\n"
806
+ " index = Index(redis=redis, data_store=store)\n"
807
+ " entry = index.insert_dataset(ds, name='my-dataset')",
808
+ DeprecationWarning,
809
+ stacklevel=2,
810
+ )
223
811
 
224
812
  if s3_credentials is None:
225
813
  self.s3_credentials = None
226
- elif isinstance( s3_credentials, dict ):
814
+ elif isinstance(s3_credentials, dict):
227
815
  self.s3_credentials = s3_credentials
228
816
  else:
229
- self.s3_credentials = _s3_env( s3_credentials )
817
+ self.s3_credentials = _s3_env(s3_credentials)
230
818
 
231
819
  if self.s3_credentials is None:
232
820
  self.bucket_fs = None
233
821
  else:
234
- self.bucket_fs = _s3_from_credentials( self.s3_credentials )
822
+ self.bucket_fs = _s3_from_credentials(self.s3_credentials)
235
823
 
236
824
  if self.bucket_fs is not None:
237
825
  if hive_path is None:
238
- raise ValueError( 'Must specify hive path within bucket' )
239
- self.hive_path = Path( hive_path )
826
+ raise ValueError("Must specify hive path within bucket")
827
+ self.hive_path = Path(hive_path)
240
828
  self.hive_bucket = self.hive_path.parts[0]
241
829
  else:
242
830
  self.hive_path = None
243
831
  self.hive_bucket = None
244
-
832
+
245
833
  #
246
834
 
247
- self.index = Index( redis = redis )
835
+ self.index = Index(redis=redis)
248
836
 
249
837
  ##
250
838
 
251
- def insert( self, ds: Dataset[T],
252
- #
253
- cache_local: bool = False,
254
- #
255
- **kwargs
256
- ) -> tuple[BasicIndexEntry, Dataset[T]]:
839
+ def insert(
840
+ self,
841
+ ds: Dataset[T],
842
+ *,
843
+ name: str,
844
+ cache_local: bool = False,
845
+ schema_ref: str | None = None,
846
+ **kwargs,
847
+ ) -> tuple[LocalDatasetEntry, Dataset[T]]:
257
848
  """Insert a dataset into the repository.
258
849
 
259
850
  Writes the dataset to S3 as WebDataset tar files, stores metadata,
@@ -261,130 +852,102 @@ class Repo:
261
852
 
262
853
  Args:
263
854
  ds: The dataset to insert.
855
+ name: Human-readable name for the dataset.
264
856
  cache_local: If True, write to local temporary storage first, then
265
857
  copy to S3. This can be faster for some workloads.
858
+ schema_ref: Optional schema reference. If None, generates from sample type.
266
859
  **kwargs: Additional arguments passed to wds.ShardWriter.
267
860
 
268
861
  Returns:
269
862
  A tuple of (index_entry, new_dataset) where:
270
- - index_entry: BasicIndexEntry for the stored dataset
863
+ - index_entry: LocalDatasetEntry for the stored dataset
271
864
  - new_dataset: Dataset object pointing to the stored copy
272
865
 
273
866
  Raises:
274
- AssertionError: If S3 credentials or hive_path are not configured.
867
+ ValueError: If S3 credentials or hive_path are not configured.
275
868
  RuntimeError: If no shards were written.
276
869
  """
277
-
278
- assert self.s3_credentials is not None
279
- assert self.hive_bucket is not None
280
- assert self.hive_path is not None
870
+ if self.s3_credentials is None:
871
+ raise ValueError(
872
+ "S3 credentials required for insert(). Initialize Repo with s3_credentials."
873
+ )
874
+ if self.hive_bucket is None or self.hive_path is None:
875
+ raise ValueError(
876
+ "hive_path required for insert(). Initialize Repo with hive_path."
877
+ )
281
878
 
282
- new_uuid = str( uuid4() )
879
+ new_uuid = str(uuid4())
283
880
 
284
- hive_fs = _s3_from_credentials( self.s3_credentials )
881
+ hive_fs = _s3_from_credentials(self.s3_credentials)
285
882
 
286
883
  # Write metadata
287
884
  metadata_path = (
288
- self.hive_path
289
- / 'metadata'
290
- / f'atdata-metadata--{new_uuid}.msgpack'
885
+ self.hive_path / "metadata" / f"atdata-metadata--{new_uuid}.msgpack"
291
886
  )
292
887
  # Note: S3 doesn't need directories created beforehand - s3fs handles this
293
888
 
294
889
  if ds.metadata is not None:
295
890
  # Use s3:// prefix to ensure s3fs treats this as an S3 path
296
- with cast( BinaryIO, hive_fs.open( f's3://{metadata_path.as_posix()}', 'wb' ) ) as f:
297
- meta_packed = msgpack.packb( ds.metadata )
891
+ with cast(
892
+ BinaryIO, hive_fs.open(f"s3://{metadata_path.as_posix()}", "wb")
893
+ ) as f:
894
+ meta_packed = msgpack.packb(ds.metadata)
298
895
  assert meta_packed is not None
299
- f.write( cast( bytes, meta_packed ) )
300
-
896
+ f.write(cast(bytes, meta_packed))
301
897
 
302
898
  # Write data
303
- shard_pattern = (
304
- self.hive_path
305
- / f'atdata--{new_uuid}--%06d.tar'
306
- ).as_posix()
899
+ shard_pattern = (self.hive_path / f"atdata--{new_uuid}--%06d.tar").as_posix()
307
900
 
901
+ written_shards: list[str] = []
308
902
  with TemporaryDirectory() as temp_dir:
903
+ writer_opener, writer_post = _create_s3_write_callbacks(
904
+ credentials=self.s3_credentials,
905
+ temp_dir=temp_dir,
906
+ written_shards=written_shards,
907
+ fs=hive_fs,
908
+ cache_local=cache_local,
909
+ add_s3_prefix=False,
910
+ )
309
911
 
310
- if cache_local:
311
- # For cache_local, we need to use boto3 directly to avoid s3fs async issues with moto
312
- import boto3
313
-
314
- # Create boto3 client from credentials
315
- s3_client_kwargs = {
316
- 'aws_access_key_id': self.s3_credentials['AWS_ACCESS_KEY_ID'],
317
- 'aws_secret_access_key': self.s3_credentials['AWS_SECRET_ACCESS_KEY']
318
- }
319
- if 'AWS_ENDPOINT' in self.s3_credentials:
320
- s3_client_kwargs['endpoint_url'] = self.s3_credentials['AWS_ENDPOINT']
321
- s3_client = boto3.client('s3', **s3_client_kwargs)
322
-
323
- def _writer_opener( p: str ):
324
- local_cache_path = Path( temp_dir ) / p
325
- local_cache_path.parent.mkdir( parents = True, exist_ok = True )
326
- return open( local_cache_path, 'wb' )
327
- writer_opener = _writer_opener
328
-
329
- def _writer_post( p: str ):
330
- local_cache_path = Path( temp_dir ) / p
331
-
332
- # Copy to S3 using boto3 client (avoids s3fs async issues)
333
- path_parts = Path( p ).parts
334
- bucket = path_parts[0]
335
- key = str( Path( *path_parts[1:] ) )
336
-
337
- with open( local_cache_path, 'rb' ) as f_in:
338
- s3_client.put_object( Bucket=bucket, Key=key, Body=f_in.read() )
339
-
340
- # Delete local cache file
341
- local_cache_path.unlink()
342
-
343
- written_shards.append( p )
344
- writer_post = _writer_post
345
-
346
- else:
347
- # Use s3:// prefix to ensure s3fs treats paths as S3 paths
348
- writer_opener = lambda s: cast( BinaryIO, hive_fs.open( f's3://{s}', 'wb' ) )
349
- writer_post = lambda s: written_shards.append( s )
350
-
351
- written_shards = []
352
912
  with wds.writer.ShardWriter(
353
913
  shard_pattern,
354
- opener = writer_opener,
355
- post = writer_post,
914
+ opener=writer_opener,
915
+ post=writer_post,
356
916
  **kwargs,
357
917
  ) as sink:
358
- for sample in ds.ordered( batch_size = None ):
359
- sink.write( sample.as_wds )
918
+ for sample in ds.ordered(batch_size=None):
919
+ sink.write(sample.as_wds)
360
920
 
361
921
  # Make a new Dataset object for the written dataset copy
362
- if len( written_shards ) == 0:
363
- raise RuntimeError( 'Cannot form new dataset entry -- did not write any shards' )
364
-
365
- elif len( written_shards ) < 2:
922
+ if len(written_shards) == 0:
923
+ raise RuntimeError(
924
+ "Cannot form new dataset entry -- did not write any shards"
925
+ )
926
+
927
+ elif len(written_shards) < 2:
366
928
  new_dataset_url = (
367
- self.hive_path
368
- / ( Path( written_shards[0] ).name )
929
+ self.hive_path / (Path(written_shards[0]).name)
369
930
  ).as_posix()
370
931
 
371
932
  else:
372
933
  shard_s3_format = (
373
- (
374
- self.hive_path
375
- / f'atdata--{new_uuid}'
376
- ).as_posix()
377
- ) + '--{shard_id}.tar'
378
- shard_id_braced = '{' + f'{0:06d}..{len( written_shards ) - 1:06d}' + '}'
379
- new_dataset_url = shard_s3_format.format( shard_id = shard_id_braced )
934
+ (self.hive_path / f"atdata--{new_uuid}").as_posix()
935
+ ) + "--{shard_id}.tar"
936
+ shard_id_braced = "{" + f"{0:06d}..{len(written_shards) - 1:06d}" + "}"
937
+ new_dataset_url = shard_s3_format.format(shard_id=shard_id_braced)
380
938
 
381
939
  new_dataset = Dataset[ds.sample_type](
382
- url = new_dataset_url,
383
- metadata_url = metadata_path.as_posix(),
940
+ url=new_dataset_url,
941
+ metadata_url=metadata_path.as_posix(),
384
942
  )
385
943
 
386
- # Add to index
387
- new_entry = self.index.add_entry( new_dataset, uuid = new_uuid )
944
+ # Add to index (use ds._metadata to avoid network requests)
945
+ new_entry = self.index.add_entry(
946
+ new_dataset,
947
+ name=name,
948
+ schema_ref=schema_ref,
949
+ metadata=ds._metadata,
950
+ )
388
951
 
389
952
  return new_entry, new_dataset
390
953
 
@@ -392,24 +955,43 @@ class Repo:
392
955
  class Index:
393
956
  """Redis-backed index for tracking datasets in a repository.
394
957
 
395
- Maintains a registry of BasicIndexEntry objects in Redis, allowing
396
- enumeration and lookup of stored datasets.
958
+ Implements the AbstractIndex protocol. Maintains a registry of
959
+ LocalDatasetEntry objects in Redis, allowing enumeration and lookup
960
+ of stored datasets.
961
+
962
+ When initialized with a data_store, insert_dataset() will write dataset
963
+ shards to storage before indexing. Without a data_store, insert_dataset()
964
+ only indexes existing URLs.
397
965
 
398
966
  Attributes:
399
967
  _redis: Redis connection for index storage.
968
+ _data_store: Optional AbstractDataStore for writing dataset shards.
400
969
  """
401
970
 
402
971
  ##
403
972
 
404
- def __init__( self,
405
- redis: Redis | None = None,
406
- **kwargs
407
- ) -> None:
973
+ def __init__(
974
+ self,
975
+ redis: Redis | None = None,
976
+ data_store: AbstractDataStore | None = None,
977
+ auto_stubs: bool = False,
978
+ stub_dir: Path | str | None = None,
979
+ **kwargs,
980
+ ) -> None:
408
981
  """Initialize an index.
409
982
 
410
983
  Args:
411
984
  redis: Redis connection to use. If None, creates a new connection
412
985
  using the provided kwargs.
986
+ data_store: Optional data store for writing dataset shards.
987
+ If provided, insert_dataset() will write shards to this store.
988
+ If None, insert_dataset() only indexes existing URLs.
989
+ auto_stubs: If True, automatically generate .pyi stub files when
990
+ schemas are accessed via get_schema() or decode_schema().
991
+ This enables IDE autocomplete for dynamically decoded types.
992
+ stub_dir: Directory to write stub files. Only used if auto_stubs
993
+ is True or if this parameter is provided (which implies auto_stubs).
994
+ Defaults to ~/.atdata/stubs/ if not specified.
413
995
  **kwargs: Additional arguments passed to Redis() constructor if
414
996
  redis is None.
415
997
  """
@@ -418,75 +1000,721 @@ class Index:
418
1000
  if redis is not None:
419
1001
  self._redis = redis
420
1002
  else:
421
- self._redis: Redis = Redis( **kwargs )
1003
+ self._redis: Redis = Redis(**kwargs)
1004
+
1005
+ self._data_store = data_store
1006
+
1007
+ # Initialize stub manager if auto-stubs enabled
1008
+ # Providing stub_dir implies auto_stubs=True
1009
+ if auto_stubs or stub_dir is not None:
1010
+ from ._stub_manager import StubManager
1011
+
1012
+ self._stub_manager: StubManager | None = StubManager(stub_dir=stub_dir)
1013
+ else:
1014
+ self._stub_manager = None
1015
+
1016
+ # Initialize schema namespace for load_schema/schemas API
1017
+ self._schema_namespace = SchemaNamespace()
1018
+
1019
+ @property
1020
+ def data_store(self) -> AbstractDataStore | None:
1021
+ """The data store for writing shards, or None if index-only."""
1022
+ return self._data_store
422
1023
 
423
1024
  @property
424
- def all_entries( self ) -> list[BasicIndexEntry]:
425
- """Get all index entries as a list.
1025
+ def stub_dir(self) -> Path | None:
1026
+ """Directory where stub files are written, or None if auto-stubs disabled.
1027
+
1028
+ Use this path to configure your IDE for type checking support:
1029
+ - VS Code/Pylance: Add to python.analysis.extraPaths in settings.json
1030
+ - PyCharm: Mark as Sources Root
1031
+ - mypy: Add to mypy_path in mypy.ini
1032
+ """
1033
+ if self._stub_manager is not None:
1034
+ return self._stub_manager.stub_dir
1035
+ return None
1036
+
1037
+ @property
1038
+ def types(self) -> SchemaNamespace:
1039
+ """Namespace for accessing loaded schema types.
1040
+
1041
+ After calling :meth:`load_schema`, schema types become available
1042
+ as attributes on this namespace.
1043
+
1044
+ Examples:
1045
+ >>> index.load_schema("atdata://local/sampleSchema/MySample@1.0.0")
1046
+ >>> MyType = index.types.MySample
1047
+ >>> sample = MyType(name="hello", value=42)
1048
+
1049
+ Returns:
1050
+ SchemaNamespace containing all loaded schema types.
1051
+ """
1052
+ return self._schema_namespace
1053
+
1054
+ def load_schema(self, ref: str) -> Type[Packable]:
1055
+ """Load a schema and make it available in the types namespace.
1056
+
1057
+ This method decodes the schema, optionally generates a Python module
1058
+ for IDE support (if auto_stubs is enabled), and registers the type
1059
+ in the :attr:`types` namespace for easy access.
1060
+
1061
+ Args:
1062
+ ref: Schema reference string (atdata://local/sampleSchema/... or
1063
+ legacy local://schemas/...).
426
1064
 
427
1065
  Returns:
428
- List of all BasicIndexEntry objects in the index.
1066
+ The decoded PackableSample subclass. Also available via
1067
+ ``index.types.<ClassName>`` after this call.
1068
+
1069
+ Raises:
1070
+ KeyError: If schema not found.
1071
+ ValueError: If schema cannot be decoded.
1072
+
1073
+ Examples:
1074
+ >>> # Load and use immediately
1075
+ >>> MyType = index.load_schema("atdata://local/sampleSchema/MySample@1.0.0")
1076
+ >>> sample = MyType(name="hello", value=42)
1077
+ >>>
1078
+ >>> # Or access later via namespace
1079
+ >>> index.load_schema("atdata://local/sampleSchema/OtherType@1.0.0")
1080
+ >>> other = index.types.OtherType(data="test")
429
1081
  """
430
- return list( self.entries )
1082
+ # Decode the schema (uses generated module if auto_stubs enabled)
1083
+ cls = self.decode_schema(ref)
1084
+
1085
+ # Register in namespace using the class name
1086
+ self._schema_namespace._register(cls.__name__, cls)
1087
+
1088
+ return cls
1089
+
1090
+ def get_import_path(self, ref: str) -> str | None:
1091
+ """Get the import path for a schema's generated module.
1092
+
1093
+ When auto_stubs is enabled, this returns the import path that can
1094
+ be used to import the schema type with full IDE support.
1095
+
1096
+ Args:
1097
+ ref: Schema reference string.
1098
+
1099
+ Returns:
1100
+ Import path like "local.MySample_1_0_0", or None if auto_stubs
1101
+ is disabled.
1102
+
1103
+ Examples:
1104
+ >>> index = LocalIndex(auto_stubs=True)
1105
+ >>> ref = index.publish_schema(MySample, version="1.0.0")
1106
+ >>> index.load_schema(ref)
1107
+ >>> print(index.get_import_path(ref))
1108
+ local.MySample_1_0_0
1109
+ >>> # Then in your code:
1110
+ >>> # from local.MySample_1_0_0 import MySample
1111
+ """
1112
+ if self._stub_manager is None:
1113
+ return None
1114
+
1115
+ from ._stub_manager import _extract_authority
1116
+
1117
+ name, version = _parse_schema_ref(ref)
1118
+ schema_dict = self.get_schema(ref)
1119
+ authority = _extract_authority(schema_dict.get("$ref"))
1120
+
1121
+ safe_version = version.replace(".", "_")
1122
+ module_name = f"{name}_{safe_version}"
1123
+
1124
+ return f"{authority}.{module_name}"
1125
+
1126
+ def list_entries(self) -> list[LocalDatasetEntry]:
1127
+ """Get all index entries as a materialized list.
1128
+
1129
+ Returns:
1130
+ List of all LocalDatasetEntry objects in the index.
1131
+ """
1132
+ return list(self.entries)
1133
+
1134
+ # Legacy alias for backwards compatibility
1135
+ @property
1136
+ def all_entries(self) -> list[LocalDatasetEntry]:
1137
+ """Get all index entries as a list (deprecated, use list_entries())."""
1138
+ return self.list_entries()
431
1139
 
432
1140
  @property
433
- def entries( self ) -> Generator[BasicIndexEntry, None, None]:
1141
+ def entries(self) -> Generator[LocalDatasetEntry, None, None]:
434
1142
  """Iterate over all index entries.
435
1143
 
436
- Scans Redis for all BasicIndexEntry keys and yields them one at a time.
1144
+ Scans Redis for LocalDatasetEntry keys and yields them one at a time.
437
1145
 
438
1146
  Yields:
439
- BasicIndexEntry objects from the index.
1147
+ LocalDatasetEntry objects from the index.
440
1148
  """
441
- ##
442
- for key in self._redis.scan_iter( match = 'BasicIndexEntry:*' ):
443
- # hgetall returns dict[bytes, bytes] which we decode to dict[str, str]
444
- cur_entry_data = _decode_bytes_dict( cast(dict[bytes, bytes], self._redis.hgetall( key )) )
445
-
446
- # Provide default None for optional fields that may be missing
447
- # Type checker complains about None in dict[str, str], but BasicIndexEntry accepts it
448
- cur_entry_data: dict[str, Any] = dict( **cur_entry_data )
449
- cur_entry_data.setdefault('metadata_url', None)
450
-
451
- cur_entry = BasicIndexEntry( **cur_entry_data )
452
- yield cur_entry
453
-
454
- return
455
-
456
- def add_entry( self, ds: Dataset,
457
- uuid: str | None = None,
458
- ) -> BasicIndexEntry:
1149
+ prefix = f"{REDIS_KEY_DATASET_ENTRY}:"
1150
+ for key in self._redis.scan_iter(match=f"{prefix}*"):
1151
+ key_str = key.decode("utf-8") if isinstance(key, bytes) else key
1152
+ cid = key_str[len(prefix) :]
1153
+ yield LocalDatasetEntry.from_redis(self._redis, cid)
1154
+
1155
+ def add_entry(
1156
+ self,
1157
+ ds: Dataset,
1158
+ *,
1159
+ name: str,
1160
+ schema_ref: str | None = None,
1161
+ metadata: dict | None = None,
1162
+ ) -> LocalDatasetEntry:
459
1163
  """Add a dataset to the index.
460
1164
 
461
- Creates a BasicIndexEntry for the dataset and persists it to Redis.
1165
+ Creates a LocalDatasetEntry for the dataset and persists it to Redis.
462
1166
 
463
1167
  Args:
464
1168
  ds: The dataset to add to the index.
465
- uuid: Optional UUID for the entry. If None, a new UUID is generated.
1169
+ name: Human-readable name for the dataset.
1170
+ schema_ref: Optional schema reference. If None, generates from sample type.
1171
+ metadata: Optional metadata dictionary. If None, uses ds._metadata if available.
466
1172
 
467
1173
  Returns:
468
- The created BasicIndexEntry object.
1174
+ The created LocalDatasetEntry object.
469
1175
  """
470
1176
  ##
471
- temp_sample_kind = _kind_str_for_sample_type( ds.sample_type )
1177
+ if schema_ref is None:
1178
+ schema_ref = (
1179
+ f"local://schemas/{_kind_str_for_sample_type(ds.sample_type)}@1.0.0"
1180
+ )
1181
+
1182
+ # Normalize URL to list
1183
+ data_urls = [ds.url]
1184
+
1185
+ # Use provided metadata, or fall back to dataset's cached metadata
1186
+ # (avoid triggering network requests via ds.metadata property)
1187
+ entry_metadata = metadata if metadata is not None else ds._metadata
1188
+
1189
+ entry = LocalDatasetEntry(
1190
+ name=name,
1191
+ schema_ref=schema_ref,
1192
+ data_urls=data_urls,
1193
+ metadata=entry_metadata,
1194
+ )
1195
+
1196
+ entry.write_to(self._redis)
1197
+
1198
+ return entry
1199
+
1200
+ def get_entry(self, cid: str) -> LocalDatasetEntry:
1201
+ """Get an entry by its CID.
1202
+
1203
+ Args:
1204
+ cid: Content identifier of the entry.
472
1205
 
473
- if uuid is None:
474
- ret_data = BasicIndexEntry(
475
- wds_url = ds.url,
476
- sample_kind = temp_sample_kind,
477
- metadata_url = ds.metadata_url,
1206
+ Returns:
1207
+ LocalDatasetEntry for the given CID.
1208
+
1209
+ Raises:
1210
+ KeyError: If entry not found.
1211
+ """
1212
+ return LocalDatasetEntry.from_redis(self._redis, cid)
1213
+
1214
+ def get_entry_by_name(self, name: str) -> LocalDatasetEntry:
1215
+ """Get an entry by its human-readable name.
1216
+
1217
+ Args:
1218
+ name: Human-readable name of the entry.
1219
+
1220
+ Returns:
1221
+ LocalDatasetEntry with the given name.
1222
+
1223
+ Raises:
1224
+ KeyError: If no entry with that name exists.
1225
+ """
1226
+ for entry in self.entries:
1227
+ if entry.name == name:
1228
+ return entry
1229
+ raise KeyError(f"No entry with name: {name}")
1230
+
1231
+ # AbstractIndex protocol methods
1232
+
1233
+ def insert_dataset(
1234
+ self,
1235
+ ds: Dataset,
1236
+ *,
1237
+ name: str,
1238
+ schema_ref: str | None = None,
1239
+ **kwargs,
1240
+ ) -> LocalDatasetEntry:
1241
+ """Insert a dataset into the index (AbstractIndex protocol).
1242
+
1243
+ If a data_store was provided at initialization, writes dataset shards
1244
+ to storage first, then indexes the new URLs. Otherwise, indexes the
1245
+ dataset's existing URL.
1246
+
1247
+ Args:
1248
+ ds: The Dataset to register.
1249
+ name: Human-readable name for the dataset.
1250
+ schema_ref: Optional schema reference.
1251
+ **kwargs: Additional options:
1252
+ - metadata: Optional metadata dict
1253
+ - prefix: Storage prefix (default: dataset name)
1254
+ - cache_local: If True, cache writes locally first
1255
+
1256
+ Returns:
1257
+ IndexEntry for the inserted dataset.
1258
+ """
1259
+ metadata = kwargs.get("metadata")
1260
+
1261
+ if self._data_store is not None:
1262
+ # Write shards to data store, then index the new URLs
1263
+ prefix = kwargs.get("prefix", name)
1264
+ cache_local = kwargs.get("cache_local", False)
1265
+
1266
+ written_urls = self._data_store.write_shards(
1267
+ ds,
1268
+ prefix=prefix,
1269
+ cache_local=cache_local,
478
1270
  )
1271
+
1272
+ # Generate schema_ref if not provided
1273
+ if schema_ref is None:
1274
+ schema_ref = _schema_ref_from_type(ds.sample_type, version="1.0.0")
1275
+
1276
+ # Create entry with the written URLs
1277
+ entry_metadata = metadata if metadata is not None else ds._metadata
1278
+ entry = LocalDatasetEntry(
1279
+ name=name,
1280
+ schema_ref=schema_ref,
1281
+ data_urls=written_urls,
1282
+ metadata=entry_metadata,
1283
+ )
1284
+ entry.write_to(self._redis)
1285
+ return entry
1286
+
1287
+ # No data store - just index the existing URL
1288
+ return self.add_entry(ds, name=name, schema_ref=schema_ref, metadata=metadata)
1289
+
1290
+ def get_dataset(self, ref: str) -> LocalDatasetEntry:
1291
+ """Get a dataset entry by name (AbstractIndex protocol).
1292
+
1293
+ Args:
1294
+ ref: Dataset name.
1295
+
1296
+ Returns:
1297
+ IndexEntry for the dataset.
1298
+
1299
+ Raises:
1300
+ KeyError: If dataset not found.
1301
+ """
1302
+ return self.get_entry_by_name(ref)
1303
+
1304
+ @property
1305
+ def datasets(self) -> Generator[LocalDatasetEntry, None, None]:
1306
+ """Lazily iterate over all dataset entries (AbstractIndex protocol).
1307
+
1308
+ Yields:
1309
+ IndexEntry for each dataset.
1310
+ """
1311
+ return self.entries
1312
+
1313
+ def list_datasets(self) -> list[LocalDatasetEntry]:
1314
+ """Get all dataset entries as a materialized list (AbstractIndex protocol).
1315
+
1316
+ Returns:
1317
+ List of IndexEntry for each dataset.
1318
+ """
1319
+ return self.list_entries()
1320
+
1321
+ # Schema operations
1322
+
1323
+ def _get_latest_schema_version(self, name: str) -> str | None:
1324
+ """Get the latest version for a schema by name, or None if not found."""
1325
+ latest_version: tuple[int, int, int] | None = None
1326
+ latest_version_str: str | None = None
1327
+
1328
+ prefix = f"{REDIS_KEY_SCHEMA}:"
1329
+ for key in self._redis.scan_iter(match=f"{prefix}*"):
1330
+ key_str = key.decode("utf-8") if isinstance(key, bytes) else key
1331
+ schema_id = key_str[len(prefix) :]
1332
+
1333
+ if "@" not in schema_id:
1334
+ continue
1335
+
1336
+ schema_name, version_str = schema_id.rsplit("@", 1)
1337
+ # Handle legacy format: module.Class -> Class
1338
+ if "." in schema_name:
1339
+ schema_name = schema_name.rsplit(".", 1)[1]
1340
+
1341
+ if schema_name != name:
1342
+ continue
1343
+
1344
+ try:
1345
+ version_tuple = _parse_semver(version_str)
1346
+ if latest_version is None or version_tuple > latest_version:
1347
+ latest_version = version_tuple
1348
+ latest_version_str = version_str
1349
+ except ValueError:
1350
+ continue
1351
+
1352
+ return latest_version_str
1353
+
1354
+ def publish_schema(
1355
+ self,
1356
+ sample_type: type,
1357
+ *,
1358
+ version: str | None = None,
1359
+ description: str | None = None,
1360
+ ) -> str:
1361
+ """Publish a schema for a sample type to Redis.
1362
+
1363
+ Args:
1364
+ sample_type: A Packable type (@packable-decorated or PackableSample subclass).
1365
+ version: Semantic version string (e.g., '1.0.0'). If None,
1366
+ auto-increments from the latest published version (patch bump),
1367
+ or starts at '1.0.0' if no previous version exists.
1368
+ description: Optional human-readable description. If None, uses
1369
+ the class docstring.
1370
+
1371
+ Returns:
1372
+ Schema reference string: 'atdata://local/sampleSchema/{name}@{version}'.
1373
+
1374
+ Raises:
1375
+ ValueError: If sample_type is not a dataclass.
1376
+ TypeError: If sample_type doesn't satisfy the Packable protocol,
1377
+ or if a field type is not supported.
1378
+ """
1379
+ # Validate that sample_type satisfies Packable protocol at runtime
1380
+ # This catches non-packable types early with a clear error message
1381
+ try:
1382
+ # Check protocol compliance by verifying required methods exist
1383
+ if not (
1384
+ hasattr(sample_type, "from_data")
1385
+ and hasattr(sample_type, "from_bytes")
1386
+ and callable(getattr(sample_type, "from_data", None))
1387
+ and callable(getattr(sample_type, "from_bytes", None))
1388
+ ):
1389
+ raise TypeError(
1390
+ f"{sample_type.__name__} does not satisfy the Packable protocol. "
1391
+ "Use @packable decorator or inherit from PackableSample."
1392
+ )
1393
+ except AttributeError:
1394
+ raise TypeError(
1395
+ f"sample_type must be a class, got {type(sample_type).__name__}"
1396
+ )
1397
+
1398
+ # Auto-increment version if not specified
1399
+ if version is None:
1400
+ latest = self._get_latest_schema_version(sample_type.__name__)
1401
+ if latest is None:
1402
+ version = "1.0.0"
1403
+ else:
1404
+ version = _increment_patch(latest)
1405
+
1406
+ schema_record = _build_schema_record(
1407
+ sample_type,
1408
+ version=version,
1409
+ description=description,
1410
+ )
1411
+
1412
+ schema_ref = _schema_ref_from_type(sample_type, version)
1413
+ name, _ = _parse_schema_ref(schema_ref)
1414
+
1415
+ # Store in Redis
1416
+ redis_key = f"{REDIS_KEY_SCHEMA}:{name}@{version}"
1417
+ schema_json = json.dumps(schema_record)
1418
+ self._redis.set(redis_key, schema_json)
1419
+
1420
+ return schema_ref
1421
+
1422
+ def get_schema(self, ref: str) -> dict:
1423
+ """Get a schema record by reference (AbstractIndex protocol).
1424
+
1425
+ Args:
1426
+ ref: Schema reference string. Supports both new format
1427
+ (atdata://local/sampleSchema/{name}@{version}) and legacy
1428
+ format (local://schemas/{module.Class}@{version}).
1429
+
1430
+ Returns:
1431
+ Schema record as a dictionary with keys 'name', 'version',
1432
+ 'fields', '$ref', etc.
1433
+
1434
+ Raises:
1435
+ KeyError: If schema not found.
1436
+ ValueError: If reference format is invalid.
1437
+ """
1438
+ name, version = _parse_schema_ref(ref)
1439
+ redis_key = f"{REDIS_KEY_SCHEMA}:{name}@{version}"
1440
+
1441
+ schema_json = self._redis.get(redis_key)
1442
+ if schema_json is None:
1443
+ raise KeyError(f"Schema not found: {ref}")
1444
+
1445
+ if isinstance(schema_json, bytes):
1446
+ schema_json = schema_json.decode("utf-8")
1447
+
1448
+ schema = json.loads(schema_json)
1449
+ schema["$ref"] = _make_schema_ref(name, version)
1450
+
1451
+ # Auto-generate stub if enabled
1452
+ if self._stub_manager is not None:
1453
+ record = LocalSchemaRecord.from_dict(schema)
1454
+ self._stub_manager.ensure_stub(record)
1455
+
1456
+ return schema
1457
+
1458
+ def get_schema_record(self, ref: str) -> LocalSchemaRecord:
1459
+ """Get a schema record as LocalSchemaRecord object.
1460
+
1461
+ Use this when you need the full LocalSchemaRecord with typed properties.
1462
+ For Protocol-compliant dict access, use get_schema() instead.
1463
+
1464
+ Args:
1465
+ ref: Schema reference string.
1466
+
1467
+ Returns:
1468
+ LocalSchemaRecord with schema details.
1469
+
1470
+ Raises:
1471
+ KeyError: If schema not found.
1472
+ ValueError: If reference format is invalid.
1473
+ """
1474
+ schema = self.get_schema(ref)
1475
+ return LocalSchemaRecord.from_dict(schema)
1476
+
1477
+ @property
1478
+ def schemas(self) -> Generator[LocalSchemaRecord, None, None]:
1479
+ """Iterate over all schema records in this index.
1480
+
1481
+ Yields:
1482
+ LocalSchemaRecord for each schema.
1483
+ """
1484
+ prefix = f"{REDIS_KEY_SCHEMA}:"
1485
+ for key in self._redis.scan_iter(match=f"{prefix}*"):
1486
+ key_str = key.decode("utf-8") if isinstance(key, bytes) else key
1487
+ # Extract name@version from key
1488
+ schema_id = key_str[len(prefix) :]
1489
+
1490
+ schema_json = self._redis.get(key)
1491
+ if schema_json is None:
1492
+ continue
1493
+
1494
+ if isinstance(schema_json, bytes):
1495
+ schema_json = schema_json.decode("utf-8")
1496
+
1497
+ schema = json.loads(schema_json)
1498
+ # Handle legacy keys that have module.Class format
1499
+ if "." in schema_id.split("@")[0]:
1500
+ name = schema_id.split("@")[0].rsplit(".", 1)[1]
1501
+ version = schema_id.split("@")[1]
1502
+ schema["$ref"] = _make_schema_ref(name, version)
1503
+ else:
1504
+ # schema_id is already "name@version"
1505
+ name, version = schema_id.rsplit("@", 1)
1506
+ schema["$ref"] = _make_schema_ref(name, version)
1507
+ yield LocalSchemaRecord.from_dict(schema)
1508
+
1509
+ def list_schemas(self) -> list[dict]:
1510
+ """Get all schema records as a materialized list (AbstractIndex protocol).
1511
+
1512
+ Returns:
1513
+ List of schema records as dictionaries.
1514
+ """
1515
+ return [record.to_dict() for record in self.schemas]
1516
+
1517
+ def decode_schema(self, ref: str) -> Type[Packable]:
1518
+ """Reconstruct a Python PackableSample type from a stored schema.
1519
+
1520
+ This method enables loading datasets without knowing the sample type
1521
+ ahead of time. The index retrieves the schema record and dynamically
1522
+ generates a PackableSample subclass matching the schema definition.
1523
+
1524
+ If auto_stubs is enabled, a Python module will be generated and the
1525
+ class will be imported from it, providing full IDE autocomplete support.
1526
+ The returned class has proper type information that IDEs can understand.
1527
+
1528
+ Args:
1529
+ ref: Schema reference string (atdata://local/sampleSchema/... or
1530
+ legacy local://schemas/...).
1531
+
1532
+ Returns:
1533
+ A PackableSample subclass - either imported from a generated module
1534
+ (if auto_stubs is enabled) or dynamically created.
1535
+
1536
+ Raises:
1537
+ KeyError: If schema not found.
1538
+ ValueError: If schema cannot be decoded.
1539
+ """
1540
+ schema_dict = self.get_schema(ref)
1541
+
1542
+ # If auto_stubs is enabled, generate module and import class from it
1543
+ if self._stub_manager is not None:
1544
+ cls = self._stub_manager.ensure_module(schema_dict)
1545
+ if cls is not None:
1546
+ return cls
1547
+
1548
+ # Fall back to dynamic type generation
1549
+ from atdata._schema_codec import schema_to_type
1550
+
1551
+ return schema_to_type(schema_dict)
1552
+
1553
+ def decode_schema_as(self, ref: str, type_hint: type[T]) -> type[T]:
1554
+ """Decode a schema with explicit type hint for IDE support.
1555
+
1556
+ This is a typed wrapper around decode_schema() that preserves the
1557
+ type information for IDE autocomplete. Use this when you have a
1558
+ stub file for the schema and want full IDE support.
1559
+
1560
+ Args:
1561
+ ref: Schema reference string.
1562
+ type_hint: The stub type to use for type hints. Import this from
1563
+ the generated stub file.
1564
+
1565
+ Returns:
1566
+ The decoded type, cast to match the type_hint for IDE support.
1567
+
1568
+ Examples:
1569
+ >>> # After enabling auto_stubs and configuring IDE extraPaths:
1570
+ >>> from local.MySample_1_0_0 import MySample
1571
+ >>>
1572
+ >>> # This gives full IDE autocomplete:
1573
+ >>> DecodedType = index.decode_schema_as(ref, MySample)
1574
+ >>> sample = DecodedType(text="hello", value=42) # IDE knows signature!
1575
+
1576
+ Note:
1577
+ The type_hint is only used for static type checking - at runtime,
1578
+ the actual decoded type from the schema is returned. Ensure the
1579
+ stub matches the schema to avoid runtime surprises.
1580
+ """
1581
+ from typing import cast
1582
+
1583
+ return cast(type[T], self.decode_schema(ref))
1584
+
1585
+ def clear_stubs(self) -> int:
1586
+ """Remove all auto-generated stub files.
1587
+
1588
+ Only works if auto_stubs was enabled when creating the Index.
1589
+
1590
+ Returns:
1591
+ Number of stub files removed, or 0 if auto_stubs is disabled.
1592
+ """
1593
+ if self._stub_manager is not None:
1594
+ return self._stub_manager.clear_stubs()
1595
+ return 0
1596
+
1597
+
1598
+ # Backwards compatibility alias
1599
+ LocalIndex = Index
1600
+
1601
+
1602
+ class S3DataStore:
1603
+ """S3-compatible data store implementing AbstractDataStore protocol.
1604
+
1605
+ Handles writing dataset shards to S3-compatible object storage and
1606
+ resolving URLs for reading.
1607
+
1608
+ Attributes:
1609
+ credentials: S3 credentials dictionary.
1610
+ bucket: Target bucket name.
1611
+ _fs: S3FileSystem instance.
1612
+ """
1613
+
1614
+ def __init__(
1615
+ self,
1616
+ credentials: str | Path | dict[str, Any],
1617
+ *,
1618
+ bucket: str,
1619
+ ) -> None:
1620
+ """Initialize an S3 data store.
1621
+
1622
+ Args:
1623
+ credentials: Path to .env file or dict with AWS_ACCESS_KEY_ID,
1624
+ AWS_SECRET_ACCESS_KEY, and optionally AWS_ENDPOINT.
1625
+ bucket: Name of the S3 bucket for storage.
1626
+ """
1627
+ if isinstance(credentials, dict):
1628
+ self.credentials = credentials
479
1629
  else:
480
- ret_data = BasicIndexEntry(
481
- wds_url = ds.url,
482
- sample_kind = temp_sample_kind,
483
- metadata_url = ds.metadata_url,
484
- uuid = uuid,
1630
+ self.credentials = _s3_env(credentials)
1631
+
1632
+ self.bucket = bucket
1633
+ self._fs = _s3_from_credentials(self.credentials)
1634
+
1635
+ def write_shards(
1636
+ self,
1637
+ ds: Dataset,
1638
+ *,
1639
+ prefix: str,
1640
+ cache_local: bool = False,
1641
+ **kwargs,
1642
+ ) -> list[str]:
1643
+ """Write dataset shards to S3.
1644
+
1645
+ Args:
1646
+ ds: The Dataset to write.
1647
+ prefix: Path prefix within bucket (e.g., 'datasets/mnist/v1').
1648
+ cache_local: If True, write locally first then copy to S3.
1649
+ **kwargs: Additional args passed to wds.ShardWriter (e.g., maxcount).
1650
+
1651
+ Returns:
1652
+ List of S3 URLs for the written shards.
1653
+
1654
+ Raises:
1655
+ RuntimeError: If no shards were written.
1656
+ """
1657
+ new_uuid = str(uuid4())
1658
+ shard_pattern = f"{self.bucket}/{prefix}/data--{new_uuid}--%06d.tar"
1659
+
1660
+ written_shards: list[str] = []
1661
+
1662
+ with TemporaryDirectory() as temp_dir:
1663
+ writer_opener, writer_post = _create_s3_write_callbacks(
1664
+ credentials=self.credentials,
1665
+ temp_dir=temp_dir,
1666
+ written_shards=written_shards,
1667
+ fs=self._fs,
1668
+ cache_local=cache_local,
1669
+ add_s3_prefix=True,
485
1670
  )
486
1671
 
487
- ret_data.write_to( self._redis )
1672
+ with wds.writer.ShardWriter(
1673
+ shard_pattern,
1674
+ opener=writer_opener,
1675
+ post=writer_post,
1676
+ **kwargs,
1677
+ ) as sink:
1678
+ for sample in ds.ordered(batch_size=None):
1679
+ sink.write(sample.as_wds)
1680
+
1681
+ if len(written_shards) == 0:
1682
+ raise RuntimeError("No shards written")
1683
+
1684
+ return written_shards
488
1685
 
489
- return ret_data
1686
+ def read_url(self, url: str) -> str:
1687
+ """Resolve an S3 URL for reading/streaming.
1688
+
1689
+ For S3-compatible stores with custom endpoints (like Cloudflare R2,
1690
+ MinIO, etc.), converts s3:// URLs to HTTPS URLs that WebDataset can
1691
+ stream directly.
1692
+
1693
+ For standard AWS S3 (no custom endpoint), URLs are returned unchanged
1694
+ since WebDataset's built-in s3fs integration handles them.
1695
+
1696
+ Args:
1697
+ url: S3 URL to resolve (e.g., 's3://bucket/path/file.tar').
1698
+
1699
+ Returns:
1700
+ HTTPS URL if custom endpoint is configured, otherwise unchanged.
1701
+ Example: 's3://bucket/path' -> 'https://endpoint.com/bucket/path'
1702
+ """
1703
+ endpoint = self.credentials.get("AWS_ENDPOINT")
1704
+ if endpoint and url.startswith("s3://"):
1705
+ # s3://bucket/path -> https://endpoint/bucket/path
1706
+ path = url[5:] # Remove 's3://' prefix
1707
+ endpoint = endpoint.rstrip("/")
1708
+ return f"{endpoint}/{path}"
1709
+ return url
1710
+
1711
+ def supports_streaming(self) -> bool:
1712
+ """S3 supports streaming reads.
1713
+
1714
+ Returns:
1715
+ True.
1716
+ """
1717
+ return True
490
1718
 
491
1719
 
492
- #
1720
+ #