atdata 0.3.1b1__py3-none-any.whl → 0.3.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,17 +9,11 @@ from dataclasses import fields, is_dataclass
9
9
  from typing import Type, TypeVar, Optional, get_type_hints, get_origin, get_args
10
10
 
11
11
  from .client import Atmosphere
12
- from ._types import (
13
- AtUri,
14
- SchemaRecord,
15
- FieldDef,
16
- FieldType,
17
- LEXICON_NAMESPACE,
18
- )
12
+ from ._types import AtUri, LEXICON_NAMESPACE
13
+ from ._lexicon_types import LexSchemaRecord, JsonSchemaFormat
19
14
  from .._type_utils import (
20
15
  unwrap_optional,
21
16
  is_ndarray_type,
22
- extract_ndarray_dtype,
23
17
  )
24
18
 
25
19
  # Import for type checking only to avoid circular imports
@@ -86,27 +80,32 @@ class SchemaPublisher:
86
80
  ValueError: If sample_type is not a dataclass or client is not authenticated.
87
81
  TypeError: If a field type is not supported.
88
82
  """
83
+ from atdata._logging import log_operation
84
+
89
85
  if not is_dataclass(sample_type):
90
86
  raise ValueError(
91
87
  f"{sample_type.__name__} must be a dataclass (use @packable)"
92
88
  )
93
89
 
94
- # Build the schema record
95
- schema_record = self._build_schema_record(
96
- sample_type,
97
- name=name,
98
- version=version,
99
- description=description,
100
- metadata=metadata,
101
- )
90
+ with log_operation(
91
+ "SchemaPublisher.publish", schema=sample_type.__name__, version=version
92
+ ):
93
+ # Build the schema record
94
+ schema_record = self._build_schema_record(
95
+ sample_type,
96
+ name=name,
97
+ version=version,
98
+ description=description,
99
+ metadata=metadata,
100
+ )
102
101
 
103
- # Publish to ATProto
104
- return self.client.create_record(
105
- collection=f"{LEXICON_NAMESPACE}.schema",
106
- record=schema_record.to_record(),
107
- rkey=rkey,
108
- validate=False, # PDS doesn't know our lexicon
109
- )
102
+ # Publish to ATProto
103
+ return self.client.create_record(
104
+ collection=f"{LEXICON_NAMESPACE}.schema",
105
+ record=schema_record.to_record(),
106
+ rkey=rkey,
107
+ validate=False, # PDS doesn't know our lexicon
108
+ )
110
109
 
111
110
  def _build_schema_record(
112
111
  self,
@@ -116,57 +115,74 @@ class SchemaPublisher:
116
115
  version: str,
117
116
  description: Optional[str],
118
117
  metadata: Optional[dict],
119
- ) -> SchemaRecord:
120
- """Build a SchemaRecord from a PackableSample class."""
121
- field_defs = []
118
+ ) -> LexSchemaRecord:
119
+ """Build a LexSchemaRecord from a PackableSample class."""
122
120
  type_hints = get_type_hints(sample_type)
121
+ properties: dict[str, dict] = {}
122
+ required_fields: list[str] = []
123
+ has_ndarray = False
123
124
 
124
125
  for f in fields(sample_type):
125
126
  field_type = type_hints.get(f.name, f.type)
126
- field_def = self._field_to_def(f.name, field_type)
127
- field_defs.append(field_def)
128
-
129
- return SchemaRecord(
127
+ field_type, is_optional = unwrap_optional(field_type)
128
+ prop = self._python_type_to_json_schema(field_type)
129
+ properties[f.name] = prop
130
+ if not is_optional:
131
+ required_fields.append(f.name)
132
+ if is_ndarray_type(field_type):
133
+ has_ndarray = True
134
+
135
+ schema_body = {
136
+ "$schema": "http://json-schema.org/draft-07/schema#",
137
+ "type": "object",
138
+ "properties": properties,
139
+ }
140
+ if required_fields:
141
+ schema_body["required"] = required_fields
142
+
143
+ array_format_versions = None
144
+ if has_ndarray:
145
+ array_format_versions = {"ndarrayBytes": "1.0.0"}
146
+
147
+ return LexSchemaRecord(
130
148
  name=name or sample_type.__name__,
131
149
  version=version,
150
+ schema_type="jsonSchema",
151
+ schema=JsonSchemaFormat(
152
+ schema_body=schema_body,
153
+ array_format_versions=array_format_versions,
154
+ ),
132
155
  description=description,
133
- fields=field_defs,
134
156
  metadata=metadata,
135
157
  )
136
158
 
137
- def _field_to_def(self, name: str, python_type) -> FieldDef:
138
- """Convert a Python field to a FieldDef."""
139
- python_type, is_optional = unwrap_optional(python_type)
140
- field_type = self._python_type_to_field_type(python_type)
141
- return FieldDef(name=name, field_type=field_type, optional=is_optional)
142
-
143
- def _python_type_to_field_type(self, python_type) -> FieldType:
144
- """Map a Python type to a FieldType."""
159
+ def _python_type_to_json_schema(self, python_type) -> dict:
160
+ """Map a Python type to a JSON Schema property definition."""
145
161
  if python_type is str:
146
- return FieldType(kind="primitive", primitive="str")
162
+ return {"type": "string"}
147
163
  if python_type is int:
148
- return FieldType(kind="primitive", primitive="int")
164
+ return {"type": "integer"}
149
165
  if python_type is float:
150
- return FieldType(kind="primitive", primitive="float")
166
+ return {"type": "number"}
151
167
  if python_type is bool:
152
- return FieldType(kind="primitive", primitive="bool")
168
+ return {"type": "boolean"}
153
169
  if python_type is bytes:
154
- return FieldType(kind="primitive", primitive="bytes")
170
+ return {"type": "string", "format": "byte", "contentEncoding": "base64"}
155
171
 
156
172
  if is_ndarray_type(python_type):
157
- return FieldType(
158
- kind="ndarray", dtype=extract_ndarray_dtype(python_type), shape=None
159
- )
173
+ return {
174
+ "$ref": "https://foundation.ac/schemas/atdata-ndarray-bytes/1.0.0#/$defs/ndarray"
175
+ }
160
176
 
161
177
  origin = get_origin(python_type)
162
178
  if origin is list:
163
179
  args = get_args(python_type)
164
180
  items = (
165
- self._python_type_to_field_type(args[0])
181
+ self._python_type_to_json_schema(args[0])
166
182
  if args
167
- else FieldType(kind="primitive", primitive="str")
183
+ else {"type": "string"}
168
184
  )
169
- return FieldType(kind="array", items=items)
185
+ return {"type": "array", "items": items}
170
186
 
171
187
  if is_dataclass(python_type):
172
188
  raise TypeError(
@@ -224,6 +240,18 @@ class SchemaLoader:
224
240
 
225
241
  return record
226
242
 
243
+ def get_typed(self, uri: str | AtUri) -> LexSchemaRecord:
244
+ """Fetch a schema record and return as a typed object.
245
+
246
+ Args:
247
+ uri: The AT URI of the schema record.
248
+
249
+ Returns:
250
+ LexSchemaRecord instance.
251
+ """
252
+ record = self.get(uri)
253
+ return LexSchemaRecord.from_record(record)
254
+
227
255
  def list_all(
228
256
  self,
229
257
  repo: Optional[str] = None,
@@ -19,11 +19,14 @@ Examples:
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
- import tempfile
23
22
  from dataclasses import dataclass
24
23
  from typing import TYPE_CHECKING, Any
25
24
 
26
- import webdataset as wds
25
+ #: Maximum size in bytes for a single PDS blob upload (50 MB).
26
+ PDS_BLOB_LIMIT_BYTES: int = 50_000_000
27
+
28
+ #: Maximum total dataset size in bytes for atmosphere uploads (1 GB).
29
+ PDS_TOTAL_DATASET_LIMIT_BYTES: int = 1_000_000_000
27
30
 
28
31
  if TYPE_CHECKING:
29
32
  from ..dataset import Dataset
@@ -31,6 +34,25 @@ if TYPE_CHECKING:
31
34
  from .client import Atmosphere
32
35
 
33
36
 
37
+ class ShardUploadResult(list):
38
+ """Return type for ``PDSBlobStore.write_shards()``.
39
+
40
+ Extends ``list[str]`` (AT URIs) so it satisfies the ``AbstractDataStore``
41
+ protocol, while also carrying the raw blob reference dicts needed to
42
+ create ``storageBlobs`` records.
43
+
44
+ Attributes:
45
+ blob_refs: Blob reference dicts as returned by
46
+ ``Atmosphere.upload_blob()``.
47
+ """
48
+
49
+ blob_refs: list[dict]
50
+
51
+ def __init__(self, urls: list[str], blob_refs: list[dict]) -> None:
52
+ super().__init__(urls)
53
+ self.blob_refs = blob_refs
54
+
55
+
34
56
  @dataclass
35
57
  class PDSBlobStore:
36
58
  """PDS blob store implementing AbstractDataStore protocol.
