atdata 0.2.0a1__py3-none-any.whl → 0.2.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atdata/_sources.py ADDED
@@ -0,0 +1,515 @@
1
+ """Data source implementations for streaming dataset shards.
2
+
3
+ This module provides concrete implementations of the DataSource protocol,
4
+ enabling Dataset to work with various data backends without URL transformation
5
+ hacks.
6
+
7
+ Classes:
8
+ URLSource: WebDataset-compatible URLs (http, https, pipe, gs, etc.)
9
+ S3Source: S3-compatible storage with explicit credentials
10
+
11
+ The key insight is that WebDataset's tar_file_expander only needs
12
+ {url: str, stream: IO} dicts - it doesn't care how streams are created.
13
+ By providing streams directly, we can support private repos, custom
14
+ endpoints, and future backends like ATProto blobs.
15
+
16
+ Example:
17
+ ::
18
+
19
+ >>> # Standard URL (uses WebDataset's gopen)
20
+ >>> source = URLSource("https://example.com/data-{000..009}.tar")
21
+ >>> ds = Dataset[MySample](source)
22
+ >>>
23
+ >>> # Private S3 with credentials
24
+ >>> source = S3Source(
25
+ ... bucket="my-bucket",
26
+ ... keys=["train/shard-000.tar", "train/shard-001.tar"],
27
+ ... endpoint="https://my-r2.cloudflarestorage.com",
28
+ ... access_key="...",
29
+ ... secret_key="...",
30
+ ... )
31
+ >>> ds = Dataset[MySample](source)
32
+ """
33
+
34
+ from __future__ import annotations
35
+
36
+ import os
37
+ from dataclasses import dataclass, field
38
+ from typing import IO, Iterator, Any
39
+
40
+ import braceexpand
41
+ import webdataset as wds
42
+
43
+
44
+ @dataclass
45
+ class URLSource:
46
+ """Data source for WebDataset-compatible URLs.
47
+
48
+ Wraps WebDataset's gopen to open URLs using built-in handlers for
49
+ http, https, pipe, gs, hf, sftp, etc. Supports brace expansion
50
+ for shard patterns like "data-{000..099}.tar".
51
+
52
+ This is the default source type when a string URL is passed to Dataset.
53
+
54
+ Attributes:
55
+ url: URL or brace pattern for the shards.
56
+
57
+ Example:
58
+ ::
59
+
60
+ >>> source = URLSource("https://example.com/train-{000..009}.tar")
61
+ >>> for shard_id, stream in source.shards:
62
+ ... print(f"Streaming {shard_id}")
63
+ """
64
+
65
+ url: str
66
+
67
+ def list_shards(self) -> list[str]:
68
+ """Expand brace pattern and return list of shard URLs."""
69
+ return list(braceexpand.braceexpand(self.url))
70
+
71
+ # Legacy alias for backwards compatibility
72
+ @property
73
+ def shard_list(self) -> list[str]:
74
+ """Expand brace pattern and return list of shard URLs (deprecated, use list_shards())."""
75
+ return self.list_shards()
76
+
77
+ @property
78
+ def shards(self) -> Iterator[tuple[str, IO[bytes]]]:
79
+ """Lazily yield (url, stream) pairs for each shard.
80
+
81
+ Uses WebDataset's gopen to open URLs, which handles various schemes:
82
+ - http/https: via curl
83
+ - pipe: shell command streaming
84
+ - gs: Google Cloud Storage via gsutil
85
+ - hf: HuggingFace Hub
86
+ - file or no scheme: local filesystem
87
+
88
+ Yields:
89
+ Tuple of (url, file-like stream).
90
+ """
91
+ for url in self.list_shards():
92
+ stream = wds.gopen(url, mode="rb")
93
+ yield url, stream
94
+
95
+ def open_shard(self, shard_id: str) -> IO[bytes]:
96
+ """Open a single shard by URL.
97
+
98
+ Args:
99
+ shard_id: URL of the shard to open.
100
+
101
+ Returns:
102
+ File-like stream from gopen.
103
+
104
+ Raises:
105
+ KeyError: If shard_id is not in list_shards().
106
+ """
107
+ if shard_id not in self.list_shards():
108
+ raise KeyError(f"Shard not found: {shard_id}")
109
+ return wds.gopen(shard_id, mode="rb")
110
+
111
+
112
+ @dataclass
113
+ class S3Source:
114
+ """Data source for S3-compatible storage with explicit credentials.
115
+
116
+ Uses boto3 to stream directly from S3, supporting:
117
+ - Standard AWS S3
118
+ - S3-compatible endpoints (Cloudflare R2, MinIO, etc.)
119
+ - Private buckets with credentials
120
+ - IAM role authentication (when keys not provided)
121
+
122
+ Unlike URL-based approaches, this doesn't require URL transformation
123
+ or global gopen_schemes registration. Credentials are scoped to the
124
+ source instance.
125
+
126
+ Attributes:
127
+ bucket: S3 bucket name.
128
+ keys: List of object keys (paths within bucket).
129
+ endpoint: Optional custom endpoint URL for S3-compatible services.
130
+ access_key: Optional AWS access key ID.
131
+ secret_key: Optional AWS secret access key.
132
+ region: Optional AWS region (defaults to us-east-1).
133
+
134
+ Example:
135
+ ::
136
+
137
+ >>> source = S3Source(
138
+ ... bucket="my-datasets",
139
+ ... keys=["train/shard-000.tar", "train/shard-001.tar"],
140
+ ... endpoint="https://abc123.r2.cloudflarestorage.com",
141
+ ... access_key="AKIAIOSFODNN7EXAMPLE",
142
+ ... secret_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
143
+ ... )
144
+ >>> for shard_id, stream in source.shards:
145
+ ... process(stream)
146
+ """
147
+
148
+ bucket: str
149
+ keys: list[str]
150
+ endpoint: str | None = None
151
+ access_key: str | None = None
152
+ secret_key: str | None = None
153
+ region: str | None = None
154
+ _client: Any = field(default=None, repr=False, compare=False)
155
+
156
+ def _get_client(self) -> Any:
157
+ """Get or create boto3 S3 client."""
158
+ if self._client is not None:
159
+ return self._client
160
+
161
+ import boto3
162
+
163
+ client_kwargs: dict[str, Any] = {}
164
+
165
+ if self.endpoint:
166
+ client_kwargs["endpoint_url"] = self.endpoint
167
+
168
+ if self.access_key and self.secret_key:
169
+ client_kwargs["aws_access_key_id"] = self.access_key
170
+ client_kwargs["aws_secret_access_key"] = self.secret_key
171
+
172
+ if self.region:
173
+ client_kwargs["region_name"] = self.region
174
+ elif not self.endpoint:
175
+ # Default region for AWS S3
176
+ client_kwargs["region_name"] = os.environ.get("AWS_DEFAULT_REGION", "us-east-1")
177
+
178
+ self._client = boto3.client("s3", **client_kwargs)
179
+ return self._client
180
+
181
+ def list_shards(self) -> list[str]:
182
+ """Return list of S3 URIs for the shards."""
183
+ return [f"s3://{self.bucket}/{key}" for key in self.keys]
184
+
185
+ # Legacy alias for backwards compatibility
186
+ @property
187
+ def shard_list(self) -> list[str]:
188
+ """Return list of S3 URIs for the shards (deprecated, use list_shards())."""
189
+ return self.list_shards()
190
+
191
+ @property
192
+ def shards(self) -> Iterator[tuple[str, IO[bytes]]]:
193
+ """Lazily yield (s3_uri, stream) pairs for each shard.
194
+
195
+ Uses boto3 to get streaming response bodies, which are file-like
196
+ objects that can be read directly by tarfile.
197
+
198
+ Yields:
199
+ Tuple of (s3://bucket/key URI, StreamingBody).
200
+ """
201
+ client = self._get_client()
202
+
203
+ for key in self.keys:
204
+ response = client.get_object(Bucket=self.bucket, Key=key)
205
+ stream = response["Body"]
206
+ uri = f"s3://{self.bucket}/{key}"
207
+ yield uri, stream
208
+
209
+ def open_shard(self, shard_id: str) -> IO[bytes]:
210
+ """Open a single shard by S3 URI.
211
+
212
+ Args:
213
+ shard_id: S3 URI of the shard (s3://bucket/key).
214
+
215
+ Returns:
216
+ StreamingBody for reading the object.
217
+
218
+ Raises:
219
+ KeyError: If shard_id is not in list_shards().
220
+ """
221
+ if shard_id not in self.list_shards():
222
+ raise KeyError(f"Shard not found: {shard_id}")
223
+
224
+ # Parse s3://bucket/key -> key
225
+ if not shard_id.startswith(f"s3://{self.bucket}/"):
226
+ raise KeyError(f"Shard not in this bucket: {shard_id}")
227
+
228
+ key = shard_id[len(f"s3://{self.bucket}/"):]
229
+ client = self._get_client()
230
+ response = client.get_object(Bucket=self.bucket, Key=key)
231
+ return response["Body"]
232
+
233
+ @classmethod
234
+ def from_urls(
235
+ cls,
236
+ urls: list[str],
237
+ *,
238
+ endpoint: str | None = None,
239
+ access_key: str | None = None,
240
+ secret_key: str | None = None,
241
+ region: str | None = None,
242
+ ) -> "S3Source":
243
+ """Create S3Source from s3:// URLs.
244
+
245
+ Parses s3://bucket/key URLs and extracts bucket and keys.
246
+ All URLs must be in the same bucket.
247
+
248
+ Args:
249
+ urls: List of s3:// URLs.
250
+ endpoint: Optional custom endpoint.
251
+ access_key: Optional access key.
252
+ secret_key: Optional secret key.
253
+ region: Optional region.
254
+
255
+ Returns:
256
+ S3Source configured for the given URLs.
257
+
258
+ Raises:
259
+ ValueError: If URLs are not valid s3:// URLs or span multiple buckets.
260
+
261
+ Example:
262
+ ::
263
+
264
+ >>> source = S3Source.from_urls(
265
+ ... ["s3://my-bucket/train-000.tar", "s3://my-bucket/train-001.tar"],
266
+ ... endpoint="https://r2.example.com",
267
+ ... )
268
+ """
269
+ if not urls:
270
+ raise ValueError("urls cannot be empty")
271
+
272
+ buckets: set[str] = set()
273
+ keys: list[str] = []
274
+
275
+ for url in urls:
276
+ if not url.startswith("s3://"):
277
+ raise ValueError(f"Not an S3 URL: {url}")
278
+
279
+ # s3://bucket/path/to/key -> bucket, path/to/key
280
+ path = url[5:] # Remove 's3://'
281
+ if "/" not in path:
282
+ raise ValueError(f"Invalid S3 URL (no key): {url}")
283
+
284
+ bucket, key = path.split("/", 1)
285
+ buckets.add(bucket)
286
+ keys.append(key)
287
+
288
+ if len(buckets) > 1:
289
+ raise ValueError(f"All URLs must be in the same bucket, got: {buckets}")
290
+
291
+ return cls(
292
+ bucket=buckets.pop(),
293
+ keys=keys,
294
+ endpoint=endpoint,
295
+ access_key=access_key,
296
+ secret_key=secret_key,
297
+ region=region,
298
+ )
299
+
300
+ @classmethod
301
+ def from_credentials(
302
+ cls,
303
+ credentials: dict[str, str],
304
+ bucket: str,
305
+ keys: list[str],
306
+ ) -> "S3Source":
307
+ """Create S3Source from a credentials dictionary.
308
+
309
+ Accepts the same credential format used by S3DataStore.
310
+
311
+ Args:
312
+ credentials: Dict with AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY,
313
+ and optionally AWS_ENDPOINT.
314
+ bucket: S3 bucket name.
315
+ keys: List of object keys.
316
+
317
+ Returns:
318
+ Configured S3Source.
319
+
320
+ Example:
321
+ ::
322
+
323
+ >>> creds = {
324
+ ... "AWS_ACCESS_KEY_ID": "...",
325
+ ... "AWS_SECRET_ACCESS_KEY": "...",
326
+ ... "AWS_ENDPOINT": "https://r2.example.com",
327
+ ... }
328
+ >>> source = S3Source.from_credentials(creds, "my-bucket", ["data.tar"])
329
+ """
330
+ return cls(
331
+ bucket=bucket,
332
+ keys=keys,
333
+ endpoint=credentials.get("AWS_ENDPOINT"),
334
+ access_key=credentials.get("AWS_ACCESS_KEY_ID"),
335
+ secret_key=credentials.get("AWS_SECRET_ACCESS_KEY"),
336
+ region=credentials.get("AWS_REGION"),
337
+ )
338
+
339
+
340
+ @dataclass
341
+ class BlobSource:
342
+ """Data source for ATProto PDS blob storage.
343
+
344
+ Streams dataset shards stored as blobs on an ATProto Personal Data Server.
345
+ Each shard is identified by a blob reference containing the DID and CID.
346
+
347
+ This source resolves blob references to HTTP URLs and streams the content
348
+ directly, supporting efficient iteration over shards without downloading
349
+ everything upfront.
350
+
351
+ Attributes:
352
+ blob_refs: List of blob reference dicts with 'did' and 'cid' keys.
353
+ pds_endpoint: Optional PDS endpoint URL. If not provided, resolved from DID.
354
+
355
+ Example:
356
+ ::
357
+
358
+ >>> source = BlobSource(
359
+ ... blob_refs=[
360
+ ... {"did": "did:plc:abc123", "cid": "bafyrei..."},
361
+ ... {"did": "did:plc:abc123", "cid": "bafyrei..."},
362
+ ... ],
363
+ ... )
364
+ >>> for shard_id, stream in source.shards:
365
+ ... process(stream)
366
+ """
367
+
368
+ blob_refs: list[dict[str, str]]
369
+ pds_endpoint: str | None = None
370
+ _endpoint_cache: dict[str, str] = field(default_factory=dict, repr=False, compare=False)
371
+
372
+ def _resolve_pds_endpoint(self, did: str) -> str:
373
+ """Resolve PDS endpoint for a DID, with caching."""
374
+ if did in self._endpoint_cache:
375
+ return self._endpoint_cache[did]
376
+
377
+ if self.pds_endpoint:
378
+ self._endpoint_cache[did] = self.pds_endpoint
379
+ return self.pds_endpoint
380
+
381
+ import requests
382
+
383
+ # Resolve via plc.directory
384
+ if did.startswith("did:plc:"):
385
+ plc_url = f"https://plc.directory/{did}"
386
+ response = requests.get(plc_url, timeout=10)
387
+ response.raise_for_status()
388
+ doc = response.json()
389
+
390
+ for service in doc.get("service", []):
391
+ if service.get("type") == "AtprotoPersonalDataServer":
392
+ endpoint = service.get("serviceEndpoint", "")
393
+ self._endpoint_cache[did] = endpoint
394
+ return endpoint
395
+
396
+ raise ValueError(f"Could not resolve PDS endpoint for {did}")
397
+
398
+ def _get_blob_url(self, did: str, cid: str) -> str:
399
+ """Get HTTP URL for fetching a blob."""
400
+ endpoint = self._resolve_pds_endpoint(did)
401
+ return f"{endpoint}/xrpc/com.atproto.sync.getBlob?did={did}&cid={cid}"
402
+
403
+ def _make_shard_id(self, ref: dict[str, str]) -> str:
404
+ """Create shard identifier from blob reference."""
405
+ return f"at://{ref['did']}/blob/{ref['cid']}"
406
+
407
+ def list_shards(self) -> list[str]:
408
+ """Return list of AT URI-style shard identifiers."""
409
+ return [self._make_shard_id(ref) for ref in self.blob_refs]
410
+
411
+ @property
412
+ def shards(self) -> Iterator[tuple[str, IO[bytes]]]:
413
+ """Lazily yield (at_uri, stream) pairs for each shard.
414
+
415
+ Fetches blobs via HTTP from the PDS and yields streaming responses.
416
+
417
+ Yields:
418
+ Tuple of (at://did/blob/cid URI, streaming response body).
419
+ """
420
+ import requests
421
+
422
+ for ref in self.blob_refs:
423
+ did = ref["did"]
424
+ cid = ref["cid"]
425
+ url = self._get_blob_url(did, cid)
426
+
427
+ response = requests.get(url, stream=True, timeout=60)
428
+ response.raise_for_status()
429
+
430
+ shard_id = self._make_shard_id(ref)
431
+ # Wrap response in a file-like object
432
+ yield shard_id, response.raw
433
+
434
+ def open_shard(self, shard_id: str) -> IO[bytes]:
435
+ """Open a single shard by its AT URI.
436
+
437
+ Args:
438
+ shard_id: AT URI of the shard (at://did/blob/cid).
439
+
440
+ Returns:
441
+ Streaming response body for reading the blob.
442
+
443
+ Raises:
444
+ KeyError: If shard_id is not in list_shards().
445
+ ValueError: If shard_id format is invalid.
446
+ """
447
+ if shard_id not in self.list_shards():
448
+ raise KeyError(f"Shard not found: {shard_id}")
449
+
450
+ # Parse at://did/blob/cid
451
+ if not shard_id.startswith("at://"):
452
+ raise ValueError(f"Invalid shard ID format: {shard_id}")
453
+
454
+ parts = shard_id[5:].split("/") # Remove 'at://'
455
+ if len(parts) != 3 or parts[1] != "blob":
456
+ raise ValueError(f"Invalid blob URI format: {shard_id}")
457
+
458
+ did, _, cid = parts
459
+ url = self._get_blob_url(did, cid)
460
+
461
+ import requests
462
+ response = requests.get(url, stream=True, timeout=60)
463
+ response.raise_for_status()
464
+ return response.raw
465
+
466
+ @classmethod
467
+ def from_refs(
468
+ cls,
469
+ refs: list[dict],
470
+ *,
471
+ pds_endpoint: str | None = None,
472
+ ) -> "BlobSource":
473
+ """Create BlobSource from blob reference dicts.
474
+
475
+ Accepts blob references in the format returned by upload_blob:
476
+ ``{"$type": "blob", "ref": {"$link": "cid"}, ...}``
477
+
478
+ Also accepts simplified format: ``{"did": "...", "cid": "..."}``
479
+
480
+ Args:
481
+ refs: List of blob reference dicts.
482
+ pds_endpoint: Optional PDS endpoint to use for all blobs.
483
+
484
+ Returns:
485
+ Configured BlobSource.
486
+
487
+ Raises:
488
+ ValueError: If refs is empty or format is invalid.
489
+ """
490
+ if not refs:
491
+ raise ValueError("refs cannot be empty")
492
+
493
+ blob_refs: list[dict[str, str]] = []
494
+
495
+ for ref in refs:
496
+ if "did" in ref and "cid" in ref:
497
+ # Simple format
498
+ blob_refs.append({"did": ref["did"], "cid": ref["cid"]})
499
+ elif "ref" in ref and "$link" in ref.get("ref", {}):
500
+ # ATProto blob format - need DID from elsewhere
501
+ raise ValueError(
502
+ "ATProto blob format requires 'did' field. "
503
+ "Use from_record_storage() for records with storage.blobs."
504
+ )
505
+ else:
506
+ raise ValueError(f"Invalid blob reference format: {ref}")
507
+
508
+ return cls(blob_refs=blob_refs, pds_endpoint=pds_endpoint)
509
+
510
+
511
+ __all__ = [
512
+ "URLSource",
513
+ "S3Source",
514
+ "BlobSource",
515
+ ]