atdata 0.1.3b4__py3-none-any.whl → 0.2.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,239 @@
1
+ """Schema publishing and loading for ATProto.
2
+
3
+ This module provides classes for publishing PackableSample schemas to ATProto
4
+ and loading them back. Schemas are published as ``ac.foundation.dataset.sampleSchema``
5
+ records.
6
+ """
7
+
8
+ from dataclasses import fields, is_dataclass
9
+ from typing import Type, TypeVar, Optional, get_type_hints, get_origin, get_args
10
+
11
+ from .client import AtmosphereClient
12
+ from ._types import (
13
+ AtUri,
14
+ SchemaRecord,
15
+ FieldDef,
16
+ FieldType,
17
+ LEXICON_NAMESPACE,
18
+ )
19
+ from .._type_utils import (
20
+ numpy_dtype_to_string,
21
+ unwrap_optional,
22
+ is_ndarray_type,
23
+ extract_ndarray_dtype,
24
+ )
25
+
26
+ # Import for type checking only to avoid circular imports
27
+ from typing import TYPE_CHECKING
28
+ if TYPE_CHECKING:
29
+ from ..dataset import PackableSample
30
+
31
+ ST = TypeVar("ST", bound="PackableSample")
32
+
33
+
34
+ class SchemaPublisher:
35
+ """Publishes PackableSample schemas to ATProto.
36
+
37
+ This class introspects a PackableSample class to extract its field
38
+ definitions and publishes them as an ATProto schema record.
39
+
40
+ Example:
41
+ ::
42
+
43
+ >>> @atdata.packable
44
+ ... class MySample:
45
+ ... image: NDArray
46
+ ... label: str
47
+ ...
48
+ >>> client = AtmosphereClient()
49
+ >>> client.login("handle", "password")
50
+ >>>
51
+ >>> publisher = SchemaPublisher(client)
52
+ >>> uri = publisher.publish(MySample, version="1.0.0")
53
+ >>> print(uri)
54
+ at://did:plc:.../ac.foundation.dataset.sampleSchema/...
55
+ """
56
+
57
+ def __init__(self, client: AtmosphereClient):
58
+ """Initialize the schema publisher.
59
+
60
+ Args:
61
+ client: Authenticated AtmosphereClient instance.
62
+ """
63
+ self.client = client
64
+
65
+ def publish(
66
+ self,
67
+ sample_type: Type[ST],
68
+ *,
69
+ name: Optional[str] = None,
70
+ version: str = "1.0.0",
71
+ description: Optional[str] = None,
72
+ metadata: Optional[dict] = None,
73
+ rkey: Optional[str] = None,
74
+ ) -> AtUri:
75
+ """Publish a PackableSample schema to ATProto.
76
+
77
+ Args:
78
+ sample_type: The PackableSample class to publish.
79
+ name: Human-readable name. Defaults to the class name.
80
+ version: Semantic version string (e.g., '1.0.0').
81
+ description: Human-readable description.
82
+ metadata: Arbitrary metadata dictionary.
83
+ rkey: Optional explicit record key. If not provided, a TID is generated.
84
+
85
+ Returns:
86
+ The AT URI of the created schema record.
87
+
88
+ Raises:
89
+ ValueError: If sample_type is not a dataclass or client is not authenticated.
90
+ TypeError: If a field type is not supported.
91
+ """
92
+ if not is_dataclass(sample_type):
93
+ raise ValueError(f"{sample_type.__name__} must be a dataclass (use @packable)")
94
+
95
+ # Build the schema record
96
+ schema_record = self._build_schema_record(
97
+ sample_type,
98
+ name=name,
99
+ version=version,
100
+ description=description,
101
+ metadata=metadata,
102
+ )
103
+
104
+ # Publish to ATProto
105
+ return self.client.create_record(
106
+ collection=f"{LEXICON_NAMESPACE}.sampleSchema",
107
+ record=schema_record.to_record(),
108
+ rkey=rkey,
109
+ validate=False, # PDS doesn't know our lexicon
110
+ )
111
+
112
+ def _build_schema_record(
113
+ self,
114
+ sample_type: Type[ST],
115
+ *,
116
+ name: Optional[str],
117
+ version: str,
118
+ description: Optional[str],
119
+ metadata: Optional[dict],
120
+ ) -> SchemaRecord:
121
+ """Build a SchemaRecord from a PackableSample class."""
122
+ field_defs = []
123
+ type_hints = get_type_hints(sample_type)
124
+
125
+ for f in fields(sample_type):
126
+ field_type = type_hints.get(f.name, f.type)
127
+ field_def = self._field_to_def(f.name, field_type)
128
+ field_defs.append(field_def)
129
+
130
+ return SchemaRecord(
131
+ name=name or sample_type.__name__,
132
+ version=version,
133
+ description=description,
134
+ fields=field_defs,
135
+ metadata=metadata,
136
+ )
137
+
138
+ def _field_to_def(self, name: str, python_type) -> FieldDef:
139
+ """Convert a Python field to a FieldDef."""
140
+ python_type, is_optional = unwrap_optional(python_type)
141
+ field_type = self._python_type_to_field_type(python_type)
142
+ return FieldDef(name=name, field_type=field_type, optional=is_optional)
143
+
144
+ def _python_type_to_field_type(self, python_type) -> FieldType:
145
+ """Map a Python type to a FieldType."""
146
+ if python_type is str:
147
+ return FieldType(kind="primitive", primitive="str")
148
+ if python_type is int:
149
+ return FieldType(kind="primitive", primitive="int")
150
+ if python_type is float:
151
+ return FieldType(kind="primitive", primitive="float")
152
+ if python_type is bool:
153
+ return FieldType(kind="primitive", primitive="bool")
154
+ if python_type is bytes:
155
+ return FieldType(kind="primitive", primitive="bytes")
156
+
157
+ if is_ndarray_type(python_type):
158
+ return FieldType(kind="ndarray", dtype=extract_ndarray_dtype(python_type), shape=None)
159
+
160
+ origin = get_origin(python_type)
161
+ if origin is list:
162
+ args = get_args(python_type)
163
+ items = self._python_type_to_field_type(args[0]) if args else FieldType(kind="primitive", primitive="str")
164
+ return FieldType(kind="array", items=items)
165
+
166
+ if is_dataclass(python_type):
167
+ raise TypeError(
168
+ f"Nested dataclass types not yet supported: {python_type.__name__}. "
169
+ "Publish nested types separately and use references."
170
+ )
171
+
172
+ raise TypeError(f"Unsupported type for schema field: {python_type}")
173
+
174
+
175
+ class SchemaLoader:
176
+ """Loads PackableSample schemas from ATProto.
177
+
178
+ This class fetches schema records from ATProto and can list available
179
+ schemas from a repository.
180
+
181
+ Example:
182
+ ::
183
+
184
+ >>> client = AtmosphereClient()
185
+ >>> client.login("handle", "password")
186
+ >>>
187
+ >>> loader = SchemaLoader(client)
188
+ >>> schema = loader.get("at://did:plc:.../ac.foundation.dataset.sampleSchema/...")
189
+ >>> print(schema["name"])
190
+ 'MySample'
191
+ """
192
+
193
+ def __init__(self, client: AtmosphereClient):
194
+ """Initialize the schema loader.
195
+
196
+ Args:
197
+ client: AtmosphereClient instance (authentication optional for reads).
198
+ """
199
+ self.client = client
200
+
201
+ def get(self, uri: str | AtUri) -> dict:
202
+ """Fetch a schema record by AT URI.
203
+
204
+ Args:
205
+ uri: The AT URI of the schema record.
206
+
207
+ Returns:
208
+ The schema record as a dictionary.
209
+
210
+ Raises:
211
+ ValueError: If the record is not a schema record.
212
+ atproto.exceptions.AtProtocolError: If record not found.
213
+ """
214
+ record = self.client.get_record(uri)
215
+
216
+ expected_type = f"{LEXICON_NAMESPACE}.sampleSchema"
217
+ if record.get("$type") != expected_type:
218
+ raise ValueError(
219
+ f"Record at {uri} is not a schema record. "
220
+ f"Expected $type='{expected_type}', got '{record.get('$type')}'"
221
+ )
222
+
223
+ return record
224
+
225
+ def list_all(
226
+ self,
227
+ repo: Optional[str] = None,
228
+ limit: int = 100,
229
+ ) -> list[dict]:
230
+ """List schema records from a repository.
231
+
232
+ Args:
233
+ repo: The DID of the repository. Defaults to authenticated user.
234
+ limit: Maximum number of records to return.
235
+
236
+ Returns:
237
+ List of schema records.
238
+ """
239
+ return self.client.list_schemas(repo=repo, limit=limit)
@@ -0,0 +1,208 @@
1
+ """PDS blob storage for dataset shards.
2
+
3
+ This module provides ``PDSBlobStore``, an implementation of the AbstractDataStore
4
+ protocol that stores dataset shards as ATProto blobs in a Personal Data Server.
5
+
6
+ This enables fully decentralized dataset storage where both metadata (records)
7
+ and data (blobs) live on the AT Protocol network.
8
+
9
+ Example:
10
+ ::
11
+
12
+ >>> from atdata.atmosphere import AtmosphereClient, PDSBlobStore
13
+ >>>
14
+ >>> client = AtmosphereClient()
15
+ >>> client.login("handle.bsky.social", "app-password")
16
+ >>>
17
+ >>> store = PDSBlobStore(client)
18
+ >>> urls = store.write_shards(dataset, prefix="mnist/v1")
19
+ >>> print(urls)
20
+ ['at://did:plc:.../blob/bafyrei...', ...]
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import io
26
+ import tempfile
27
+ from dataclasses import dataclass
28
+ from typing import TYPE_CHECKING, Any
29
+
30
+ import webdataset as wds
31
+
32
+ if TYPE_CHECKING:
33
+ from ..dataset import Dataset
34
+ from .client import AtmosphereClient
35
+
36
+
37
+ @dataclass
38
+ class PDSBlobStore:
39
+ """PDS blob store implementing AbstractDataStore protocol.
40
+
41
+ Stores dataset shards as ATProto blobs, enabling decentralized dataset
42
+ storage on the AT Protocol network.
43
+
44
+ Each shard is written to a temporary tar file, then uploaded as a blob
45
+ to the user's PDS. The returned URLs are AT URIs that can be resolved
46
+ to HTTP URLs for streaming.
47
+
48
+ Attributes:
49
+ client: Authenticated AtmosphereClient instance.
50
+
51
+ Example:
52
+ ::
53
+
54
+ >>> store = PDSBlobStore(client)
55
+ >>> urls = store.write_shards(dataset, prefix="training/v1")
56
+ >>> # Returns AT URIs like:
57
+ >>> # ['at://did:plc:abc/blob/bafyrei...', ...]
58
+ """
59
+
60
+ client: "AtmosphereClient"
61
+
62
+ def write_shards(
63
+ self,
64
+ ds: "Dataset",
65
+ *,
66
+ prefix: str,
67
+ maxcount: int = 10000,
68
+ maxsize: float = 3e9,
69
+ **kwargs: Any,
70
+ ) -> list[str]:
71
+ """Write dataset shards as PDS blobs.
72
+
73
+ Creates tar archives from the dataset and uploads each as a blob
74
+ to the authenticated user's PDS.
75
+
76
+ Args:
77
+ ds: The Dataset to write.
78
+ prefix: Logical path prefix for naming (used in shard names only).
79
+ maxcount: Maximum samples per shard (default: 10000).
80
+ maxsize: Maximum shard size in bytes (default: 3GB, PDS limit).
81
+ **kwargs: Additional args passed to wds.ShardWriter.
82
+
83
+ Returns:
84
+ List of AT URIs for the written blobs, in format:
85
+ ``at://{did}/blob/{cid}``
86
+
87
+ Raises:
88
+ ValueError: If not authenticated.
89
+ RuntimeError: If no shards were written.
90
+
91
+ Note:
92
+ PDS blobs have size limits (typically 50MB-5GB depending on PDS).
93
+ Adjust maxcount/maxsize to stay within limits.
94
+ """
95
+ if not self.client.did:
96
+ raise ValueError("Client must be authenticated to upload blobs")
97
+
98
+ did = self.client.did
99
+ blob_urls: list[str] = []
100
+
101
+ # Write shards to temp files, upload each as blob
102
+ with tempfile.TemporaryDirectory() as temp_dir:
103
+ shard_pattern = f"{temp_dir}/shard-%06d.tar"
104
+ written_files: list[str] = []
105
+
106
+ # Track written files via custom post callback
107
+ def track_file(fname: str) -> None:
108
+ written_files.append(fname)
109
+
110
+ with wds.writer.ShardWriter(
111
+ shard_pattern,
112
+ maxcount=maxcount,
113
+ maxsize=maxsize,
114
+ post=track_file,
115
+ **kwargs,
116
+ ) as sink:
117
+ for sample in ds.ordered(batch_size=None):
118
+ sink.write(sample.as_wds)
119
+
120
+ if not written_files:
121
+ raise RuntimeError("No shards written")
122
+
123
+ # Upload each shard as a blob
124
+ for shard_path in written_files:
125
+ with open(shard_path, "rb") as f:
126
+ shard_data = f.read()
127
+
128
+ blob_ref = self.client.upload_blob(
129
+ shard_data,
130
+ mime_type="application/x-tar",
131
+ )
132
+
133
+ # Extract CID from blob reference
134
+ cid = blob_ref["ref"]["$link"]
135
+ at_uri = f"at://{did}/blob/{cid}"
136
+ blob_urls.append(at_uri)
137
+
138
+ return blob_urls
139
+
140
+ def read_url(self, url: str) -> str:
141
+ """Resolve an AT URI blob reference to an HTTP URL.
142
+
143
+ Transforms ``at://did/blob/cid`` URIs to HTTP URLs that can be
144
+ streamed by WebDataset.
145
+
146
+ Args:
147
+ url: AT URI in format ``at://{did}/blob/{cid}``.
148
+
149
+ Returns:
150
+ HTTP URL for fetching the blob via PDS API.
151
+
152
+ Raises:
153
+ ValueError: If URL format is invalid or PDS cannot be resolved.
154
+ """
155
+ if not url.startswith("at://"):
156
+ # Not an AT URI, return unchanged
157
+ return url
158
+
159
+ # Parse at://did/blob/cid
160
+ parts = url[5:].split("/") # Remove 'at://'
161
+ if len(parts) != 3 or parts[1] != "blob":
162
+ raise ValueError(f"Invalid blob AT URI format: {url}")
163
+
164
+ did, _, cid = parts
165
+ return self.client.get_blob_url(did, cid)
166
+
167
+ def supports_streaming(self) -> bool:
168
+ """PDS blobs support streaming via HTTP.
169
+
170
+ Returns:
171
+ True.
172
+ """
173
+ return True
174
+
175
+ def create_source(self, urls: list[str]) -> "BlobSource":
176
+ """Create a BlobSource for reading these AT URIs.
177
+
178
+ This is a convenience method for creating a DataSource that can
179
+ stream the blobs written by this store.
180
+
181
+ Args:
182
+ urls: List of AT URIs from write_shards().
183
+
184
+ Returns:
185
+ BlobSource configured for the given URLs.
186
+
187
+ Raises:
188
+ ValueError: If URLs are not valid AT URIs.
189
+ """
190
+ from .._sources import BlobSource
191
+
192
+ blob_refs: list[dict[str, str]] = []
193
+
194
+ for url in urls:
195
+ if not url.startswith("at://"):
196
+ raise ValueError(f"Not an AT URI: {url}")
197
+
198
+ parts = url[5:].split("/")
199
+ if len(parts) != 3 or parts[1] != "blob":
200
+ raise ValueError(f"Invalid blob AT URI: {url}")
201
+
202
+ did, _, cid = parts
203
+ blob_refs.append({"did": did, "cid": cid})
204
+
205
+ return BlobSource(blob_refs=blob_refs)
206
+
207
+
208
+ __all__ = ["PDSBlobStore"]
atdata/cli/__init__.py ADDED
@@ -0,0 +1,213 @@
1
+ """Command-line interface for atdata.
2
+
3
+ This module provides CLI commands for managing local development infrastructure
4
+ and diagnosing configuration issues.
5
+
6
+ Commands:
7
+ atdata local up Start Redis and MinIO containers for local development
8
+ atdata local down Stop local development containers
9
+ atdata diagnose Check Redis configuration and connectivity
10
+ atdata version Show version information
11
+
12
+ Example:
13
+ $ atdata local up
14
+ Starting Redis on port 6379...
15
+ Starting MinIO on port 9000...
16
+ Local infrastructure ready.
17
+
18
+ $ atdata diagnose
19
+ Checking Redis configuration...
20
+ ✓ Redis connected
21
+ ✓ Persistence enabled (AOF)
22
+ ✓ Memory policy: noeviction
23
+ """
24
+
25
+ import argparse
26
+ import sys
27
+ from typing import Sequence
28
+
29
+
30
+ def main(argv: Sequence[str] | None = None) -> int:
31
+ """Main entry point for the atdata CLI.
32
+
33
+ Args:
34
+ argv: Command-line arguments. If None, uses sys.argv[1:].
35
+
36
+ Returns:
37
+ Exit code (0 for success, non-zero for errors).
38
+ """
39
+ parser = argparse.ArgumentParser(
40
+ prog="atdata",
41
+ description="A loose federation of distributed, typed datasets",
42
+ formatter_class=argparse.RawDescriptionHelpFormatter,
43
+ )
44
+ parser.add_argument(
45
+ "--version", "-v",
46
+ action="store_true",
47
+ help="Show version information",
48
+ )
49
+
50
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
51
+
52
+ # 'local' command group
53
+ local_parser = subparsers.add_parser(
54
+ "local",
55
+ help="Manage local development infrastructure",
56
+ )
57
+ local_subparsers = local_parser.add_subparsers(
58
+ dest="local_command",
59
+ help="Local infrastructure commands",
60
+ )
61
+
62
+ # 'local up' command
63
+ up_parser = local_subparsers.add_parser(
64
+ "up",
65
+ help="Start Redis and MinIO containers",
66
+ )
67
+ up_parser.add_argument(
68
+ "--redis-port",
69
+ type=int,
70
+ default=6379,
71
+ help="Redis port (default: 6379)",
72
+ )
73
+ up_parser.add_argument(
74
+ "--minio-port",
75
+ type=int,
76
+ default=9000,
77
+ help="MinIO API port (default: 9000)",
78
+ )
79
+ up_parser.add_argument(
80
+ "--minio-console-port",
81
+ type=int,
82
+ default=9001,
83
+ help="MinIO console port (default: 9001)",
84
+ )
85
+ up_parser.add_argument(
86
+ "--detach", "-d",
87
+ action="store_true",
88
+ default=True,
89
+ help="Run containers in detached mode (default: True)",
90
+ )
91
+
92
+ # 'local down' command
93
+ down_parser = local_subparsers.add_parser(
94
+ "down",
95
+ help="Stop local development containers",
96
+ )
97
+ down_parser.add_argument(
98
+ "--volumes", "-v",
99
+ action="store_true",
100
+ help="Also remove volumes (deletes all data)",
101
+ )
102
+
103
+ # 'local status' command
104
+ local_subparsers.add_parser(
105
+ "status",
106
+ help="Show status of local infrastructure",
107
+ )
108
+
109
+ # 'diagnose' command
110
+ diagnose_parser = subparsers.add_parser(
111
+ "diagnose",
112
+ help="Diagnose Redis configuration and connectivity",
113
+ )
114
+ diagnose_parser.add_argument(
115
+ "--host",
116
+ default="localhost",
117
+ help="Redis host (default: localhost)",
118
+ )
119
+ diagnose_parser.add_argument(
120
+ "--port",
121
+ type=int,
122
+ default=6379,
123
+ help="Redis port (default: 6379)",
124
+ )
125
+
126
+ # 'version' command (alternative to --version flag)
127
+ subparsers.add_parser(
128
+ "version",
129
+ help="Show version information",
130
+ )
131
+
132
+ args = parser.parse_args(argv)
133
+
134
+ # Handle --version flag
135
+ if args.version or args.command == "version":
136
+ return _cmd_version()
137
+
138
+ # Handle 'local' commands
139
+ if args.command == "local":
140
+ if args.local_command == "up":
141
+ return _cmd_local_up(
142
+ redis_port=args.redis_port,
143
+ minio_port=args.minio_port,
144
+ minio_console_port=args.minio_console_port,
145
+ detach=args.detach,
146
+ )
147
+ elif args.local_command == "down":
148
+ return _cmd_local_down(remove_volumes=args.volumes)
149
+ elif args.local_command == "status":
150
+ return _cmd_local_status()
151
+ else:
152
+ local_parser.print_help()
153
+ return 1
154
+
155
+ # Handle 'diagnose' command
156
+ if args.command == "diagnose":
157
+ return _cmd_diagnose(host=args.host, port=args.port)
158
+
159
+ # No command given
160
+ parser.print_help()
161
+ return 0
162
+
163
+
164
+ def _cmd_version() -> int:
165
+ """Show version information."""
166
+ try:
167
+ from atdata import __version__
168
+ version = __version__
169
+ except ImportError:
170
+ # Fallback to package metadata
171
+ from importlib.metadata import version as pkg_version
172
+ version = pkg_version("atdata")
173
+
174
+ print(f"atdata {version}")
175
+ return 0
176
+
177
+
178
+ def _cmd_local_up(
179
+ redis_port: int,
180
+ minio_port: int,
181
+ minio_console_port: int,
182
+ detach: bool,
183
+ ) -> int:
184
+ """Start local development infrastructure."""
185
+ from .local import local_up
186
+ return local_up(
187
+ redis_port=redis_port,
188
+ minio_port=minio_port,
189
+ minio_console_port=minio_console_port,
190
+ detach=detach,
191
+ )
192
+
193
+
194
+ def _cmd_local_down(remove_volumes: bool) -> int:
195
+ """Stop local development infrastructure."""
196
+ from .local import local_down
197
+ return local_down(remove_volumes=remove_volumes)
198
+
199
+
200
+ def _cmd_local_status() -> int:
201
+ """Show status of local infrastructure."""
202
+ from .local import local_status
203
+ return local_status()
204
+
205
+
206
+ def _cmd_diagnose(host: str, port: int) -> int:
207
+ """Diagnose Redis configuration."""
208
+ from .diagnose import diagnose_redis
209
+ return diagnose_redis(host=host, port=port)
210
+
211
+
212
+ if __name__ == "__main__":
213
+ sys.exit(main())