atdata 0.2.0a1__py3-none-any.whl → 0.2.3b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,6 +19,7 @@ from ._types import (
19
19
 
20
20
  # Import for type checking only to avoid circular imports
21
21
  from typing import TYPE_CHECKING
22
+
22
23
  if TYPE_CHECKING:
23
24
  from ..dataset import PackableSample, Dataset
24
25
 
@@ -31,7 +32,7 @@ class DatasetPublisher:
31
32
  This class creates dataset records that reference a schema and point to
32
33
  external storage (WebDataset URLs) or ATProto blobs.
33
34
 
34
- Example:
35
+ Examples:
35
36
  >>> dataset = atdata.Dataset[MySample]("s3://bucket/data-{000000..000009}.tar")
36
37
  >>>
37
38
  >>> client = AtmosphereClient()
@@ -187,6 +188,76 @@ class DatasetPublisher:
187
188
  validate=False,
188
189
  )
189
190
 
191
+ def publish_with_blobs(
192
+ self,
193
+ blobs: list[bytes],
194
+ schema_uri: str,
195
+ *,
196
+ name: str,
197
+ description: Optional[str] = None,
198
+ tags: Optional[list[str]] = None,
199
+ license: Optional[str] = None,
200
+ metadata: Optional[dict] = None,
201
+ mime_type: str = "application/x-tar",
202
+ rkey: Optional[str] = None,
203
+ ) -> AtUri:
204
+ """Publish a dataset with data stored as ATProto blobs.
205
+
206
+ This method uploads the provided data as blobs to the PDS and creates
207
+ a dataset record referencing them. Suitable for smaller datasets that
208
+ fit within blob size limits (typically 50MB per blob, configurable).
209
+
210
+ Args:
211
+ blobs: List of binary data (e.g., tar shards) to upload as blobs.
212
+ schema_uri: AT URI of the schema record.
213
+ name: Human-readable dataset name.
214
+ description: Human-readable description.
215
+ tags: Searchable tags for discovery.
216
+ license: SPDX license identifier.
217
+ metadata: Arbitrary metadata dictionary.
218
+ mime_type: MIME type for the blobs (default: application/x-tar).
219
+ rkey: Optional explicit record key.
220
+
221
+ Returns:
222
+ The AT URI of the created dataset record.
223
+
224
+ Note:
225
+ Blobs are only retained by the PDS when referenced in a committed
226
+ record. This method handles that automatically.
227
+ """
228
+ # Upload all blobs
229
+ blob_refs = []
230
+ for blob_data in blobs:
231
+ blob_ref = self.client.upload_blob(blob_data, mime_type=mime_type)
232
+ blob_refs.append(blob_ref)
233
+
234
+ # Create storage location with blob references
235
+ storage = StorageLocation(
236
+ kind="blobs",
237
+ blob_refs=blob_refs,
238
+ )
239
+
240
+ metadata_bytes: Optional[bytes] = None
241
+ if metadata is not None:
242
+ metadata_bytes = msgpack.packb(metadata)
243
+
244
+ dataset_record = DatasetRecord(
245
+ name=name,
246
+ schema_ref=schema_uri,
247
+ storage=storage,
248
+ description=description,
249
+ tags=tags or [],
250
+ license=license,
251
+ metadata=metadata_bytes,
252
+ )
253
+
254
+ return self.client.create_record(
255
+ collection=f"{LEXICON_NAMESPACE}.record",
256
+ record=dataset_record.to_record(),
257
+ rkey=rkey,
258
+ validate=False,
259
+ )
260
+
190
261
 
191
262
  class DatasetLoader:
192
263
  """Loads dataset records from ATProto.
@@ -195,7 +266,7 @@ class DatasetLoader:
195
266
  from them. Note that loading a dataset requires having the corresponding
196
267
  Python class for the sample type.
197
268
 
198
- Example:
269
+ Examples:
199
270
  >>> client = AtmosphereClient()
200
271
  >>> loader = DatasetLoader(client)
201
272
  >>>
@@ -255,6 +326,29 @@ class DatasetLoader:
255
326
  """
