atdata 0.3.0b1__py3-none-any.whl → 0.3.2b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +11 -0
- atdata/_cid.py +0 -21
- atdata/_helpers.py +12 -0
- atdata/_hf_api.py +46 -1
- atdata/_logging.py +43 -0
- atdata/_protocols.py +81 -182
- atdata/_schema_codec.py +2 -2
- atdata/_sources.py +24 -4
- atdata/_stub_manager.py +5 -25
- atdata/atmosphere/__init__.py +60 -21
- atdata/atmosphere/_lexicon_types.py +595 -0
- atdata/atmosphere/_types.py +73 -245
- atdata/atmosphere/client.py +64 -12
- atdata/atmosphere/lens.py +60 -53
- atdata/atmosphere/records.py +291 -100
- atdata/atmosphere/schema.py +91 -65
- atdata/atmosphere/store.py +68 -66
- atdata/cli/__init__.py +16 -16
- atdata/cli/diagnose.py +2 -2
- atdata/cli/{local.py → infra.py} +10 -10
- atdata/dataset.py +266 -47
- atdata/index/__init__.py +54 -0
- atdata/{local → index}/_entry.py +6 -2
- atdata/{local → index}/_index.py +617 -72
- atdata/{local → index}/_schema.py +5 -5
- atdata/lexicons/__init__.py +127 -0
- atdata/lexicons/ac.foundation.dataset.arrayFormat.json +16 -0
- atdata/lexicons/ac.foundation.dataset.getLatestSchema.json +78 -0
- atdata/lexicons/ac.foundation.dataset.lens.json +101 -0
- atdata/lexicons/ac.foundation.dataset.record.json +117 -0
- atdata/lexicons/ac.foundation.dataset.schema.json +107 -0
- atdata/lexicons/ac.foundation.dataset.schemaType.json +16 -0
- atdata/lexicons/ac.foundation.dataset.storageBlobs.json +46 -0
- atdata/lexicons/ac.foundation.dataset.storageExternal.json +25 -0
- atdata/lexicons/ac.foundation.dataset.storageHttp.json +45 -0
- atdata/lexicons/ac.foundation.dataset.storageS3.json +61 -0
- atdata/lexicons/ndarray_shim.json +16 -0
- atdata/local/__init__.py +12 -13
- atdata/local/_repo_legacy.py +3 -3
- atdata/manifest/__init__.py +4 -0
- atdata/manifest/_proxy.py +321 -0
- atdata/promote.py +14 -10
- atdata/repository.py +66 -16
- atdata/stores/__init__.py +23 -0
- atdata/stores/_disk.py +131 -0
- atdata/{local → stores}/_s3.py +134 -112
- atdata/testing.py +12 -8
- {atdata-0.3.0b1.dist-info → atdata-0.3.2b1.dist-info}/METADATA +2 -2
- atdata-0.3.2b1.dist-info/RECORD +71 -0
- atdata-0.3.0b1.dist-info/RECORD +0 -54
- {atdata-0.3.0b1.dist-info → atdata-0.3.2b1.dist-info}/WHEEL +0 -0
- {atdata-0.3.0b1.dist-info → atdata-0.3.2b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.3.0b1.dist-info → atdata-0.3.2b1.dist-info}/licenses/LICENSE +0 -0
atdata/repository.py
CHANGED
|
@@ -5,7 +5,7 @@ optional ``AbstractDataStore`` (shard storage), forming a named storage unit
|
|
|
5
5
|
that can be mounted into an ``Index``.
|
|
6
6
|
|
|
7
7
|
The ``_AtmosphereBackend`` is an internal adapter that wraps an
|
|
8
|
-
``
|
|
8
|
+
``Atmosphere`` to present the same operational surface as a repository,
|
|
9
9
|
but routes through the ATProto network instead of a local provider.
|
|
10
10
|
|
|
11
11
|
Examples:
|
|
@@ -105,7 +105,7 @@ def create_repository(
|
|
|
105
105
|
|
|
106
106
|
|
|
107
107
|
class _AtmosphereBackend:
|
|
108
|
-
"""Internal adapter wrapping
|
|
108
|
+
"""Internal adapter wrapping Atmosphere for Index routing.
|
|
109
109
|
|
|
110
110
|
This class extracts the operational logic from ``AtmosphereIndex`` into an
|
|
111
111
|
internal component that the unified ``Index`` uses for ATProto resolution.
|
|
@@ -117,15 +117,15 @@ class _AtmosphereBackend:
|
|
|
117
117
|
|
|
118
118
|
def __init__(
|
|
119
119
|
self,
|
|
120
|
-
client: Any, #
|
|
120
|
+
client: Any, # Atmosphere, typed as Any to avoid hard import
|
|
121
121
|
*,
|
|
122
122
|
data_store: Optional[AbstractDataStore] = None,
|
|
123
123
|
) -> None:
|
|
124
|
-
from .atmosphere.client import
|
|
124
|
+
from .atmosphere.client import Atmosphere
|
|
125
125
|
|
|
126
|
-
if not isinstance(client,
|
|
127
|
-
raise TypeError(f"Expected
|
|
128
|
-
self.client:
|
|
126
|
+
if not isinstance(client, Atmosphere):
|
|
127
|
+
raise TypeError(f"Expected Atmosphere, got {type(client).__name__}")
|
|
128
|
+
self.client: Atmosphere = client
|
|
129
129
|
self._data_store = data_store
|
|
130
130
|
self._schema_publisher: Any = None
|
|
131
131
|
self._schema_loader: Any = None
|
|
@@ -210,14 +210,26 @@ class _AtmosphereBackend:
|
|
|
210
210
|
*,
|
|
211
211
|
name: str,
|
|
212
212
|
schema_ref: str | None = None,
|
|
213
|
+
data_urls: list[str] | None = None,
|
|
214
|
+
blob_refs: list[dict] | None = None,
|
|
213
215
|
**kwargs: Any,
|
|
214
216
|
) -> Any:
|
|
215
217
|
"""Insert a dataset into ATProto.
|
|
216
218
|
|
|
219
|
+
When *blob_refs* is provided the record uses ``storageBlobs`` with
|
|
220
|
+
embedded blob reference objects so the PDS retains the uploaded blobs.
|
|
221
|
+
|
|
222
|
+
When *data_urls* is provided (without *blob_refs*) the record uses
|
|
223
|
+
``storageExternal`` with those URLs.
|
|
224
|
+
|
|
217
225
|
Args:
|
|
218
226
|
ds: The Dataset to publish.
|
|
219
227
|
name: Human-readable name.
|
|
220
228
|
schema_ref: Optional schema AT URI. If None, auto-publishes schema.
|
|
229
|
+
data_urls: Explicit shard URLs to store in the record. When
|
|
230
|
+
provided, these replace whatever ``ds.url`` contains.
|
|
231
|
+
blob_refs: Pre-uploaded blob reference dicts from
|
|
232
|
+
``PDSBlobStore``. Takes precedence over *data_urls*.
|
|
221
233
|
**kwargs: Additional options (description, tags, license).
|
|
222
234
|
|
|
223
235
|
Returns:
|
|
@@ -226,15 +238,53 @@ class _AtmosphereBackend:
|
|
|
226
238
|
self._ensure_loaders()
|
|
227
239
|
from .atmosphere import AtmosphereIndexEntry
|
|
228
240
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
241
|
+
if blob_refs is not None or data_urls is not None:
|
|
242
|
+
# Ensure schema is published first
|
|
243
|
+
if schema_ref is None:
|
|
244
|
+
from .atmosphere import SchemaPublisher
|
|
245
|
+
|
|
246
|
+
sp = SchemaPublisher(self.client)
|
|
247
|
+
schema_uri_obj = sp.publish(
|
|
248
|
+
ds.sample_type,
|
|
249
|
+
version=kwargs.get("schema_version", "1.0.0"),
|
|
250
|
+
)
|
|
251
|
+
schema_ref = str(schema_uri_obj)
|
|
252
|
+
|
|
253
|
+
metadata = kwargs.get("metadata")
|
|
254
|
+
if metadata is None and hasattr(ds, "_metadata"):
|
|
255
|
+
metadata = ds._metadata
|
|
256
|
+
|
|
257
|
+
if blob_refs is not None:
|
|
258
|
+
uri = self._dataset_publisher.publish_with_blob_refs(
|
|
259
|
+
blob_refs=blob_refs,
|
|
260
|
+
schema_uri=schema_ref,
|
|
261
|
+
name=name,
|
|
262
|
+
description=kwargs.get("description"),
|
|
263
|
+
tags=kwargs.get("tags"),
|
|
264
|
+
license=kwargs.get("license"),
|
|
265
|
+
metadata=metadata,
|
|
266
|
+
)
|
|
267
|
+
else:
|
|
268
|
+
uri = self._dataset_publisher.publish_with_urls(
|
|
269
|
+
urls=data_urls,
|
|
270
|
+
schema_uri=schema_ref,
|
|
271
|
+
name=name,
|
|
272
|
+
description=kwargs.get("description"),
|
|
273
|
+
tags=kwargs.get("tags"),
|
|
274
|
+
license=kwargs.get("license"),
|
|
275
|
+
metadata=metadata,
|
|
276
|
+
)
|
|
277
|
+
else:
|
|
278
|
+
uri = self._dataset_publisher.publish(
|
|
279
|
+
ds,
|
|
280
|
+
name=name,
|
|
281
|
+
schema_uri=schema_ref,
|
|
282
|
+
description=kwargs.get("description"),
|
|
283
|
+
tags=kwargs.get("tags"),
|
|
284
|
+
license=kwargs.get("license"),
|
|
285
|
+
auto_publish_schema=(schema_ref is None),
|
|
286
|
+
)
|
|
287
|
+
|
|
238
288
|
record = self._dataset_loader.get(uri)
|
|
239
289
|
return AtmosphereIndexEntry(str(uri), record)
|
|
240
290
|
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Data stores for atdata datasets.
|
|
2
|
+
|
|
3
|
+
Key classes:
|
|
4
|
+
|
|
5
|
+
- ``LocalDiskStore``: Local filesystem data store.
|
|
6
|
+
- ``S3DataStore``: S3-compatible object storage.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from atdata.stores._disk import LocalDiskStore
|
|
10
|
+
from atdata.stores._s3 import (
|
|
11
|
+
S3DataStore,
|
|
12
|
+
_s3_env,
|
|
13
|
+
_s3_from_credentials,
|
|
14
|
+
_create_s3_write_callbacks,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"LocalDiskStore",
|
|
19
|
+
"S3DataStore",
|
|
20
|
+
"_s3_env",
|
|
21
|
+
"_s3_from_credentials",
|
|
22
|
+
"_create_s3_write_callbacks",
|
|
23
|
+
]
|
atdata/stores/_disk.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""Local filesystem data store for WebDataset shards.
|
|
2
|
+
|
|
3
|
+
Writes and reads WebDataset tar archives on the local filesystem,
|
|
4
|
+
implementing the ``AbstractDataStore`` protocol.
|
|
5
|
+
|
|
6
|
+
Examples:
|
|
7
|
+
>>> store = LocalDiskStore(root="~/.atdata/data")
|
|
8
|
+
>>> urls = store.write_shards(dataset, prefix="mnist/v1")
|
|
9
|
+
>>> print(urls[0])
|
|
10
|
+
/home/user/.atdata/data/mnist/v1/data--a1b2c3--000000.tar
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import TYPE_CHECKING, Any
|
|
17
|
+
from uuid import uuid4
|
|
18
|
+
|
|
19
|
+
import webdataset as wds
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from atdata.dataset import Dataset
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LocalDiskStore:
|
|
26
|
+
"""Local filesystem data store.
|
|
27
|
+
|
|
28
|
+
Writes WebDataset shards to a directory on disk. Implements the
|
|
29
|
+
``AbstractDataStore`` protocol for use with ``Index``.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
root: Root directory for shard storage. Defaults to
|
|
33
|
+
``~/.atdata/data/``. Created automatically if it does
|
|
34
|
+
not exist.
|
|
35
|
+
|
|
36
|
+
Examples:
|
|
37
|
+
>>> store = LocalDiskStore()
|
|
38
|
+
>>> urls = store.write_shards(dataset, prefix="my-dataset")
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, root: str | Path | None = None) -> None:
|
|
42
|
+
if root is None:
|
|
43
|
+
root = Path.home() / ".atdata" / "data"
|
|
44
|
+
self._root = Path(root).expanduser().resolve()
|
|
45
|
+
self._root.mkdir(parents=True, exist_ok=True)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def root(self) -> Path:
|
|
49
|
+
"""Root directory for shard storage."""
|
|
50
|
+
return self._root
|
|
51
|
+
|
|
52
|
+
def write_shards(
|
|
53
|
+
self,
|
|
54
|
+
ds: "Dataset",
|
|
55
|
+
*,
|
|
56
|
+
prefix: str,
|
|
57
|
+
**kwargs: Any,
|
|
58
|
+
) -> list[str]:
|
|
59
|
+
"""Write dataset shards to the local filesystem.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
ds: The Dataset to write.
|
|
63
|
+
prefix: Path prefix within root (e.g., ``'datasets/mnist/v1'``).
|
|
64
|
+
**kwargs: Additional args passed to ``wds.writer.ShardWriter``
|
|
65
|
+
(e.g., ``maxcount``, ``maxsize``).
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
List of absolute file paths for the written shards.
|
|
69
|
+
|
|
70
|
+
Raises:
|
|
71
|
+
RuntimeError: If no shards were written.
|
|
72
|
+
"""
|
|
73
|
+
from atdata._logging import get_logger, log_operation
|
|
74
|
+
|
|
75
|
+
log = get_logger()
|
|
76
|
+
shard_dir = self._root / prefix
|
|
77
|
+
shard_dir.mkdir(parents=True, exist_ok=True)
|
|
78
|
+
|
|
79
|
+
new_uuid = str(uuid4())[:8]
|
|
80
|
+
shard_pattern = str(shard_dir / f"data--{new_uuid}--%06d.tar")
|
|
81
|
+
|
|
82
|
+
written_shards: list[str] = []
|
|
83
|
+
|
|
84
|
+
def _track_shard(path: str) -> None:
|
|
85
|
+
written_shards.append(str(Path(path).resolve()))
|
|
86
|
+
|
|
87
|
+
# Filter out kwargs that are specific to other stores (e.g. S3)
|
|
88
|
+
# and not understood by wds.writer.ShardWriter / TarWriter.
|
|
89
|
+
writer_kwargs = {k: v for k, v in kwargs.items() if k not in ("cache_local",)}
|
|
90
|
+
|
|
91
|
+
with log_operation("LocalDiskStore.write_shards", prefix=prefix):
|
|
92
|
+
with wds.writer.ShardWriter(
|
|
93
|
+
shard_pattern,
|
|
94
|
+
post=_track_shard,
|
|
95
|
+
**writer_kwargs,
|
|
96
|
+
) as sink:
|
|
97
|
+
for sample in ds.ordered(batch_size=None):
|
|
98
|
+
sink.write(sample.as_wds)
|
|
99
|
+
|
|
100
|
+
if not written_shards:
|
|
101
|
+
raise RuntimeError(
|
|
102
|
+
f"No shards written for prefix {prefix!r} in {self._root}"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
log.info(
|
|
106
|
+
"LocalDiskStore.write_shards: wrote %d shard(s)", len(written_shards)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return written_shards
|
|
110
|
+
|
|
111
|
+
def read_url(self, url: str) -> str:
|
|
112
|
+
"""Resolve a storage URL for reading.
|
|
113
|
+
|
|
114
|
+
Local filesystem paths are returned as-is since WebDataset
|
|
115
|
+
can read them directly.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
url: Absolute file path to a shard.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
The same path, unchanged.
|
|
122
|
+
"""
|
|
123
|
+
return url
|
|
124
|
+
|
|
125
|
+
def supports_streaming(self) -> bool:
|
|
126
|
+
"""Whether this store supports streaming reads.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
``True`` — local filesystem supports streaming.
|
|
130
|
+
"""
|
|
131
|
+
return True
|
atdata/{local → stores}/_s3.py
RENAMED
|
@@ -189,129 +189,151 @@ class S3DataStore:
|
|
|
189
189
|
Raises:
|
|
190
190
|
RuntimeError: If no shards were written.
|
|
191
191
|
"""
|
|
192
|
+
from atdata._logging import get_logger, log_operation
|
|
193
|
+
|
|
194
|
+
log = get_logger()
|
|
195
|
+
|
|
192
196
|
new_uuid = str(uuid4())
|
|
193
197
|
shard_pattern = f"{self.bucket}/{prefix}/data--{new_uuid}--%06d.tar"
|
|
194
198
|
|
|
195
199
|
written_shards: list[str] = []
|
|
196
200
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
return ManifestBuilder(
|
|
208
|
-
sample_type=ds.sample_type,
|
|
209
|
-
shard_id=shard_id,
|
|
210
|
-
schema_version=schema_version,
|
|
211
|
-
source_job_id=source_job_id,
|
|
212
|
-
parent_shards=parent_shards,
|
|
213
|
-
pipeline_version=pipeline_version,
|
|
214
|
-
)
|
|
215
|
-
|
|
216
|
-
current_builder[0] = _make_builder(0)
|
|
217
|
-
|
|
218
|
-
with TemporaryDirectory() as temp_dir:
|
|
219
|
-
writer_opener, writer_post_orig = _create_s3_write_callbacks(
|
|
220
|
-
credentials=self.credentials,
|
|
221
|
-
temp_dir=temp_dir,
|
|
222
|
-
written_shards=written_shards,
|
|
223
|
-
fs=self._fs,
|
|
224
|
-
cache_local=cache_local,
|
|
225
|
-
add_s3_prefix=True,
|
|
226
|
-
)
|
|
201
|
+
with log_operation(
|
|
202
|
+
"S3DataStore.write_shards",
|
|
203
|
+
prefix=prefix,
|
|
204
|
+
bucket=self.bucket,
|
|
205
|
+
manifest=manifest,
|
|
206
|
+
):
|
|
207
|
+
# Manifest tracking state shared with the post callback
|
|
208
|
+
manifest_builders: list = []
|
|
209
|
+
current_builder: list = [None] # mutable ref for closure
|
|
210
|
+
shard_counter: list[int] = [0]
|
|
227
211
|
|
|
228
212
|
if manifest:
|
|
213
|
+
from atdata.manifest import ManifestBuilder, ManifestWriter
|
|
214
|
+
|
|
215
|
+
def _make_builder(shard_idx: int) -> ManifestBuilder:
|
|
216
|
+
shard_id = (
|
|
217
|
+
f"{self.bucket}/{prefix}/data--{new_uuid}--{shard_idx:06d}"
|
|
218
|
+
)
|
|
219
|
+
return ManifestBuilder(
|
|
220
|
+
sample_type=ds.sample_type,
|
|
221
|
+
shard_id=shard_id,
|
|
222
|
+
schema_version=schema_version,
|
|
223
|
+
source_job_id=source_job_id,
|
|
224
|
+
parent_shards=parent_shards,
|
|
225
|
+
pipeline_version=pipeline_version,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
current_builder[0] = _make_builder(0)
|
|
229
|
+
|
|
230
|
+
with TemporaryDirectory() as temp_dir:
|
|
231
|
+
writer_opener, writer_post_orig = _create_s3_write_callbacks(
|
|
232
|
+
credentials=self.credentials,
|
|
233
|
+
temp_dir=temp_dir,
|
|
234
|
+
written_shards=written_shards,
|
|
235
|
+
fs=self._fs,
|
|
236
|
+
cache_local=cache_local,
|
|
237
|
+
add_s3_prefix=True,
|
|
238
|
+
)
|
|
229
239
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
current_builder[0]
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
offset += 512 + packed_size + (512 - packed_size % 512) % 512
|
|
264
|
-
|
|
265
|
-
# Finalize the last shard's builder (post isn't called for the last shard
|
|
266
|
-
# until ShardWriter closes, but we handle it here for safety)
|
|
267
|
-
if manifest and current_builder[0] is not None:
|
|
268
|
-
builder = current_builder[0]
|
|
269
|
-
if builder._rows: # Only if samples were added
|
|
270
|
-
manifest_builders.append(builder)
|
|
271
|
-
|
|
272
|
-
# Write all manifest files
|
|
273
|
-
if manifest:
|
|
274
|
-
for builder in manifest_builders:
|
|
275
|
-
built = builder.build()
|
|
276
|
-
writer = ManifestWriter(Path(temp_dir) / Path(built.shard_id))
|
|
277
|
-
json_path, parquet_path = writer.write(built)
|
|
278
|
-
|
|
279
|
-
# Upload manifest files to S3 alongside shards
|
|
280
|
-
shard_id = built.shard_id
|
|
281
|
-
json_key = f"{shard_id}.manifest.json"
|
|
282
|
-
parquet_key = f"{shard_id}.manifest.parquet"
|
|
283
|
-
|
|
284
|
-
if cache_local:
|
|
285
|
-
import boto3
|
|
286
|
-
|
|
287
|
-
s3_kwargs = {
|
|
288
|
-
"aws_access_key_id": self.credentials["AWS_ACCESS_KEY_ID"],
|
|
289
|
-
"aws_secret_access_key": self.credentials[
|
|
290
|
-
"AWS_SECRET_ACCESS_KEY"
|
|
291
|
-
],
|
|
292
|
-
}
|
|
293
|
-
if "AWS_ENDPOINT" in self.credentials:
|
|
294
|
-
s3_kwargs["endpoint_url"] = self.credentials["AWS_ENDPOINT"]
|
|
295
|
-
s3_client = boto3.client("s3", **s3_kwargs)
|
|
296
|
-
|
|
297
|
-
bucket_name = Path(shard_id).parts[0]
|
|
298
|
-
json_s3_key = str(Path(*Path(json_key).parts[1:]))
|
|
299
|
-
parquet_s3_key = str(Path(*Path(parquet_key).parts[1:]))
|
|
300
|
-
|
|
301
|
-
with open(json_path, "rb") as f:
|
|
302
|
-
s3_client.put_object(
|
|
303
|
-
Bucket=bucket_name, Key=json_s3_key, Body=f.read()
|
|
240
|
+
if manifest:
|
|
241
|
+
|
|
242
|
+
def writer_post(p: str):
|
|
243
|
+
# Finalize the current manifest builder when a shard completes
|
|
244
|
+
builder = current_builder[0]
|
|
245
|
+
if builder is not None:
|
|
246
|
+
manifest_builders.append(builder)
|
|
247
|
+
# Advance to the next shard's builder
|
|
248
|
+
shard_counter[0] += 1
|
|
249
|
+
current_builder[0] = _make_builder(shard_counter[0])
|
|
250
|
+
# Call original post callback
|
|
251
|
+
writer_post_orig(p)
|
|
252
|
+
else:
|
|
253
|
+
writer_post = writer_post_orig
|
|
254
|
+
|
|
255
|
+
offset = 0
|
|
256
|
+
with wds.writer.ShardWriter(
|
|
257
|
+
shard_pattern,
|
|
258
|
+
opener=writer_opener,
|
|
259
|
+
post=writer_post,
|
|
260
|
+
**kwargs,
|
|
261
|
+
) as sink:
|
|
262
|
+
for sample in ds.ordered(batch_size=None):
|
|
263
|
+
wds_dict = sample.as_wds
|
|
264
|
+
sink.write(wds_dict)
|
|
265
|
+
|
|
266
|
+
if manifest and current_builder[0] is not None:
|
|
267
|
+
packed_size = len(wds_dict.get("msgpack", b""))
|
|
268
|
+
current_builder[0].add_sample(
|
|
269
|
+
key=wds_dict["__key__"],
|
|
270
|
+
offset=offset,
|
|
271
|
+
size=packed_size,
|
|
272
|
+
sample=sample,
|
|
304
273
|
)
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
274
|
+
# Approximate tar entry: 512-byte header + data rounded to 512
|
|
275
|
+
offset += (
|
|
276
|
+
512 + packed_size + (512 - packed_size % 512) % 512
|
|
308
277
|
)
|
|
309
|
-
else:
|
|
310
|
-
self._fs.put(str(json_path), f"s3://{json_key}")
|
|
311
|
-
self._fs.put(str(parquet_path), f"s3://{parquet_key}")
|
|
312
278
|
|
|
313
|
-
|
|
314
|
-
|
|
279
|
+
# Finalize the last shard's builder (post isn't called for the last shard
|
|
280
|
+
# until ShardWriter closes, but we handle it here for safety)
|
|
281
|
+
if manifest and current_builder[0] is not None:
|
|
282
|
+
builder = current_builder[0]
|
|
283
|
+
if builder._rows: # Only if samples were added
|
|
284
|
+
manifest_builders.append(builder)
|
|
285
|
+
|
|
286
|
+
# Write all manifest files
|
|
287
|
+
if manifest:
|
|
288
|
+
for builder in manifest_builders:
|
|
289
|
+
built = builder.build()
|
|
290
|
+
writer = ManifestWriter(Path(temp_dir) / Path(built.shard_id))
|
|
291
|
+
json_path, parquet_path = writer.write(built)
|
|
292
|
+
|
|
293
|
+
# Upload manifest files to S3 alongside shards
|
|
294
|
+
shard_id = built.shard_id
|
|
295
|
+
json_key = f"{shard_id}.manifest.json"
|
|
296
|
+
parquet_key = f"{shard_id}.manifest.parquet"
|
|
297
|
+
|
|
298
|
+
if cache_local:
|
|
299
|
+
import boto3
|
|
300
|
+
|
|
301
|
+
s3_kwargs = {
|
|
302
|
+
"aws_access_key_id": self.credentials[
|
|
303
|
+
"AWS_ACCESS_KEY_ID"
|
|
304
|
+
],
|
|
305
|
+
"aws_secret_access_key": self.credentials[
|
|
306
|
+
"AWS_SECRET_ACCESS_KEY"
|
|
307
|
+
],
|
|
308
|
+
}
|
|
309
|
+
if "AWS_ENDPOINT" in self.credentials:
|
|
310
|
+
s3_kwargs["endpoint_url"] = self.credentials[
|
|
311
|
+
"AWS_ENDPOINT"
|
|
312
|
+
]
|
|
313
|
+
s3_client = boto3.client("s3", **s3_kwargs)
|
|
314
|
+
|
|
315
|
+
bucket_name = Path(shard_id).parts[0]
|
|
316
|
+
json_s3_key = str(Path(*Path(json_key).parts[1:]))
|
|
317
|
+
parquet_s3_key = str(Path(*Path(parquet_key).parts[1:]))
|
|
318
|
+
|
|
319
|
+
with open(json_path, "rb") as f:
|
|
320
|
+
s3_client.put_object(
|
|
321
|
+
Bucket=bucket_name, Key=json_s3_key, Body=f.read()
|
|
322
|
+
)
|
|
323
|
+
with open(parquet_path, "rb") as f:
|
|
324
|
+
s3_client.put_object(
|
|
325
|
+
Bucket=bucket_name,
|
|
326
|
+
Key=parquet_s3_key,
|
|
327
|
+
Body=f.read(),
|
|
328
|
+
)
|
|
329
|
+
else:
|
|
330
|
+
self._fs.put(str(json_path), f"s3://{json_key}")
|
|
331
|
+
self._fs.put(str(parquet_path), f"s3://{parquet_key}")
|
|
332
|
+
|
|
333
|
+
if len(written_shards) == 0:
|
|
334
|
+
raise RuntimeError("No shards written")
|
|
335
|
+
|
|
336
|
+
log.info("S3DataStore.write_shards: wrote %d shard(s)", len(written_shards))
|
|
315
337
|
|
|
316
338
|
return written_shards
|
|
317
339
|
|
atdata/testing.py
CHANGED
|
@@ -14,7 +14,7 @@ Usage::
|
|
|
14
14
|
samples = at_test.make_samples(MyType, n=100)
|
|
15
15
|
|
|
16
16
|
# Use mock atmosphere client
|
|
17
|
-
client = at_test.
|
|
17
|
+
client = at_test.MockAtmosphere()
|
|
18
18
|
|
|
19
19
|
# Use in-memory index (SQLite backed, temporary)
|
|
20
20
|
index = at_test.mock_index(tmp_path)
|
|
@@ -40,7 +40,7 @@ import webdataset as wds
|
|
|
40
40
|
|
|
41
41
|
import atdata
|
|
42
42
|
from atdata import Dataset, PackableSample
|
|
43
|
-
from atdata.
|
|
43
|
+
from atdata.index._index import Index
|
|
44
44
|
from atdata.providers._sqlite import SqliteProvider
|
|
45
45
|
|
|
46
46
|
ST = TypeVar("ST")
|
|
@@ -51,14 +51,14 @@ ST = TypeVar("ST")
|
|
|
51
51
|
# ---------------------------------------------------------------------------
|
|
52
52
|
|
|
53
53
|
|
|
54
|
-
class
|
|
55
|
-
"""In-memory mock of ``
|
|
54
|
+
class MockAtmosphere:
|
|
55
|
+
"""In-memory mock of ``Atmosphere`` for testing.
|
|
56
56
|
|
|
57
57
|
Simulates login, schema publishing, dataset publishing, and record
|
|
58
58
|
retrieval without requiring a live ATProto PDS.
|
|
59
59
|
|
|
60
60
|
Examples:
|
|
61
|
-
>>> client =
|
|
61
|
+
>>> client = MockAtmosphere()
|
|
62
62
|
>>> client.login("alice.test", "password")
|
|
63
63
|
>>> client.did
|
|
64
64
|
'did:plc:mock000000000000'
|
|
@@ -294,8 +294,8 @@ try:
|
|
|
294
294
|
|
|
295
295
|
@pytest.fixture
|
|
296
296
|
def mock_atmosphere():
|
|
297
|
-
"""Provide a fresh ``
|
|
298
|
-
client =
|
|
297
|
+
"""Provide a fresh ``MockAtmosphere`` for each test."""
|
|
298
|
+
client = MockAtmosphere()
|
|
299
299
|
client.login("test.mock.social", "test-password")
|
|
300
300
|
yield client
|
|
301
301
|
client.reset()
|
|
@@ -329,8 +329,12 @@ except ImportError:
|
|
|
329
329
|
# Public API
|
|
330
330
|
# ---------------------------------------------------------------------------
|
|
331
331
|
|
|
332
|
+
# Deprecated alias for backward compatibility
|
|
333
|
+
MockAtmosphereClient = MockAtmosphere
|
|
334
|
+
|
|
332
335
|
__all__ = [
|
|
333
|
-
"
|
|
336
|
+
"MockAtmosphere",
|
|
337
|
+
"MockAtmosphereClient", # deprecated alias
|
|
334
338
|
"make_dataset",
|
|
335
339
|
"make_samples",
|
|
336
340
|
"mock_index",
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: atdata
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.2b1
|
|
4
4
|
Summary: A loose federation of distributed, typed datasets
|
|
5
5
|
Author-email: Maxine Levesque <hello@maxine.science>, "Maxine @ Forecast Bio" <maxine@forecast.bio>
|
|
6
6
|
License-File: LICENSE
|
|
@@ -30,7 +30,7 @@ Description-Content-Type: text/markdown
|
|
|
30
30
|
|
|
31
31
|
# atdata
|
|
32
32
|
|
|
33
|
-
[](https://codecov.io/gh/forecast-bio/atdata)
|
|
34
34
|
|
|
35
35
|
A loose federation of distributed, typed datasets built on WebDataset.
|
|
36
36
|
|