atdata 0.2.3b1__py3-none-any.whl → 0.3.1b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. atdata/.gitignore +1 -0
  2. atdata/__init__.py +39 -0
  3. atdata/_cid.py +0 -21
  4. atdata/_exceptions.py +168 -0
  5. atdata/_helpers.py +41 -15
  6. atdata/_hf_api.py +95 -11
  7. atdata/_logging.py +70 -0
  8. atdata/_protocols.py +77 -238
  9. atdata/_schema_codec.py +7 -6
  10. atdata/_stub_manager.py +5 -25
  11. atdata/_type_utils.py +28 -2
  12. atdata/atmosphere/__init__.py +31 -20
  13. atdata/atmosphere/_types.py +4 -4
  14. atdata/atmosphere/client.py +64 -12
  15. atdata/atmosphere/lens.py +11 -12
  16. atdata/atmosphere/records.py +12 -12
  17. atdata/atmosphere/schema.py +16 -18
  18. atdata/atmosphere/store.py +6 -7
  19. atdata/cli/__init__.py +161 -175
  20. atdata/cli/diagnose.py +2 -2
  21. atdata/cli/{local.py → infra.py} +11 -11
  22. atdata/cli/inspect.py +69 -0
  23. atdata/cli/preview.py +63 -0
  24. atdata/cli/schema.py +109 -0
  25. atdata/dataset.py +583 -328
  26. atdata/index/__init__.py +54 -0
  27. atdata/index/_entry.py +157 -0
  28. atdata/index/_index.py +1198 -0
  29. atdata/index/_schema.py +380 -0
  30. atdata/lens.py +9 -2
  31. atdata/lexicons/__init__.py +121 -0
  32. atdata/lexicons/ac.foundation.dataset.arrayFormat.json +16 -0
  33. atdata/lexicons/ac.foundation.dataset.getLatestSchema.json +78 -0
  34. atdata/lexicons/ac.foundation.dataset.lens.json +99 -0
  35. atdata/lexicons/ac.foundation.dataset.record.json +96 -0
  36. atdata/lexicons/ac.foundation.dataset.schema.json +107 -0
  37. atdata/lexicons/ac.foundation.dataset.schemaType.json +16 -0
  38. atdata/lexicons/ac.foundation.dataset.storageBlobs.json +24 -0
  39. atdata/lexicons/ac.foundation.dataset.storageExternal.json +25 -0
  40. atdata/lexicons/ndarray_shim.json +16 -0
  41. atdata/local/__init__.py +70 -0
  42. atdata/local/_repo_legacy.py +218 -0
  43. atdata/manifest/__init__.py +28 -0
  44. atdata/manifest/_aggregates.py +156 -0
  45. atdata/manifest/_builder.py +163 -0
  46. atdata/manifest/_fields.py +154 -0
  47. atdata/manifest/_manifest.py +146 -0
  48. atdata/manifest/_query.py +150 -0
  49. atdata/manifest/_writer.py +74 -0
  50. atdata/promote.py +18 -14
  51. atdata/providers/__init__.py +25 -0
  52. atdata/providers/_base.py +140 -0
  53. atdata/providers/_factory.py +69 -0
  54. atdata/providers/_postgres.py +214 -0
  55. atdata/providers/_redis.py +171 -0
  56. atdata/providers/_sqlite.py +191 -0
  57. atdata/repository.py +323 -0
  58. atdata/stores/__init__.py +23 -0
  59. atdata/stores/_disk.py +123 -0
  60. atdata/stores/_s3.py +349 -0
  61. atdata/testing.py +341 -0
  62. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/METADATA +5 -2
  63. atdata-0.3.1b1.dist-info/RECORD +67 -0
  64. atdata/local.py +0 -1720
  65. atdata-0.2.3b1.dist-info/RECORD +0 -28
  66. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/WHEEL +0 -0
  67. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/entry_points.txt +0 -0
  68. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,28 @@
