atdata 0.2.0a1__py3-none-any.whl → 0.2.3b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +43 -10
- atdata/_cid.py +144 -0
- atdata/_helpers.py +7 -5
- atdata/_hf_api.py +690 -0
- atdata/_protocols.py +504 -0
- atdata/_schema_codec.py +438 -0
- atdata/_sources.py +508 -0
- atdata/_stub_manager.py +534 -0
- atdata/_type_utils.py +104 -0
- atdata/atmosphere/__init__.py +269 -1
- atdata/atmosphere/_types.py +4 -2
- atdata/atmosphere/client.py +146 -3
- atdata/atmosphere/lens.py +4 -3
- atdata/atmosphere/records.py +168 -7
- atdata/atmosphere/schema.py +29 -82
- atdata/atmosphere/store.py +204 -0
- atdata/cli/__init__.py +222 -0
- atdata/cli/diagnose.py +169 -0
- atdata/cli/local.py +283 -0
- atdata/dataset.py +615 -257
- atdata/lens.py +53 -54
- atdata/local.py +1456 -228
- atdata/promote.py +195 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.3b1.dist-info}/METADATA +106 -14
- atdata-0.2.3b1.dist-info/RECORD +28 -0
- atdata-0.2.0a1.dist-info/RECORD +0 -16
- {atdata-0.2.0a1.dist-info → atdata-0.2.3b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.3b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.3b1.dist-info}/licenses/LICENSE +0 -0
atdata/atmosphere/records.py
CHANGED
|
@@ -19,6 +19,7 @@ from ._types import (
|
|
|
19
19
|
|
|
20
20
|
# Import for type checking only to avoid circular imports
|
|
21
21
|
from typing import TYPE_CHECKING
|
|
22
|
+
|
|
22
23
|
if TYPE_CHECKING:
|
|
23
24
|
from ..dataset import PackableSample, Dataset
|
|
24
25
|
|
|
@@ -31,7 +32,7 @@ class DatasetPublisher:
|
|
|
31
32
|
This class creates dataset records that reference a schema and point to
|
|
32
33
|
external storage (WebDataset URLs) or ATProto blobs.
|
|
33
34
|
|
|
34
|
-
|
|
35
|
+
Examples:
|
|
35
36
|
>>> dataset = atdata.Dataset[MySample]("s3://bucket/data-{000000..000009}.tar")
|
|
36
37
|
>>>
|
|
37
38
|
>>> client = AtmosphereClient()
|
|
@@ -187,6 +188,76 @@ class DatasetPublisher:
|
|
|
187
188
|
validate=False,
|
|
188
189
|
)
|
|
189
190
|
|
|
191
|
+
def publish_with_blobs(
|
|
192
|
+
self,
|
|
193
|
+
blobs: list[bytes],
|
|
194
|
+
schema_uri: str,
|
|
195
|
+
*,
|
|
196
|
+
name: str,
|
|
197
|
+
description: Optional[str] = None,
|
|
198
|
+
tags: Optional[list[str]] = None,
|
|
199
|
+
license: Optional[str] = None,
|
|
200
|
+
metadata: Optional[dict] = None,
|
|
201
|
+
mime_type: str = "application/x-tar",
|
|
202
|
+
rkey: Optional[str] = None,
|
|
203
|
+
) -> AtUri:
|
|
204
|
+
"""Publish a dataset with data stored as ATProto blobs.
|
|
205
|
+
|
|
206
|
+
This method uploads the provided data as blobs to the PDS and creates
|
|
207
|
+
a dataset record referencing them. Suitable for smaller datasets that
|
|
208
|
+
fit within blob size limits (typically 50MB per blob, configurable).
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
blobs: List of binary data (e.g., tar shards) to upload as blobs.
|
|
212
|
+
schema_uri: AT URI of the schema record.
|
|
213
|
+
name: Human-readable dataset name.
|
|
214
|
+
description: Human-readable description.
|
|
215
|
+
tags: Searchable tags for discovery.
|
|
216
|
+
license: SPDX license identifier.
|
|
217
|
+
metadata: Arbitrary metadata dictionary.
|
|
218
|
+
mime_type: MIME type for the blobs (default: application/x-tar).
|
|
219
|
+
rkey: Optional explicit record key.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
The AT URI of the created dataset record.
|
|
223
|
+
|
|
224
|
+
Note:
|
|
225
|
+
Blobs are only retained by the PDS when referenced in a committed
|
|
226
|
+
record. This method handles that automatically.
|
|
227
|
+
"""
|
|
228
|
+
# Upload all blobs
|
|
229
|
+
blob_refs = []
|
|
230
|
+
for blob_data in blobs:
|
|
231
|
+
blob_ref = self.client.upload_blob(blob_data, mime_type=mime_type)
|
|
232
|
+
blob_refs.append(blob_ref)
|
|
233
|
+
|
|
234
|
+
# Create storage location with blob references
|
|
235
|
+
storage = StorageLocation(
|
|
236
|
+
kind="blobs",
|
|
237
|
+
blob_refs=blob_refs,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
metadata_bytes: Optional[bytes] = None
|
|
241
|
+
if metadata is not None:
|
|
242
|
+
metadata_bytes = msgpack.packb(metadata)
|
|
243
|
+
|
|
244
|
+
dataset_record = DatasetRecord(
|
|
245
|
+
name=name,
|
|
246
|
+
schema_ref=schema_uri,
|
|
247
|
+
storage=storage,
|
|
248
|
+
description=description,
|
|
249
|
+
tags=tags or [],
|
|
250
|
+
license=license,
|
|
251
|
+
metadata=metadata_bytes,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
return self.client.create_record(
|
|
255
|
+
collection=f"{LEXICON_NAMESPACE}.record",
|
|
256
|
+
record=dataset_record.to_record(),
|
|
257
|
+
rkey=rkey,
|
|
258
|
+
validate=False,
|
|
259
|
+
)
|
|
260
|
+
|
|
190
261
|
|
|
191
262
|
class DatasetLoader:
|
|
192
263
|
"""Loads dataset records from ATProto.
|
|
@@ -195,7 +266,7 @@ class DatasetLoader:
|
|
|
195
266
|
from them. Note that loading a dataset requires having the corresponding
|
|
196
267
|
Python class for the sample type.
|
|
197
268
|
|
|
198
|
-
|
|
269
|
+
Examples:
|
|
199
270
|
>>> client = AtmosphereClient()
|
|
200
271
|
>>> loader = DatasetLoader(client)
|
|
201
272
|
>>>
|
|
@@ -255,6 +326,29 @@ class DatasetLoader:
|
|
|
255
326
|
"""
|
|
256
327
|
return self.client.list_datasets(repo=repo, limit=limit)
|
|
257
328
|
|
|
329
|
+
def get_storage_type(self, uri: str | AtUri) -> str:
|
|
330
|
+
"""Get the storage type of a dataset record.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
uri: The AT URI of the dataset record.
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
Either "external" or "blobs".
|
|
337
|
+
|
|
338
|
+
Raises:
|
|
339
|
+
ValueError: If storage type is unknown.
|
|
340
|
+
"""
|
|
341
|
+
record = self.get(uri)
|
|
342
|
+
storage = record.get("storage", {})
|
|
343
|
+
storage_type = storage.get("$type", "")
|
|
344
|
+
|
|
345
|
+
if "storageExternal" in storage_type:
|
|
346
|
+
return "external"
|
|
347
|
+
elif "storageBlobs" in storage_type:
|
|
348
|
+
return "blobs"
|
|
349
|
+
else:
|
|
350
|
+
raise ValueError(f"Unknown storage type: {storage_type}")
|
|
351
|
+
|
|
258
352
|
def get_urls(self, uri: str | AtUri) -> list[str]:
|
|
259
353
|
"""Get the WebDataset URLs from a dataset record.
|
|
260
354
|
|
|
@@ -276,11 +370,70 @@ class DatasetLoader:
|
|
|
276
370
|
elif "storageBlobs" in storage_type:
|
|
277
371
|
raise ValueError(
|
|
278
372
|
"Dataset uses blob storage, not external URLs. "
|
|
279
|
-
"Use
|
|
373
|
+
"Use get_blob_urls() instead."
|
|
374
|
+
)
|
|
375
|
+
else:
|
|
376
|
+
raise ValueError(f"Unknown storage type: {storage_type}")
|
|
377
|
+
|
|
378
|
+
def get_blobs(self, uri: str | AtUri) -> list[dict]:
|
|
379
|
+
"""Get the blob references from a dataset record.
|
|
380
|
+
|
|
381
|
+
Args:
|
|
382
|
+
uri: The AT URI of the dataset record.
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
List of blob reference dicts with keys: $type, ref, mimeType, size.
|
|
386
|
+
|
|
387
|
+
Raises:
|
|
388
|
+
ValueError: If the storage type is not blobs.
|
|
389
|
+
"""
|
|
390
|
+
record = self.get(uri)
|
|
391
|
+
storage = record.get("storage", {})
|
|
392
|
+
|
|
393
|
+
storage_type = storage.get("$type", "")
|
|
394
|
+
if "storageBlobs" in storage_type:
|
|
395
|
+
return storage.get("blobs", [])
|
|
396
|
+
elif "storageExternal" in storage_type:
|
|
397
|
+
raise ValueError(
|
|
398
|
+
"Dataset uses external URL storage, not blobs. Use get_urls() instead."
|
|
280
399
|
)
|
|
281
400
|
else:
|
|
282
401
|
raise ValueError(f"Unknown storage type: {storage_type}")
|
|
283
402
|
|
|
403
|
+
def get_blob_urls(self, uri: str | AtUri) -> list[str]:
|
|
404
|
+
"""Get fetchable URLs for blob-stored dataset shards.
|
|
405
|
+
|
|
406
|
+
This resolves the PDS endpoint and constructs URLs that can be
|
|
407
|
+
used to fetch the blob data directly.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
uri: The AT URI of the dataset record.
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
List of URLs for fetching the blob data.
|
|
414
|
+
|
|
415
|
+
Raises:
|
|
416
|
+
ValueError: If storage type is not blobs or PDS cannot be resolved.
|
|
417
|
+
"""
|
|
418
|
+
if isinstance(uri, str):
|
|
419
|
+
parsed_uri = AtUri.parse(uri)
|
|
420
|
+
else:
|
|
421
|
+
parsed_uri = uri
|
|
422
|
+
|
|
423
|
+
blobs = self.get_blobs(uri)
|
|
424
|
+
did = parsed_uri.authority
|
|
425
|
+
|
|
426
|
+
urls = []
|
|
427
|
+
for blob in blobs:
|
|
428
|
+
# Extract CID from blob reference
|
|
429
|
+
ref = blob.get("ref", {})
|
|
430
|
+
cid = ref.get("$link") if isinstance(ref, dict) else str(ref)
|
|
431
|
+
if cid:
|
|
432
|
+
url = self.client.get_blob_url(did, cid)
|
|
433
|
+
urls.append(url)
|
|
434
|
+
|
|
435
|
+
return urls
|
|
436
|
+
|
|
284
437
|
def get_metadata(self, uri: str | AtUri) -> Optional[dict]:
|
|
285
438
|
"""Get the metadata from a dataset record.
|
|
286
439
|
|
|
@@ -309,6 +462,8 @@ class DatasetLoader:
|
|
|
309
462
|
You must provide the sample type class, which should match the
|
|
310
463
|
schema referenced by the record.
|
|
311
464
|
|
|
465
|
+
Supports both external URL storage and ATProto blob storage.
|
|
466
|
+
|
|
312
467
|
Args:
|
|
313
468
|
uri: The AT URI of the dataset record.
|
|
314
469
|
sample_type: The Python class for the sample type.
|
|
@@ -317,9 +472,9 @@ class DatasetLoader:
|
|
|
317
472
|
A Dataset instance configured from the record.
|
|
318
473
|
|
|
319
474
|
Raises:
|
|
320
|
-
ValueError: If
|
|
475
|
+
ValueError: If no storage URLs can be resolved.
|
|
321
476
|
|
|
322
|
-
|
|
477
|
+
Examples:
|
|
323
478
|
>>> loader = DatasetLoader(client)
|
|
324
479
|
>>> dataset = loader.to_dataset(uri, MySampleType)
|
|
325
480
|
>>> for batch in dataset.shuffled(batch_size=32):
|
|
@@ -328,9 +483,15 @@ class DatasetLoader:
|
|
|
328
483
|
# Import here to avoid circular import
|
|
329
484
|
from ..dataset import Dataset
|
|
330
485
|
|
|
331
|
-
|
|
486
|
+
storage_type = self.get_storage_type(uri)
|
|
487
|
+
|
|
488
|
+
if storage_type == "external":
|
|
489
|
+
urls = self.get_urls(uri)
|
|
490
|
+
else:
|
|
491
|
+
urls = self.get_blob_urls(uri)
|
|
492
|
+
|
|
332
493
|
if not urls:
|
|
333
|
-
raise ValueError("Dataset record has no URLs")
|
|
494
|
+
raise ValueError("Dataset record has no storage URLs")
|
|
334
495
|
|
|
335
496
|
# Use the first URL (multi-URL support could be added later)
|
|
336
497
|
url = urls[0]
|
atdata/atmosphere/schema.py
CHANGED
|
@@ -6,8 +6,7 @@ records.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
from dataclasses import fields, is_dataclass
|
|
9
|
-
from typing import Type, TypeVar, Optional,
|
|
10
|
-
import types
|
|
9
|
+
from typing import Type, TypeVar, Optional, get_type_hints, get_origin, get_args
|
|
11
10
|
|
|
12
11
|
from .client import AtmosphereClient
|
|
13
12
|
from ._types import (
|
|
@@ -17,9 +16,15 @@ from ._types import (
|
|
|
17
16
|
FieldType,
|
|
18
17
|
LEXICON_NAMESPACE,
|
|
19
18
|
)
|
|
19
|
+
from .._type_utils import (
|
|
20
|
+
unwrap_optional,
|
|
21
|
+
is_ndarray_type,
|
|
22
|
+
extract_ndarray_dtype,
|
|
23
|
+
)
|
|
20
24
|
|
|
21
25
|
# Import for type checking only to avoid circular imports
|
|
22
26
|
from typing import TYPE_CHECKING
|
|
27
|
+
|
|
23
28
|
if TYPE_CHECKING:
|
|
24
29
|
from ..dataset import PackableSample
|
|
25
30
|
|
|
@@ -32,7 +37,7 @@ class SchemaPublisher:
|
|
|
32
37
|
This class introspects a PackableSample class to extract its field
|
|
33
38
|
definitions and publishes them as an ATProto schema record.
|
|
34
39
|
|
|
35
|
-
|
|
40
|
+
Examples:
|
|
36
41
|
>>> @atdata.packable
|
|
37
42
|
... class MySample:
|
|
38
43
|
... image: NDArray
|
|
@@ -83,7 +88,9 @@ class SchemaPublisher:
|
|
|
83
88
|
TypeError: If a field type is not supported.
|
|
84
89
|
"""
|
|
85
90
|
if not is_dataclass(sample_type):
|
|
86
|
-
raise ValueError(
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"{sample_type.__name__} must be a dataclass (use @packable)"
|
|
93
|
+
)
|
|
87
94
|
|
|
88
95
|
# Build the schema record
|
|
89
96
|
schema_record = self._build_schema_record(
|
|
@@ -130,71 +137,38 @@ class SchemaPublisher:
|
|
|
130
137
|
|
|
131
138
|
def _field_to_def(self, name: str, python_type) -> FieldDef:
|
|
132
139
|
"""Convert a Python field to a FieldDef."""
|
|
133
|
-
|
|
134
|
-
is_optional = False
|
|
135
|
-
origin = get_origin(python_type)
|
|
136
|
-
|
|
137
|
-
# Handle Union types (including Optional which is Union[T, None])
|
|
138
|
-
if origin is Union or isinstance(python_type, types.UnionType):
|
|
139
|
-
args = get_args(python_type)
|
|
140
|
-
non_none_args = [a for a in args if a is not type(None)]
|
|
141
|
-
if type(None) in args or len(non_none_args) < len(args):
|
|
142
|
-
is_optional = True
|
|
143
|
-
if len(non_none_args) == 1:
|
|
144
|
-
python_type = non_none_args[0]
|
|
145
|
-
elif len(non_none_args) > 1:
|
|
146
|
-
# Complex union type - not fully supported yet
|
|
147
|
-
raise TypeError(f"Complex union types not supported: {python_type}")
|
|
148
|
-
|
|
140
|
+
python_type, is_optional = unwrap_optional(python_type)
|
|
149
141
|
field_type = self._python_type_to_field_type(python_type)
|
|
150
|
-
|
|
151
|
-
return FieldDef(
|
|
152
|
-
name=name,
|
|
153
|
-
field_type=field_type,
|
|
154
|
-
optional=is_optional,
|
|
155
|
-
)
|
|
142
|
+
return FieldDef(name=name, field_type=field_type, optional=is_optional)
|
|
156
143
|
|
|
157
144
|
def _python_type_to_field_type(self, python_type) -> FieldType:
|
|
158
145
|
"""Map a Python type to a FieldType."""
|
|
159
|
-
# Handle primitives
|
|
160
146
|
if python_type is str:
|
|
161
147
|
return FieldType(kind="primitive", primitive="str")
|
|
162
|
-
|
|
148
|
+
if python_type is int:
|
|
163
149
|
return FieldType(kind="primitive", primitive="int")
|
|
164
|
-
|
|
150
|
+
if python_type is float:
|
|
165
151
|
return FieldType(kind="primitive", primitive="float")
|
|
166
|
-
|
|
152
|
+
if python_type is bool:
|
|
167
153
|
return FieldType(kind="primitive", primitive="bool")
|
|
168
|
-
|
|
154
|
+
if python_type is bytes:
|
|
169
155
|
return FieldType(kind="primitive", primitive="bytes")
|
|
170
156
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
# Try to extract dtype info if available
|
|
176
|
-
dtype = "float32" # Default
|
|
177
|
-
args = get_args(python_type)
|
|
178
|
-
if args:
|
|
179
|
-
# NDArray[np.float64] or similar
|
|
180
|
-
dtype_arg = args[-1] if args else None
|
|
181
|
-
if dtype_arg is not None:
|
|
182
|
-
dtype = self._numpy_dtype_to_string(dtype_arg)
|
|
183
|
-
|
|
184
|
-
return FieldType(kind="ndarray", dtype=dtype, shape=None)
|
|
157
|
+
if is_ndarray_type(python_type):
|
|
158
|
+
return FieldType(
|
|
159
|
+
kind="ndarray", dtype=extract_ndarray_dtype(python_type), shape=None
|
|
160
|
+
)
|
|
185
161
|
|
|
186
|
-
# Check for list/array types
|
|
187
162
|
origin = get_origin(python_type)
|
|
188
163
|
if origin is list:
|
|
189
164
|
args = get_args(python_type)
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
# Check for nested PackableSample (not yet supported)
|
|
165
|
+
items = (
|
|
166
|
+
self._python_type_to_field_type(args[0])
|
|
167
|
+
if args
|
|
168
|
+
else FieldType(kind="primitive", primitive="str")
|
|
169
|
+
)
|
|
170
|
+
return FieldType(kind="array", items=items)
|
|
171
|
+
|
|
198
172
|
if is_dataclass(python_type):
|
|
199
173
|
raise TypeError(
|
|
200
174
|
f"Nested dataclass types not yet supported: {python_type.__name__}. "
|
|
@@ -203,33 +177,6 @@ class SchemaPublisher:
|
|
|
203
177
|
|
|
204
178
|
raise TypeError(f"Unsupported type for schema field: {python_type}")
|
|
205
179
|
|
|
206
|
-
def _numpy_dtype_to_string(self, dtype) -> str:
|
|
207
|
-
"""Convert a numpy dtype annotation to a string."""
|
|
208
|
-
dtype_str = str(dtype)
|
|
209
|
-
# Handle common numpy dtypes
|
|
210
|
-
dtype_map = {
|
|
211
|
-
"float16": "float16",
|
|
212
|
-
"float32": "float32",
|
|
213
|
-
"float64": "float64",
|
|
214
|
-
"int8": "int8",
|
|
215
|
-
"int16": "int16",
|
|
216
|
-
"int32": "int32",
|
|
217
|
-
"int64": "int64",
|
|
218
|
-
"uint8": "uint8",
|
|
219
|
-
"uint16": "uint16",
|
|
220
|
-
"uint32": "uint32",
|
|
221
|
-
"uint64": "uint64",
|
|
222
|
-
"bool": "bool",
|
|
223
|
-
"complex64": "complex64",
|
|
224
|
-
"complex128": "complex128",
|
|
225
|
-
}
|
|
226
|
-
|
|
227
|
-
for key, value in dtype_map.items():
|
|
228
|
-
if key in dtype_str:
|
|
229
|
-
return value
|
|
230
|
-
|
|
231
|
-
return "float32" # Default fallback
|
|
232
|
-
|
|
233
180
|
|
|
234
181
|
class SchemaLoader:
|
|
235
182
|
"""Loads PackableSample schemas from ATProto.
|
|
@@ -237,7 +184,7 @@ class SchemaLoader:
|
|
|
237
184
|
This class fetches schema records from ATProto and can list available
|
|
238
185
|
schemas from a repository.
|
|
239
186
|
|
|
240
|
-
|
|
187
|
+
Examples:
|
|
241
188
|
>>> client = AtmosphereClient()
|
|
242
189
|
>>> client.login("handle", "password")
|
|
243
190
|
>>>
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
"""PDS blob storage for dataset shards.
|
|
2
|
+
|
|
3
|
+
This module provides ``PDSBlobStore``, an implementation of the AbstractDataStore
|
|
4
|
+
protocol that stores dataset shards as ATProto blobs in a Personal Data Server.
|
|
5
|
+
|
|
6
|
+
This enables fully decentralized dataset storage where both metadata (records)
|
|
7
|
+
and data (blobs) live on the AT Protocol network.
|
|
8
|
+
|
|
9
|
+
Examples:
|
|
10
|
+
>>> from atdata.atmosphere import AtmosphereClient, PDSBlobStore
|
|
11
|
+
>>>
|
|
12
|
+
>>> client = AtmosphereClient()
|
|
13
|
+
>>> client.login("handle.bsky.social", "app-password")
|
|
14
|
+
>>>
|
|
15
|
+
>>> store = PDSBlobStore(client)
|
|
16
|
+
>>> urls = store.write_shards(dataset, prefix="mnist/v1")
|
|
17
|
+
>>> print(urls)
|
|
18
|
+
['at://did:plc:.../blob/bafyrei...', ...]
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import tempfile
|
|
24
|
+
from dataclasses import dataclass
|
|
25
|
+
from typing import TYPE_CHECKING, Any
|
|
26
|
+
|
|
27
|
+
import webdataset as wds
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from ..dataset import Dataset
|
|
31
|
+
from .._sources import BlobSource
|
|
32
|
+
from .client import AtmosphereClient
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class PDSBlobStore:
|
|
37
|
+
"""PDS blob store implementing AbstractDataStore protocol.
|
|
38
|
+
|
|
39
|
+
Stores dataset shards as ATProto blobs, enabling decentralized dataset
|
|
40
|
+
storage on the AT Protocol network.
|
|
41
|
+
|
|
42
|
+
Each shard is written to a temporary tar file, then uploaded as a blob
|
|
43
|
+
to the user's PDS. The returned URLs are AT URIs that can be resolved
|
|
44
|
+
to HTTP URLs for streaming.
|
|
45
|
+
|
|
46
|
+
Attributes:
|
|
47
|
+
client: Authenticated AtmosphereClient instance.
|
|
48
|
+
|
|
49
|
+
Examples:
|
|
50
|
+
>>> store = PDSBlobStore(client)
|
|
51
|
+
>>> urls = store.write_shards(dataset, prefix="training/v1")
|
|
52
|
+
>>> # Returns AT URIs like:
|
|
53
|
+
>>> # ['at://did:plc:abc/blob/bafyrei...', ...]
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
client: "AtmosphereClient"
|
|
57
|
+
|
|
58
|
+
def write_shards(
|
|
59
|
+
self,
|
|
60
|
+
ds: "Dataset",
|
|
61
|
+
*,
|
|
62
|
+
prefix: str,
|
|
63
|
+
maxcount: int = 10000,
|
|
64
|
+
maxsize: float = 3e9,
|
|
65
|
+
**kwargs: Any,
|
|
66
|
+
) -> list[str]:
|
|
67
|
+
"""Write dataset shards as PDS blobs.
|
|
68
|
+
|
|
69
|
+
Creates tar archives from the dataset and uploads each as a blob
|
|
70
|
+
to the authenticated user's PDS.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
ds: The Dataset to write.
|
|
74
|
+
prefix: Logical path prefix for naming (used in shard names only).
|
|
75
|
+
maxcount: Maximum samples per shard (default: 10000).
|
|
76
|
+
maxsize: Maximum shard size in bytes (default: 3GB, PDS limit).
|
|
77
|
+
**kwargs: Additional args passed to wds.ShardWriter.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
List of AT URIs for the written blobs, in format:
|
|
81
|
+
``at://{did}/blob/{cid}``
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
ValueError: If not authenticated.
|
|
85
|
+
RuntimeError: If no shards were written.
|
|
86
|
+
|
|
87
|
+
Note:
|
|
88
|
+
PDS blobs have size limits (typically 50MB-5GB depending on PDS).
|
|
89
|
+
Adjust maxcount/maxsize to stay within limits.
|
|
90
|
+
"""
|
|
91
|
+
if not self.client.did:
|
|
92
|
+
raise ValueError("Client must be authenticated to upload blobs")
|
|
93
|
+
|
|
94
|
+
did = self.client.did
|
|
95
|
+
blob_urls: list[str] = []
|
|
96
|
+
|
|
97
|
+
# Write shards to temp files, upload each as blob
|
|
98
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
99
|
+
shard_pattern = f"{temp_dir}/shard-%06d.tar"
|
|
100
|
+
written_files: list[str] = []
|
|
101
|
+
|
|
102
|
+
# Track written files via custom post callback
|
|
103
|
+
def track_file(fname: str) -> None:
|
|
104
|
+
written_files.append(fname)
|
|
105
|
+
|
|
106
|
+
with wds.writer.ShardWriter(
|
|
107
|
+
shard_pattern,
|
|
108
|
+
maxcount=maxcount,
|
|
109
|
+
maxsize=maxsize,
|
|
110
|
+
post=track_file,
|
|
111
|
+
**kwargs,
|
|
112
|
+
) as sink:
|
|
113
|
+
for sample in ds.ordered(batch_size=None):
|
|
114
|
+
sink.write(sample.as_wds)
|
|
115
|
+
|
|
116
|
+
if not written_files:
|
|
117
|
+
raise RuntimeError("No shards written")
|
|
118
|
+
|
|
119
|
+
# Upload each shard as a blob
|
|
120
|
+
for shard_path in written_files:
|
|
121
|
+
with open(shard_path, "rb") as f:
|
|
122
|
+
shard_data = f.read()
|
|
123
|
+
|
|
124
|
+
blob_ref = self.client.upload_blob(
|
|
125
|
+
shard_data,
|
|
126
|
+
mime_type="application/x-tar",
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Extract CID from blob reference
|
|
130
|
+
cid = blob_ref["ref"]["$link"]
|
|
131
|
+
at_uri = f"at://{did}/blob/{cid}"
|
|
132
|
+
blob_urls.append(at_uri)
|
|
133
|
+
|
|
134
|
+
return blob_urls
|
|
135
|
+
|
|
136
|
+
def read_url(self, url: str) -> str:
|
|
137
|
+
"""Resolve an AT URI blob reference to an HTTP URL.
|
|
138
|
+
|
|
139
|
+
Transforms ``at://did/blob/cid`` URIs to HTTP URLs that can be
|
|
140
|
+
streamed by WebDataset.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
url: AT URI in format ``at://{did}/blob/{cid}``.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
HTTP URL for fetching the blob via PDS API.
|
|
147
|
+
|
|
148
|
+
Raises:
|
|
149
|
+
ValueError: If URL format is invalid or PDS cannot be resolved.
|
|
150
|
+
"""
|
|
151
|
+
if not url.startswith("at://"):
|
|
152
|
+
# Not an AT URI, return unchanged
|
|
153
|
+
return url
|
|
154
|
+
|
|
155
|
+
# Parse at://did/blob/cid
|
|
156
|
+
parts = url[5:].split("/") # Remove 'at://'
|
|
157
|
+
if len(parts) != 3 or parts[1] != "blob":
|
|
158
|
+
raise ValueError(f"Invalid blob AT URI format: {url}")
|
|
159
|
+
|
|
160
|
+
did, _, cid = parts
|
|
161
|
+
return self.client.get_blob_url(did, cid)
|
|
162
|
+
|
|
163
|
+
def supports_streaming(self) -> bool:
|
|
164
|
+
"""PDS blobs support streaming via HTTP.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
True.
|
|
168
|
+
"""
|
|
169
|
+
return True
|
|
170
|
+
|
|
171
|
+
def create_source(self, urls: list[str]) -> "BlobSource":
|
|
172
|
+
"""Create a BlobSource for reading these AT URIs.
|
|
173
|
+
|
|
174
|
+
This is a convenience method for creating a DataSource that can
|
|
175
|
+
stream the blobs written by this store.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
urls: List of AT URIs from write_shards().
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
BlobSource configured for the given URLs.
|
|
182
|
+
|
|
183
|
+
Raises:
|
|
184
|
+
ValueError: If URLs are not valid AT URIs.
|
|
185
|
+
"""
|
|
186
|
+
from .._sources import BlobSource
|
|
187
|
+
|
|
188
|
+
blob_refs: list[dict[str, str]] = []
|
|
189
|
+
|
|
190
|
+
for url in urls:
|
|
191
|
+
if not url.startswith("at://"):
|
|
192
|
+
raise ValueError(f"Not an AT URI: {url}")
|
|
193
|
+
|
|
194
|
+
parts = url[5:].split("/")
|
|
195
|
+
if len(parts) != 3 or parts[1] != "blob":
|
|
196
|
+
raise ValueError(f"Invalid blob AT URI: {url}")
|
|
197
|
+
|
|
198
|
+
did, _, cid = parts
|
|
199
|
+
blob_refs.append({"did": did, "cid": cid})
|
|
200
|
+
|
|
201
|
+
return BlobSource(blob_refs=blob_refs)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
__all__ = ["PDSBlobStore"]
|