atdata 0.2.3b1__py3-none-any.whl → 0.3.1b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/.gitignore +1 -0
- atdata/__init__.py +39 -0
- atdata/_cid.py +0 -21
- atdata/_exceptions.py +168 -0
- atdata/_helpers.py +41 -15
- atdata/_hf_api.py +95 -11
- atdata/_logging.py +70 -0
- atdata/_protocols.py +77 -238
- atdata/_schema_codec.py +7 -6
- atdata/_stub_manager.py +5 -25
- atdata/_type_utils.py +28 -2
- atdata/atmosphere/__init__.py +31 -20
- atdata/atmosphere/_types.py +4 -4
- atdata/atmosphere/client.py +64 -12
- atdata/atmosphere/lens.py +11 -12
- atdata/atmosphere/records.py +12 -12
- atdata/atmosphere/schema.py +16 -18
- atdata/atmosphere/store.py +6 -7
- atdata/cli/__init__.py +161 -175
- atdata/cli/diagnose.py +2 -2
- atdata/cli/{local.py → infra.py} +11 -11
- atdata/cli/inspect.py +69 -0
- atdata/cli/preview.py +63 -0
- atdata/cli/schema.py +109 -0
- atdata/dataset.py +583 -328
- atdata/index/__init__.py +54 -0
- atdata/index/_entry.py +157 -0
- atdata/index/_index.py +1198 -0
- atdata/index/_schema.py +380 -0
- atdata/lens.py +9 -2
- atdata/lexicons/__init__.py +121 -0
- atdata/lexicons/ac.foundation.dataset.arrayFormat.json +16 -0
- atdata/lexicons/ac.foundation.dataset.getLatestSchema.json +78 -0
- atdata/lexicons/ac.foundation.dataset.lens.json +99 -0
- atdata/lexicons/ac.foundation.dataset.record.json +96 -0
- atdata/lexicons/ac.foundation.dataset.schema.json +107 -0
- atdata/lexicons/ac.foundation.dataset.schemaType.json +16 -0
- atdata/lexicons/ac.foundation.dataset.storageBlobs.json +24 -0
- atdata/lexicons/ac.foundation.dataset.storageExternal.json +25 -0
- atdata/lexicons/ndarray_shim.json +16 -0
- atdata/local/__init__.py +70 -0
- atdata/local/_repo_legacy.py +218 -0
- atdata/manifest/__init__.py +28 -0
- atdata/manifest/_aggregates.py +156 -0
- atdata/manifest/_builder.py +163 -0
- atdata/manifest/_fields.py +154 -0
- atdata/manifest/_manifest.py +146 -0
- atdata/manifest/_query.py +150 -0
- atdata/manifest/_writer.py +74 -0
- atdata/promote.py +18 -14
- atdata/providers/__init__.py +25 -0
- atdata/providers/_base.py +140 -0
- atdata/providers/_factory.py +69 -0
- atdata/providers/_postgres.py +214 -0
- atdata/providers/_redis.py +171 -0
- atdata/providers/_sqlite.py +191 -0
- atdata/repository.py +323 -0
- atdata/stores/__init__.py +23 -0
- atdata/stores/_disk.py +123 -0
- atdata/stores/_s3.py +349 -0
- atdata/testing.py +341 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/METADATA +5 -2
- atdata-0.3.1b1.dist-info/RECORD +67 -0
- atdata/local.py +0 -1720
- atdata-0.2.3b1.dist-info/RECORD +0 -28
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/licenses/LICENSE +0 -0
atdata/stores/_s3.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
"""S3-compatible data store and helper functions."""
|
|
2
|
+
|
|
3
|
+
from atdata import Dataset
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from uuid import uuid4
|
|
7
|
+
from tempfile import TemporaryDirectory
|
|
8
|
+
from dotenv import dotenv_values
|
|
9
|
+
from typing import Any, BinaryIO, cast
|
|
10
|
+
|
|
11
|
+
from s3fs import S3FileSystem
|
|
12
|
+
import webdataset as wds
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _s3_env(credentials_path: str | Path) -> dict[str, Any]:
|
|
16
|
+
"""Load S3 credentials from .env file.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
credentials_path: Path to .env file containing AWS_ENDPOINT,
|
|
20
|
+
AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Dict with the three required credential keys.
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
ValueError: If any required key is missing from the .env file.
|
|
27
|
+
"""
|
|
28
|
+
credentials_path = Path(credentials_path)
|
|
29
|
+
env_values = dotenv_values(credentials_path)
|
|
30
|
+
|
|
31
|
+
required_keys = ("AWS_ENDPOINT", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY")
|
|
32
|
+
missing = [k for k in required_keys if k not in env_values]
|
|
33
|
+
if missing:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Missing required keys in {credentials_path}: {', '.join(missing)}"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
return {k: env_values[k] for k in required_keys}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _s3_from_credentials(creds: str | Path | dict) -> S3FileSystem:
|
|
42
|
+
"""Create S3FileSystem from credentials dict or .env file path."""
|
|
43
|
+
if not isinstance(creds, dict):
|
|
44
|
+
creds = _s3_env(creds)
|
|
45
|
+
|
|
46
|
+
# Build kwargs, making endpoint_url optional
|
|
47
|
+
kwargs = {
|
|
48
|
+
"key": creds["AWS_ACCESS_KEY_ID"],
|
|
49
|
+
"secret": creds["AWS_SECRET_ACCESS_KEY"],
|
|
50
|
+
}
|
|
51
|
+
if "AWS_ENDPOINT" in creds:
|
|
52
|
+
kwargs["endpoint_url"] = creds["AWS_ENDPOINT"]
|
|
53
|
+
|
|
54
|
+
return S3FileSystem(**kwargs)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _create_s3_write_callbacks(
|
|
58
|
+
credentials: dict[str, Any],
|
|
59
|
+
temp_dir: str,
|
|
60
|
+
written_shards: list[str],
|
|
61
|
+
fs: S3FileSystem | None,
|
|
62
|
+
cache_local: bool,
|
|
63
|
+
add_s3_prefix: bool = False,
|
|
64
|
+
) -> tuple:
|
|
65
|
+
"""Create opener and post callbacks for ShardWriter with S3 upload.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
credentials: S3 credentials dict.
|
|
69
|
+
temp_dir: Temporary directory for local caching.
|
|
70
|
+
written_shards: List to append written shard paths to.
|
|
71
|
+
fs: S3FileSystem for direct writes (used when cache_local=False).
|
|
72
|
+
cache_local: If True, write locally then copy to S3.
|
|
73
|
+
add_s3_prefix: If True, prepend 's3://' to shard paths.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Tuple of (writer_opener, writer_post) callbacks.
|
|
77
|
+
"""
|
|
78
|
+
if cache_local:
|
|
79
|
+
import boto3
|
|
80
|
+
|
|
81
|
+
s3_client_kwargs = {
|
|
82
|
+
"aws_access_key_id": credentials["AWS_ACCESS_KEY_ID"],
|
|
83
|
+
"aws_secret_access_key": credentials["AWS_SECRET_ACCESS_KEY"],
|
|
84
|
+
}
|
|
85
|
+
if "AWS_ENDPOINT" in credentials:
|
|
86
|
+
s3_client_kwargs["endpoint_url"] = credentials["AWS_ENDPOINT"]
|
|
87
|
+
s3_client = boto3.client("s3", **s3_client_kwargs)
|
|
88
|
+
|
|
89
|
+
def _writer_opener(p: str):
|
|
90
|
+
local_path = Path(temp_dir) / p
|
|
91
|
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
|
92
|
+
return open(local_path, "wb")
|
|
93
|
+
|
|
94
|
+
def _writer_post(p: str):
|
|
95
|
+
local_path = Path(temp_dir) / p
|
|
96
|
+
path_parts = Path(p).parts
|
|
97
|
+
bucket = path_parts[0]
|
|
98
|
+
key = str(Path(*path_parts[1:]))
|
|
99
|
+
|
|
100
|
+
with open(local_path, "rb") as f_in:
|
|
101
|
+
s3_client.put_object(Bucket=bucket, Key=key, Body=f_in.read())
|
|
102
|
+
|
|
103
|
+
local_path.unlink()
|
|
104
|
+
if add_s3_prefix:
|
|
105
|
+
written_shards.append(f"s3://{p}")
|
|
106
|
+
else:
|
|
107
|
+
written_shards.append(p)
|
|
108
|
+
|
|
109
|
+
return _writer_opener, _writer_post
|
|
110
|
+
else:
|
|
111
|
+
if fs is None:
|
|
112
|
+
raise ValueError("S3FileSystem required when cache_local=False")
|
|
113
|
+
|
|
114
|
+
def _direct_opener(s: str):
|
|
115
|
+
return cast(BinaryIO, fs.open(f"s3://{s}", "wb"))
|
|
116
|
+
|
|
117
|
+
def _direct_post(s: str):
|
|
118
|
+
if add_s3_prefix:
|
|
119
|
+
written_shards.append(f"s3://{s}")
|
|
120
|
+
else:
|
|
121
|
+
written_shards.append(s)
|
|
122
|
+
|
|
123
|
+
return _direct_opener, _direct_post
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class S3DataStore:
|
|
127
|
+
"""S3-compatible data store implementing AbstractDataStore protocol.
|
|
128
|
+
|
|
129
|
+
Handles writing dataset shards to S3-compatible object storage and
|
|
130
|
+
resolving URLs for reading.
|
|
131
|
+
|
|
132
|
+
Attributes:
|
|
133
|
+
credentials: S3 credentials dictionary.
|
|
134
|
+
bucket: Target bucket name.
|
|
135
|
+
_fs: S3FileSystem instance.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def __init__(
|
|
139
|
+
self,
|
|
140
|
+
credentials: str | Path | dict[str, Any],
|
|
141
|
+
*,
|
|
142
|
+
bucket: str,
|
|
143
|
+
) -> None:
|
|
144
|
+
"""Initialize an S3 data store.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
credentials: Path to .env file or dict with AWS_ACCESS_KEY_ID,
|
|
148
|
+
AWS_SECRET_ACCESS_KEY, and optionally AWS_ENDPOINT.
|
|
149
|
+
bucket: Name of the S3 bucket for storage.
|
|
150
|
+
"""
|
|
151
|
+
if isinstance(credentials, dict):
|
|
152
|
+
self.credentials = credentials
|
|
153
|
+
else:
|
|
154
|
+
self.credentials = _s3_env(credentials)
|
|
155
|
+
|
|
156
|
+
self.bucket = bucket
|
|
157
|
+
self._fs = _s3_from_credentials(self.credentials)
|
|
158
|
+
|
|
159
|
+
def write_shards(
|
|
160
|
+
self,
|
|
161
|
+
ds: Dataset,
|
|
162
|
+
*,
|
|
163
|
+
prefix: str,
|
|
164
|
+
cache_local: bool = False,
|
|
165
|
+
manifest: bool = False,
|
|
166
|
+
schema_version: str = "1.0.0",
|
|
167
|
+
source_job_id: str | None = None,
|
|
168
|
+
parent_shards: list[str] | None = None,
|
|
169
|
+
pipeline_version: str | None = None,
|
|
170
|
+
**kwargs,
|
|
171
|
+
) -> list[str]:
|
|
172
|
+
"""Write dataset shards to S3.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
ds: The Dataset to write.
|
|
176
|
+
prefix: Path prefix within bucket (e.g., 'datasets/mnist/v1').
|
|
177
|
+
cache_local: If True, write locally first then copy to S3.
|
|
178
|
+
manifest: If True, generate per-shard manifest files alongside
|
|
179
|
+
each tar shard (``.manifest.json`` + ``.manifest.parquet``).
|
|
180
|
+
schema_version: Schema version for manifest headers.
|
|
181
|
+
source_job_id: Optional provenance job identifier for manifests.
|
|
182
|
+
parent_shards: Optional list of input shard identifiers for provenance.
|
|
183
|
+
pipeline_version: Optional pipeline version string for provenance.
|
|
184
|
+
**kwargs: Additional args passed to wds.ShardWriter (e.g., maxcount).
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
List of S3 URLs for the written shards.
|
|
188
|
+
|
|
189
|
+
Raises:
|
|
190
|
+
RuntimeError: If no shards were written.
|
|
191
|
+
"""
|
|
192
|
+
new_uuid = str(uuid4())
|
|
193
|
+
shard_pattern = f"{self.bucket}/{prefix}/data--{new_uuid}--%06d.tar"
|
|
194
|
+
|
|
195
|
+
written_shards: list[str] = []
|
|
196
|
+
|
|
197
|
+
# Manifest tracking state shared with the post callback
|
|
198
|
+
manifest_builders: list = []
|
|
199
|
+
current_builder: list = [None] # mutable ref for closure
|
|
200
|
+
shard_counter: list[int] = [0]
|
|
201
|
+
|
|
202
|
+
if manifest:
|
|
203
|
+
from atdata.manifest import ManifestBuilder, ManifestWriter
|
|
204
|
+
|
|
205
|
+
def _make_builder(shard_idx: int) -> ManifestBuilder:
|
|
206
|
+
shard_id = f"{self.bucket}/{prefix}/data--{new_uuid}--{shard_idx:06d}"
|
|
207
|
+
return ManifestBuilder(
|
|
208
|
+
sample_type=ds.sample_type,
|
|
209
|
+
shard_id=shard_id,
|
|
210
|
+
schema_version=schema_version,
|
|
211
|
+
source_job_id=source_job_id,
|
|
212
|
+
parent_shards=parent_shards,
|
|
213
|
+
pipeline_version=pipeline_version,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
current_builder[0] = _make_builder(0)
|
|
217
|
+
|
|
218
|
+
with TemporaryDirectory() as temp_dir:
|
|
219
|
+
writer_opener, writer_post_orig = _create_s3_write_callbacks(
|
|
220
|
+
credentials=self.credentials,
|
|
221
|
+
temp_dir=temp_dir,
|
|
222
|
+
written_shards=written_shards,
|
|
223
|
+
fs=self._fs,
|
|
224
|
+
cache_local=cache_local,
|
|
225
|
+
add_s3_prefix=True,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
if manifest:
|
|
229
|
+
|
|
230
|
+
def writer_post(p: str):
|
|
231
|
+
# Finalize the current manifest builder when a shard completes
|
|
232
|
+
builder = current_builder[0]
|
|
233
|
+
if builder is not None:
|
|
234
|
+
manifest_builders.append(builder)
|
|
235
|
+
# Advance to the next shard's builder
|
|
236
|
+
shard_counter[0] += 1
|
|
237
|
+
current_builder[0] = _make_builder(shard_counter[0])
|
|
238
|
+
# Call original post callback
|
|
239
|
+
writer_post_orig(p)
|
|
240
|
+
else:
|
|
241
|
+
writer_post = writer_post_orig
|
|
242
|
+
|
|
243
|
+
offset = 0
|
|
244
|
+
with wds.writer.ShardWriter(
|
|
245
|
+
shard_pattern,
|
|
246
|
+
opener=writer_opener,
|
|
247
|
+
post=writer_post,
|
|
248
|
+
**kwargs,
|
|
249
|
+
) as sink:
|
|
250
|
+
for sample in ds.ordered(batch_size=None):
|
|
251
|
+
wds_dict = sample.as_wds
|
|
252
|
+
sink.write(wds_dict)
|
|
253
|
+
|
|
254
|
+
if manifest and current_builder[0] is not None:
|
|
255
|
+
packed_size = len(wds_dict.get("msgpack", b""))
|
|
256
|
+
current_builder[0].add_sample(
|
|
257
|
+
key=wds_dict["__key__"],
|
|
258
|
+
offset=offset,
|
|
259
|
+
size=packed_size,
|
|
260
|
+
sample=sample,
|
|
261
|
+
)
|
|
262
|
+
# Approximate tar entry: 512-byte header + data rounded to 512
|
|
263
|
+
offset += 512 + packed_size + (512 - packed_size % 512) % 512
|
|
264
|
+
|
|
265
|
+
# Finalize the last shard's builder (post isn't called for the last shard
|
|
266
|
+
# until ShardWriter closes, but we handle it here for safety)
|
|
267
|
+
if manifest and current_builder[0] is not None:
|
|
268
|
+
builder = current_builder[0]
|
|
269
|
+
if builder._rows: # Only if samples were added
|
|
270
|
+
manifest_builders.append(builder)
|
|
271
|
+
|
|
272
|
+
# Write all manifest files
|
|
273
|
+
if manifest:
|
|
274
|
+
for builder in manifest_builders:
|
|
275
|
+
built = builder.build()
|
|
276
|
+
writer = ManifestWriter(Path(temp_dir) / Path(built.shard_id))
|
|
277
|
+
json_path, parquet_path = writer.write(built)
|
|
278
|
+
|
|
279
|
+
# Upload manifest files to S3 alongside shards
|
|
280
|
+
shard_id = built.shard_id
|
|
281
|
+
json_key = f"{shard_id}.manifest.json"
|
|
282
|
+
parquet_key = f"{shard_id}.manifest.parquet"
|
|
283
|
+
|
|
284
|
+
if cache_local:
|
|
285
|
+
import boto3
|
|
286
|
+
|
|
287
|
+
s3_kwargs = {
|
|
288
|
+
"aws_access_key_id": self.credentials["AWS_ACCESS_KEY_ID"],
|
|
289
|
+
"aws_secret_access_key": self.credentials[
|
|
290
|
+
"AWS_SECRET_ACCESS_KEY"
|
|
291
|
+
],
|
|
292
|
+
}
|
|
293
|
+
if "AWS_ENDPOINT" in self.credentials:
|
|
294
|
+
s3_kwargs["endpoint_url"] = self.credentials["AWS_ENDPOINT"]
|
|
295
|
+
s3_client = boto3.client("s3", **s3_kwargs)
|
|
296
|
+
|
|
297
|
+
bucket_name = Path(shard_id).parts[0]
|
|
298
|
+
json_s3_key = str(Path(*Path(json_key).parts[1:]))
|
|
299
|
+
parquet_s3_key = str(Path(*Path(parquet_key).parts[1:]))
|
|
300
|
+
|
|
301
|
+
with open(json_path, "rb") as f:
|
|
302
|
+
s3_client.put_object(
|
|
303
|
+
Bucket=bucket_name, Key=json_s3_key, Body=f.read()
|
|
304
|
+
)
|
|
305
|
+
with open(parquet_path, "rb") as f:
|
|
306
|
+
s3_client.put_object(
|
|
307
|
+
Bucket=bucket_name, Key=parquet_s3_key, Body=f.read()
|
|
308
|
+
)
|
|
309
|
+
else:
|
|
310
|
+
self._fs.put(str(json_path), f"s3://{json_key}")
|
|
311
|
+
self._fs.put(str(parquet_path), f"s3://{parquet_key}")
|
|
312
|
+
|
|
313
|
+
if len(written_shards) == 0:
|
|
314
|
+
raise RuntimeError("No shards written")
|
|
315
|
+
|
|
316
|
+
return written_shards
|
|
317
|
+
|
|
318
|
+
def read_url(self, url: str) -> str:
|
|
319
|
+
"""Resolve an S3 URL for reading/streaming.
|
|
320
|
+
|
|
321
|
+
For S3-compatible stores with custom endpoints (like Cloudflare R2,
|
|
322
|
+
MinIO, etc.), converts s3:// URLs to HTTPS URLs that WebDataset can
|
|
323
|
+
stream directly.
|
|
324
|
+
|
|
325
|
+
For standard AWS S3 (no custom endpoint), URLs are returned unchanged
|
|
326
|
+
since WebDataset's built-in s3fs integration handles them.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
url: S3 URL to resolve (e.g., 's3://bucket/path/file.tar').
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
HTTPS URL if custom endpoint is configured, otherwise unchanged.
|
|
333
|
+
Example: 's3://bucket/path' -> 'https://endpoint.com/bucket/path'
|
|
334
|
+
"""
|
|
335
|
+
endpoint = self.credentials.get("AWS_ENDPOINT")
|
|
336
|
+
if endpoint and url.startswith("s3://"):
|
|
337
|
+
# s3://bucket/path -> https://endpoint/bucket/path
|
|
338
|
+
path = url[5:] # Remove 's3://' prefix
|
|
339
|
+
endpoint = endpoint.rstrip("/")
|
|
340
|
+
return f"{endpoint}/{path}"
|
|
341
|
+
return url
|
|
342
|
+
|
|
343
|
+
def supports_streaming(self) -> bool:
|
|
344
|
+
"""S3 supports streaming reads.
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
True.
|
|
348
|
+
"""
|
|
349
|
+
return True
|
atdata/testing.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
"""Testing utilities for atdata.
|
|
2
|
+
|
|
3
|
+
Provides mock clients, dataset factories, and pytest fixtures for writing
|
|
4
|
+
tests against atdata without requiring external services (Redis, S3, ATProto PDS).
|
|
5
|
+
|
|
6
|
+
Usage::
|
|
7
|
+
|
|
8
|
+
import atdata.testing as at_test
|
|
9
|
+
|
|
10
|
+
# Create a dataset from samples
|
|
11
|
+
ds = at_test.make_dataset(tmp_path, [sample1, sample2])
|
|
12
|
+
|
|
13
|
+
# Generate random samples
|
|
14
|
+
samples = at_test.make_samples(MyType, n=100)
|
|
15
|
+
|
|
16
|
+
# Use mock atmosphere client
|
|
17
|
+
client = at_test.MockAtmosphere()
|
|
18
|
+
|
|
19
|
+
# Use in-memory index (SQLite backed, temporary)
|
|
20
|
+
index = at_test.mock_index(tmp_path)
|
|
21
|
+
|
|
22
|
+
Pytest fixtures (available when ``atdata`` is installed)::
|
|
23
|
+
|
|
24
|
+
def test_something(mock_atmosphere):
|
|
25
|
+
client = mock_atmosphere
|
|
26
|
+
client.login("user", "pass")
|
|
27
|
+
...
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from __future__ import annotations
|
|
31
|
+
|
|
32
|
+
import tempfile
|
|
33
|
+
import uuid
|
|
34
|
+
from dataclasses import fields as dc_fields
|
|
35
|
+
from pathlib import Path
|
|
36
|
+
from typing import Any, Sequence, Type, TypeVar
|
|
37
|
+
|
|
38
|
+
import numpy as np
|
|
39
|
+
import webdataset as wds
|
|
40
|
+
|
|
41
|
+
import atdata
|
|
42
|
+
from atdata import Dataset, PackableSample
|
|
43
|
+
from atdata.index._index import Index
|
|
44
|
+
from atdata.providers._sqlite import SqliteProvider
|
|
45
|
+
|
|
46
|
+
ST = TypeVar("ST")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# ---------------------------------------------------------------------------
|
|
50
|
+
# Mock Atmosphere Client
|
|
51
|
+
# ---------------------------------------------------------------------------
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class MockAtmosphere:
|
|
55
|
+
"""In-memory mock of ``Atmosphere`` for testing.
|
|
56
|
+
|
|
57
|
+
Simulates login, schema publishing, dataset publishing, and record
|
|
58
|
+
retrieval without requiring a live ATProto PDS.
|
|
59
|
+
|
|
60
|
+
Examples:
|
|
61
|
+
>>> client = MockAtmosphere()
|
|
62
|
+
>>> client.login("alice.test", "password")
|
|
63
|
+
>>> client.did
|
|
64
|
+
'did:plc:mock000000000000'
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
did: str = "did:plc:mock000000000000",
|
|
70
|
+
handle: str = "test.mock.social",
|
|
71
|
+
) -> None:
|
|
72
|
+
self.did = did
|
|
73
|
+
self.handle = handle
|
|
74
|
+
self._logged_in = False
|
|
75
|
+
self._records: dict[str, dict[str, Any]] = {}
|
|
76
|
+
self._schemas: dict[str, dict[str, Any]] = {}
|
|
77
|
+
self._datasets: dict[str, dict[str, Any]] = {}
|
|
78
|
+
self._blobs: dict[str, bytes] = {}
|
|
79
|
+
self._session_string = "mock-session-string"
|
|
80
|
+
self._call_log: list[tuple[str, dict[str, Any]]] = []
|
|
81
|
+
|
|
82
|
+
def login(self, handle: str, password: str) -> dict[str, Any]:
|
|
83
|
+
"""Simulate login. Always succeeds."""
|
|
84
|
+
self._logged_in = True
|
|
85
|
+
self.handle = handle
|
|
86
|
+
self._call_log.append(("login", {"handle": handle}))
|
|
87
|
+
return {"did": self.did, "handle": self.handle}
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def is_authenticated(self) -> bool:
|
|
91
|
+
return self._logged_in
|
|
92
|
+
|
|
93
|
+
def export_session_string(self) -> str:
|
|
94
|
+
return self._session_string
|
|
95
|
+
|
|
96
|
+
def create_record(
|
|
97
|
+
self,
|
|
98
|
+
collection: str,
|
|
99
|
+
record: dict[str, Any],
|
|
100
|
+
rkey: str | None = None,
|
|
101
|
+
) -> str:
|
|
102
|
+
"""Simulate creating a record. Returns a mock AT URI."""
|
|
103
|
+
key = rkey or uuid.uuid4().hex[:12]
|
|
104
|
+
uri = f"at://{self.did}/{collection}/{key}"
|
|
105
|
+
self._records[uri] = record
|
|
106
|
+
self._call_log.append(
|
|
107
|
+
("create_record", {"collection": collection, "rkey": key, "uri": uri})
|
|
108
|
+
)
|
|
109
|
+
return uri
|
|
110
|
+
|
|
111
|
+
def get_record(self, uri: str) -> dict[str, Any]:
|
|
112
|
+
"""Retrieve a previously created record by URI."""
|
|
113
|
+
if uri not in self._records:
|
|
114
|
+
raise KeyError(f"Record not found: {uri}")
|
|
115
|
+
return self._records[uri]
|
|
116
|
+
|
|
117
|
+
def list_records(self, collection: str) -> list[dict[str, Any]]:
|
|
118
|
+
"""List records for a collection."""
|
|
119
|
+
return [
|
|
120
|
+
{"uri": uri, "value": rec}
|
|
121
|
+
for uri, rec in self._records.items()
|
|
122
|
+
if collection in uri
|
|
123
|
+
]
|
|
124
|
+
|
|
125
|
+
def upload_blob(self, data: bytes) -> dict[str, Any]:
|
|
126
|
+
"""Simulate uploading a blob. Returns a mock blob ref."""
|
|
127
|
+
ref = f"blob:{uuid.uuid4().hex[:16]}"
|
|
128
|
+
self._blobs[ref] = data
|
|
129
|
+
self._call_log.append(("upload_blob", {"ref": ref, "size": len(data)}))
|
|
130
|
+
return {"ref": {"$link": ref}, "mimeType": "application/octet-stream"}
|
|
131
|
+
|
|
132
|
+
def get_blob(self, did: str, cid: str) -> bytes:
|
|
133
|
+
"""Retrieve a previously uploaded blob."""
|
|
134
|
+
if cid not in self._blobs:
|
|
135
|
+
raise KeyError(f"Blob not found: {cid}")
|
|
136
|
+
return self._blobs[cid]
|
|
137
|
+
|
|
138
|
+
def reset(self) -> None:
|
|
139
|
+
"""Clear all stored state."""
|
|
140
|
+
self._records.clear()
|
|
141
|
+
self._schemas.clear()
|
|
142
|
+
self._datasets.clear()
|
|
143
|
+
self._blobs.clear()
|
|
144
|
+
self._call_log.clear()
|
|
145
|
+
self._logged_in = False
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
# ---------------------------------------------------------------------------
|
|
149
|
+
# Dataset Factory
|
|
150
|
+
# ---------------------------------------------------------------------------
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def make_dataset(
|
|
154
|
+
path: Path,
|
|
155
|
+
samples: Sequence[PackableSample],
|
|
156
|
+
*,
|
|
157
|
+
name: str = "test",
|
|
158
|
+
sample_type: type | None = None,
|
|
159
|
+
) -> Dataset:
|
|
160
|
+
"""Create a ``Dataset`` from a list of samples.
|
|
161
|
+
|
|
162
|
+
Writes the samples to a WebDataset tar file in *path* and returns a
|
|
163
|
+
``Dataset`` configured to read them back.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
path: Directory where the tar file will be created.
|
|
167
|
+
samples: Sequence of ``PackableSample`` (or ``@packable``) instances.
|
|
168
|
+
name: Filename prefix for the tar file.
|
|
169
|
+
sample_type: Explicit sample type for the Dataset generic parameter.
|
|
170
|
+
If ``None``, inferred from the first sample.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
A ``Dataset`` ready for iteration.
|
|
174
|
+
|
|
175
|
+
Examples:
|
|
176
|
+
>>> ds = make_dataset(tmp_path, [MySample(x=1), MySample(x=2)])
|
|
177
|
+
>>> assert len(list(ds.ordered())) == 2
|
|
178
|
+
"""
|
|
179
|
+
if not samples:
|
|
180
|
+
raise ValueError("samples must be non-empty")
|
|
181
|
+
|
|
182
|
+
tar_path = path / f"{name}-000000.tar"
|
|
183
|
+
tar_path.parent.mkdir(parents=True, exist_ok=True)
|
|
184
|
+
|
|
185
|
+
with wds.writer.TarWriter(str(tar_path)) as writer:
|
|
186
|
+
for sample in samples:
|
|
187
|
+
writer.write(sample.as_wds)
|
|
188
|
+
|
|
189
|
+
st = sample_type or type(samples[0])
|
|
190
|
+
return Dataset[st](url=str(tar_path))
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def make_samples(
|
|
194
|
+
sample_type: Type[ST], n: int = 10, seed: int | None = None
|
|
195
|
+
) -> list[ST]:
|
|
196
|
+
"""Generate *n* random instances of a ``@packable`` sample type.
|
|
197
|
+
|
|
198
|
+
Inspects the dataclass fields and generates appropriate random data:
|
|
199
|
+
- ``str`` fields get ``"field_name_0"``, ``"field_name_1"``, etc.
|
|
200
|
+
- ``int`` fields get sequential integers
|
|
201
|
+
- ``float`` fields get random floats in [0, 1)
|
|
202
|
+
- ``bool`` fields alternate True/False
|
|
203
|
+
- ``bytes`` fields get random 16 bytes
|
|
204
|
+
- NDArray fields get random ``(4, 4)`` float32 arrays
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
sample_type: A ``@packable``-decorated class or ``PackableSample`` subclass.
|
|
208
|
+
n: Number of samples to generate.
|
|
209
|
+
seed: Optional random seed for reproducibility.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
List of *n* sample instances.
|
|
213
|
+
|
|
214
|
+
Examples:
|
|
215
|
+
>>> @atdata.packable
|
|
216
|
+
... class Point:
|
|
217
|
+
... x: float
|
|
218
|
+
... y: float
|
|
219
|
+
... label: str
|
|
220
|
+
>>> points = make_samples(Point, n=5, seed=42)
|
|
221
|
+
>>> len(points)
|
|
222
|
+
5
|
|
223
|
+
"""
|
|
224
|
+
rng = np.random.default_rng(seed)
|
|
225
|
+
result: list[ST] = []
|
|
226
|
+
|
|
227
|
+
for i in range(n):
|
|
228
|
+
kwargs: dict[str, Any] = {}
|
|
229
|
+
for field in dc_fields(sample_type):
|
|
230
|
+
type_str = str(field.type)
|
|
231
|
+
fname = field.name
|
|
232
|
+
|
|
233
|
+
if field.type is str or type_str == "str":
|
|
234
|
+
kwargs[fname] = f"{fname}_{i}"
|
|
235
|
+
elif field.type is int or type_str == "int":
|
|
236
|
+
kwargs[fname] = i
|
|
237
|
+
elif field.type is float or type_str == "float":
|
|
238
|
+
kwargs[fname] = float(rng.random())
|
|
239
|
+
elif field.type is bool or type_str == "bool":
|
|
240
|
+
kwargs[fname] = i % 2 == 0
|
|
241
|
+
elif field.type is bytes or type_str == "bytes":
|
|
242
|
+
kwargs[fname] = rng.bytes(16)
|
|
243
|
+
elif "NDArray" in type_str or "ndarray" in type_str.lower():
|
|
244
|
+
kwargs[fname] = rng.standard_normal((4, 4)).astype(np.float32)
|
|
245
|
+
elif "list" in type_str.lower():
|
|
246
|
+
kwargs[fname] = [f"{fname}_{i}_{j}" for j in range(3)]
|
|
247
|
+
elif "None" in type_str:
|
|
248
|
+
# Optional field — leave at default
|
|
249
|
+
if field.default is not field.default_factory: # type: ignore[attr-defined]
|
|
250
|
+
continue
|
|
251
|
+
else:
|
|
252
|
+
kwargs[fname] = f"{fname}_{i}"
|
|
253
|
+
|
|
254
|
+
result.append(sample_type(**kwargs))
|
|
255
|
+
|
|
256
|
+
return result
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
# ---------------------------------------------------------------------------
|
|
260
|
+
# Mock Index
|
|
261
|
+
# ---------------------------------------------------------------------------
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def mock_index(path: Path | None = None, **kwargs: Any) -> Index:
|
|
265
|
+
"""Create an in-memory SQLite-backed ``Index`` for testing.
|
|
266
|
+
|
|
267
|
+
No Redis or external services required.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
path: Directory for the SQLite database file. If ``None``, uses
|
|
271
|
+
a temporary directory.
|
|
272
|
+
**kwargs: Additional keyword arguments passed to ``Index()``.
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
An ``Index`` instance backed by a temporary SQLite database.
|
|
276
|
+
|
|
277
|
+
Examples:
|
|
278
|
+
>>> index = mock_index(tmp_path)
|
|
279
|
+
>>> ref = index.publish_schema(MyType, version="1.0.0")
|
|
280
|
+
"""
|
|
281
|
+
if path is None:
|
|
282
|
+
path = Path(tempfile.mkdtemp())
|
|
283
|
+
db_path = path / "test_index.db"
|
|
284
|
+
provider = SqliteProvider(str(db_path))
|
|
285
|
+
return Index(provider=provider, atmosphere=None, **kwargs)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
# ---------------------------------------------------------------------------
|
|
289
|
+
# Pytest plugin (fixtures auto-discovered when atdata is installed)
|
|
290
|
+
# ---------------------------------------------------------------------------
|
|
291
|
+
|
|
292
|
+
try:
|
|
293
|
+
import pytest
|
|
294
|
+
|
|
295
|
+
@pytest.fixture
|
|
296
|
+
def mock_atmosphere():
|
|
297
|
+
"""Provide a fresh ``MockAtmosphere`` for each test."""
|
|
298
|
+
client = MockAtmosphere()
|
|
299
|
+
client.login("test.mock.social", "test-password")
|
|
300
|
+
yield client
|
|
301
|
+
client.reset()
|
|
302
|
+
|
|
303
|
+
@pytest.fixture
|
|
304
|
+
def tmp_dataset(tmp_path: Path):
|
|
305
|
+
"""Provide a small ``Dataset[SharedBasicSample]`` with 10 samples.
|
|
306
|
+
|
|
307
|
+
Uses ``SharedBasicSample`` (name: str, value: int) from the test suite.
|
|
308
|
+
"""
|
|
309
|
+
|
|
310
|
+
@atdata.packable
|
|
311
|
+
class _TmpSample:
|
|
312
|
+
name: str
|
|
313
|
+
value: int
|
|
314
|
+
|
|
315
|
+
samples = [_TmpSample(name=f"s{i}", value=i) for i in range(10)]
|
|
316
|
+
return make_dataset(tmp_path, samples, sample_type=_TmpSample)
|
|
317
|
+
|
|
318
|
+
@pytest.fixture
|
|
319
|
+
def tmp_index(tmp_path: Path):
|
|
320
|
+
"""Provide a fresh SQLite-backed ``Index`` for each test."""
|
|
321
|
+
return mock_index(tmp_path)
|
|
322
|
+
|
|
323
|
+
except ImportError:
|
|
324
|
+
# pytest not installed — skip fixture registration
|
|
325
|
+
_no_pytest = True
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
# ---------------------------------------------------------------------------
|
|
329
|
+
# Public API
|
|
330
|
+
# ---------------------------------------------------------------------------
|
|
331
|
+
|
|
332
|
+
# Deprecated alias for backward compatibility
|
|
333
|
+
MockAtmosphereClient = MockAtmosphere
|
|
334
|
+
|
|
335
|
+
__all__ = [
|
|
336
|
+
"MockAtmosphere",
|
|
337
|
+
"MockAtmosphereClient", # deprecated alias
|
|
338
|
+
"make_dataset",
|
|
339
|
+
"make_samples",
|
|
340
|
+
"mock_index",
|
|
341
|
+
]
|