atdata 0.3.0b1__py3-none-any.whl → 0.3.2b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +11 -0
- atdata/_cid.py +0 -21
- atdata/_helpers.py +12 -0
- atdata/_hf_api.py +46 -1
- atdata/_logging.py +43 -0
- atdata/_protocols.py +81 -182
- atdata/_schema_codec.py +2 -2
- atdata/_sources.py +24 -4
- atdata/_stub_manager.py +5 -25
- atdata/atmosphere/__init__.py +60 -21
- atdata/atmosphere/_lexicon_types.py +595 -0
- atdata/atmosphere/_types.py +73 -245
- atdata/atmosphere/client.py +64 -12
- atdata/atmosphere/lens.py +60 -53
- atdata/atmosphere/records.py +291 -100
- atdata/atmosphere/schema.py +91 -65
- atdata/atmosphere/store.py +68 -66
- atdata/cli/__init__.py +16 -16
- atdata/cli/diagnose.py +2 -2
- atdata/cli/{local.py → infra.py} +10 -10
- atdata/dataset.py +266 -47
- atdata/index/__init__.py +54 -0
- atdata/{local → index}/_entry.py +6 -2
- atdata/{local → index}/_index.py +617 -72
- atdata/{local → index}/_schema.py +5 -5
- atdata/lexicons/__init__.py +127 -0
- atdata/lexicons/ac.foundation.dataset.arrayFormat.json +16 -0
- atdata/lexicons/ac.foundation.dataset.getLatestSchema.json +78 -0
- atdata/lexicons/ac.foundation.dataset.lens.json +101 -0
- atdata/lexicons/ac.foundation.dataset.record.json +117 -0
- atdata/lexicons/ac.foundation.dataset.schema.json +107 -0
- atdata/lexicons/ac.foundation.dataset.schemaType.json +16 -0
- atdata/lexicons/ac.foundation.dataset.storageBlobs.json +46 -0
- atdata/lexicons/ac.foundation.dataset.storageExternal.json +25 -0
- atdata/lexicons/ac.foundation.dataset.storageHttp.json +45 -0
- atdata/lexicons/ac.foundation.dataset.storageS3.json +61 -0
- atdata/lexicons/ndarray_shim.json +16 -0
- atdata/local/__init__.py +12 -13
- atdata/local/_repo_legacy.py +3 -3
- atdata/manifest/__init__.py +4 -0
- atdata/manifest/_proxy.py +321 -0
- atdata/promote.py +14 -10
- atdata/repository.py +66 -16
- atdata/stores/__init__.py +23 -0
- atdata/stores/_disk.py +131 -0
- atdata/{local → stores}/_s3.py +134 -112
- atdata/testing.py +12 -8
- {atdata-0.3.0b1.dist-info → atdata-0.3.2b1.dist-info}/METADATA +2 -2
- atdata-0.3.2b1.dist-info/RECORD +71 -0
- atdata-0.3.0b1.dist-info/RECORD +0 -54
- {atdata-0.3.0b1.dist-info → atdata-0.3.2b1.dist-info}/WHEEL +0 -0
- {atdata-0.3.0b1.dist-info → atdata-0.3.2b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.3.0b1.dist-info → atdata-0.3.2b1.dist-info}/licenses/LICENSE +0 -0
atdata/dataset.py
CHANGED
|
@@ -47,8 +47,6 @@ from ._protocols import DataSource, Packable
|
|
|
47
47
|
from ._exceptions import SampleKeyError, PartialFailureError
|
|
48
48
|
|
|
49
49
|
import numpy as np
|
|
50
|
-
import pandas as pd
|
|
51
|
-
import requests
|
|
52
50
|
|
|
53
51
|
import typing
|
|
54
52
|
from typing import (
|
|
@@ -70,6 +68,9 @@ from typing import (
|
|
|
70
68
|
)
|
|
71
69
|
|
|
72
70
|
if TYPE_CHECKING:
|
|
71
|
+
import pandas
|
|
72
|
+
import pandas as pd
|
|
73
|
+
from .manifest._proxy import Predicate
|
|
73
74
|
from .manifest._query import SampleLocation
|
|
74
75
|
from numpy.typing import NDArray
|
|
75
76
|
|
|
@@ -99,9 +100,11 @@ DT = TypeVar("DT")
|
|
|
99
100
|
|
|
100
101
|
|
|
101
102
|
def _make_packable(x):
|
|
102
|
-
"""Convert numpy arrays to bytes;
|
|
103
|
+
"""Convert numpy arrays to bytes; coerce numpy scalars to Python natives."""
|
|
103
104
|
if isinstance(x, np.ndarray):
|
|
104
105
|
return eh.array_to_bytes(x)
|
|
106
|
+
if isinstance(x, np.generic):
|
|
107
|
+
return x.item()
|
|
105
108
|
return x
|
|
106
109
|
|
|
107
110
|
|
|
@@ -280,16 +283,9 @@ class PackableSample(ABC):
|
|
|
280
283
|
|
|
281
284
|
@property
|
|
282
285
|
def packed(self) -> bytes:
|
|
283
|
-
"""Serialize to msgpack bytes. NDArray fields are auto-converted.
|
|
284
|
-
|
|
285
|
-
Raises:
|
|
286
|
-
RuntimeError: If msgpack serialization fails.
|
|
287
|
-
"""
|
|
286
|
+
"""Serialize to msgpack bytes. NDArray fields are auto-converted."""
|
|
288
287
|
o = {k: _make_packable(v) for k, v in vars(self).items()}
|
|
289
|
-
|
|
290
|
-
if ret is None:
|
|
291
|
-
raise RuntimeError(f"Failed to pack sample to bytes: {o}")
|
|
292
|
-
return ret
|
|
288
|
+
return msgpack.packb(o)
|
|
293
289
|
|
|
294
290
|
@property
|
|
295
291
|
def as_wds(self) -> WDSRawSample:
|
|
@@ -305,7 +301,7 @@ def _batch_aggregate(xs: Sequence):
|
|
|
305
301
|
if not xs:
|
|
306
302
|
return []
|
|
307
303
|
if isinstance(xs[0], np.ndarray):
|
|
308
|
-
return np.
|
|
304
|
+
return np.stack(xs)
|
|
309
305
|
return list(xs)
|
|
310
306
|
|
|
311
307
|
|
|
@@ -540,6 +536,8 @@ class Dataset(Generic[ST]):
|
|
|
540
536
|
return None
|
|
541
537
|
|
|
542
538
|
if self._metadata is None:
|
|
539
|
+
import requests
|
|
540
|
+
|
|
543
541
|
with requests.get(self.metadata_url, stream=True) as response:
|
|
544
542
|
response.raise_for_status()
|
|
545
543
|
self._metadata = msgpack.unpackb(response.content, raw=False)
|
|
@@ -708,6 +706,8 @@ class Dataset(Generic[ST]):
|
|
|
708
706
|
fn: Callable[[list[ST]], Any],
|
|
709
707
|
*,
|
|
710
708
|
shards: list[str] | None = None,
|
|
709
|
+
checkpoint: Path | str | None = None,
|
|
710
|
+
on_shard_error: Callable[[str, Exception], None] | None = None,
|
|
711
711
|
) -> dict[str, Any]:
|
|
712
712
|
"""Process each shard independently, collecting per-shard results.
|
|
713
713
|
|
|
@@ -723,6 +723,14 @@ class Dataset(Generic[ST]):
|
|
|
723
723
|
shards: Optional list of shard identifiers to process. If ``None``,
|
|
724
724
|
processes all shards in the dataset. Useful for retrying only
|
|
725
725
|
the failed shards from a previous ``PartialFailureError``.
|
|
726
|
+
checkpoint: Optional path to a checkpoint file. If provided,
|
|
727
|
+
already-succeeded shard IDs are loaded from this file and
|
|
728
|
+
skipped. Each newly succeeded shard is appended. On full
|
|
729
|
+
success the file is deleted. On partial failure it remains
|
|
730
|
+
for resume.
|
|
731
|
+
on_shard_error: Optional callback invoked as
|
|
732
|
+
``on_shard_error(shard_id, exception)`` for each failed shard,
|
|
733
|
+
enabling dead-letter logging or alerting.
|
|
726
734
|
|
|
727
735
|
Returns:
|
|
728
736
|
Dict mapping shard identifier to *fn*'s return value for each shard.
|
|
@@ -739,45 +747,67 @@ class Dataset(Generic[ST]):
|
|
|
739
747
|
... results = ds.process_shards(expensive_fn)
|
|
740
748
|
... except PartialFailureError as e:
|
|
741
749
|
... retry = ds.process_shards(expensive_fn, shards=e.failed_shards)
|
|
750
|
+
|
|
751
|
+
>>> # With checkpoint for crash recovery:
|
|
752
|
+
>>> results = ds.process_shards(expensive_fn, checkpoint="progress.txt")
|
|
742
753
|
"""
|
|
743
|
-
from ._logging import get_logger
|
|
754
|
+
from ._logging import get_logger, log_operation
|
|
744
755
|
|
|
745
756
|
log = get_logger()
|
|
746
757
|
shard_ids = shards or self.list_shards()
|
|
747
|
-
|
|
758
|
+
|
|
759
|
+
# Load checkpoint: skip already-succeeded shards
|
|
760
|
+
checkpoint_path: Path | None = None
|
|
761
|
+
if checkpoint is not None:
|
|
762
|
+
checkpoint_path = Path(checkpoint)
|
|
763
|
+
if checkpoint_path.exists():
|
|
764
|
+
already_done = set(checkpoint_path.read_text().splitlines())
|
|
765
|
+
log.info(
|
|
766
|
+
"process_shards: loaded checkpoint, %d shards already done",
|
|
767
|
+
len(already_done),
|
|
768
|
+
)
|
|
769
|
+
shard_ids = [s for s in shard_ids if s not in already_done]
|
|
770
|
+
if not shard_ids:
|
|
771
|
+
log.info("process_shards: all shards already checkpointed")
|
|
772
|
+
return {}
|
|
748
773
|
|
|
749
774
|
succeeded: list[str] = []
|
|
750
775
|
failed: list[str] = []
|
|
751
776
|
errors: dict[str, Exception] = {}
|
|
752
777
|
results: dict[str, Any] = {}
|
|
753
778
|
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
+
with log_operation("process_shards", total_shards=len(shard_ids)):
|
|
780
|
+
for shard_id in shard_ids:
|
|
781
|
+
try:
|
|
782
|
+
shard_ds = Dataset[self.sample_type](shard_id)
|
|
783
|
+
shard_ds._sample_type_cache = self._sample_type_cache
|
|
784
|
+
samples = list(shard_ds.ordered())
|
|
785
|
+
results[shard_id] = fn(samples)
|
|
786
|
+
succeeded.append(shard_id)
|
|
787
|
+
log.debug("process_shards: shard ok %s", shard_id)
|
|
788
|
+
if checkpoint_path is not None:
|
|
789
|
+
with open(checkpoint_path, "a") as f:
|
|
790
|
+
f.write(shard_id + "\n")
|
|
791
|
+
except Exception as exc:
|
|
792
|
+
failed.append(shard_id)
|
|
793
|
+
errors[shard_id] = exc
|
|
794
|
+
log.warning("process_shards: shard failed %s: %s", shard_id, exc)
|
|
795
|
+
if on_shard_error is not None:
|
|
796
|
+
on_shard_error(shard_id, exc)
|
|
797
|
+
|
|
798
|
+
if failed:
|
|
799
|
+
raise PartialFailureError(
|
|
800
|
+
succeeded_shards=succeeded,
|
|
801
|
+
failed_shards=failed,
|
|
802
|
+
errors=errors,
|
|
803
|
+
results=results,
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
# All shards succeeded; clean up checkpoint file
|
|
807
|
+
if checkpoint_path is not None and checkpoint_path.exists():
|
|
808
|
+
checkpoint_path.unlink()
|
|
809
|
+
log.debug("process_shards: checkpoint file removed (all shards done)")
|
|
779
810
|
|
|
780
|
-
log.info("process_shards: all %d shards succeeded", len(shard_ids))
|
|
781
811
|
return results
|
|
782
812
|
|
|
783
813
|
def select(self, indices: Sequence[int]) -> list[ST]:
|
|
@@ -809,9 +839,25 @@ class Dataset(Generic[ST]):
|
|
|
809
839
|
break
|
|
810
840
|
return [result[i] for i in indices if i in result]
|
|
811
841
|
|
|
842
|
+
@property
|
|
843
|
+
def fields(self) -> "Any":
|
|
844
|
+
"""Typed field proxy for manifest queries on this dataset.
|
|
845
|
+
|
|
846
|
+
Returns an object whose attributes are ``FieldProxy`` instances,
|
|
847
|
+
one per manifest-eligible field of this dataset's sample type.
|
|
848
|
+
|
|
849
|
+
Examples:
|
|
850
|
+
>>> ds = atdata.Dataset[MySample](url)
|
|
851
|
+
>>> Q = ds.fields
|
|
852
|
+
>>> results = ds.query(where=(Q.confidence > 0.9))
|
|
853
|
+
"""
|
|
854
|
+
from .manifest._proxy import query_fields
|
|
855
|
+
|
|
856
|
+
return query_fields(self.sample_type)
|
|
857
|
+
|
|
812
858
|
def query(
|
|
813
859
|
self,
|
|
814
|
-
where: "Callable[[pd.DataFrame], pd.Series]",
|
|
860
|
+
where: "Callable[[pd.DataFrame], pd.Series] | Predicate",
|
|
815
861
|
) -> "list[SampleLocation]":
|
|
816
862
|
"""Query this dataset using per-shard manifest metadata.
|
|
817
863
|
|
|
@@ -820,10 +866,12 @@ class Dataset(Generic[ST]):
|
|
|
820
866
|
and executes a two-phase query (shard-level aggregate pruning,
|
|
821
867
|
then sample-level parquet filtering).
|
|
822
868
|
|
|
869
|
+
The *where* argument accepts either a lambda/function that operates
|
|
870
|
+
on a pandas DataFrame, or a ``Predicate`` built from the proxy DSL.
|
|
871
|
+
|
|
823
872
|
Args:
|
|
824
|
-
where: Predicate function
|
|
825
|
-
|
|
826
|
-
matching rows.
|
|
873
|
+
where: Predicate function or ``Predicate`` object that selects
|
|
874
|
+
matching rows from the per-sample manifest DataFrame.
|
|
827
875
|
|
|
828
876
|
Returns:
|
|
829
877
|
List of ``SampleLocation`` for matching samples.
|
|
@@ -835,6 +883,9 @@ class Dataset(Generic[ST]):
|
|
|
835
883
|
>>> locs = ds.query(where=lambda df: df["confidence"] > 0.9)
|
|
836
884
|
>>> len(locs)
|
|
837
885
|
42
|
|
886
|
+
|
|
887
|
+
>>> Q = ds.fields
|
|
888
|
+
>>> locs = ds.query(where=(Q.confidence > 0.9))
|
|
838
889
|
"""
|
|
839
890
|
from .manifest import QueryExecutor
|
|
840
891
|
|
|
@@ -842,7 +893,7 @@ class Dataset(Generic[ST]):
|
|
|
842
893
|
executor = QueryExecutor.from_shard_urls(shard_urls)
|
|
843
894
|
return executor.query(where=where)
|
|
844
895
|
|
|
845
|
-
def to_pandas(self, limit: int | None = None) -> "
|
|
896
|
+
def to_pandas(self, limit: int | None = None) -> "pandas.DataFrame":
|
|
846
897
|
"""Materialize the dataset (or first *limit* samples) as a DataFrame.
|
|
847
898
|
|
|
848
899
|
Args:
|
|
@@ -865,6 +916,8 @@ class Dataset(Generic[ST]):
|
|
|
865
916
|
rows = [
|
|
866
917
|
asdict(s) if dataclasses.is_dataclass(s) else s.to_dict() for s in samples
|
|
867
918
|
]
|
|
919
|
+
import pandas as pd
|
|
920
|
+
|
|
868
921
|
return pd.DataFrame(rows)
|
|
869
922
|
|
|
870
923
|
def to_dict(self, limit: int | None = None) -> dict[str, list[Any]]:
|
|
@@ -1059,6 +1112,8 @@ class Dataset(Generic[ST]):
|
|
|
1059
1112
|
Examples:
|
|
1060
1113
|
>>> ds.to_parquet("output.parquet", maxcount=50000)
|
|
1061
1114
|
"""
|
|
1115
|
+
import pandas as pd
|
|
1116
|
+
|
|
1062
1117
|
path = Path(path)
|
|
1063
1118
|
if sample_map is None:
|
|
1064
1119
|
sample_map = asdict
|
|
@@ -1127,7 +1182,7 @@ _T = TypeVar("_T")
|
|
|
1127
1182
|
|
|
1128
1183
|
|
|
1129
1184
|
@dataclass_transform()
|
|
1130
|
-
def packable(cls: type[_T]) -> type[
|
|
1185
|
+
def packable(cls: type[_T]) -> type[_T]:
|
|
1131
1186
|
"""Convert a class into a ``PackableSample`` dataclass with msgpack serialization.
|
|
1132
1187
|
|
|
1133
1188
|
The resulting class gains ``packed``, ``as_wds``, ``from_bytes``, and
|
|
@@ -1188,3 +1243,167 @@ def packable(cls: type[_T]) -> type[Packable]:
|
|
|
1188
1243
|
##
|
|
1189
1244
|
|
|
1190
1245
|
return as_packable
|
|
1246
|
+
|
|
1247
|
+
|
|
1248
|
+
# ---------------------------------------------------------------------------
|
|
1249
|
+
# write_samples — convenience function for writing samples to tar files
|
|
1250
|
+
# ---------------------------------------------------------------------------
|
|
1251
|
+
|
|
1252
|
+
|
|
1253
|
+
def write_samples(
|
|
1254
|
+
samples: Iterable[ST],
|
|
1255
|
+
path: str | Path,
|
|
1256
|
+
*,
|
|
1257
|
+
maxcount: int | None = None,
|
|
1258
|
+
maxsize: int | None = None,
|
|
1259
|
+
manifest: bool = False,
|
|
1260
|
+
) -> "Dataset[ST]":
|
|
1261
|
+
"""Write an iterable of samples to WebDataset tar file(s).
|
|
1262
|
+
|
|
1263
|
+
Args:
|
|
1264
|
+
samples: Iterable of ``PackableSample`` instances. Must be non-empty.
|
|
1265
|
+
path: Output path for the tar file. For sharded output (when
|
|
1266
|
+
*maxcount* or *maxsize* is set), a ``%06d`` pattern is
|
|
1267
|
+
auto-appended if the path does not already contain ``%``.
|
|
1268
|
+
maxcount: Maximum samples per shard. Triggers multi-shard output.
|
|
1269
|
+
maxsize: Maximum bytes per shard. Triggers multi-shard output.
|
|
1270
|
+
manifest: If True, write per-shard manifest sidecar files
|
|
1271
|
+
(``.manifest.json`` + ``.manifest.parquet``) alongside each
|
|
1272
|
+
tar file. Manifests enable metadata queries via
|
|
1273
|
+
``QueryExecutor`` without opening the tars.
|
|
1274
|
+
|
|
1275
|
+
Returns:
|
|
1276
|
+
A ``Dataset`` wrapping the written file(s), typed to the sample
|
|
1277
|
+
type of the input samples.
|
|
1278
|
+
|
|
1279
|
+
Raises:
|
|
1280
|
+
ValueError: If *samples* is empty.
|
|
1281
|
+
|
|
1282
|
+
Examples:
|
|
1283
|
+
>>> samples = [MySample(key="0", text="hello")]
|
|
1284
|
+
>>> ds = write_samples(samples, "out.tar")
|
|
1285
|
+
>>> list(ds.ordered())
|
|
1286
|
+
[MySample(key='0', text='hello')]
|
|
1287
|
+
"""
|
|
1288
|
+
from ._hf_api import _shards_to_wds_url
|
|
1289
|
+
from ._logging import get_logger, log_operation
|
|
1290
|
+
|
|
1291
|
+
if manifest:
|
|
1292
|
+
from .manifest._builder import ManifestBuilder
|
|
1293
|
+
from .manifest._writer import ManifestWriter
|
|
1294
|
+
|
|
1295
|
+
log = get_logger()
|
|
1296
|
+
path = Path(path)
|
|
1297
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
1298
|
+
|
|
1299
|
+
use_shard_writer = maxcount is not None or maxsize is not None
|
|
1300
|
+
sample_type: type | None = None
|
|
1301
|
+
written_paths: list[str] = []
|
|
1302
|
+
|
|
1303
|
+
with log_operation(
|
|
1304
|
+
"write_samples", path=str(path), sharded=use_shard_writer, manifest=manifest
|
|
1305
|
+
):
|
|
1306
|
+
# Manifest tracking state
|
|
1307
|
+
_current_builder: list = [] # single-element list for nonlocal mutation
|
|
1308
|
+
_builders: list[tuple[str, "ManifestBuilder"]] = []
|
|
1309
|
+
_running_offset: list[int] = [0]
|
|
1310
|
+
|
|
1311
|
+
def _finalize_builder() -> None:
|
|
1312
|
+
"""Finalize the current manifest builder and stash it."""
|
|
1313
|
+
if _current_builder:
|
|
1314
|
+
shard_path = written_paths[-1] if written_paths else ""
|
|
1315
|
+
_builders.append((shard_path, _current_builder[0]))
|
|
1316
|
+
_current_builder.clear()
|
|
1317
|
+
|
|
1318
|
+
def _start_builder(shard_path: str) -> None:
|
|
1319
|
+
"""Start a new manifest builder for a shard."""
|
|
1320
|
+
_finalize_builder()
|
|
1321
|
+
shard_id = Path(shard_path).stem
|
|
1322
|
+
_current_builder.append(
|
|
1323
|
+
ManifestBuilder(sample_type=sample_type, shard_id=shard_id)
|
|
1324
|
+
)
|
|
1325
|
+
_running_offset[0] = 0
|
|
1326
|
+
|
|
1327
|
+
def _record_sample(sample: "PackableSample", wds_dict: dict) -> None:
|
|
1328
|
+
"""Record a sample in the active manifest builder."""
|
|
1329
|
+
if not _current_builder:
|
|
1330
|
+
return
|
|
1331
|
+
packed_bytes = wds_dict["msgpack"]
|
|
1332
|
+
size = len(packed_bytes)
|
|
1333
|
+
_current_builder[0].add_sample(
|
|
1334
|
+
key=wds_dict["__key__"],
|
|
1335
|
+
offset=_running_offset[0],
|
|
1336
|
+
size=size,
|
|
1337
|
+
sample=sample,
|
|
1338
|
+
)
|
|
1339
|
+
_running_offset[0] += size
|
|
1340
|
+
|
|
1341
|
+
if use_shard_writer:
|
|
1342
|
+
# Build shard pattern from path
|
|
1343
|
+
if "%" not in str(path):
|
|
1344
|
+
pattern = str(path.parent / f"{path.stem}-%06d{path.suffix}")
|
|
1345
|
+
else:
|
|
1346
|
+
pattern = str(path)
|
|
1347
|
+
|
|
1348
|
+
writer_kwargs: dict[str, Any] = {}
|
|
1349
|
+
if maxcount is not None:
|
|
1350
|
+
writer_kwargs["maxcount"] = maxcount
|
|
1351
|
+
if maxsize is not None:
|
|
1352
|
+
writer_kwargs["maxsize"] = maxsize
|
|
1353
|
+
|
|
1354
|
+
def _track(p: str) -> None:
|
|
1355
|
+
written_paths.append(str(Path(p).resolve()))
|
|
1356
|
+
if manifest and sample_type is not None:
|
|
1357
|
+
_start_builder(p)
|
|
1358
|
+
|
|
1359
|
+
with wds.writer.ShardWriter(pattern, post=_track, **writer_kwargs) as sink:
|
|
1360
|
+
for sample in samples:
|
|
1361
|
+
if sample_type is None:
|
|
1362
|
+
sample_type = type(sample)
|
|
1363
|
+
wds_dict = sample.as_wds
|
|
1364
|
+
sink.write(wds_dict)
|
|
1365
|
+
if manifest:
|
|
1366
|
+
# The first sample triggers _track before we get here when
|
|
1367
|
+
# ShardWriter opens the first shard, but just in case:
|
|
1368
|
+
if not _current_builder and sample_type is not None:
|
|
1369
|
+
_start_builder(str(path))
|
|
1370
|
+
_record_sample(sample, wds_dict)
|
|
1371
|
+
else:
|
|
1372
|
+
with wds.writer.TarWriter(str(path)) as sink:
|
|
1373
|
+
for sample in samples:
|
|
1374
|
+
if sample_type is None:
|
|
1375
|
+
sample_type = type(sample)
|
|
1376
|
+
wds_dict = sample.as_wds
|
|
1377
|
+
sink.write(wds_dict)
|
|
1378
|
+
if manifest:
|
|
1379
|
+
if not _current_builder and sample_type is not None:
|
|
1380
|
+
_current_builder.append(
|
|
1381
|
+
ManifestBuilder(
|
|
1382
|
+
sample_type=sample_type, shard_id=path.stem
|
|
1383
|
+
)
|
|
1384
|
+
)
|
|
1385
|
+
_record_sample(sample, wds_dict)
|
|
1386
|
+
written_paths.append(str(path.resolve()))
|
|
1387
|
+
|
|
1388
|
+
if sample_type is None:
|
|
1389
|
+
raise ValueError("samples must be non-empty")
|
|
1390
|
+
|
|
1391
|
+
# Finalize and write manifests
|
|
1392
|
+
if manifest:
|
|
1393
|
+
_finalize_builder()
|
|
1394
|
+
for shard_path, builder in _builders:
|
|
1395
|
+
m = builder.build()
|
|
1396
|
+
base = str(Path(shard_path).with_suffix(""))
|
|
1397
|
+
writer = ManifestWriter(base)
|
|
1398
|
+
writer.write(m)
|
|
1399
|
+
|
|
1400
|
+
log.info(
|
|
1401
|
+
"write_samples: wrote %d shard(s), sample_type=%s",
|
|
1402
|
+
len(written_paths),
|
|
1403
|
+
sample_type.__name__,
|
|
1404
|
+
)
|
|
1405
|
+
|
|
1406
|
+
url = _shards_to_wds_url(written_paths)
|
|
1407
|
+
ds: Dataset = Dataset(url)
|
|
1408
|
+
ds._sample_type_cache = sample_type
|
|
1409
|
+
return ds
|
atdata/index/__init__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Index and entry models for atdata datasets.
|
|
2
|
+
|
|
3
|
+
Key classes:
|
|
4
|
+
|
|
5
|
+
- ``Index``: Unified index with pluggable providers (SQLite default),
|
|
6
|
+
named repositories, and optional atmosphere backend.
|
|
7
|
+
- ``LocalDatasetEntry``: Index entry with ATProto-compatible CIDs.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from atdata.index._entry import (
|
|
11
|
+
LocalDatasetEntry,
|
|
12
|
+
BasicIndexEntry,
|
|
13
|
+
REDIS_KEY_DATASET_ENTRY,
|
|
14
|
+
REDIS_KEY_SCHEMA,
|
|
15
|
+
)
|
|
16
|
+
from atdata.index._schema import (
|
|
17
|
+
SchemaNamespace,
|
|
18
|
+
SchemaFieldType,
|
|
19
|
+
SchemaField,
|
|
20
|
+
LocalSchemaRecord,
|
|
21
|
+
_ATDATA_URI_PREFIX,
|
|
22
|
+
_LEGACY_URI_PREFIX,
|
|
23
|
+
_kind_str_for_sample_type,
|
|
24
|
+
_schema_ref_from_type,
|
|
25
|
+
_make_schema_ref,
|
|
26
|
+
_parse_schema_ref,
|
|
27
|
+
_increment_patch,
|
|
28
|
+
_python_type_to_field_type,
|
|
29
|
+
_build_schema_record,
|
|
30
|
+
)
|
|
31
|
+
from atdata.index._index import Index
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
# Public API
|
|
35
|
+
"Index",
|
|
36
|
+
"LocalDatasetEntry",
|
|
37
|
+
"BasicIndexEntry",
|
|
38
|
+
"SchemaNamespace",
|
|
39
|
+
"SchemaFieldType",
|
|
40
|
+
"SchemaField",
|
|
41
|
+
"LocalSchemaRecord",
|
|
42
|
+
"REDIS_KEY_DATASET_ENTRY",
|
|
43
|
+
"REDIS_KEY_SCHEMA",
|
|
44
|
+
# Internal helpers (re-exported for backward compatibility)
|
|
45
|
+
"_ATDATA_URI_PREFIX",
|
|
46
|
+
"_LEGACY_URI_PREFIX",
|
|
47
|
+
"_kind_str_for_sample_type",
|
|
48
|
+
"_schema_ref_from_type",
|
|
49
|
+
"_make_schema_ref",
|
|
50
|
+
"_parse_schema_ref",
|
|
51
|
+
"_increment_patch",
|
|
52
|
+
"_python_type_to_field_type",
|
|
53
|
+
"_build_schema_record",
|
|
54
|
+
]
|
atdata/{local → index}/_entry.py
RENAMED
|
@@ -1,12 +1,16 @@
|
|
|
1
1
|
"""Dataset entry model and Redis key constants."""
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from atdata._cid import generate_cid
|
|
4
6
|
|
|
5
7
|
from dataclasses import dataclass, field
|
|
6
|
-
from typing import Any, cast
|
|
8
|
+
from typing import Any, TYPE_CHECKING, cast
|
|
7
9
|
|
|
8
10
|
import msgpack
|
|
9
|
-
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from redis import Redis
|
|
10
14
|
|
|
11
15
|
|
|
12
16
|
# Redis key prefixes for index entries and schemas
|