atdata 0.2.2b1__py3-none-any.whl → 0.3.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. atdata/.gitignore +1 -0
  2. atdata/__init__.py +31 -1
  3. atdata/_cid.py +29 -35
  4. atdata/_exceptions.py +168 -0
  5. atdata/_helpers.py +33 -17
  6. atdata/_hf_api.py +109 -59
  7. atdata/_logging.py +70 -0
  8. atdata/_protocols.py +74 -132
  9. atdata/_schema_codec.py +38 -41
  10. atdata/_sources.py +57 -64
  11. atdata/_stub_manager.py +31 -26
  12. atdata/_type_utils.py +47 -7
  13. atdata/atmosphere/__init__.py +31 -24
  14. atdata/atmosphere/_types.py +11 -11
  15. atdata/atmosphere/client.py +11 -8
  16. atdata/atmosphere/lens.py +27 -30
  17. atdata/atmosphere/records.py +34 -39
  18. atdata/atmosphere/schema.py +35 -31
  19. atdata/atmosphere/store.py +16 -20
  20. atdata/cli/__init__.py +163 -168
  21. atdata/cli/diagnose.py +12 -8
  22. atdata/cli/inspect.py +69 -0
  23. atdata/cli/local.py +5 -2
  24. atdata/cli/preview.py +63 -0
  25. atdata/cli/schema.py +109 -0
  26. atdata/dataset.py +678 -533
  27. atdata/lens.py +85 -83
  28. atdata/local/__init__.py +71 -0
  29. atdata/local/_entry.py +157 -0
  30. atdata/local/_index.py +940 -0
  31. atdata/local/_repo_legacy.py +218 -0
  32. atdata/local/_s3.py +349 -0
  33. atdata/local/_schema.py +380 -0
  34. atdata/manifest/__init__.py +28 -0
  35. atdata/manifest/_aggregates.py +156 -0
  36. atdata/manifest/_builder.py +163 -0
  37. atdata/manifest/_fields.py +154 -0
  38. atdata/manifest/_manifest.py +146 -0
  39. atdata/manifest/_query.py +150 -0
  40. atdata/manifest/_writer.py +74 -0
  41. atdata/promote.py +20 -24
  42. atdata/providers/__init__.py +25 -0
  43. atdata/providers/_base.py +140 -0
  44. atdata/providers/_factory.py +69 -0
  45. atdata/providers/_postgres.py +214 -0
  46. atdata/providers/_redis.py +171 -0
  47. atdata/providers/_sqlite.py +191 -0
  48. atdata/repository.py +323 -0
  49. atdata/testing.py +337 -0
  50. {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/METADATA +5 -1
  51. atdata-0.3.0b1.dist-info/RECORD +54 -0
  52. atdata/local.py +0 -1707
  53. atdata-0.2.2b1.dist-info/RECORD +0 -28
  54. {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/WHEEL +0 -0
  55. {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/entry_points.txt +0 -0
  56. {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,163 @@
1
+ """ManifestBuilder for accumulating sample metadata during shard writes.
2
+
3
+ Creates one ``ManifestBuilder`` per shard. Call ``add_sample()`` for each
4
+ sample written, then ``build()`` to produce a finalized ``ShardManifest``.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from datetime import datetime, timezone
11
+ from typing import Any
12
+
13
+ import pandas as pd
14
+
15
+ from ._aggregates import (
16
+ create_aggregate,
17
+ CategoricalAggregate,
18
+ NumericAggregate,
19
+ SetAggregate,
20
+ )
21
+ from ._fields import resolve_manifest_fields
22
+ from ._manifest import ShardManifest
23
+
24
+
25
+ @dataclass
26
+ class _SampleRow:
27
+ """Internal per-sample metadata row."""
28
+
29
+ key: str
30
+ offset: int
31
+ size: int
32
+ fields: dict[str, Any]
33
+
34
+
35
+ class ManifestBuilder:
36
+ """Accumulates sample metadata during shard writing.
37
+
38
+ Extracts manifest-annotated fields from each sample, feeds running
39
+ aggregate collectors, and accumulates per-sample metadata rows.
40
+ Call ``build()`` after all samples are written to produce a
41
+ ``ShardManifest``.
42
+
43
+ Args:
44
+ sample_type: The Packable type being written.
45
+ shard_id: Identifier for this shard (e.g., path without extension).
46
+ schema_version: Version string for the schema.
47
+ source_job_id: Optional provenance job identifier.
48
+ parent_shards: Optional list of input shard identifiers.
49
+ pipeline_version: Optional pipeline version string.
50
+
51
+ Examples:
52
+ >>> builder = ManifestBuilder(
53
+ ... sample_type=ImageSample,
54
+ ... shard_id="train/shard-00042",
55
+ ... )
56
+ >>> builder.add_sample(key="abc", offset=0, size=1024, sample=my_sample)
57
+ >>> manifest = builder.build()
58
+ >>> manifest.num_samples
59
+ 1
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ sample_type: type,
65
+ shard_id: str,
66
+ schema_version: str = "1.0.0",
67
+ source_job_id: str | None = None,
68
+ parent_shards: list[str] | None = None,
69
+ pipeline_version: str | None = None,
70
+ ) -> None:
71
+ self._sample_type = sample_type
72
+ self._shard_id = shard_id
73
+ self._schema_version = schema_version
74
+ self._source_job_id = source_job_id
75
+ self._parent_shards = parent_shards or []
76
+ self._pipeline_version = pipeline_version
77
+
78
+ self._manifest_fields = resolve_manifest_fields(sample_type)
79
+ self._aggregates: dict[
80
+ str, CategoricalAggregate | NumericAggregate | SetAggregate
81
+ ] = {
82
+ name: create_aggregate(mf.aggregate)
83
+ for name, mf in self._manifest_fields.items()
84
+ }
85
+ self._rows: list[_SampleRow] = []
86
+ self._total_size: int = 0
87
+
88
+ def add_sample(
89
+ self,
90
+ *,
91
+ key: str,
92
+ offset: int,
93
+ size: int,
94
+ sample: Any,
95
+ ) -> None:
96
+ """Record a sample's metadata.
97
+
98
+ Extracts manifest-annotated fields from the sample, updates
99
+ running aggregates, and appends a row to the internal list.
100
+
101
+ Args:
102
+ key: The WebDataset ``__key__`` for this sample.
103
+ offset: Byte offset within the tar file.
104
+ size: Size in bytes of this sample's tar entry.
105
+ sample: The sample instance (dataclass with manifest fields).
106
+ """
107
+ field_values: dict[str, Any] = {}
108
+ for name in self._manifest_fields:
109
+ value = getattr(sample, name, None)
110
+ if value is not None:
111
+ self._aggregates[name].add(value)
112
+ field_values[name] = value
113
+
114
+ self._rows.append(
115
+ _SampleRow(key=key, offset=offset, size=size, fields=field_values)
116
+ )
117
+ self._total_size += size
118
+
119
+ def build(self) -> ShardManifest:
120
+ """Finalize aggregates and produce a ``ShardManifest``.
121
+
122
+ Returns:
123
+ A complete ``ShardManifest`` with header, aggregates, and
124
+ per-sample DataFrame.
125
+ """
126
+ # Build aggregates dict
127
+ aggregates = {name: agg.to_dict() for name, agg in self._aggregates.items()}
128
+
129
+ # Build per-sample DataFrame
130
+ records: list[dict[str, Any]] = []
131
+ for row in self._rows:
132
+ record: dict[str, Any] = {
133
+ "__key__": row.key,
134
+ "__offset__": row.offset,
135
+ "__size__": row.size,
136
+ }
137
+ record.update(row.fields)
138
+ records.append(record)
139
+
140
+ samples_df = pd.DataFrame(records) if records else pd.DataFrame()
141
+
142
+ # Build provenance
143
+ provenance: dict[str, Any] = {}
144
+ if self._source_job_id:
145
+ provenance["source_job_id"] = self._source_job_id
146
+ if self._parent_shards:
147
+ provenance["parent_shards"] = self._parent_shards
148
+ if self._pipeline_version:
149
+ provenance["pipeline_version"] = self._pipeline_version
150
+
151
+ schema_name = self._sample_type.__name__
152
+
153
+ return ShardManifest(
154
+ shard_id=self._shard_id,
155
+ schema_type=schema_name,
156
+ schema_version=self._schema_version,
157
+ num_samples=len(self._rows),
158
+ size_bytes=self._total_size,
159
+ created_at=datetime.now(timezone.utc),
160
+ aggregates=aggregates,
161
+ samples=samples_df,
162
+ provenance=provenance,
163
+ )
@@ -0,0 +1,154 @@
1
+ """Manifest field annotation and introspection.
2
+
3
+ Provides the ``ManifestField`` marker for annotating which sample fields
4
+ should appear in per-shard manifests, and ``resolve_manifest_fields()``
5
+ for introspecting a sample type to discover its manifest-included fields.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import dataclasses
11
+ from dataclasses import dataclass
12
+ from typing import Any, Literal, get_args, get_origin, get_type_hints
13
+
14
+ from atdata._type_utils import PRIMITIVE_TYPE_MAP, is_ndarray_type, unwrap_optional
15
+
16
+ AggregateKind = Literal["categorical", "numeric", "set"]
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class ManifestField:
21
+ """Marker for manifest-included fields.
22
+
23
+ Use with ``Annotated`` to control which fields appear in per-shard
24
+ manifests and what aggregate statistics to compute.
25
+
26
+ Args:
27
+ aggregate: The type of statistical aggregate to compute.
28
+ exclude: If True, explicitly exclude this field from the manifest
29
+ even if it would be auto-inferred.
30
+
31
+ Examples:
32
+ >>> from typing import Annotated
33
+ >>> from numpy.typing import NDArray
34
+ >>> @atdata.packable
35
+ ... class ImageSample:
36
+ ... image: NDArray
37
+ ... label: Annotated[str, ManifestField("categorical")]
38
+ ... confidence: Annotated[float, ManifestField("numeric")]
39
+ ... tags: Annotated[list[str], ManifestField("set")]
40
+ """
41
+
42
+ aggregate: AggregateKind
43
+ exclude: bool = False
44
+
45
+
46
+ def _extract_manifest_field(annotation: Any) -> ManifestField | None:
47
+ """Extract a ManifestField from an Annotated type, if present."""
48
+ if get_origin(annotation) is not None:
49
+ # Check for Annotated[T, ManifestField(...)]
50
+ # In Python 3.12+, typing.Annotated has __metadata__
51
+ metadata = getattr(annotation, "__metadata__", None)
52
+ if metadata is not None:
53
+ for item in metadata:
54
+ if isinstance(item, ManifestField):
55
+ return item
56
+ return None
57
+
58
+
59
+ def _infer_aggregate_kind(python_type: Any) -> AggregateKind | None:
60
+ """Infer the aggregate kind from a Python type annotation.
61
+
62
+ Returns:
63
+ The inferred aggregate kind, or None if the type should be excluded.
64
+ """
65
+ # Unwrap Optional
66
+ inner_type, _ = unwrap_optional(python_type)
67
+
68
+ # Exclude NDArray and bytes
69
+ if is_ndarray_type(inner_type):
70
+ return None
71
+ if inner_type is bytes:
72
+ return None
73
+
74
+ # Check primitives
75
+ if inner_type in PRIMITIVE_TYPE_MAP:
76
+ type_name = PRIMITIVE_TYPE_MAP[inner_type]
77
+ if type_name in ("str", "bool"):
78
+ return "categorical"
79
+ if type_name in ("int", "float"):
80
+ return "numeric"
81
+ return None
82
+
83
+ # Check list types -> set aggregate
84
+ origin = get_origin(inner_type)
85
+ if origin is list:
86
+ return "set"
87
+
88
+ return None
89
+
90
+
91
+ def _get_base_type(annotation: Any) -> Any:
92
+ """Get the base type from an Annotated type or return as-is."""
93
+ args = get_args(annotation)
94
+ metadata = getattr(annotation, "__metadata__", None)
95
+ if metadata is not None and args:
96
+ return args[0]
97
+ return annotation
98
+
99
+
100
+ def resolve_manifest_fields(sample_type: type) -> dict[str, ManifestField]:
101
+ """Extract manifest field descriptors from a Packable type.
102
+
103
+ Inspects type hints for ``Annotated[..., ManifestField(...)]`` markers.
104
+ For fields without explicit markers, applies auto-inference rules:
105
+
106
+ - ``str``, ``bool`` -> categorical
107
+ - ``int``, ``float`` -> numeric
108
+ - ``list[T]`` -> set
109
+ - ``NDArray``, ``bytes`` -> excluded
110
+
111
+ Args:
112
+ sample_type: A ``@packable`` or ``PackableSample`` subclass.
113
+
114
+ Returns:
115
+ Dict mapping field name to ``ManifestField`` descriptor for all
116
+ manifest-included fields.
117
+
118
+ Examples:
119
+ >>> from typing import Annotated
120
+ >>> @atdata.packable
121
+ ... class MySample:
122
+ ... label: Annotated[str, ManifestField("categorical")]
123
+ ... score: float
124
+ >>> fields = resolve_manifest_fields(MySample)
125
+ >>> fields["label"].aggregate
126
+ 'categorical'
127
+ >>> fields["score"].aggregate
128
+ 'numeric'
129
+ """
130
+ if not dataclasses.is_dataclass(sample_type):
131
+ raise TypeError(f"{sample_type} is not a dataclass")
132
+
133
+ hints = get_type_hints(sample_type, include_extras=True)
134
+ dc_fields = {f.name for f in dataclasses.fields(sample_type)}
135
+ result: dict[str, ManifestField] = {}
136
+
137
+ for name, annotation in hints.items():
138
+ if name not in dc_fields:
139
+ continue
140
+
141
+ # Check for explicit ManifestField annotation
142
+ explicit = _extract_manifest_field(annotation)
143
+ if explicit is not None:
144
+ if not explicit.exclude:
145
+ result[name] = explicit
146
+ continue
147
+
148
+ # Auto-infer from base type
149
+ base_type = _get_base_type(annotation)
150
+ kind = _infer_aggregate_kind(base_type)
151
+ if kind is not None:
152
+ result[name] = ManifestField(aggregate=kind)
153
+
154
+ return result
@@ -0,0 +1,146 @@
1
+ """ShardManifest data model.
2
+
3
+ Represents a loaded manifest with JSON header (metadata + aggregates)
4
+ and per-sample metadata (as a pandas DataFrame).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from dataclasses import dataclass, field
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ import pandas as pd
16
+
17
+ MANIFEST_FORMAT_VERSION = "1.0.0"
18
+
19
+
20
+ @dataclass
21
+ class ShardManifest:
22
+ """In-memory representation of a shard's manifest.
23
+
24
+ Contains the JSON header (metadata + aggregates) and per-sample
25
+ metadata stored as a pandas DataFrame for efficient columnar filtering.
26
+
27
+ Attributes:
28
+ shard_id: Shard identifier (path without extension).
29
+ schema_type: Schema class name.
30
+ schema_version: Schema version string.
31
+ num_samples: Number of samples in the shard.
32
+ size_bytes: Total shard size in bytes.
33
+ created_at: When the manifest was created.
34
+ aggregates: Dict of field name to aggregate summary dict.
35
+ samples: DataFrame with ``__key__``, ``__offset__``, ``__size__``,
36
+ and manifest field columns.
37
+ provenance: Optional provenance metadata (job ID, parent shards, etc.).
38
+
39
+ Examples:
40
+ >>> manifest = ShardManifest.from_files(
41
+ ... "data/shard-000000.manifest.json",
42
+ ... "data/shard-000000.manifest.parquet",
43
+ ... )
44
+ >>> manifest.num_samples
45
+ 1000
46
+ >>> manifest.aggregates["label"]["cardinality"]
47
+ 3
48
+ """
49
+
50
+ shard_id: str
51
+ schema_type: str
52
+ schema_version: str
53
+ num_samples: int
54
+ size_bytes: int
55
+ created_at: datetime
56
+ aggregates: dict[str, dict[str, Any]]
57
+ samples: pd.DataFrame
58
+ provenance: dict[str, Any] = field(default_factory=dict)
59
+
60
+ def header_dict(self) -> dict[str, Any]:
61
+ """Return the JSON-serializable header including aggregates.
62
+
63
+ Returns:
64
+ Dict suitable for writing as the ``.manifest.json`` file.
65
+ """
66
+ header: dict[str, Any] = {
67
+ "manifest_version": MANIFEST_FORMAT_VERSION,
68
+ "shard_id": self.shard_id,
69
+ "schema_type": self.schema_type,
70
+ "schema_version": self.schema_version,
71
+ "num_samples": self.num_samples,
72
+ "size_bytes": self.size_bytes,
73
+ "created_at": self.created_at.isoformat(),
74
+ "aggregates": self.aggregates,
75
+ }
76
+ if self.provenance:
77
+ header["provenance"] = self.provenance
78
+ return header
79
+
80
+ @classmethod
81
+ def from_files(
82
+ cls, json_path: str | Path, parquet_path: str | Path
83
+ ) -> ShardManifest:
84
+ """Load a manifest from its JSON + parquet companion files.
85
+
86
+ Args:
87
+ json_path: Path to the ``.manifest.json`` file.
88
+ parquet_path: Path to the ``.manifest.parquet`` file.
89
+
90
+ Returns:
91
+ A fully loaded ``ShardManifest``.
92
+
93
+ Raises:
94
+ FileNotFoundError: If either file does not exist.
95
+ json.JSONDecodeError: If the JSON file is malformed.
96
+ """
97
+ json_path = Path(json_path)
98
+ parquet_path = Path(parquet_path)
99
+
100
+ with open(json_path, "r", encoding="utf-8") as f:
101
+ header = json.load(f)
102
+
103
+ samples = pd.read_parquet(parquet_path, engine="fastparquet")
104
+
105
+ return cls(
106
+ shard_id=header["shard_id"],
107
+ schema_type=header["schema_type"],
108
+ schema_version=header["schema_version"],
109
+ num_samples=header["num_samples"],
110
+ size_bytes=header["size_bytes"],
111
+ created_at=datetime.fromisoformat(header["created_at"]),
112
+ aggregates=header.get("aggregates", {}),
113
+ samples=samples,
114
+ provenance=header.get("provenance", {}),
115
+ )
116
+
117
+ @classmethod
118
+ def from_json_only(cls, json_path: str | Path) -> ShardManifest:
119
+ """Load header-only manifest for shard-level filtering.
120
+
121
+ Loads just the JSON header without the parquet per-sample data.
122
+ Useful for fast shard pruning via aggregates before loading
123
+ the full parquet file.
124
+
125
+ Args:
126
+ json_path: Path to the ``.manifest.json`` file.
127
+
128
+ Returns:
129
+ A ``ShardManifest`` with an empty ``samples`` DataFrame.
130
+ """
131
+ json_path = Path(json_path)
132
+
133
+ with open(json_path, "r", encoding="utf-8") as f:
134
+ header = json.load(f)
135
+
136
+ return cls(
137
+ shard_id=header["shard_id"],
138
+ schema_type=header["schema_type"],
139
+ schema_version=header["schema_version"],
140
+ num_samples=header["num_samples"],
141
+ size_bytes=header["size_bytes"],
142
+ created_at=datetime.fromisoformat(header["created_at"]),
143
+ aggregates=header.get("aggregates", {}),
144
+ samples=pd.DataFrame(),
145
+ provenance=header.get("provenance", {}),
146
+ )
@@ -0,0 +1,150 @@
1
+ """Query executor for manifest-based dataset queries.
2
+
3
+ Provides two-phase filtering: shard-level pruning via aggregates,
4
+ then sample-level filtering via the parquet DataFrame.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Callable
12
+
13
+ import pandas as pd
14
+
15
+ from ._manifest import ShardManifest
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class SampleLocation:
20
+ """Location of a sample within a shard.
21
+
22
+ Attributes:
23
+ shard: Shard identifier or URL.
24
+ key: WebDataset ``__key__`` for the sample.
25
+ offset: Byte offset within the tar file.
26
+
27
+ Examples:
28
+ >>> loc = SampleLocation(shard="data/shard-000000", key="sample_00042", offset=52480)
29
+ >>> loc.shard
30
+ 'data/shard-000000'
31
+ """
32
+
33
+ shard: str
34
+ key: str
35
+ offset: int
36
+
37
+
38
+ class QueryExecutor:
39
+ """Executes queries over per-shard manifests.
40
+
41
+ Performs two-phase filtering:
42
+
43
+ 1. **Shard-level**: uses aggregates to skip shards that cannot contain
44
+ matching samples (e.g., numeric range exclusion, categorical value absence).
45
+ 2. **Sample-level**: applies the predicate to the parquet DataFrame rows.
46
+
47
+ Args:
48
+ manifests: List of ``ShardManifest`` objects to query over.
49
+
50
+ Examples:
51
+ >>> executor = QueryExecutor(manifests)
52
+ >>> results = executor.query(
53
+ ... where=lambda df: (df["confidence"] > 0.9) & (df["label"].isin(["dog", "cat"]))
54
+ ... )
55
+ >>> len(results)
56
+ 42
57
+ """
58
+
59
+ def __init__(self, manifests: list[ShardManifest]) -> None:
60
+ self._manifests = manifests
61
+
62
+ def query(
63
+ self,
64
+ where: Callable[[pd.DataFrame], pd.Series],
65
+ ) -> list[SampleLocation]:
66
+ """Execute a query across all manifests.
67
+
68
+ The ``where`` callable receives a pandas DataFrame with the per-sample
69
+ manifest columns and must return a boolean Series selecting matching rows.
70
+
71
+ Args:
72
+ where: Predicate function. Receives a DataFrame, returns a boolean Series.
73
+
74
+ Returns:
75
+ List of ``SampleLocation`` for all matching samples.
76
+ """
77
+ results: list[SampleLocation] = []
78
+
79
+ for manifest in self._manifests:
80
+ if manifest.samples.empty:
81
+ continue
82
+
83
+ mask = where(manifest.samples)
84
+ matching = manifest.samples[mask]
85
+
86
+ for _, row in matching.iterrows():
87
+ results.append(
88
+ SampleLocation(
89
+ shard=manifest.shard_id,
90
+ key=row["__key__"],
91
+ offset=int(row["__offset__"]),
92
+ )
93
+ )
94
+
95
+ return results
96
+
97
+ @classmethod
98
+ def from_directory(cls, directory: str | Path) -> QueryExecutor:
99
+ """Load all manifests from a directory.
100
+
101
+ Discovers ``*.manifest.json`` files and loads each with its
102
+ companion parquet file.
103
+
104
+ Args:
105
+ directory: Path to scan for manifest files.
106
+
107
+ Returns:
108
+ A ``QueryExecutor`` loaded with all discovered manifests.
109
+
110
+ Raises:
111
+ FileNotFoundError: If the directory does not exist.
112
+ """
113
+ directory = Path(directory)
114
+ manifests: list[ShardManifest] = []
115
+
116
+ for json_path in sorted(directory.glob("*.manifest.json")):
117
+ parquet_path = json_path.with_suffix("").with_suffix(".manifest.parquet")
118
+ if parquet_path.exists():
119
+ manifests.append(ShardManifest.from_files(json_path, parquet_path))
120
+ else:
121
+ manifests.append(ShardManifest.from_json_only(json_path))
122
+
123
+ return cls(manifests)
124
+
125
+ @classmethod
126
+ def from_shard_urls(cls, shard_urls: list[str]) -> QueryExecutor:
127
+ """Load manifests corresponding to a list of shard URLs.
128
+
129
+ Derives manifest paths by replacing the ``.tar`` extension with
130
+ ``.manifest.json`` and ``.manifest.parquet``.
131
+
132
+ Args:
133
+ shard_urls: List of shard file paths or URLs.
134
+
135
+ Returns:
136
+ A ``QueryExecutor`` with manifests for shards that have them.
137
+ """
138
+ manifests: list[ShardManifest] = []
139
+
140
+ for url in shard_urls:
141
+ base = url.removesuffix(".tar")
142
+ json_path = Path(f"{base}.manifest.json")
143
+ parquet_path = Path(f"{base}.manifest.parquet")
144
+
145
+ if json_path.exists() and parquet_path.exists():
146
+ manifests.append(ShardManifest.from_files(json_path, parquet_path))
147
+ elif json_path.exists():
148
+ manifests.append(ShardManifest.from_json_only(json_path))
149
+
150
+ return cls(manifests)
@@ -0,0 +1,74 @@
1
+ """ManifestWriter for serializing ShardManifest to JSON + parquet files."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+
8
+ from ._manifest import ShardManifest
9
+
10
+
11
+ class ManifestWriter:
12
+ """Writes a ``ShardManifest`` to companion JSON and parquet files.
13
+
14
+ Produces two files alongside each shard:
15
+
16
+ - ``{base_path}.manifest.json`` -- header with metadata and aggregates
17
+ - ``{base_path}.manifest.parquet`` -- per-sample metadata (columnar)
18
+
19
+ Args:
20
+ base_path: The shard path without the ``.tar`` extension.
21
+
22
+ Examples:
23
+ >>> writer = ManifestWriter("/data/shard-000000")
24
+ >>> json_path, parquet_path = writer.write(manifest)
25
+ """
26
+
27
+ def __init__(self, base_path: str | Path) -> None:
28
+ self._base_path = Path(base_path)
29
+
30
+ @property
31
+ def json_path(self) -> Path:
32
+ """Path for the JSON header file."""
33
+ return self._base_path.with_suffix(".manifest.json")
34
+
35
+ @property
36
+ def parquet_path(self) -> Path:
37
+ """Path for the parquet per-sample file."""
38
+ return self._base_path.with_suffix(".manifest.parquet")
39
+
40
+ def write(self, manifest: ShardManifest) -> tuple[Path, Path]:
41
+ """Write the manifest to JSON + parquet files.
42
+
43
+ Args:
44
+ manifest: The ``ShardManifest`` to serialize.
45
+
46
+ Returns:
47
+ Tuple of ``(json_path, parquet_path)``.
48
+ """
49
+ json_out = self.json_path
50
+ parquet_out = self.parquet_path
51
+
52
+ # Ensure parent directory exists
53
+ json_out.parent.mkdir(parents=True, exist_ok=True)
54
+
55
+ # Write JSON header + aggregates
56
+ with open(json_out, "w", encoding="utf-8") as f:
57
+ json.dump(manifest.header_dict(), f, indent=2)
58
+
59
+ # Write per-sample parquet
60
+ if not manifest.samples.empty:
61
+ manifest.samples.to_parquet(
62
+ parquet_out,
63
+ engine="fastparquet",
64
+ index=False,
65
+ )
66
+ else:
67
+ # Write an empty parquet with no rows
68
+ manifest.samples.to_parquet(
69
+ parquet_out,
70
+ engine="fastparquet",
71
+ index=False,
72
+ )
73
+
74
+ return json_out, parquet_out