atdata 0.3.1b1__py3-none-any.whl → 0.3.2b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +2 -0
- atdata/_hf_api.py +13 -0
- atdata/_logging.py +43 -0
- atdata/_protocols.py +18 -1
- atdata/_sources.py +24 -4
- atdata/atmosphere/__init__.py +48 -10
- atdata/atmosphere/_lexicon_types.py +595 -0
- atdata/atmosphere/_types.py +71 -243
- atdata/atmosphere/lens.py +49 -41
- atdata/atmosphere/records.py +282 -90
- atdata/atmosphere/schema.py +78 -50
- atdata/atmosphere/store.py +62 -59
- atdata/dataset.py +201 -135
- atdata/index/_entry.py +6 -2
- atdata/index/_index.py +396 -109
- atdata/lexicons/__init__.py +9 -3
- atdata/lexicons/ac.foundation.dataset.lens.json +2 -0
- atdata/lexicons/ac.foundation.dataset.record.json +22 -1
- atdata/lexicons/ac.foundation.dataset.storageBlobs.json +26 -4
- atdata/lexicons/ac.foundation.dataset.storageExternal.json +1 -1
- atdata/lexicons/ac.foundation.dataset.storageHttp.json +45 -0
- atdata/lexicons/ac.foundation.dataset.storageS3.json +61 -0
- atdata/manifest/__init__.py +4 -0
- atdata/manifest/_proxy.py +321 -0
- atdata/repository.py +59 -9
- atdata/stores/_disk.py +19 -11
- atdata/stores/_s3.py +134 -112
- {atdata-0.3.1b1.dist-info → atdata-0.3.2b1.dist-info}/METADATA +1 -1
- {atdata-0.3.1b1.dist-info → atdata-0.3.2b1.dist-info}/RECORD +37 -33
- {atdata-0.3.1b1.dist-info → atdata-0.3.2b1.dist-info}/WHEEL +0 -0
- {atdata-0.3.1b1.dist-info → atdata-0.3.2b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.3.1b1.dist-info → atdata-0.3.2b1.dist-info}/licenses/LICENSE +0 -0
atdata/dataset.py
CHANGED
|
@@ -47,8 +47,6 @@ from ._protocols import DataSource, Packable
|
|
|
47
47
|
from ._exceptions import SampleKeyError, PartialFailureError
|
|
48
48
|
|
|
49
49
|
import numpy as np
|
|
50
|
-
import pandas as pd
|
|
51
|
-
import requests
|
|
52
50
|
|
|
53
51
|
import typing
|
|
54
52
|
from typing import (
|
|
@@ -70,6 +68,9 @@ from typing import (
|
|
|
70
68
|
)
|
|
71
69
|
|
|
72
70
|
if TYPE_CHECKING:
|
|
71
|
+
import pandas
|
|
72
|
+
import pandas as pd
|
|
73
|
+
from .manifest._proxy import Predicate
|
|
73
74
|
from .manifest._query import SampleLocation
|
|
74
75
|
from numpy.typing import NDArray
|
|
75
76
|
|
|
@@ -282,16 +283,9 @@ class PackableSample(ABC):
|
|
|
282
283
|
|
|
283
284
|
@property
|
|
284
285
|
def packed(self) -> bytes:
|
|
285
|
-
"""Serialize to msgpack bytes. NDArray fields are auto-converted.
|
|
286
|
-
|
|
287
|
-
Raises:
|
|
288
|
-
RuntimeError: If msgpack serialization fails.
|
|
289
|
-
"""
|
|
286
|
+
"""Serialize to msgpack bytes. NDArray fields are auto-converted."""
|
|
290
287
|
o = {k: _make_packable(v) for k, v in vars(self).items()}
|
|
291
|
-
|
|
292
|
-
if ret is None:
|
|
293
|
-
raise RuntimeError(f"Failed to pack sample to bytes: {o}")
|
|
294
|
-
return ret
|
|
288
|
+
return msgpack.packb(o)
|
|
295
289
|
|
|
296
290
|
@property
|
|
297
291
|
def as_wds(self) -> WDSRawSample:
|
|
@@ -542,6 +536,8 @@ class Dataset(Generic[ST]):
|
|
|
542
536
|
return None
|
|
543
537
|
|
|
544
538
|
if self._metadata is None:
|
|
539
|
+
import requests
|
|
540
|
+
|
|
545
541
|
with requests.get(self.metadata_url, stream=True) as response:
|
|
546
542
|
response.raise_for_status()
|
|
547
543
|
self._metadata = msgpack.unpackb(response.content, raw=False)
|
|
@@ -710,6 +706,8 @@ class Dataset(Generic[ST]):
|
|
|
710
706
|
fn: Callable[[list[ST]], Any],
|
|
711
707
|
*,
|
|
712
708
|
shards: list[str] | None = None,
|
|
709
|
+
checkpoint: Path | str | None = None,
|
|
710
|
+
on_shard_error: Callable[[str, Exception], None] | None = None,
|
|
713
711
|
) -> dict[str, Any]:
|
|
714
712
|
"""Process each shard independently, collecting per-shard results.
|
|
715
713
|
|
|
@@ -725,6 +723,14 @@ class Dataset(Generic[ST]):
|
|
|
725
723
|
shards: Optional list of shard identifiers to process. If ``None``,
|
|
726
724
|
processes all shards in the dataset. Useful for retrying only
|
|
727
725
|
the failed shards from a previous ``PartialFailureError``.
|
|
726
|
+
checkpoint: Optional path to a checkpoint file. If provided,
|
|
727
|
+
already-succeeded shard IDs are loaded from this file and
|
|
728
|
+
skipped. Each newly succeeded shard is appended. On full
|
|
729
|
+
success the file is deleted. On partial failure it remains
|
|
730
|
+
for resume.
|
|
731
|
+
on_shard_error: Optional callback invoked as
|
|
732
|
+
``on_shard_error(shard_id, exception)`` for each failed shard,
|
|
733
|
+
enabling dead-letter logging or alerting.
|
|
728
734
|
|
|
729
735
|
Returns:
|
|
730
736
|
Dict mapping shard identifier to *fn*'s return value for each shard.
|
|
@@ -741,45 +747,67 @@ class Dataset(Generic[ST]):
|
|
|
741
747
|
... results = ds.process_shards(expensive_fn)
|
|
742
748
|
... except PartialFailureError as e:
|
|
743
749
|
... retry = ds.process_shards(expensive_fn, shards=e.failed_shards)
|
|
750
|
+
|
|
751
|
+
>>> # With checkpoint for crash recovery:
|
|
752
|
+
>>> results = ds.process_shards(expensive_fn, checkpoint="progress.txt")
|
|
744
753
|
"""
|
|
745
|
-
from ._logging import get_logger
|
|
754
|
+
from ._logging import get_logger, log_operation
|
|
746
755
|
|
|
747
756
|
log = get_logger()
|
|
748
757
|
shard_ids = shards or self.list_shards()
|
|
749
|
-
|
|
758
|
+
|
|
759
|
+
# Load checkpoint: skip already-succeeded shards
|
|
760
|
+
checkpoint_path: Path | None = None
|
|
761
|
+
if checkpoint is not None:
|
|
762
|
+
checkpoint_path = Path(checkpoint)
|
|
763
|
+
if checkpoint_path.exists():
|
|
764
|
+
already_done = set(checkpoint_path.read_text().splitlines())
|
|
765
|
+
log.info(
|
|
766
|
+
"process_shards: loaded checkpoint, %d shards already done",
|
|
767
|
+
len(already_done),
|
|
768
|
+
)
|
|
769
|
+
shard_ids = [s for s in shard_ids if s not in already_done]
|
|
770
|
+
if not shard_ids:
|
|
771
|
+
log.info("process_shards: all shards already checkpointed")
|
|
772
|
+
return {}
|
|
750
773
|
|
|
751
774
|
succeeded: list[str] = []
|
|
752
775
|
failed: list[str] = []
|
|
753
776
|
errors: dict[str, Exception] = {}
|
|
754
777
|
results: dict[str, Any] = {}
|
|
755
778
|
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
779
|
+
with log_operation("process_shards", total_shards=len(shard_ids)):
|
|
780
|
+
for shard_id in shard_ids:
|
|
781
|
+
try:
|
|
782
|
+
shard_ds = Dataset[self.sample_type](shard_id)
|
|
783
|
+
shard_ds._sample_type_cache = self._sample_type_cache
|
|
784
|
+
samples = list(shard_ds.ordered())
|
|
785
|
+
results[shard_id] = fn(samples)
|
|
786
|
+
succeeded.append(shard_id)
|
|
787
|
+
log.debug("process_shards: shard ok %s", shard_id)
|
|
788
|
+
if checkpoint_path is not None:
|
|
789
|
+
with open(checkpoint_path, "a") as f:
|
|
790
|
+
f.write(shard_id + "\n")
|
|
791
|
+
except Exception as exc:
|
|
792
|
+
failed.append(shard_id)
|
|
793
|
+
errors[shard_id] = exc
|
|
794
|
+
log.warning("process_shards: shard failed %s: %s", shard_id, exc)
|
|
795
|
+
if on_shard_error is not None:
|
|
796
|
+
on_shard_error(shard_id, exc)
|
|
797
|
+
|
|
798
|
+
if failed:
|
|
799
|
+
raise PartialFailureError(
|
|
800
|
+
succeeded_shards=succeeded,
|
|
801
|
+
failed_shards=failed,
|
|
802
|
+
errors=errors,
|
|
803
|
+
results=results,
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
# All shards succeeded; clean up checkpoint file
|
|
807
|
+
if checkpoint_path is not None and checkpoint_path.exists():
|
|
808
|
+
checkpoint_path.unlink()
|
|
809
|
+
log.debug("process_shards: checkpoint file removed (all shards done)")
|
|
781
810
|
|
|
782
|
-
log.info("process_shards: all %d shards succeeded", len(shard_ids))
|
|
783
811
|
return results
|
|
784
812
|
|
|
785
813
|
def select(self, indices: Sequence[int]) -> list[ST]:
|
|
@@ -811,9 +839,25 @@ class Dataset(Generic[ST]):
|
|
|
811
839
|
break
|
|
812
840
|
return [result[i] for i in indices if i in result]
|
|
813
841
|
|
|
842
|
+
@property
|
|
843
|
+
def fields(self) -> "Any":
|
|
844
|
+
"""Typed field proxy for manifest queries on this dataset.
|
|
845
|
+
|
|
846
|
+
Returns an object whose attributes are ``FieldProxy`` instances,
|
|
847
|
+
one per manifest-eligible field of this dataset's sample type.
|
|
848
|
+
|
|
849
|
+
Examples:
|
|
850
|
+
>>> ds = atdata.Dataset[MySample](url)
|
|
851
|
+
>>> Q = ds.fields
|
|
852
|
+
>>> results = ds.query(where=(Q.confidence > 0.9))
|
|
853
|
+
"""
|
|
854
|
+
from .manifest._proxy import query_fields
|
|
855
|
+
|
|
856
|
+
return query_fields(self.sample_type)
|
|
857
|
+
|
|
814
858
|
def query(
|
|
815
859
|
self,
|
|
816
|
-
where: "Callable[[pd.DataFrame], pd.Series]",
|
|
860
|
+
where: "Callable[[pd.DataFrame], pd.Series] | Predicate",
|
|
817
861
|
) -> "list[SampleLocation]":
|
|
818
862
|
"""Query this dataset using per-shard manifest metadata.
|
|
819
863
|
|
|
@@ -822,10 +866,12 @@ class Dataset(Generic[ST]):
|
|
|
822
866
|
and executes a two-phase query (shard-level aggregate pruning,
|
|
823
867
|
then sample-level parquet filtering).
|
|
824
868
|
|
|
869
|
+
The *where* argument accepts either a lambda/function that operates
|
|
870
|
+
on a pandas DataFrame, or a ``Predicate`` built from the proxy DSL.
|
|
871
|
+
|
|
825
872
|
Args:
|
|
826
|
-
where: Predicate function
|
|
827
|
-
|
|
828
|
-
matching rows.
|
|
873
|
+
where: Predicate function or ``Predicate`` object that selects
|
|
874
|
+
matching rows from the per-sample manifest DataFrame.
|
|
829
875
|
|
|
830
876
|
Returns:
|
|
831
877
|
List of ``SampleLocation`` for matching samples.
|
|
@@ -837,6 +883,9 @@ class Dataset(Generic[ST]):
|
|
|
837
883
|
>>> locs = ds.query(where=lambda df: df["confidence"] > 0.9)
|
|
838
884
|
>>> len(locs)
|
|
839
885
|
42
|
|
886
|
+
|
|
887
|
+
>>> Q = ds.fields
|
|
888
|
+
>>> locs = ds.query(where=(Q.confidence > 0.9))
|
|
840
889
|
"""
|
|
841
890
|
from .manifest import QueryExecutor
|
|
842
891
|
|
|
@@ -844,7 +893,7 @@ class Dataset(Generic[ST]):
|
|
|
844
893
|
executor = QueryExecutor.from_shard_urls(shard_urls)
|
|
845
894
|
return executor.query(where=where)
|
|
846
895
|
|
|
847
|
-
def to_pandas(self, limit: int | None = None) -> "
|
|
896
|
+
def to_pandas(self, limit: int | None = None) -> "pandas.DataFrame":
|
|
848
897
|
"""Materialize the dataset (or first *limit* samples) as a DataFrame.
|
|
849
898
|
|
|
850
899
|
Args:
|
|
@@ -867,6 +916,8 @@ class Dataset(Generic[ST]):
|
|
|
867
916
|
rows = [
|
|
868
917
|
asdict(s) if dataclasses.is_dataclass(s) else s.to_dict() for s in samples
|
|
869
918
|
]
|
|
919
|
+
import pandas as pd
|
|
920
|
+
|
|
870
921
|
return pd.DataFrame(rows)
|
|
871
922
|
|
|
872
923
|
def to_dict(self, limit: int | None = None) -> dict[str, list[Any]]:
|
|
@@ -1061,6 +1112,8 @@ class Dataset(Generic[ST]):
|
|
|
1061
1112
|
Examples:
|
|
1062
1113
|
>>> ds.to_parquet("output.parquet", maxcount=50000)
|
|
1063
1114
|
"""
|
|
1115
|
+
import pandas as pd
|
|
1116
|
+
|
|
1064
1117
|
path = Path(path)
|
|
1065
1118
|
if sample_map is None:
|
|
1066
1119
|
sample_map = asdict
|
|
@@ -1129,7 +1182,7 @@ _T = TypeVar("_T")
|
|
|
1129
1182
|
|
|
1130
1183
|
|
|
1131
1184
|
@dataclass_transform()
|
|
1132
|
-
def packable(cls: type[_T]) -> type[
|
|
1185
|
+
def packable(cls: type[_T]) -> type[_T]:
|
|
1133
1186
|
"""Convert a class into a ``PackableSample`` dataclass with msgpack serialization.
|
|
1134
1187
|
|
|
1135
1188
|
The resulting class gains ``packed``, ``as_wds``, ``from_bytes``, and
|
|
@@ -1233,11 +1286,13 @@ def write_samples(
|
|
|
1233
1286
|
[MySample(key='0', text='hello')]
|
|
1234
1287
|
"""
|
|
1235
1288
|
from ._hf_api import _shards_to_wds_url
|
|
1289
|
+
from ._logging import get_logger, log_operation
|
|
1236
1290
|
|
|
1237
1291
|
if manifest:
|
|
1238
1292
|
from .manifest._builder import ManifestBuilder
|
|
1239
1293
|
from .manifest._writer import ManifestWriter
|
|
1240
1294
|
|
|
1295
|
+
log = get_logger()
|
|
1241
1296
|
path = Path(path)
|
|
1242
1297
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
1243
1298
|
|
|
@@ -1245,97 +1300,108 @@ def write_samples(
|
|
|
1245
1300
|
sample_type: type | None = None
|
|
1246
1301
|
written_paths: list[str] = []
|
|
1247
1302
|
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
_current_builder
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1303
|
+
with log_operation(
|
|
1304
|
+
"write_samples", path=str(path), sharded=use_shard_writer, manifest=manifest
|
|
1305
|
+
):
|
|
1306
|
+
# Manifest tracking state
|
|
1307
|
+
_current_builder: list = [] # single-element list for nonlocal mutation
|
|
1308
|
+
_builders: list[tuple[str, "ManifestBuilder"]] = []
|
|
1309
|
+
_running_offset: list[int] = [0]
|
|
1310
|
+
|
|
1311
|
+
def _finalize_builder() -> None:
|
|
1312
|
+
"""Finalize the current manifest builder and stash it."""
|
|
1313
|
+
if _current_builder:
|
|
1314
|
+
shard_path = written_paths[-1] if written_paths else ""
|
|
1315
|
+
_builders.append((shard_path, _current_builder[0]))
|
|
1316
|
+
_current_builder.clear()
|
|
1317
|
+
|
|
1318
|
+
def _start_builder(shard_path: str) -> None:
|
|
1319
|
+
"""Start a new manifest builder for a shard."""
|
|
1320
|
+
_finalize_builder()
|
|
1321
|
+
shard_id = Path(shard_path).stem
|
|
1322
|
+
_current_builder.append(
|
|
1323
|
+
ManifestBuilder(sample_type=sample_type, shard_id=shard_id)
|
|
1324
|
+
)
|
|
1325
|
+
_running_offset[0] = 0
|
|
1326
|
+
|
|
1327
|
+
def _record_sample(sample: "PackableSample", wds_dict: dict) -> None:
|
|
1328
|
+
"""Record a sample in the active manifest builder."""
|
|
1329
|
+
if not _current_builder:
|
|
1330
|
+
return
|
|
1331
|
+
packed_bytes = wds_dict["msgpack"]
|
|
1332
|
+
size = len(packed_bytes)
|
|
1333
|
+
_current_builder[0].add_sample(
|
|
1334
|
+
key=wds_dict["__key__"],
|
|
1335
|
+
offset=_running_offset[0],
|
|
1336
|
+
size=size,
|
|
1337
|
+
sample=sample,
|
|
1338
|
+
)
|
|
1339
|
+
_running_offset[0] += size
|
|
1340
|
+
|
|
1341
|
+
if use_shard_writer:
|
|
1342
|
+
# Build shard pattern from path
|
|
1343
|
+
if "%" not in str(path):
|
|
1344
|
+
pattern = str(path.parent / f"{path.stem}-%06d{path.suffix}")
|
|
1345
|
+
else:
|
|
1346
|
+
pattern = str(path)
|
|
1347
|
+
|
|
1348
|
+
writer_kwargs: dict[str, Any] = {}
|
|
1349
|
+
if maxcount is not None:
|
|
1350
|
+
writer_kwargs["maxcount"] = maxcount
|
|
1351
|
+
if maxsize is not None:
|
|
1352
|
+
writer_kwargs["maxsize"] = maxsize
|
|
1353
|
+
|
|
1354
|
+
def _track(p: str) -> None:
|
|
1355
|
+
written_paths.append(str(Path(p).resolve()))
|
|
1356
|
+
if manifest and sample_type is not None:
|
|
1357
|
+
_start_builder(p)
|
|
1358
|
+
|
|
1359
|
+
with wds.writer.ShardWriter(pattern, post=_track, **writer_kwargs) as sink:
|
|
1360
|
+
for sample in samples:
|
|
1361
|
+
if sample_type is None:
|
|
1362
|
+
sample_type = type(sample)
|
|
1363
|
+
wds_dict = sample.as_wds
|
|
1364
|
+
sink.write(wds_dict)
|
|
1365
|
+
if manifest:
|
|
1366
|
+
# The first sample triggers _track before we get here when
|
|
1367
|
+
# ShardWriter opens the first shard, but just in case:
|
|
1368
|
+
if not _current_builder and sample_type is not None:
|
|
1369
|
+
_start_builder(str(path))
|
|
1370
|
+
_record_sample(sample, wds_dict)
|
|
1287
1371
|
else:
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
if not _current_builder and sample_type is not None:
|
|
1322
|
-
_current_builder.append(
|
|
1323
|
-
ManifestBuilder(sample_type=sample_type, shard_id=path.stem)
|
|
1324
|
-
)
|
|
1325
|
-
_record_sample(sample, wds_dict)
|
|
1326
|
-
written_paths.append(str(path.resolve()))
|
|
1327
|
-
|
|
1328
|
-
if sample_type is None:
|
|
1329
|
-
raise ValueError("samples must be non-empty")
|
|
1330
|
-
|
|
1331
|
-
# Finalize and write manifests
|
|
1332
|
-
if manifest:
|
|
1333
|
-
_finalize_builder()
|
|
1334
|
-
for shard_path, builder in _builders:
|
|
1335
|
-
m = builder.build()
|
|
1336
|
-
base = str(Path(shard_path).with_suffix(""))
|
|
1337
|
-
writer = ManifestWriter(base)
|
|
1338
|
-
writer.write(m)
|
|
1372
|
+
with wds.writer.TarWriter(str(path)) as sink:
|
|
1373
|
+
for sample in samples:
|
|
1374
|
+
if sample_type is None:
|
|
1375
|
+
sample_type = type(sample)
|
|
1376
|
+
wds_dict = sample.as_wds
|
|
1377
|
+
sink.write(wds_dict)
|
|
1378
|
+
if manifest:
|
|
1379
|
+
if not _current_builder and sample_type is not None:
|
|
1380
|
+
_current_builder.append(
|
|
1381
|
+
ManifestBuilder(
|
|
1382
|
+
sample_type=sample_type, shard_id=path.stem
|
|
1383
|
+
)
|
|
1384
|
+
)
|
|
1385
|
+
_record_sample(sample, wds_dict)
|
|
1386
|
+
written_paths.append(str(path.resolve()))
|
|
1387
|
+
|
|
1388
|
+
if sample_type is None:
|
|
1389
|
+
raise ValueError("samples must be non-empty")
|
|
1390
|
+
|
|
1391
|
+
# Finalize and write manifests
|
|
1392
|
+
if manifest:
|
|
1393
|
+
_finalize_builder()
|
|
1394
|
+
for shard_path, builder in _builders:
|
|
1395
|
+
m = builder.build()
|
|
1396
|
+
base = str(Path(shard_path).with_suffix(""))
|
|
1397
|
+
writer = ManifestWriter(base)
|
|
1398
|
+
writer.write(m)
|
|
1399
|
+
|
|
1400
|
+
log.info(
|
|
1401
|
+
"write_samples: wrote %d shard(s), sample_type=%s",
|
|
1402
|
+
len(written_paths),
|
|
1403
|
+
sample_type.__name__,
|
|
1404
|
+
)
|
|
1339
1405
|
|
|
1340
1406
|
url = _shards_to_wds_url(written_paths)
|
|
1341
1407
|
ds: Dataset = Dataset(url)
|
atdata/index/_entry.py
CHANGED
|
@@ -1,12 +1,16 @@
|
|
|
1
1
|
"""Dataset entry model and Redis key constants."""
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from atdata._cid import generate_cid
|
|
4
6
|
|
|
5
7
|
from dataclasses import dataclass, field
|
|
6
|
-
from typing import Any, cast
|
|
8
|
+
from typing import Any, TYPE_CHECKING, cast
|
|
7
9
|
|
|
8
10
|
import msgpack
|
|
9
|
-
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from redis import Redis
|
|
10
14
|
|
|
11
15
|
|
|
12
16
|
# Redis key prefixes for index entries and schemas
|