@@ -59,78 +81,54 @@ class PDSBlobStore:
59
81
  ds: "Dataset",
60
82
  *,
61
83
  prefix: str,
62
- maxcount: int = 10000,
63
- maxsize: float = 3e9,
64
84
  **kwargs: Any,
65
- ) -> list[str]:
66
- """Write dataset shards as PDS blobs.
85
+ ) -> "ShardUploadResult":
86
+ """Upload existing dataset shards as PDS blobs.
67
87
 
68
- Creates tar archives from the dataset and uploads each as a blob
69
- to the authenticated user's PDS.
88
+ Reads the tar archives already written to disk by the caller and
89
+ uploads each as a blob to the authenticated user's PDS. This
90
+ avoids re-serializing samples that have already been written.
70
91
 
71
92
  Args:
72
- ds: The Dataset to write.
73
- prefix: Logical path prefix for naming (used in shard names only).
74
- maxcount: Maximum samples per shard (default: 10000).
75
- maxsize: Maximum shard size in bytes (default: 3GB, PDS limit).
76
- **kwargs: Additional args passed to wds.ShardWriter.
93
+ ds: The Dataset whose shards to upload.
94
+ prefix: Logical path prefix (unused, kept for protocol compat).
95
+ **kwargs: Unused, kept for protocol compatibility.
77
96
 
