atdata 0.3.1b1__py3-none-any.whl → 0.3.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atdata/dataset.py CHANGED
@@ -47,8 +47,6 @@ from ._protocols import DataSource, Packable
47
47
  from ._exceptions import SampleKeyError, PartialFailureError
48
48
 
49
49
  import numpy as np
50
- import pandas as pd
51
- import requests
52
50
 
53
51
  import typing
54
52
  from typing import (
@@ -70,6 +68,9 @@ from typing import (
70
68
  )
71
69
 
72
70
  if TYPE_CHECKING:
71
+ import pandas
72
+ import pandas as pd
73
+ from .manifest._proxy import Predicate
73
74
  from .manifest._query import SampleLocation
74
75
  from numpy.typing import NDArray
75
76
 
@@ -282,16 +283,9 @@ class PackableSample(ABC):
282
283
 
283
284
  @property
284
285
  def packed(self) -> bytes:
285
- """Serialize to msgpack bytes. NDArray fields are auto-converted.
286
-
287
- Raises:
288
- RuntimeError: If msgpack serialization fails.
289
- """
286
+ """Serialize to msgpack bytes. NDArray fields are auto-converted."""
290
287
  o = {k: _make_packable(v) for k, v in vars(self).items()}
291
- ret = msgpack.packb(o)
292
- if ret is None:
293
- raise RuntimeError(f"Failed to pack sample to bytes: {o}")
294
- return ret
288
+ return msgpack.packb(o)
295
289
 
296
290
  @property
297
291
  def as_wds(self) -> WDSRawSample:
@@ -542,6 +536,8 @@ class Dataset(Generic[ST]):
542
536
  return None
543
537
 
544
538
  if self._metadata is None:
539
+ import requests
540
+
545
541
  with requests.get(self.metadata_url, stream=True) as response:
546
542
  response.raise_for_status()
547
543
  self._metadata = msgpack.unpackb(response.content, raw=False)
@@ -710,6 +706,8 @@ class Dataset(Generic[ST]):
710
706
  fn: Callable[[list[ST]], Any],
711
707
  *,
712
708
  shards: list[str] | None = None,
709
+ checkpoint: Path | str | None = None,
710
+ on_shard_error: Callable[[str, Exception], None] | None = None,
713
711
  ) -> dict[str, Any]:
714
712
  """Process each shard independently, collecting per-shard results.
715
713
 
@@ -725,6 +723,14 @@ class Dataset(Generic[ST]):
725
723
  shards: Optional list of shard identifiers to process. If ``None``,
726
724
  processes all shards in the dataset. Useful for retrying only
727
725
  the failed shards from a previous ``PartialFailureError``.
726
+ checkpoint: Optional path to a checkpoint file. If provided,
727
+ already-succeeded shard IDs are loaded from this file and
728
+ skipped. Each newly succeeded shard is appended. On full
729
+ success the file is deleted. On partial failure it remains
730
+ for resume.
731
+ on_shard_error: Optional callback invoked as
732
+ ``on_shard_error(shard_id, exception)`` for each failed shard,
733
+ enabling dead-letter logging or alerting.
728
734
 
729
735
  Returns:
730
736
  Dict mapping shard identifier to *fn*'s return value for each shard.
@@ -741,45 +747,67 @@ class Dataset(Generic[ST]):
741
747
  ... results = ds.process_shards(expensive_fn)
742
748
  ... except PartialFailureError as e:
743
749
  ... retry = ds.process_shards(expensive_fn, shards=e.failed_shards)
750
+
751
+ >>> # With checkpoint for crash recovery:
752
+ >>> results = ds.process_shards(expensive_fn, checkpoint="progress.txt")
744
753
  """
745
- from ._logging import get_logger
754
+ from ._logging import get_logger, log_operation
746
755
 
747
756
  log = get_logger()
748
757
  shard_ids = shards or self.list_shards()
