atdata 0.3.1b1__py3-none-any.whl → 0.3.2b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +2 -0
- atdata/_hf_api.py +13 -0
- atdata/_logging.py +43 -0
- atdata/_protocols.py +18 -1
- atdata/_sources.py +24 -4
- atdata/atmosphere/__init__.py +48 -10
- atdata/atmosphere/_lexicon_types.py +595 -0
- atdata/atmosphere/_types.py +71 -243
- atdata/atmosphere/lens.py +49 -41
- atdata/atmosphere/records.py +282 -90
- atdata/atmosphere/schema.py +78 -50
- atdata/atmosphere/store.py +62 -59
- atdata/dataset.py +201 -135
- atdata/index/_entry.py +6 -2
- atdata/index/_index.py +396 -109
- atdata/lexicons/__init__.py +9 -3
- atdata/lexicons/ac.foundation.dataset.lens.json +2 -0
- atdata/lexicons/ac.foundation.dataset.record.json +22 -1
- atdata/lexicons/ac.foundation.dataset.storageBlobs.json +26 -4
- atdata/lexicons/ac.foundation.dataset.storageExternal.json +1 -1
- atdata/lexicons/ac.foundation.dataset.storageHttp.json +45 -0
- atdata/lexicons/ac.foundation.dataset.storageS3.json +61 -0
- atdata/manifest/__init__.py +4 -0
- atdata/manifest/_proxy.py +321 -0
- atdata/repository.py +59 -9
- atdata/stores/_disk.py +19 -11
- atdata/stores/_s3.py +134 -112
- {atdata-0.3.1b1.dist-info → atdata-0.3.2b1.dist-info}/METADATA +1 -1
- {atdata-0.3.1b1.dist-info → atdata-0.3.2b1.dist-info}/RECORD +37 -33
- {atdata-0.3.1b1.dist-info → atdata-0.3.2b1.dist-info}/WHEEL +0 -0
- {atdata-0.3.1b1.dist-info → atdata-0.3.2b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.3.1b1.dist-info → atdata-0.3.2b1.dist-info}/licenses/LICENSE +0 -0
atdata/atmosphere/schema.py
CHANGED
|
@@ -9,17 +9,11 @@ from dataclasses import fields, is_dataclass
|
|
|
9
9
|
from typing import Type, TypeVar, Optional, get_type_hints, get_origin, get_args
|
|
10
10
|
|
|
11
11
|
from .client import Atmosphere
|
|
12
|
-
from ._types import
|
|
13
|
-
|
|
14
|
-
SchemaRecord,
|
|
15
|
-
FieldDef,
|
|
16
|
-
FieldType,
|
|
17
|
-
LEXICON_NAMESPACE,
|
|
18
|
-
)
|
|
12
|
+
from ._types import AtUri, LEXICON_NAMESPACE
|
|
13
|
+
from ._lexicon_types import LexSchemaRecord, JsonSchemaFormat
|
|
19
14
|
from .._type_utils import (
|
|
20
15
|
unwrap_optional,
|
|
21
16
|
is_ndarray_type,
|
|
22
|
-
extract_ndarray_dtype,
|
|
23
17
|
)
|
|
24
18
|
|
|
25
19
|
# Import for type checking only to avoid circular imports
|
|
@@ -86,27 +80,32 @@ class SchemaPublisher:
|
|
|
86
80
|
ValueError: If sample_type is not a dataclass or client is not authenticated.
|
|
87
81
|
TypeError: If a field type is not supported.
|
|
88
82
|
"""
|
|
83
|
+
from atdata._logging import log_operation
|
|
84
|
+
|
|
89
85
|
if not is_dataclass(sample_type):
|
|
90
86
|
raise ValueError(
|
|
91
87
|
f"{sample_type.__name__} must be a dataclass (use @packable)"
|
|
92
88
|
)
|
|
93
89
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
90
|
+
with log_operation(
|
|
91
|
+
"SchemaPublisher.publish", schema=sample_type.__name__, version=version
|
|
92
|
+
):
|
|
93
|
+
# Build the schema record
|
|
94
|
+
schema_record = self._build_schema_record(
|
|
95
|
+
sample_type,
|
|
96
|
+
name=name,
|
|
97
|
+
version=version,
|
|
98
|
+
description=description,
|
|
99
|
+
metadata=metadata,
|
|
100
|
+
)
|
|
102
101
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
102
|
+
# Publish to ATProto
|
|
103
|
+
return self.client.create_record(
|
|
104
|
+
collection=f"{LEXICON_NAMESPACE}.schema",
|
|
105
|
+
record=schema_record.to_record(),
|
|
106
|
+
rkey=rkey,
|
|
107
|
+
validate=False, # PDS doesn't know our lexicon
|
|
108
|
+
)
|
|
110
109
|
|
|
111
110
|
def _build_schema_record(
|
|
112
111
|
self,
|
|
@@ -116,57 +115,74 @@ class SchemaPublisher:
|
|
|
116
115
|
version: str,
|
|
117
116
|
description: Optional[str],
|
|
118
117
|
metadata: Optional[dict],
|
|
119
|
-
) ->
|
|
120
|
-
"""Build a
|
|
121
|
-
field_defs = []
|
|
118
|
+
) -> LexSchemaRecord:
|
|
119
|
+
"""Build a LexSchemaRecord from a PackableSample class."""
|
|
122
120
|
type_hints = get_type_hints(sample_type)
|
|
121
|
+
properties: dict[str, dict] = {}
|
|
122
|
+
required_fields: list[str] = []
|
|
123
|
+
has_ndarray = False
|
|
123
124
|
|
|
124
125
|
for f in fields(sample_type):
|
|
125
126
|
field_type = type_hints.get(f.name, f.type)
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
127
|
+
field_type, is_optional = unwrap_optional(field_type)
|
|
128
|
+
prop = self._python_type_to_json_schema(field_type)
|
|
129
|
+
properties[f.name] = prop
|
|
130
|
+
if not is_optional:
|
|
131
|
+
required_fields.append(f.name)
|
|
132
|
+
if is_ndarray_type(field_type):
|
|
133
|
+
has_ndarray = True
|
|
134
|
+
|
|
135
|
+
schema_body = {
|
|
136
|
+
"$schema": "http://json-schema.org/draft-07/schema#",
|
|
137
|
+
"type": "object",
|
|
138
|
+
"properties": properties,
|
|
139
|
+
}
|
|
140
|
+
if required_fields:
|
|
141
|
+
schema_body["required"] = required_fields
|
|
142
|
+
|
|
143
|
+
array_format_versions = None
|
|
144
|
+
if has_ndarray:
|
|
145
|
+
array_format_versions = {"ndarrayBytes": "1.0.0"}
|
|
146
|
+
|
|
147
|
+
return LexSchemaRecord(
|
|
130
148
|
name=name or sample_type.__name__,
|
|
131
149
|
version=version,
|
|
150
|
+
schema_type="jsonSchema",
|
|
151
|
+
schema=JsonSchemaFormat(
|
|
152
|
+
schema_body=schema_body,
|
|
153
|
+
array_format_versions=array_format_versions,
|
|
154
|
+
),
|
|
132
155
|
description=description,
|
|
133
|
-
fields=field_defs,
|
|
134
156
|
metadata=metadata,
|
|
135
157
|
)
|
|
136
158
|
|
|
137
|
-
def
|
|
138
|
-
"""
|
|
139
|
-
python_type, is_optional = unwrap_optional(python_type)
|
|
140
|
-
field_type = self._python_type_to_field_type(python_type)
|
|
141
|
-
return FieldDef(name=name, field_type=field_type, optional=is_optional)
|
|
142
|
-
|
|
143
|
-
def _python_type_to_field_type(self, python_type) -> FieldType:
|
|
144
|
-
"""Map a Python type to a FieldType."""
|
|
159
|
+
def _python_type_to_json_schema(self, python_type) -> dict:
|
|
160
|
+
"""Map a Python type to a JSON Schema property definition."""
|
|
145
161
|
if python_type is str:
|
|
146
|
-
return
|
|
162
|
+
return {"type": "string"}
|
|
147
163
|
if python_type is int:
|
|
148
|
-
return
|
|
164
|
+
return {"type": "integer"}
|
|
149
165
|
if python_type is float:
|
|
150
|
-
return
|
|
166
|
+
return {"type": "number"}
|
|
151
167
|
if python_type is bool:
|
|
152
|
-
return
|
|
168
|
+
return {"type": "boolean"}
|
|
153
169
|
if python_type is bytes:
|
|
154
|
-
return
|
|
170
|
+
return {"type": "string", "format": "byte", "contentEncoding": "base64"}
|
|
155
171
|
|
|
156
172
|
if is_ndarray_type(python_type):
|
|
157
|
-
return
|
|
158
|
-
|
|
159
|
-
|
|
173
|
+
return {
|
|
174
|
+
"$ref": "https://foundation.ac/schemas/atdata-ndarray-bytes/1.0.0#/$defs/ndarray"
|
|
175
|
+
}
|
|
160
176
|
|
|
161
177
|
origin = get_origin(python_type)
|
|
162
178
|
if origin is list:
|
|
163
179
|
args = get_args(python_type)
|
|
164
180
|
items = (
|
|
165
|
-
self.
|
|
181
|
+
self._python_type_to_json_schema(args[0])
|
|
166
182
|
if args
|
|
167
|
-
else
|
|
183
|
+
else {"type": "string"}
|
|
168
184
|
)
|
|
169
|
-
return
|
|
185
|
+
return {"type": "array", "items": items}
|
|
170
186
|
|
|
171
187
|
if is_dataclass(python_type):
|
|
172
188
|
raise TypeError(
|
|
@@ -224,6 +240,18 @@ class SchemaLoader:
|
|
|
224
240
|
|
|
225
241
|
return record
|
|
226
242
|
|
|
243
|
+
def get_typed(self, uri: str | AtUri) -> LexSchemaRecord:
|
|
244
|
+
"""Fetch a schema record and return as a typed object.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
uri: The AT URI of the schema record.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
LexSchemaRecord instance.
|
|
251
|
+
"""
|
|
252
|
+
record = self.get(uri)
|
|
253
|
+
return LexSchemaRecord.from_record(record)
|
|
254
|
+
|
|
227
255
|
def list_all(
|
|
228
256
|
self,
|
|
229
257
|
repo: Optional[str] = None,
|
atdata/atmosphere/store.py
CHANGED
|
@@ -19,11 +19,14 @@ Examples:
|
|
|
19
19
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
|
-
import tempfile
|
|
23
22
|
from dataclasses import dataclass
|
|
24
23
|
from typing import TYPE_CHECKING, Any
|
|
25
24
|
|
|
26
|
-
|
|
25
|
+
#: Maximum size in bytes for a single PDS blob upload (50 MB).
|
|
26
|
+
PDS_BLOB_LIMIT_BYTES: int = 50_000_000
|
|
27
|
+
|
|
28
|
+
#: Maximum total dataset size in bytes for atmosphere uploads (1 GB).
|
|
29
|
+
PDS_TOTAL_DATASET_LIMIT_BYTES: int = 1_000_000_000
|
|
27
30
|
|
|
28
31
|
if TYPE_CHECKING:
|
|
29
32
|
from ..dataset import Dataset
|
|
@@ -31,6 +34,25 @@ if TYPE_CHECKING:
|
|
|
31
34
|
from .client import Atmosphere
|
|
32
35
|
|
|
33
36
|
|
|
37
|
+
class ShardUploadResult(list):
|
|
38
|
+
"""Return type for ``PDSBlobStore.write_shards()``.
|
|
39
|
+
|
|
40
|
+
Extends ``list[str]`` (AT URIs) so it satisfies the ``AbstractDataStore``
|
|
41
|
+
protocol, while also carrying the raw blob reference dicts needed to
|
|
42
|
+
create ``storageBlobs`` records.
|
|
43
|
+
|
|
44
|
+
Attributes:
|
|
45
|
+
blob_refs: Blob reference dicts as returned by
|
|
46
|
+
``Atmosphere.upload_blob()``.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
blob_refs: list[dict]
|
|
50
|
+
|
|
51
|
+
def __init__(self, urls: list[str], blob_refs: list[dict]) -> None:
|
|
52
|
+
super().__init__(urls)
|
|
53
|
+
self.blob_refs = blob_refs
|
|
54
|
+
|
|
55
|
+
|
|
34
56
|
@dataclass
|
|
35
57
|
class PDSBlobStore:
|
|
36
58
|
"""PDS blob store implementing AbstractDataStore protocol.
|
|
@@ -59,78 +81,54 @@ class PDSBlobStore:
|
|
|
59
81
|
ds: "Dataset",
|
|
60
82
|
*,
|
|
61
83
|
prefix: str,
|
|
62
|
-
maxcount: int = 10000,
|
|
63
|
-
maxsize: float = 3e9,
|
|
64
84
|
**kwargs: Any,
|
|
65
|
-
) ->
|
|
66
|
-
"""
|
|
85
|
+
) -> "ShardUploadResult":
|
|
86
|
+
"""Upload existing dataset shards as PDS blobs.
|
|
67
87
|
|
|
68
|
-
|
|
69
|
-
to the authenticated user's PDS.
|
|
88
|
+
Reads the tar archives already written to disk by the caller and
|
|
89
|
+
uploads each as a blob to the authenticated user's PDS. This
|
|
90
|
+
avoids re-serializing samples that have already been written.
|
|
70
91
|
|
|
71
92
|
Args:
|
|
72
|
-
ds: The Dataset to
|
|
73
|
-
prefix: Logical path prefix
|
|
74
|
-
|
|
75
|
-
maxsize: Maximum shard size in bytes (default: 3GB, PDS limit).
|
|
76
|
-
**kwargs: Additional args passed to wds.ShardWriter.
|
|
93
|
+
ds: The Dataset whose shards to upload.
|
|
94
|
+
prefix: Logical path prefix (unused, kept for protocol compat).
|
|
95
|
+
**kwargs: Unused, kept for protocol compatibility.
|
|
77
96
|
|
|
78
97
|
Returns:
|
|
79
|
-
|
|
80
|
-
``
|
|
98
|
+
A ``ShardUploadResult`` (behaves as ``list[str]`` of AT URIs)
|
|
99
|
+
with a ``blob_refs`` attribute containing the raw blob reference
|
|
100
|
+
dicts needed for ``storageBlobs`` records.
|
|
81
101
|
|
|
82
102
|
Raises:
|
|
83
103
|
ValueError: If not authenticated.
|
|
84
|
-
RuntimeError: If no shards
|
|
85
|
-
|
|
86
|
-
Note:
|
|
87
|
-
PDS blobs have size limits (typically 50MB-5GB depending on PDS).
|
|
88
|
-
Adjust maxcount/maxsize to stay within limits.
|
|
104
|
+
RuntimeError: If no shards are found on the dataset.
|
|
89
105
|
"""
|
|
90
106
|
if not self.client.did:
|
|
91
107
|
raise ValueError("Client must be authenticated to upload blobs")
|
|
92
108
|
|
|
93
109
|
did = self.client.did
|
|
94
110
|
blob_urls: list[str] = []
|
|
111
|
+
blob_refs: list[dict] = []
|
|
112
|
+
|
|
113
|
+
shard_paths = ds.list_shards()
|
|
114
|
+
if not shard_paths:
|
|
115
|
+
raise RuntimeError("No shards to upload")
|
|
116
|
+
|
|
117
|
+
for shard_url in shard_paths:
|
|
118
|
+
with open(shard_url, "rb") as f:
|
|
119
|
+
shard_data = f.read()
|
|
120
|
+
|
|
121
|
+
blob_ref = self.client.upload_blob(
|
|
122
|
+
shard_data,
|
|
123
|
+
mime_type="application/x-tar",
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
blob_refs.append(blob_ref)
|
|
127
|
+
cid = blob_ref["ref"]["$link"]
|
|
128
|
+
at_uri = f"at://{did}/blob/{cid}"
|
|
129
|
+
blob_urls.append(at_uri)
|
|
95
130
|
|
|
96
|
-
|
|
97
|
-
with tempfile.TemporaryDirectory() as temp_dir:
|
|
98
|
-
shard_pattern = f"{temp_dir}/shard-%06d.tar"
|
|
99
|
-
written_files: list[str] = []
|
|
100
|
-
|
|
101
|
-
# Track written files via custom post callback
|
|
102
|
-
def track_file(fname: str) -> None:
|
|
103
|
-
written_files.append(fname)
|
|
104
|
-
|
|
105
|
-
with wds.writer.ShardWriter(
|
|
106
|
-
shard_pattern,
|
|
107
|
-
maxcount=maxcount,
|
|
108
|
-
maxsize=maxsize,
|
|
109
|
-
post=track_file,
|
|
110
|
-
**kwargs,
|
|
111
|
-
) as sink:
|
|
112
|
-
for sample in ds.ordered(batch_size=None):
|
|
113
|
-
sink.write(sample.as_wds)
|
|
114
|
-
|
|
115
|
-
if not written_files:
|
|
116
|
-
raise RuntimeError("No shards written")
|
|
117
|
-
|
|
118
|
-
# Upload each shard as a blob
|
|
119
|
-
for shard_path in written_files:
|
|
120
|
-
with open(shard_path, "rb") as f:
|
|
121
|
-
shard_data = f.read()
|
|
122
|
-
|
|
123
|
-
blob_ref = self.client.upload_blob(
|
|
124
|
-
shard_data,
|
|
125
|
-
mime_type="application/x-tar",
|
|
126
|
-
)
|
|
127
|
-
|
|
128
|
-
# Extract CID from blob reference
|
|
129
|
-
cid = blob_ref["ref"]["$link"]
|
|
130
|
-
at_uri = f"at://{did}/blob/{cid}"
|
|
131
|
-
blob_urls.append(at_uri)
|
|
132
|
-
|
|
133
|
-
return blob_urls
|
|
131
|
+
return ShardUploadResult(blob_urls, blob_refs)
|
|
134
132
|
|
|
135
133
|
def read_url(self, url: str) -> str:
|
|
136
134
|
"""Resolve an AT URI blob reference to an HTTP URL.
|
|
@@ -200,4 +198,9 @@ class PDSBlobStore:
|
|
|
200
198
|
return BlobSource(blob_refs=blob_refs)
|
|
201
199
|
|
|
202
200
|
|
|
203
|
-
__all__ = [
|
|
201
|
+
__all__ = [
|
|
202
|
+
"PDS_BLOB_LIMIT_BYTES",
|
|
203
|
+
"PDS_TOTAL_DATASET_LIMIT_BYTES",
|
|
204
|
+
"PDSBlobStore",
|
|
205
|
+
"ShardUploadResult",
|
|
206
|
+
]
|