atdata 0.2.0a1__py3-none-any.whl → 0.2.2b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +43 -10
- atdata/_cid.py +150 -0
- atdata/_hf_api.py +692 -0
- atdata/_protocols.py +519 -0
- atdata/_schema_codec.py +442 -0
- atdata/_sources.py +515 -0
- atdata/_stub_manager.py +529 -0
- atdata/_type_utils.py +90 -0
- atdata/atmosphere/__init__.py +278 -7
- atdata/atmosphere/_types.py +9 -7
- atdata/atmosphere/client.py +146 -6
- atdata/atmosphere/lens.py +29 -25
- atdata/atmosphere/records.py +197 -30
- atdata/atmosphere/schema.py +41 -98
- atdata/atmosphere/store.py +208 -0
- atdata/cli/__init__.py +213 -0
- atdata/cli/diagnose.py +165 -0
- atdata/cli/local.py +280 -0
- atdata/dataset.py +482 -167
- atdata/lens.py +61 -57
- atdata/local.py +1400 -185
- atdata/promote.py +199 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.2b1.dist-info}/METADATA +105 -14
- atdata-0.2.2b1.dist-info/RECORD +28 -0
- atdata-0.2.0a1.dist-info/RECORD +0 -16
- {atdata-0.2.0a1.dist-info → atdata-0.2.2b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.2b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.2b1.dist-info}/licenses/LICENSE +0 -0
atdata/atmosphere/records.py
CHANGED
|
@@ -32,18 +32,20 @@ class DatasetPublisher:
|
|
|
32
32
|
external storage (WebDataset URLs) or ATProto blobs.
|
|
33
33
|
|
|
34
34
|
Example:
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
35
|
+
::
|
|
36
|
+
|
|
37
|
+
>>> dataset = atdata.Dataset[MySample]("s3://bucket/data-{000000..000009}.tar")
|
|
38
|
+
>>>
|
|
39
|
+
>>> client = AtmosphereClient()
|
|
40
|
+
>>> client.login("handle", "password")
|
|
41
|
+
>>>
|
|
42
|
+
>>> publisher = DatasetPublisher(client)
|
|
43
|
+
>>> uri = publisher.publish(
|
|
44
|
+
... dataset,
|
|
45
|
+
... name="My Training Data",
|
|
46
|
+
... description="Training data for my model",
|
|
47
|
+
... tags=["computer-vision", "training"],
|
|
48
|
+
... )
|
|
47
49
|
"""
|
|
48
50
|
|
|
49
51
|
def __init__(self, client: AtmosphereClient):
|
|
@@ -187,6 +189,76 @@ class DatasetPublisher:
|
|
|
187
189
|
validate=False,
|
|
188
190
|
)
|
|
189
191
|
|
|
192
|
+
def publish_with_blobs(
|
|
193
|
+
self,
|
|
194
|
+
blobs: list[bytes],
|
|
195
|
+
schema_uri: str,
|
|
196
|
+
*,
|
|
197
|
+
name: str,
|
|
198
|
+
description: Optional[str] = None,
|
|
199
|
+
tags: Optional[list[str]] = None,
|
|
200
|
+
license: Optional[str] = None,
|
|
201
|
+
metadata: Optional[dict] = None,
|
|
202
|
+
mime_type: str = "application/x-tar",
|
|
203
|
+
rkey: Optional[str] = None,
|
|
204
|
+
) -> AtUri:
|
|
205
|
+
"""Publish a dataset with data stored as ATProto blobs.
|
|
206
|
+
|
|
207
|
+
This method uploads the provided data as blobs to the PDS and creates
|
|
208
|
+
a dataset record referencing them. Suitable for smaller datasets that
|
|
209
|
+
fit within blob size limits (typically 50MB per blob, configurable).
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
blobs: List of binary data (e.g., tar shards) to upload as blobs.
|
|
213
|
+
schema_uri: AT URI of the schema record.
|
|
214
|
+
name: Human-readable dataset name.
|
|
215
|
+
description: Human-readable description.
|
|
216
|
+
tags: Searchable tags for discovery.
|
|
217
|
+
license: SPDX license identifier.
|
|
218
|
+
metadata: Arbitrary metadata dictionary.
|
|
219
|
+
mime_type: MIME type for the blobs (default: application/x-tar).
|
|
220
|
+
rkey: Optional explicit record key.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
The AT URI of the created dataset record.
|
|
224
|
+
|
|
225
|
+
Note:
|
|
226
|
+
Blobs are only retained by the PDS when referenced in a committed
|
|
227
|
+
record. This method handles that automatically.
|
|
228
|
+
"""
|
|
229
|
+
# Upload all blobs
|
|
230
|
+
blob_refs = []
|
|
231
|
+
for blob_data in blobs:
|
|
232
|
+
blob_ref = self.client.upload_blob(blob_data, mime_type=mime_type)
|
|
233
|
+
blob_refs.append(blob_ref)
|
|
234
|
+
|
|
235
|
+
# Create storage location with blob references
|
|
236
|
+
storage = StorageLocation(
|
|
237
|
+
kind="blobs",
|
|
238
|
+
blob_refs=blob_refs,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
metadata_bytes: Optional[bytes] = None
|
|
242
|
+
if metadata is not None:
|
|
243
|
+
metadata_bytes = msgpack.packb(metadata)
|
|
244
|
+
|
|
245
|
+
dataset_record = DatasetRecord(
|
|
246
|
+
name=name,
|
|
247
|
+
schema_ref=schema_uri,
|
|
248
|
+
storage=storage,
|
|
249
|
+
description=description,
|
|
250
|
+
tags=tags or [],
|
|
251
|
+
license=license,
|
|
252
|
+
metadata=metadata_bytes,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
return self.client.create_record(
|
|
256
|
+
collection=f"{LEXICON_NAMESPACE}.record",
|
|
257
|
+
record=dataset_record.to_record(),
|
|
258
|
+
rkey=rkey,
|
|
259
|
+
validate=False,
|
|
260
|
+
)
|
|
261
|
+
|
|
190
262
|
|
|
191
263
|
class DatasetLoader:
|
|
192
264
|
"""Loads dataset records from ATProto.
|
|
@@ -196,16 +268,18 @@ class DatasetLoader:
|
|
|
196
268
|
Python class for the sample type.
|
|
197
269
|
|
|
198
270
|
Example:
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
271
|
+
::
|
|
272
|
+
|
|
273
|
+
>>> client = AtmosphereClient()
|
|
274
|
+
>>> loader = DatasetLoader(client)
|
|
275
|
+
>>>
|
|
276
|
+
>>> # List available datasets
|
|
277
|
+
>>> datasets = loader.list()
|
|
278
|
+
>>> for ds in datasets:
|
|
279
|
+
... print(ds["name"], ds["schemaRef"])
|
|
280
|
+
>>>
|
|
281
|
+
>>> # Get a specific dataset record
|
|
282
|
+
>>> record = loader.get("at://did:plc:abc/ac.foundation.dataset.record/xyz")
|
|
209
283
|
"""
|
|
210
284
|
|
|
211
285
|
def __init__(self, client: AtmosphereClient):
|
|
@@ -255,6 +329,29 @@ class DatasetLoader:
|
|
|
255
329
|
"""
|
|
256
330
|
return self.client.list_datasets(repo=repo, limit=limit)
|
|
257
331
|
|
|
332
|
+
def get_storage_type(self, uri: str | AtUri) -> str:
|
|
333
|
+
"""Get the storage type of a dataset record.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
uri: The AT URI of the dataset record.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
Either "external" or "blobs".
|
|
340
|
+
|
|
341
|
+
Raises:
|
|
342
|
+
ValueError: If storage type is unknown.
|
|
343
|
+
"""
|
|
344
|
+
record = self.get(uri)
|
|
345
|
+
storage = record.get("storage", {})
|
|
346
|
+
storage_type = storage.get("$type", "")
|
|
347
|
+
|
|
348
|
+
if "storageExternal" in storage_type:
|
|
349
|
+
return "external"
|
|
350
|
+
elif "storageBlobs" in storage_type:
|
|
351
|
+
return "blobs"
|
|
352
|
+
else:
|
|
353
|
+
raise ValueError(f"Unknown storage type: {storage_type}")
|
|
354
|
+
|
|
258
355
|
def get_urls(self, uri: str | AtUri) -> list[str]:
|
|
259
356
|
"""Get the WebDataset URLs from a dataset record.
|
|
260
357
|
|
|
@@ -276,11 +373,71 @@ class DatasetLoader:
|
|
|
276
373
|
elif "storageBlobs" in storage_type:
|
|
277
374
|
raise ValueError(
|
|
278
375
|
"Dataset uses blob storage, not external URLs. "
|
|
279
|
-
"Use
|
|
376
|
+
"Use get_blob_urls() instead."
|
|
280
377
|
)
|
|
281
378
|
else:
|
|
282
379
|
raise ValueError(f"Unknown storage type: {storage_type}")
|
|
283
380
|
|
|
381
|
+
def get_blobs(self, uri: str | AtUri) -> list[dict]:
|
|
382
|
+
"""Get the blob references from a dataset record.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
uri: The AT URI of the dataset record.
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
List of blob reference dicts with keys: $type, ref, mimeType, size.
|
|
389
|
+
|
|
390
|
+
Raises:
|
|
391
|
+
ValueError: If the storage type is not blobs.
|
|
392
|
+
"""
|
|
393
|
+
record = self.get(uri)
|
|
394
|
+
storage = record.get("storage", {})
|
|
395
|
+
|
|
396
|
+
storage_type = storage.get("$type", "")
|
|
397
|
+
if "storageBlobs" in storage_type:
|
|
398
|
+
return storage.get("blobs", [])
|
|
399
|
+
elif "storageExternal" in storage_type:
|
|
400
|
+
raise ValueError(
|
|
401
|
+
"Dataset uses external URL storage, not blobs. "
|
|
402
|
+
"Use get_urls() instead."
|
|
403
|
+
)
|
|
404
|
+
else:
|
|
405
|
+
raise ValueError(f"Unknown storage type: {storage_type}")
|
|
406
|
+
|
|
407
|
+
def get_blob_urls(self, uri: str | AtUri) -> list[str]:
|
|
408
|
+
"""Get fetchable URLs for blob-stored dataset shards.
|
|
409
|
+
|
|
410
|
+
This resolves the PDS endpoint and constructs URLs that can be
|
|
411
|
+
used to fetch the blob data directly.
|
|
412
|
+
|
|
413
|
+
Args:
|
|
414
|
+
uri: The AT URI of the dataset record.
|
|
415
|
+
|
|
416
|
+
Returns:
|
|
417
|
+
List of URLs for fetching the blob data.
|
|
418
|
+
|
|
419
|
+
Raises:
|
|
420
|
+
ValueError: If storage type is not blobs or PDS cannot be resolved.
|
|
421
|
+
"""
|
|
422
|
+
if isinstance(uri, str):
|
|
423
|
+
parsed_uri = AtUri.parse(uri)
|
|
424
|
+
else:
|
|
425
|
+
parsed_uri = uri
|
|
426
|
+
|
|
427
|
+
blobs = self.get_blobs(uri)
|
|
428
|
+
did = parsed_uri.authority
|
|
429
|
+
|
|
430
|
+
urls = []
|
|
431
|
+
for blob in blobs:
|
|
432
|
+
# Extract CID from blob reference
|
|
433
|
+
ref = blob.get("ref", {})
|
|
434
|
+
cid = ref.get("$link") if isinstance(ref, dict) else str(ref)
|
|
435
|
+
if cid:
|
|
436
|
+
url = self.client.get_blob_url(did, cid)
|
|
437
|
+
urls.append(url)
|
|
438
|
+
|
|
439
|
+
return urls
|
|
440
|
+
|
|
284
441
|
def get_metadata(self, uri: str | AtUri) -> Optional[dict]:
|
|
285
442
|
"""Get the metadata from a dataset record.
|
|
286
443
|
|
|
@@ -309,6 +466,8 @@ class DatasetLoader:
|
|
|
309
466
|
You must provide the sample type class, which should match the
|
|
310
467
|
schema referenced by the record.
|
|
311
468
|
|
|
469
|
+
Supports both external URL storage and ATProto blob storage.
|
|
470
|
+
|
|
312
471
|
Args:
|
|
313
472
|
uri: The AT URI of the dataset record.
|
|
314
473
|
sample_type: The Python class for the sample type.
|
|
@@ -317,20 +476,28 @@ class DatasetLoader:
|
|
|
317
476
|
A Dataset instance configured from the record.
|
|
318
477
|
|
|
319
478
|
Raises:
|
|
320
|
-
ValueError: If
|
|
479
|
+
ValueError: If no storage URLs can be resolved.
|
|
321
480
|
|
|
322
481
|
Example:
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
482
|
+
::
|
|
483
|
+
|
|
484
|
+
>>> loader = DatasetLoader(client)
|
|
485
|
+
>>> dataset = loader.to_dataset(uri, MySampleType)
|
|
486
|
+
>>> for batch in dataset.shuffled(batch_size=32):
|
|
487
|
+
... process(batch)
|
|
327
488
|
"""
|
|
328
489
|
# Import here to avoid circular import
|
|
329
490
|
from ..dataset import Dataset
|
|
330
491
|
|
|
331
|
-
|
|
492
|
+
storage_type = self.get_storage_type(uri)
|
|
493
|
+
|
|
494
|
+
if storage_type == "external":
|
|
495
|
+
urls = self.get_urls(uri)
|
|
496
|
+
else:
|
|
497
|
+
urls = self.get_blob_urls(uri)
|
|
498
|
+
|
|
332
499
|
if not urls:
|
|
333
|
-
raise ValueError("Dataset record has no URLs")
|
|
500
|
+
raise ValueError("Dataset record has no storage URLs")
|
|
334
501
|
|
|
335
502
|
# Use the first URL (multi-URL support could be added later)
|
|
336
503
|
url = urls[0]
|
atdata/atmosphere/schema.py
CHANGED
|
@@ -6,8 +6,7 @@ records.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
from dataclasses import fields, is_dataclass
|
|
9
|
-
from typing import Type, TypeVar, Optional,
|
|
10
|
-
import types
|
|
9
|
+
from typing import Type, TypeVar, Optional, get_type_hints, get_origin, get_args
|
|
11
10
|
|
|
12
11
|
from .client import AtmosphereClient
|
|
13
12
|
from ._types import (
|
|
@@ -17,6 +16,12 @@ from ._types import (
|
|
|
17
16
|
FieldType,
|
|
18
17
|
LEXICON_NAMESPACE,
|
|
19
18
|
)
|
|
19
|
+
from .._type_utils import (
|
|
20
|
+
numpy_dtype_to_string,
|
|
21
|
+
unwrap_optional,
|
|
22
|
+
is_ndarray_type,
|
|
23
|
+
extract_ndarray_dtype,
|
|
24
|
+
)
|
|
20
25
|
|
|
21
26
|
# Import for type checking only to avoid circular imports
|
|
22
27
|
from typing import TYPE_CHECKING
|
|
@@ -33,18 +38,20 @@ class SchemaPublisher:
|
|
|
33
38
|
definitions and publishes them as an ATProto schema record.
|
|
34
39
|
|
|
35
40
|
Example:
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
41
|
+
::
|
|
42
|
+
|
|
43
|
+
>>> @atdata.packable
|
|
44
|
+
... class MySample:
|
|
45
|
+
... image: NDArray
|
|
46
|
+
... label: str
|
|
47
|
+
...
|
|
48
|
+
>>> client = AtmosphereClient()
|
|
49
|
+
>>> client.login("handle", "password")
|
|
50
|
+
>>>
|
|
51
|
+
>>> publisher = SchemaPublisher(client)
|
|
52
|
+
>>> uri = publisher.publish(MySample, version="1.0.0")
|
|
53
|
+
>>> print(uri)
|
|
54
|
+
at://did:plc:.../ac.foundation.dataset.sampleSchema/...
|
|
48
55
|
"""
|
|
49
56
|
|
|
50
57
|
def __init__(self, client: AtmosphereClient):
|
|
@@ -130,71 +137,32 @@ class SchemaPublisher:
|
|
|
130
137
|
|
|
131
138
|
def _field_to_def(self, name: str, python_type) -> FieldDef:
|
|
132
139
|
"""Convert a Python field to a FieldDef."""
|
|
133
|
-
|
|
134
|
-
is_optional = False
|
|
135
|
-
origin = get_origin(python_type)
|
|
136
|
-
|
|
137
|
-
# Handle Union types (including Optional which is Union[T, None])
|
|
138
|
-
if origin is Union or isinstance(python_type, types.UnionType):
|
|
139
|
-
args = get_args(python_type)
|
|
140
|
-
non_none_args = [a for a in args if a is not type(None)]
|
|
141
|
-
if type(None) in args or len(non_none_args) < len(args):
|
|
142
|
-
is_optional = True
|
|
143
|
-
if len(non_none_args) == 1:
|
|
144
|
-
python_type = non_none_args[0]
|
|
145
|
-
elif len(non_none_args) > 1:
|
|
146
|
-
# Complex union type - not fully supported yet
|
|
147
|
-
raise TypeError(f"Complex union types not supported: {python_type}")
|
|
148
|
-
|
|
140
|
+
python_type, is_optional = unwrap_optional(python_type)
|
|
149
141
|
field_type = self._python_type_to_field_type(python_type)
|
|
150
|
-
|
|
151
|
-
return FieldDef(
|
|
152
|
-
name=name,
|
|
153
|
-
field_type=field_type,
|
|
154
|
-
optional=is_optional,
|
|
155
|
-
)
|
|
142
|
+
return FieldDef(name=name, field_type=field_type, optional=is_optional)
|
|
156
143
|
|
|
157
144
|
def _python_type_to_field_type(self, python_type) -> FieldType:
|
|
158
145
|
"""Map a Python type to a FieldType."""
|
|
159
|
-
# Handle primitives
|
|
160
146
|
if python_type is str:
|
|
161
147
|
return FieldType(kind="primitive", primitive="str")
|
|
162
|
-
|
|
148
|
+
if python_type is int:
|
|
163
149
|
return FieldType(kind="primitive", primitive="int")
|
|
164
|
-
|
|
150
|
+
if python_type is float:
|
|
165
151
|
return FieldType(kind="primitive", primitive="float")
|
|
166
|
-
|
|
152
|
+
if python_type is bool:
|
|
167
153
|
return FieldType(kind="primitive", primitive="bool")
|
|
168
|
-
|
|
154
|
+
if python_type is bytes:
|
|
169
155
|
return FieldType(kind="primitive", primitive="bytes")
|
|
170
156
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
type_str = str(python_type)
|
|
174
|
-
if "NDArray" in type_str or "ndarray" in type_str.lower():
|
|
175
|
-
# Try to extract dtype info if available
|
|
176
|
-
dtype = "float32" # Default
|
|
177
|
-
args = get_args(python_type)
|
|
178
|
-
if args:
|
|
179
|
-
# NDArray[np.float64] or similar
|
|
180
|
-
dtype_arg = args[-1] if args else None
|
|
181
|
-
if dtype_arg is not None:
|
|
182
|
-
dtype = self._numpy_dtype_to_string(dtype_arg)
|
|
183
|
-
|
|
184
|
-
return FieldType(kind="ndarray", dtype=dtype, shape=None)
|
|
157
|
+
if is_ndarray_type(python_type):
|
|
158
|
+
return FieldType(kind="ndarray", dtype=extract_ndarray_dtype(python_type), shape=None)
|
|
185
159
|
|
|
186
|
-
# Check for list/array types
|
|
187
160
|
origin = get_origin(python_type)
|
|
188
161
|
if origin is list:
|
|
189
162
|
args = get_args(python_type)
|
|
190
|
-
if args
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
else:
|
|
194
|
-
# Untyped list
|
|
195
|
-
return FieldType(kind="array", items=FieldType(kind="primitive", primitive="str"))
|
|
196
|
-
|
|
197
|
-
# Check for nested PackableSample (not yet supported)
|
|
163
|
+
items = self._python_type_to_field_type(args[0]) if args else FieldType(kind="primitive", primitive="str")
|
|
164
|
+
return FieldType(kind="array", items=items)
|
|
165
|
+
|
|
198
166
|
if is_dataclass(python_type):
|
|
199
167
|
raise TypeError(
|
|
200
168
|
f"Nested dataclass types not yet supported: {python_type.__name__}. "
|
|
@@ -203,33 +171,6 @@ class SchemaPublisher:
|
|
|
203
171
|
|
|
204
172
|
raise TypeError(f"Unsupported type for schema field: {python_type}")
|
|
205
173
|
|
|
206
|
-
def _numpy_dtype_to_string(self, dtype) -> str:
|
|
207
|
-
"""Convert a numpy dtype annotation to a string."""
|
|
208
|
-
dtype_str = str(dtype)
|
|
209
|
-
# Handle common numpy dtypes
|
|
210
|
-
dtype_map = {
|
|
211
|
-
"float16": "float16",
|
|
212
|
-
"float32": "float32",
|
|
213
|
-
"float64": "float64",
|
|
214
|
-
"int8": "int8",
|
|
215
|
-
"int16": "int16",
|
|
216
|
-
"int32": "int32",
|
|
217
|
-
"int64": "int64",
|
|
218
|
-
"uint8": "uint8",
|
|
219
|
-
"uint16": "uint16",
|
|
220
|
-
"uint32": "uint32",
|
|
221
|
-
"uint64": "uint64",
|
|
222
|
-
"bool": "bool",
|
|
223
|
-
"complex64": "complex64",
|
|
224
|
-
"complex128": "complex128",
|
|
225
|
-
}
|
|
226
|
-
|
|
227
|
-
for key, value in dtype_map.items():
|
|
228
|
-
if key in dtype_str:
|
|
229
|
-
return value
|
|
230
|
-
|
|
231
|
-
return "float32" # Default fallback
|
|
232
|
-
|
|
233
174
|
|
|
234
175
|
class SchemaLoader:
|
|
235
176
|
"""Loads PackableSample schemas from ATProto.
|
|
@@ -238,13 +179,15 @@ class SchemaLoader:
|
|
|
238
179
|
schemas from a repository.
|
|
239
180
|
|
|
240
181
|
Example:
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
182
|
+
::
|
|
183
|
+
|
|
184
|
+
>>> client = AtmosphereClient()
|
|
185
|
+
>>> client.login("handle", "password")
|
|
186
|
+
>>>
|
|
187
|
+
>>> loader = SchemaLoader(client)
|
|
188
|
+
>>> schema = loader.get("at://did:plc:.../ac.foundation.dataset.sampleSchema/...")
|
|
189
|
+
>>> print(schema["name"])
|
|
190
|
+
'MySample'
|
|
248
191
|
"""
|
|
249
192
|
|
|
250
193
|
def __init__(self, client: AtmosphereClient):
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
"""PDS blob storage for dataset shards.
|
|
2
|
+
|
|
3
|
+
This module provides ``PDSBlobStore``, an implementation of the AbstractDataStore
|
|
4
|
+
protocol that stores dataset shards as ATProto blobs in a Personal Data Server.
|
|
5
|
+
|
|
6
|
+
This enables fully decentralized dataset storage where both metadata (records)
|
|
7
|
+
and data (blobs) live on the AT Protocol network.
|
|
8
|
+
|
|
9
|
+
Example:
|
|
10
|
+
::
|
|
11
|
+
|
|
12
|
+
>>> from atdata.atmosphere import AtmosphereClient, PDSBlobStore
|
|
13
|
+
>>>
|
|
14
|
+
>>> client = AtmosphereClient()
|
|
15
|
+
>>> client.login("handle.bsky.social", "app-password")
|
|
16
|
+
>>>
|
|
17
|
+
>>> store = PDSBlobStore(client)
|
|
18
|
+
>>> urls = store.write_shards(dataset, prefix="mnist/v1")
|
|
19
|
+
>>> print(urls)
|
|
20
|
+
['at://did:plc:.../blob/bafyrei...', ...]
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import io
|
|
26
|
+
import tempfile
|
|
27
|
+
from dataclasses import dataclass
|
|
28
|
+
from typing import TYPE_CHECKING, Any
|
|
29
|
+
|
|
30
|
+
import webdataset as wds
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from ..dataset import Dataset
|
|
34
|
+
from .client import AtmosphereClient
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class PDSBlobStore:
|
|
39
|
+
"""PDS blob store implementing AbstractDataStore protocol.
|
|
40
|
+
|
|
41
|
+
Stores dataset shards as ATProto blobs, enabling decentralized dataset
|
|
42
|
+
storage on the AT Protocol network.
|
|
43
|
+
|
|
44
|
+
Each shard is written to a temporary tar file, then uploaded as a blob
|
|
45
|
+
to the user's PDS. The returned URLs are AT URIs that can be resolved
|
|
46
|
+
to HTTP URLs for streaming.
|
|
47
|
+
|
|
48
|
+
Attributes:
|
|
49
|
+
client: Authenticated AtmosphereClient instance.
|
|
50
|
+
|
|
51
|
+
Example:
|
|
52
|
+
::
|
|
53
|
+
|
|
54
|
+
>>> store = PDSBlobStore(client)
|
|
55
|
+
>>> urls = store.write_shards(dataset, prefix="training/v1")
|
|
56
|
+
>>> # Returns AT URIs like:
|
|
57
|
+
>>> # ['at://did:plc:abc/blob/bafyrei...', ...]
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
client: "AtmosphereClient"
|
|
61
|
+
|
|
62
|
+
def write_shards(
|
|
63
|
+
self,
|
|
64
|
+
ds: "Dataset",
|
|
65
|
+
*,
|
|
66
|
+
prefix: str,
|
|
67
|
+
maxcount: int = 10000,
|
|
68
|
+
maxsize: float = 3e9,
|
|
69
|
+
**kwargs: Any,
|
|
70
|
+
) -> list[str]:
|
|
71
|
+
"""Write dataset shards as PDS blobs.
|
|
72
|
+
|
|
73
|
+
Creates tar archives from the dataset and uploads each as a blob
|
|
74
|
+
to the authenticated user's PDS.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
ds: The Dataset to write.
|
|
78
|
+
prefix: Logical path prefix for naming (used in shard names only).
|
|
79
|
+
maxcount: Maximum samples per shard (default: 10000).
|
|
80
|
+
maxsize: Maximum shard size in bytes (default: 3GB, PDS limit).
|
|
81
|
+
**kwargs: Additional args passed to wds.ShardWriter.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
List of AT URIs for the written blobs, in format:
|
|
85
|
+
``at://{did}/blob/{cid}``
|
|
86
|
+
|
|
87
|
+
Raises:
|
|
88
|
+
ValueError: If not authenticated.
|
|
89
|
+
RuntimeError: If no shards were written.
|
|
90
|
+
|
|
91
|
+
Note:
|
|
92
|
+
PDS blobs have size limits (typically 50MB-5GB depending on PDS).
|
|
93
|
+
Adjust maxcount/maxsize to stay within limits.
|
|
94
|
+
"""
|
|
95
|
+
if not self.client.did:
|
|
96
|
+
raise ValueError("Client must be authenticated to upload blobs")
|
|
97
|
+
|
|
98
|
+
did = self.client.did
|
|
99
|
+
blob_urls: list[str] = []
|
|
100
|
+
|
|
101
|
+
# Write shards to temp files, upload each as blob
|
|
102
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
103
|
+
shard_pattern = f"{temp_dir}/shard-%06d.tar"
|
|
104
|
+
written_files: list[str] = []
|
|
105
|
+
|
|
106
|
+
# Track written files via custom post callback
|
|
107
|
+
def track_file(fname: str) -> None:
|
|
108
|
+
written_files.append(fname)
|
|
109
|
+
|
|
110
|
+
with wds.writer.ShardWriter(
|
|
111
|
+
shard_pattern,
|
|
112
|
+
maxcount=maxcount,
|
|
113
|
+
maxsize=maxsize,
|
|
114
|
+
post=track_file,
|
|
115
|
+
**kwargs,
|
|
116
|
+
) as sink:
|
|
117
|
+
for sample in ds.ordered(batch_size=None):
|
|
118
|
+
sink.write(sample.as_wds)
|
|
119
|
+
|
|
120
|
+
if not written_files:
|
|
121
|
+
raise RuntimeError("No shards written")
|
|
122
|
+
|
|
123
|
+
# Upload each shard as a blob
|
|
124
|
+
for shard_path in written_files:
|
|
125
|
+
with open(shard_path, "rb") as f:
|
|
126
|
+
shard_data = f.read()
|
|
127
|
+
|
|
128
|
+
blob_ref = self.client.upload_blob(
|
|
129
|
+
shard_data,
|
|
130
|
+
mime_type="application/x-tar",
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Extract CID from blob reference
|
|
134
|
+
cid = blob_ref["ref"]["$link"]
|
|
135
|
+
at_uri = f"at://{did}/blob/{cid}"
|
|
136
|
+
blob_urls.append(at_uri)
|
|
137
|
+
|
|
138
|
+
return blob_urls
|
|
139
|
+
|
|
140
|
+
def read_url(self, url: str) -> str:
|
|
141
|
+
"""Resolve an AT URI blob reference to an HTTP URL.
|
|
142
|
+
|
|
143
|
+
Transforms ``at://did/blob/cid`` URIs to HTTP URLs that can be
|
|
144
|
+
streamed by WebDataset.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
url: AT URI in format ``at://{did}/blob/{cid}``.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
HTTP URL for fetching the blob via PDS API.
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
ValueError: If URL format is invalid or PDS cannot be resolved.
|
|
154
|
+
"""
|
|
155
|
+
if not url.startswith("at://"):
|
|
156
|
+
# Not an AT URI, return unchanged
|
|
157
|
+
return url
|
|
158
|
+
|
|
159
|
+
# Parse at://did/blob/cid
|
|
160
|
+
parts = url[5:].split("/") # Remove 'at://'
|
|
161
|
+
if len(parts) != 3 or parts[1] != "blob":
|
|
162
|
+
raise ValueError(f"Invalid blob AT URI format: {url}")
|
|
163
|
+
|
|
164
|
+
did, _, cid = parts
|
|
165
|
+
return self.client.get_blob_url(did, cid)
|
|
166
|
+
|
|
167
|
+
def supports_streaming(self) -> bool:
|
|
168
|
+
"""PDS blobs support streaming via HTTP.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
True.
|
|
172
|
+
"""
|
|
173
|
+
return True
|
|
174
|
+
|
|
175
|
+
def create_source(self, urls: list[str]) -> "BlobSource":
|
|
176
|
+
"""Create a BlobSource for reading these AT URIs.
|
|
177
|
+
|
|
178
|
+
This is a convenience method for creating a DataSource that can
|
|
179
|
+
stream the blobs written by this store.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
urls: List of AT URIs from write_shards().
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
BlobSource configured for the given URLs.
|
|
186
|
+
|
|
187
|
+
Raises:
|
|
188
|
+
ValueError: If URLs are not valid AT URIs.
|
|
189
|
+
"""
|
|
190
|
+
from .._sources import BlobSource
|
|
191
|
+
|
|
192
|
+
blob_refs: list[dict[str, str]] = []
|
|
193
|
+
|
|
194
|
+
for url in urls:
|
|
195
|
+
if not url.startswith("at://"):
|
|
196
|
+
raise ValueError(f"Not an AT URI: {url}")
|
|
197
|
+
|
|
198
|
+
parts = url[5:].split("/")
|
|
199
|
+
if len(parts) != 3 or parts[1] != "blob":
|
|
200
|
+
raise ValueError(f"Invalid blob AT URI: {url}")
|
|
201
|
+
|
|
202
|
+
did, _, cid = parts
|
|
203
|
+
blob_refs.append({"did": did, "cid": cid})
|
|
204
|
+
|
|
205
|
+
return BlobSource(blob_refs=blob_refs)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
__all__ = ["PDSBlobStore"]
|