749
- log.info("process_shards: starting %d shards", len(shard_ids))
758
+
759
+ # Load checkpoint: skip already-succeeded shards
760
+ checkpoint_path: Path | None = None
761
+ if checkpoint is not None:
762
+ checkpoint_path = Path(checkpoint)
763
+ if checkpoint_path.exists():
764
+ already_done = set(checkpoint_path.read_text().splitlines())
765
+ log.info(
766
+ "process_shards: loaded checkpoint, %d shards already done",
767
+ len(already_done),
768
+ )
769
+ shard_ids = [s for s in shard_ids if s not in already_done]
770
+ if not shard_ids:
771
+ log.info("process_shards: all shards already checkpointed")
772
+ return {}
750
773
 
751
774
  succeeded: list[str] = []
752
775
  failed: list[str] = []
753
776
  errors: dict[str, Exception] = {}
754
777
  results: dict[str, Any] = {}
755
778
 
756
- for shard_id in shard_ids:
757
- try:
758
- shard_ds = Dataset[self.sample_type](shard_id)
759
- shard_ds._sample_type_cache = self._sample_type_cache
760
- samples = list(shard_ds.ordered())
761
- results[shard_id] = fn(samples)
762
- succeeded.append(shard_id)
763
- log.debug("process_shards: shard ok %s", shard_id)
764
- except Exception as exc:
765
- failed.append(shard_id)
766
- errors[shard_id] = exc
767
- log.warning("process_shards: shard failed %s: %s", shard_id, exc)
768
-
769
- if failed:
770
- log.error(
771
- "process_shards: %d/%d shards failed",
772
- len(failed),
773
- len(shard_ids),
774
- )
775
- raise PartialFailureError(
776
- succeeded_shards=succeeded,
777
- failed_shards=failed,
778
- errors=errors,
779
- results=results,
780
- )
779
+ with log_operation("process_shards", total_shards=len(shard_ids)):
780
+ for shard_id in shard_ids:
781
+ try:
782
+ shard_ds = Dataset[self.sample_type](shard_id)
783
+ shard_ds._sample_type_cache = self._sample_type_cache
784
+ samples = list(shard_ds.ordered())
785
+ results[shard_id] = fn(samples)
786
+ succeeded.append(shard_id)
787
+ log.debug("process_shards: shard ok %s", shard_id)
788
+ if checkpoint_path is not None:
789
+ with open(checkpoint_path, "a") as f:
790
+ f.write(shard_id + "\n")
791
+ except Exception as exc:
792
+ failed.append(shard_id)
793
+ errors[shard_id] = exc
794
+ log.warning("process_shards: shard failed %s: %s", shard_id, exc)
795
+ if on_shard_error is not None:
796
+ on_shard_error(shard_id, exc)
797
+
798
+ if failed:
799
+ raise PartialFailureError(
800
+ succeeded_shards=succeeded,
801
+ failed_shards=failed,
802
+ errors=errors,
803
+ results=results,
804
+ )
805
+
806
+ # All shards succeeded; clean up checkpoint file
807
+ if checkpoint_path is not None and checkpoint_path.exists():
808
+ checkpoint_path.unlink()
809
+ log.debug("process_shards: checkpoint file removed (all shards done)")
781
810
 
782
- log.info("process_shards: all %d shards succeeded", len(shard_ids))
783
811
  return results
784
812
 
785
813
  def select(self, indices: Sequence[int]) -> list[ST]:
@@ -811,9 +839,25 @@ class Dataset(Generic[ST]):
811
839
  break
812
840
  return [result[i] for i in indices if i in result]
813
841
 
842
+ @property
843
+ def fields(self) -> "Any":
844
+ """Typed field proxy for manifest queries on this dataset.
845
+
846
+ Returns an object whose attributes are ``FieldProxy`` instances,
847
+ one per manifest-eligible field of this dataset's sample type.
848
+
849
+ Examples:
850
+ >>> ds = atdata.Dataset[MySample](url)
851
+ >>> Q = ds.fields
852
+ >>> results = ds.query(where=(Q.confidence > 0.9))
853
+ """
854
+ from .manifest._proxy import query_fields
855
+
856
+ return query_fields(self.sample_type)
857
+
814
858
  def query(
815
859
  self,
816
- where: "Callable[[pd.DataFrame], pd.Series]",
860
+ where: "Callable[[pd.DataFrame], pd.Series] | Predicate",
817
861
  ) -> "list[SampleLocation]":
