atdata 0.2.3b1__py3-none-any.whl → 0.3.1b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. atdata/.gitignore +1 -0
  2. atdata/__init__.py +39 -0
  3. atdata/_cid.py +0 -21
  4. atdata/_exceptions.py +168 -0
  5. atdata/_helpers.py +41 -15
  6. atdata/_hf_api.py +95 -11
  7. atdata/_logging.py +70 -0
  8. atdata/_protocols.py +77 -238
  9. atdata/_schema_codec.py +7 -6
  10. atdata/_stub_manager.py +5 -25
  11. atdata/_type_utils.py +28 -2
  12. atdata/atmosphere/__init__.py +31 -20
  13. atdata/atmosphere/_types.py +4 -4
  14. atdata/atmosphere/client.py +64 -12
  15. atdata/atmosphere/lens.py +11 -12
  16. atdata/atmosphere/records.py +12 -12
  17. atdata/atmosphere/schema.py +16 -18
  18. atdata/atmosphere/store.py +6 -7
  19. atdata/cli/__init__.py +161 -175
  20. atdata/cli/diagnose.py +2 -2
  21. atdata/cli/{local.py → infra.py} +11 -11
  22. atdata/cli/inspect.py +69 -0
  23. atdata/cli/preview.py +63 -0
  24. atdata/cli/schema.py +109 -0
  25. atdata/dataset.py +583 -328
  26. atdata/index/__init__.py +54 -0
  27. atdata/index/_entry.py +157 -0
  28. atdata/index/_index.py +1198 -0
  29. atdata/index/_schema.py +380 -0
  30. atdata/lens.py +9 -2
  31. atdata/lexicons/__init__.py +121 -0
  32. atdata/lexicons/ac.foundation.dataset.arrayFormat.json +16 -0
  33. atdata/lexicons/ac.foundation.dataset.getLatestSchema.json +78 -0
  34. atdata/lexicons/ac.foundation.dataset.lens.json +99 -0
  35. atdata/lexicons/ac.foundation.dataset.record.json +96 -0
  36. atdata/lexicons/ac.foundation.dataset.schema.json +107 -0
  37. atdata/lexicons/ac.foundation.dataset.schemaType.json +16 -0
  38. atdata/lexicons/ac.foundation.dataset.storageBlobs.json +24 -0
  39. atdata/lexicons/ac.foundation.dataset.storageExternal.json +25 -0
  40. atdata/lexicons/ndarray_shim.json +16 -0
  41. atdata/local/__init__.py +70 -0
  42. atdata/local/_repo_legacy.py +218 -0
  43. atdata/manifest/__init__.py +28 -0
  44. atdata/manifest/_aggregates.py +156 -0
  45. atdata/manifest/_builder.py +163 -0
  46. atdata/manifest/_fields.py +154 -0
  47. atdata/manifest/_manifest.py +146 -0
  48. atdata/manifest/_query.py +150 -0
  49. atdata/manifest/_writer.py +74 -0
  50. atdata/promote.py +18 -14
  51. atdata/providers/__init__.py +25 -0
  52. atdata/providers/_base.py +140 -0
  53. atdata/providers/_factory.py +69 -0
  54. atdata/providers/_postgres.py +214 -0
  55. atdata/providers/_redis.py +171 -0
  56. atdata/providers/_sqlite.py +191 -0
  57. atdata/repository.py +323 -0
  58. atdata/stores/__init__.py +23 -0
  59. atdata/stores/_disk.py +123 -0
  60. atdata/stores/_s3.py +349 -0
  61. atdata/testing.py +341 -0
  62. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/METADATA +5 -2
  63. atdata-0.3.1b1.dist-info/RECORD +67 -0
  64. atdata/local.py +0 -1720
  65. atdata-0.2.3b1.dist-info/RECORD +0 -28
  66. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/WHEEL +0 -0
  67. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/entry_points.txt +0 -0
  68. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/licenses/LICENSE +0 -0
