atdata 0.2.2b1__py3-none-any.whl → 0.3.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/.gitignore +1 -0
- atdata/__init__.py +31 -1
- atdata/_cid.py +29 -35
- atdata/_exceptions.py +168 -0
- atdata/_helpers.py +33 -17
- atdata/_hf_api.py +109 -59
- atdata/_logging.py +70 -0
- atdata/_protocols.py +74 -132
- atdata/_schema_codec.py +38 -41
- atdata/_sources.py +57 -64
- atdata/_stub_manager.py +31 -26
- atdata/_type_utils.py +47 -7
- atdata/atmosphere/__init__.py +31 -24
- atdata/atmosphere/_types.py +11 -11
- atdata/atmosphere/client.py +11 -8
- atdata/atmosphere/lens.py +27 -30
- atdata/atmosphere/records.py +34 -39
- atdata/atmosphere/schema.py +35 -31
- atdata/atmosphere/store.py +16 -20
- atdata/cli/__init__.py +163 -168
- atdata/cli/diagnose.py +12 -8
- atdata/cli/inspect.py +69 -0
- atdata/cli/local.py +5 -2
- atdata/cli/preview.py +63 -0
- atdata/cli/schema.py +109 -0
- atdata/dataset.py +678 -533
- atdata/lens.py +85 -83
- atdata/local/__init__.py +71 -0
- atdata/local/_entry.py +157 -0
- atdata/local/_index.py +940 -0
- atdata/local/_repo_legacy.py +218 -0
- atdata/local/_s3.py +349 -0
- atdata/local/_schema.py +380 -0
- atdata/manifest/__init__.py +28 -0
- atdata/manifest/_aggregates.py +156 -0
- atdata/manifest/_builder.py +163 -0
- atdata/manifest/_fields.py +154 -0
- atdata/manifest/_manifest.py +146 -0
- atdata/manifest/_query.py +150 -0
- atdata/manifest/_writer.py +74 -0
- atdata/promote.py +20 -24
- atdata/providers/__init__.py +25 -0
- atdata/providers/_base.py +140 -0
- atdata/providers/_factory.py +69 -0
- atdata/providers/_postgres.py +214 -0
- atdata/providers/_redis.py +171 -0
- atdata/providers/_sqlite.py +191 -0
- atdata/repository.py +323 -0
- atdata/testing.py +337 -0
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/METADATA +5 -1
- atdata-0.3.0b1.dist-info/RECORD +54 -0
- atdata/local.py +0 -1707
- atdata-0.2.2b1.dist-info/RECORD +0 -28
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""ManifestBuilder for accumulating sample metadata during shard writes.
|
|
2
|
+
|
|
3
|
+
Creates one ``ManifestBuilder`` per shard. Call ``add_sample()`` for each
|
|
4
|
+
sample written, then ``build()`` to produce a finalized ``ShardManifest``.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from datetime import datetime, timezone
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
from ._aggregates import (
|
|
16
|
+
create_aggregate,
|
|
17
|
+
CategoricalAggregate,
|
|
18
|
+
NumericAggregate,
|
|
19
|
+
SetAggregate,
|
|
20
|
+
)
|
|
21
|
+
from ._fields import resolve_manifest_fields
|
|
22
|
+
from ._manifest import ShardManifest
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class _SampleRow:
|
|
27
|
+
"""Internal per-sample metadata row."""
|
|
28
|
+
|
|
29
|
+
key: str
|
|
30
|
+
offset: int
|
|
31
|
+
size: int
|
|
32
|
+
fields: dict[str, Any]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ManifestBuilder:
|
|
36
|
+
"""Accumulates sample metadata during shard writing.
|
|
37
|
+
|
|
38
|
+
Extracts manifest-annotated fields from each sample, feeds running
|
|
39
|
+
aggregate collectors, and accumulates per-sample metadata rows.
|
|
40
|
+
Call ``build()`` after all samples are written to produce a
|
|
41
|
+
``ShardManifest``.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
sample_type: The Packable type being written.
|
|
45
|
+
shard_id: Identifier for this shard (e.g., path without extension).
|
|
46
|
+
schema_version: Version string for the schema.
|
|
47
|
+
source_job_id: Optional provenance job identifier.
|
|
48
|
+
parent_shards: Optional list of input shard identifiers.
|
|
49
|
+
pipeline_version: Optional pipeline version string.
|
|
50
|
+
|
|
51
|
+
Examples:
|
|
52
|
+
>>> builder = ManifestBuilder(
|
|
53
|
+
... sample_type=ImageSample,
|
|
54
|
+
... shard_id="train/shard-00042",
|
|
55
|
+
... )
|
|
56
|
+
>>> builder.add_sample(key="abc", offset=0, size=1024, sample=my_sample)
|
|
57
|
+
>>> manifest = builder.build()
|
|
58
|
+
>>> manifest.num_samples
|
|
59
|
+
1
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
sample_type: type,
|
|
65
|
+
shard_id: str,
|
|
66
|
+
schema_version: str = "1.0.0",
|
|
67
|
+
source_job_id: str | None = None,
|
|
68
|
+
parent_shards: list[str] | None = None,
|
|
69
|
+
pipeline_version: str | None = None,
|
|
70
|
+
) -> None:
|
|
71
|
+
self._sample_type = sample_type
|
|
72
|
+
self._shard_id = shard_id
|
|
73
|
+
self._schema_version = schema_version
|
|
74
|
+
self._source_job_id = source_job_id
|
|
75
|
+
self._parent_shards = parent_shards or []
|
|
76
|
+
self._pipeline_version = pipeline_version
|
|
77
|
+
|
|
78
|
+
self._manifest_fields = resolve_manifest_fields(sample_type)
|
|
79
|
+
self._aggregates: dict[
|
|
80
|
+
str, CategoricalAggregate | NumericAggregate | SetAggregate
|
|
81
|
+
] = {
|
|
82
|
+
name: create_aggregate(mf.aggregate)
|
|
83
|
+
for name, mf in self._manifest_fields.items()
|
|
84
|
+
}
|
|
85
|
+
self._rows: list[_SampleRow] = []
|
|
86
|
+
self._total_size: int = 0
|
|
87
|
+
|
|
88
|
+
def add_sample(
|
|
89
|
+
self,
|
|
90
|
+
*,
|
|
91
|
+
key: str,
|
|
92
|
+
offset: int,
|
|
93
|
+
size: int,
|
|
94
|
+
sample: Any,
|
|
95
|
+
) -> None:
|
|
96
|
+
"""Record a sample's metadata.
|
|
97
|
+
|
|
98
|
+
Extracts manifest-annotated fields from the sample, updates
|
|
99
|
+
running aggregates, and appends a row to the internal list.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
key: The WebDataset ``__key__`` for this sample.
|
|
103
|
+
offset: Byte offset within the tar file.
|
|
104
|
+
size: Size in bytes of this sample's tar entry.
|
|
105
|
+
sample: The sample instance (dataclass with manifest fields).
|
|
106
|
+
"""
|
|
107
|
+
field_values: dict[str, Any] = {}
|
|
108
|
+
for name in self._manifest_fields:
|
|
109
|
+
value = getattr(sample, name, None)
|
|
110
|
+
if value is not None:
|
|
111
|
+
self._aggregates[name].add(value)
|
|
112
|
+
field_values[name] = value
|
|
113
|
+
|
|
114
|
+
self._rows.append(
|
|
115
|
+
_SampleRow(key=key, offset=offset, size=size, fields=field_values)
|
|
116
|
+
)
|
|
117
|
+
self._total_size += size
|
|
118
|
+
|
|
119
|
+
def build(self) -> ShardManifest:
|
|
120
|
+
"""Finalize aggregates and produce a ``ShardManifest``.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
A complete ``ShardManifest`` with header, aggregates, and
|
|
124
|
+
per-sample DataFrame.
|
|
125
|
+
"""
|
|
126
|
+
# Build aggregates dict
|
|
127
|
+
aggregates = {name: agg.to_dict() for name, agg in self._aggregates.items()}
|
|
128
|
+
|
|
129
|
+
# Build per-sample DataFrame
|
|
130
|
+
records: list[dict[str, Any]] = []
|
|
131
|
+
for row in self._rows:
|
|
132
|
+
record: dict[str, Any] = {
|
|
133
|
+
"__key__": row.key,
|
|
134
|
+
"__offset__": row.offset,
|
|
135
|
+
"__size__": row.size,
|
|
136
|
+
}
|
|
137
|
+
record.update(row.fields)
|
|
138
|
+
records.append(record)
|
|
139
|
+
|
|
140
|
+
samples_df = pd.DataFrame(records) if records else pd.DataFrame()
|
|
141
|
+
|
|
142
|
+
# Build provenance
|
|
143
|
+
provenance: dict[str, Any] = {}
|
|
144
|
+
if self._source_job_id:
|
|
145
|
+
provenance["source_job_id"] = self._source_job_id
|
|
146
|
+
if self._parent_shards:
|
|
147
|
+
provenance["parent_shards"] = self._parent_shards
|
|
148
|
+
if self._pipeline_version:
|
|
149
|
+
provenance["pipeline_version"] = self._pipeline_version
|
|
150
|
+
|
|
151
|
+
schema_name = self._sample_type.__name__
|
|
152
|
+
|
|
153
|
+
return ShardManifest(
|
|
154
|
+
shard_id=self._shard_id,
|
|
155
|
+
schema_type=schema_name,
|
|
156
|
+
schema_version=self._schema_version,
|
|
157
|
+
num_samples=len(self._rows),
|
|
158
|
+
size_bytes=self._total_size,
|
|
159
|
+
created_at=datetime.now(timezone.utc),
|
|
160
|
+
aggregates=aggregates,
|
|
161
|
+
samples=samples_df,
|
|
162
|
+
provenance=provenance,
|
|
163
|
+
)
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""Manifest field annotation and introspection.
|
|
2
|
+
|
|
3
|
+
Provides the ``ManifestField`` marker for annotating which sample fields
|
|
4
|
+
should appear in per-shard manifests, and ``resolve_manifest_fields()``
|
|
5
|
+
for introspecting a sample type to discover its manifest-included fields.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import dataclasses
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Any, Literal, get_args, get_origin, get_type_hints
|
|
13
|
+
|
|
14
|
+
from atdata._type_utils import PRIMITIVE_TYPE_MAP, is_ndarray_type, unwrap_optional
|
|
15
|
+
|
|
16
|
+
AggregateKind = Literal["categorical", "numeric", "set"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class ManifestField:
|
|
21
|
+
"""Marker for manifest-included fields.
|
|
22
|
+
|
|
23
|
+
Use with ``Annotated`` to control which fields appear in per-shard
|
|
24
|
+
manifests and what aggregate statistics to compute.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
aggregate: The type of statistical aggregate to compute.
|
|
28
|
+
exclude: If True, explicitly exclude this field from the manifest
|
|
29
|
+
even if it would be auto-inferred.
|
|
30
|
+
|
|
31
|
+
Examples:
|
|
32
|
+
>>> from typing import Annotated
|
|
33
|
+
>>> from numpy.typing import NDArray
|
|
34
|
+
>>> @atdata.packable
|
|
35
|
+
... class ImageSample:
|
|
36
|
+
... image: NDArray
|
|
37
|
+
... label: Annotated[str, ManifestField("categorical")]
|
|
38
|
+
... confidence: Annotated[float, ManifestField("numeric")]
|
|
39
|
+
... tags: Annotated[list[str], ManifestField("set")]
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
aggregate: AggregateKind
|
|
43
|
+
exclude: bool = False
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _extract_manifest_field(annotation: Any) -> ManifestField | None:
|
|
47
|
+
"""Extract a ManifestField from an Annotated type, if present."""
|
|
48
|
+
if get_origin(annotation) is not None:
|
|
49
|
+
# Check for Annotated[T, ManifestField(...)]
|
|
50
|
+
# In Python 3.12+, typing.Annotated has __metadata__
|
|
51
|
+
metadata = getattr(annotation, "__metadata__", None)
|
|
52
|
+
if metadata is not None:
|
|
53
|
+
for item in metadata:
|
|
54
|
+
if isinstance(item, ManifestField):
|
|
55
|
+
return item
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _infer_aggregate_kind(python_type: Any) -> AggregateKind | None:
|
|
60
|
+
"""Infer the aggregate kind from a Python type annotation.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
The inferred aggregate kind, or None if the type should be excluded.
|
|
64
|
+
"""
|
|
65
|
+
# Unwrap Optional
|
|
66
|
+
inner_type, _ = unwrap_optional(python_type)
|
|
67
|
+
|
|
68
|
+
# Exclude NDArray and bytes
|
|
69
|
+
if is_ndarray_type(inner_type):
|
|
70
|
+
return None
|
|
71
|
+
if inner_type is bytes:
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
# Check primitives
|
|
75
|
+
if inner_type in PRIMITIVE_TYPE_MAP:
|
|
76
|
+
type_name = PRIMITIVE_TYPE_MAP[inner_type]
|
|
77
|
+
if type_name in ("str", "bool"):
|
|
78
|
+
return "categorical"
|
|
79
|
+
if type_name in ("int", "float"):
|
|
80
|
+
return "numeric"
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
# Check list types -> set aggregate
|
|
84
|
+
origin = get_origin(inner_type)
|
|
85
|
+
if origin is list:
|
|
86
|
+
return "set"
|
|
87
|
+
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _get_base_type(annotation: Any) -> Any:
|
|
92
|
+
"""Get the base type from an Annotated type or return as-is."""
|
|
93
|
+
args = get_args(annotation)
|
|
94
|
+
metadata = getattr(annotation, "__metadata__", None)
|
|
95
|
+
if metadata is not None and args:
|
|
96
|
+
return args[0]
|
|
97
|
+
return annotation
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def resolve_manifest_fields(sample_type: type) -> dict[str, ManifestField]:
|
|
101
|
+
"""Extract manifest field descriptors from a Packable type.
|
|
102
|
+
|
|
103
|
+
Inspects type hints for ``Annotated[..., ManifestField(...)]`` markers.
|
|
104
|
+
For fields without explicit markers, applies auto-inference rules:
|
|
105
|
+
|
|
106
|
+
- ``str``, ``bool`` -> categorical
|
|
107
|
+
- ``int``, ``float`` -> numeric
|
|
108
|
+
- ``list[T]`` -> set
|
|
109
|
+
- ``NDArray``, ``bytes`` -> excluded
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
sample_type: A ``@packable`` or ``PackableSample`` subclass.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Dict mapping field name to ``ManifestField`` descriptor for all
|
|
116
|
+
manifest-included fields.
|
|
117
|
+
|
|
118
|
+
Examples:
|
|
119
|
+
>>> from typing import Annotated
|
|
120
|
+
>>> @atdata.packable
|
|
121
|
+
... class MySample:
|
|
122
|
+
... label: Annotated[str, ManifestField("categorical")]
|
|
123
|
+
... score: float
|
|
124
|
+
>>> fields = resolve_manifest_fields(MySample)
|
|
125
|
+
>>> fields["label"].aggregate
|
|
126
|
+
'categorical'
|
|
127
|
+
>>> fields["score"].aggregate
|
|
128
|
+
'numeric'
|
|
129
|
+
"""
|
|
130
|
+
if not dataclasses.is_dataclass(sample_type):
|
|
131
|
+
raise TypeError(f"{sample_type} is not a dataclass")
|
|
132
|
+
|
|
133
|
+
hints = get_type_hints(sample_type, include_extras=True)
|
|
134
|
+
dc_fields = {f.name for f in dataclasses.fields(sample_type)}
|
|
135
|
+
result: dict[str, ManifestField] = {}
|
|
136
|
+
|
|
137
|
+
for name, annotation in hints.items():
|
|
138
|
+
if name not in dc_fields:
|
|
139
|
+
continue
|
|
140
|
+
|
|
141
|
+
# Check for explicit ManifestField annotation
|
|
142
|
+
explicit = _extract_manifest_field(annotation)
|
|
143
|
+
if explicit is not None:
|
|
144
|
+
if not explicit.exclude:
|
|
145
|
+
result[name] = explicit
|
|
146
|
+
continue
|
|
147
|
+
|
|
148
|
+
# Auto-infer from base type
|
|
149
|
+
base_type = _get_base_type(annotation)
|
|
150
|
+
kind = _infer_aggregate_kind(base_type)
|
|
151
|
+
if kind is not None:
|
|
152
|
+
result[name] = ManifestField(aggregate=kind)
|
|
153
|
+
|
|
154
|
+
return result
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""ShardManifest data model.
|
|
2
|
+
|
|
3
|
+
Represents a loaded manifest with JSON header (metadata + aggregates)
|
|
4
|
+
and per-sample metadata (as a pandas DataFrame).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
17
|
+
MANIFEST_FORMAT_VERSION = "1.0.0"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class ShardManifest:
|
|
22
|
+
"""In-memory representation of a shard's manifest.
|
|
23
|
+
|
|
24
|
+
Contains the JSON header (metadata + aggregates) and per-sample
|
|
25
|
+
metadata stored as a pandas DataFrame for efficient columnar filtering.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
shard_id: Shard identifier (path without extension).
|
|
29
|
+
schema_type: Schema class name.
|
|
30
|
+
schema_version: Schema version string.
|
|
31
|
+
num_samples: Number of samples in the shard.
|
|
32
|
+
size_bytes: Total shard size in bytes.
|
|
33
|
+
created_at: When the manifest was created.
|
|
34
|
+
aggregates: Dict of field name to aggregate summary dict.
|
|
35
|
+
samples: DataFrame with ``__key__``, ``__offset__``, ``__size__``,
|
|
36
|
+
and manifest field columns.
|
|
37
|
+
provenance: Optional provenance metadata (job ID, parent shards, etc.).
|
|
38
|
+
|
|
39
|
+
Examples:
|
|
40
|
+
>>> manifest = ShardManifest.from_files(
|
|
41
|
+
... "data/shard-000000.manifest.json",
|
|
42
|
+
... "data/shard-000000.manifest.parquet",
|
|
43
|
+
... )
|
|
44
|
+
>>> manifest.num_samples
|
|
45
|
+
1000
|
|
46
|
+
>>> manifest.aggregates["label"]["cardinality"]
|
|
47
|
+
3
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
shard_id: str
|
|
51
|
+
schema_type: str
|
|
52
|
+
schema_version: str
|
|
53
|
+
num_samples: int
|
|
54
|
+
size_bytes: int
|
|
55
|
+
created_at: datetime
|
|
56
|
+
aggregates: dict[str, dict[str, Any]]
|
|
57
|
+
samples: pd.DataFrame
|
|
58
|
+
provenance: dict[str, Any] = field(default_factory=dict)
|
|
59
|
+
|
|
60
|
+
def header_dict(self) -> dict[str, Any]:
|
|
61
|
+
"""Return the JSON-serializable header including aggregates.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Dict suitable for writing as the ``.manifest.json`` file.
|
|
65
|
+
"""
|
|
66
|
+
header: dict[str, Any] = {
|
|
67
|
+
"manifest_version": MANIFEST_FORMAT_VERSION,
|
|
68
|
+
"shard_id": self.shard_id,
|
|
69
|
+
"schema_type": self.schema_type,
|
|
70
|
+
"schema_version": self.schema_version,
|
|
71
|
+
"num_samples": self.num_samples,
|
|
72
|
+
"size_bytes": self.size_bytes,
|
|
73
|
+
"created_at": self.created_at.isoformat(),
|
|
74
|
+
"aggregates": self.aggregates,
|
|
75
|
+
}
|
|
76
|
+
if self.provenance:
|
|
77
|
+
header["provenance"] = self.provenance
|
|
78
|
+
return header
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def from_files(
|
|
82
|
+
cls, json_path: str | Path, parquet_path: str | Path
|
|
83
|
+
) -> ShardManifest:
|
|
84
|
+
"""Load a manifest from its JSON + parquet companion files.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
json_path: Path to the ``.manifest.json`` file.
|
|
88
|
+
parquet_path: Path to the ``.manifest.parquet`` file.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
A fully loaded ``ShardManifest``.
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
FileNotFoundError: If either file does not exist.
|
|
95
|
+
json.JSONDecodeError: If the JSON file is malformed.
|
|
96
|
+
"""
|
|
97
|
+
json_path = Path(json_path)
|
|
98
|
+
parquet_path = Path(parquet_path)
|
|
99
|
+
|
|
100
|
+
with open(json_path, "r", encoding="utf-8") as f:
|
|
101
|
+
header = json.load(f)
|
|
102
|
+
|
|
103
|
+
samples = pd.read_parquet(parquet_path, engine="fastparquet")
|
|
104
|
+
|
|
105
|
+
return cls(
|
|
106
|
+
shard_id=header["shard_id"],
|
|
107
|
+
schema_type=header["schema_type"],
|
|
108
|
+
schema_version=header["schema_version"],
|
|
109
|
+
num_samples=header["num_samples"],
|
|
110
|
+
size_bytes=header["size_bytes"],
|
|
111
|
+
created_at=datetime.fromisoformat(header["created_at"]),
|
|
112
|
+
aggregates=header.get("aggregates", {}),
|
|
113
|
+
samples=samples,
|
|
114
|
+
provenance=header.get("provenance", {}),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
@classmethod
|
|
118
|
+
def from_json_only(cls, json_path: str | Path) -> ShardManifest:
|
|
119
|
+
"""Load header-only manifest for shard-level filtering.
|
|
120
|
+
|
|
121
|
+
Loads just the JSON header without the parquet per-sample data.
|
|
122
|
+
Useful for fast shard pruning via aggregates before loading
|
|
123
|
+
the full parquet file.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
json_path: Path to the ``.manifest.json`` file.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
A ``ShardManifest`` with an empty ``samples`` DataFrame.
|
|
130
|
+
"""
|
|
131
|
+
json_path = Path(json_path)
|
|
132
|
+
|
|
133
|
+
with open(json_path, "r", encoding="utf-8") as f:
|
|
134
|
+
header = json.load(f)
|
|
135
|
+
|
|
136
|
+
return cls(
|
|
137
|
+
shard_id=header["shard_id"],
|
|
138
|
+
schema_type=header["schema_type"],
|
|
139
|
+
schema_version=header["schema_version"],
|
|
140
|
+
num_samples=header["num_samples"],
|
|
141
|
+
size_bytes=header["size_bytes"],
|
|
142
|
+
created_at=datetime.fromisoformat(header["created_at"]),
|
|
143
|
+
aggregates=header.get("aggregates", {}),
|
|
144
|
+
samples=pd.DataFrame(),
|
|
145
|
+
provenance=header.get("provenance", {}),
|
|
146
|
+
)
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Query executor for manifest-based dataset queries.
|
|
2
|
+
|
|
3
|
+
Provides two-phase filtering: shard-level pruning via aggregates,
|
|
4
|
+
then sample-level filtering via the parquet DataFrame.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Callable
|
|
12
|
+
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
from ._manifest import ShardManifest
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class SampleLocation:
|
|
20
|
+
"""Location of a sample within a shard.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
shard: Shard identifier or URL.
|
|
24
|
+
key: WebDataset ``__key__`` for the sample.
|
|
25
|
+
offset: Byte offset within the tar file.
|
|
26
|
+
|
|
27
|
+
Examples:
|
|
28
|
+
>>> loc = SampleLocation(shard="data/shard-000000", key="sample_00042", offset=52480)
|
|
29
|
+
>>> loc.shard
|
|
30
|
+
'data/shard-000000'
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
shard: str
|
|
34
|
+
key: str
|
|
35
|
+
offset: int
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class QueryExecutor:
|
|
39
|
+
"""Executes queries over per-shard manifests.
|
|
40
|
+
|
|
41
|
+
Performs two-phase filtering:
|
|
42
|
+
|
|
43
|
+
1. **Shard-level**: uses aggregates to skip shards that cannot contain
|
|
44
|
+
matching samples (e.g., numeric range exclusion, categorical value absence).
|
|
45
|
+
2. **Sample-level**: applies the predicate to the parquet DataFrame rows.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
manifests: List of ``ShardManifest`` objects to query over.
|
|
49
|
+
|
|
50
|
+
Examples:
|
|
51
|
+
>>> executor = QueryExecutor(manifests)
|
|
52
|
+
>>> results = executor.query(
|
|
53
|
+
... where=lambda df: (df["confidence"] > 0.9) & (df["label"].isin(["dog", "cat"]))
|
|
54
|
+
... )
|
|
55
|
+
>>> len(results)
|
|
56
|
+
42
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, manifests: list[ShardManifest]) -> None:
|
|
60
|
+
self._manifests = manifests
|
|
61
|
+
|
|
62
|
+
def query(
|
|
63
|
+
self,
|
|
64
|
+
where: Callable[[pd.DataFrame], pd.Series],
|
|
65
|
+
) -> list[SampleLocation]:
|
|
66
|
+
"""Execute a query across all manifests.
|
|
67
|
+
|
|
68
|
+
The ``where`` callable receives a pandas DataFrame with the per-sample
|
|
69
|
+
manifest columns and must return a boolean Series selecting matching rows.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
where: Predicate function. Receives a DataFrame, returns a boolean Series.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
List of ``SampleLocation`` for all matching samples.
|
|
76
|
+
"""
|
|
77
|
+
results: list[SampleLocation] = []
|
|
78
|
+
|
|
79
|
+
for manifest in self._manifests:
|
|
80
|
+
if manifest.samples.empty:
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
mask = where(manifest.samples)
|
|
84
|
+
matching = manifest.samples[mask]
|
|
85
|
+
|
|
86
|
+
for _, row in matching.iterrows():
|
|
87
|
+
results.append(
|
|
88
|
+
SampleLocation(
|
|
89
|
+
shard=manifest.shard_id,
|
|
90
|
+
key=row["__key__"],
|
|
91
|
+
offset=int(row["__offset__"]),
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return results
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def from_directory(cls, directory: str | Path) -> QueryExecutor:
|
|
99
|
+
"""Load all manifests from a directory.
|
|
100
|
+
|
|
101
|
+
Discovers ``*.manifest.json`` files and loads each with its
|
|
102
|
+
companion parquet file.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
directory: Path to scan for manifest files.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
A ``QueryExecutor`` loaded with all discovered manifests.
|
|
109
|
+
|
|
110
|
+
Raises:
|
|
111
|
+
FileNotFoundError: If the directory does not exist.
|
|
112
|
+
"""
|
|
113
|
+
directory = Path(directory)
|
|
114
|
+
manifests: list[ShardManifest] = []
|
|
115
|
+
|
|
116
|
+
for json_path in sorted(directory.glob("*.manifest.json")):
|
|
117
|
+
parquet_path = json_path.with_suffix("").with_suffix(".manifest.parquet")
|
|
118
|
+
if parquet_path.exists():
|
|
119
|
+
manifests.append(ShardManifest.from_files(json_path, parquet_path))
|
|
120
|
+
else:
|
|
121
|
+
manifests.append(ShardManifest.from_json_only(json_path))
|
|
122
|
+
|
|
123
|
+
return cls(manifests)
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def from_shard_urls(cls, shard_urls: list[str]) -> QueryExecutor:
|
|
127
|
+
"""Load manifests corresponding to a list of shard URLs.
|
|
128
|
+
|
|
129
|
+
Derives manifest paths by replacing the ``.tar`` extension with
|
|
130
|
+
``.manifest.json`` and ``.manifest.parquet``.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
shard_urls: List of shard file paths or URLs.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
A ``QueryExecutor`` with manifests for shards that have them.
|
|
137
|
+
"""
|
|
138
|
+
manifests: list[ShardManifest] = []
|
|
139
|
+
|
|
140
|
+
for url in shard_urls:
|
|
141
|
+
base = url.removesuffix(".tar")
|
|
142
|
+
json_path = Path(f"{base}.manifest.json")
|
|
143
|
+
parquet_path = Path(f"{base}.manifest.parquet")
|
|
144
|
+
|
|
145
|
+
if json_path.exists() and parquet_path.exists():
|
|
146
|
+
manifests.append(ShardManifest.from_files(json_path, parquet_path))
|
|
147
|
+
elif json_path.exists():
|
|
148
|
+
manifests.append(ShardManifest.from_json_only(json_path))
|
|
149
|
+
|
|
150
|
+
return cls(manifests)
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""ManifestWriter for serializing ShardManifest to JSON + parquet files."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from ._manifest import ShardManifest
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ManifestWriter:
|
|
12
|
+
"""Writes a ``ShardManifest`` to companion JSON and parquet files.
|
|
13
|
+
|
|
14
|
+
Produces two files alongside each shard:
|
|
15
|
+
|
|
16
|
+
- ``{base_path}.manifest.json`` -- header with metadata and aggregates
|
|
17
|
+
- ``{base_path}.manifest.parquet`` -- per-sample metadata (columnar)
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
base_path: The shard path without the ``.tar`` extension.
|
|
21
|
+
|
|
22
|
+
Examples:
|
|
23
|
+
>>> writer = ManifestWriter("/data/shard-000000")
|
|
24
|
+
>>> json_path, parquet_path = writer.write(manifest)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, base_path: str | Path) -> None:
|
|
28
|
+
self._base_path = Path(base_path)
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def json_path(self) -> Path:
|
|
32
|
+
"""Path for the JSON header file."""
|
|
33
|
+
return self._base_path.with_suffix(".manifest.json")
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def parquet_path(self) -> Path:
|
|
37
|
+
"""Path for the parquet per-sample file."""
|
|
38
|
+
return self._base_path.with_suffix(".manifest.parquet")
|
|
39
|
+
|
|
40
|
+
def write(self, manifest: ShardManifest) -> tuple[Path, Path]:
|
|
41
|
+
"""Write the manifest to JSON + parquet files.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
manifest: The ``ShardManifest`` to serialize.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Tuple of ``(json_path, parquet_path)``.
|
|
48
|
+
"""
|
|
49
|
+
json_out = self.json_path
|
|
50
|
+
parquet_out = self.parquet_path
|
|
51
|
+
|
|
52
|
+
# Ensure parent directory exists
|
|
53
|
+
json_out.parent.mkdir(parents=True, exist_ok=True)
|
|
54
|
+
|
|
55
|
+
# Write JSON header + aggregates
|
|
56
|
+
with open(json_out, "w", encoding="utf-8") as f:
|
|
57
|
+
json.dump(manifest.header_dict(), f, indent=2)
|
|
58
|
+
|
|
59
|
+
# Write per-sample parquet
|
|
60
|
+
if not manifest.samples.empty:
|
|
61
|
+
manifest.samples.to_parquet(
|
|
62
|
+
parquet_out,
|
|
63
|
+
engine="fastparquet",
|
|
64
|
+
index=False,
|
|
65
|
+
)
|
|
66
|
+
else:
|
|
67
|
+
# Write an empty parquet with no rows
|
|
68
|
+
manifest.samples.to_parquet(
|
|
69
|
+
parquet_out,
|
|
70
|
+
engine="fastparquet",
|
|
71
|
+
index=False,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return json_out, parquet_out
|