818
862
  """Query this dataset using per-shard manifest metadata.
819
863
 
@@ -822,10 +866,12 @@ class Dataset(Generic[ST]):
822
866
  and executes a two-phase query (shard-level aggregate pruning,
823
867
  then sample-level parquet filtering).
824
868
 
869
+ The *where* argument accepts either a lambda/function that operates
870
+ on a pandas DataFrame, or a ``Predicate`` built from the proxy DSL.
871
+
825
872
  Args:
826
- where: Predicate function that receives a pandas DataFrame
827
- of manifest fields and returns a boolean Series selecting
828
- matching rows.
873
+ where: Predicate function or ``Predicate`` object that selects
874
+ matching rows from the per-sample manifest DataFrame.
829
875
 
830
876
  Returns:
831
877
  List of ``SampleLocation`` for matching samples.
@@ -837,6 +883,9 @@ class Dataset(Generic[ST]):
837
883
  >>> locs = ds.query(where=lambda df: df["confidence"] > 0.9)
838
884
  >>> len(locs)
839
885
  42
886
+
887
+ >>> Q = ds.fields
888
+ >>> locs = ds.query(where=(Q.confidence > 0.9))
840
889
  """
841
890
  from .manifest import QueryExecutor
842
891
 
@@ -844,7 +893,7 @@ class Dataset(Generic[ST]):
844
893
  executor = QueryExecutor.from_shard_urls(shard_urls)
845
894
  return executor.query(where=where)
846
895
 
847
- def to_pandas(self, limit: int | None = None) -> "pd.DataFrame":
896
+ def to_pandas(self, limit: int | None = None) -> "pandas.DataFrame":
848
897
  """Materialize the dataset (or first *limit* samples) as a DataFrame.
849
898
 
850
899
  Args:
@@ -867,6 +916,8 @@ class Dataset(Generic[ST]):
867
916
  rows = [
868
917
  asdict(s) if dataclasses.is_dataclass(s) else s.to_dict() for s in samples
869
918
  ]
919
+ import pandas as pd
920
+
870
921
  return pd.DataFrame(rows)
871
922
 
872
923
  def to_dict(self, limit: int | None = None) -> dict[str, list[Any]]:
@@ -1061,6 +1112,8 @@ class Dataset(Generic[ST]):
1061
1112
  Examples:
1062
1113
  >>> ds.to_parquet("output.parquet", maxcount=50000)
1063
1114
  """
1115
+ import pandas as pd
1116
+
1064
1117
  path = Path(path)
1065
1118
  if sample_map is None:
1066
1119
  sample_map = asdict
@@ -1129,7 +1182,7 @@ _T = TypeVar("_T")
1129
1182
 
1130
1183
 
1131
1184
  @dataclass_transform()
1132
- def packable(cls: type[_T]) -> type[Packable]:
1185
+ def packable(cls: type[_T]) -> type[_T]:
1133
1186
  """Convert a class into a ``PackableSample`` dataclass with msgpack serialization.
1134
1187
 
1135
1188
  The resulting class gains ``packed``, ``as_wds``, ``from_bytes``, and
@@ -1233,11 +1286,13 @@ def write_samples(
1233
1286
  [MySample(key='0', text='hello')]
1234
1287
  """
1235
1288
  from ._hf_api import _shards_to_wds_url
1289
+ from ._logging import get_logger, log_operation
1236
1290
 
1237
1291
  if manifest:
1238
1292
  from .manifest._builder import ManifestBuilder
1239
1293
  from .manifest._writer import ManifestWriter
1240
1294
 
1295
+ log = get_logger()
1241
1296
  path = Path(path)
1242
1297
  path.parent.mkdir(parents=True, exist_ok=True)
1243
1298
 
