atdata 0.2.3b1__py3-none-any.whl → 0.3.1b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/.gitignore +1 -0
- atdata/__init__.py +39 -0
- atdata/_cid.py +0 -21
- atdata/_exceptions.py +168 -0
- atdata/_helpers.py +41 -15
- atdata/_hf_api.py +95 -11
- atdata/_logging.py +70 -0
- atdata/_protocols.py +77 -238
- atdata/_schema_codec.py +7 -6
- atdata/_stub_manager.py +5 -25
- atdata/_type_utils.py +28 -2
- atdata/atmosphere/__init__.py +31 -20
- atdata/atmosphere/_types.py +4 -4
- atdata/atmosphere/client.py +64 -12
- atdata/atmosphere/lens.py +11 -12
- atdata/atmosphere/records.py +12 -12
- atdata/atmosphere/schema.py +16 -18
- atdata/atmosphere/store.py +6 -7
- atdata/cli/__init__.py +161 -175
- atdata/cli/diagnose.py +2 -2
- atdata/cli/{local.py → infra.py} +11 -11
- atdata/cli/inspect.py +69 -0
- atdata/cli/preview.py +63 -0
- atdata/cli/schema.py +109 -0
- atdata/dataset.py +583 -328
- atdata/index/__init__.py +54 -0
- atdata/index/_entry.py +157 -0
- atdata/index/_index.py +1198 -0
- atdata/index/_schema.py +380 -0
- atdata/lens.py +9 -2
- atdata/lexicons/__init__.py +121 -0
- atdata/lexicons/ac.foundation.dataset.arrayFormat.json +16 -0
- atdata/lexicons/ac.foundation.dataset.getLatestSchema.json +78 -0
- atdata/lexicons/ac.foundation.dataset.lens.json +99 -0
- atdata/lexicons/ac.foundation.dataset.record.json +96 -0
- atdata/lexicons/ac.foundation.dataset.schema.json +107 -0
- atdata/lexicons/ac.foundation.dataset.schemaType.json +16 -0
- atdata/lexicons/ac.foundation.dataset.storageBlobs.json +24 -0
- atdata/lexicons/ac.foundation.dataset.storageExternal.json +25 -0
- atdata/lexicons/ndarray_shim.json +16 -0
- atdata/local/__init__.py +70 -0
- atdata/local/_repo_legacy.py +218 -0
- atdata/manifest/__init__.py +28 -0
- atdata/manifest/_aggregates.py +156 -0
- atdata/manifest/_builder.py +163 -0
- atdata/manifest/_fields.py +154 -0
- atdata/manifest/_manifest.py +146 -0
- atdata/manifest/_query.py +150 -0
- atdata/manifest/_writer.py +74 -0
- atdata/promote.py +18 -14
- atdata/providers/__init__.py +25 -0
- atdata/providers/_base.py +140 -0
- atdata/providers/_factory.py +69 -0
- atdata/providers/_postgres.py +214 -0
- atdata/providers/_redis.py +171 -0
- atdata/providers/_sqlite.py +191 -0
- atdata/repository.py +323 -0
- atdata/stores/__init__.py +23 -0
- atdata/stores/_disk.py +123 -0
- atdata/stores/_s3.py +349 -0
- atdata/testing.py +341 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/METADATA +5 -2
- atdata-0.3.1b1.dist-info/RECORD +67 -0
- atdata/local.py +0 -1720
- atdata-0.2.3b1.dist-info/RECORD +0 -28
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Per-shard manifest and query system.
|
|
2
|
+
|
|
3
|
+
Provides manifest generation during shard writes and efficient
|
|
4
|
+
query execution over large datasets without full scans.
|
|
5
|
+
|
|
6
|
+
Components:
|
|
7
|
+
|
|
8
|
+
- ``ManifestField``: Annotation marker for manifest-included fields
|
|
9
|
+
- ``ManifestBuilder``: Accumulates sample metadata during writes
|
|
10
|
+
- ``ShardManifest``: Loaded manifest representation
|
|
11
|
+
- ``ManifestWriter``: Serializes manifests to JSON + parquet
|
|
12
|
+
- ``QueryExecutor``: Two-phase query over manifest metadata
|
|
13
|
+
- ``SampleLocation``: Address of a sample within a shard
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from ._fields import ManifestField as ManifestField
|
|
17
|
+
from ._fields import AggregateKind as AggregateKind
|
|
18
|
+
from ._fields import resolve_manifest_fields as resolve_manifest_fields
|
|
19
|
+
from ._aggregates import CategoricalAggregate as CategoricalAggregate
|
|
20
|
+
from ._aggregates import NumericAggregate as NumericAggregate
|
|
21
|
+
from ._aggregates import SetAggregate as SetAggregate
|
|
22
|
+
from ._aggregates import create_aggregate as create_aggregate
|
|
23
|
+
from ._builder import ManifestBuilder as ManifestBuilder
|
|
24
|
+
from ._manifest import ShardManifest as ShardManifest
|
|
25
|
+
from ._manifest import MANIFEST_FORMAT_VERSION as MANIFEST_FORMAT_VERSION
|
|
26
|
+
from ._writer import ManifestWriter as ManifestWriter
|
|
27
|
+
from ._query import QueryExecutor as QueryExecutor
|
|
28
|
+
from ._query import SampleLocation as SampleLocation
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""Statistical aggregate collectors for manifest fields.
|
|
2
|
+
|
|
3
|
+
Each aggregate type tracks running statistics during shard writing and
|
|
4
|
+
produces a summary dict for inclusion in the manifest JSON header.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class CategoricalAggregate:
|
|
15
|
+
"""Aggregate for categorical (string/enum) fields.
|
|
16
|
+
|
|
17
|
+
Tracks value counts and cardinality across all samples in a shard.
|
|
18
|
+
|
|
19
|
+
Examples:
|
|
20
|
+
>>> agg = CategoricalAggregate()
|
|
21
|
+
>>> agg.add("dog")
|
|
22
|
+
>>> agg.add("cat")
|
|
23
|
+
>>> agg.add("dog")
|
|
24
|
+
>>> agg.to_dict()
|
|
25
|
+
{'type': 'categorical', 'cardinality': 2, 'value_counts': {'dog': 2, 'cat': 1}}
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
value_counts: dict[str, int] = field(default_factory=dict)
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def cardinality(self) -> int:
|
|
32
|
+
"""Number of distinct values observed."""
|
|
33
|
+
return len(self.value_counts)
|
|
34
|
+
|
|
35
|
+
def add(self, value: Any) -> None:
|
|
36
|
+
"""Record a value observation."""
|
|
37
|
+
key = str(value)
|
|
38
|
+
self.value_counts[key] = self.value_counts.get(key, 0) + 1
|
|
39
|
+
|
|
40
|
+
def to_dict(self) -> dict[str, Any]:
|
|
41
|
+
"""Serialize to a JSON-compatible dict."""
|
|
42
|
+
return {
|
|
43
|
+
"type": "categorical",
|
|
44
|
+
"cardinality": self.cardinality,
|
|
45
|
+
"value_counts": dict(self.value_counts),
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class NumericAggregate:
|
|
51
|
+
"""Aggregate for numeric (int/float) fields.
|
|
52
|
+
|
|
53
|
+
Tracks min, max, sum, and count for computing summary statistics.
|
|
54
|
+
|
|
55
|
+
Examples:
|
|
56
|
+
>>> agg = NumericAggregate()
|
|
57
|
+
>>> agg.add(1.0)
|
|
58
|
+
>>> agg.add(3.0)
|
|
59
|
+
>>> agg.add(2.0)
|
|
60
|
+
>>> agg.to_dict()
|
|
61
|
+
{'type': 'numeric', 'min': 1.0, 'max': 3.0, 'mean': 2.0, 'count': 3}
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
_min: float = field(default=float("inf"))
|
|
65
|
+
_max: float = field(default=float("-inf"))
|
|
66
|
+
_sum: float = 0.0
|
|
67
|
+
count: int = 0
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def min(self) -> float:
|
|
71
|
+
"""Minimum observed value."""
|
|
72
|
+
return self._min
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def max(self) -> float:
|
|
76
|
+
"""Maximum observed value."""
|
|
77
|
+
return self._max
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def mean(self) -> float:
|
|
81
|
+
"""Running mean of observed values."""
|
|
82
|
+
if self.count == 0:
|
|
83
|
+
return 0.0
|
|
84
|
+
return self._sum / self.count
|
|
85
|
+
|
|
86
|
+
def add(self, value: float | int) -> None:
|
|
87
|
+
"""Record a numeric observation."""
|
|
88
|
+
v = float(value)
|
|
89
|
+
if v < self._min:
|
|
90
|
+
self._min = v
|
|
91
|
+
if v > self._max:
|
|
92
|
+
self._max = v
|
|
93
|
+
self._sum += v
|
|
94
|
+
self.count += 1
|
|
95
|
+
|
|
96
|
+
def to_dict(self) -> dict[str, Any]:
|
|
97
|
+
"""Serialize to a JSON-compatible dict."""
|
|
98
|
+
return {
|
|
99
|
+
"type": "numeric",
|
|
100
|
+
"min": self._min,
|
|
101
|
+
"max": self._max,
|
|
102
|
+
"mean": self.mean,
|
|
103
|
+
"count": self.count,
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@dataclass
|
|
108
|
+
class SetAggregate:
|
|
109
|
+
"""Aggregate for set/tag (list) fields.
|
|
110
|
+
|
|
111
|
+
Tracks the union of all observed values across samples.
|
|
112
|
+
|
|
113
|
+
Examples:
|
|
114
|
+
>>> agg = SetAggregate()
|
|
115
|
+
>>> agg.add(["outdoor", "day"])
|
|
116
|
+
>>> agg.add(["indoor"])
|
|
117
|
+
>>> agg.to_dict()
|
|
118
|
+
{'type': 'set', 'all_values': ['day', 'indoor', 'outdoor']}
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
all_values: set[str] = field(default_factory=set)
|
|
122
|
+
|
|
123
|
+
def add(self, values: list | set | tuple) -> None:
|
|
124
|
+
"""Record a collection of values."""
|
|
125
|
+
for v in values:
|
|
126
|
+
self.all_values.add(str(v))
|
|
127
|
+
|
|
128
|
+
def to_dict(self) -> dict[str, Any]:
|
|
129
|
+
"""Serialize to a JSON-compatible dict."""
|
|
130
|
+
return {
|
|
131
|
+
"type": "set",
|
|
132
|
+
"all_values": sorted(self.all_values),
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def create_aggregate(
|
|
137
|
+
kind: str,
|
|
138
|
+
) -> CategoricalAggregate | NumericAggregate | SetAggregate:
|
|
139
|
+
"""Create an aggregate collector for the given kind.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
kind: One of ``"categorical"``, ``"numeric"``, ``"set"``.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
A new aggregate collector instance.
|
|
146
|
+
|
|
147
|
+
Raises:
|
|
148
|
+
ValueError: If kind is not recognized.
|
|
149
|
+
"""
|
|
150
|
+
if kind == "categorical":
|
|
151
|
+
return CategoricalAggregate()
|
|
152
|
+
if kind == "numeric":
|
|
153
|
+
return NumericAggregate()
|
|
154
|
+
if kind == "set":
|
|
155
|
+
return SetAggregate()
|
|
156
|
+
raise ValueError(f"Unknown aggregate kind: {kind!r}")
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""ManifestBuilder for accumulating sample metadata during shard writes.
|
|
2
|
+
|
|
3
|
+
Creates one ``ManifestBuilder`` per shard. Call ``add_sample()`` for each
|
|
4
|
+
sample written, then ``build()`` to produce a finalized ``ShardManifest``.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from datetime import datetime, timezone
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
from ._aggregates import (
|
|
16
|
+
create_aggregate,
|
|
17
|
+
CategoricalAggregate,
|
|
18
|
+
NumericAggregate,
|
|
19
|
+
SetAggregate,
|
|
20
|
+
)
|
|
21
|
+
from ._fields import resolve_manifest_fields
|
|
22
|
+
from ._manifest import ShardManifest
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class _SampleRow:
|
|
27
|
+
"""Internal per-sample metadata row."""
|
|
28
|
+
|
|
29
|
+
key: str
|
|
30
|
+
offset: int
|
|
31
|
+
size: int
|
|
32
|
+
fields: dict[str, Any]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ManifestBuilder:
|
|
36
|
+
"""Accumulates sample metadata during shard writing.
|
|
37
|
+
|
|
38
|
+
Extracts manifest-annotated fields from each sample, feeds running
|
|
39
|
+
aggregate collectors, and accumulates per-sample metadata rows.
|
|
40
|
+
Call ``build()`` after all samples are written to produce a
|
|
41
|
+
``ShardManifest``.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
sample_type: The Packable type being written.
|
|
45
|
+
shard_id: Identifier for this shard (e.g., path without extension).
|
|
46
|
+
schema_version: Version string for the schema.
|
|
47
|
+
source_job_id: Optional provenance job identifier.
|
|
48
|
+
parent_shards: Optional list of input shard identifiers.
|
|
49
|
+
pipeline_version: Optional pipeline version string.
|
|
50
|
+
|
|
51
|
+
Examples:
|
|
52
|
+
>>> builder = ManifestBuilder(
|
|
53
|
+
... sample_type=ImageSample,
|
|
54
|
+
... shard_id="train/shard-00042",
|
|
55
|
+
... )
|
|
56
|
+
>>> builder.add_sample(key="abc", offset=0, size=1024, sample=my_sample)
|
|
57
|
+
>>> manifest = builder.build()
|
|
58
|
+
>>> manifest.num_samples
|
|
59
|
+
1
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
sample_type: type,
|
|
65
|
+
shard_id: str,
|
|
66
|
+
schema_version: str = "1.0.0",
|
|
67
|
+
source_job_id: str | None = None,
|
|
68
|
+
parent_shards: list[str] | None = None,
|
|
69
|
+
pipeline_version: str | None = None,
|
|
70
|
+
) -> None:
|
|
71
|
+
self._sample_type = sample_type
|
|
72
|
+
self._shard_id = shard_id
|
|
73
|
+
self._schema_version = schema_version
|
|
74
|
+
self._source_job_id = source_job_id
|
|
75
|
+
self._parent_shards = parent_shards or []
|
|
76
|
+
self._pipeline_version = pipeline_version
|
|
77
|
+
|
|
78
|
+
self._manifest_fields = resolve_manifest_fields(sample_type)
|
|
79
|
+
self._aggregates: dict[
|
|
80
|
+
str, CategoricalAggregate | NumericAggregate | SetAggregate
|
|
81
|
+
] = {
|
|
82
|
+
name: create_aggregate(mf.aggregate)
|
|
83
|
+
for name, mf in self._manifest_fields.items()
|
|
84
|
+
}
|
|
85
|
+
self._rows: list[_SampleRow] = []
|
|
86
|
+
self._total_size: int = 0
|
|
87
|
+
|
|
88
|
+
def add_sample(
|
|
89
|
+
self,
|
|
90
|
+
*,
|
|
91
|
+
key: str,
|
|
92
|
+
offset: int,
|
|
93
|
+
size: int,
|
|
94
|
+
sample: Any,
|
|
95
|
+
) -> None:
|
|
96
|
+
"""Record a sample's metadata.
|
|
97
|
+
|
|
98
|
+
Extracts manifest-annotated fields from the sample, updates
|
|
99
|
+
running aggregates, and appends a row to the internal list.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
key: The WebDataset ``__key__`` for this sample.
|
|
103
|
+
offset: Byte offset within the tar file.
|
|
104
|
+
size: Size in bytes of this sample's tar entry.
|
|
105
|
+
sample: The sample instance (dataclass with manifest fields).
|
|
106
|
+
"""
|
|
107
|
+
field_values: dict[str, Any] = {}
|
|
108
|
+
for name in self._manifest_fields:
|
|
109
|
+
value = getattr(sample, name, None)
|
|
110
|
+
if value is not None:
|
|
111
|
+
self._aggregates[name].add(value)
|
|
112
|
+
field_values[name] = value
|
|
113
|
+
|
|
114
|
+
self._rows.append(
|
|
115
|
+
_SampleRow(key=key, offset=offset, size=size, fields=field_values)
|
|
116
|
+
)
|
|
117
|
+
self._total_size += size
|
|
118
|
+
|
|
119
|
+
def build(self) -> ShardManifest:
|
|
120
|
+
"""Finalize aggregates and produce a ``ShardManifest``.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
A complete ``ShardManifest`` with header, aggregates, and
|
|
124
|
+
per-sample DataFrame.
|
|
125
|
+
"""
|
|
126
|
+
# Build aggregates dict
|
|
127
|
+
aggregates = {name: agg.to_dict() for name, agg in self._aggregates.items()}
|
|
128
|
+
|
|
129
|
+
# Build per-sample DataFrame
|
|
130
|
+
records: list[dict[str, Any]] = []
|
|
131
|
+
for row in self._rows:
|
|
132
|
+
record: dict[str, Any] = {
|
|
133
|
+
"__key__": row.key,
|
|
134
|
+
"__offset__": row.offset,
|
|
135
|
+
"__size__": row.size,
|
|
136
|
+
}
|
|
137
|
+
record.update(row.fields)
|
|
138
|
+
records.append(record)
|
|
139
|
+
|
|
140
|
+
samples_df = pd.DataFrame(records) if records else pd.DataFrame()
|
|
141
|
+
|
|
142
|
+
# Build provenance
|
|
143
|
+
provenance: dict[str, Any] = {}
|
|
144
|
+
if self._source_job_id:
|
|
145
|
+
provenance["source_job_id"] = self._source_job_id
|
|
146
|
+
if self._parent_shards:
|
|
147
|
+
provenance["parent_shards"] = self._parent_shards
|
|
148
|
+
if self._pipeline_version:
|
|
149
|
+
provenance["pipeline_version"] = self._pipeline_version
|
|
150
|
+
|
|
151
|
+
schema_name = self._sample_type.__name__
|
|
152
|
+
|
|
153
|
+
return ShardManifest(
|
|
154
|
+
shard_id=self._shard_id,
|
|
155
|
+
schema_type=schema_name,
|
|
156
|
+
schema_version=self._schema_version,
|
|
157
|
+
num_samples=len(self._rows),
|
|
158
|
+
size_bytes=self._total_size,
|
|
159
|
+
created_at=datetime.now(timezone.utc),
|
|
160
|
+
aggregates=aggregates,
|
|
161
|
+
samples=samples_df,
|
|
162
|
+
provenance=provenance,
|
|
163
|
+
)
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""Manifest field annotation and introspection.
|
|
2
|
+
|
|
3
|
+
Provides the ``ManifestField`` marker for annotating which sample fields
|
|
4
|
+
should appear in per-shard manifests, and ``resolve_manifest_fields()``
|
|
5
|
+
for introspecting a sample type to discover its manifest-included fields.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import dataclasses
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Any, Literal, get_args, get_origin, get_type_hints
|
|
13
|
+
|
|
14
|
+
from atdata._type_utils import PRIMITIVE_TYPE_MAP, is_ndarray_type, unwrap_optional
|
|
15
|
+
|
|
16
|
+
AggregateKind = Literal["categorical", "numeric", "set"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class ManifestField:
|
|
21
|
+
"""Marker for manifest-included fields.
|
|
22
|
+
|
|
23
|
+
Use with ``Annotated`` to control which fields appear in per-shard
|
|
24
|
+
manifests and what aggregate statistics to compute.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
aggregate: The type of statistical aggregate to compute.
|
|
28
|
+
exclude: If True, explicitly exclude this field from the manifest
|
|
29
|
+
even if it would be auto-inferred.
|
|
30
|
+
|
|
31
|
+
Examples:
|
|
32
|
+
>>> from typing import Annotated
|
|
33
|
+
>>> from numpy.typing import NDArray
|
|
34
|
+
>>> @atdata.packable
|
|
35
|
+
... class ImageSample:
|
|
36
|
+
... image: NDArray
|
|
37
|
+
... label: Annotated[str, ManifestField("categorical")]
|
|
38
|
+
... confidence: Annotated[float, ManifestField("numeric")]
|
|
39
|
+
... tags: Annotated[list[str], ManifestField("set")]
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
aggregate: AggregateKind
|
|
43
|
+
exclude: bool = False
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _extract_manifest_field(annotation: Any) -> ManifestField | None:
|
|
47
|
+
"""Extract a ManifestField from an Annotated type, if present."""
|
|
48
|
+
if get_origin(annotation) is not None:
|
|
49
|
+
# Check for Annotated[T, ManifestField(...)]
|
|
50
|
+
# In Python 3.12+, typing.Annotated has __metadata__
|
|
51
|
+
metadata = getattr(annotation, "__metadata__", None)
|
|
52
|
+
if metadata is not None:
|
|
53
|
+
for item in metadata:
|
|
54
|
+
if isinstance(item, ManifestField):
|
|
55
|
+
return item
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _infer_aggregate_kind(python_type: Any) -> AggregateKind | None:
|
|
60
|
+
"""Infer the aggregate kind from a Python type annotation.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
The inferred aggregate kind, or None if the type should be excluded.
|
|
64
|
+
"""
|
|
65
|
+
# Unwrap Optional
|
|
66
|
+
inner_type, _ = unwrap_optional(python_type)
|
|
67
|
+
|
|
68
|
+
# Exclude NDArray and bytes
|
|
69
|
+
if is_ndarray_type(inner_type):
|
|
70
|
+
return None
|
|
71
|
+
if inner_type is bytes:
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
# Check primitives
|
|
75
|
+
if inner_type in PRIMITIVE_TYPE_MAP:
|
|
76
|
+
type_name = PRIMITIVE_TYPE_MAP[inner_type]
|
|
77
|
+
if type_name in ("str", "bool"):
|
|
78
|
+
return "categorical"
|
|
79
|
+
if type_name in ("int", "float"):
|
|
80
|
+
return "numeric"
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
# Check list types -> set aggregate
|
|
84
|
+
origin = get_origin(inner_type)
|
|
85
|
+
if origin is list:
|
|
86
|
+
return "set"
|
|
87
|
+
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _get_base_type(annotation: Any) -> Any:
|
|
92
|
+
"""Get the base type from an Annotated type or return as-is."""
|
|
93
|
+
args = get_args(annotation)
|
|
94
|
+
metadata = getattr(annotation, "__metadata__", None)
|
|
95
|
+
if metadata is not None and args:
|
|
96
|
+
return args[0]
|
|
97
|
+
return annotation
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def resolve_manifest_fields(sample_type: type) -> dict[str, ManifestField]:
|
|
101
|
+
"""Extract manifest field descriptors from a Packable type.
|
|
102
|
+
|
|
103
|
+
Inspects type hints for ``Annotated[..., ManifestField(...)]`` markers.
|
|
104
|
+
For fields without explicit markers, applies auto-inference rules:
|
|
105
|
+
|
|
106
|
+
- ``str``, ``bool`` -> categorical
|
|
107
|
+
- ``int``, ``float`` -> numeric
|
|
108
|
+
- ``list[T]`` -> set
|
|
109
|
+
- ``NDArray``, ``bytes`` -> excluded
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
sample_type: A ``@packable`` or ``PackableSample`` subclass.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Dict mapping field name to ``ManifestField`` descriptor for all
|
|
116
|
+
manifest-included fields.
|
|
117
|
+
|
|
118
|
+
Examples:
|
|
119
|
+
>>> from typing import Annotated
|
|
120
|
+
>>> @atdata.packable
|
|
121
|
+
... class MySample:
|
|
122
|
+
... label: Annotated[str, ManifestField("categorical")]
|
|
123
|
+
... score: float
|
|
124
|
+
>>> fields = resolve_manifest_fields(MySample)
|
|
125
|
+
>>> fields["label"].aggregate
|
|
126
|
+
'categorical'
|
|
127
|
+
>>> fields["score"].aggregate
|
|
128
|
+
'numeric'
|
|
129
|
+
"""
|
|
130
|
+
if not dataclasses.is_dataclass(sample_type):
|
|
131
|
+
raise TypeError(f"{sample_type} is not a dataclass")
|
|
132
|
+
|
|
133
|
+
hints = get_type_hints(sample_type, include_extras=True)
|
|
134
|
+
dc_fields = {f.name for f in dataclasses.fields(sample_type)}
|
|
135
|
+
result: dict[str, ManifestField] = {}
|
|
136
|
+
|
|
137
|
+
for name, annotation in hints.items():
|
|
138
|
+
if name not in dc_fields:
|
|
139
|
+
continue
|
|
140
|
+
|
|
141
|
+
# Check for explicit ManifestField annotation
|
|
142
|
+
explicit = _extract_manifest_field(annotation)
|
|
143
|
+
if explicit is not None:
|
|
144
|
+
if not explicit.exclude:
|
|
145
|
+
result[name] = explicit
|
|
146
|
+
continue
|
|
147
|
+
|
|
148
|
+
# Auto-infer from base type
|
|
149
|
+
base_type = _get_base_type(annotation)
|
|
150
|
+
kind = _infer_aggregate_kind(base_type)
|
|
151
|
+
if kind is not None:
|
|
152
|
+
result[name] = ManifestField(aggregate=kind)
|
|
153
|
+
|
|
154
|
+
return result
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""ShardManifest data model.
|
|
2
|
+
|
|
3
|
+
Represents a loaded manifest with JSON header (metadata + aggregates)
|
|
4
|
+
and per-sample metadata (as a pandas DataFrame).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
17
|
+
MANIFEST_FORMAT_VERSION = "1.0.0"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class ShardManifest:
|
|
22
|
+
"""In-memory representation of a shard's manifest.
|
|
23
|
+
|
|
24
|
+
Contains the JSON header (metadata + aggregates) and per-sample
|
|
25
|
+
metadata stored as a pandas DataFrame for efficient columnar filtering.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
shard_id: Shard identifier (path without extension).
|
|
29
|
+
schema_type: Schema class name.
|
|
30
|
+
schema_version: Schema version string.
|
|
31
|
+
num_samples: Number of samples in the shard.
|
|
32
|
+
size_bytes: Total shard size in bytes.
|
|
33
|
+
created_at: When the manifest was created.
|
|
34
|
+
aggregates: Dict of field name to aggregate summary dict.
|
|
35
|
+
samples: DataFrame with ``__key__``, ``__offset__``, ``__size__``,
|
|
36
|
+
and manifest field columns.
|
|
37
|
+
provenance: Optional provenance metadata (job ID, parent shards, etc.).
|
|
38
|
+
|
|
39
|
+
Examples:
|
|
40
|
+
>>> manifest = ShardManifest.from_files(
|
|
41
|
+
... "data/shard-000000.manifest.json",
|
|
42
|
+
... "data/shard-000000.manifest.parquet",
|
|
43
|
+
... )
|
|
44
|
+
>>> manifest.num_samples
|
|
45
|
+
1000
|
|
46
|
+
>>> manifest.aggregates["label"]["cardinality"]
|
|
47
|
+
3
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
shard_id: str
|
|
51
|
+
schema_type: str
|
|
52
|
+
schema_version: str
|
|
53
|
+
num_samples: int
|
|
54
|
+
size_bytes: int
|
|
55
|
+
created_at: datetime
|
|
56
|
+
aggregates: dict[str, dict[str, Any]]
|
|
57
|
+
samples: pd.DataFrame
|
|
58
|
+
provenance: dict[str, Any] = field(default_factory=dict)
|
|
59
|
+
|
|
60
|
+
def header_dict(self) -> dict[str, Any]:
|
|
61
|
+
"""Return the JSON-serializable header including aggregates.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Dict suitable for writing as the ``.manifest.json`` file.
|
|
65
|
+
"""
|
|
66
|
+
header: dict[str, Any] = {
|
|
67
|
+
"manifest_version": MANIFEST_FORMAT_VERSION,
|
|
68
|
+
"shard_id": self.shard_id,
|
|
69
|
+
"schema_type": self.schema_type,
|
|
70
|
+
"schema_version": self.schema_version,
|
|
71
|
+
"num_samples": self.num_samples,
|
|
72
|
+
"size_bytes": self.size_bytes,
|
|
73
|
+
"created_at": self.created_at.isoformat(),
|
|
74
|
+
"aggregates": self.aggregates,
|
|
75
|
+
}
|
|
76
|
+
if self.provenance:
|
|
77
|
+
header["provenance"] = self.provenance
|
|
78
|
+
return header
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def from_files(
|
|
82
|
+
cls, json_path: str | Path, parquet_path: str | Path
|
|
83
|
+
) -> ShardManifest:
|
|
84
|
+
"""Load a manifest from its JSON + parquet companion files.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
json_path: Path to the ``.manifest.json`` file.
|
|
88
|
+
parquet_path: Path to the ``.manifest.parquet`` file.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
A fully loaded ``ShardManifest``.
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
FileNotFoundError: If either file does not exist.
|
|
95
|
+
json.JSONDecodeError: If the JSON file is malformed.
|
|
96
|
+
"""
|
|
97
|
+
json_path = Path(json_path)
|
|
98
|
+
parquet_path = Path(parquet_path)
|
|
99
|
+
|
|
100
|
+
with open(json_path, "r", encoding="utf-8") as f:
|
|
101
|
+
header = json.load(f)
|
|
102
|
+
|
|
103
|
+
samples = pd.read_parquet(parquet_path, engine="fastparquet")
|
|
104
|
+
|
|
105
|
+
return cls(
|
|
106
|
+
shard_id=header["shard_id"],
|
|
107
|
+
schema_type=header["schema_type"],
|
|
108
|
+
schema_version=header["schema_version"],
|
|
109
|
+
num_samples=header["num_samples"],
|
|
110
|
+
size_bytes=header["size_bytes"],
|
|
111
|
+
created_at=datetime.fromisoformat(header["created_at"]),
|
|
112
|
+
aggregates=header.get("aggregates", {}),
|
|
113
|
+
samples=samples,
|
|
114
|
+
provenance=header.get("provenance", {}),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
@classmethod
|
|
118
|
+
def from_json_only(cls, json_path: str | Path) -> ShardManifest:
|
|
119
|
+
"""Load header-only manifest for shard-level filtering.
|
|
120
|
+
|
|
121
|
+
Loads just the JSON header without the parquet per-sample data.
|
|
122
|
+
Useful for fast shard pruning via aggregates before loading
|
|
123
|
+
the full parquet file.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
json_path: Path to the ``.manifest.json`` file.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
A ``ShardManifest`` with an empty ``samples`` DataFrame.
|
|
130
|
+
"""
|
|
131
|
+
json_path = Path(json_path)
|
|
132
|
+
|
|
133
|
+
with open(json_path, "r", encoding="utf-8") as f:
|
|
134
|
+
header = json.load(f)
|
|
135
|
+
|
|
136
|
+
return cls(
|
|
137
|
+
shard_id=header["shard_id"],
|
|
138
|
+
schema_type=header["schema_type"],
|
|
139
|
+
schema_version=header["schema_version"],
|
|
140
|
+
num_samples=header["num_samples"],
|
|
141
|
+
size_bytes=header["size_bytes"],
|
|
142
|
+
created_at=datetime.fromisoformat(header["created_at"]),
|
|
143
|
+
aggregates=header.get("aggregates", {}),
|
|
144
|
+
samples=pd.DataFrame(),
|
|
145
|
+
provenance=header.get("provenance", {}),
|
|
146
|
+
)
|