78
97
  Returns:
79
- List of AT URIs for the written blobs, in format:
80
- ``at://{did}/blob/{cid}``
98
+ A ``ShardUploadResult`` (behaves as ``list[str]`` of AT URIs)
99
+ with a ``blob_refs`` attribute containing the raw blob reference
100
+ dicts needed for ``storageBlobs`` records.
81
101
 
82
102
  Raises:
83
103
  ValueError: If not authenticated.
84
- RuntimeError: If no shards were written.
85
-
86
- Note:
87
- PDS blobs have size limits (typically 50MB-5GB depending on PDS).
88
- Adjust maxcount/maxsize to stay within limits.
104
+ RuntimeError: If no shards are found on the dataset.
89
105
  """
90
106
  if not self.client.did:
91
107
  raise ValueError("Client must be authenticated to upload blobs")
92
108
 
93
109
  did = self.client.did
94
110
  blob_urls: list[str] = []
111
+ blob_refs: list[dict] = []
112
+
113
+ shard_paths = ds.list_shards()
114
+ if not shard_paths:
115
+ raise RuntimeError("No shards to upload")
116
+
117
+ for shard_url in shard_paths:
118
+ with open(shard_url, "rb") as f:
119
+ shard_data = f.read()
120
+
121
+ blob_ref = self.client.upload_blob(
122
+ shard_data,
123
+ mime_type="application/x-tar",
124
+ )
125
+
126
+ blob_refs.append(blob_ref)
127
+ cid = blob_ref["ref"]["$link"]
128
+ at_uri = f"at://{did}/blob/{cid}"
129
+ blob_urls.append(at_uri)
95
130
 
96
- # Write shards to temp files, upload each as blob
97
- with tempfile.TemporaryDirectory() as temp_dir:
98
- shard_pattern = f"{temp_dir}/shard-%06d.tar"
99
- written_files: list[str] = []
100
-
101
- # Track written files via custom post callback
102
- def track_file(fname: str) -> None:
103
- written_files.append(fname)
104
-
105
- with wds.writer.ShardWriter(
106
- shard_pattern,
107
- maxcount=maxcount,
108
- maxsize=maxsize,
109
- post=track_file,
110
- **kwargs,
111
- ) as sink:
112
- for sample in ds.ordered(batch_size=None):
113
- sink.write(sample.as_wds)
114
-
115
- if not written_files:
116
- raise RuntimeError("No shards written")
117
-
118
- # Upload each shard as a blob
119
- for shard_path in written_files:
120
- with open(shard_path, "rb") as f:
121
- shard_data = f.read()
122
-
123
- blob_ref = self.client.upload_blob(
124
- shard_data,
125
- mime_type="application/x-tar",
126
- )
127
-
128
- # Extract CID from blob reference
129
- cid = blob_ref["ref"]["$link"]
130
- at_uri = f"at://{did}/blob/{cid}"
131
- blob_urls.append(at_uri)
132
-
133
- return blob_urls
131
+ return ShardUploadResult(blob_urls, blob_refs)
134
132
 
135
133
  def read_url(self, url: str) -> str:
136
134
  """Resolve an AT URI blob reference to an HTTP URL.
@@ -200,4 +198,9 @@ class PDSBlobStore:
200
198
  return BlobSource(blob_refs=blob_refs)
201
199
 
202
200
 
203
- __all__ = ["PDSBlobStore"]
201
+ __all__ = [
202
+ "PDS_BLOB_LIMIT_BYTES",
203
+ "PDS_TOTAL_DATASET_LIMIT_BYTES",
204
+ "PDSBlobStore",
205
+ "ShardUploadResult",
206
+ ]