1
+ """Per-shard manifest and query system.
2
+
3
+ Provides manifest generation during shard writes and efficient
4
+ query execution over large datasets without full scans.
5
+
6
+ Components:
7
+
8
+ - ``ManifestField``: Annotation marker for manifest-included fields
9
+ - ``ManifestBuilder``: Accumulates sample metadata during writes
10
+ - ``ShardManifest``: Loaded manifest representation
11
+ - ``ManifestWriter``: Serializes manifests to JSON + parquet
12
+ - ``QueryExecutor``: Two-phase query over manifest metadata
13
+ - ``SampleLocation``: Address of a sample within a shard
14
+ """
15
+
16
+ from ._fields import ManifestField as ManifestField
17
+ from ._fields import AggregateKind as AggregateKind
18
+ from ._fields import resolve_manifest_fields as resolve_manifest_fields
19
+ from ._aggregates import CategoricalAggregate as CategoricalAggregate
20
+ from ._aggregates import NumericAggregate as NumericAggregate
21
+ from ._aggregates import SetAggregate as SetAggregate
22
+ from ._aggregates import create_aggregate as create_aggregate
23
+ from ._builder import ManifestBuilder as ManifestBuilder
24
+ from ._manifest import ShardManifest as ShardManifest
25
+ from ._manifest import MANIFEST_FORMAT_VERSION as MANIFEST_FORMAT_VERSION
26
+ from ._writer import ManifestWriter as ManifestWriter
27
+ from ._query import QueryExecutor as QueryExecutor
28
+ from ._query import SampleLocation as SampleLocation
@@ -0,0 +1,156 @@
1
+ """Statistical aggregate collectors for manifest fields.
2
+
3
+ Each aggregate type tracks running statistics during shard writing and
4
+ produces a summary dict for inclusion in the manifest JSON header.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from typing import Any
11
+
12
+
13
+ @dataclass
14
+ class CategoricalAggregate:
15
+ """Aggregate for categorical (string/enum) fields.
16
+
17
+ Tracks value counts and cardinality across all samples in a shard.
18
+
19
+ Examples:
20
+ >>> agg = CategoricalAggregate()
21
+ >>> agg.add("dog")
22
+ >>> agg.add("cat")
23
+ >>> agg.add("dog")
24
+ >>> agg.to_dict()
25
+ {'type': 'categorical', 'cardinality': 2, 'value_counts': {'dog': 2, 'cat': 1}}
26
+ """
27
+
28
+ value_counts: dict[str, int] = field(default_factory=dict)
29
+
30
+ @property
31
+ def cardinality(self) -> int:
32
+ """Number of distinct values observed."""
33
+ return len(self.value_counts)
34
+
35
+ def add(self, value: Any) -> None:
36
+ """Record a value observation."""
37
+ key = str(value)
38
+ self.value_counts[key] = self.value_counts.get(key, 0) + 1
39
+
40
+ def to_dict(self) -> dict[str, Any]:
41
+ """Serialize to a JSON-compatible dict."""
42
+ return {
43
+ "type": "categorical",
44
+ "cardinality": self.cardinality,
45
+ "value_counts": dict(self.value_counts),
46
+ }
47
+
48
+
49
+ @dataclass
50
+ class NumericAggregate:
51
+ """Aggregate for numeric (int/float) fields.
52
+
53
+ Tracks min, max, sum, and count for computing summary statistics.
54
+
55
+ Examples:
56
+ >>> agg = NumericAggregate()
57
+ >>> agg.add(1.0)
58
+ >>> agg.add(3.0)
59
+ >>> agg.add(2.0)
60
+ >>> agg.to_dict()
61
+ {'type': 'numeric', 'min': 1.0, 'max': 3.0, 'mean': 2.0, 'count': 3}
62
+ """
63
+
64
+ _min: float = field(default=float("inf"))
65
+ _max: float = field(default=float("-inf"))
66
+ _sum: float = 0.0
67
+ count: int = 0
68
+
69
+ @property
70
+ def min(self) -> float:
71
+ """Minimum observed value."""
72
+ return self._min
73
+
74
+ @property
75
+ def max(self) -> float:
76
+ """Maximum observed value."""
77
+ return self._max
78
+
79
+ @property
80
+ def mean(self) -> float:
81
+ """Running mean of observed values."""
82
+ if self.count == 0:
83
+ return 0.0
84
+ return self._sum / self.count
85
+
86
+ def add(self, value: float | int) -> None:
87
+ """Record a numeric observation."""
88
+ v = float(value)
89
+ if v < self._min:
90
+ self._min = v
91
+ if v > self._max:
92
+ self._max = v
93
+ self._sum += v
94
+ self.count += 1
95
+
96
+ def to_dict(self) -> dict[str, Any]:
97
+ """Serialize to a JSON-compatible dict."""
98
+ return {
99
+ "type": "numeric",
100
+ "min": self._min,
101
+ "max": self._max,
102
+ "mean": self.mean,
103
+ "count": self.count,
104
+ }
105
+
106
+
107
+ @dataclass
108
+ class SetAggregate:
109
+ """Aggregate for set/tag (list) fields.
110
+
111
+ Tracks the union of all observed values across samples.
112
+
113
+ Examples:
114
+ >>> agg = SetAggregate()
115
+ >>> agg.add(["outdoor", "day"])
116
+ >>> agg.add(["indoor"])
117
+ >>> agg.to_dict()
118
+ {'type': 'set', 'all_values': ['day', 'indoor', 'outdoor']}
119
+ """
120
+
121
+ all_values: set[str] = field(default_factory=set)
122
+
123
+ def add(self, values: list | set | tuple) -> None:
124
+ """Record a collection of values."""
125
+ for v in values:
126
+ self.all_values.add(str(v))
127
+
128
+ def to_dict(self) -> dict[str, Any]:
129
+ """Serialize to a JSON-compatible dict."""
130
+ return {
131
+ "type": "set",
132
+ "all_values": sorted(self.all_values),
133
+ }
134
+
135
+
136
+ def create_aggregate(
137
+ kind: str,
138
+ ) -> CategoricalAggregate | NumericAggregate | SetAggregate:
139
+ """Create an aggregate collector for the given kind.
140
+
141
+ Args:
142
+ kind: One of ``"categorical"``, ``"numeric"``, ``"set"``.
143
+
144
+ Returns:
145
+ A new aggregate collector instance.
146
+
147
+ Raises:
148
+ ValueError: If kind is not recognized.
149
+ """
150
+ if kind == "categorical":
151
+ return CategoricalAggregate()
152
+ if kind == "numeric":
153
+ return NumericAggregate()
154
+ if kind == "set":
155
+ return SetAggregate()
156
+ raise ValueError(f"Unknown aggregate kind: {kind!r}")
@@ -0,0 +1,163 @@
1
+ """ManifestBuilder for accumulating sample metadata during shard writes.
2
+
3
+ Creates one ``ManifestBuilder`` per shard. Call ``add_sample()`` for each
4
+ sample written, then ``build()`` to produce a finalized ``ShardManifest``.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from datetime import datetime, timezone
11
+ from typing import Any
12
+
13
+ import pandas as pd
14
+
15
+ from ._aggregates import (
16
+ create_aggregate,
17
+ CategoricalAggregate,
18
+ NumericAggregate,
19
+ SetAggregate,
20
+ )
21
+ from ._fields import resolve_manifest_fields
22
+ from ._manifest import ShardManifest
23
+
24
+
25
+ @dataclass
26
+ class _SampleRow:
27
+ """Internal per-sample metadata row."""
28
+
29
+ key: str
30
+ offset: int
31
+ size: int
32
+ fields: dict[str, Any]
33
+
34
+
35
+ class ManifestBuilder:
36
+ """Accumulates sample metadata during shard writing.
37
+
38
+ Extracts manifest-annotated fields from each sample, feeds running
39
+ aggregate collectors, and accumulates per-sample metadata rows.
40
+ Call ``build()`` after all samples are written to produce a
41
+ ``ShardManifest``.
42
+
43
+ Args:
44
+ sample_type: The Packable type being written.
45
+ shard_id: Identifier for this shard (e.g., path without extension).
46
+ schema_version: Version string for the schema.
47
+ source_job_id: Optional provenance job identifier.
48
+ parent_shards: Optional list of input shard identifiers.
49
+ pipeline_version: Optional pipeline version string.
50
+
51
+ Examples:
52
+ >>> builder = ManifestBuilder(
53
+ ... sample_type=ImageSample,
54
+ ... shard_id="train/shard-00042",
55
+ ... )
56
+ >>> builder.add_sample(key="abc", offset=0, size=1024, sample=my_sample)
57
+ >>> manifest = builder.build()
58
+ >>> manifest.num_samples
59
+ 1
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ sample_type: type,
65
+ shard_id: str,
66
+ schema_version: str = "1.0.0",
67
+ source_job_id: str | None = None,
68
+ parent_shards: list[str] | None = None,
69
+ pipeline_version: str | None = None,
70
+ ) -> None:
71
+ self._sample_type = sample_type
72
+ self._shard_id = shard_id
73
+ self._schema_version = schema_version
74
+ self._source_job_id = source_job_id
75
+ self._parent_shards = parent_shards or []
76
+ self._pipeline_version = pipeline_version
77
+
78
+ self._manifest_fields = resolve_manifest_fields(sample_type)
79
+ self._aggregates: dict[
80
+ str, CategoricalAggregate | NumericAggregate | SetAggregate
81
+ ] = {
82
+ name: create_aggregate(mf.aggregate)
83
+ for name, mf in self._manifest_fields.items()
84
+ }
85
+ self._rows: list[_SampleRow] = []
86
+ self._total_size: int = 0
87
+
88
+ def add_sample(
89
+ self,
90
+ *,
91
+ key: str,
92
+ offset: int,
93
+ size: int,
94
+ sample: Any,
95
+ ) -> None:
96
+ """Record a sample's metadata.
97
+
98
+ Extracts manifest-annotated fields from the sample, updates
99
+ running aggregates, and appends a row to the internal list.
100
+
101
+ Args:
102
+ key: The WebDataset ``__key__`` for this sample.
103
+ offset: Byte offset within the tar file.
104
+ size: Size in bytes of this sample's tar entry.
105
+ sample: The sample instance (dataclass with manifest fields).
106
+ """
107
+ field_values: dict[str, Any] = {}
108
+ for name in self._manifest_fields:
109
+ value = getattr(sample, name, None)
110
+ if value is not None:
111
+ self._aggregates[name].add(value)
112
+ field_values[name] = value
113
+
114
+ self._rows.append(
115
+ _SampleRow(key=key, offset=offset, size=size, fields=field_values)
116
+ )
117
+ self._total_size += size
118
+
119
+ def build(self) -> ShardManifest:
120
+ """Finalize aggregates and produce a ``ShardManifest``.
121
+
122
+ Returns:
123
+ A complete ``ShardManifest`` with header, aggregates, and
124
+ per-sample DataFrame.
125
+ """
126
+ # Build aggregates dict
127
+ aggregates = {name: agg.to_dict() for name, agg in self._aggregates.items()}
128
+
129
+ # Build per-sample DataFrame
130
+ records: list[dict[str, Any]] = []
131
+ for row in self._rows:
132
+ record: dict[str, Any] = {
133
+ "__key__": row.key,
134
+ "__offset__": row.offset,
135
+ "__size__": row.size,
136
+ }
137
+ record.update(row.fields)
138
+ records.append(record)
139
+
140
+ samples_df = pd.DataFrame(records) if records else pd.DataFrame()
141
+
142
+ # Build provenance
143
+ provenance: dict[str, Any] = {}
144
+ if self._source_job_id:
145
+ provenance["source_job_id"] = self._source_job_id
146
+ if self._parent_shards:
147
+ provenance["parent_shards"] = self._parent_shards
148
+ if self._pipeline_version:
149
+ provenance["pipeline_version"] = self._pipeline_version
150
+
151
+ schema_name = self._sample_type.__name__
152
+
153
+ return ShardManifest(
154
+ shard_id=self._shard_id,
155
+ schema_type=schema_name,
156
+ schema_version=self._schema_version,
157
+ num_samples=len(self._rows),
158
+ size_bytes=self._total_size,
159
+ created_at=datetime.now(timezone.utc),
160
+ aggregates=aggregates,
161
+ samples=samples_df,
162
+ provenance=provenance,
163
+ )
@@ -0,0 +1,154 @@
1
+ """Manifest field annotation and introspection.
2
+
3
+ Provides the ``ManifestField`` marker for annotating which sample fields
4
+ should appear in per-shard manifests, and ``resolve_manifest_fields()``
5
+ for introspecting a sample type to discover its manifest-included fields.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import dataclasses
11
+ from dataclasses import dataclass
12
+ from typing import Any, Literal, get_args, get_origin, get_type_hints
13
+
14
+ from atdata._type_utils import PRIMITIVE_TYPE_MAP, is_ndarray_type, unwrap_optional
15
+
16
+ AggregateKind = Literal["categorical", "numeric", "set"]
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class ManifestField:
21
+ """Marker for manifest-included fields.
22
+
23
+ Use with ``Annotated`` to control which fields appear in per-shard
24
+ manifests and what aggregate statistics to compute.
25
+
26
+ Args:
27
+ aggregate: The type of statistical aggregate to compute.
28
+ exclude: If True, explicitly exclude this field from the manifest
29
+ even if it would be auto-inferred.
30
+
31
+ Examples:
32
+ >>> from typing import Annotated
33
+ >>> from numpy.typing import NDArray
34
+ >>> @atdata.packable
35
+ ... class ImageSample:
36
+ ... image: NDArray
37
+ ... label: Annotated[str, ManifestField("categorical")]
38
+ ... confidence: Annotated[float, ManifestField("numeric")]
39
+ ... tags: Annotated[list[str], ManifestField("set")]
40
+ """
41
+
42
+ aggregate: AggregateKind
43
+ exclude: bool = False
44
+
45
+
46
+ def _extract_manifest_field(annotation: Any) -> ManifestField | None:
47
+ """Extract a ManifestField from an Annotated type, if present."""
48
+ if get_origin(annotation) is not None:
49
+ # Check for Annotated[T, ManifestField(...)]
50
+ # In Python 3.12+, typing.Annotated has __metadata__
51
+ metadata = getattr(annotation, "__metadata__", None)
52
+ if metadata is not None:
53
+ for item in metadata:
54
+ if isinstance(item, ManifestField):
55
+ return item
56
+ return None
57
+
58
+
59
+ def _infer_aggregate_kind(python_type: Any) -> AggregateKind | None:
60
+ """Infer the aggregate kind from a Python type annotation.
61
+
62
+ Returns:
63
+ The inferred aggregate kind, or None if the type should be excluded.
64
+ """
65
+ # Unwrap Optional
66
+ inner_type, _ = unwrap_optional(python_type)
67
+
68
+ # Exclude NDArray and bytes
69
+ if is_ndarray_type(inner_type):
70
+ return None
71
+ if inner_type is bytes:
72
+ return None
73
+
74
+ # Check primitives
75
+ if inner_type in PRIMITIVE_TYPE_MAP:
76
+ type_name = PRIMITIVE_TYPE_MAP[inner_type]
77
+ if type_name in ("str", "bool"):
78
+ return "categorical"
79
+ if type_name in ("int", "float"):
80
+ return "numeric"
81
+ return None
82
+
83
+ # Check list types -> set aggregate
84
+ origin = get_origin(inner_type)
85
+ if origin is list:
86
+ return "set"
87
+
88
+ return None
89
+
90
+
91
+ def _get_base_type(annotation: Any) -> Any:
92
+ """Get the base type from an Annotated type or return as-is."""
93
+ args = get_args(annotation)
94
+ metadata = getattr(annotation, "__metadata__", None)
95
+ if metadata is not None and args:
96
+ return args[0]
97
+ return annotation
98
+
99
+
100
+ def resolve_manifest_fields(sample_type: type) -> dict[str, ManifestField]:
101
+ """Extract manifest field descriptors from a Packable type.
102
+
103
+ Inspects type hints for ``Annotated[..., ManifestField(...)]`` markers.
104
+ For fields without explicit markers, applies auto-inference rules:
105
+
106
+ - ``str``, ``bool`` -> categorical
107
+ - ``int``, ``float`` -> numeric
108
+ - ``list[T]`` -> set
109
+ - ``NDArray``, ``bytes`` -> excluded
110
+
111
+ Args:
112
+ sample_type: A ``@packable`` or ``PackableSample`` subclass.
113
+
114
+ Returns:
115
+ Dict mapping field name to ``ManifestField`` descriptor for all
116
+ manifest-included fields.
117
+
118
+ Examples:
119
+ >>> from typing import Annotated
120
+ >>> @atdata.packable
121
+ ... class MySample:
122
+ ... label: Annotated[str, ManifestField("categorical")]
123
+ ... score: float
124
+ >>> fields = resolve_manifest_fields(MySample)
125
+ >>> fields["label"].aggregate
126
+ 'categorical'
127
+ >>> fields["score"].aggregate
128
+ 'numeric'
129
+ """
130
+ if not dataclasses.is_dataclass(sample_type):
131
+ raise TypeError(f"{sample_type} is not a dataclass")
132
+
133
+ hints = get_type_hints(sample_type, include_extras=True)
134
+ dc_fields = {f.name for f in dataclasses.fields(sample_type)}
135
+ result: dict[str, ManifestField] = {}
136
+
137
+ for name, annotation in hints.items():
138
+ if name not in dc_fields:
139
+ continue
140
+
141
+ # Check for explicit ManifestField annotation
142
+ explicit = _extract_manifest_field(annotation)
143
+ if explicit is not None:
144
+ if not explicit.exclude:
145
+ result[name] = explicit
146
+ continue
147
+
148
+ # Auto-infer from base type
149
+ base_type = _get_base_type(annotation)
150
+ kind = _infer_aggregate_kind(base_type)
151
+ if kind is not None:
152
+ result[name] = ManifestField(aggregate=kind)
153
+
154
+ return result
@@ -0,0 +1,146 @@
1
+ """ShardManifest data model.
2
+
3
+ Represents a loaded manifest with JSON header (metadata + aggregates)
4
+ and per-sample metadata (as a pandas DataFrame).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from dataclasses import dataclass, field
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ import pandas as pd
16
+
17
+ MANIFEST_FORMAT_VERSION = "1.0.0"
18
+
19
+
20
+ @dataclass
21
+ class ShardManifest:
22
+ """In-memory representation of a shard's manifest.
23
+
24
+ Contains the JSON header (metadata + aggregates) and per-sample
25
+ metadata stored as a pandas DataFrame for efficient columnar filtering.
26
+
27
+ Attributes:
28
+ shard_id: Shard identifier (path without extension).
29
+ schema_type: Schema class name.
30
+ schema_version: Schema version string.
31
+ num_samples: Number of samples in the shard.
32
+ size_bytes: Total shard size in bytes.
33
+ created_at: When the manifest was created.
34
+ aggregates: Dict of field name to aggregate summary dict.
35
+ samples: DataFrame with ``__key__``, ``__offset__``, ``__size__``,
36
+ and manifest field columns.
37
+ provenance: Optional provenance metadata (job ID, parent shards, etc.).
38
+
39
+ Examples:
40
+ >>> manifest = ShardManifest.from_files(
41
+ ... "data/shard-000000.manifest.json",
42
+ ... "data/shard-000000.manifest.parquet",
43
+ ... )
44
+ >>> manifest.num_samples
45
+ 1000
46
+ >>> manifest.aggregates["label"]["cardinality"]
47
+ 3
48
+ """
49
+
50
+ shard_id: str
51
+ schema_type: str
52
+ schema_version: str
53
+ num_samples: int
54
+ size_bytes: int
55
+ created_at: datetime
56
+ aggregates: dict[str, dict[str, Any]]
57
+ samples: pd.DataFrame
58
+ provenance: dict[str, Any] = field(default_factory=dict)
59
+
60
+ def header_dict(self) -> dict[str, Any]:
61
+ """Return the JSON-serializable header including aggregates.
62
+
63
+ Returns:
64
+ Dict suitable for writing as the ``.manifest.json`` file.
65
+ """
66
+ header: dict[str, Any] = {
67
+ "manifest_version": MANIFEST_FORMAT_VERSION,
68
+ "shard_id": self.shard_id,
69
+ "schema_type": self.schema_type,
70
+ "schema_version": self.schema_version,
71
+ "num_samples": self.num_samples,
72
+ "size_bytes": self.size_bytes,
73
+ "created_at": self.created_at.isoformat(),
74
+ "aggregates": self.aggregates,
75
+ }
76
+ if self.provenance:
77
+ header["provenance"] = self.provenance
78
+ return header
79
+
80
+ @classmethod
81
+ def from_files(
82
+ cls, json_path: str | Path, parquet_path: str | Path
83
+ ) -> ShardManifest:
84
+ """Load a manifest from its JSON + parquet companion files.
85
+
86
+ Args:
87
+ json_path: Path to the ``.manifest.json`` file.
88
+ parquet_path: Path to the ``.manifest.parquet`` file.
89
+
90
+ Returns:
91
+ A fully loaded ``ShardManifest``.
92
+
93
+ Raises:
94
+ FileNotFoundError: If either file does not exist.
95
+ json.JSONDecodeError: If the JSON file is malformed.
96
+ """
97
+ json_path = Path(json_path)
98
+ parquet_path = Path(parquet_path)
99
+
100
+ with open(json_path, "r", encoding="utf-8") as f:
101
+ header = json.load(f)
102
+
103
+ samples = pd.read_parquet(parquet_path, engine="fastparquet")
104
+
105
+ return cls(
106
+ shard_id=header["shard_id"],
107
+ schema_type=header["schema_type"],
108
+ schema_version=header["schema_version"],
109
+ num_samples=header["num_samples"],
110
+ size_bytes=header["size_bytes"],
111
+ created_at=datetime.fromisoformat(header["created_at"]),
112
+ aggregates=header.get("aggregates", {}),
113
+ samples=samples,
114
+ provenance=header.get("provenance", {}),
115
+ )
116
+
117
+ @classmethod
118
+ def from_json_only(cls, json_path: str | Path) -> ShardManifest:
119
+ """Load header-only manifest for shard-level filtering.
120
+
121
+ Loads just the JSON header without the parquet per-sample data.
122
+ Useful for fast shard pruning via aggregates before loading
123
+ the full parquet file.
124
+
125
+ Args:
126
+ json_path: Path to the ``.manifest.json`` file.
127
+
128
+ Returns:
129
+ A ``ShardManifest`` with an empty ``samples`` DataFrame.
130
+ """
131
+ json_path = Path(json_path)
132
+
133
+ with open(json_path, "r", encoding="utf-8") as f:
134
+ header = json.load(f)
135
+
136
+ return cls(
137
+ shard_id=header["shard_id"],
138
+ schema_type=header["schema_type"],
139
+ schema_version=header["schema_version"],
140
+ num_samples=header["num_samples"],
141
+ size_bytes=header["size_bytes"],
142
+ created_at=datetime.fromisoformat(header["created_at"]),
143
+ aggregates=header.get("aggregates", {}),
144
+ samples=pd.DataFrame(),
145
+ provenance=header.get("provenance", {}),
146
+ )