256
327
  return self.client.list_datasets(repo=repo, limit=limit)
257
328
 
329
+ def get_storage_type(self, uri: str | AtUri) -> str:
330
+ """Get the storage type of a dataset record.
331
+
332
+ Args:
333
+ uri: The AT URI of the dataset record.
334
+
335
+ Returns:
336
+ Either "external" or "blobs".
337
+
338
+ Raises:
339
+ ValueError: If storage type is unknown.
340
+ """
341
+ record = self.get(uri)
342
+ storage = record.get("storage", {})
343
+ storage_type = storage.get("$type", "")
344
+
345
+ if "storageExternal" in storage_type:
346
+ return "external"
347
+ elif "storageBlobs" in storage_type:
348
+ return "blobs"
349
+ else:
350
+ raise ValueError(f"Unknown storage type: {storage_type}")
351
+
258
352
  def get_urls(self, uri: str | AtUri) -> list[str]:
259
353
  """Get the WebDataset URLs from a dataset record.
260
354
 
@@ -276,11 +370,70 @@ class DatasetLoader:
276
370
  elif "storageBlobs" in storage_type:
277
371
  raise ValueError(
278
372
  "Dataset uses blob storage, not external URLs. "
279
- "Use get_blobs() instead."
373
+ "Use get_blob_urls() instead."
374
+ )
375
+ else:
376
+ raise ValueError(f"Unknown storage type: {storage_type}")
377
+
378
+ def get_blobs(self, uri: str | AtUri) -> list[dict]:
379
+ """Get the blob references from a dataset record.
380
+
381
+ Args:
382
+ uri: The AT URI of the dataset record.
383
+
384
+ Returns:
385
+ List of blob reference dicts with keys: $type, ref, mimeType, size.
386
+
387
+ Raises:
388
+ ValueError: If the storage type is not blobs.
389
+ """
390
+ record = self.get(uri)
391
+ storage = record.get("storage", {})
392
+
393
+ storage_type = storage.get("$type", "")
394
+ if "storageBlobs" in storage_type:
395
+ return storage.get("blobs", [])
396
+ elif "storageExternal" in storage_type:
397
+ raise ValueError(
398
+ "Dataset uses external URL storage, not blobs. Use get_urls() instead."
280
399
  )
281
400
  else:
282
401
  raise ValueError(f"Unknown storage type: {storage_type}")
283
402
 
403
+ def get_blob_urls(self, uri: str | AtUri) -> list[str]:
404
+ """Get fetchable URLs for blob-stored dataset shards.
405
+
406
+ This resolves the PDS endpoint and constructs URLs that can be
407
+ used to fetch the blob data directly.
408
+
409
+ Args:
410
+ uri: The AT URI of the dataset record.
411
+
412
+ Returns:
413
+ List of URLs for fetching the blob data.
414
+
415
+ Raises:
416
+ ValueError: If storage type is not blobs or PDS cannot be resolved.
417
+ """
418
+ if isinstance(uri, str):
419
+ parsed_uri = AtUri.parse(uri)
420
+ else:
421
+ parsed_uri = uri
422
+
423
+ blobs = self.get_blobs(uri)
424
+ did = parsed_uri.authority
425
+
426
+ urls = []
427
+ for blob in blobs:
428
+ # Extract CID from blob reference
429
+ ref = blob.get("ref", {})
430
+ cid = ref.get("$link") if isinstance(ref, dict) else str(ref)
431
+ if cid:
432
+ url = self.client.get_blob_url(did, cid)
433
+ urls.append(url)
434
+
435
+ return urls
436
+
284
437
  def get_metadata(self, uri: str | AtUri) -> Optional[dict]:
285
438
  """Get the metadata from a dataset record.
286
439
 
@@ -309,6 +462,8 @@ class DatasetLoader:
309
462
  You must provide the sample type class, which should match the
310
463
  schema referenced by the record.
311
464
 