@@ -1245,97 +1300,108 @@ def write_samples(
1245
1300
  sample_type: type | None = None
1246
1301
  written_paths: list[str] = []
1247
1302
 
1248
- # Manifest tracking state
1249
- _current_builder: list = [] # single-element list for nonlocal mutation
1250
- _builders: list[tuple[str, "ManifestBuilder"]] = []
1251
- _running_offset: list[int] = [0]
1252
-
1253
- def _finalize_builder() -> None:
1254
- """Finalize the current manifest builder and stash it."""
1255
- if _current_builder:
1256
- shard_path = written_paths[-1] if written_paths else ""
1257
- _builders.append((shard_path, _current_builder[0]))
1258
- _current_builder.clear()
1259
-
1260
- def _start_builder(shard_path: str) -> None:
1261
- """Start a new manifest builder for a shard."""
1262
- _finalize_builder()
1263
- shard_id = Path(shard_path).stem
1264
- _current_builder.append(
1265
- ManifestBuilder(sample_type=sample_type, shard_id=shard_id)
1266
- )
1267
- _running_offset[0] = 0
1268
-
1269
- def _record_sample(sample: "PackableSample", wds_dict: dict) -> None:
1270
- """Record a sample in the active manifest builder."""
1271
- if not _current_builder:
1272
- return
1273
- packed_bytes = wds_dict["msgpack"]
1274
- size = len(packed_bytes)
1275
- _current_builder[0].add_sample(
1276
- key=wds_dict["__key__"],
1277
- offset=_running_offset[0],
1278
- size=size,
1279
- sample=sample,
1280
- )
1281
- _running_offset[0] += size
1282
-
1283
- if use_shard_writer:
1284
- # Build shard pattern from path
1285
- if "%" not in str(path):
1286
- pattern = str(path.parent / f"{path.stem}-%06d{path.suffix}")
1303
+ with log_operation(
1304
+ "write_samples", path=str(path), sharded=use_shard_writer, manifest=manifest
1305
+ ):
1306
+ # Manifest tracking state
1307
+ _current_builder: list = [] # single-element list for nonlocal mutation
1308
+ _builders: list[tuple[str, "ManifestBuilder"]] = []
1309
+ _running_offset: list[int] = [0]
1310
+
1311
+ def _finalize_builder() -> None:
1312
+ """Finalize the current manifest builder and stash it."""
1313
+ if _current_builder:
1314
+ shard_path = written_paths[-1] if written_paths else ""
1315
+ _builders.append((shard_path, _current_builder[0]))
1316
+ _current_builder.clear()
1317
+
1318
+ def _start_builder(shard_path: str) -> None:
1319
+ """Start a new manifest builder for a shard."""
1320
+ _finalize_builder()
1321
+ shard_id = Path(shard_path).stem
1322
+ _current_builder.append(
1323
+ ManifestBuilder(sample_type=sample_type, shard_id=shard_id)
1324
+ )
1325
+ _running_offset[0] = 0
1326
+
1327
+ def _record_sample(sample: "PackableSample", wds_dict: dict) -> None:
1328
+ """Record a sample in the active manifest builder."""
1329
+ if not _current_builder:
1330
+ return
1331
+ packed_bytes = wds_dict["msgpack"]
1332
+ size = len(packed_bytes)
1333
+ _current_builder[0].add_sample(
1334
+ key=wds_dict["__key__"],
1335
+ offset=_running_offset[0],
1336
+ size=size,
1337
+ sample=sample,
1338
+ )
1339
+ _running_offset[0] += size
1340
+
1341
+ if use_shard_writer:
1342
+ # Build shard pattern from path
1343
+ if "%" not in str(path):
1344
+ pattern = str(path.parent / f"{path.stem}-%06d{path.suffix}")
1345
+ else:
1346
+ pattern = str(path)
1347
+
1348
+ writer_kwargs: dict[str, Any] = {}
1349
+ if maxcount is not None:
1350
+ writer_kwargs["maxcount"] = maxcount
1351
+ if maxsize is not None:
1352
+ writer_kwargs["maxsize"] = maxsize
1353
+
1354
+ def _track(p: str) -> None:
1355
+ written_paths.append(str(Path(p).resolve()))
1356
+ if manifest and sample_type is not None:
1357
+ _start_builder(p)
1358
+
1359
+ with wds.writer.ShardWriter(pattern, post=_track, **writer_kwargs) as sink:
1360
+ for sample in samples:
1361
+ if sample_type is None:
1362
+ sample_type = type(sample)
1363
+ wds_dict = sample.as_wds
1364
+ sink.write(wds_dict)
1365
+ if manifest:
1366
+ # The first sample triggers _track before we get here when
1367
+ # ShardWriter opens the first shard, but just in case:
1368
+ if not _current_builder and sample_type is not None:
1369
+ _start_builder(str(path))
1370
+ _record_sample(sample, wds_dict)
1287
1371
  else:
1288
- pattern = str(path)
1289
-
1290
- writer_kwargs: dict[str, Any] = {}
1291
- if maxcount is not None:
1292
- writer_kwargs["maxcount"] = maxcount
1293
- if maxsize is not None:
1294
- writer_kwargs["maxsize"] = maxsize
1295
-
1296
- def _track(p: str) -> None:
1297
- written_paths.append(str(Path(p).resolve()))
1298
- if manifest and sample_type is not None:
1299
- _start_builder(p)
1300
-
1301
- with wds.writer.ShardWriter(pattern, post=_track, **writer_kwargs) as sink:
1302
- for sample in samples:
1303
- if sample_type is None:
1304
- sample_type = type(sample)
1305
- wds_dict = sample.as_wds
1306
- sink.write(wds_dict)
1307
- if manifest:
1308
- # The first sample triggers _track before we get here when
1309
- # ShardWriter opens the first shard, but just in case:
1310
- if not _current_builder and sample_type is not None:
1311
- _start_builder(str(path))
1312
- _record_sample(sample, wds_dict)
1313
- else:
1314
- with wds.writer.TarWriter(str(path)) as sink:
1315
- for sample in samples:
1316
- if sample_type is None:
1317
- sample_type = type(sample)
1318
- wds_dict = sample.as_wds
1319
- sink.write(wds_dict)
1320
- if manifest:
1321
- if not _current_builder and sample_type is not None:
1322
- _current_builder.append(
1323
- ManifestBuilder(sample_type=sample_type, shard_id=path.stem)
1324
- )
1325
- _record_sample(sample, wds_dict)
1326
- written_paths.append(str(path.resolve()))
1327
-
1328
- if sample_type is None:
1329
- raise ValueError("samples must be non-empty")
1330
-
1331
- # Finalize and write manifests
1332
- if manifest:
1333
- _finalize_builder()
1334
- for shard_path, builder in _builders:
1335
- m = builder.build()
1336
- base = str(Path(shard_path).with_suffix(""))
1337
- writer = ManifestWriter(base)
1338
- writer.write(m)
1372
+ with wds.writer.TarWriter(str(path)) as sink:
1373
+ for sample in samples:
1374
+ if sample_type is None:
1375
+ sample_type = type(sample)
1376
+ wds_dict = sample.as_wds
1377
+ sink.write(wds_dict)
1378
+ if manifest:
1379
+ if not _current_builder and sample_type is not None:
1380
+ _current_builder.append(
1381
+ ManifestBuilder(
1382
+ sample_type=sample_type, shard_id=path.stem
1383
+ )
1384
+ )
1385
+ _record_sample(sample, wds_dict)
1386
+ written_paths.append(str(path.resolve()))
1387
+
1388
+ if sample_type is None:
1389
+ raise ValueError("samples must be non-empty")
1390
+
1391
+ # Finalize and write manifests
1392
+ if manifest:
1393
+ _finalize_builder()
1394
+ for shard_path, builder in _builders:
1395
+ m = builder.build()
1396
+ base = str(Path(shard_path).with_suffix(""))
1397
+ writer = ManifestWriter(base)
1398
+ writer.write(m)
1399
+
1400
+ log.info(
1401
+ "write_samples: wrote %d shard(s), sample_type=%s",
1402
+ len(written_paths),
1403
+ sample_type.__name__,
1404
+ )
1339
1405
 
1340
1406
  url = _shards_to_wds_url(written_paths)
1341
1407
  ds: Dataset = Dataset(url)
atdata/index/_entry.py CHANGED
@@ -1,12 +1,16 @@
1
1
  """Dataset entry model and Redis key constants."""
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  from atdata._cid import generate_cid
4
6
 
5
7
  from dataclasses import dataclass, field
6
- from typing import Any, cast
8
+ from typing import Any, TYPE_CHECKING, cast
7
9
 
8
10
  import msgpack
9
- from redis import Redis
11
+
12
+ if TYPE_CHECKING:
13
+ from redis import Redis
10
14
 
11
15
 
12
16
  # Redis key prefixes for index entries and schemas