atdata 0.1.3b3__py3-none-any.whl → 0.2.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +39 -1
- atdata/_helpers.py +39 -3
- atdata/atmosphere/__init__.py +61 -0
- atdata/atmosphere/_types.py +329 -0
- atdata/atmosphere/client.py +393 -0
- atdata/atmosphere/lens.py +280 -0
- atdata/atmosphere/records.py +342 -0
- atdata/atmosphere/schema.py +296 -0
- atdata/dataset.py +336 -203
- atdata/lens.py +177 -77
- atdata/local.py +492 -0
- atdata-0.2.0a1.dist-info/METADATA +181 -0
- atdata-0.2.0a1.dist-info/RECORD +16 -0
- {atdata-0.1.3b3.dist-info → atdata-0.2.0a1.dist-info}/WHEEL +1 -1
- atdata-0.1.3b3.dist-info/METADATA +0 -18
- atdata-0.1.3b3.dist-info/RECORD +0 -9
- {atdata-0.1.3b3.dist-info → atdata-0.2.0a1.dist-info}/entry_points.txt +0 -0
- {atdata-0.1.3b3.dist-info → atdata-0.2.0a1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
"""Dataset record publishing and loading for ATProto.
|
|
2
|
+
|
|
3
|
+
This module provides classes for publishing dataset index records to ATProto
|
|
4
|
+
and loading them back. Dataset records are published as
|
|
5
|
+
``ac.foundation.dataset.record`` records.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Type, TypeVar, Optional
|
|
9
|
+
import msgpack
|
|
10
|
+
|
|
11
|
+
from .client import AtmosphereClient
|
|
12
|
+
from .schema import SchemaPublisher
|
|
13
|
+
from ._types import (
|
|
14
|
+
AtUri,
|
|
15
|
+
DatasetRecord,
|
|
16
|
+
StorageLocation,
|
|
17
|
+
LEXICON_NAMESPACE,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
# Import for type checking only to avoid circular imports
|
|
21
|
+
from typing import TYPE_CHECKING
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from ..dataset import PackableSample, Dataset
|
|
24
|
+
|
|
25
|
+
ST = TypeVar("ST", bound="PackableSample")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class DatasetPublisher:
|
|
29
|
+
"""Publishes dataset index records to ATProto.
|
|
30
|
+
|
|
31
|
+
This class creates dataset records that reference a schema and point to
|
|
32
|
+
external storage (WebDataset URLs) or ATProto blobs.
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
>>> dataset = atdata.Dataset[MySample]("s3://bucket/data-{000000..000009}.tar")
|
|
36
|
+
>>>
|
|
37
|
+
>>> client = AtmosphereClient()
|
|
38
|
+
>>> client.login("handle", "password")
|
|
39
|
+
>>>
|
|
40
|
+
>>> publisher = DatasetPublisher(client)
|
|
41
|
+
>>> uri = publisher.publish(
|
|
42
|
+
... dataset,
|
|
43
|
+
... name="My Training Data",
|
|
44
|
+
... description="Training data for my model",
|
|
45
|
+
... tags=["computer-vision", "training"],
|
|
46
|
+
... )
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, client: AtmosphereClient):
|
|
50
|
+
"""Initialize the dataset publisher.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
client: Authenticated AtmosphereClient instance.
|
|
54
|
+
"""
|
|
55
|
+
self.client = client
|
|
56
|
+
self._schema_publisher = SchemaPublisher(client)
|
|
57
|
+
|
|
58
|
+
def publish(
|
|
59
|
+
self,
|
|
60
|
+
dataset: "Dataset[ST]",
|
|
61
|
+
*,
|
|
62
|
+
name: str,
|
|
63
|
+
schema_uri: Optional[str] = None,
|
|
64
|
+
description: Optional[str] = None,
|
|
65
|
+
tags: Optional[list[str]] = None,
|
|
66
|
+
license: Optional[str] = None,
|
|
67
|
+
auto_publish_schema: bool = True,
|
|
68
|
+
schema_version: str = "1.0.0",
|
|
69
|
+
rkey: Optional[str] = None,
|
|
70
|
+
) -> AtUri:
|
|
71
|
+
"""Publish a dataset index record to ATProto.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
dataset: The Dataset to publish.
|
|
75
|
+
name: Human-readable dataset name.
|
|
76
|
+
schema_uri: AT URI of the schema record. If not provided and
|
|
77
|
+
auto_publish_schema is True, the schema will be published.
|
|
78
|
+
description: Human-readable description.
|
|
79
|
+
tags: Searchable tags for discovery.
|
|
80
|
+
license: SPDX license identifier (e.g., 'MIT', 'Apache-2.0').
|
|
81
|
+
auto_publish_schema: If True and schema_uri not provided,
|
|
82
|
+
automatically publish the schema first.
|
|
83
|
+
schema_version: Version for auto-published schema.
|
|
84
|
+
rkey: Optional explicit record key.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
The AT URI of the created dataset record.
|
|
88
|
+
|
|
89
|
+
Raises:
|
|
90
|
+
ValueError: If schema_uri is not provided and auto_publish_schema is False.
|
|
91
|
+
"""
|
|
92
|
+
# Ensure we have a schema reference
|
|
93
|
+
if schema_uri is None:
|
|
94
|
+
if not auto_publish_schema:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
"schema_uri is required when auto_publish_schema=False"
|
|
97
|
+
)
|
|
98
|
+
# Auto-publish the schema
|
|
99
|
+
schema_uri_obj = self._schema_publisher.publish(
|
|
100
|
+
dataset.sample_type,
|
|
101
|
+
version=schema_version,
|
|
102
|
+
)
|
|
103
|
+
schema_uri = str(schema_uri_obj)
|
|
104
|
+
|
|
105
|
+
# Build the storage location
|
|
106
|
+
storage = StorageLocation(
|
|
107
|
+
kind="external",
|
|
108
|
+
urls=[dataset.url],
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Build dataset record
|
|
112
|
+
metadata_bytes: Optional[bytes] = None
|
|
113
|
+
if dataset.metadata is not None:
|
|
114
|
+
metadata_bytes = msgpack.packb(dataset.metadata)
|
|
115
|
+
|
|
116
|
+
dataset_record = DatasetRecord(
|
|
117
|
+
name=name,
|
|
118
|
+
schema_ref=schema_uri,
|
|
119
|
+
storage=storage,
|
|
120
|
+
description=description,
|
|
121
|
+
tags=tags or [],
|
|
122
|
+
license=license,
|
|
123
|
+
metadata=metadata_bytes,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Publish to ATProto
|
|
127
|
+
return self.client.create_record(
|
|
128
|
+
collection=f"{LEXICON_NAMESPACE}.record",
|
|
129
|
+
record=dataset_record.to_record(),
|
|
130
|
+
rkey=rkey,
|
|
131
|
+
validate=False,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
def publish_with_urls(
|
|
135
|
+
self,
|
|
136
|
+
urls: list[str],
|
|
137
|
+
schema_uri: str,
|
|
138
|
+
*,
|
|
139
|
+
name: str,
|
|
140
|
+
description: Optional[str] = None,
|
|
141
|
+
tags: Optional[list[str]] = None,
|
|
142
|
+
license: Optional[str] = None,
|
|
143
|
+
metadata: Optional[dict] = None,
|
|
144
|
+
rkey: Optional[str] = None,
|
|
145
|
+
) -> AtUri:
|
|
146
|
+
"""Publish a dataset record with explicit URLs.
|
|
147
|
+
|
|
148
|
+
This method allows publishing a dataset record without having a
|
|
149
|
+
Dataset object, useful for registering existing WebDataset files.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
urls: List of WebDataset URLs with brace notation.
|
|
153
|
+
schema_uri: AT URI of the schema record.
|
|
154
|
+
name: Human-readable dataset name.
|
|
155
|
+
description: Human-readable description.
|
|
156
|
+
tags: Searchable tags for discovery.
|
|
157
|
+
license: SPDX license identifier.
|
|
158
|
+
metadata: Arbitrary metadata dictionary.
|
|
159
|
+
rkey: Optional explicit record key.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
The AT URI of the created dataset record.
|
|
163
|
+
"""
|
|
164
|
+
storage = StorageLocation(
|
|
165
|
+
kind="external",
|
|
166
|
+
urls=urls,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
metadata_bytes: Optional[bytes] = None
|
|
170
|
+
if metadata is not None:
|
|
171
|
+
metadata_bytes = msgpack.packb(metadata)
|
|
172
|
+
|
|
173
|
+
dataset_record = DatasetRecord(
|
|
174
|
+
name=name,
|
|
175
|
+
schema_ref=schema_uri,
|
|
176
|
+
storage=storage,
|
|
177
|
+
description=description,
|
|
178
|
+
tags=tags or [],
|
|
179
|
+
license=license,
|
|
180
|
+
metadata=metadata_bytes,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
return self.client.create_record(
|
|
184
|
+
collection=f"{LEXICON_NAMESPACE}.record",
|
|
185
|
+
record=dataset_record.to_record(),
|
|
186
|
+
rkey=rkey,
|
|
187
|
+
validate=False,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class DatasetLoader:
|
|
192
|
+
"""Loads dataset records from ATProto.
|
|
193
|
+
|
|
194
|
+
This class fetches dataset index records and can create Dataset objects
|
|
195
|
+
from them. Note that loading a dataset requires having the corresponding
|
|
196
|
+
Python class for the sample type.
|
|
197
|
+
|
|
198
|
+
Example:
|
|
199
|
+
>>> client = AtmosphereClient()
|
|
200
|
+
>>> loader = DatasetLoader(client)
|
|
201
|
+
>>>
|
|
202
|
+
>>> # List available datasets
|
|
203
|
+
>>> datasets = loader.list()
|
|
204
|
+
>>> for ds in datasets:
|
|
205
|
+
... print(ds["name"], ds["schemaRef"])
|
|
206
|
+
>>>
|
|
207
|
+
>>> # Get a specific dataset record
|
|
208
|
+
>>> record = loader.get("at://did:plc:abc/ac.foundation.dataset.record/xyz")
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
def __init__(self, client: AtmosphereClient):
|
|
212
|
+
"""Initialize the dataset loader.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
client: AtmosphereClient instance.
|
|
216
|
+
"""
|
|
217
|
+
self.client = client
|
|
218
|
+
|
|
219
|
+
def get(self, uri: str | AtUri) -> dict:
|
|
220
|
+
"""Fetch a dataset record by AT URI.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
uri: The AT URI of the dataset record.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
The dataset record as a dictionary.
|
|
227
|
+
|
|
228
|
+
Raises:
|
|
229
|
+
ValueError: If the record is not a dataset record.
|
|
230
|
+
"""
|
|
231
|
+
record = self.client.get_record(uri)
|
|
232
|
+
|
|
233
|
+
expected_type = f"{LEXICON_NAMESPACE}.record"
|
|
234
|
+
if record.get("$type") != expected_type:
|
|
235
|
+
raise ValueError(
|
|
236
|
+
f"Record at {uri} is not a dataset record. "
|
|
237
|
+
f"Expected $type='{expected_type}', got '{record.get('$type')}'"
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
return record
|
|
241
|
+
|
|
242
|
+
def list_all(
|
|
243
|
+
self,
|
|
244
|
+
repo: Optional[str] = None,
|
|
245
|
+
limit: int = 100,
|
|
246
|
+
) -> list[dict]:
|
|
247
|
+
"""List dataset records from a repository.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
repo: The DID of the repository. Defaults to authenticated user.
|
|
251
|
+
limit: Maximum number of records to return.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
List of dataset records.
|
|
255
|
+
"""
|
|
256
|
+
return self.client.list_datasets(repo=repo, limit=limit)
|
|
257
|
+
|
|
258
|
+
def get_urls(self, uri: str | AtUri) -> list[str]:
|
|
259
|
+
"""Get the WebDataset URLs from a dataset record.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
uri: The AT URI of the dataset record.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
List of WebDataset URLs.
|
|
266
|
+
|
|
267
|
+
Raises:
|
|
268
|
+
ValueError: If the storage type is not external URLs.
|
|
269
|
+
"""
|
|
270
|
+
record = self.get(uri)
|
|
271
|
+
storage = record.get("storage", {})
|
|
272
|
+
|
|
273
|
+
storage_type = storage.get("$type", "")
|
|
274
|
+
if "storageExternal" in storage_type:
|
|
275
|
+
return storage.get("urls", [])
|
|
276
|
+
elif "storageBlobs" in storage_type:
|
|
277
|
+
raise ValueError(
|
|
278
|
+
"Dataset uses blob storage, not external URLs. "
|
|
279
|
+
"Use get_blobs() instead."
|
|
280
|
+
)
|
|
281
|
+
else:
|
|
282
|
+
raise ValueError(f"Unknown storage type: {storage_type}")
|
|
283
|
+
|
|
284
|
+
def get_metadata(self, uri: str | AtUri) -> Optional[dict]:
|
|
285
|
+
"""Get the metadata from a dataset record.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
uri: The AT URI of the dataset record.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
The metadata dictionary, or None if no metadata.
|
|
292
|
+
"""
|
|
293
|
+
record = self.get(uri)
|
|
294
|
+
metadata_bytes = record.get("metadata")
|
|
295
|
+
|
|
296
|
+
if metadata_bytes is None:
|
|
297
|
+
return None
|
|
298
|
+
|
|
299
|
+
return msgpack.unpackb(metadata_bytes, raw=False)
|
|
300
|
+
|
|
301
|
+
def to_dataset(
|
|
302
|
+
self,
|
|
303
|
+
uri: str | AtUri,
|
|
304
|
+
sample_type: Type[ST],
|
|
305
|
+
) -> "Dataset[ST]":
|
|
306
|
+
"""Create a Dataset object from an ATProto record.
|
|
307
|
+
|
|
308
|
+
This method creates a Dataset instance from a published record.
|
|
309
|
+
You must provide the sample type class, which should match the
|
|
310
|
+
schema referenced by the record.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
uri: The AT URI of the dataset record.
|
|
314
|
+
sample_type: The Python class for the sample type.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
A Dataset instance configured from the record.
|
|
318
|
+
|
|
319
|
+
Raises:
|
|
320
|
+
ValueError: If the storage type is not external URLs.
|
|
321
|
+
|
|
322
|
+
Example:
|
|
323
|
+
>>> loader = DatasetLoader(client)
|
|
324
|
+
>>> dataset = loader.to_dataset(uri, MySampleType)
|
|
325
|
+
>>> for batch in dataset.shuffled(batch_size=32):
|
|
326
|
+
... process(batch)
|
|
327
|
+
"""
|
|
328
|
+
# Import here to avoid circular import
|
|
329
|
+
from ..dataset import Dataset
|
|
330
|
+
|
|
331
|
+
urls = self.get_urls(uri)
|
|
332
|
+
if not urls:
|
|
333
|
+
raise ValueError("Dataset record has no URLs")
|
|
334
|
+
|
|
335
|
+
# Use the first URL (multi-URL support could be added later)
|
|
336
|
+
url = urls[0]
|
|
337
|
+
|
|
338
|
+
# Get metadata URL if available
|
|
339
|
+
record = self.get(uri)
|
|
340
|
+
metadata_url = record.get("metadataUrl")
|
|
341
|
+
|
|
342
|
+
return Dataset[sample_type](url, metadata_url=metadata_url)
|
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
"""Schema publishing and loading for ATProto.
|
|
2
|
+
|
|
3
|
+
This module provides classes for publishing PackableSample schemas to ATProto
|
|
4
|
+
and loading them back. Schemas are published as ``ac.foundation.dataset.sampleSchema``
|
|
5
|
+
records.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from dataclasses import fields, is_dataclass
|
|
9
|
+
from typing import Type, TypeVar, Optional, Union, get_type_hints, get_origin, get_args
|
|
10
|
+
import types
|
|
11
|
+
|
|
12
|
+
from .client import AtmosphereClient
|
|
13
|
+
from ._types import (
|
|
14
|
+
AtUri,
|
|
15
|
+
SchemaRecord,
|
|
16
|
+
FieldDef,
|
|
17
|
+
FieldType,
|
|
18
|
+
LEXICON_NAMESPACE,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# Import for type checking only to avoid circular imports
|
|
22
|
+
from typing import TYPE_CHECKING
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from ..dataset import PackableSample
|
|
25
|
+
|
|
26
|
+
ST = TypeVar("ST", bound="PackableSample")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class SchemaPublisher:
|
|
30
|
+
"""Publishes PackableSample schemas to ATProto.
|
|
31
|
+
|
|
32
|
+
This class introspects a PackableSample class to extract its field
|
|
33
|
+
definitions and publishes them as an ATProto schema record.
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
>>> @atdata.packable
|
|
37
|
+
... class MySample:
|
|
38
|
+
... image: NDArray
|
|
39
|
+
... label: str
|
|
40
|
+
...
|
|
41
|
+
>>> client = AtmosphereClient()
|
|
42
|
+
>>> client.login("handle", "password")
|
|
43
|
+
>>>
|
|
44
|
+
>>> publisher = SchemaPublisher(client)
|
|
45
|
+
>>> uri = publisher.publish(MySample, version="1.0.0")
|
|
46
|
+
>>> print(uri)
|
|
47
|
+
at://did:plc:.../ac.foundation.dataset.sampleSchema/...
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, client: AtmosphereClient):
|
|
51
|
+
"""Initialize the schema publisher.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
client: Authenticated AtmosphereClient instance.
|
|
55
|
+
"""
|
|
56
|
+
self.client = client
|
|
57
|
+
|
|
58
|
+
def publish(
|
|
59
|
+
self,
|
|
60
|
+
sample_type: Type[ST],
|
|
61
|
+
*,
|
|
62
|
+
name: Optional[str] = None,
|
|
63
|
+
version: str = "1.0.0",
|
|
64
|
+
description: Optional[str] = None,
|
|
65
|
+
metadata: Optional[dict] = None,
|
|
66
|
+
rkey: Optional[str] = None,
|
|
67
|
+
) -> AtUri:
|
|
68
|
+
"""Publish a PackableSample schema to ATProto.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
sample_type: The PackableSample class to publish.
|
|
72
|
+
name: Human-readable name. Defaults to the class name.
|
|
73
|
+
version: Semantic version string (e.g., '1.0.0').
|
|
74
|
+
description: Human-readable description.
|
|
75
|
+
metadata: Arbitrary metadata dictionary.
|
|
76
|
+
rkey: Optional explicit record key. If not provided, a TID is generated.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
The AT URI of the created schema record.
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
ValueError: If sample_type is not a dataclass or client is not authenticated.
|
|
83
|
+
TypeError: If a field type is not supported.
|
|
84
|
+
"""
|
|
85
|
+
if not is_dataclass(sample_type):
|
|
86
|
+
raise ValueError(f"{sample_type.__name__} must be a dataclass (use @packable)")
|
|
87
|
+
|
|
88
|
+
# Build the schema record
|
|
89
|
+
schema_record = self._build_schema_record(
|
|
90
|
+
sample_type,
|
|
91
|
+
name=name,
|
|
92
|
+
version=version,
|
|
93
|
+
description=description,
|
|
94
|
+
metadata=metadata,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Publish to ATProto
|
|
98
|
+
return self.client.create_record(
|
|
99
|
+
collection=f"{LEXICON_NAMESPACE}.sampleSchema",
|
|
100
|
+
record=schema_record.to_record(),
|
|
101
|
+
rkey=rkey,
|
|
102
|
+
validate=False, # PDS doesn't know our lexicon
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
def _build_schema_record(
|
|
106
|
+
self,
|
|
107
|
+
sample_type: Type[ST],
|
|
108
|
+
*,
|
|
109
|
+
name: Optional[str],
|
|
110
|
+
version: str,
|
|
111
|
+
description: Optional[str],
|
|
112
|
+
metadata: Optional[dict],
|
|
113
|
+
) -> SchemaRecord:
|
|
114
|
+
"""Build a SchemaRecord from a PackableSample class."""
|
|
115
|
+
field_defs = []
|
|
116
|
+
type_hints = get_type_hints(sample_type)
|
|
117
|
+
|
|
118
|
+
for f in fields(sample_type):
|
|
119
|
+
field_type = type_hints.get(f.name, f.type)
|
|
120
|
+
field_def = self._field_to_def(f.name, field_type)
|
|
121
|
+
field_defs.append(field_def)
|
|
122
|
+
|
|
123
|
+
return SchemaRecord(
|
|
124
|
+
name=name or sample_type.__name__,
|
|
125
|
+
version=version,
|
|
126
|
+
description=description,
|
|
127
|
+
fields=field_defs,
|
|
128
|
+
metadata=metadata,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def _field_to_def(self, name: str, python_type) -> FieldDef:
|
|
132
|
+
"""Convert a Python field to a FieldDef."""
|
|
133
|
+
# Check for Optional types (Union with None)
|
|
134
|
+
is_optional = False
|
|
135
|
+
origin = get_origin(python_type)
|
|
136
|
+
|
|
137
|
+
# Handle Union types (including Optional which is Union[T, None])
|
|
138
|
+
if origin is Union or isinstance(python_type, types.UnionType):
|
|
139
|
+
args = get_args(python_type)
|
|
140
|
+
non_none_args = [a for a in args if a is not type(None)]
|
|
141
|
+
if type(None) in args or len(non_none_args) < len(args):
|
|
142
|
+
is_optional = True
|
|
143
|
+
if len(non_none_args) == 1:
|
|
144
|
+
python_type = non_none_args[0]
|
|
145
|
+
elif len(non_none_args) > 1:
|
|
146
|
+
# Complex union type - not fully supported yet
|
|
147
|
+
raise TypeError(f"Complex union types not supported: {python_type}")
|
|
148
|
+
|
|
149
|
+
field_type = self._python_type_to_field_type(python_type)
|
|
150
|
+
|
|
151
|
+
return FieldDef(
|
|
152
|
+
name=name,
|
|
153
|
+
field_type=field_type,
|
|
154
|
+
optional=is_optional,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def _python_type_to_field_type(self, python_type) -> FieldType:
|
|
158
|
+
"""Map a Python type to a FieldType."""
|
|
159
|
+
# Handle primitives
|
|
160
|
+
if python_type is str:
|
|
161
|
+
return FieldType(kind="primitive", primitive="str")
|
|
162
|
+
elif python_type is int:
|
|
163
|
+
return FieldType(kind="primitive", primitive="int")
|
|
164
|
+
elif python_type is float:
|
|
165
|
+
return FieldType(kind="primitive", primitive="float")
|
|
166
|
+
elif python_type is bool:
|
|
167
|
+
return FieldType(kind="primitive", primitive="bool")
|
|
168
|
+
elif python_type is bytes:
|
|
169
|
+
return FieldType(kind="primitive", primitive="bytes")
|
|
170
|
+
|
|
171
|
+
# Check for NDArray
|
|
172
|
+
# NDArray from numpy.typing is a special generic alias
|
|
173
|
+
type_str = str(python_type)
|
|
174
|
+
if "NDArray" in type_str or "ndarray" in type_str.lower():
|
|
175
|
+
# Try to extract dtype info if available
|
|
176
|
+
dtype = "float32" # Default
|
|
177
|
+
args = get_args(python_type)
|
|
178
|
+
if args:
|
|
179
|
+
# NDArray[np.float64] or similar
|
|
180
|
+
dtype_arg = args[-1] if args else None
|
|
181
|
+
if dtype_arg is not None:
|
|
182
|
+
dtype = self._numpy_dtype_to_string(dtype_arg)
|
|
183
|
+
|
|
184
|
+
return FieldType(kind="ndarray", dtype=dtype, shape=None)
|
|
185
|
+
|
|
186
|
+
# Check for list/array types
|
|
187
|
+
origin = get_origin(python_type)
|
|
188
|
+
if origin is list:
|
|
189
|
+
args = get_args(python_type)
|
|
190
|
+
if args:
|
|
191
|
+
items = self._python_type_to_field_type(args[0])
|
|
192
|
+
return FieldType(kind="array", items=items)
|
|
193
|
+
else:
|
|
194
|
+
# Untyped list
|
|
195
|
+
return FieldType(kind="array", items=FieldType(kind="primitive", primitive="str"))
|
|
196
|
+
|
|
197
|
+
# Check for nested PackableSample (not yet supported)
|
|
198
|
+
if is_dataclass(python_type):
|
|
199
|
+
raise TypeError(
|
|
200
|
+
f"Nested dataclass types not yet supported: {python_type.__name__}. "
|
|
201
|
+
"Publish nested types separately and use references."
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
raise TypeError(f"Unsupported type for schema field: {python_type}")
|
|
205
|
+
|
|
206
|
+
def _numpy_dtype_to_string(self, dtype) -> str:
|
|
207
|
+
"""Convert a numpy dtype annotation to a string."""
|
|
208
|
+
dtype_str = str(dtype)
|
|
209
|
+
# Handle common numpy dtypes
|
|
210
|
+
dtype_map = {
|
|
211
|
+
"float16": "float16",
|
|
212
|
+
"float32": "float32",
|
|
213
|
+
"float64": "float64",
|
|
214
|
+
"int8": "int8",
|
|
215
|
+
"int16": "int16",
|
|
216
|
+
"int32": "int32",
|
|
217
|
+
"int64": "int64",
|
|
218
|
+
"uint8": "uint8",
|
|
219
|
+
"uint16": "uint16",
|
|
220
|
+
"uint32": "uint32",
|
|
221
|
+
"uint64": "uint64",
|
|
222
|
+
"bool": "bool",
|
|
223
|
+
"complex64": "complex64",
|
|
224
|
+
"complex128": "complex128",
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
for key, value in dtype_map.items():
|
|
228
|
+
if key in dtype_str:
|
|
229
|
+
return value
|
|
230
|
+
|
|
231
|
+
return "float32" # Default fallback
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class SchemaLoader:
|
|
235
|
+
"""Loads PackableSample schemas from ATProto.
|
|
236
|
+
|
|
237
|
+
This class fetches schema records from ATProto and can list available
|
|
238
|
+
schemas from a repository.
|
|
239
|
+
|
|
240
|
+
Example:
|
|
241
|
+
>>> client = AtmosphereClient()
|
|
242
|
+
>>> client.login("handle", "password")
|
|
243
|
+
>>>
|
|
244
|
+
>>> loader = SchemaLoader(client)
|
|
245
|
+
>>> schema = loader.get("at://did:plc:.../ac.foundation.dataset.sampleSchema/...")
|
|
246
|
+
>>> print(schema["name"])
|
|
247
|
+
'MySample'
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
def __init__(self, client: AtmosphereClient):
|
|
251
|
+
"""Initialize the schema loader.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
client: AtmosphereClient instance (authentication optional for reads).
|
|
255
|
+
"""
|
|
256
|
+
self.client = client
|
|
257
|
+
|
|
258
|
+
def get(self, uri: str | AtUri) -> dict:
|
|
259
|
+
"""Fetch a schema record by AT URI.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
uri: The AT URI of the schema record.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
The schema record as a dictionary.
|
|
266
|
+
|
|
267
|
+
Raises:
|
|
268
|
+
ValueError: If the record is not a schema record.
|
|
269
|
+
atproto.exceptions.AtProtocolError: If record not found.
|
|
270
|
+
"""
|
|
271
|
+
record = self.client.get_record(uri)
|
|
272
|
+
|
|
273
|
+
expected_type = f"{LEXICON_NAMESPACE}.sampleSchema"
|
|
274
|
+
if record.get("$type") != expected_type:
|
|
275
|
+
raise ValueError(
|
|
276
|
+
f"Record at {uri} is not a schema record. "
|
|
277
|
+
f"Expected $type='{expected_type}', got '{record.get('$type')}'"
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
return record
|
|
281
|
+
|
|
282
|
+
def list_all(
|
|
283
|
+
self,
|
|
284
|
+
repo: Optional[str] = None,
|
|
285
|
+
limit: int = 100,
|
|
286
|
+
) -> list[dict]:
|
|
287
|
+
"""List schema records from a repository.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
repo: The DID of the repository. Defaults to authenticated user.
|
|
291
|
+
limit: Maximum number of records to return.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
List of schema records.
|
|
295
|
+
"""
|
|
296
|
+
return self.client.list_schemas(repo=repo, limit=limit)
|