atdata/stores/_s3.py ADDED
@@ -0,0 +1,349 @@
1
+ """S3-compatible data store and helper functions."""
2
+
3
+ from atdata import Dataset
4
+
5
+ from pathlib import Path
6
+ from uuid import uuid4
7
+ from tempfile import TemporaryDirectory
8
+ from dotenv import dotenv_values
9
+ from typing import Any, BinaryIO, cast
10
+
11
+ from s3fs import S3FileSystem
12
+ import webdataset as wds
13
+
14
+
15
+ def _s3_env(credentials_path: str | Path) -> dict[str, Any]:
16
+ """Load S3 credentials from .env file.
17
+
18
+ Args:
19
+ credentials_path: Path to .env file containing AWS_ENDPOINT,
20
+ AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY.
21
+
22
+ Returns:
23
+ Dict with the three required credential keys.
24
+
25
+ Raises:
26
+ ValueError: If any required key is missing from the .env file.
27
+ """
28
+ credentials_path = Path(credentials_path)
29
+ env_values = dotenv_values(credentials_path)
30
+
31
+ required_keys = ("AWS_ENDPOINT", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY")
32
+ missing = [k for k in required_keys if k not in env_values]
33
+ if missing:
34
+ raise ValueError(
35
+ f"Missing required keys in {credentials_path}: {', '.join(missing)}"
36
+ )
37
+
38
+ return {k: env_values[k] for k in required_keys}
39
+
40
+
41
+ def _s3_from_credentials(creds: str | Path | dict) -> S3FileSystem:
42
+ """Create S3FileSystem from credentials dict or .env file path."""
43
+ if not isinstance(creds, dict):
44
+ creds = _s3_env(creds)
45
+
46
+ # Build kwargs, making endpoint_url optional
47
+ kwargs = {
48
+ "key": creds["AWS_ACCESS_KEY_ID"],
49
+ "secret": creds["AWS_SECRET_ACCESS_KEY"],
50
+ }
51
+ if "AWS_ENDPOINT" in creds:
52
+ kwargs["endpoint_url"] = creds["AWS_ENDPOINT"]
53
+
54
+ return S3FileSystem(**kwargs)
55
+
56
+
57
+ def _create_s3_write_callbacks(
58
+ credentials: dict[str, Any],
59
+ temp_dir: str,
60
+ written_shards: list[str],
61
+ fs: S3FileSystem | None,
62
+ cache_local: bool,
63
+ add_s3_prefix: bool = False,
64
+ ) -> tuple:
65
+ """Create opener and post callbacks for ShardWriter with S3 upload.
66
+
67
+ Args:
68
+ credentials: S3 credentials dict.
69
+ temp_dir: Temporary directory for local caching.
70
+ written_shards: List to append written shard paths to.
71
+ fs: S3FileSystem for direct writes (used when cache_local=False).
72
+ cache_local: If True, write locally then copy to S3.
73
+ add_s3_prefix: If True, prepend 's3://' to shard paths.
74
+
75
+ Returns:
76
+ Tuple of (writer_opener, writer_post) callbacks.
77
+ """
78
+ if cache_local:
79
+ import boto3
80
+
81
+ s3_client_kwargs = {
82
+ "aws_access_key_id": credentials["AWS_ACCESS_KEY_ID"],
83
+ "aws_secret_access_key": credentials["AWS_SECRET_ACCESS_KEY"],
84
+ }
85
+ if "AWS_ENDPOINT" in credentials:
86
+ s3_client_kwargs["endpoint_url"] = credentials["AWS_ENDPOINT"]
87
+ s3_client = boto3.client("s3", **s3_client_kwargs)
88
+
89
+ def _writer_opener(p: str):
90
+ local_path = Path(temp_dir) / p
91
+ local_path.parent.mkdir(parents=True, exist_ok=True)
92
+ return open(local_path, "wb")
93
+
94
+ def _writer_post(p: str):
95
+ local_path = Path(temp_dir) / p
96
+ path_parts = Path(p).parts
97
+ bucket = path_parts[0]
98
+ key = str(Path(*path_parts[1:]))
99
+
100
+ with open(local_path, "rb") as f_in:
101
+ s3_client.put_object(Bucket=bucket, Key=key, Body=f_in.read())
102
+
103
+ local_path.unlink()
104
+ if add_s3_prefix:
105
+ written_shards.append(f"s3://{p}")
106
+ else:
107
+ written_shards.append(p)
108
+
109
+ return _writer_opener, _writer_post
110
+ else:
111
+ if fs is None:
112
+ raise ValueError("S3FileSystem required when cache_local=False")
113
+
114
+ def _direct_opener(s: str):
115
+ return cast(BinaryIO, fs.open(f"s3://{s}", "wb"))
116
+
117
+ def _direct_post(s: str):
118
+ if add_s3_prefix:
119
+ written_shards.append(f"s3://{s}")
120
+ else:
121
+ written_shards.append(s)
122
+
123
+ return _direct_opener, _direct_post
124
+
125
+
126
+ class S3DataStore:
127
+ """S3-compatible data store implementing AbstractDataStore protocol.
128
+
129
+ Handles writing dataset shards to S3-compatible object storage and
130
+ resolving URLs for reading.
131
+
132
+ Attributes:
133
+ credentials: S3 credentials dictionary.
134
+ bucket: Target bucket name.
135
+ _fs: S3FileSystem instance.
136
+ """
137
+
138
+ def __init__(
139
+ self,
140
+ credentials: str | Path | dict[str, Any],
141
+ *,
142
+ bucket: str,
143
+ ) -> None:
144
+ """Initialize an S3 data store.
145
+
146
+ Args:
147
+ credentials: Path to .env file or dict with AWS_ACCESS_KEY_ID,
148
+ AWS_SECRET_ACCESS_KEY, and optionally AWS_ENDPOINT.
149
+ bucket: Name of the S3 bucket for storage.
150
+ """
151
+ if isinstance(credentials, dict):
152
+ self.credentials = credentials
153
+ else:
154
+ self.credentials = _s3_env(credentials)
155
+
156
+ self.bucket = bucket
157
+ self._fs = _s3_from_credentials(self.credentials)
158
+
159
+ def write_shards(
160
+ self,
161
+ ds: Dataset,
162
+ *,
163
+ prefix: str,
164
+ cache_local: bool = False,
165
+ manifest: bool = False,
166
+ schema_version: str = "1.0.0",
167
+ source_job_id: str | None = None,
168
+ parent_shards: list[str] | None = None,
169
+ pipeline_version: str | None = None,
170
+ **kwargs,
171
+ ) -> list[str]:
172
+ """Write dataset shards to S3.
173
+
174
+ Args:
175
+ ds: The Dataset to write.
176
+ prefix: Path prefix within bucket (e.g., 'datasets/mnist/v1').
177
+ cache_local: If True, write locally first then copy to S3.
178
+ manifest: If True, generate per-shard manifest files alongside
179
+ each tar shard (``.manifest.json`` + ``.manifest.parquet``).
180
+ schema_version: Schema version for manifest headers.
181
+ source_job_id: Optional provenance job identifier for manifests.
182
+ parent_shards: Optional list of input shard identifiers for provenance.
183
+ pipeline_version: Optional pipeline version string for provenance.
184
+ **kwargs: Additional args passed to wds.ShardWriter (e.g., maxcount).
185
+
186
+ Returns:
187
+ List of S3 URLs for the written shards.
188
+
189
+ Raises:
190
+ RuntimeError: If no shards were written.
191
+ """
192
+ new_uuid = str(uuid4())
193
+ shard_pattern = f"{self.bucket}/{prefix}/data--{new_uuid}--%06d.tar"
194
+
195
+ written_shards: list[str] = []
196
+
197
+ # Manifest tracking state shared with the post callback
198
+ manifest_builders: list = []
199
+ current_builder: list = [None] # mutable ref for closure
200
+ shard_counter: list[int] = [0]
201
+
202
+ if manifest:
203
+ from atdata.manifest import ManifestBuilder, ManifestWriter
204
+
205
+ def _make_builder(shard_idx: int) -> ManifestBuilder:
206
+ shard_id = f"{self.bucket}/{prefix}/data--{new_uuid}--{shard_idx:06d}"
207
+ return ManifestBuilder(
208
+ sample_type=ds.sample_type,
209
+ shard_id=shard_id,
210
+ schema_version=schema_version,
211
+ source_job_id=source_job_id,
212
+ parent_shards=parent_shards,
213
+ pipeline_version=pipeline_version,
214
+ )
215
+
216
+ current_builder[0] = _make_builder(0)
217
+
218
+ with TemporaryDirectory() as temp_dir:
219
+ writer_opener, writer_post_orig = _create_s3_write_callbacks(
220
+ credentials=self.credentials,
221
+ temp_dir=temp_dir,
222
+ written_shards=written_shards,
223
+ fs=self._fs,
224
+ cache_local=cache_local,
225
+ add_s3_prefix=True,
226
+ )
227
+
228
+ if manifest:
229
+
230
+ def writer_post(p: str):
231
+ # Finalize the current manifest builder when a shard completes
232
+ builder = current_builder[0]
233
+ if builder is not None:
234
+ manifest_builders.append(builder)
235
+ # Advance to the next shard's builder
236
+ shard_counter[0] += 1
237
+ current_builder[0] = _make_builder(shard_counter[0])
238
+ # Call original post callback
239
+ writer_post_orig(p)
240
+ else:
241
+ writer_post = writer_post_orig
242
+
243
+ offset = 0
244
+ with wds.writer.ShardWriter(
245
+ shard_pattern,
246
+ opener=writer_opener,
247
+ post=writer_post,
248
+ **kwargs,
249
+ ) as sink:
250
+ for sample in ds.ordered(batch_size=None):
251
+ wds_dict = sample.as_wds
252
+ sink.write(wds_dict)
253
+
254
+ if manifest and current_builder[0] is not None:
255
+ packed_size = len(wds_dict.get("msgpack", b""))
256
+ current_builder[0].add_sample(
257
+ key=wds_dict["__key__"],
258
+ offset=offset,
259
+ size=packed_size,
260
+ sample=sample,
261
+ )
262
+ # Approximate tar entry: 512-byte header + data rounded to 512
263
+ offset += 512 + packed_size + (512 - packed_size % 512) % 512
264
+
265
+ # Finalize the last shard's builder (post isn't called for the last shard
266
+ # until ShardWriter closes, but we handle it here for safety)
267
+ if manifest and current_builder[0] is not None:
268
+ builder = current_builder[0]
269
+ if builder._rows: # Only if samples were added
270
+ manifest_builders.append(builder)
271
+
272
+ # Write all manifest files
273
+ if manifest:
274
+ for builder in manifest_builders:
275
+ built = builder.build()
276
+ writer = ManifestWriter(Path(temp_dir) / Path(built.shard_id))
277
+ json_path, parquet_path = writer.write(built)
278
+
279
+ # Upload manifest files to S3 alongside shards
280
+ shard_id = built.shard_id
281
+ json_key = f"{shard_id}.manifest.json"
282
+ parquet_key = f"{shard_id}.manifest.parquet"
283
+
284
+ if cache_local:
285
+ import boto3
286
+
287
+ s3_kwargs = {
288
+ "aws_access_key_id": self.credentials["AWS_ACCESS_KEY_ID"],
289
+ "aws_secret_access_key": self.credentials[
290
+ "AWS_SECRET_ACCESS_KEY"
291
+ ],
292
+ }
293
+ if "AWS_ENDPOINT" in self.credentials:
294
+ s3_kwargs["endpoint_url"] = self.credentials["AWS_ENDPOINT"]
295
+ s3_client = boto3.client("s3", **s3_kwargs)
296
+
297
+ bucket_name = Path(shard_id).parts[0]
298
+ json_s3_key = str(Path(*Path(json_key).parts[1:]))
299
+ parquet_s3_key = str(Path(*Path(parquet_key).parts[1:]))
300
+
301
+ with open(json_path, "rb") as f:
302
+ s3_client.put_object(
303
+ Bucket=bucket_name, Key=json_s3_key, Body=f.read()
304
+ )
305
+ with open(parquet_path, "rb") as f:
306
+ s3_client.put_object(
307
+ Bucket=bucket_name, Key=parquet_s3_key, Body=f.read()
308
+ )
309
+ else:
310
+ self._fs.put(str(json_path), f"s3://{json_key}")
311
+ self._fs.put(str(parquet_path), f"s3://{parquet_key}")
312
+
313
+ if len(written_shards) == 0:
314
+ raise RuntimeError("No shards written")
315
+
316
+ return written_shards
317
+
318
+ def read_url(self, url: str) -> str:
319
+ """Resolve an S3 URL for reading/streaming.
320
+
321
+ For S3-compatible stores with custom endpoints (like Cloudflare R2,
322
+ MinIO, etc.), converts s3:// URLs to HTTPS URLs that WebDataset can
323
+ stream directly.
324
+
325
+ For standard AWS S3 (no custom endpoint), URLs are returned unchanged
326
+ since WebDataset's built-in s3fs integration handles them.
327
+
328
+ Args:
329
+ url: S3 URL to resolve (e.g., 's3://bucket/path/file.tar').
330
+
331
+ Returns:
332
+ HTTPS URL if custom endpoint is configured, otherwise unchanged.
333
+ Example: 's3://bucket/path' -> 'https://endpoint.com/bucket/path'
334
+ """
335
+ endpoint = self.credentials.get("AWS_ENDPOINT")
336
+ if endpoint and url.startswith("s3://"):
337
+ # s3://bucket/path -> https://endpoint/bucket/path
338
+ path = url[5:] # Remove 's3://' prefix
339
+ endpoint = endpoint.rstrip("/")
340
+ return f"{endpoint}/{path}"
341
+ return url
342
+
343
+ def supports_streaming(self) -> bool:
344
+ """S3 supports streaming reads.
345
+
346
+ Returns:
347
+ True.
348
+ """
349
+ return True
atdata/testing.py ADDED
@@ -0,0 +1,341 @@
1
+ """Testing utilities for atdata.
2
+
3
+ Provides mock clients, dataset factories, and pytest fixtures for writing
4
+ tests against atdata without requiring external services (Redis, S3, ATProto PDS).
5
+
6
+ Usage::
7
+
8
+ import atdata.testing as at_test
9
+
10
+ # Create a dataset from samples
11
+ ds = at_test.make_dataset(tmp_path, [sample1, sample2])
12
+
13
+ # Generate random samples
14
+ samples = at_test.make_samples(MyType, n=100)
15
+
16
+ # Use mock atmosphere client
17
+ client = at_test.MockAtmosphere()
18
+
19
+ # Use in-memory index (SQLite backed, temporary)
20
+ index = at_test.mock_index(tmp_path)
21
+
22
+ Pytest fixtures (available when ``atdata`` is installed)::
23
+
24
+ def test_something(mock_atmosphere):
25
+ client = mock_atmosphere
26
+ client.login("user", "pass")
27
+ ...
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ import tempfile
33
+ import uuid
34
+ from dataclasses import fields as dc_fields
35
+ from pathlib import Path
36
+ from typing import Any, Sequence, Type, TypeVar
37
+
38
+ import numpy as np
39
+ import webdataset as wds
40
+
41
+ import atdata
42
+ from atdata import Dataset, PackableSample
43
+ from atdata.index._index import Index
44
+ from atdata.providers._sqlite import SqliteProvider
45
+
46
+ ST = TypeVar("ST")
47
+
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # Mock Atmosphere Client
51
+ # ---------------------------------------------------------------------------
52
+
53
+
54
+ class MockAtmosphere:
55
+ """In-memory mock of ``Atmosphere`` for testing.
56
+
57
+ Simulates login, schema publishing, dataset publishing, and record
58
+ retrieval without requiring a live ATProto PDS.
59
+
60
+ Examples:
61
+ >>> client = MockAtmosphere()
62
+ >>> client.login("alice.test", "password")
63
+ >>> client.did
64
+ 'did:plc:mock000000000000'
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ did: str = "did:plc:mock000000000000",
70
+ handle: str = "test.mock.social",
71
+ ) -> None:
72
+ self.did = did
73
+ self.handle = handle
74
+ self._logged_in = False
75
+ self._records: dict[str, dict[str, Any]] = {}
76
+ self._schemas: dict[str, dict[str, Any]] = {}
77
+ self._datasets: dict[str, dict[str, Any]] = {}
78
+ self._blobs: dict[str, bytes] = {}
79
+ self._session_string = "mock-session-string"
80
+ self._call_log: list[tuple[str, dict[str, Any]]] = []
81
+
82
+ def login(self, handle: str, password: str) -> dict[str, Any]:
83
+ """Simulate login. Always succeeds."""
84
+ self._logged_in = True
85
+ self.handle = handle
86
+ self._call_log.append(("login", {"handle": handle}))
87
+ return {"did": self.did, "handle": self.handle}
88
+
89
+ @property
90
+ def is_authenticated(self) -> bool:
91
+ return self._logged_in
92
+
93
+ def export_session_string(self) -> str:
94
+ return self._session_string
95
+
96
+ def create_record(
97
+ self,
98
+ collection: str,
99
+ record: dict[str, Any],
100
+ rkey: str | None = None,
101
+ ) -> str:
102
+ """Simulate creating a record. Returns a mock AT URI."""
103
+ key = rkey or uuid.uuid4().hex[:12]
104
+ uri = f"at://{self.did}/{collection}/{key}"
105
+ self._records[uri] = record
106
+ self._call_log.append(
107
+ ("create_record", {"collection": collection, "rkey": key, "uri": uri})
108
+ )
109
+ return uri
110
+
111
+ def get_record(self, uri: str) -> dict[str, Any]:
112
+ """Retrieve a previously created record by URI."""
113
+ if uri not in self._records:
114
+ raise KeyError(f"Record not found: {uri}")
115
+ return self._records[uri]
116
+
117
+ def list_records(self, collection: str) -> list[dict[str, Any]]:
118
+ """List records for a collection."""
119
+ return [
120
+ {"uri": uri, "value": rec}
121
+ for uri, rec in self._records.items()
122
+ if collection in uri
123
+ ]
124
+
125
+ def upload_blob(self, data: bytes) -> dict[str, Any]:
126
+ """Simulate uploading a blob. Returns a mock blob ref."""
127
+ ref = f"blob:{uuid.uuid4().hex[:16]}"
128
+ self._blobs[ref] = data
129
+ self._call_log.append(("upload_blob", {"ref": ref, "size": len(data)}))
130
+ return {"ref": {"$link": ref}, "mimeType": "application/octet-stream"}
131
+
132
+ def get_blob(self, did: str, cid: str) -> bytes:
133
+ """Retrieve a previously uploaded blob."""
134
+ if cid not in self._blobs:
135
+ raise KeyError(f"Blob not found: {cid}")
136
+ return self._blobs[cid]
137
+
138
+ def reset(self) -> None:
139
+ """Clear all stored state."""
140
+ self._records.clear()
141
+ self._schemas.clear()
142
+ self._datasets.clear()
143
+ self._blobs.clear()
144
+ self._call_log.clear()
145
+ self._logged_in = False
146
+
147
+
148
+ # ---------------------------------------------------------------------------
149
+ # Dataset Factory
150
+ # ---------------------------------------------------------------------------
151
+
152
+
153
+ def make_dataset(
154
+ path: Path,
155
+ samples: Sequence[PackableSample],
156
+ *,
157
+ name: str = "test",
158
+ sample_type: type | None = None,
159
+ ) -> Dataset:
160
+ """Create a ``Dataset`` from a list of samples.
161
+
162
+ Writes the samples to a WebDataset tar file in *path* and returns a
163
+ ``Dataset`` configured to read them back.
164
+
165
+ Args:
166
+ path: Directory where the tar file will be created.
167
+ samples: Sequence of ``PackableSample`` (or ``@packable``) instances.
168
+ name: Filename prefix for the tar file.
169
+ sample_type: Explicit sample type for the Dataset generic parameter.
170
+ If ``None``, inferred from the first sample.
171
+
172
+ Returns:
173
+ A ``Dataset`` ready for iteration.
174
+
175
+ Examples:
176
+ >>> ds = make_dataset(tmp_path, [MySample(x=1), MySample(x=2)])
177
+ >>> assert len(list(ds.ordered())) == 2
178
+ """
179
+ if not samples:
180
+ raise ValueError("samples must be non-empty")
181
+
182
+ tar_path = path / f"{name}-000000.tar"
183
+ tar_path.parent.mkdir(parents=True, exist_ok=True)
184
+
185
+ with wds.writer.TarWriter(str(tar_path)) as writer:
186
+ for sample in samples:
187
+ writer.write(sample.as_wds)
188
+
189
+ st = sample_type or type(samples[0])
190
+ return Dataset[st](url=str(tar_path))
191
+
192
+
193
+ def make_samples(
194
+ sample_type: Type[ST], n: int = 10, seed: int | None = None
195
+ ) -> list[ST]:
196
+ """Generate *n* random instances of a ``@packable`` sample type.
197
+
198
+ Inspects the dataclass fields and generates appropriate random data:
199
+ - ``str`` fields get ``"field_name_0"``, ``"field_name_1"``, etc.
200
+ - ``int`` fields get sequential integers
201
+ - ``float`` fields get random floats in [0, 1)
202
+ - ``bool`` fields alternate True/False
203
+ - ``bytes`` fields get random 16 bytes
204
+ - NDArray fields get random ``(4, 4)`` float32 arrays
205
+
206
+ Args:
207
+ sample_type: A ``@packable``-decorated class or ``PackableSample`` subclass.
208
+ n: Number of samples to generate.
209
+ seed: Optional random seed for reproducibility.
210
+
211
+ Returns:
212
+ List of *n* sample instances.
213
+
214
+ Examples:
215
+ >>> @atdata.packable
216
+ ... class Point:
217
+ ... x: float
218
+ ... y: float
219
+ ... label: str
220
+ >>> points = make_samples(Point, n=5, seed=42)
221
+ >>> len(points)
222
+ 5
223
+ """
224
+ rng = np.random.default_rng(seed)
225
+ result: list[ST] = []
226
+
227
+ for i in range(n):
228
+ kwargs: dict[str, Any] = {}
229
+ for field in dc_fields(sample_type):
230
+ type_str = str(field.type)
231
+ fname = field.name
232
+
233
+ if field.type is str or type_str == "str":
234
+ kwargs[fname] = f"{fname}_{i}"
235
+ elif field.type is int or type_str == "int":
236
+ kwargs[fname] = i
237
+ elif field.type is float or type_str == "float":
238
+ kwargs[fname] = float(rng.random())
239
+ elif field.type is bool or type_str == "bool":
240
+ kwargs[fname] = i % 2 == 0
241
+ elif field.type is bytes or type_str == "bytes":
242
+ kwargs[fname] = rng.bytes(16)
243
+ elif "NDArray" in type_str or "ndarray" in type_str.lower():
244
+ kwargs[fname] = rng.standard_normal((4, 4)).astype(np.float32)
245
+ elif "list" in type_str.lower():
246
+ kwargs[fname] = [f"{fname}_{i}_{j}" for j in range(3)]
247
+ elif "None" in type_str:
248
+ # Optional field — leave at default
249
+ if field.default is not field.default_factory: # type: ignore[attr-defined]
250
+ continue
251
+ else:
252
+ kwargs[fname] = f"{fname}_{i}"
253
+
254
+ result.append(sample_type(**kwargs))
255
+
256
+ return result
257
+
258
+
259
+ # ---------------------------------------------------------------------------
260
+ # Mock Index
261
+ # ---------------------------------------------------------------------------
262
+
263
+
264
+ def mock_index(path: Path | None = None, **kwargs: Any) -> Index:
265
+ """Create an in-memory SQLite-backed ``Index`` for testing.
266
+
267
+ No Redis or external services required.
268
+
269
+ Args:
270
+ path: Directory for the SQLite database file. If ``None``, uses
271
+ a temporary directory.
272
+ **kwargs: Additional keyword arguments passed to ``Index()``.
273
+
274
+ Returns:
275
+ An ``Index`` instance backed by a temporary SQLite database.
276
+
277
+ Examples:
278
+ >>> index = mock_index(tmp_path)
279
+ >>> ref = index.publish_schema(MyType, version="1.0.0")
280
+ """
281
+ if path is None:
282
+ path = Path(tempfile.mkdtemp())
283
+ db_path = path / "test_index.db"
284
+ provider = SqliteProvider(str(db_path))
285
+ return Index(provider=provider, atmosphere=None, **kwargs)
286
+
287
+
288
+ # ---------------------------------------------------------------------------
289
+ # Pytest plugin (fixtures auto-discovered when atdata is installed)
290
+ # ---------------------------------------------------------------------------
291
+
292
+ try:
293
+ import pytest
294
+
295
+ @pytest.fixture
296
+ def mock_atmosphere():
297
+ """Provide a fresh ``MockAtmosphere`` for each test."""
298
+ client = MockAtmosphere()
299
+ client.login("test.mock.social", "test-password")
300
+ yield client
301
+ client.reset()
302
+
303
+ @pytest.fixture
304
+ def tmp_dataset(tmp_path: Path):
305
+ """Provide a small ``Dataset[SharedBasicSample]`` with 10 samples.
306
+
307
+ Uses ``SharedBasicSample`` (name: str, value: int) from the test suite.
308
+ """
309
+
310
+ @atdata.packable
311
+ class _TmpSample:
312
+ name: str
313
+ value: int
314
+
315
+ samples = [_TmpSample(name=f"s{i}", value=i) for i in range(10)]
316
+ return make_dataset(tmp_path, samples, sample_type=_TmpSample)
317
+
318
+ @pytest.fixture
319
+ def tmp_index(tmp_path: Path):
320
+ """Provide a fresh SQLite-backed ``Index`` for each test."""
321
+ return mock_index(tmp_path)
322
+
323
+ except ImportError:
324
+ # pytest not installed — skip fixture registration
325
+ _no_pytest = True
326
+
327
+
328
+ # ---------------------------------------------------------------------------
329
+ # Public API
330
+ # ---------------------------------------------------------------------------
331
+
332
+ # Deprecated alias for backward compatibility
333
+ MockAtmosphereClient = MockAtmosphere
334
+
335
+ __all__ = [
336
+ "MockAtmosphere",
337
+ "MockAtmosphereClient", # deprecated alias
338
+ "make_dataset",
339
+ "make_samples",
340
+ "mock_index",
341
+ ]