atdata 0.1.3b3__py3-none-any.whl → 0.2.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +39 -1
- atdata/_helpers.py +39 -3
- atdata/atmosphere/__init__.py +61 -0
- atdata/atmosphere/_types.py +329 -0
- atdata/atmosphere/client.py +393 -0
- atdata/atmosphere/lens.py +280 -0
- atdata/atmosphere/records.py +342 -0
- atdata/atmosphere/schema.py +296 -0
- atdata/dataset.py +336 -203
- atdata/lens.py +177 -77
- atdata/local.py +492 -0
- atdata-0.2.0a1.dist-info/METADATA +181 -0
- atdata-0.2.0a1.dist-info/RECORD +16 -0
- {atdata-0.1.3b3.dist-info → atdata-0.2.0a1.dist-info}/WHEEL +1 -1
- atdata-0.1.3b3.dist-info/METADATA +0 -18
- atdata-0.1.3b3.dist-info/RECORD +0 -9
- {atdata-0.1.3b3.dist-info → atdata-0.2.0a1.dist-info}/entry_points.txt +0 -0
- {atdata-0.1.3b3.dist-info → atdata-0.2.0a1.dist-info}/licenses/LICENSE +0 -0
atdata/local.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
1
|
+
"""Local repository storage for atdata datasets.
|
|
2
|
+
|
|
3
|
+
This module provides a local storage backend for atdata datasets using:
|
|
4
|
+
- S3-compatible object storage for dataset tar files and metadata
|
|
5
|
+
- Redis for indexing and tracking datasets
|
|
6
|
+
|
|
7
|
+
The main classes are:
|
|
8
|
+
- Repo: Manages dataset storage in S3 with Redis indexing
|
|
9
|
+
- Index: Redis-backed index for tracking dataset metadata
|
|
10
|
+
- BasicIndexEntry: Index entry representing a stored dataset
|
|
11
|
+
|
|
12
|
+
This is intended for development and small-scale deployment before
|
|
13
|
+
migrating to the full atproto PDS infrastructure.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
##
|
|
17
|
+
# Imports
|
|
18
|
+
|
|
19
|
+
from atdata import (
|
|
20
|
+
PackableSample,
|
|
21
|
+
Dataset,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
import os
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from uuid import uuid4
|
|
27
|
+
from tempfile import TemporaryDirectory
|
|
28
|
+
from dotenv import dotenv_values
|
|
29
|
+
import msgpack
|
|
30
|
+
|
|
31
|
+
from redis import Redis
|
|
32
|
+
|
|
33
|
+
from s3fs import (
|
|
34
|
+
S3FileSystem,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
import webdataset as wds
|
|
38
|
+
|
|
39
|
+
from dataclasses import (
|
|
40
|
+
dataclass,
|
|
41
|
+
asdict,
|
|
42
|
+
field,
|
|
43
|
+
)
|
|
44
|
+
from typing import (
|
|
45
|
+
Any,
|
|
46
|
+
Optional,
|
|
47
|
+
Dict,
|
|
48
|
+
Type,
|
|
49
|
+
TypeVar,
|
|
50
|
+
Generator,
|
|
51
|
+
BinaryIO,
|
|
52
|
+
cast,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
T = TypeVar( 'T', bound = PackableSample )
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
##
|
|
59
|
+
# Helpers
|
|
60
|
+
|
|
61
|
+
def _kind_str_for_sample_type( st: Type[PackableSample] ) -> str:
|
|
62
|
+
"""Convert a sample type to a fully-qualified string identifier.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
st: The sample type class.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
A string in the format 'module.name' identifying the sample type.
|
|
69
|
+
"""
|
|
70
|
+
return f'{st.__module__}.{st.__name__}'
|
|
71
|
+
|
|
72
|
+
def _decode_bytes_dict( d: dict[bytes, bytes] ) -> dict[str, str]:
|
|
73
|
+
"""Decode a dictionary with byte keys and values to strings.
|
|
74
|
+
|
|
75
|
+
Redis returns dictionaries with bytes keys/values, this converts them to strings.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
d: Dictionary with bytes keys and values.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Dictionary with UTF-8 decoded string keys and values.
|
|
82
|
+
"""
|
|
83
|
+
return {
|
|
84
|
+
k.decode('utf-8'): v.decode('utf-8')
|
|
85
|
+
for k, v in d.items()
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
##
|
|
90
|
+
# Redis object model
|
|
91
|
+
|
|
92
|
+
@dataclass
|
|
93
|
+
class BasicIndexEntry:
|
|
94
|
+
"""Index entry for a dataset stored in the repository.
|
|
95
|
+
|
|
96
|
+
Tracks metadata about a dataset stored in S3, including its location,
|
|
97
|
+
type, and unique identifier.
|
|
98
|
+
"""
|
|
99
|
+
##
|
|
100
|
+
|
|
101
|
+
wds_url: str
|
|
102
|
+
"""WebDataset URL for the dataset tar files, for use with atdata.Dataset."""
|
|
103
|
+
|
|
104
|
+
sample_kind: str
|
|
105
|
+
"""Fully-qualified sample type name (e.g., 'module.ClassName')."""
|
|
106
|
+
|
|
107
|
+
metadata_url: str | None
|
|
108
|
+
"""S3 URL to the dataset's metadata msgpack file, if any."""
|
|
109
|
+
|
|
110
|
+
uuid: str = field( default_factory = lambda: str( uuid4() ) )
|
|
111
|
+
"""Unique identifier for this dataset entry. Defaults to a new UUID if not provided."""
|
|
112
|
+
|
|
113
|
+
def write_to( self, redis: Redis ):
|
|
114
|
+
"""Persist this index entry to Redis.
|
|
115
|
+
|
|
116
|
+
Stores the entry as a Redis hash with key 'BasicIndexEntry:{uuid}'.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
redis: Redis connection to write to.
|
|
120
|
+
"""
|
|
121
|
+
save_key = f'BasicIndexEntry:{self.uuid}'
|
|
122
|
+
# Filter out None values - Redis doesn't accept None
|
|
123
|
+
data = {k: v for k, v in asdict(self).items() if v is not None}
|
|
124
|
+
# redis-py typing uses untyped dict, so type checker complains about dict[str, Any]
|
|
125
|
+
redis.hset( save_key, mapping = data ) # type: ignore[arg-type]
|
|
126
|
+
|
|
127
|
+
def _s3_env( credentials_path: str | Path ) -> dict[str, Any]:
|
|
128
|
+
"""Load S3 credentials from a .env file.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
credentials_path: Path to .env file containing S3 credentials.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Dictionary with AWS_ENDPOINT, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY.
|
|
135
|
+
|
|
136
|
+
Raises:
|
|
137
|
+
AssertionError: If required credentials are missing from the file.
|
|
138
|
+
"""
|
|
139
|
+
##
|
|
140
|
+
credentials_path = Path( credentials_path )
|
|
141
|
+
env_values = dotenv_values( credentials_path )
|
|
142
|
+
assert 'AWS_ENDPOINT' in env_values
|
|
143
|
+
assert 'AWS_ACCESS_KEY_ID' in env_values
|
|
144
|
+
assert 'AWS_SECRET_ACCESS_KEY' in env_values
|
|
145
|
+
|
|
146
|
+
return {
|
|
147
|
+
k: env_values[k]
|
|
148
|
+
for k in (
|
|
149
|
+
'AWS_ENDPOINT',
|
|
150
|
+
'AWS_ACCESS_KEY_ID',
|
|
151
|
+
'AWS_SECRET_ACCESS_KEY',
|
|
152
|
+
)
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
def _s3_from_credentials( creds: str | Path | dict ) -> S3FileSystem:
|
|
156
|
+
"""Create an S3FileSystem from credentials.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
creds: Either a path to a .env file with credentials, or a dict
|
|
160
|
+
containing AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and optionally
|
|
161
|
+
AWS_ENDPOINT.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
Configured S3FileSystem instance.
|
|
165
|
+
"""
|
|
166
|
+
##
|
|
167
|
+
if not isinstance( creds, dict ):
|
|
168
|
+
creds = _s3_env( creds )
|
|
169
|
+
|
|
170
|
+
# Build kwargs, making endpoint_url optional
|
|
171
|
+
kwargs = {
|
|
172
|
+
'key': creds['AWS_ACCESS_KEY_ID'],
|
|
173
|
+
'secret': creds['AWS_SECRET_ACCESS_KEY']
|
|
174
|
+
}
|
|
175
|
+
if 'AWS_ENDPOINT' in creds:
|
|
176
|
+
kwargs['endpoint_url'] = creds['AWS_ENDPOINT']
|
|
177
|
+
|
|
178
|
+
return S3FileSystem(**kwargs)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
##
|
|
182
|
+
# Classes
|
|
183
|
+
|
|
184
|
+
class Repo:
|
|
185
|
+
"""Repository for storing and managing atdata datasets.
|
|
186
|
+
|
|
187
|
+
Provides storage of datasets in S3-compatible object storage with Redis-based
|
|
188
|
+
indexing. Datasets are stored as WebDataset tar files with optional metadata.
|
|
189
|
+
|
|
190
|
+
Attributes:
|
|
191
|
+
s3_credentials: S3 credentials dictionary or None.
|
|
192
|
+
bucket_fs: S3FileSystem instance or None.
|
|
193
|
+
hive_path: Path within S3 bucket for storing datasets.
|
|
194
|
+
hive_bucket: Name of the S3 bucket.
|
|
195
|
+
index: Index instance for tracking datasets.
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
##
|
|
199
|
+
|
|
200
|
+
def __init__( self,
|
|
201
|
+
#
|
|
202
|
+
s3_credentials: str | Path | dict[str, Any] | None = None,
|
|
203
|
+
hive_path: str | Path | None = None,
|
|
204
|
+
redis: Redis | None = None,
|
|
205
|
+
#
|
|
206
|
+
#
|
|
207
|
+
**kwargs
|
|
208
|
+
) -> None:
|
|
209
|
+
"""Initialize a repository.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
s3_credentials: Path to .env file with S3 credentials, or dict with
|
|
213
|
+
AWS_ENDPOINT, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY.
|
|
214
|
+
If None, S3 functionality will be disabled.
|
|
215
|
+
hive_path: Path within the S3 bucket to store datasets.
|
|
216
|
+
Required if s3_credentials is provided.
|
|
217
|
+
redis: Redis connection for indexing. If None, creates a new connection.
|
|
218
|
+
**kwargs: Additional arguments (reserved for future use).
|
|
219
|
+
|
|
220
|
+
Raises:
|
|
221
|
+
ValueError: If hive_path is not provided when s3_credentials is set.
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
if s3_credentials is None:
|
|
225
|
+
self.s3_credentials = None
|
|
226
|
+
elif isinstance( s3_credentials, dict ):
|
|
227
|
+
self.s3_credentials = s3_credentials
|
|
228
|
+
else:
|
|
229
|
+
self.s3_credentials = _s3_env( s3_credentials )
|
|
230
|
+
|
|
231
|
+
if self.s3_credentials is None:
|
|
232
|
+
self.bucket_fs = None
|
|
233
|
+
else:
|
|
234
|
+
self.bucket_fs = _s3_from_credentials( self.s3_credentials )
|
|
235
|
+
|
|
236
|
+
if self.bucket_fs is not None:
|
|
237
|
+
if hive_path is None:
|
|
238
|
+
raise ValueError( 'Must specify hive path within bucket' )
|
|
239
|
+
self.hive_path = Path( hive_path )
|
|
240
|
+
self.hive_bucket = self.hive_path.parts[0]
|
|
241
|
+
else:
|
|
242
|
+
self.hive_path = None
|
|
243
|
+
self.hive_bucket = None
|
|
244
|
+
|
|
245
|
+
#
|
|
246
|
+
|
|
247
|
+
self.index = Index( redis = redis )
|
|
248
|
+
|
|
249
|
+
##
|
|
250
|
+
|
|
251
|
+
def insert( self, ds: Dataset[T],
|
|
252
|
+
#
|
|
253
|
+
cache_local: bool = False,
|
|
254
|
+
#
|
|
255
|
+
**kwargs
|
|
256
|
+
) -> tuple[BasicIndexEntry, Dataset[T]]:
|
|
257
|
+
"""Insert a dataset into the repository.
|
|
258
|
+
|
|
259
|
+
Writes the dataset to S3 as WebDataset tar files, stores metadata,
|
|
260
|
+
and creates an index entry in Redis.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
ds: The dataset to insert.
|
|
264
|
+
cache_local: If True, write to local temporary storage first, then
|
|
265
|
+
copy to S3. This can be faster for some workloads.
|
|
266
|
+
**kwargs: Additional arguments passed to wds.ShardWriter.
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
A tuple of (index_entry, new_dataset) where:
|
|
270
|
+
- index_entry: BasicIndexEntry for the stored dataset
|
|
271
|
+
- new_dataset: Dataset object pointing to the stored copy
|
|
272
|
+
|
|
273
|
+
Raises:
|
|
274
|
+
AssertionError: If S3 credentials or hive_path are not configured.
|
|
275
|
+
RuntimeError: If no shards were written.
|
|
276
|
+
"""
|
|
277
|
+
|
|
278
|
+
assert self.s3_credentials is not None
|
|
279
|
+
assert self.hive_bucket is not None
|
|
280
|
+
assert self.hive_path is not None
|
|
281
|
+
|
|
282
|
+
new_uuid = str( uuid4() )
|
|
283
|
+
|
|
284
|
+
hive_fs = _s3_from_credentials( self.s3_credentials )
|
|
285
|
+
|
|
286
|
+
# Write metadata
|
|
287
|
+
metadata_path = (
|
|
288
|
+
self.hive_path
|
|
289
|
+
/ 'metadata'
|
|
290
|
+
/ f'atdata-metadata--{new_uuid}.msgpack'
|
|
291
|
+
)
|
|
292
|
+
# Note: S3 doesn't need directories created beforehand - s3fs handles this
|
|
293
|
+
|
|
294
|
+
if ds.metadata is not None:
|
|
295
|
+
# Use s3:// prefix to ensure s3fs treats this as an S3 path
|
|
296
|
+
with cast( BinaryIO, hive_fs.open( f's3://{metadata_path.as_posix()}', 'wb' ) ) as f:
|
|
297
|
+
meta_packed = msgpack.packb( ds.metadata )
|
|
298
|
+
assert meta_packed is not None
|
|
299
|
+
f.write( cast( bytes, meta_packed ) )
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
# Write data
|
|
303
|
+
shard_pattern = (
|
|
304
|
+
self.hive_path
|
|
305
|
+
/ f'atdata--{new_uuid}--%06d.tar'
|
|
306
|
+
).as_posix()
|
|
307
|
+
|
|
308
|
+
with TemporaryDirectory() as temp_dir:
|
|
309
|
+
|
|
310
|
+
if cache_local:
|
|
311
|
+
# For cache_local, we need to use boto3 directly to avoid s3fs async issues with moto
|
|
312
|
+
import boto3
|
|
313
|
+
|
|
314
|
+
# Create boto3 client from credentials
|
|
315
|
+
s3_client_kwargs = {
|
|
316
|
+
'aws_access_key_id': self.s3_credentials['AWS_ACCESS_KEY_ID'],
|
|
317
|
+
'aws_secret_access_key': self.s3_credentials['AWS_SECRET_ACCESS_KEY']
|
|
318
|
+
}
|
|
319
|
+
if 'AWS_ENDPOINT' in self.s3_credentials:
|
|
320
|
+
s3_client_kwargs['endpoint_url'] = self.s3_credentials['AWS_ENDPOINT']
|
|
321
|
+
s3_client = boto3.client('s3', **s3_client_kwargs)
|
|
322
|
+
|
|
323
|
+
def _writer_opener( p: str ):
|
|
324
|
+
local_cache_path = Path( temp_dir ) / p
|
|
325
|
+
local_cache_path.parent.mkdir( parents = True, exist_ok = True )
|
|
326
|
+
return open( local_cache_path, 'wb' )
|
|
327
|
+
writer_opener = _writer_opener
|
|
328
|
+
|
|
329
|
+
def _writer_post( p: str ):
|
|
330
|
+
local_cache_path = Path( temp_dir ) / p
|
|
331
|
+
|
|
332
|
+
# Copy to S3 using boto3 client (avoids s3fs async issues)
|
|
333
|
+
path_parts = Path( p ).parts
|
|
334
|
+
bucket = path_parts[0]
|
|
335
|
+
key = str( Path( *path_parts[1:] ) )
|
|
336
|
+
|
|
337
|
+
with open( local_cache_path, 'rb' ) as f_in:
|
|
338
|
+
s3_client.put_object( Bucket=bucket, Key=key, Body=f_in.read() )
|
|
339
|
+
|
|
340
|
+
# Delete local cache file
|
|
341
|
+
local_cache_path.unlink()
|
|
342
|
+
|
|
343
|
+
written_shards.append( p )
|
|
344
|
+
writer_post = _writer_post
|
|
345
|
+
|
|
346
|
+
else:
|
|
347
|
+
# Use s3:// prefix to ensure s3fs treats paths as S3 paths
|
|
348
|
+
writer_opener = lambda s: cast( BinaryIO, hive_fs.open( f's3://{s}', 'wb' ) )
|
|
349
|
+
writer_post = lambda s: written_shards.append( s )
|
|
350
|
+
|
|
351
|
+
written_shards = []
|
|
352
|
+
with wds.writer.ShardWriter(
|
|
353
|
+
shard_pattern,
|
|
354
|
+
opener = writer_opener,
|
|
355
|
+
post = writer_post,
|
|
356
|
+
**kwargs,
|
|
357
|
+
) as sink:
|
|
358
|
+
for sample in ds.ordered( batch_size = None ):
|
|
359
|
+
sink.write( sample.as_wds )
|
|
360
|
+
|
|
361
|
+
# Make a new Dataset object for the written dataset copy
|
|
362
|
+
if len( written_shards ) == 0:
|
|
363
|
+
raise RuntimeError( 'Cannot form new dataset entry -- did not write any shards' )
|
|
364
|
+
|
|
365
|
+
elif len( written_shards ) < 2:
|
|
366
|
+
new_dataset_url = (
|
|
367
|
+
self.hive_path
|
|
368
|
+
/ ( Path( written_shards[0] ).name )
|
|
369
|
+
).as_posix()
|
|
370
|
+
|
|
371
|
+
else:
|
|
372
|
+
shard_s3_format = (
|
|
373
|
+
(
|
|
374
|
+
self.hive_path
|
|
375
|
+
/ f'atdata--{new_uuid}'
|
|
376
|
+
).as_posix()
|
|
377
|
+
) + '--{shard_id}.tar'
|
|
378
|
+
shard_id_braced = '{' + f'{0:06d}..{len( written_shards ) - 1:06d}' + '}'
|
|
379
|
+
new_dataset_url = shard_s3_format.format( shard_id = shard_id_braced )
|
|
380
|
+
|
|
381
|
+
new_dataset = Dataset[ds.sample_type](
|
|
382
|
+
url = new_dataset_url,
|
|
383
|
+
metadata_url = metadata_path.as_posix(),
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Add to index
|
|
387
|
+
new_entry = self.index.add_entry( new_dataset, uuid = new_uuid )
|
|
388
|
+
|
|
389
|
+
return new_entry, new_dataset
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
class Index:
|
|
393
|
+
"""Redis-backed index for tracking datasets in a repository.
|
|
394
|
+
|
|
395
|
+
Maintains a registry of BasicIndexEntry objects in Redis, allowing
|
|
396
|
+
enumeration and lookup of stored datasets.
|
|
397
|
+
|
|
398
|
+
Attributes:
|
|
399
|
+
_redis: Redis connection for index storage.
|
|
400
|
+
"""
|
|
401
|
+
|
|
402
|
+
##
|
|
403
|
+
|
|
404
|
+
def __init__( self,
|
|
405
|
+
redis: Redis | None = None,
|
|
406
|
+
**kwargs
|
|
407
|
+
) -> None:
|
|
408
|
+
"""Initialize an index.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
redis: Redis connection to use. If None, creates a new connection
|
|
412
|
+
using the provided kwargs.
|
|
413
|
+
**kwargs: Additional arguments passed to Redis() constructor if
|
|
414
|
+
redis is None.
|
|
415
|
+
"""
|
|
416
|
+
##
|
|
417
|
+
|
|
418
|
+
if redis is not None:
|
|
419
|
+
self._redis = redis
|
|
420
|
+
else:
|
|
421
|
+
self._redis: Redis = Redis( **kwargs )
|
|
422
|
+
|
|
423
|
+
@property
|
|
424
|
+
def all_entries( self ) -> list[BasicIndexEntry]:
|
|
425
|
+
"""Get all index entries as a list.
|
|
426
|
+
|
|
427
|
+
Returns:
|
|
428
|
+
List of all BasicIndexEntry objects in the index.
|
|
429
|
+
"""
|
|
430
|
+
return list( self.entries )
|
|
431
|
+
|
|
432
|
+
@property
|
|
433
|
+
def entries( self ) -> Generator[BasicIndexEntry, None, None]:
|
|
434
|
+
"""Iterate over all index entries.
|
|
435
|
+
|
|
436
|
+
Scans Redis for all BasicIndexEntry keys and yields them one at a time.
|
|
437
|
+
|
|
438
|
+
Yields:
|
|
439
|
+
BasicIndexEntry objects from the index.
|
|
440
|
+
"""
|
|
441
|
+
##
|
|
442
|
+
for key in self._redis.scan_iter( match = 'BasicIndexEntry:*' ):
|
|
443
|
+
# hgetall returns dict[bytes, bytes] which we decode to dict[str, str]
|
|
444
|
+
cur_entry_data = _decode_bytes_dict( cast(dict[bytes, bytes], self._redis.hgetall( key )) )
|
|
445
|
+
|
|
446
|
+
# Provide default None for optional fields that may be missing
|
|
447
|
+
# Type checker complains about None in dict[str, str], but BasicIndexEntry accepts it
|
|
448
|
+
cur_entry_data: dict[str, Any] = dict( **cur_entry_data )
|
|
449
|
+
cur_entry_data.setdefault('metadata_url', None)
|
|
450
|
+
|
|
451
|
+
cur_entry = BasicIndexEntry( **cur_entry_data )
|
|
452
|
+
yield cur_entry
|
|
453
|
+
|
|
454
|
+
return
|
|
455
|
+
|
|
456
|
+
def add_entry( self, ds: Dataset,
|
|
457
|
+
uuid: str | None = None,
|
|
458
|
+
) -> BasicIndexEntry:
|
|
459
|
+
"""Add a dataset to the index.
|
|
460
|
+
|
|
461
|
+
Creates a BasicIndexEntry for the dataset and persists it to Redis.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
ds: The dataset to add to the index.
|
|
465
|
+
uuid: Optional UUID for the entry. If None, a new UUID is generated.
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
The created BasicIndexEntry object.
|
|
469
|
+
"""
|
|
470
|
+
##
|
|
471
|
+
temp_sample_kind = _kind_str_for_sample_type( ds.sample_type )
|
|
472
|
+
|
|
473
|
+
if uuid is None:
|
|
474
|
+
ret_data = BasicIndexEntry(
|
|
475
|
+
wds_url = ds.url,
|
|
476
|
+
sample_kind = temp_sample_kind,
|
|
477
|
+
metadata_url = ds.metadata_url,
|
|
478
|
+
)
|
|
479
|
+
else:
|
|
480
|
+
ret_data = BasicIndexEntry(
|
|
481
|
+
wds_url = ds.url,
|
|
482
|
+
sample_kind = temp_sample_kind,
|
|
483
|
+
metadata_url = ds.metadata_url,
|
|
484
|
+
uuid = uuid,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
ret_data.write_to( self._redis )
|
|
488
|
+
|
|
489
|
+
return ret_data
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
#
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: atdata
|
|
3
|
+
Version: 0.2.0a1
|
|
4
|
+
Summary: A loose federation of distributed, typed datasets
|
|
5
|
+
Author-email: Maxine Levesque <hello@maxine.science>
|
|
6
|
+
License-File: LICENSE
|
|
7
|
+
Requires-Python: >=3.12
|
|
8
|
+
Requires-Dist: atproto>=0.0.65
|
|
9
|
+
Requires-Dist: fastparquet>=2024.11.0
|
|
10
|
+
Requires-Dist: msgpack>=1.1.2
|
|
11
|
+
Requires-Dist: numpy>=2.3.4
|
|
12
|
+
Requires-Dist: ormsgpack>=1.11.0
|
|
13
|
+
Requires-Dist: pandas>=2.3.3
|
|
14
|
+
Requires-Dist: pydantic>=2.12.5
|
|
15
|
+
Requires-Dist: python-dotenv>=1.2.1
|
|
16
|
+
Requires-Dist: redis-om>=0.3.5
|
|
17
|
+
Requires-Dist: requests>=2.32.5
|
|
18
|
+
Requires-Dist: s3fs>=2025.12.0
|
|
19
|
+
Requires-Dist: schemamodels>=0.9.1
|
|
20
|
+
Requires-Dist: tqdm>=4.67.1
|
|
21
|
+
Requires-Dist: webdataset>=1.0.2
|
|
22
|
+
Provides-Extra: atmosphere
|
|
23
|
+
Requires-Dist: atproto>=0.0.55; extra == 'atmosphere'
|
|
24
|
+
Description-Content-Type: text/markdown
|
|
25
|
+
|
|
26
|
+
# atdata
|
|
27
|
+
|
|
28
|
+
[](https://codecov.io/gh/foundation-ac/atdata)
|
|
29
|
+
|
|
30
|
+
A loose federation of distributed, typed datasets built on WebDataset.
|
|
31
|
+
|
|
32
|
+
**atdata** provides a type-safe, composable framework for working with large-scale datasets. It combines the efficiency of WebDataset's tar-based storage with Python's type system and functional programming patterns.
|
|
33
|
+
|
|
34
|
+
## Features
|
|
35
|
+
|
|
36
|
+
- **Typed Samples** - Define dataset schemas using Python dataclasses with automatic msgpack serialization
|
|
37
|
+
- **Lens Transformations** - Bidirectional, composable transformations between different dataset views
|
|
38
|
+
- **Automatic Batching** - Smart batch aggregation with numpy array stacking
|
|
39
|
+
- **WebDataset Integration** - Efficient storage and streaming for large-scale datasets
|
|
40
|
+
|
|
41
|
+
## Installation
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
pip install atdata
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
Requires Python 3.12 or later.
|
|
48
|
+
|
|
49
|
+
## Quick Start
|
|
50
|
+
|
|
51
|
+
### Defining Sample Types
|
|
52
|
+
|
|
53
|
+
Use the `@packable` decorator to create typed dataset samples:
|
|
54
|
+
|
|
55
|
+
```python
|
|
56
|
+
import atdata
|
|
57
|
+
from numpy.typing import NDArray
|
|
58
|
+
|
|
59
|
+
@atdata.packable
|
|
60
|
+
class ImageSample:
|
|
61
|
+
image: NDArray
|
|
62
|
+
label: str
|
|
63
|
+
metadata: dict
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
### Creating Datasets
|
|
67
|
+
|
|
68
|
+
```python
|
|
69
|
+
# Create a dataset
|
|
70
|
+
dataset = atdata.Dataset[ImageSample]("path/to/data-{000000..000009}.tar")
|
|
71
|
+
|
|
72
|
+
# Iterate over samples in order
|
|
73
|
+
for sample in dataset.ordered(batch_size=None):
|
|
74
|
+
print(f"Label: {sample.label}, Image shape: {sample.image.shape}")
|
|
75
|
+
|
|
76
|
+
# Iterate with shuffling and batching
|
|
77
|
+
for batch in dataset.shuffled(batch_size=32):
|
|
78
|
+
# batch.image is automatically stacked into shape (32, ...)
|
|
79
|
+
# batch.label is a list of 32 labels
|
|
80
|
+
process_batch(batch.image, batch.label)
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
### Lens Transformations
|
|
84
|
+
|
|
85
|
+
Define reusable transformations between sample types:
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
@atdata.packable
|
|
89
|
+
class ProcessedSample:
|
|
90
|
+
features: NDArray
|
|
91
|
+
label: str
|
|
92
|
+
|
|
93
|
+
@atdata.lens
|
|
94
|
+
def preprocess(sample: ImageSample) -> ProcessedSample:
|
|
95
|
+
features = extract_features(sample.image)
|
|
96
|
+
return ProcessedSample(features=features, label=sample.label)
|
|
97
|
+
|
|
98
|
+
# Apply lens to view dataset as ProcessedSample
|
|
99
|
+
processed_ds = dataset.as_type(ProcessedSample)
|
|
100
|
+
|
|
101
|
+
for sample in processed_ds.ordered(batch_size=None):
|
|
102
|
+
# sample is now a ProcessedSample
|
|
103
|
+
print(sample.features.shape)
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
## Core Concepts
|
|
107
|
+
|
|
108
|
+
### PackableSample
|
|
109
|
+
|
|
110
|
+
Base class for serializable samples. Fields annotated as `NDArray` are automatically handled:
|
|
111
|
+
|
|
112
|
+
```python
|
|
113
|
+
@atdata.packable
|
|
114
|
+
class MySample:
|
|
115
|
+
array_field: NDArray # Automatically serialized
|
|
116
|
+
optional_array: NDArray | None
|
|
117
|
+
regular_field: str
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
### Lens
|
|
121
|
+
|
|
122
|
+
Bidirectional transformations with getter/putter semantics:
|
|
123
|
+
|
|
124
|
+
```python
|
|
125
|
+
@atdata.lens
|
|
126
|
+
def my_lens(source: SourceType) -> ViewType:
|
|
127
|
+
# Transform source -> view
|
|
128
|
+
return ViewType(...)
|
|
129
|
+
|
|
130
|
+
@my_lens.putter
|
|
131
|
+
def my_lens_put(view: ViewType, source: SourceType) -> SourceType:
|
|
132
|
+
# Transform view -> source
|
|
133
|
+
return SourceType(...)
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
### Dataset URLs
|
|
137
|
+
|
|
138
|
+
Uses WebDataset brace expansion for sharded datasets:
|
|
139
|
+
|
|
140
|
+
- Single file: `"data/dataset-000000.tar"`
|
|
141
|
+
- Multiple shards: `"data/dataset-{000000..000099}.tar"`
|
|
142
|
+
- Multiple patterns: `"data/{train,val}/dataset-{000000..000009}.tar"`
|
|
143
|
+
|
|
144
|
+
## Development
|
|
145
|
+
|
|
146
|
+
### Setup
|
|
147
|
+
|
|
148
|
+
```bash
|
|
149
|
+
# Install uv if not already available
|
|
150
|
+
python -m pip install uv
|
|
151
|
+
|
|
152
|
+
# Install dependencies
|
|
153
|
+
uv sync
|
|
154
|
+
```
|
|
155
|
+
|
|
156
|
+
### Testing
|
|
157
|
+
|
|
158
|
+
```bash
|
|
159
|
+
# Run all tests with coverage
|
|
160
|
+
pytest
|
|
161
|
+
|
|
162
|
+
# Run specific test file
|
|
163
|
+
pytest tests/test_dataset.py
|
|
164
|
+
|
|
165
|
+
# Run single test
|
|
166
|
+
pytest tests/test_lens.py::test_lens
|
|
167
|
+
```
|
|
168
|
+
|
|
169
|
+
### Building
|
|
170
|
+
|
|
171
|
+
```bash
|
|
172
|
+
uv build
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
## Contributing
|
|
176
|
+
|
|
177
|
+
Contributions are welcome! This project is in beta, so the API may still evolve.
|
|
178
|
+
|
|
179
|
+
## License
|
|
180
|
+
|
|
181
|
+
This project is licensed under the Mozilla Public License 2.0. See [LICENSE](LICENSE) for details.
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
atdata/__init__.py,sha256=6RYvy9GJwqtSQbCS81HaQyOyAVgLxm63kBt0SH5Qapo,1642
|
|
2
|
+
atdata/_helpers.py,sha256=RvA-Xlj3AvgSWuiPdS8YTBp8AJT-u32BaLpxsu4PIIA,1564
|
|
3
|
+
atdata/dataset.py,sha256=O2j1_ABvTFcs83_y-GGDRROD9zRe-237O2OiI1NhySg,24173
|
|
4
|
+
atdata/lens.py,sha256=lFFVeuKXa17KYjfz3VFqE9Xf0vy3C6puSiF78hyIaAI,9673
|
|
5
|
+
atdata/local.py,sha256=IdNOTA0nvszG-XRkRMkT_zkMivIx93WKh3bpgIx_u_o,15458
|
|
6
|
+
atdata/atmosphere/__init__.py,sha256=8tPDziazrQWdyvetWTVV1eWRt6JBy86WfnvAeyh8iJE,1743
|
|
7
|
+
atdata/atmosphere/_types.py,sha256=0606wb2c8Ty7cmZWTh5mb_qwJmAwYf5oaJU_wk9moa8,9564
|
|
8
|
+
atdata/atmosphere/client.py,sha256=tihVBlhPCz3TZBHs_Ce7uYwE70IzKyeXNpDKsN_qc5U,11358
|
|
9
|
+
atdata/atmosphere/lens.py,sha256=BzUdagItYsyzYHtK1jqppJJ1VUHJVQRw0hi7LuvJG5Q,9267
|
|
10
|
+
atdata/atmosphere/records.py,sha256=-9hhSLsr6sDHkzCVWDudZtxTMHXcVyUHeVojlNcGdL4,10672
|
|
11
|
+
atdata/atmosphere/schema.py,sha256=6gQMGSRjgESaXZzBYMfO51qL9JMiyNGrqJe4iWarO7w,9872
|
|
12
|
+
atdata-0.2.0a1.dist-info/METADATA,sha256=EBwfarL5lmzP2lMdn7Z9yfZBjP6TwTalnTDC8cc7cdY,4471
|
|
13
|
+
atdata-0.2.0a1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
14
|
+
atdata-0.2.0a1.dist-info/entry_points.txt,sha256=6-iQr1veSTq-ac94bLyfcyGHprrZWevPEd12BWX37tQ,39
|
|
15
|
+
atdata-0.2.0a1.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
|
|
16
|
+
atdata-0.2.0a1.dist-info/RECORD,,
|