atdata 0.3.0b1__py3-none-any.whl → 0.3.1b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +9 -0
- atdata/_cid.py +0 -21
- atdata/_helpers.py +12 -0
- atdata/_hf_api.py +33 -1
- atdata/_protocols.py +64 -182
- atdata/_schema_codec.py +2 -2
- atdata/_stub_manager.py +5 -25
- atdata/atmosphere/__init__.py +12 -11
- atdata/atmosphere/_types.py +4 -4
- atdata/atmosphere/client.py +64 -12
- atdata/atmosphere/lens.py +11 -12
- atdata/atmosphere/records.py +9 -10
- atdata/atmosphere/schema.py +14 -16
- atdata/atmosphere/store.py +6 -7
- atdata/cli/__init__.py +16 -16
- atdata/cli/diagnose.py +2 -2
- atdata/cli/{local.py → infra.py} +10 -10
- atdata/dataset.py +155 -2
- atdata/index/__init__.py +54 -0
- atdata/{local → index}/_index.py +322 -64
- atdata/{local → index}/_schema.py +5 -5
- atdata/lexicons/__init__.py +121 -0
- atdata/lexicons/ac.foundation.dataset.arrayFormat.json +16 -0
- atdata/lexicons/ac.foundation.dataset.getLatestSchema.json +78 -0
- atdata/lexicons/ac.foundation.dataset.lens.json +99 -0
- atdata/lexicons/ac.foundation.dataset.record.json +96 -0
- atdata/lexicons/ac.foundation.dataset.schema.json +107 -0
- atdata/lexicons/ac.foundation.dataset.schemaType.json +16 -0
- atdata/lexicons/ac.foundation.dataset.storageBlobs.json +24 -0
- atdata/lexicons/ac.foundation.dataset.storageExternal.json +25 -0
- atdata/lexicons/ndarray_shim.json +16 -0
- atdata/local/__init__.py +12 -13
- atdata/local/_repo_legacy.py +3 -3
- atdata/promote.py +14 -10
- atdata/repository.py +7 -7
- atdata/stores/__init__.py +23 -0
- atdata/stores/_disk.py +123 -0
- atdata/testing.py +12 -8
- {atdata-0.3.0b1.dist-info → atdata-0.3.1b1.dist-info}/METADATA +2 -2
- atdata-0.3.1b1.dist-info/RECORD +67 -0
- atdata-0.3.0b1.dist-info/RECORD +0 -54
- /atdata/{local → index}/_entry.py +0 -0
- /atdata/{local → stores}/_s3.py +0 -0
- {atdata-0.3.0b1.dist-info → atdata-0.3.1b1.dist-info}/WHEEL +0 -0
- {atdata-0.3.0b1.dist-info → atdata-0.3.1b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.3.0b1.dist-info → atdata-0.3.1b1.dist-info}/licenses/LICENSE +0 -0
atdata/dataset.py
CHANGED
|
@@ -99,9 +99,11 @@ DT = TypeVar("DT")
|
|
|
99
99
|
|
|
100
100
|
|
|
101
101
|
def _make_packable(x):
|
|
102
|
-
"""Convert numpy arrays to bytes;
|
|
102
|
+
"""Convert numpy arrays to bytes; coerce numpy scalars to Python natives."""
|
|
103
103
|
if isinstance(x, np.ndarray):
|
|
104
104
|
return eh.array_to_bytes(x)
|
|
105
|
+
if isinstance(x, np.generic):
|
|
106
|
+
return x.item()
|
|
105
107
|
return x
|
|
106
108
|
|
|
107
109
|
|
|
@@ -305,7 +307,7 @@ def _batch_aggregate(xs: Sequence):
|
|
|
305
307
|
if not xs:
|
|
306
308
|
return []
|
|
307
309
|
if isinstance(xs[0], np.ndarray):
|
|
308
|
-
return np.
|
|
310
|
+
return np.stack(xs)
|
|
309
311
|
return list(xs)
|
|
310
312
|
|
|
311
313
|
|
|
@@ -1188,3 +1190,154 @@ def packable(cls: type[_T]) -> type[Packable]:
|
|
|
1188
1190
|
##
|
|
1189
1191
|
|
|
1190
1192
|
return as_packable
|
|
1193
|
+
|
|
1194
|
+
|
|
1195
|
+
# ---------------------------------------------------------------------------
|
|
1196
|
+
# write_samples — convenience function for writing samples to tar files
|
|
1197
|
+
# ---------------------------------------------------------------------------
|
|
1198
|
+
|
|
1199
|
+
|
|
1200
|
+
def write_samples(
|
|
1201
|
+
samples: Iterable[ST],
|
|
1202
|
+
path: str | Path,
|
|
1203
|
+
*,
|
|
1204
|
+
maxcount: int | None = None,
|
|
1205
|
+
maxsize: int | None = None,
|
|
1206
|
+
manifest: bool = False,
|
|
1207
|
+
) -> "Dataset[ST]":
|
|
1208
|
+
"""Write an iterable of samples to WebDataset tar file(s).
|
|
1209
|
+
|
|
1210
|
+
Args:
|
|
1211
|
+
samples: Iterable of ``PackableSample`` instances. Must be non-empty.
|
|
1212
|
+
path: Output path for the tar file. For sharded output (when
|
|
1213
|
+
*maxcount* or *maxsize* is set), a ``%06d`` pattern is
|
|
1214
|
+
auto-appended if the path does not already contain ``%``.
|
|
1215
|
+
maxcount: Maximum samples per shard. Triggers multi-shard output.
|
|
1216
|
+
maxsize: Maximum bytes per shard. Triggers multi-shard output.
|
|
1217
|
+
manifest: If True, write per-shard manifest sidecar files
|
|
1218
|
+
(``.manifest.json`` + ``.manifest.parquet``) alongside each
|
|
1219
|
+
tar file. Manifests enable metadata queries via
|
|
1220
|
+
``QueryExecutor`` without opening the tars.
|
|
1221
|
+
|
|
1222
|
+
Returns:
|
|
1223
|
+
A ``Dataset`` wrapping the written file(s), typed to the sample
|
|
1224
|
+
type of the input samples.
|
|
1225
|
+
|
|
1226
|
+
Raises:
|
|
1227
|
+
ValueError: If *samples* is empty.
|
|
1228
|
+
|
|
1229
|
+
Examples:
|
|
1230
|
+
>>> samples = [MySample(key="0", text="hello")]
|
|
1231
|
+
>>> ds = write_samples(samples, "out.tar")
|
|
1232
|
+
>>> list(ds.ordered())
|
|
1233
|
+
[MySample(key='0', text='hello')]
|
|
1234
|
+
"""
|
|
1235
|
+
from ._hf_api import _shards_to_wds_url
|
|
1236
|
+
|
|
1237
|
+
if manifest:
|
|
1238
|
+
from .manifest._builder import ManifestBuilder
|
|
1239
|
+
from .manifest._writer import ManifestWriter
|
|
1240
|
+
|
|
1241
|
+
path = Path(path)
|
|
1242
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
1243
|
+
|
|
1244
|
+
use_shard_writer = maxcount is not None or maxsize is not None
|
|
1245
|
+
sample_type: type | None = None
|
|
1246
|
+
written_paths: list[str] = []
|
|
1247
|
+
|
|
1248
|
+
# Manifest tracking state
|
|
1249
|
+
_current_builder: list = [] # single-element list for nonlocal mutation
|
|
1250
|
+
_builders: list[tuple[str, "ManifestBuilder"]] = []
|
|
1251
|
+
_running_offset: list[int] = [0]
|
|
1252
|
+
|
|
1253
|
+
def _finalize_builder() -> None:
|
|
1254
|
+
"""Finalize the current manifest builder and stash it."""
|
|
1255
|
+
if _current_builder:
|
|
1256
|
+
shard_path = written_paths[-1] if written_paths else ""
|
|
1257
|
+
_builders.append((shard_path, _current_builder[0]))
|
|
1258
|
+
_current_builder.clear()
|
|
1259
|
+
|
|
1260
|
+
def _start_builder(shard_path: str) -> None:
|
|
1261
|
+
"""Start a new manifest builder for a shard."""
|
|
1262
|
+
_finalize_builder()
|
|
1263
|
+
shard_id = Path(shard_path).stem
|
|
1264
|
+
_current_builder.append(
|
|
1265
|
+
ManifestBuilder(sample_type=sample_type, shard_id=shard_id)
|
|
1266
|
+
)
|
|
1267
|
+
_running_offset[0] = 0
|
|
1268
|
+
|
|
1269
|
+
def _record_sample(sample: "PackableSample", wds_dict: dict) -> None:
|
|
1270
|
+
"""Record a sample in the active manifest builder."""
|
|
1271
|
+
if not _current_builder:
|
|
1272
|
+
return
|
|
1273
|
+
packed_bytes = wds_dict["msgpack"]
|
|
1274
|
+
size = len(packed_bytes)
|
|
1275
|
+
_current_builder[0].add_sample(
|
|
1276
|
+
key=wds_dict["__key__"],
|
|
1277
|
+
offset=_running_offset[0],
|
|
1278
|
+
size=size,
|
|
1279
|
+
sample=sample,
|
|
1280
|
+
)
|
|
1281
|
+
_running_offset[0] += size
|
|
1282
|
+
|
|
1283
|
+
if use_shard_writer:
|
|
1284
|
+
# Build shard pattern from path
|
|
1285
|
+
if "%" not in str(path):
|
|
1286
|
+
pattern = str(path.parent / f"{path.stem}-%06d{path.suffix}")
|
|
1287
|
+
else:
|
|
1288
|
+
pattern = str(path)
|
|
1289
|
+
|
|
1290
|
+
writer_kwargs: dict[str, Any] = {}
|
|
1291
|
+
if maxcount is not None:
|
|
1292
|
+
writer_kwargs["maxcount"] = maxcount
|
|
1293
|
+
if maxsize is not None:
|
|
1294
|
+
writer_kwargs["maxsize"] = maxsize
|
|
1295
|
+
|
|
1296
|
+
def _track(p: str) -> None:
|
|
1297
|
+
written_paths.append(str(Path(p).resolve()))
|
|
1298
|
+
if manifest and sample_type is not None:
|
|
1299
|
+
_start_builder(p)
|
|
1300
|
+
|
|
1301
|
+
with wds.writer.ShardWriter(pattern, post=_track, **writer_kwargs) as sink:
|
|
1302
|
+
for sample in samples:
|
|
1303
|
+
if sample_type is None:
|
|
1304
|
+
sample_type = type(sample)
|
|
1305
|
+
wds_dict = sample.as_wds
|
|
1306
|
+
sink.write(wds_dict)
|
|
1307
|
+
if manifest:
|
|
1308
|
+
# The first sample triggers _track before we get here when
|
|
1309
|
+
# ShardWriter opens the first shard, but just in case:
|
|
1310
|
+
if not _current_builder and sample_type is not None:
|
|
1311
|
+
_start_builder(str(path))
|
|
1312
|
+
_record_sample(sample, wds_dict)
|
|
1313
|
+
else:
|
|
1314
|
+
with wds.writer.TarWriter(str(path)) as sink:
|
|
1315
|
+
for sample in samples:
|
|
1316
|
+
if sample_type is None:
|
|
1317
|
+
sample_type = type(sample)
|
|
1318
|
+
wds_dict = sample.as_wds
|
|
1319
|
+
sink.write(wds_dict)
|
|
1320
|
+
if manifest:
|
|
1321
|
+
if not _current_builder and sample_type is not None:
|
|
1322
|
+
_current_builder.append(
|
|
1323
|
+
ManifestBuilder(sample_type=sample_type, shard_id=path.stem)
|
|
1324
|
+
)
|
|
1325
|
+
_record_sample(sample, wds_dict)
|
|
1326
|
+
written_paths.append(str(path.resolve()))
|
|
1327
|
+
|
|
1328
|
+
if sample_type is None:
|
|
1329
|
+
raise ValueError("samples must be non-empty")
|
|
1330
|
+
|
|
1331
|
+
# Finalize and write manifests
|
|
1332
|
+
if manifest:
|
|
1333
|
+
_finalize_builder()
|
|
1334
|
+
for shard_path, builder in _builders:
|
|
1335
|
+
m = builder.build()
|
|
1336
|
+
base = str(Path(shard_path).with_suffix(""))
|
|
1337
|
+
writer = ManifestWriter(base)
|
|
1338
|
+
writer.write(m)
|
|
1339
|
+
|
|
1340
|
+
url = _shards_to_wds_url(written_paths)
|
|
1341
|
+
ds: Dataset = Dataset(url)
|
|
1342
|
+
ds._sample_type_cache = sample_type
|
|
1343
|
+
return ds
|
atdata/index/__init__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Index and entry models for atdata datasets.
|
|
2
|
+
|
|
3
|
+
Key classes:
|
|
4
|
+
|
|
5
|
+
- ``Index``: Unified index with pluggable providers (SQLite default),
|
|
6
|
+
named repositories, and optional atmosphere backend.
|
|
7
|
+
- ``LocalDatasetEntry``: Index entry with ATProto-compatible CIDs.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from atdata.index._entry import (
|
|
11
|
+
LocalDatasetEntry,
|
|
12
|
+
BasicIndexEntry,
|
|
13
|
+
REDIS_KEY_DATASET_ENTRY,
|
|
14
|
+
REDIS_KEY_SCHEMA,
|
|
15
|
+
)
|
|
16
|
+
from atdata.index._schema import (
|
|
17
|
+
SchemaNamespace,
|
|
18
|
+
SchemaFieldType,
|
|
19
|
+
SchemaField,
|
|
20
|
+
LocalSchemaRecord,
|
|
21
|
+
_ATDATA_URI_PREFIX,
|
|
22
|
+
_LEGACY_URI_PREFIX,
|
|
23
|
+
_kind_str_for_sample_type,
|
|
24
|
+
_schema_ref_from_type,
|
|
25
|
+
_make_schema_ref,
|
|
26
|
+
_parse_schema_ref,
|
|
27
|
+
_increment_patch,
|
|
28
|
+
_python_type_to_field_type,
|
|
29
|
+
_build_schema_record,
|
|
30
|
+
)
|
|
31
|
+
from atdata.index._index import Index
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
# Public API
|
|
35
|
+
"Index",
|
|
36
|
+
"LocalDatasetEntry",
|
|
37
|
+
"BasicIndexEntry",
|
|
38
|
+
"SchemaNamespace",
|
|
39
|
+
"SchemaFieldType",
|
|
40
|
+
"SchemaField",
|
|
41
|
+
"LocalSchemaRecord",
|
|
42
|
+
"REDIS_KEY_DATASET_ENTRY",
|
|
43
|
+
"REDIS_KEY_SCHEMA",
|
|
44
|
+
# Internal helpers (re-exported for backward compatibility)
|
|
45
|
+
"_ATDATA_URI_PREFIX",
|
|
46
|
+
"_LEGACY_URI_PREFIX",
|
|
47
|
+
"_kind_str_for_sample_type",
|
|
48
|
+
"_schema_ref_from_type",
|
|
49
|
+
"_make_schema_ref",
|
|
50
|
+
"_parse_schema_ref",
|
|
51
|
+
"_increment_patch",
|
|
52
|
+
"_python_type_to_field_type",
|
|
53
|
+
"_build_schema_record",
|
|
54
|
+
]
|