atdata 0.2.0a1__py3-none-any.whl → 0.2.3b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atdata/_sources.py ADDED
@@ -0,0 +1,508 @@
1
+ """Data source implementations for streaming dataset shards.
2
+
3
+ This module provides concrete implementations of the DataSource protocol,
4
+ enabling Dataset to work with various data backends without URL transformation
5
+ hacks.
6
+
7
+ Classes:
8
+ URLSource: WebDataset-compatible URLs (http, https, pipe, gs, etc.)
9
+ S3Source: S3-compatible storage with explicit credentials
10
+
11
+ The key insight is that WebDataset's tar_file_expander only needs
12
+ {url: str, stream: IO} dicts - it doesn't care how streams are created.
13
+ By providing streams directly, we can support private repos, custom
14
+ endpoints, and future backends like ATProto blobs.
15
+
16
+ Examples:
17
+ >>> # Standard URL (uses WebDataset's gopen)
18
+ >>> source = URLSource("https://example.com/data-{000..009}.tar")
19
+ >>> ds = Dataset[MySample](source)
20
+ >>>
21
+ >>> # Private S3 with credentials
22
+ >>> source = S3Source(
23
+ ... bucket="my-bucket",
24
+ ... keys=["train/shard-000.tar", "train/shard-001.tar"],
25
+ ... endpoint="https://my-r2.cloudflarestorage.com",
26
+ ... access_key="...",
27
+ ... secret_key="...",
28
+ ... )
29
+ >>> ds = Dataset[MySample](source)
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ import os
35
+ from dataclasses import dataclass, field
36
+ from typing import IO, Iterator, Any
37
+
38
+ import braceexpand
39
+ import webdataset as wds
40
+
41
+
42
+ @dataclass
43
+ class URLSource:
44
+ """Data source for WebDataset-compatible URLs.
45
+
46
+ Wraps WebDataset's gopen to open URLs using built-in handlers for
47
+ http, https, pipe, gs, hf, sftp, etc. Supports brace expansion
48
+ for shard patterns like "data-{000..099}.tar".
49
+
50
+ This is the default source type when a string URL is passed to Dataset.
51
+
52
+ Attributes:
53
+ url: URL or brace pattern for the shards.
54
+
55
+ Examples:
56
+ >>> source = URLSource("https://example.com/train-{000..009}.tar")
57
+ >>> for shard_id, stream in source.shards:
58
+ ... print(f"Streaming {shard_id}")
59
+ """
60
+
61
+ url: str
62
+
63
+ def list_shards(self) -> list[str]:
64
+ """Expand brace pattern and return list of shard URLs."""
65
+ return list(braceexpand.braceexpand(self.url))
66
+
67
+ # Legacy alias for backwards compatibility
68
+ @property
69
+ def shard_list(self) -> list[str]:
70
+ """Expand brace pattern and return list of shard URLs (deprecated, use list_shards())."""
71
+ return self.list_shards()
72
+
73
+ @property
74
+ def shards(self) -> Iterator[tuple[str, IO[bytes]]]:
75
+ """Lazily yield (url, stream) pairs for each shard.
76
+
77
+ Uses WebDataset's gopen to open URLs, which handles various schemes:
78
+ - http/https: via curl
79
+ - pipe: shell command streaming
80
+ - gs: Google Cloud Storage via gsutil
81
+ - hf: HuggingFace Hub
82
+ - file or no scheme: local filesystem
83
+
84
+ Yields:
85
+ Tuple of (url, file-like stream).
86
+ """
87
+ for url in self.list_shards():
88
+ stream = wds.gopen(url, mode="rb")
89
+ yield url, stream
90
+
91
+ def open_shard(self, shard_id: str) -> IO[bytes]:
92
+ """Open a single shard by URL.
93
+
94
+ Args:
95
+ shard_id: URL of the shard to open.
96
+
97
+ Returns:
98
+ File-like stream from gopen.
99
+
100
+ Raises:
101
+ KeyError: If shard_id is not in list_shards().
102
+ """
103
+ if shard_id not in self.list_shards():
104
+ raise KeyError(f"Shard not found: {shard_id}")
105
+ return wds.gopen(shard_id, mode="rb")
106
+
107
+
108
+ @dataclass
109
+ class S3Source:
110
+ """Data source for S3-compatible storage with explicit credentials.
111
+
112
+ Uses boto3 to stream directly from S3, supporting:
113
+ - Standard AWS S3
114
+ - S3-compatible endpoints (Cloudflare R2, MinIO, etc.)
115
+ - Private buckets with credentials
116
+ - IAM role authentication (when keys not provided)
117
+
118
+ Unlike URL-based approaches, this doesn't require URL transformation
119
+ or global gopen_schemes registration. Credentials are scoped to the
120
+ source instance.
121
+
122
+ Attributes:
123
+ bucket: S3 bucket name.
124
+ keys: List of object keys (paths within bucket).
125
+ endpoint: Optional custom endpoint URL for S3-compatible services.
126
+ access_key: Optional AWS access key ID.
127
+ secret_key: Optional AWS secret access key.
128
+ region: Optional AWS region (defaults to us-east-1).
129
+
130
+ Examples:
131
+ >>> source = S3Source(
132
+ ... bucket="my-datasets",
133
+ ... keys=["train/shard-000.tar", "train/shard-001.tar"],
134
+ ... endpoint="https://abc123.r2.cloudflarestorage.com",
135
+ ... access_key="AKIAIOSFODNN7EXAMPLE",
136
+ ... secret_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
137
+ ... )
138
+ >>> for shard_id, stream in source.shards:
139
+ ... process(stream)
140
+ """
141
+
142
+ bucket: str
143
+ keys: list[str]
144
+ endpoint: str | None = None
145
+ access_key: str | None = None
146
+ secret_key: str | None = None
147
+ region: str | None = None
148
+ _client: Any = field(default=None, repr=False, compare=False)
149
+
150
+ def _get_client(self) -> Any:
151
+ """Get or create boto3 S3 client."""
152
+ if self._client is not None:
153
+ return self._client
154
+
155
+ import boto3
156
+
157
+ client_kwargs: dict[str, Any] = {}
158
+
159
+ if self.endpoint:
160
+ client_kwargs["endpoint_url"] = self.endpoint
161
+
162
+ if self.access_key and self.secret_key:
163
+ client_kwargs["aws_access_key_id"] = self.access_key
164
+ client_kwargs["aws_secret_access_key"] = self.secret_key
165
+
166
+ if self.region:
167
+ client_kwargs["region_name"] = self.region
168
+ elif not self.endpoint:
169
+ # Default region for AWS S3
170
+ client_kwargs["region_name"] = os.environ.get(
171
+ "AWS_DEFAULT_REGION", "us-east-1"
172
+ )
173
+
174
+ self._client = boto3.client("s3", **client_kwargs)
175
+ return self._client
176
+
177
+ def list_shards(self) -> list[str]:
178
+ """Return list of S3 URIs for the shards."""
179
+ return [f"s3://{self.bucket}/{key}" for key in self.keys]
180
+
181
+ # Legacy alias for backwards compatibility
182
+ @property
183
+ def shard_list(self) -> list[str]:
184
+ """Return list of S3 URIs for the shards (deprecated, use list_shards())."""
185
+ return self.list_shards()
186
+
187
+ @property
188
+ def shards(self) -> Iterator[tuple[str, IO[bytes]]]:
189
+ """Lazily yield (s3_uri, stream) pairs for each shard.
190
+
191
+ Uses boto3 to get streaming response bodies, which are file-like
192
+ objects that can be read directly by tarfile.
193
+
194
+ Yields:
195
+ Tuple of (s3://bucket/key URI, StreamingBody).
196
+ """
197
+ client = self._get_client()
198
+
199
+ for key in self.keys:
200
+ response = client.get_object(Bucket=self.bucket, Key=key)
201
+ stream = response["Body"]
202
+ uri = f"s3://{self.bucket}/{key}"
203
+ yield uri, stream
204
+
205
+ def open_shard(self, shard_id: str) -> IO[bytes]:
206
+ """Open a single shard by S3 URI.
207
+
208
+ Args:
209
+ shard_id: S3 URI of the shard (s3://bucket/key).
210
+
211
+ Returns:
212
+ StreamingBody for reading the object.
213
+
214
+ Raises:
215
+ KeyError: If shard_id is not in list_shards().
216
+ """
217
+ if shard_id not in self.list_shards():
218
+ raise KeyError(f"Shard not found: {shard_id}")
219
+
220
+ # Parse s3://bucket/key -> key
221
+ if not shard_id.startswith(f"s3://{self.bucket}/"):
222
+ raise KeyError(f"Shard not in this bucket: {shard_id}")
223
+
224
+ key = shard_id[len(f"s3://{self.bucket}/") :]
225
+ client = self._get_client()
226
+ response = client.get_object(Bucket=self.bucket, Key=key)
227
+ return response["Body"]
228
+
229
+ @classmethod
230
+ def from_urls(
231
+ cls,
232
+ urls: list[str],
233
+ *,
234
+ endpoint: str | None = None,
235
+ access_key: str | None = None,
236
+ secret_key: str | None = None,
237
+ region: str | None = None,
238
+ ) -> "S3Source":
239
+ """Create S3Source from s3:// URLs.
240
+
241
+ Parses s3://bucket/key URLs and extracts bucket and keys.
242
+ All URLs must be in the same bucket.
243
+
244
+ Args:
245
+ urls: List of s3:// URLs.
246
+ endpoint: Optional custom endpoint.
247
+ access_key: Optional access key.
248
+ secret_key: Optional secret key.
249
+ region: Optional region.
250
+
251
+ Returns:
252
+ S3Source configured for the given URLs.
253
+
254
+ Raises:
255
+ ValueError: If URLs are not valid s3:// URLs or span multiple buckets.
256
+
257
+ Examples:
258
+ >>> source = S3Source.from_urls(
259
+ ... ["s3://my-bucket/train-000.tar", "s3://my-bucket/train-001.tar"],
260
+ ... endpoint="https://r2.example.com",
261
+ ... )
262
+ """
263
+ if not urls:
264
+ raise ValueError("urls cannot be empty")
265
+
266
+ buckets: set[str] = set()
267
+ keys: list[str] = []
268
+
269
+ for url in urls:
270
+ if not url.startswith("s3://"):
271
+ raise ValueError(f"Not an S3 URL: {url}")
272
+
273
+ # s3://bucket/path/to/key -> bucket, path/to/key
274
+ path = url[5:] # Remove 's3://'
275
+ if "/" not in path:
276
+ raise ValueError(f"Invalid S3 URL (no key): {url}")
277
+
278
+ bucket, key = path.split("/", 1)
279
+ buckets.add(bucket)
280
+ keys.append(key)
281
+
282
+ if len(buckets) > 1:
283
+ raise ValueError(f"All URLs must be in the same bucket, got: {buckets}")
284
+
285
+ return cls(
286
+ bucket=buckets.pop(),
287
+ keys=keys,
288
+ endpoint=endpoint,
289
+ access_key=access_key,
290
+ secret_key=secret_key,
291
+ region=region,
292
+ )
293
+
294
+ @classmethod
295
+ def from_credentials(
296
+ cls,
297
+ credentials: dict[str, str],
298
+ bucket: str,
299
+ keys: list[str],
300
+ ) -> "S3Source":
301
+ """Create S3Source from a credentials dictionary.
302
+
303
+ Accepts the same credential format used by S3DataStore.
304
+
305
+ Args:
306
+ credentials: Dict with AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY,
307
+ and optionally AWS_ENDPOINT.
308
+ bucket: S3 bucket name.
309
+ keys: List of object keys.
310
+
311
+ Returns:
312
+ Configured S3Source.
313
+
314
+ Examples:
315
+ >>> creds = {
316
+ ... "AWS_ACCESS_KEY_ID": "...",
317
+ ... "AWS_SECRET_ACCESS_KEY": "...",
318
+ ... "AWS_ENDPOINT": "https://r2.example.com",
319
+ ... }
320
+ >>> source = S3Source.from_credentials(creds, "my-bucket", ["data.tar"])
321
+ """
322
+ return cls(
323
+ bucket=bucket,
324
+ keys=keys,
325
+ endpoint=credentials.get("AWS_ENDPOINT"),
326
+ access_key=credentials.get("AWS_ACCESS_KEY_ID"),
327
+ secret_key=credentials.get("AWS_SECRET_ACCESS_KEY"),
328
+ region=credentials.get("AWS_REGION"),
329
+ )
330
+
331
+
332
+ @dataclass
333
+ class BlobSource:
334
+ """Data source for ATProto PDS blob storage.
335
+
336
+ Streams dataset shards stored as blobs on an ATProto Personal Data Server.
337
+ Each shard is identified by a blob reference containing the DID and CID.
338
+
339
+ This source resolves blob references to HTTP URLs and streams the content
340
+ directly, supporting efficient iteration over shards without downloading
341
+ everything upfront.
342
+
343
+ Attributes:
344
+ blob_refs: List of blob reference dicts with 'did' and 'cid' keys.
345
+ pds_endpoint: Optional PDS endpoint URL. If not provided, resolved from DID.
346
+
347
+ Examples:
348
+ >>> source = BlobSource(
349
+ ... blob_refs=[
350
+ ... {"did": "did:plc:abc123", "cid": "bafyrei..."},
351
+ ... {"did": "did:plc:abc123", "cid": "bafyrei..."},
352
+ ... ],
353
+ ... )
354
+ >>> for shard_id, stream in source.shards:
355
+ ... process(stream)
356
+ """
357
+
358
+ blob_refs: list[dict[str, str]]
359
+ pds_endpoint: str | None = None
360
+ _endpoint_cache: dict[str, str] = field(
361
+ default_factory=dict, repr=False, compare=False
362
+ )
363
+
364
+ def _resolve_pds_endpoint(self, did: str) -> str:
365
+ """Resolve PDS endpoint for a DID, with caching."""
366
+ if did in self._endpoint_cache:
367
+ return self._endpoint_cache[did]
368
+
369
+ if self.pds_endpoint:
370
+ self._endpoint_cache[did] = self.pds_endpoint
371
+ return self.pds_endpoint
372
+
373
+ import requests
374
+
375
+ # Resolve via plc.directory
376
+ if did.startswith("did:plc:"):
377
+ plc_url = f"https://plc.directory/{did}"
378
+ response = requests.get(plc_url, timeout=10)
379
+ response.raise_for_status()
380
+ doc = response.json()
381
+
382
+ for service in doc.get("service", []):
383
+ if service.get("type") == "AtprotoPersonalDataServer":
384
+ endpoint = service.get("serviceEndpoint", "")
385
+ self._endpoint_cache[did] = endpoint
386
+ return endpoint
387
+
388
+ raise ValueError(f"Could not resolve PDS endpoint for {did}")
389
+
390
+ def _get_blob_url(self, did: str, cid: str) -> str:
391
+ """Get HTTP URL for fetching a blob."""
392
+ endpoint = self._resolve_pds_endpoint(did)
393
+ return f"{endpoint}/xrpc/com.atproto.sync.getBlob?did={did}&cid={cid}"
394
+
395
+ def _make_shard_id(self, ref: dict[str, str]) -> str:
396
+ """Create shard identifier from blob reference."""
397
+ return f"at://{ref['did']}/blob/{ref['cid']}"
398
+
399
+ def list_shards(self) -> list[str]:
400
+ """Return list of AT URI-style shard identifiers."""
401
+ return [self._make_shard_id(ref) for ref in self.blob_refs]
402
+
403
+ @property
404
+ def shards(self) -> Iterator[tuple[str, IO[bytes]]]:
405
+ """Lazily yield (at_uri, stream) pairs for each shard.
406
+
407
+ Fetches blobs via HTTP from the PDS and yields streaming responses.
408
+
409
+ Yields:
410
+ Tuple of (at://did/blob/cid URI, streaming response body).
411
+ """
412
+ import requests
413
+
414
+ for ref in self.blob_refs:
415
+ did = ref["did"]
416
+ cid = ref["cid"]
417
+ url = self._get_blob_url(did, cid)
418
+
419
+ response = requests.get(url, stream=True, timeout=60)
420
+ response.raise_for_status()
421
+
422
+ shard_id = self._make_shard_id(ref)
423
+ # Wrap response in a file-like object
424
+ yield shard_id, response.raw
425
+
426
+ def open_shard(self, shard_id: str) -> IO[bytes]:
427
+ """Open a single shard by its AT URI.
428
+
429
+ Args:
430
+ shard_id: AT URI of the shard (at://did/blob/cid).
431
+
432
+ Returns:
433
+ Streaming response body for reading the blob.
434
+
435
+ Raises:
436
+ KeyError: If shard_id is not in list_shards().
437
+ ValueError: If shard_id format is invalid.
438
+ """
439
+ if shard_id not in self.list_shards():
440
+ raise KeyError(f"Shard not found: {shard_id}")
441
+
442
+ # Parse at://did/blob/cid
443
+ if not shard_id.startswith("at://"):
444
+ raise ValueError(f"Invalid shard ID format: {shard_id}")
445
+
446
+ parts = shard_id[5:].split("/") # Remove 'at://'
447
+ if len(parts) != 3 or parts[1] != "blob":
448
+ raise ValueError(f"Invalid blob URI format: {shard_id}")
449
+
450
+ did, _, cid = parts
451
+ url = self._get_blob_url(did, cid)
452
+
453
+ import requests
454
+
455
+ response = requests.get(url, stream=True, timeout=60)
456
+ response.raise_for_status()
457
+ return response.raw
458
+
459
+ @classmethod
460
+ def from_refs(
461
+ cls,
462
+ refs: list[dict],
463
+ *,
464
+ pds_endpoint: str | None = None,
465
+ ) -> "BlobSource":
466
+ """Create BlobSource from blob reference dicts.
467
+
468
+ Accepts blob references in the format returned by upload_blob:
469
+ ``{"$type": "blob", "ref": {"$link": "cid"}, ...}``
470
+
471
+ Also accepts simplified format: ``{"did": "...", "cid": "..."}``
472
+
473
+ Args:
474
+ refs: List of blob reference dicts.
475
+ pds_endpoint: Optional PDS endpoint to use for all blobs.
476
+
477
+ Returns:
478
+ Configured BlobSource.
479
+
480
+ Raises:
481
+ ValueError: If refs is empty or format is invalid.
482
+ """
483
+ if not refs:
484
+ raise ValueError("refs cannot be empty")
485
+
486
+ blob_refs: list[dict[str, str]] = []
487
+
488
+ for ref in refs:
489
+ if "did" in ref and "cid" in ref:
490
+ # Simple format
491
+ blob_refs.append({"did": ref["did"], "cid": ref["cid"]})
492
+ elif "ref" in ref and "$link" in ref.get("ref", {}):
493
+ # ATProto blob format - need DID from elsewhere
494
+ raise ValueError(
495
+ "ATProto blob format requires 'did' field. "
496
+ "Use from_record_storage() for records with storage.blobs."
497
+ )
498
+ else:
499
+ raise ValueError(f"Invalid blob reference format: {ref}")
500
+
501
+ return cls(blob_refs=blob_refs, pds_endpoint=pds_endpoint)
502
+
503
+
504
+ __all__ = [
505
+ "URLSource",
506
+ "S3Source",
507
+ "BlobSource",
508
+ ]