atdata 0.3.1b1__py3-none-any.whl → 0.3.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,11 +10,16 @@ import msgpack
10
10
 
11
11
  from .client import Atmosphere
12
12
  from .schema import SchemaPublisher
13
- from ._types import (
14
- AtUri,
15
- DatasetRecord,
16
- StorageLocation,
17
- LEXICON_NAMESPACE,
13
+ from ._types import AtUri, LEXICON_NAMESPACE
14
+ from ._lexicon_types import (
15
+ LexDatasetRecord,
16
+ StorageHttp,
17
+ StorageS3,
18
+ StorageBlobs,
19
+ HttpShardEntry,
20
+ S3ShardEntry,
21
+ BlobEntry,
22
+ ShardChecksum,
18
23
  )
19
24
 
20
25
  # Import for type checking only to avoid circular imports
@@ -27,14 +32,19 @@ if TYPE_CHECKING:
27
32
  ST = TypeVar("ST", bound="Packable")
28
33
 
29
34
 
35
+ def _placeholder_checksum() -> ShardChecksum:
36
+ """Return an empty checksum placeholder for shards without pre-computed digests."""
37
+ return ShardChecksum(algorithm="none", digest="")
38
+
39
+
30
40
  class DatasetPublisher:
31
41
  """Publishes dataset index records to ATProto.
32
42
 
33
43
  This class creates dataset records that reference a schema and point to
34
- external storage (WebDataset URLs) or ATProto blobs.
44
+ HTTP storage, S3 storage, or ATProto blobs.
35
45
 
36
46
  Examples:
37
- >>> dataset = atdata.Dataset[MySample]("s3://bucket/data-{000000..000009}.tar")
47
+ >>> dataset = atdata.Dataset[MySample]("https://example.com/data-000000.tar")
38
48
  >>>
39
49
  >>> atmo = Atmosphere.login("handle", "password")
40
50
  >>>
@@ -56,6 +66,40 @@ class DatasetPublisher:
56
66
  self.client = client
57
67
  self._schema_publisher = SchemaPublisher(client)
58
68
 
69
+ def _create_record(
70
+ self,
71
+ storage: "StorageHttp | StorageS3 | StorageBlobs",
72
+ *,
73
+ name: str,
74
+ schema_uri: str,
75
+ description: Optional[str] = None,
76
+ tags: Optional[list[str]] = None,
77
+ license: Optional[str] = None,
78
+ metadata: Optional[dict] = None,
79
+ rkey: Optional[str] = None,
80
+ ) -> AtUri:
81
+ """Build a LexDatasetRecord and publish it to ATProto."""
82
+ metadata_bytes: Optional[bytes] = None
83
+ if metadata is not None:
84
+ metadata_bytes = msgpack.packb(metadata)
85
+
86
+ dataset_record = LexDatasetRecord(
87
+ name=name,
88
+ schema_ref=schema_uri,
89
+ storage=storage,
90
+ description=description,
91
+ tags=tags or [],
92
+ license=license,
93
+ metadata=metadata_bytes,
94
+ )
95
+
96
+ return self.client.create_record(
97
+ collection=f"{LEXICON_NAMESPACE}.record",
98
+ record=dataset_record.to_record(),
99
+ rkey=rkey,
100
+ validate=False,
101
+ )
102
+
59
103
  def publish(
60
104
  self,
61
105
  dataset: "Dataset[ST]",
@@ -90,46 +134,34 @@ class DatasetPublisher:
90
134
  Raises:
91
135
  ValueError: If schema_uri is not provided and auto_publish_schema is False.
92
136
  """
93
- # Ensure we have a schema reference
94
137
  if schema_uri is None:
95
138
  if not auto_publish_schema:
96
139
  raise ValueError(
97
140
  "schema_uri is required when auto_publish_schema=False"
98
141
  )
99
- # Auto-publish the schema
100
142
  schema_uri_obj = self._schema_publisher.publish(
101
143
  dataset.sample_type,
102
144
  version=schema_version,
103
145
  )
104
146
  schema_uri = str(schema_uri_obj)
105
147
 
106
- # Build the storage location
107
- storage = StorageLocation(
108
- kind="external",
109
- urls=[dataset.url],
148
+ shard_urls = dataset.list_shards()
149
+ storage = StorageHttp(
150
+ shards=[
151
+ HttpShardEntry(url=url, checksum=_placeholder_checksum())
152
+ for url in shard_urls
153
+ ]
110
154
  )
111
155
 
112
- # Build dataset record
113
- metadata_bytes: Optional[bytes] = None
114
- if dataset.metadata is not None:
115
- metadata_bytes = msgpack.packb(dataset.metadata)
116
-
117
- dataset_record = DatasetRecord(
156
+ return self._create_record(
157
+ storage,
118
158
  name=name,
119
- schema_ref=schema_uri,
120
- storage=storage,
159
+ schema_uri=schema_uri,
121
160
  description=description,
122
- tags=tags or [],
161
+ tags=tags,
123
162
  license=license,
124
- metadata=metadata_bytes,
125
- )
126
-
127
- # Publish to ATProto
128
- return self.client.create_record(
129
- collection=f"{LEXICON_NAMESPACE}.record",
130
- record=dataset_record.to_record(),
163
+ metadata=dataset.metadata,
131
164
  rkey=rkey,
132
- validate=False,
133
165
  )
134
166
 
135
167
  def publish_with_urls(
@@ -142,50 +174,162 @@ class DatasetPublisher:
142
174
  tags: Optional[list[str]] = None,
143
175
  license: Optional[str] = None,
144
176
  metadata: Optional[dict] = None,
177
+ checksums: Optional[list[ShardChecksum]] = None,
145
178
  rkey: Optional[str] = None,
146
179
  ) -> AtUri:
147
- """Publish a dataset record with explicit URLs.
180
+ """Publish a dataset record with explicit HTTP URLs.
148
181
 
149
182
  This method allows publishing a dataset record without having a
150
183
  Dataset object, useful for registering existing WebDataset files.
184
+ Each URL should be an individual shard (no brace notation).
151
185
 
152
186
  Args:
153
- urls: List of WebDataset URLs with brace notation.
187
+ urls: List of individual shard URLs.
154
188
  schema_uri: AT URI of the schema record.
155
189
  name: Human-readable dataset name.
156
190
  description: Human-readable description.
157
191
  tags: Searchable tags for discovery.
158
192
  license: SPDX license identifier.
159
193
  metadata: Arbitrary metadata dictionary.
194
+ checksums: Per-shard checksums. If not provided, empty checksums
195
+ are used.
160
196
  rkey: Optional explicit record key.
161
197
 
162
198
  Returns:
163
199
  The AT URI of the created dataset record.
164
200
  """
165
- storage = StorageLocation(
166
- kind="external",
167
- urls=urls,
201
+ if checksums and len(checksums) != len(urls):
202
+ raise ValueError(
203
+ f"checksums length ({len(checksums)}) must match "
204
+ f"urls length ({len(urls)})"
205
+ )
206
+
207
+ shards = [
208
+ HttpShardEntry(
209
+ url=url,
210
+ checksum=checksums[i] if checksums else _placeholder_checksum(),
211
+ )
212
+ for i, url in enumerate(urls)
213
+ ]
214
+
215
+ return self._create_record(
216
+ StorageHttp(shards=shards),
217
+ name=name,
218
+ schema_uri=schema_uri,
219
+ description=description,
220
+ tags=tags,
221
+ license=license,
222
+ metadata=metadata,
223
+ rkey=rkey,
168
224
  )
169
225
 
170
- metadata_bytes: Optional[bytes] = None
171
- if metadata is not None:
172
- metadata_bytes = msgpack.packb(metadata)
226
+ def publish_with_s3(
227
+ self,
228
+ bucket: str,
229
+ keys: list[str],
230
+ schema_uri: str,
231
+ *,
232
+ name: str,
233
+ region: Optional[str] = None,
234
+ endpoint: Optional[str] = None,
235
+ description: Optional[str] = None,
236
+ tags: Optional[list[str]] = None,
237
+ license: Optional[str] = None,
238
+ metadata: Optional[dict] = None,
239
+ checksums: Optional[list[ShardChecksum]] = None,
240
+ rkey: Optional[str] = None,
241
+ ) -> AtUri:
242
+ """Publish a dataset record with S3 storage.
173
243
 
174
- dataset_record = DatasetRecord(
244
+ Args:
245
+ bucket: S3 bucket name.
246
+ keys: List of S3 object keys for shard files.
247
+ schema_uri: AT URI of the schema record.
248
+ name: Human-readable dataset name.
249
+ region: AWS region (e.g., 'us-east-1').
250
+ endpoint: Custom S3-compatible endpoint URL.
251
+ description: Human-readable description.
252
+ tags: Searchable tags for discovery.
253
+ license: SPDX license identifier.
254
+ metadata: Arbitrary metadata dictionary.
255
+ checksums: Per-shard checksums.
256
+ rkey: Optional explicit record key.
257
+
258
+ Returns:
259
+ The AT URI of the created dataset record.
260
+ """
261
+ if checksums and len(checksums) != len(keys):
262
+ raise ValueError(
263
+ f"checksums length ({len(checksums)}) must match "
264
+ f"keys length ({len(keys)})"
265
+ )
266
+
267
+ shards = [
268
+ S3ShardEntry(
269
+ key=key,
270
+ checksum=checksums[i] if checksums else _placeholder_checksum(),
271
+ )
272
+ for i, key in enumerate(keys)
273
+ ]
274
+
275
+ return self._create_record(
276
+ StorageS3(bucket=bucket, shards=shards, region=region, endpoint=endpoint),
175
277
  name=name,
176
- schema_ref=schema_uri,
177
- storage=storage,
278
+ schema_uri=schema_uri,
178
279
  description=description,
179
- tags=tags or [],
280
+ tags=tags,
180
281
  license=license,
181
- metadata=metadata_bytes,
282
+ metadata=metadata,
283
+ rkey=rkey,
182
284
  )
183
285
 
184
- return self.client.create_record(
185
- collection=f"{LEXICON_NAMESPACE}.record",
186
- record=dataset_record.to_record(),
286
+ def publish_with_blob_refs(
287
+ self,
288
+ blob_refs: list[dict],
289
+ schema_uri: str,
290
+ *,
291
+ name: str,
292
+ description: Optional[str] = None,
293
+ tags: Optional[list[str]] = None,
294
+ license: Optional[str] = None,
295
+ metadata: Optional[dict] = None,
296
+ rkey: Optional[str] = None,
297
+ ) -> AtUri:
298
+ """Publish a dataset record with pre-uploaded blob references.
299
+
300
+ Unlike ``publish_with_blobs`` (which takes raw bytes and uploads them),
301
+ this method accepts blob ref dicts that have already been uploaded to
302
+ the PDS. The refs are embedded directly in the record so the PDS
303
+ retains the blobs.
304
+
305
+ Args:
306
+ blob_refs: List of blob reference dicts as returned by
307
+ ``Atmosphere.upload_blob()``. Each dict must contain
308
+ ``$type``, ``ref`` (with ``$link``), ``mimeType``, and ``size``.
309
+ schema_uri: AT URI of the schema record.
310
+ name: Human-readable dataset name.
311
+ description: Human-readable description.
312
+ tags: Searchable tags for discovery.
313
+ license: SPDX license identifier.
314
+ metadata: Arbitrary metadata dictionary.
315
+ rkey: Optional explicit record key.
316
+
317
+ Returns:
318
+ The AT URI of the created dataset record.
319
+ """
320
+ blob_entries = [
321
+ BlobEntry(blob=ref, checksum=_placeholder_checksum()) for ref in blob_refs
322
+ ]
323
+
324
+ return self._create_record(
325
+ StorageBlobs(blobs=blob_entries),
326
+ name=name,
327
+ schema_uri=schema_uri,
328
+ description=description,
329
+ tags=tags,
330
+ license=license,
331
+ metadata=metadata,
187
332
  rkey=rkey,
188
- validate=False,
189
333
  )
190
334
 
191
335
  def publish_with_blobs(
@@ -225,37 +369,28 @@ class DatasetPublisher:
225
369
  Blobs are only retained by the PDS when referenced in a committed
226
370
  record. This method handles that automatically.
227
371
  """
228
- # Upload all blobs
229
- blob_refs = []
372
+ blob_entries = []
230
373
  for blob_data in blobs:
231
374
  blob_ref = self.client.upload_blob(blob_data, mime_type=mime_type)
232
- blob_refs.append(blob_ref)
233
-
234
- # Create storage location with blob references
235
- storage = StorageLocation(
236
- kind="blobs",
237
- blob_refs=blob_refs,
238
- )
375
+ import hashlib
239
376
 
240
- metadata_bytes: Optional[bytes] = None
241
- if metadata is not None:
242
- metadata_bytes = msgpack.packb(metadata)
377
+ digest = hashlib.sha256(blob_data).hexdigest()
378
+ blob_entries.append(
379
+ BlobEntry(
380
+ blob=blob_ref,
381
+ checksum=ShardChecksum(algorithm="sha256", digest=digest),
382
+ )
383
+ )
243
384
 
244
- dataset_record = DatasetRecord(
385
+ return self._create_record(
386
+ StorageBlobs(blobs=blob_entries),
245
387
  name=name,
246
- schema_ref=schema_uri,
247
- storage=storage,
388
+ schema_uri=schema_uri,
248
389
  description=description,
249
- tags=tags or [],
390
+ tags=tags,
250
391
  license=license,
251
- metadata=metadata_bytes,
252
- )
253
-
254
- return self.client.create_record(
255
- collection=f"{LEXICON_NAMESPACE}.record",
256
- record=dataset_record.to_record(),
392
+ metadata=metadata,
257
393
  rkey=rkey,
258
- validate=False,
259
394
  )
260
395
 
261
396
 
@@ -310,6 +445,18 @@ class DatasetLoader:
310
445
 
311
446
  return record
312
447
 
448
+ def get_typed(self, uri: str | AtUri) -> LexDatasetRecord:
449
+ """Fetch a dataset record and return as a typed object.
450
+
451
+ Args:
452
+ uri: The AT URI of the dataset record.
453
+
454
+ Returns:
455
+ LexDatasetRecord instance.
456
+ """
457
+ record = self.get(uri)
458
+ return LexDatasetRecord.from_record(record)
459
+
313
460
  def list_all(
314
461
  self,
315
462
  repo: Optional[str] = None,
@@ -333,7 +480,7 @@ class DatasetLoader:
333
480
  uri: The AT URI of the dataset record.
334
481
 
335
482
  Returns:
336
- Either "external" or "blobs".
483
+ One of "http", "s3", "blobs", or "external" (legacy).
337
484
 
338
485
  Raises:
339
486
  ValueError: If storage type is unknown.
@@ -342,16 +489,22 @@ class DatasetLoader:
342
489
  storage = record.get("storage", {})
343
490
  storage_type = storage.get("$type", "")
344
491
 
345
- if "storageExternal" in storage_type:
346
- return "external"
492
+ if "storageHttp" in storage_type:
493
+ return "http"
494
+ elif "storageS3" in storage_type:
495
+ return "s3"
347
496
  elif "storageBlobs" in storage_type:
348
497
  return "blobs"
498
+ elif "storageExternal" in storage_type:
499
+ return "external"
349
500
  else:
350
501
  raise ValueError(f"Unknown storage type: {storage_type}")
351
502
 
352
503
  def get_urls(self, uri: str | AtUri) -> list[str]:
353
504
  """Get the WebDataset URLs from a dataset record.
354
505
 
506
+ Supports storageHttp, storageS3, and legacy storageExternal formats.
507
+
355
508
  Args:
356
509
  uri: The AT URI of the dataset record.
357
510
 
@@ -359,22 +512,61 @@ class DatasetLoader:
359
512
  List of WebDataset URLs.
360
513
 
361
514
  Raises:
362
- ValueError: If the storage type is not external URLs.
515
+ ValueError: If the storage type is blob-only.
363
516
  """
364
517
  record = self.get(uri)
365
518
  storage = record.get("storage", {})
366
-
367
519
  storage_type = storage.get("$type", "")
368
- if "storageExternal" in storage_type:
520
+
521
+ if "storageHttp" in storage_type:
522
+ return [s["url"] for s in storage.get("shards", [])]
523
+ elif "storageS3" in storage_type:
524
+ bucket = storage.get("bucket", "")
525
+ endpoint = storage.get("endpoint")
526
+ urls = []
527
+ for s in storage.get("shards", []):
528
+ if endpoint:
529
+ urls.append(f"{endpoint.rstrip('/')}/{bucket}/{s['key']}")
530
+ else:
531
+ urls.append(f"s3://{bucket}/{s['key']}")
532
+ return urls
533
+ elif "storageExternal" in storage_type:
369
534
  return storage.get("urls", [])
370
535
  elif "storageBlobs" in storage_type:
371
536
  raise ValueError(
372
- "Dataset uses blob storage, not external URLs. "
373
- "Use get_blob_urls() instead."
537
+ "Dataset uses blob storage, not URLs. Use get_blob_urls() instead."
374
538
  )
375
539
  else:
376
540
  raise ValueError(f"Unknown storage type: {storage_type}")
377
541
 
542
+ def get_s3_info(self, uri: str | AtUri) -> dict:
543
+ """Get S3 storage details from a dataset record.
544
+
545
+ Args:
546
+ uri: The AT URI of the dataset record.
547
+
548
+ Returns:
549
+ Dict with keys: bucket, keys, region (optional), endpoint (optional).
550
+
551
+ Raises:
552
+ ValueError: If the storage type is not S3.
553
+ """
554
+ record = self.get(uri)
555
+ storage = record.get("storage", {})
556
+ storage_type = storage.get("$type", "")
557
+
558
+ if "storageS3" not in storage_type:
559
+ raise ValueError(
560
+ f"Dataset does not use S3 storage. Storage type: {storage_type}"
561
+ )
562
+
563
+ return {
564
+ "bucket": storage.get("bucket", ""),
565
+ "keys": [s["key"] for s in storage.get("shards", [])],
566
+ "region": storage.get("region"),
567
+ "endpoint": storage.get("endpoint"),
568
+ }
569
+
378
570
  def get_blobs(self, uri: str | AtUri) -> list[dict]:
379
571
  """Get the blob references from a dataset record.
380
572
 
@@ -382,7 +574,7 @@ class DatasetLoader:
382
574
  uri: The AT URI of the dataset record.
383
575
 
384
576
  Returns:
385
- List of blob reference dicts with keys: $type, ref, mimeType, size.
577
+ List of blob entry dicts.
386
578
 
387
579
  Raises:
388
580
  ValueError: If the storage type is not blobs.
@@ -393,12 +585,11 @@ class DatasetLoader:
393
585
  storage_type = storage.get("$type", "")
394
586
  if "storageBlobs" in storage_type:
395
587
  return storage.get("blobs", [])
396
- elif "storageExternal" in storage_type:
588
+ else:
397
589
  raise ValueError(
398
- "Dataset uses external URL storage, not blobs. Use get_urls() instead."
590
+ f"Dataset does not use blob storage. Storage type: {storage_type}. "
591
+ "Use get_urls() instead."
399
592
  )
400
- else:
401
- raise ValueError(f"Unknown storage type: {storage_type}")
402
593
 
403
594
  def get_blob_urls(self, uri: str | AtUri) -> list[str]:
404
595
  """Get fetchable URLs for blob-stored dataset shards.
@@ -420,12 +611,13 @@ class DatasetLoader:
420
611
  else:
421
612
  parsed_uri = uri
422
613
 
423
- blobs = self.get_blobs(uri)
614
+ blob_entries = self.get_blobs(uri)
424
615
  did = parsed_uri.authority
425
616
 
426
617
  urls = []
427
- for blob in blobs:
428
- # Extract CID from blob reference
618
+ for entry in blob_entries:
619
+ # Handle both new blobEntry format and legacy bare blob format
620
+ blob = entry.get("blob", entry)
429
621
  ref = blob.get("ref", {})
430
622
  cid = ref.get("$link") if isinstance(ref, dict) else str(ref)
431
623
  if cid:
@@ -462,7 +654,7 @@ class DatasetLoader:
462
654
  You must provide the sample type class, which should match the
463
655
  schema referenced by the record.
464
656
 
465
- Supports both external URL storage and ATProto blob storage.
657
+ Supports HTTP, S3, blob, and legacy external storage.
466
658
 
467
659
  Args:
468
660
  uri: The AT URI of the dataset record.
@@ -485,10 +677,10 @@ class DatasetLoader:
485
677
 
486
678
  storage_type = self.get_storage_type(uri)
487
679
 
488
- if storage_type == "external":
489
- urls = self.get_urls(uri)
490
- else:
680
+ if storage_type == "blobs":
491
681
  urls = self.get_blob_urls(uri)
682
+ else:
683
+ urls = self.get_urls(uri)
492
684
 
493
685
  if not urls:
494
686
  raise ValueError("Dataset record has no storage URLs")