atdata 0.2.2b1__py3-none-any.whl → 0.2.3b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atdata/local.py CHANGED
@@ -24,13 +24,12 @@ from atdata import (
24
24
  )
25
25
  from atdata._cid import generate_cid
26
26
  from atdata._type_utils import (
27
- numpy_dtype_to_string,
28
27
  PRIMITIVE_TYPE_MAP,
29
28
  unwrap_optional,
30
29
  is_ndarray_type,
31
30
  extract_ndarray_dtype,
32
31
  )
33
- from atdata._protocols import IndexEntry, AbstractDataStore, Packable
32
+ from atdata._protocols import AbstractDataStore, Packable
34
33
 
35
34
  from pathlib import Path
36
35
  from uuid import uuid4
@@ -57,7 +56,6 @@ from typing import (
57
56
  Generator,
58
57
  Iterator,
59
58
  BinaryIO,
60
- Union,
61
59
  Optional,
62
60
  Literal,
63
61
  cast,
@@ -70,7 +68,7 @@ from datetime import datetime, timezone
70
68
  import json
71
69
  import warnings
72
70
 
73
- T = TypeVar( 'T', bound = PackableSample )
71
+ T = TypeVar("T", bound=PackableSample)
74
72
 
75
73
  # Redis key prefixes for index entries and schemas
76
74
  REDIS_KEY_DATASET_ENTRY = "LocalDatasetEntry"
@@ -84,12 +82,10 @@ class SchemaNamespace:
84
82
  loaded schema types. After calling ``index.load_schema(uri)``, the
85
83
  schema's class becomes available as an attribute on this namespace.
86
84
 
87
- Example:
88
- ::
89
-
90
- >>> index.load_schema("atdata://local/sampleSchema/MySample@1.0.0")
91
- >>> MyType = index.types.MySample
92
- >>> sample = MyType(field1="hello", field2=42)
85
+ Examples:
86
+ >>> index.load_schema("atdata://local/sampleSchema/MySample@1.0.0")
87
+ >>> MyType = index.types.MySample
88
+ >>> sample = MyType(field1="hello", field2=42)
93
89
 
94
90
  The namespace supports:
95
91
  - Attribute access: ``index.types.MySample``
@@ -357,9 +353,10 @@ class LocalSchemaRecord:
357
353
  ##
358
354
  # Helpers
359
355
 
360
- def _kind_str_for_sample_type( st: Type[Packable] ) -> str:
356
+
357
+ def _kind_str_for_sample_type(st: Type[Packable]) -> str:
361
358
  """Return fully-qualified 'module.name' string for a sample type."""
362
- return f'{st.__module__}.{st.__name__}'
359
+ return f"{st.__module__}.{st.__name__}"
363
360
 
364
361
 
365
362
  def _create_s3_write_callbacks(
@@ -387,17 +384,17 @@ def _create_s3_write_callbacks(
387
384
  import boto3
388
385
 
389
386
  s3_client_kwargs = {
390
- 'aws_access_key_id': credentials['AWS_ACCESS_KEY_ID'],
391
- 'aws_secret_access_key': credentials['AWS_SECRET_ACCESS_KEY']
387
+ "aws_access_key_id": credentials["AWS_ACCESS_KEY_ID"],
388
+ "aws_secret_access_key": credentials["AWS_SECRET_ACCESS_KEY"],
392
389
  }
393
- if 'AWS_ENDPOINT' in credentials:
394
- s3_client_kwargs['endpoint_url'] = credentials['AWS_ENDPOINT']
395
- s3_client = boto3.client('s3', **s3_client_kwargs)
390
+ if "AWS_ENDPOINT" in credentials:
391
+ s3_client_kwargs["endpoint_url"] = credentials["AWS_ENDPOINT"]
392
+ s3_client = boto3.client("s3", **s3_client_kwargs)
396
393
 
397
394
  def _writer_opener(p: str):
398
395
  local_path = Path(temp_dir) / p
399
396
  local_path.parent.mkdir(parents=True, exist_ok=True)
400
- return open(local_path, 'wb')
397
+ return open(local_path, "wb")
401
398
 
402
399
  def _writer_post(p: str):
403
400
  local_path = Path(temp_dir) / p
@@ -405,7 +402,7 @@ def _create_s3_write_callbacks(
405
402
  bucket = path_parts[0]
406
403
  key = str(Path(*path_parts[1:]))
407
404
 
408
- with open(local_path, 'rb') as f_in:
405
+ with open(local_path, "rb") as f_in:
409
406
  s3_client.put_object(Bucket=bucket, Key=key, Body=f_in.read())
410
407
 
411
408
  local_path.unlink()
@@ -419,7 +416,7 @@ def _create_s3_write_callbacks(
419
416
  assert fs is not None, "S3FileSystem required when cache_local=False"
420
417
 
421
418
  def _direct_opener(s: str):
422
- return cast(BinaryIO, fs.open(f's3://{s}', 'wb'))
419
+ return cast(BinaryIO, fs.open(f"s3://{s}", "wb"))
423
420
 
424
421
  def _direct_post(s: str):
425
422
  if add_s3_prefix:
@@ -429,6 +426,7 @@ def _create_s3_write_callbacks(
429
426
 
430
427
  return _direct_opener, _direct_post
431
428
 
429
+
432
430
  ##
433
431
  # Schema helpers
434
432
 
@@ -454,9 +452,9 @@ def _parse_schema_ref(ref: str) -> tuple[str, str]:
454
452
  and legacy format: 'local://schemas/{module.Class}@{version}'
455
453
  """
456
454
  if ref.startswith(_ATDATA_URI_PREFIX):
457
- path = ref[len(_ATDATA_URI_PREFIX):]
455
+ path = ref[len(_ATDATA_URI_PREFIX) :]
458
456
  elif ref.startswith(_LEGACY_URI_PREFIX):
459
- path = ref[len(_LEGACY_URI_PREFIX):]
457
+ path = ref[len(_LEGACY_URI_PREFIX) :]
460
458
  else:
461
459
  raise ValueError(f"Invalid schema reference: {ref}")
462
460
 
@@ -487,7 +485,10 @@ def _increment_patch(version: str) -> str:
487
485
  def _python_type_to_field_type(python_type: Any) -> dict:
488
486
  """Convert Python type annotation to schema field type dict."""
489
487
  if python_type in PRIMITIVE_TYPE_MAP:
490
- return {"$type": "local#primitive", "primitive": PRIMITIVE_TYPE_MAP[python_type]}
488
+ return {
489
+ "$type": "local#primitive",
490
+ "primitive": PRIMITIVE_TYPE_MAP[python_type],
491
+ }
491
492
 
492
493
  if is_ndarray_type(python_type):
493
494
  return {"$type": "local#ndarray", "dtype": extract_ndarray_dtype(python_type)}
@@ -495,7 +496,11 @@ def _python_type_to_field_type(python_type: Any) -> dict:
495
496
  origin = get_origin(python_type)
496
497
  if origin is list:
497
498
  args = get_args(python_type)
498
- items = _python_type_to_field_type(args[0]) if args else {"$type": "local#primitive", "primitive": "str"}
499
+ items = (
500
+ _python_type_to_field_type(args[0])
501
+ if args
502
+ else {"$type": "local#primitive", "primitive": "str"}
503
+ )
499
504
  return {"$type": "local#array", "items": items}
500
505
 
501
506
  if is_dataclass(python_type):
@@ -543,11 +548,13 @@ def _build_schema_record(
543
548
  field_type, is_optional = unwrap_optional(field_type)
544
549
  field_type_dict = _python_type_to_field_type(field_type)
545
550
 
546
- field_defs.append({
547
- "name": f.name,
548
- "fieldType": field_type_dict,
549
- "optional": is_optional,
550
- })
551
+ field_defs.append(
552
+ {
553
+ "name": f.name,
554
+ "fieldType": field_type_dict,
555
+ "optional": is_optional,
556
+ }
557
+ )
551
558
 
552
559
  return {
553
560
  "name": sample_type.__name__,
@@ -561,6 +568,7 @@ def _build_schema_record(
561
568
  ##
562
569
  # Redis object model
563
570
 
571
+
564
572
  @dataclass
565
573
  class LocalDatasetEntry:
566
574
  """Index entry for a dataset stored in the local repository.
@@ -579,6 +587,7 @@ class LocalDatasetEntry:
579
587
  data_urls: WebDataset URLs for the data.
580
588
  metadata: Arbitrary metadata dictionary, or None if not set.
581
589
  """
590
+
582
591
  ##
583
592
 
584
593
  name: str
@@ -640,17 +649,17 @@ class LocalDatasetEntry:
640
649
  Args:
641
650
  redis: Redis connection to write to.
642
651
  """
643
- save_key = f'{REDIS_KEY_DATASET_ENTRY}:{self.cid}'
652
+ save_key = f"{REDIS_KEY_DATASET_ENTRY}:{self.cid}"
644
653
  data = {
645
- 'name': self.name,
646
- 'schema_ref': self.schema_ref,
647
- 'data_urls': msgpack.packb(self.data_urls), # Serialize list
648
- 'cid': self.cid,
654
+ "name": self.name,
655
+ "schema_ref": self.schema_ref,
656
+ "data_urls": msgpack.packb(self.data_urls), # Serialize list
657
+ "cid": self.cid,
649
658
  }
650
659
  if self.metadata is not None:
651
- data['metadata'] = msgpack.packb(self.metadata)
660
+ data["metadata"] = msgpack.packb(self.metadata)
652
661
  if self._legacy_uuid is not None:
653
- data['legacy_uuid'] = self._legacy_uuid
662
+ data["legacy_uuid"] = self._legacy_uuid
654
663
 
655
664
  redis.hset(save_key, mapping=data) # type: ignore[arg-type]
656
665
 
@@ -668,23 +677,23 @@ class LocalDatasetEntry:
668
677
  Raises:
669
678
  KeyError: If entry not found.
670
679
  """
671
- save_key = f'{REDIS_KEY_DATASET_ENTRY}:{cid}'
680
+ save_key = f"{REDIS_KEY_DATASET_ENTRY}:{cid}"
672
681
  raw_data = redis.hgetall(save_key)
673
682
  if not raw_data:
674
683
  raise KeyError(f"{REDIS_KEY_DATASET_ENTRY} not found: {cid}")
675
684
 
676
685
  # Decode string fields, keep binary fields as bytes for msgpack
677
686
  raw_data_typed = cast(dict[bytes, bytes], raw_data)
678
- name = raw_data_typed[b'name'].decode('utf-8')
679
- schema_ref = raw_data_typed[b'schema_ref'].decode('utf-8')
680
- cid_value = raw_data_typed.get(b'cid', b'').decode('utf-8') or None
681
- legacy_uuid = raw_data_typed.get(b'legacy_uuid', b'').decode('utf-8') or None
687
+ name = raw_data_typed[b"name"].decode("utf-8")
688
+ schema_ref = raw_data_typed[b"schema_ref"].decode("utf-8")
689
+ cid_value = raw_data_typed.get(b"cid", b"").decode("utf-8") or None
690
+ legacy_uuid = raw_data_typed.get(b"legacy_uuid", b"").decode("utf-8") or None
682
691
 
683
692
  # Deserialize msgpack fields (stored as raw bytes)
684
- data_urls = msgpack.unpackb(raw_data_typed[b'data_urls'])
693
+ data_urls = msgpack.unpackb(raw_data_typed[b"data_urls"])
685
694
  metadata = None
686
- if b'metadata' in raw_data_typed:
687
- metadata = msgpack.unpackb(raw_data_typed[b'metadata'])
695
+ if b"metadata" in raw_data_typed:
696
+ metadata = msgpack.unpackb(raw_data_typed[b"metadata"])
688
697
 
689
698
  return cls(
690
699
  name=name,
@@ -699,7 +708,8 @@ class LocalDatasetEntry:
699
708
  # Backwards compatibility alias
700
709
  BasicIndexEntry = LocalDatasetEntry
701
710
 
702
- def _s3_env( credentials_path: str | Path ) -> dict[str, Any]:
711
+
712
+ def _s3_env(credentials_path: str | Path) -> dict[str, Any]:
703
713
  """Load S3 credentials from .env file.
704
714
 
705
715
  Args:
@@ -712,28 +722,31 @@ def _s3_env( credentials_path: str | Path ) -> dict[str, Any]:
712
722
  Raises:
713
723
  ValueError: If any required key is missing from the .env file.
714
724
  """
715
- credentials_path = Path( credentials_path )
716
- env_values = dotenv_values( credentials_path )
725
+ credentials_path = Path(credentials_path)
726
+ env_values = dotenv_values(credentials_path)
717
727
 
718
- required_keys = ('AWS_ENDPOINT', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY')
728
+ required_keys = ("AWS_ENDPOINT", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY")
719
729
  missing = [k for k in required_keys if k not in env_values]
720
730
  if missing:
721
- raise ValueError(f"Missing required keys in {credentials_path}: {', '.join(missing)}")
731
+ raise ValueError(
732
+ f"Missing required keys in {credentials_path}: {', '.join(missing)}"
733
+ )
722
734
 
723
735
  return {k: env_values[k] for k in required_keys}
724
736
 
725
- def _s3_from_credentials( creds: str | Path | dict ) -> S3FileSystem:
737
+
738
+ def _s3_from_credentials(creds: str | Path | dict) -> S3FileSystem:
726
739
  """Create S3FileSystem from credentials dict or .env file path."""
727
- if not isinstance( creds, dict ):
728
- creds = _s3_env( creds )
740
+ if not isinstance(creds, dict):
741
+ creds = _s3_env(creds)
729
742
 
730
743
  # Build kwargs, making endpoint_url optional
731
744
  kwargs = {
732
- 'key': creds['AWS_ACCESS_KEY_ID'],
733
- 'secret': creds['AWS_SECRET_ACCESS_KEY']
745
+ "key": creds["AWS_ACCESS_KEY_ID"],
746
+ "secret": creds["AWS_SECRET_ACCESS_KEY"],
734
747
  }
735
- if 'AWS_ENDPOINT' in creds:
736
- kwargs['endpoint_url'] = creds['AWS_ENDPOINT']
748
+ if "AWS_ENDPOINT" in creds:
749
+ kwargs["endpoint_url"] = creds["AWS_ENDPOINT"]
737
750
 
738
751
  return S3FileSystem(**kwargs)
739
752
 
@@ -741,6 +754,7 @@ def _s3_from_credentials( creds: str | Path | dict ) -> S3FileSystem:
741
754
  ##
742
755
  # Classes
743
756
 
757
+
744
758
  class Repo:
745
759
  """Repository for storing and managing atdata datasets.
746
760
 
@@ -797,20 +811,20 @@ class Repo:
797
811
 
798
812
  if s3_credentials is None:
799
813
  self.s3_credentials = None
800
- elif isinstance( s3_credentials, dict ):
814
+ elif isinstance(s3_credentials, dict):
801
815
  self.s3_credentials = s3_credentials
802
816
  else:
803
- self.s3_credentials = _s3_env( s3_credentials )
817
+ self.s3_credentials = _s3_env(s3_credentials)
804
818
 
805
819
  if self.s3_credentials is None:
806
820
  self.bucket_fs = None
807
821
  else:
808
- self.bucket_fs = _s3_from_credentials( self.s3_credentials )
822
+ self.bucket_fs = _s3_from_credentials(self.s3_credentials)
809
823
 
810
824
  if self.bucket_fs is not None:
811
825
  if hive_path is None:
812
- raise ValueError( 'Must specify hive path within bucket' )
813
- self.hive_path = Path( hive_path )
826
+ raise ValueError("Must specify hive path within bucket")
827
+ self.hive_path = Path(hive_path)
814
828
  self.hive_bucket = self.hive_path.parts[0]
815
829
  else:
816
830
  self.hive_path = None
@@ -818,18 +832,19 @@ class Repo:
818
832
 
819
833
  #
820
834
 
821
- self.index = Index( redis = redis )
835
+ self.index = Index(redis=redis)
822
836
 
823
837
  ##
824
838
 
825
- def insert(self,
826
- ds: Dataset[T],
827
- *,
828
- name: str,
829
- cache_local: bool = False,
830
- schema_ref: str | None = None,
831
- **kwargs
832
- ) -> tuple[LocalDatasetEntry, Dataset[T]]:
839
+ def insert(
840
+ self,
841
+ ds: Dataset[T],
842
+ *,
843
+ name: str,
844
+ cache_local: bool = False,
845
+ schema_ref: str | None = None,
846
+ **kwargs,
847
+ ) -> tuple[LocalDatasetEntry, Dataset[T]]:
833
848
  """Insert a dataset into the repository.
834
849
 
835
850
  Writes the dataset to S3 as WebDataset tar files, stores metadata,
@@ -853,35 +868,35 @@ class Repo:
853
868
  RuntimeError: If no shards were written.
854
869
  """
855
870
  if self.s3_credentials is None:
856
- raise ValueError("S3 credentials required for insert(). Initialize Repo with s3_credentials.")
871
+ raise ValueError(
872
+ "S3 credentials required for insert(). Initialize Repo with s3_credentials."
873
+ )
857
874
  if self.hive_bucket is None or self.hive_path is None:
858
- raise ValueError("hive_path required for insert(). Initialize Repo with hive_path.")
875
+ raise ValueError(
876
+ "hive_path required for insert(). Initialize Repo with hive_path."
877
+ )
859
878
 
860
- new_uuid = str( uuid4() )
879
+ new_uuid = str(uuid4())
861
880
 
862
- hive_fs = _s3_from_credentials( self.s3_credentials )
881
+ hive_fs = _s3_from_credentials(self.s3_credentials)
863
882
 
864
883
  # Write metadata
865
884
  metadata_path = (
866
- self.hive_path
867
- / 'metadata'
868
- / f'atdata-metadata--{new_uuid}.msgpack'
885
+ self.hive_path / "metadata" / f"atdata-metadata--{new_uuid}.msgpack"
869
886
  )
870
887
  # Note: S3 doesn't need directories created beforehand - s3fs handles this
871
888
 
872
889
  if ds.metadata is not None:
873
890
  # Use s3:// prefix to ensure s3fs treats this as an S3 path
874
- with cast( BinaryIO, hive_fs.open( f's3://{metadata_path.as_posix()}', 'wb' ) ) as f:
875
- meta_packed = msgpack.packb( ds.metadata )
891
+ with cast(
892
+ BinaryIO, hive_fs.open(f"s3://{metadata_path.as_posix()}", "wb")
893
+ ) as f:
894
+ meta_packed = msgpack.packb(ds.metadata)
876
895
  assert meta_packed is not None
877
- f.write( cast( bytes, meta_packed ) )
878
-
896
+ f.write(cast(bytes, meta_packed))
879
897
 
880
898
  # Write data
881
- shard_pattern = (
882
- self.hive_path
883
- / f'atdata--{new_uuid}--%06d.tar'
884
- ).as_posix()
899
+ shard_pattern = (self.hive_path / f"atdata--{new_uuid}--%06d.tar").as_posix()
885
900
 
886
901
  written_shards: list[str] = []
887
902
  with TemporaryDirectory() as temp_dir:
@@ -904,24 +919,22 @@ class Repo:
904
919
  sink.write(sample.as_wds)
905
920
 
906
921
  # Make a new Dataset object for the written dataset copy
907
- if len( written_shards ) == 0:
908
- raise RuntimeError( 'Cannot form new dataset entry -- did not write any shards' )
909
-
910
- elif len( written_shards ) < 2:
922
+ if len(written_shards) == 0:
923
+ raise RuntimeError(
924
+ "Cannot form new dataset entry -- did not write any shards"
925
+ )
926
+
927
+ elif len(written_shards) < 2:
911
928
  new_dataset_url = (
912
- self.hive_path
913
- / ( Path( written_shards[0] ).name )
929
+ self.hive_path / (Path(written_shards[0]).name)
914
930
  ).as_posix()
915
931
 
916
932
  else:
917
933
  shard_s3_format = (
918
- (
919
- self.hive_path
920
- / f'atdata--{new_uuid}'
921
- ).as_posix()
922
- ) + '--{shard_id}.tar'
923
- shard_id_braced = '{' + f'{0:06d}..{len( written_shards ) - 1:06d}' + '}'
924
- new_dataset_url = shard_s3_format.format( shard_id = shard_id_braced )
934
+ (self.hive_path / f"atdata--{new_uuid}").as_posix()
935
+ ) + "--{shard_id}.tar"
936
+ shard_id_braced = "{" + f"{0:06d}..{len(written_shards) - 1:06d}" + "}"
937
+ new_dataset_url = shard_s3_format.format(shard_id=shard_id_braced)
925
938
 
926
939
  new_dataset = Dataset[ds.sample_type](
927
940
  url=new_dataset_url,
@@ -995,6 +1008,7 @@ class Index:
995
1008
  # Providing stub_dir implies auto_stubs=True
996
1009
  if auto_stubs or stub_dir is not None:
997
1010
  from ._stub_manager import StubManager
1011
+
998
1012
  self._stub_manager: StubManager | None = StubManager(stub_dir=stub_dir)
999
1013
  else:
1000
1014
  self._stub_manager = None
@@ -1027,12 +1041,10 @@ class Index:
1027
1041
  After calling :meth:`load_schema`, schema types become available
1028
1042
  as attributes on this namespace.
1029
1043
 
1030
- Example:
1031
- ::
1032
-
1033
- >>> index.load_schema("atdata://local/sampleSchema/MySample@1.0.0")
1034
- >>> MyType = index.types.MySample
1035
- >>> sample = MyType(name="hello", value=42)
1044
+ Examples:
1045
+ >>> index.load_schema("atdata://local/sampleSchema/MySample@1.0.0")
1046
+ >>> MyType = index.types.MySample
1047
+ >>> sample = MyType(name="hello", value=42)
1036
1048
 
1037
1049
  Returns:
1038
1050
  SchemaNamespace containing all loaded schema types.
@@ -1058,16 +1070,14 @@ class Index:
1058
1070
  KeyError: If schema not found.
1059
1071
  ValueError: If schema cannot be decoded.
1060
1072
 
1061
- Example:
1062
- ::
1063
-
1064
- >>> # Load and use immediately
1065
- >>> MyType = index.load_schema("atdata://local/sampleSchema/MySample@1.0.0")
1066
- >>> sample = MyType(name="hello", value=42)
1067
- >>>
1068
- >>> # Or access later via namespace
1069
- >>> index.load_schema("atdata://local/sampleSchema/OtherType@1.0.0")
1070
- >>> other = index.types.OtherType(data="test")
1073
+ Examples:
1074
+ >>> # Load and use immediately
1075
+ >>> MyType = index.load_schema("atdata://local/sampleSchema/MySample@1.0.0")
1076
+ >>> sample = MyType(name="hello", value=42)
1077
+ >>>
1078
+ >>> # Or access later via namespace
1079
+ >>> index.load_schema("atdata://local/sampleSchema/OtherType@1.0.0")
1080
+ >>> other = index.types.OtherType(data="test")
1071
1081
  """
1072
1082
  # Decode the schema (uses generated module if auto_stubs enabled)
1073
1083
  cls = self.decode_schema(ref)
@@ -1090,16 +1100,14 @@ class Index:
1090
1100
  Import path like "local.MySample_1_0_0", or None if auto_stubs
1091
1101
  is disabled.
1092
1102
 
1093
- Example:
1094
- ::
1095
-
1096
- >>> index = LocalIndex(auto_stubs=True)
1097
- >>> ref = index.publish_schema(MySample, version="1.0.0")
1098
- >>> index.load_schema(ref)
1099
- >>> print(index.get_import_path(ref))
1100
- local.MySample_1_0_0
1101
- >>> # Then in your code:
1102
- >>> # from local.MySample_1_0_0 import MySample
1103
+ Examples:
1104
+ >>> index = LocalIndex(auto_stubs=True)
1105
+ >>> ref = index.publish_schema(MySample, version="1.0.0")
1106
+ >>> index.load_schema(ref)
1107
+ >>> print(index.get_import_path(ref))
1108
+ local.MySample_1_0_0
1109
+ >>> # Then in your code:
1110
+ >>> # from local.MySample_1_0_0 import MySample
1103
1111
  """
1104
1112
  if self._stub_manager is None:
1105
1113
  return None
@@ -1138,19 +1146,20 @@ class Index:
1138
1146
  Yields:
1139
1147
  LocalDatasetEntry objects from the index.
1140
1148
  """
1141
- prefix = f'{REDIS_KEY_DATASET_ENTRY}:'
1142
- for key in self._redis.scan_iter(match=f'{prefix}*'):
1143
- key_str = key.decode('utf-8') if isinstance(key, bytes) else key
1144
- cid = key_str[len(prefix):]
1149
+ prefix = f"{REDIS_KEY_DATASET_ENTRY}:"
1150
+ for key in self._redis.scan_iter(match=f"{prefix}*"):
1151
+ key_str = key.decode("utf-8") if isinstance(key, bytes) else key
1152
+ cid = key_str[len(prefix) :]
1145
1153
  yield LocalDatasetEntry.from_redis(self._redis, cid)
1146
1154
 
1147
- def add_entry(self,
1148
- ds: Dataset,
1149
- *,
1150
- name: str,
1151
- schema_ref: str | None = None,
1152
- metadata: dict | None = None,
1153
- ) -> LocalDatasetEntry:
1155
+ def add_entry(
1156
+ self,
1157
+ ds: Dataset,
1158
+ *,
1159
+ name: str,
1160
+ schema_ref: str | None = None,
1161
+ metadata: dict | None = None,
1162
+ ) -> LocalDatasetEntry:
1154
1163
  """Add a dataset to the index.
1155
1164
 
1156
1165
  Creates a LocalDatasetEntry for the dataset and persists it to Redis.
@@ -1166,7 +1175,9 @@ class Index:
1166
1175
  """
1167
1176
  ##
1168
1177
  if schema_ref is None:
1169
- schema_ref = f"local://schemas/{_kind_str_for_sample_type(ds.sample_type)}@1.0.0"
1178
+ schema_ref = (
1179
+ f"local://schemas/{_kind_str_for_sample_type(ds.sample_type)}@1.0.0"
1180
+ )
1170
1181
 
1171
1182
  # Normalize URL to list
1172
1183
  data_urls = [ds.url]
@@ -1245,12 +1256,12 @@ class Index:
1245
1256
  Returns:
1246
1257
  IndexEntry for the inserted dataset.
1247
1258
  """
1248
- metadata = kwargs.get('metadata')
1259
+ metadata = kwargs.get("metadata")
1249
1260
 
1250
1261
  if self._data_store is not None:
1251
1262
  # Write shards to data store, then index the new URLs
1252
- prefix = kwargs.get('prefix', name)
1253
- cache_local = kwargs.get('cache_local', False)
1263
+ prefix = kwargs.get("prefix", name)
1264
+ cache_local = kwargs.get("cache_local", False)
1254
1265
 
1255
1266
  written_urls = self._data_store.write_shards(
1256
1267
  ds,
@@ -1314,10 +1325,10 @@ class Index:
1314
1325
  latest_version: tuple[int, int, int] | None = None
1315
1326
  latest_version_str: str | None = None
1316
1327
 
1317
- prefix = f'{REDIS_KEY_SCHEMA}:'
1318
- for key in self._redis.scan_iter(match=f'{prefix}*'):
1319
- key_str = key.decode('utf-8') if isinstance(key, bytes) else key
1320
- schema_id = key_str[len(prefix):]
1328
+ prefix = f"{REDIS_KEY_SCHEMA}:"
1329
+ for key in self._redis.scan_iter(match=f"{prefix}*"):
1330
+ key_str = key.decode("utf-8") if isinstance(key, bytes) else key
1331
+ schema_id = key_str[len(prefix) :]
1321
1332
 
1322
1333
  if "@" not in schema_id:
1323
1334
  continue
@@ -1369,10 +1380,12 @@ class Index:
1369
1380
  # This catches non-packable types early with a clear error message
1370
1381
  try:
1371
1382
  # Check protocol compliance by verifying required methods exist
1372
- if not (hasattr(sample_type, 'from_data') and
1373
- hasattr(sample_type, 'from_bytes') and
1374
- callable(getattr(sample_type, 'from_data', None)) and
1375
- callable(getattr(sample_type, 'from_bytes', None))):
1383
+ if not (
1384
+ hasattr(sample_type, "from_data")
1385
+ and hasattr(sample_type, "from_bytes")
1386
+ and callable(getattr(sample_type, "from_data", None))
1387
+ and callable(getattr(sample_type, "from_bytes", None))
1388
+ ):
1376
1389
  raise TypeError(
1377
1390
  f"{sample_type.__name__} does not satisfy the Packable protocol. "
1378
1391
  "Use @packable decorator or inherit from PackableSample."
@@ -1430,10 +1443,10 @@ class Index:
1430
1443
  raise KeyError(f"Schema not found: {ref}")
1431
1444
 
1432
1445
  if isinstance(schema_json, bytes):
1433
- schema_json = schema_json.decode('utf-8')
1446
+ schema_json = schema_json.decode("utf-8")
1434
1447
 
1435
1448
  schema = json.loads(schema_json)
1436
- schema['$ref'] = _make_schema_ref(name, version)
1449
+ schema["$ref"] = _make_schema_ref(name, version)
1437
1450
 
1438
1451
  # Auto-generate stub if enabled
1439
1452
  if self._stub_manager is not None:
@@ -1468,29 +1481,29 @@ class Index:
1468
1481
  Yields:
1469
1482
  LocalSchemaRecord for each schema.
1470
1483
  """
1471
- prefix = f'{REDIS_KEY_SCHEMA}:'
1472
- for key in self._redis.scan_iter(match=f'{prefix}*'):
1473
- key_str = key.decode('utf-8') if isinstance(key, bytes) else key
1484
+ prefix = f"{REDIS_KEY_SCHEMA}:"
1485
+ for key in self._redis.scan_iter(match=f"{prefix}*"):
1486
+ key_str = key.decode("utf-8") if isinstance(key, bytes) else key
1474
1487
  # Extract name@version from key
1475
- schema_id = key_str[len(prefix):]
1488
+ schema_id = key_str[len(prefix) :]
1476
1489
 
1477
1490
  schema_json = self._redis.get(key)
1478
1491
  if schema_json is None:
1479
1492
  continue
1480
1493
 
1481
1494
  if isinstance(schema_json, bytes):
1482
- schema_json = schema_json.decode('utf-8')
1495
+ schema_json = schema_json.decode("utf-8")
1483
1496
 
1484
1497
  schema = json.loads(schema_json)
1485
1498
  # Handle legacy keys that have module.Class format
1486
1499
  if "." in schema_id.split("@")[0]:
1487
1500
  name = schema_id.split("@")[0].rsplit(".", 1)[1]
1488
1501
  version = schema_id.split("@")[1]
1489
- schema['$ref'] = _make_schema_ref(name, version)
1502
+ schema["$ref"] = _make_schema_ref(name, version)
1490
1503
  else:
1491
1504
  # schema_id is already "name@version"
1492
1505
  name, version = schema_id.rsplit("@", 1)
1493
- schema['$ref'] = _make_schema_ref(name, version)
1506
+ schema["$ref"] = _make_schema_ref(name, version)
1494
1507
  yield LocalSchemaRecord.from_dict(schema)
1495
1508
 
1496
1509
  def list_schemas(self) -> list[dict]:
@@ -1534,6 +1547,7 @@ class Index:
1534
1547
 
1535
1548
  # Fall back to dynamic type generation
1536
1549
  from atdata._schema_codec import schema_to_type
1550
+
1537
1551
  return schema_to_type(schema_dict)
1538
1552
 
1539
1553
  def decode_schema_as(self, ref: str, type_hint: type[T]) -> type[T]:
@@ -1551,15 +1565,13 @@ class Index:
1551
1565
  Returns:
1552
1566
  The decoded type, cast to match the type_hint for IDE support.
1553
1567
 
1554
- Example:
1555
- ::
1556
-
1557
- >>> # After enabling auto_stubs and configuring IDE extraPaths:
1558
- >>> from local.MySample_1_0_0 import MySample
1559
- >>>
1560
- >>> # This gives full IDE autocomplete:
1561
- >>> DecodedType = index.decode_schema_as(ref, MySample)
1562
- >>> sample = DecodedType(text="hello", value=42) # IDE knows signature!
1568
+ Examples:
1569
+ >>> # After enabling auto_stubs and configuring IDE extraPaths:
1570
+ >>> from local.MySample_1_0_0 import MySample
1571
+ >>>
1572
+ >>> # This gives full IDE autocomplete:
1573
+ >>> DecodedType = index.decode_schema_as(ref, MySample)
1574
+ >>> sample = DecodedType(text="hello", value=42) # IDE knows signature!
1563
1575
 
1564
1576
  Note:
1565
1577
  The type_hint is only used for static type checking - at runtime,
@@ -1567,6 +1579,7 @@ class Index:
1567
1579
  stub matches the schema to avoid runtime surprises.
1568
1580
  """
1569
1581
  from typing import cast
1582
+
1570
1583
  return cast(type[T], self.decode_schema(ref))
1571
1584
 
1572
1585
  def clear_stubs(self) -> int:
@@ -1687,11 +1700,11 @@ class S3DataStore:
1687
1700
  HTTPS URL if custom endpoint is configured, otherwise unchanged.
1688
1701
  Example: 's3://bucket/path' -> 'https://endpoint.com/bucket/path'
1689
1702
  """
1690
- endpoint = self.credentials.get('AWS_ENDPOINT')
1691
- if endpoint and url.startswith('s3://'):
1703
+ endpoint = self.credentials.get("AWS_ENDPOINT")
1704
+ if endpoint and url.startswith("s3://"):
1692
1705
  # s3://bucket/path -> https://endpoint/bucket/path
1693
1706
  path = url[5:] # Remove 's3://' prefix
1694
- endpoint = endpoint.rstrip('/')
1707
+ endpoint = endpoint.rstrip("/")
1695
1708
  return f"{endpoint}/{path}"
1696
1709
  return url
1697
1710
 
@@ -1704,4 +1717,4 @@ class S3DataStore:
1704
1717
  return True
1705
1718
 
1706
1719
 
1707
- #
1720
+ #