atdata 0.2.2b1__py3-none-any.whl → 0.2.3b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +1 -1
- atdata/_cid.py +29 -35
- atdata/_helpers.py +7 -5
- atdata/_hf_api.py +48 -50
- atdata/_protocols.py +56 -71
- atdata/_schema_codec.py +33 -37
- atdata/_sources.py +57 -64
- atdata/_stub_manager.py +31 -26
- atdata/_type_utils.py +19 -5
- atdata/atmosphere/__init__.py +20 -23
- atdata/atmosphere/_types.py +11 -11
- atdata/atmosphere/client.py +11 -8
- atdata/atmosphere/lens.py +27 -30
- atdata/atmosphere/records.py +31 -37
- atdata/atmosphere/schema.py +33 -29
- atdata/atmosphere/store.py +16 -20
- atdata/cli/__init__.py +12 -3
- atdata/cli/diagnose.py +12 -8
- atdata/cli/local.py +4 -1
- atdata/dataset.py +284 -241
- atdata/lens.py +77 -82
- atdata/local.py +182 -169
- atdata/promote.py +18 -22
- {atdata-0.2.2b1.dist-info → atdata-0.2.3b1.dist-info}/METADATA +2 -1
- atdata-0.2.3b1.dist-info/RECORD +28 -0
- atdata-0.2.2b1.dist-info/RECORD +0 -28
- {atdata-0.2.2b1.dist-info → atdata-0.2.3b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.2b1.dist-info → atdata-0.2.3b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.2b1.dist-info → atdata-0.2.3b1.dist-info}/licenses/LICENSE +0 -0
atdata/local.py
CHANGED
|
@@ -24,13 +24,12 @@ from atdata import (
|
|
|
24
24
|
)
|
|
25
25
|
from atdata._cid import generate_cid
|
|
26
26
|
from atdata._type_utils import (
|
|
27
|
-
numpy_dtype_to_string,
|
|
28
27
|
PRIMITIVE_TYPE_MAP,
|
|
29
28
|
unwrap_optional,
|
|
30
29
|
is_ndarray_type,
|
|
31
30
|
extract_ndarray_dtype,
|
|
32
31
|
)
|
|
33
|
-
from atdata._protocols import
|
|
32
|
+
from atdata._protocols import AbstractDataStore, Packable
|
|
34
33
|
|
|
35
34
|
from pathlib import Path
|
|
36
35
|
from uuid import uuid4
|
|
@@ -57,7 +56,6 @@ from typing import (
|
|
|
57
56
|
Generator,
|
|
58
57
|
Iterator,
|
|
59
58
|
BinaryIO,
|
|
60
|
-
Union,
|
|
61
59
|
Optional,
|
|
62
60
|
Literal,
|
|
63
61
|
cast,
|
|
@@ -70,7 +68,7 @@ from datetime import datetime, timezone
|
|
|
70
68
|
import json
|
|
71
69
|
import warnings
|
|
72
70
|
|
|
73
|
-
T = TypeVar(
|
|
71
|
+
T = TypeVar("T", bound=PackableSample)
|
|
74
72
|
|
|
75
73
|
# Redis key prefixes for index entries and schemas
|
|
76
74
|
REDIS_KEY_DATASET_ENTRY = "LocalDatasetEntry"
|
|
@@ -84,12 +82,10 @@ class SchemaNamespace:
|
|
|
84
82
|
loaded schema types. After calling ``index.load_schema(uri)``, the
|
|
85
83
|
schema's class becomes available as an attribute on this namespace.
|
|
86
84
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
>>> MyType = index.types.MySample
|
|
92
|
-
>>> sample = MyType(field1="hello", field2=42)
|
|
85
|
+
Examples:
|
|
86
|
+
>>> index.load_schema("atdata://local/sampleSchema/MySample@1.0.0")
|
|
87
|
+
>>> MyType = index.types.MySample
|
|
88
|
+
>>> sample = MyType(field1="hello", field2=42)
|
|
93
89
|
|
|
94
90
|
The namespace supports:
|
|
95
91
|
- Attribute access: ``index.types.MySample``
|
|
@@ -357,9 +353,10 @@ class LocalSchemaRecord:
|
|
|
357
353
|
##
|
|
358
354
|
# Helpers
|
|
359
355
|
|
|
360
|
-
|
|
356
|
+
|
|
357
|
+
def _kind_str_for_sample_type(st: Type[Packable]) -> str:
|
|
361
358
|
"""Return fully-qualified 'module.name' string for a sample type."""
|
|
362
|
-
return f
|
|
359
|
+
return f"{st.__module__}.{st.__name__}"
|
|
363
360
|
|
|
364
361
|
|
|
365
362
|
def _create_s3_write_callbacks(
|
|
@@ -387,17 +384,17 @@ def _create_s3_write_callbacks(
|
|
|
387
384
|
import boto3
|
|
388
385
|
|
|
389
386
|
s3_client_kwargs = {
|
|
390
|
-
|
|
391
|
-
|
|
387
|
+
"aws_access_key_id": credentials["AWS_ACCESS_KEY_ID"],
|
|
388
|
+
"aws_secret_access_key": credentials["AWS_SECRET_ACCESS_KEY"],
|
|
392
389
|
}
|
|
393
|
-
if
|
|
394
|
-
s3_client_kwargs[
|
|
395
|
-
s3_client = boto3.client(
|
|
390
|
+
if "AWS_ENDPOINT" in credentials:
|
|
391
|
+
s3_client_kwargs["endpoint_url"] = credentials["AWS_ENDPOINT"]
|
|
392
|
+
s3_client = boto3.client("s3", **s3_client_kwargs)
|
|
396
393
|
|
|
397
394
|
def _writer_opener(p: str):
|
|
398
395
|
local_path = Path(temp_dir) / p
|
|
399
396
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
|
400
|
-
return open(local_path,
|
|
397
|
+
return open(local_path, "wb")
|
|
401
398
|
|
|
402
399
|
def _writer_post(p: str):
|
|
403
400
|
local_path = Path(temp_dir) / p
|
|
@@ -405,7 +402,7 @@ def _create_s3_write_callbacks(
|
|
|
405
402
|
bucket = path_parts[0]
|
|
406
403
|
key = str(Path(*path_parts[1:]))
|
|
407
404
|
|
|
408
|
-
with open(local_path,
|
|
405
|
+
with open(local_path, "rb") as f_in:
|
|
409
406
|
s3_client.put_object(Bucket=bucket, Key=key, Body=f_in.read())
|
|
410
407
|
|
|
411
408
|
local_path.unlink()
|
|
@@ -419,7 +416,7 @@ def _create_s3_write_callbacks(
|
|
|
419
416
|
assert fs is not None, "S3FileSystem required when cache_local=False"
|
|
420
417
|
|
|
421
418
|
def _direct_opener(s: str):
|
|
422
|
-
return cast(BinaryIO, fs.open(f
|
|
419
|
+
return cast(BinaryIO, fs.open(f"s3://{s}", "wb"))
|
|
423
420
|
|
|
424
421
|
def _direct_post(s: str):
|
|
425
422
|
if add_s3_prefix:
|
|
@@ -429,6 +426,7 @@ def _create_s3_write_callbacks(
|
|
|
429
426
|
|
|
430
427
|
return _direct_opener, _direct_post
|
|
431
428
|
|
|
429
|
+
|
|
432
430
|
##
|
|
433
431
|
# Schema helpers
|
|
434
432
|
|
|
@@ -454,9 +452,9 @@ def _parse_schema_ref(ref: str) -> tuple[str, str]:
|
|
|
454
452
|
and legacy format: 'local://schemas/{module.Class}@{version}'
|
|
455
453
|
"""
|
|
456
454
|
if ref.startswith(_ATDATA_URI_PREFIX):
|
|
457
|
-
path = ref[len(_ATDATA_URI_PREFIX):]
|
|
455
|
+
path = ref[len(_ATDATA_URI_PREFIX) :]
|
|
458
456
|
elif ref.startswith(_LEGACY_URI_PREFIX):
|
|
459
|
-
path = ref[len(_LEGACY_URI_PREFIX):]
|
|
457
|
+
path = ref[len(_LEGACY_URI_PREFIX) :]
|
|
460
458
|
else:
|
|
461
459
|
raise ValueError(f"Invalid schema reference: {ref}")
|
|
462
460
|
|
|
@@ -487,7 +485,10 @@ def _increment_patch(version: str) -> str:
|
|
|
487
485
|
def _python_type_to_field_type(python_type: Any) -> dict:
|
|
488
486
|
"""Convert Python type annotation to schema field type dict."""
|
|
489
487
|
if python_type in PRIMITIVE_TYPE_MAP:
|
|
490
|
-
return {
|
|
488
|
+
return {
|
|
489
|
+
"$type": "local#primitive",
|
|
490
|
+
"primitive": PRIMITIVE_TYPE_MAP[python_type],
|
|
491
|
+
}
|
|
491
492
|
|
|
492
493
|
if is_ndarray_type(python_type):
|
|
493
494
|
return {"$type": "local#ndarray", "dtype": extract_ndarray_dtype(python_type)}
|
|
@@ -495,7 +496,11 @@ def _python_type_to_field_type(python_type: Any) -> dict:
|
|
|
495
496
|
origin = get_origin(python_type)
|
|
496
497
|
if origin is list:
|
|
497
498
|
args = get_args(python_type)
|
|
498
|
-
items =
|
|
499
|
+
items = (
|
|
500
|
+
_python_type_to_field_type(args[0])
|
|
501
|
+
if args
|
|
502
|
+
else {"$type": "local#primitive", "primitive": "str"}
|
|
503
|
+
)
|
|
499
504
|
return {"$type": "local#array", "items": items}
|
|
500
505
|
|
|
501
506
|
if is_dataclass(python_type):
|
|
@@ -543,11 +548,13 @@ def _build_schema_record(
|
|
|
543
548
|
field_type, is_optional = unwrap_optional(field_type)
|
|
544
549
|
field_type_dict = _python_type_to_field_type(field_type)
|
|
545
550
|
|
|
546
|
-
field_defs.append(
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
+
field_defs.append(
|
|
552
|
+
{
|
|
553
|
+
"name": f.name,
|
|
554
|
+
"fieldType": field_type_dict,
|
|
555
|
+
"optional": is_optional,
|
|
556
|
+
}
|
|
557
|
+
)
|
|
551
558
|
|
|
552
559
|
return {
|
|
553
560
|
"name": sample_type.__name__,
|
|
@@ -561,6 +568,7 @@ def _build_schema_record(
|
|
|
561
568
|
##
|
|
562
569
|
# Redis object model
|
|
563
570
|
|
|
571
|
+
|
|
564
572
|
@dataclass
|
|
565
573
|
class LocalDatasetEntry:
|
|
566
574
|
"""Index entry for a dataset stored in the local repository.
|
|
@@ -579,6 +587,7 @@ class LocalDatasetEntry:
|
|
|
579
587
|
data_urls: WebDataset URLs for the data.
|
|
580
588
|
metadata: Arbitrary metadata dictionary, or None if not set.
|
|
581
589
|
"""
|
|
590
|
+
|
|
582
591
|
##
|
|
583
592
|
|
|
584
593
|
name: str
|
|
@@ -640,17 +649,17 @@ class LocalDatasetEntry:
|
|
|
640
649
|
Args:
|
|
641
650
|
redis: Redis connection to write to.
|
|
642
651
|
"""
|
|
643
|
-
save_key = f
|
|
652
|
+
save_key = f"{REDIS_KEY_DATASET_ENTRY}:{self.cid}"
|
|
644
653
|
data = {
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
654
|
+
"name": self.name,
|
|
655
|
+
"schema_ref": self.schema_ref,
|
|
656
|
+
"data_urls": msgpack.packb(self.data_urls), # Serialize list
|
|
657
|
+
"cid": self.cid,
|
|
649
658
|
}
|
|
650
659
|
if self.metadata is not None:
|
|
651
|
-
data[
|
|
660
|
+
data["metadata"] = msgpack.packb(self.metadata)
|
|
652
661
|
if self._legacy_uuid is not None:
|
|
653
|
-
data[
|
|
662
|
+
data["legacy_uuid"] = self._legacy_uuid
|
|
654
663
|
|
|
655
664
|
redis.hset(save_key, mapping=data) # type: ignore[arg-type]
|
|
656
665
|
|
|
@@ -668,23 +677,23 @@ class LocalDatasetEntry:
|
|
|
668
677
|
Raises:
|
|
669
678
|
KeyError: If entry not found.
|
|
670
679
|
"""
|
|
671
|
-
save_key = f
|
|
680
|
+
save_key = f"{REDIS_KEY_DATASET_ENTRY}:{cid}"
|
|
672
681
|
raw_data = redis.hgetall(save_key)
|
|
673
682
|
if not raw_data:
|
|
674
683
|
raise KeyError(f"{REDIS_KEY_DATASET_ENTRY} not found: {cid}")
|
|
675
684
|
|
|
676
685
|
# Decode string fields, keep binary fields as bytes for msgpack
|
|
677
686
|
raw_data_typed = cast(dict[bytes, bytes], raw_data)
|
|
678
|
-
name = raw_data_typed[b
|
|
679
|
-
schema_ref = raw_data_typed[b
|
|
680
|
-
cid_value = raw_data_typed.get(b
|
|
681
|
-
legacy_uuid = raw_data_typed.get(b
|
|
687
|
+
name = raw_data_typed[b"name"].decode("utf-8")
|
|
688
|
+
schema_ref = raw_data_typed[b"schema_ref"].decode("utf-8")
|
|
689
|
+
cid_value = raw_data_typed.get(b"cid", b"").decode("utf-8") or None
|
|
690
|
+
legacy_uuid = raw_data_typed.get(b"legacy_uuid", b"").decode("utf-8") or None
|
|
682
691
|
|
|
683
692
|
# Deserialize msgpack fields (stored as raw bytes)
|
|
684
|
-
data_urls = msgpack.unpackb(raw_data_typed[b
|
|
693
|
+
data_urls = msgpack.unpackb(raw_data_typed[b"data_urls"])
|
|
685
694
|
metadata = None
|
|
686
|
-
if b
|
|
687
|
-
metadata = msgpack.unpackb(raw_data_typed[b
|
|
695
|
+
if b"metadata" in raw_data_typed:
|
|
696
|
+
metadata = msgpack.unpackb(raw_data_typed[b"metadata"])
|
|
688
697
|
|
|
689
698
|
return cls(
|
|
690
699
|
name=name,
|
|
@@ -699,7 +708,8 @@ class LocalDatasetEntry:
|
|
|
699
708
|
# Backwards compatibility alias
|
|
700
709
|
BasicIndexEntry = LocalDatasetEntry
|
|
701
710
|
|
|
702
|
-
|
|
711
|
+
|
|
712
|
+
def _s3_env(credentials_path: str | Path) -> dict[str, Any]:
|
|
703
713
|
"""Load S3 credentials from .env file.
|
|
704
714
|
|
|
705
715
|
Args:
|
|
@@ -712,28 +722,31 @@ def _s3_env( credentials_path: str | Path ) -> dict[str, Any]:
|
|
|
712
722
|
Raises:
|
|
713
723
|
ValueError: If any required key is missing from the .env file.
|
|
714
724
|
"""
|
|
715
|
-
credentials_path = Path(
|
|
716
|
-
env_values = dotenv_values(
|
|
725
|
+
credentials_path = Path(credentials_path)
|
|
726
|
+
env_values = dotenv_values(credentials_path)
|
|
717
727
|
|
|
718
|
-
required_keys = (
|
|
728
|
+
required_keys = ("AWS_ENDPOINT", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY")
|
|
719
729
|
missing = [k for k in required_keys if k not in env_values]
|
|
720
730
|
if missing:
|
|
721
|
-
raise ValueError(
|
|
731
|
+
raise ValueError(
|
|
732
|
+
f"Missing required keys in {credentials_path}: {', '.join(missing)}"
|
|
733
|
+
)
|
|
722
734
|
|
|
723
735
|
return {k: env_values[k] for k in required_keys}
|
|
724
736
|
|
|
725
|
-
|
|
737
|
+
|
|
738
|
+
def _s3_from_credentials(creds: str | Path | dict) -> S3FileSystem:
|
|
726
739
|
"""Create S3FileSystem from credentials dict or .env file path."""
|
|
727
|
-
if not isinstance(
|
|
728
|
-
creds = _s3_env(
|
|
740
|
+
if not isinstance(creds, dict):
|
|
741
|
+
creds = _s3_env(creds)
|
|
729
742
|
|
|
730
743
|
# Build kwargs, making endpoint_url optional
|
|
731
744
|
kwargs = {
|
|
732
|
-
|
|
733
|
-
|
|
745
|
+
"key": creds["AWS_ACCESS_KEY_ID"],
|
|
746
|
+
"secret": creds["AWS_SECRET_ACCESS_KEY"],
|
|
734
747
|
}
|
|
735
|
-
if
|
|
736
|
-
kwargs[
|
|
748
|
+
if "AWS_ENDPOINT" in creds:
|
|
749
|
+
kwargs["endpoint_url"] = creds["AWS_ENDPOINT"]
|
|
737
750
|
|
|
738
751
|
return S3FileSystem(**kwargs)
|
|
739
752
|
|
|
@@ -741,6 +754,7 @@ def _s3_from_credentials( creds: str | Path | dict ) -> S3FileSystem:
|
|
|
741
754
|
##
|
|
742
755
|
# Classes
|
|
743
756
|
|
|
757
|
+
|
|
744
758
|
class Repo:
|
|
745
759
|
"""Repository for storing and managing atdata datasets.
|
|
746
760
|
|
|
@@ -797,20 +811,20 @@ class Repo:
|
|
|
797
811
|
|
|
798
812
|
if s3_credentials is None:
|
|
799
813
|
self.s3_credentials = None
|
|
800
|
-
elif isinstance(
|
|
814
|
+
elif isinstance(s3_credentials, dict):
|
|
801
815
|
self.s3_credentials = s3_credentials
|
|
802
816
|
else:
|
|
803
|
-
self.s3_credentials = _s3_env(
|
|
817
|
+
self.s3_credentials = _s3_env(s3_credentials)
|
|
804
818
|
|
|
805
819
|
if self.s3_credentials is None:
|
|
806
820
|
self.bucket_fs = None
|
|
807
821
|
else:
|
|
808
|
-
self.bucket_fs = _s3_from_credentials(
|
|
822
|
+
self.bucket_fs = _s3_from_credentials(self.s3_credentials)
|
|
809
823
|
|
|
810
824
|
if self.bucket_fs is not None:
|
|
811
825
|
if hive_path is None:
|
|
812
|
-
raise ValueError(
|
|
813
|
-
self.hive_path = Path(
|
|
826
|
+
raise ValueError("Must specify hive path within bucket")
|
|
827
|
+
self.hive_path = Path(hive_path)
|
|
814
828
|
self.hive_bucket = self.hive_path.parts[0]
|
|
815
829
|
else:
|
|
816
830
|
self.hive_path = None
|
|
@@ -818,18 +832,19 @@ class Repo:
|
|
|
818
832
|
|
|
819
833
|
#
|
|
820
834
|
|
|
821
|
-
self.index = Index(
|
|
835
|
+
self.index = Index(redis=redis)
|
|
822
836
|
|
|
823
837
|
##
|
|
824
838
|
|
|
825
|
-
def insert(
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
839
|
+
def insert(
|
|
840
|
+
self,
|
|
841
|
+
ds: Dataset[T],
|
|
842
|
+
*,
|
|
843
|
+
name: str,
|
|
844
|
+
cache_local: bool = False,
|
|
845
|
+
schema_ref: str | None = None,
|
|
846
|
+
**kwargs,
|
|
847
|
+
) -> tuple[LocalDatasetEntry, Dataset[T]]:
|
|
833
848
|
"""Insert a dataset into the repository.
|
|
834
849
|
|
|
835
850
|
Writes the dataset to S3 as WebDataset tar files, stores metadata,
|
|
@@ -853,35 +868,35 @@ class Repo:
|
|
|
853
868
|
RuntimeError: If no shards were written.
|
|
854
869
|
"""
|
|
855
870
|
if self.s3_credentials is None:
|
|
856
|
-
raise ValueError(
|
|
871
|
+
raise ValueError(
|
|
872
|
+
"S3 credentials required for insert(). Initialize Repo with s3_credentials."
|
|
873
|
+
)
|
|
857
874
|
if self.hive_bucket is None or self.hive_path is None:
|
|
858
|
-
raise ValueError(
|
|
875
|
+
raise ValueError(
|
|
876
|
+
"hive_path required for insert(). Initialize Repo with hive_path."
|
|
877
|
+
)
|
|
859
878
|
|
|
860
|
-
new_uuid = str(
|
|
879
|
+
new_uuid = str(uuid4())
|
|
861
880
|
|
|
862
|
-
hive_fs = _s3_from_credentials(
|
|
881
|
+
hive_fs = _s3_from_credentials(self.s3_credentials)
|
|
863
882
|
|
|
864
883
|
# Write metadata
|
|
865
884
|
metadata_path = (
|
|
866
|
-
self.hive_path
|
|
867
|
-
/ 'metadata'
|
|
868
|
-
/ f'atdata-metadata--{new_uuid}.msgpack'
|
|
885
|
+
self.hive_path / "metadata" / f"atdata-metadata--{new_uuid}.msgpack"
|
|
869
886
|
)
|
|
870
887
|
# Note: S3 doesn't need directories created beforehand - s3fs handles this
|
|
871
888
|
|
|
872
889
|
if ds.metadata is not None:
|
|
873
890
|
# Use s3:// prefix to ensure s3fs treats this as an S3 path
|
|
874
|
-
with cast(
|
|
875
|
-
|
|
891
|
+
with cast(
|
|
892
|
+
BinaryIO, hive_fs.open(f"s3://{metadata_path.as_posix()}", "wb")
|
|
893
|
+
) as f:
|
|
894
|
+
meta_packed = msgpack.packb(ds.metadata)
|
|
876
895
|
assert meta_packed is not None
|
|
877
|
-
f.write(
|
|
878
|
-
|
|
896
|
+
f.write(cast(bytes, meta_packed))
|
|
879
897
|
|
|
880
898
|
# Write data
|
|
881
|
-
shard_pattern = (
|
|
882
|
-
self.hive_path
|
|
883
|
-
/ f'atdata--{new_uuid}--%06d.tar'
|
|
884
|
-
).as_posix()
|
|
899
|
+
shard_pattern = (self.hive_path / f"atdata--{new_uuid}--%06d.tar").as_posix()
|
|
885
900
|
|
|
886
901
|
written_shards: list[str] = []
|
|
887
902
|
with TemporaryDirectory() as temp_dir:
|
|
@@ -904,24 +919,22 @@ class Repo:
|
|
|
904
919
|
sink.write(sample.as_wds)
|
|
905
920
|
|
|
906
921
|
# Make a new Dataset object for the written dataset copy
|
|
907
|
-
if len(
|
|
908
|
-
raise RuntimeError(
|
|
909
|
-
|
|
910
|
-
|
|
922
|
+
if len(written_shards) == 0:
|
|
923
|
+
raise RuntimeError(
|
|
924
|
+
"Cannot form new dataset entry -- did not write any shards"
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
elif len(written_shards) < 2:
|
|
911
928
|
new_dataset_url = (
|
|
912
|
-
self.hive_path
|
|
913
|
-
/ ( Path( written_shards[0] ).name )
|
|
929
|
+
self.hive_path / (Path(written_shards[0]).name)
|
|
914
930
|
).as_posix()
|
|
915
931
|
|
|
916
932
|
else:
|
|
917
933
|
shard_s3_format = (
|
|
918
|
-
(
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
) + '--{shard_id}.tar'
|
|
923
|
-
shard_id_braced = '{' + f'{0:06d}..{len( written_shards ) - 1:06d}' + '}'
|
|
924
|
-
new_dataset_url = shard_s3_format.format( shard_id = shard_id_braced )
|
|
934
|
+
(self.hive_path / f"atdata--{new_uuid}").as_posix()
|
|
935
|
+
) + "--{shard_id}.tar"
|
|
936
|
+
shard_id_braced = "{" + f"{0:06d}..{len(written_shards) - 1:06d}" + "}"
|
|
937
|
+
new_dataset_url = shard_s3_format.format(shard_id=shard_id_braced)
|
|
925
938
|
|
|
926
939
|
new_dataset = Dataset[ds.sample_type](
|
|
927
940
|
url=new_dataset_url,
|
|
@@ -995,6 +1008,7 @@ class Index:
|
|
|
995
1008
|
# Providing stub_dir implies auto_stubs=True
|
|
996
1009
|
if auto_stubs or stub_dir is not None:
|
|
997
1010
|
from ._stub_manager import StubManager
|
|
1011
|
+
|
|
998
1012
|
self._stub_manager: StubManager | None = StubManager(stub_dir=stub_dir)
|
|
999
1013
|
else:
|
|
1000
1014
|
self._stub_manager = None
|
|
@@ -1027,12 +1041,10 @@ class Index:
|
|
|
1027
1041
|
After calling :meth:`load_schema`, schema types become available
|
|
1028
1042
|
as attributes on this namespace.
|
|
1029
1043
|
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
>>> MyType = index.types.MySample
|
|
1035
|
-
>>> sample = MyType(name="hello", value=42)
|
|
1044
|
+
Examples:
|
|
1045
|
+
>>> index.load_schema("atdata://local/sampleSchema/MySample@1.0.0")
|
|
1046
|
+
>>> MyType = index.types.MySample
|
|
1047
|
+
>>> sample = MyType(name="hello", value=42)
|
|
1036
1048
|
|
|
1037
1049
|
Returns:
|
|
1038
1050
|
SchemaNamespace containing all loaded schema types.
|
|
@@ -1058,16 +1070,14 @@ class Index:
|
|
|
1058
1070
|
KeyError: If schema not found.
|
|
1059
1071
|
ValueError: If schema cannot be decoded.
|
|
1060
1072
|
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
>>> index.load_schema("atdata://local/sampleSchema/OtherType@1.0.0")
|
|
1070
|
-
>>> other = index.types.OtherType(data="test")
|
|
1073
|
+
Examples:
|
|
1074
|
+
>>> # Load and use immediately
|
|
1075
|
+
>>> MyType = index.load_schema("atdata://local/sampleSchema/MySample@1.0.0")
|
|
1076
|
+
>>> sample = MyType(name="hello", value=42)
|
|
1077
|
+
>>>
|
|
1078
|
+
>>> # Or access later via namespace
|
|
1079
|
+
>>> index.load_schema("atdata://local/sampleSchema/OtherType@1.0.0")
|
|
1080
|
+
>>> other = index.types.OtherType(data="test")
|
|
1071
1081
|
"""
|
|
1072
1082
|
# Decode the schema (uses generated module if auto_stubs enabled)
|
|
1073
1083
|
cls = self.decode_schema(ref)
|
|
@@ -1090,16 +1100,14 @@ class Index:
|
|
|
1090
1100
|
Import path like "local.MySample_1_0_0", or None if auto_stubs
|
|
1091
1101
|
is disabled.
|
|
1092
1102
|
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
>>> # Then in your code:
|
|
1102
|
-
>>> # from local.MySample_1_0_0 import MySample
|
|
1103
|
+
Examples:
|
|
1104
|
+
>>> index = LocalIndex(auto_stubs=True)
|
|
1105
|
+
>>> ref = index.publish_schema(MySample, version="1.0.0")
|
|
1106
|
+
>>> index.load_schema(ref)
|
|
1107
|
+
>>> print(index.get_import_path(ref))
|
|
1108
|
+
local.MySample_1_0_0
|
|
1109
|
+
>>> # Then in your code:
|
|
1110
|
+
>>> # from local.MySample_1_0_0 import MySample
|
|
1103
1111
|
"""
|
|
1104
1112
|
if self._stub_manager is None:
|
|
1105
1113
|
return None
|
|
@@ -1138,19 +1146,20 @@ class Index:
|
|
|
1138
1146
|
Yields:
|
|
1139
1147
|
LocalDatasetEntry objects from the index.
|
|
1140
1148
|
"""
|
|
1141
|
-
prefix = f
|
|
1142
|
-
for key in self._redis.scan_iter(match=f
|
|
1143
|
-
key_str = key.decode(
|
|
1144
|
-
cid = key_str[len(prefix):]
|
|
1149
|
+
prefix = f"{REDIS_KEY_DATASET_ENTRY}:"
|
|
1150
|
+
for key in self._redis.scan_iter(match=f"{prefix}*"):
|
|
1151
|
+
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
|
|
1152
|
+
cid = key_str[len(prefix) :]
|
|
1145
1153
|
yield LocalDatasetEntry.from_redis(self._redis, cid)
|
|
1146
1154
|
|
|
1147
|
-
def add_entry(
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1155
|
+
def add_entry(
|
|
1156
|
+
self,
|
|
1157
|
+
ds: Dataset,
|
|
1158
|
+
*,
|
|
1159
|
+
name: str,
|
|
1160
|
+
schema_ref: str | None = None,
|
|
1161
|
+
metadata: dict | None = None,
|
|
1162
|
+
) -> LocalDatasetEntry:
|
|
1154
1163
|
"""Add a dataset to the index.
|
|
1155
1164
|
|
|
1156
1165
|
Creates a LocalDatasetEntry for the dataset and persists it to Redis.
|
|
@@ -1166,7 +1175,9 @@ class Index:
|
|
|
1166
1175
|
"""
|
|
1167
1176
|
##
|
|
1168
1177
|
if schema_ref is None:
|
|
1169
|
-
schema_ref =
|
|
1178
|
+
schema_ref = (
|
|
1179
|
+
f"local://schemas/{_kind_str_for_sample_type(ds.sample_type)}@1.0.0"
|
|
1180
|
+
)
|
|
1170
1181
|
|
|
1171
1182
|
# Normalize URL to list
|
|
1172
1183
|
data_urls = [ds.url]
|
|
@@ -1245,12 +1256,12 @@ class Index:
|
|
|
1245
1256
|
Returns:
|
|
1246
1257
|
IndexEntry for the inserted dataset.
|
|
1247
1258
|
"""
|
|
1248
|
-
metadata = kwargs.get(
|
|
1259
|
+
metadata = kwargs.get("metadata")
|
|
1249
1260
|
|
|
1250
1261
|
if self._data_store is not None:
|
|
1251
1262
|
# Write shards to data store, then index the new URLs
|
|
1252
|
-
prefix = kwargs.get(
|
|
1253
|
-
cache_local = kwargs.get(
|
|
1263
|
+
prefix = kwargs.get("prefix", name)
|
|
1264
|
+
cache_local = kwargs.get("cache_local", False)
|
|
1254
1265
|
|
|
1255
1266
|
written_urls = self._data_store.write_shards(
|
|
1256
1267
|
ds,
|
|
@@ -1314,10 +1325,10 @@ class Index:
|
|
|
1314
1325
|
latest_version: tuple[int, int, int] | None = None
|
|
1315
1326
|
latest_version_str: str | None = None
|
|
1316
1327
|
|
|
1317
|
-
prefix = f
|
|
1318
|
-
for key in self._redis.scan_iter(match=f
|
|
1319
|
-
key_str = key.decode(
|
|
1320
|
-
schema_id = key_str[len(prefix):]
|
|
1328
|
+
prefix = f"{REDIS_KEY_SCHEMA}:"
|
|
1329
|
+
for key in self._redis.scan_iter(match=f"{prefix}*"):
|
|
1330
|
+
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
|
|
1331
|
+
schema_id = key_str[len(prefix) :]
|
|
1321
1332
|
|
|
1322
1333
|
if "@" not in schema_id:
|
|
1323
1334
|
continue
|
|
@@ -1369,10 +1380,12 @@ class Index:
|
|
|
1369
1380
|
# This catches non-packable types early with a clear error message
|
|
1370
1381
|
try:
|
|
1371
1382
|
# Check protocol compliance by verifying required methods exist
|
|
1372
|
-
if not (
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1383
|
+
if not (
|
|
1384
|
+
hasattr(sample_type, "from_data")
|
|
1385
|
+
and hasattr(sample_type, "from_bytes")
|
|
1386
|
+
and callable(getattr(sample_type, "from_data", None))
|
|
1387
|
+
and callable(getattr(sample_type, "from_bytes", None))
|
|
1388
|
+
):
|
|
1376
1389
|
raise TypeError(
|
|
1377
1390
|
f"{sample_type.__name__} does not satisfy the Packable protocol. "
|
|
1378
1391
|
"Use @packable decorator or inherit from PackableSample."
|
|
@@ -1430,10 +1443,10 @@ class Index:
|
|
|
1430
1443
|
raise KeyError(f"Schema not found: {ref}")
|
|
1431
1444
|
|
|
1432
1445
|
if isinstance(schema_json, bytes):
|
|
1433
|
-
schema_json = schema_json.decode(
|
|
1446
|
+
schema_json = schema_json.decode("utf-8")
|
|
1434
1447
|
|
|
1435
1448
|
schema = json.loads(schema_json)
|
|
1436
|
-
schema[
|
|
1449
|
+
schema["$ref"] = _make_schema_ref(name, version)
|
|
1437
1450
|
|
|
1438
1451
|
# Auto-generate stub if enabled
|
|
1439
1452
|
if self._stub_manager is not None:
|
|
@@ -1468,29 +1481,29 @@ class Index:
|
|
|
1468
1481
|
Yields:
|
|
1469
1482
|
LocalSchemaRecord for each schema.
|
|
1470
1483
|
"""
|
|
1471
|
-
prefix = f
|
|
1472
|
-
for key in self._redis.scan_iter(match=f
|
|
1473
|
-
key_str = key.decode(
|
|
1484
|
+
prefix = f"{REDIS_KEY_SCHEMA}:"
|
|
1485
|
+
for key in self._redis.scan_iter(match=f"{prefix}*"):
|
|
1486
|
+
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
|
|
1474
1487
|
# Extract name@version from key
|
|
1475
|
-
schema_id = key_str[len(prefix):]
|
|
1488
|
+
schema_id = key_str[len(prefix) :]
|
|
1476
1489
|
|
|
1477
1490
|
schema_json = self._redis.get(key)
|
|
1478
1491
|
if schema_json is None:
|
|
1479
1492
|
continue
|
|
1480
1493
|
|
|
1481
1494
|
if isinstance(schema_json, bytes):
|
|
1482
|
-
schema_json = schema_json.decode(
|
|
1495
|
+
schema_json = schema_json.decode("utf-8")
|
|
1483
1496
|
|
|
1484
1497
|
schema = json.loads(schema_json)
|
|
1485
1498
|
# Handle legacy keys that have module.Class format
|
|
1486
1499
|
if "." in schema_id.split("@")[0]:
|
|
1487
1500
|
name = schema_id.split("@")[0].rsplit(".", 1)[1]
|
|
1488
1501
|
version = schema_id.split("@")[1]
|
|
1489
|
-
schema[
|
|
1502
|
+
schema["$ref"] = _make_schema_ref(name, version)
|
|
1490
1503
|
else:
|
|
1491
1504
|
# schema_id is already "name@version"
|
|
1492
1505
|
name, version = schema_id.rsplit("@", 1)
|
|
1493
|
-
schema[
|
|
1506
|
+
schema["$ref"] = _make_schema_ref(name, version)
|
|
1494
1507
|
yield LocalSchemaRecord.from_dict(schema)
|
|
1495
1508
|
|
|
1496
1509
|
def list_schemas(self) -> list[dict]:
|
|
@@ -1534,6 +1547,7 @@ class Index:
|
|
|
1534
1547
|
|
|
1535
1548
|
# Fall back to dynamic type generation
|
|
1536
1549
|
from atdata._schema_codec import schema_to_type
|
|
1550
|
+
|
|
1537
1551
|
return schema_to_type(schema_dict)
|
|
1538
1552
|
|
|
1539
1553
|
def decode_schema_as(self, ref: str, type_hint: type[T]) -> type[T]:
|
|
@@ -1551,15 +1565,13 @@ class Index:
|
|
|
1551
1565
|
Returns:
|
|
1552
1566
|
The decoded type, cast to match the type_hint for IDE support.
|
|
1553
1567
|
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
>>> DecodedType = index.decode_schema_as(ref, MySample)
|
|
1562
|
-
>>> sample = DecodedType(text="hello", value=42) # IDE knows signature!
|
|
1568
|
+
Examples:
|
|
1569
|
+
>>> # After enabling auto_stubs and configuring IDE extraPaths:
|
|
1570
|
+
>>> from local.MySample_1_0_0 import MySample
|
|
1571
|
+
>>>
|
|
1572
|
+
>>> # This gives full IDE autocomplete:
|
|
1573
|
+
>>> DecodedType = index.decode_schema_as(ref, MySample)
|
|
1574
|
+
>>> sample = DecodedType(text="hello", value=42) # IDE knows signature!
|
|
1563
1575
|
|
|
1564
1576
|
Note:
|
|
1565
1577
|
The type_hint is only used for static type checking - at runtime,
|
|
@@ -1567,6 +1579,7 @@ class Index:
|
|
|
1567
1579
|
stub matches the schema to avoid runtime surprises.
|
|
1568
1580
|
"""
|
|
1569
1581
|
from typing import cast
|
|
1582
|
+
|
|
1570
1583
|
return cast(type[T], self.decode_schema(ref))
|
|
1571
1584
|
|
|
1572
1585
|
def clear_stubs(self) -> int:
|
|
@@ -1687,11 +1700,11 @@ class S3DataStore:
|
|
|
1687
1700
|
HTTPS URL if custom endpoint is configured, otherwise unchanged.
|
|
1688
1701
|
Example: 's3://bucket/path' -> 'https://endpoint.com/bucket/path'
|
|
1689
1702
|
"""
|
|
1690
|
-
endpoint = self.credentials.get(
|
|
1691
|
-
if endpoint and url.startswith(
|
|
1703
|
+
endpoint = self.credentials.get("AWS_ENDPOINT")
|
|
1704
|
+
if endpoint and url.startswith("s3://"):
|
|
1692
1705
|
# s3://bucket/path -> https://endpoint/bucket/path
|
|
1693
1706
|
path = url[5:] # Remove 's3://' prefix
|
|
1694
|
-
endpoint = endpoint.rstrip(
|
|
1707
|
+
endpoint = endpoint.rstrip("/")
|
|
1695
1708
|
return f"{endpoint}/{path}"
|
|
1696
1709
|
return url
|
|
1697
1710
|
|
|
@@ -1704,4 +1717,4 @@ class S3DataStore:
|
|
|
1704
1717
|
return True
|
|
1705
1718
|
|
|
1706
1719
|
|
|
1707
|
-
#
|
|
1720
|
+
#
|