465
+ Supports both external URL storage and ATProto blob storage.
466
+
312
467
  Args:
313
468
  uri: The AT URI of the dataset record.
314
469
  sample_type: The Python class for the sample type.
@@ -317,9 +472,9 @@ class DatasetLoader:
317
472
  A Dataset instance configured from the record.
318
473
 
319
474
  Raises:
320
- ValueError: If the storage type is not external URLs.
475
+ ValueError: If no storage URLs can be resolved.
321
476
 
322
- Example:
477
+ Examples:
323
478
  >>> loader = DatasetLoader(client)
324
479
  >>> dataset = loader.to_dataset(uri, MySampleType)
325
480
  >>> for batch in dataset.shuffled(batch_size=32):
@@ -328,9 +483,15 @@ class DatasetLoader:
328
483
  # Import here to avoid circular import
329
484
  from ..dataset import Dataset
330
485
 
331
- urls = self.get_urls(uri)
486
+ storage_type = self.get_storage_type(uri)
487
+
488
+ if storage_type == "external":
489
+ urls = self.get_urls(uri)
490
+ else:
491
+ urls = self.get_blob_urls(uri)
492
+
332
493
  if not urls:
333
- raise ValueError("Dataset record has no URLs")
494
+ raise ValueError("Dataset record has no storage URLs")
334
495
 
335
496
  # Use the first URL (multi-URL support could be added later)
336
497
  url = urls[0]
@@ -6,8 +6,7 @@ records.
6
6
  """
7
7
 
8
8
  from dataclasses import fields, is_dataclass
9
- from typing import Type, TypeVar, Optional, Union, get_type_hints, get_origin, get_args
10
- import types
9
+ from typing import Type, TypeVar, Optional, get_type_hints, get_origin, get_args
11
10
 
12
11
  from .client import AtmosphereClient
13
12
  from ._types import (
@@ -17,9 +16,15 @@ from ._types import (
17
16
  FieldType,
18
17
  LEXICON_NAMESPACE,
19
18
  )
19
+ from .._type_utils import (
20
+ unwrap_optional,
21
+ is_ndarray_type,
22
+ extract_ndarray_dtype,
23
+ )
20
24
 
21
25
  # Import for type checking only to avoid circular imports
22
26
  from typing import TYPE_CHECKING
27
+
23
28
  if TYPE_CHECKING:
24
29
  from ..dataset import PackableSample
25
30
 
@@ -32,7 +37,7 @@ class SchemaPublisher:
32
37
  This class introspects a PackableSample class to extract its field
33
38
  definitions and publishes them as an ATProto schema record.
34
39
 
35
- Example:
40
+ Examples:
36
41
  >>> @atdata.packable
37
42
  ... class MySample:
38
43
  ... image: NDArray
@@ -83,7 +88,9 @@ class SchemaPublisher:
83
88
  TypeError: If a field type is not supported.
84
89
  """
85
90
  if not is_dataclass(sample_type):
86
- raise ValueError(f"{sample_type.__name__} must be a dataclass (use @packable)")
91
+ raise ValueError(
92
+ f"{sample_type.__name__} must be a dataclass (use @packable)"
93
+ )
87
94
 
88
95
  # Build the schema record
89
96
  schema_record = self._build_schema_record(
@@ -130,71 +137,38 @@ class SchemaPublisher:
130
137
 
131
138
  def _field_to_def(self, name: str, python_type) -> FieldDef:
132
139
  """Convert a Python field to a FieldDef."""
133
- # Check for Optional types (Union with None)
134
- is_optional = False
135
- origin = get_origin(python_type)
136
-
137
- # Handle Union types (including Optional which is Union[T, None])
138
- if origin is Union or isinstance(python_type, types.UnionType):
139
- args = get_args(python_type)
140
- non_none_args = [a for a in args if a is not type(None)]
141
- if type(None) in args or len(non_none_args) < len(args):
142
- is_optional = True
143
- if len(non_none_args) == 1:
144
- python_type = non_none_args[0]
145
- elif len(non_none_args) > 1:
146
- # Complex union type - not fully supported yet
147
- raise TypeError(f"Complex union types not supported: {python_type}")
148
-
140
+ python_type, is_optional = unwrap_optional(python_type)
149
141
  field_type = self._python_type_to_field_type(python_type)
150
-
151
- return FieldDef(
152
- name=name,
153
- field_type=field_type,
154
- optional=is_optional,
155
- )
142
+ return FieldDef(name=name, field_type=field_type, optional=is_optional)
156
143
 
157
144
  def _python_type_to_field_type(self, python_type) -> FieldType:
158
145
  """Map a Python type to a FieldType."""
159
- # Handle primitives
160
146
  if python_type is str:
161
147
  return FieldType(kind="primitive", primitive="str")
162
- elif python_type is int:
148
+ if python_type is int:
163
149
  return FieldType(kind="primitive", primitive="int")
164
- elif python_type is float:
150
+ if python_type is float:
165
151
  return FieldType(kind="primitive", primitive="float")
166
- elif python_type is bool:
152
+ if python_type is bool:
167
153
  return FieldType(kind="primitive", primitive="bool")
168
- elif python_type is bytes:
154
+ if python_type is bytes:
169
155
  return FieldType(kind="primitive", primitive="bytes")
170
156
 
171
- # Check for NDArray
172
- # NDArray from numpy.typing is a special generic alias
173
- type_str = str(python_type)
174
- if "NDArray" in type_str or "ndarray" in type_str.lower():
175
- # Try to extract dtype info if available
176
- dtype = "float32" # Default
177
- args = get_args(python_type)
178
- if args:
179
- # NDArray[np.float64] or similar
180
- dtype_arg = args[-1] if args else None
181
- if dtype_arg is not None:
182
- dtype = self._numpy_dtype_to_string(dtype_arg)
183
-
184
- return FieldType(kind="ndarray", dtype=dtype, shape=None)
157
+ if is_ndarray_type(python_type):
158
+ return FieldType(
159
+ kind="ndarray", dtype=extract_ndarray_dtype(python_type), shape=None
160
+ )
185
161
 
186
- # Check for list/array types
187
162
  origin = get_origin(python_type)
188
163
  if origin is list:
189
164
  args = get_args(python_type)
190
- if args:
191
- items = self._python_type_to_field_type(args[0])
192
- return FieldType(kind="array", items=items)
193
- else:
194
- # Untyped list
195
- return FieldType(kind="array", items=FieldType(kind="primitive", primitive="str"))
196
-
197
- # Check for nested PackableSample (not yet supported)
165
+ items = (
166
+ self._python_type_to_field_type(args[0])
167
+ if args
168
+ else FieldType(kind="primitive", primitive="str")
169
+ )
170
+ return FieldType(kind="array", items=items)
171
+
198
172
  if is_dataclass(python_type):
199
173
  raise TypeError(
200
174
  f"Nested dataclass types not yet supported: {python_type.__name__}. "
@@ -203,33 +177,6 @@ class SchemaPublisher:
203
177
 
204
178
  raise TypeError(f"Unsupported type for schema field: {python_type}")
205
179
 
206
- def _numpy_dtype_to_string(self, dtype) -> str:
207
- """Convert a numpy dtype annotation to a string."""
208
- dtype_str = str(dtype)
209
- # Handle common numpy dtypes
210
- dtype_map = {
211
- "float16": "float16",
212
- "float32": "float32",
213
- "float64": "float64",
214
- "int8": "int8",
215
- "int16": "int16",
216
- "int32": "int32",
217
- "int64": "int64",
218
- "uint8": "uint8",
219
- "uint16": "uint16",
220
- "uint32": "uint32",
221
- "uint64": "uint64",
222
- "bool": "bool",
223
- "complex64": "complex64",
224
- "complex128": "complex128",
225
- }
226
-
227
- for key, value in dtype_map.items():
228
- if key in dtype_str:
229
- return value
230
-
231
- return "float32" # Default fallback
232
-
233
180
 
234
181
  class SchemaLoader:
235
182
  """Loads PackableSample schemas from ATProto.
@@ -237,7 +184,7 @@ class SchemaLoader:
237
184
  This class fetches schema records from ATProto and can list available
238
185
  schemas from a repository.
239
186
 
240
- Example:
187
+ Examples:
241
188
  >>> client = AtmosphereClient()
242
189
  >>> client.login("handle", "password")
243
190
  >>>
@@ -0,0 +1,204 @@
1
+ """PDS blob storage for dataset shards.
2
+
3
+ This module provides ``PDSBlobStore``, an implementation of the AbstractDataStore
4
+ protocol that stores dataset shards as ATProto blobs in a Personal Data Server.
5
+
6
+ This enables fully decentralized dataset storage where both metadata (records)
7
+ and data (blobs) live on the AT Protocol network.
8
+
9
+ Examples:
10
+ >>> from atdata.atmosphere import AtmosphereClient, PDSBlobStore
11
+ >>>
12
+ >>> client = AtmosphereClient()
13
+ >>> client.login("handle.bsky.social", "app-password")
14
+ >>>
15
+ >>> store = PDSBlobStore(client)
16
+ >>> urls = store.write_shards(dataset, prefix="mnist/v1")
17
+ >>> print(urls)
18
+ ['at://did:plc:.../blob/bafyrei...', ...]
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import tempfile
24
+ from dataclasses import dataclass
25
+ from typing import TYPE_CHECKING, Any
26
+
27
+ import webdataset as wds
28
+
29
+ if TYPE_CHECKING:
30
+ from ..dataset import Dataset
31
+ from .._sources import BlobSource
32
+ from .client import AtmosphereClient
33
+
34
+
35
+ @dataclass
36
+ class PDSBlobStore:
37
+ """PDS blob store implementing AbstractDataStore protocol.
38
+
39
+ Stores dataset shards as ATProto blobs, enabling decentralized dataset
40
+ storage on the AT Protocol network.
41
+
42
+ Each shard is written to a temporary tar file, then uploaded as a blob
43
+ to the user's PDS. The returned URLs are AT URIs that can be resolved
44
+ to HTTP URLs for streaming.
45
+
46
+ Attributes:
47
+ client: Authenticated AtmosphereClient instance.
48
+
49
+ Examples:
50
+ >>> store = PDSBlobStore(client)
51
+ >>> urls = store.write_shards(dataset, prefix="training/v1")
52
+ >>> # Returns AT URIs like:
53
+ >>> # ['at://did:plc:abc/blob/bafyrei...', ...]
54
+ """
55
+
56
+ client: "AtmosphereClient"
57
+
58
+ def write_shards(
59
+ self,
60
+ ds: "Dataset",
61
+ *,
62
+ prefix: str,
63
+ maxcount: int = 10000,
64
+ maxsize: float = 3e9,
65
+ **kwargs: Any,
66
+ ) -> list[str]:
67
+ """Write dataset shards as PDS blobs.
68
+
69
+ Creates tar archives from the dataset and uploads each as a blob
70
+ to the authenticated user's PDS.
71
+
72
+ Args:
73
+ ds: The Dataset to write.
74
+ prefix: Logical path prefix for naming (used in shard names only).
75
+ maxcount: Maximum samples per shard (default: 10000).
76
+ maxsize: Maximum shard size in bytes (default: 3GB, PDS limit).
77
+ **kwargs: Additional args passed to wds.ShardWriter.
78
+
79
+ Returns:
80
+ List of AT URIs for the written blobs, in format:
81
+ ``at://{did}/blob/{cid}``
82
+
83
+ Raises:
84
+ ValueError: If not authenticated.
85
+ RuntimeError: If no shards were written.
86
+
87
+ Note:
88
+ PDS blobs have size limits (typically 50MB-5GB depending on PDS).
89
+ Adjust maxcount/maxsize to stay within limits.
90
+ """
91
+ if not self.client.did:
92
+ raise ValueError("Client must be authenticated to upload blobs")
93
+
94
+ did = self.client.did
95
+ blob_urls: list[str] = []
96
+
97
+ # Write shards to temp files, upload each as blob
98
+ with tempfile.TemporaryDirectory() as temp_dir:
99
+ shard_pattern = f"{temp_dir}/shard-%06d.tar"
100
+ written_files: list[str] = []
101
+
102
+ # Track written files via custom post callback
103
+ def track_file(fname: str) -> None:
104
+ written_files.append(fname)
105
+
106
+ with wds.writer.ShardWriter(
107
+ shard_pattern,
108
+ maxcount=maxcount,
109
+ maxsize=maxsize,
110
+ post=track_file,
111
+ **kwargs,
112
+ ) as sink:
113
+ for sample in ds.ordered(batch_size=None):
114
+ sink.write(sample.as_wds)
115
+
116
+ if not written_files:
117
+ raise RuntimeError("No shards written")
118
+
119
+ # Upload each shard as a blob
120
+ for shard_path in written_files:
121
+ with open(shard_path, "rb") as f:
122
+ shard_data = f.read()
123
+
124
+ blob_ref = self.client.upload_blob(
125
+ shard_data,
126
+ mime_type="application/x-tar",
127
+ )
128
+
129
+ # Extract CID from blob reference
130
+ cid = blob_ref["ref"]["$link"]
131
+ at_uri = f"at://{did}/blob/{cid}"
132
+ blob_urls.append(at_uri)
133
+
134
+ return blob_urls
135
+
136
+ def read_url(self, url: str) -> str:
137
+ """Resolve an AT URI blob reference to an HTTP URL.
138
+
139
+ Transforms ``at://did/blob/cid`` URIs to HTTP URLs that can be
140
+ streamed by WebDataset.
141
+
142
+ Args:
143
+ url: AT URI in format ``at://{did}/blob/{cid}``.
144
+
145
+ Returns:
146
+ HTTP URL for fetching the blob via PDS API.
147
+
148
+ Raises:
149
+ ValueError: If URL format is invalid or PDS cannot be resolved.
150
+ """
151
+ if not url.startswith("at://"):
152
+ # Not an AT URI, return unchanged
153
+ return url
154
+
155
+ # Parse at://did/blob/cid
156
+ parts = url[5:].split("/") # Remove 'at://'
157
+ if len(parts) != 3 or parts[1] != "blob":
158
+ raise ValueError(f"Invalid blob AT URI format: {url}")
159
+
160
+ did, _, cid = parts
161
+ return self.client.get_blob_url(did, cid)
162
+
163
+ def supports_streaming(self) -> bool:
164
+ """PDS blobs support streaming via HTTP.
165
+
166
+ Returns:
167
+ True.
168
+ """
169
+ return True
170
+
171
+ def create_source(self, urls: list[str]) -> "BlobSource":
172
+ """Create a BlobSource for reading these AT URIs.
173
+
174
+ This is a convenience method for creating a DataSource that can
175
+ stream the blobs written by this store.
176
+
177
+ Args:
178
+ urls: List of AT URIs from write_shards().
179
+
180
+ Returns:
181
+ BlobSource configured for the given URLs.
182
+
183
+ Raises:
184
+ ValueError: If URLs are not valid AT URIs.
185
+ """
186
+ from .._sources import BlobSource
187
+
188
+ blob_refs: list[dict[str, str]] = []
189
+
190
+ for url in urls:
191
+ if not url.startswith("at://"):
192
+ raise ValueError(f"Not an AT URI: {url}")
193
+
194
+ parts = url[5:].split("/")
195
+ if len(parts) != 3 or parts[1] != "blob":
196
+ raise ValueError(f"Invalid blob AT URI: {url}")
197
+
198
+ did, _, cid = parts
199
+ blob_refs.append({"did": did, "cid": cid})
200
+
201
+ return BlobSource(blob_refs=blob_refs)
202
+
203
+
204
+ __all__ = ["PDSBlobStore"]