atdata 0.2.0a1__py3-none-any.whl → 0.2.2b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +43 -10
- atdata/_cid.py +150 -0
- atdata/_hf_api.py +692 -0
- atdata/_protocols.py +519 -0
- atdata/_schema_codec.py +442 -0
- atdata/_sources.py +515 -0
- atdata/_stub_manager.py +529 -0
- atdata/_type_utils.py +90 -0
- atdata/atmosphere/__init__.py +278 -7
- atdata/atmosphere/_types.py +9 -7
- atdata/atmosphere/client.py +146 -6
- atdata/atmosphere/lens.py +29 -25
- atdata/atmosphere/records.py +197 -30
- atdata/atmosphere/schema.py +41 -98
- atdata/atmosphere/store.py +208 -0
- atdata/cli/__init__.py +213 -0
- atdata/cli/diagnose.py +165 -0
- atdata/cli/local.py +280 -0
- atdata/dataset.py +482 -167
- atdata/lens.py +61 -57
- atdata/local.py +1400 -185
- atdata/promote.py +199 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.2b1.dist-info}/METADATA +105 -14
- atdata-0.2.2b1.dist-info/RECORD +28 -0
- atdata-0.2.0a1.dist-info/RECORD +0 -16
- {atdata-0.2.0a1.dist-info → atdata-0.2.2b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.2b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.0a1.dist-info → atdata-0.2.2b1.dist-info}/licenses/LICENSE +0 -0
atdata/local.py
CHANGED
|
@@ -6,11 +6,13 @@ This module provides a local storage backend for atdata datasets using:
|
|
|
6
6
|
|
|
7
7
|
The main classes are:
|
|
8
8
|
- Repo: Manages dataset storage in S3 with Redis indexing
|
|
9
|
-
-
|
|
10
|
-
-
|
|
9
|
+
- LocalIndex: Redis-backed index for tracking dataset metadata
|
|
10
|
+
- LocalDatasetEntry: Index entry representing a stored dataset
|
|
11
11
|
|
|
12
12
|
This is intended for development and small-scale deployment before
|
|
13
|
-
migrating to the full atproto PDS infrastructure.
|
|
13
|
+
migrating to the full atproto PDS infrastructure. The implementation
|
|
14
|
+
uses ATProto-compatible CIDs for content addressing, enabling seamless
|
|
15
|
+
promotion from local storage to the atmosphere (ATProto network).
|
|
14
16
|
"""
|
|
15
17
|
|
|
16
18
|
##
|
|
@@ -20,8 +22,16 @@ from atdata import (
|
|
|
20
22
|
PackableSample,
|
|
21
23
|
Dataset,
|
|
22
24
|
)
|
|
25
|
+
from atdata._cid import generate_cid
|
|
26
|
+
from atdata._type_utils import (
|
|
27
|
+
numpy_dtype_to_string,
|
|
28
|
+
PRIMITIVE_TYPE_MAP,
|
|
29
|
+
unwrap_optional,
|
|
30
|
+
is_ndarray_type,
|
|
31
|
+
extract_ndarray_dtype,
|
|
32
|
+
)
|
|
33
|
+
from atdata._protocols import IndexEntry, AbstractDataStore, Packable
|
|
23
34
|
|
|
24
|
-
import os
|
|
25
35
|
from pathlib import Path
|
|
26
36
|
from uuid import uuid4
|
|
27
37
|
from tempfile import TemporaryDirectory
|
|
@@ -38,51 +48,513 @@ import webdataset as wds
|
|
|
38
48
|
|
|
39
49
|
from dataclasses import (
|
|
40
50
|
dataclass,
|
|
41
|
-
asdict,
|
|
42
51
|
field,
|
|
43
52
|
)
|
|
44
53
|
from typing import (
|
|
45
54
|
Any,
|
|
46
|
-
Optional,
|
|
47
|
-
Dict,
|
|
48
55
|
Type,
|
|
49
56
|
TypeVar,
|
|
50
57
|
Generator,
|
|
58
|
+
Iterator,
|
|
51
59
|
BinaryIO,
|
|
60
|
+
Union,
|
|
61
|
+
Optional,
|
|
62
|
+
Literal,
|
|
52
63
|
cast,
|
|
64
|
+
get_type_hints,
|
|
65
|
+
get_origin,
|
|
66
|
+
get_args,
|
|
53
67
|
)
|
|
68
|
+
from dataclasses import fields, is_dataclass
|
|
69
|
+
from datetime import datetime, timezone
|
|
70
|
+
import json
|
|
71
|
+
import warnings
|
|
54
72
|
|
|
55
73
|
T = TypeVar( 'T', bound = PackableSample )
|
|
56
74
|
|
|
75
|
+
# Redis key prefixes for index entries and schemas
|
|
76
|
+
REDIS_KEY_DATASET_ENTRY = "LocalDatasetEntry"
|
|
77
|
+
REDIS_KEY_SCHEMA = "LocalSchema"
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class SchemaNamespace:
|
|
81
|
+
"""Namespace for accessing loaded schema types as attributes.
|
|
82
|
+
|
|
83
|
+
This class provides a module-like interface for accessing dynamically
|
|
84
|
+
loaded schema types. After calling ``index.load_schema(uri)``, the
|
|
85
|
+
schema's class becomes available as an attribute on this namespace.
|
|
86
|
+
|
|
87
|
+
Example:
|
|
88
|
+
::
|
|
89
|
+
|
|
90
|
+
>>> index.load_schema("atdata://local/sampleSchema/MySample@1.0.0")
|
|
91
|
+
>>> MyType = index.types.MySample
|
|
92
|
+
>>> sample = MyType(field1="hello", field2=42)
|
|
93
|
+
|
|
94
|
+
The namespace supports:
|
|
95
|
+
- Attribute access: ``index.types.MySample``
|
|
96
|
+
- Iteration: ``for name in index.types: ...``
|
|
97
|
+
- Length: ``len(index.types)``
|
|
98
|
+
- Contains check: ``"MySample" in index.types``
|
|
99
|
+
|
|
100
|
+
Note:
|
|
101
|
+
For full IDE autocomplete support, import from the generated module::
|
|
102
|
+
|
|
103
|
+
# After load_schema with auto_stubs=True
|
|
104
|
+
from local.MySample_1_0_0 import MySample
|
|
105
|
+
sample = MySample(name="hello", value=42) # IDE knows signature!
|
|
106
|
+
|
|
107
|
+
Add ``index.stub_dir`` to your IDE's extraPaths for imports to resolve.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
def __init__(self) -> None:
|
|
111
|
+
self._types: dict[str, Type[Packable]] = {}
|
|
112
|
+
|
|
113
|
+
def _register(self, name: str, cls: Type[Packable]) -> None:
|
|
114
|
+
"""Register a schema type in the namespace."""
|
|
115
|
+
self._types[name] = cls
|
|
116
|
+
|
|
117
|
+
def __getattr__(self, name: str) -> Any:
|
|
118
|
+
# Returns Any to avoid IDE complaints about unknown attributes.
|
|
119
|
+
# For full IDE support, import from the generated module instead.
|
|
120
|
+
if name.startswith("_"):
|
|
121
|
+
raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'")
|
|
122
|
+
if name not in self._types:
|
|
123
|
+
raise AttributeError(
|
|
124
|
+
f"Schema '{name}' not loaded. "
|
|
125
|
+
f"Call index.load_schema() first to load the schema."
|
|
126
|
+
)
|
|
127
|
+
return self._types[name]
|
|
128
|
+
|
|
129
|
+
def __dir__(self) -> list[str]:
|
|
130
|
+
return list(self._types.keys()) + ["_types", "_register", "get"]
|
|
131
|
+
|
|
132
|
+
def __iter__(self) -> Iterator[str]:
|
|
133
|
+
return iter(self._types)
|
|
134
|
+
|
|
135
|
+
def __len__(self) -> int:
|
|
136
|
+
return len(self._types)
|
|
137
|
+
|
|
138
|
+
def __contains__(self, name: str) -> bool:
|
|
139
|
+
return name in self._types
|
|
140
|
+
|
|
141
|
+
def __repr__(self) -> str:
|
|
142
|
+
if not self._types:
|
|
143
|
+
return "SchemaNamespace(empty)"
|
|
144
|
+
names = ", ".join(sorted(self._types.keys()))
|
|
145
|
+
return f"SchemaNamespace({names})"
|
|
146
|
+
|
|
147
|
+
def get(self, name: str, default: T | None = None) -> Type[Packable] | T | None:
|
|
148
|
+
"""Get a type by name, returning default if not found.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
name: The schema class name to look up.
|
|
152
|
+
default: Value to return if not found (default: None).
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
The schema class, or default if not loaded.
|
|
156
|
+
"""
|
|
157
|
+
return self._types.get(name, default)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
##
|
|
161
|
+
# Schema types
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@dataclass
|
|
165
|
+
class SchemaFieldType:
|
|
166
|
+
"""Schema field type definition for local storage.
|
|
167
|
+
|
|
168
|
+
Represents a type in the schema type system, supporting primitives,
|
|
169
|
+
ndarrays, arrays, and references to other schemas.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
kind: Literal["primitive", "ndarray", "ref", "array"]
|
|
173
|
+
"""The category of type."""
|
|
174
|
+
|
|
175
|
+
primitive: Optional[str] = None
|
|
176
|
+
"""For kind='primitive': one of 'str', 'int', 'float', 'bool', 'bytes'."""
|
|
177
|
+
|
|
178
|
+
dtype: Optional[str] = None
|
|
179
|
+
"""For kind='ndarray': numpy dtype string (e.g., 'float32')."""
|
|
180
|
+
|
|
181
|
+
ref: Optional[str] = None
|
|
182
|
+
"""For kind='ref': URI of referenced schema."""
|
|
183
|
+
|
|
184
|
+
items: Optional["SchemaFieldType"] = None
|
|
185
|
+
"""For kind='array': type of array elements."""
|
|
186
|
+
|
|
187
|
+
@classmethod
|
|
188
|
+
def from_dict(cls, data: dict) -> "SchemaFieldType":
|
|
189
|
+
"""Create from a dictionary (e.g., from Redis storage)."""
|
|
190
|
+
type_str = data.get("$type", "")
|
|
191
|
+
if "#" in type_str:
|
|
192
|
+
kind = type_str.split("#")[-1]
|
|
193
|
+
else:
|
|
194
|
+
kind = data.get("kind", "primitive")
|
|
195
|
+
|
|
196
|
+
items = None
|
|
197
|
+
if "items" in data and data["items"]:
|
|
198
|
+
items = cls.from_dict(data["items"])
|
|
199
|
+
|
|
200
|
+
return cls(
|
|
201
|
+
kind=kind, # type: ignore[arg-type]
|
|
202
|
+
primitive=data.get("primitive"),
|
|
203
|
+
dtype=data.get("dtype"),
|
|
204
|
+
ref=data.get("ref"),
|
|
205
|
+
items=items,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def to_dict(self) -> dict:
|
|
209
|
+
"""Convert to dictionary for storage."""
|
|
210
|
+
result: dict[str, Any] = {"$type": f"local#{self.kind}"}
|
|
211
|
+
if self.kind == "primitive":
|
|
212
|
+
result["primitive"] = self.primitive
|
|
213
|
+
elif self.kind == "ndarray":
|
|
214
|
+
result["dtype"] = self.dtype
|
|
215
|
+
elif self.kind == "ref":
|
|
216
|
+
result["ref"] = self.ref
|
|
217
|
+
elif self.kind == "array" and self.items:
|
|
218
|
+
result["items"] = self.items.to_dict()
|
|
219
|
+
return result
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@dataclass
|
|
223
|
+
class SchemaField:
|
|
224
|
+
"""Schema field definition for local storage."""
|
|
225
|
+
|
|
226
|
+
name: str
|
|
227
|
+
"""Field name."""
|
|
228
|
+
|
|
229
|
+
field_type: SchemaFieldType
|
|
230
|
+
"""Type of this field."""
|
|
231
|
+
|
|
232
|
+
optional: bool = False
|
|
233
|
+
"""Whether this field can be None."""
|
|
234
|
+
|
|
235
|
+
@classmethod
|
|
236
|
+
def from_dict(cls, data: dict) -> "SchemaField":
|
|
237
|
+
"""Create from a dictionary."""
|
|
238
|
+
return cls(
|
|
239
|
+
name=data["name"],
|
|
240
|
+
field_type=SchemaFieldType.from_dict(data["fieldType"]),
|
|
241
|
+
optional=data.get("optional", False),
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def to_dict(self) -> dict:
|
|
245
|
+
"""Convert to dictionary for storage."""
|
|
246
|
+
return {
|
|
247
|
+
"name": self.name,
|
|
248
|
+
"fieldType": self.field_type.to_dict(),
|
|
249
|
+
"optional": self.optional,
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
def __getitem__(self, key: str) -> Any:
|
|
253
|
+
"""Dict-style access for backwards compatibility."""
|
|
254
|
+
if key == "name":
|
|
255
|
+
return self.name
|
|
256
|
+
elif key == "fieldType":
|
|
257
|
+
return self.field_type.to_dict()
|
|
258
|
+
elif key == "optional":
|
|
259
|
+
return self.optional
|
|
260
|
+
raise KeyError(key)
|
|
261
|
+
|
|
262
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
263
|
+
"""Dict-style get() for backwards compatibility."""
|
|
264
|
+
try:
|
|
265
|
+
return self[key]
|
|
266
|
+
except KeyError:
|
|
267
|
+
return default
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@dataclass
|
|
271
|
+
class LocalSchemaRecord:
|
|
272
|
+
"""Schema record for local storage.
|
|
273
|
+
|
|
274
|
+
Represents a PackableSample schema stored in the local index.
|
|
275
|
+
Aligns with the atmosphere SchemaRecord structure for seamless promotion.
|
|
276
|
+
"""
|
|
277
|
+
|
|
278
|
+
name: str
|
|
279
|
+
"""Schema name (typically the class name)."""
|
|
280
|
+
|
|
281
|
+
version: str
|
|
282
|
+
"""Semantic version string (e.g., '1.0.0')."""
|
|
283
|
+
|
|
284
|
+
fields: list[SchemaField]
|
|
285
|
+
"""List of field definitions."""
|
|
286
|
+
|
|
287
|
+
ref: str
|
|
288
|
+
"""Schema reference URI (atdata://local/sampleSchema/{name}@{version})."""
|
|
289
|
+
|
|
290
|
+
description: Optional[str] = None
|
|
291
|
+
"""Human-readable description."""
|
|
292
|
+
|
|
293
|
+
created_at: Optional[datetime] = None
|
|
294
|
+
"""When this schema was published."""
|
|
295
|
+
|
|
296
|
+
@classmethod
|
|
297
|
+
def from_dict(cls, data: dict) -> "LocalSchemaRecord":
|
|
298
|
+
"""Create from a dictionary (e.g., from Redis storage)."""
|
|
299
|
+
created_at = None
|
|
300
|
+
if "createdAt" in data:
|
|
301
|
+
try:
|
|
302
|
+
created_at = datetime.fromisoformat(data["createdAt"])
|
|
303
|
+
except (ValueError, TypeError):
|
|
304
|
+
created_at = None # Invalid datetime format, leave as None
|
|
305
|
+
|
|
306
|
+
return cls(
|
|
307
|
+
name=data["name"],
|
|
308
|
+
version=data["version"],
|
|
309
|
+
fields=[SchemaField.from_dict(f) for f in data.get("fields", [])],
|
|
310
|
+
ref=data.get("$ref", ""),
|
|
311
|
+
description=data.get("description"),
|
|
312
|
+
created_at=created_at,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
def to_dict(self) -> dict:
|
|
316
|
+
"""Convert to dictionary for storage."""
|
|
317
|
+
result: dict[str, Any] = {
|
|
318
|
+
"name": self.name,
|
|
319
|
+
"version": self.version,
|
|
320
|
+
"fields": [f.to_dict() for f in self.fields],
|
|
321
|
+
"$ref": self.ref,
|
|
322
|
+
}
|
|
323
|
+
if self.description:
|
|
324
|
+
result["description"] = self.description
|
|
325
|
+
if self.created_at:
|
|
326
|
+
result["createdAt"] = self.created_at.isoformat()
|
|
327
|
+
return result
|
|
328
|
+
|
|
329
|
+
def __getitem__(self, key: str) -> Any:
|
|
330
|
+
"""Dict-style access for backwards compatibility."""
|
|
331
|
+
if key == "name":
|
|
332
|
+
return self.name
|
|
333
|
+
elif key == "version":
|
|
334
|
+
return self.version
|
|
335
|
+
elif key == "fields":
|
|
336
|
+
return self.fields # Returns list of SchemaField (also subscriptable)
|
|
337
|
+
elif key == "$ref":
|
|
338
|
+
return self.ref
|
|
339
|
+
elif key == "description":
|
|
340
|
+
return self.description
|
|
341
|
+
elif key == "createdAt":
|
|
342
|
+
return self.created_at.isoformat() if self.created_at else None
|
|
343
|
+
raise KeyError(key)
|
|
344
|
+
|
|
345
|
+
def __contains__(self, key: str) -> bool:
|
|
346
|
+
"""Support 'in' operator for backwards compatibility."""
|
|
347
|
+
return key in ("name", "version", "fields", "$ref", "description", "createdAt")
|
|
348
|
+
|
|
349
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
350
|
+
"""Dict-style get() for backwards compatibility."""
|
|
351
|
+
try:
|
|
352
|
+
return self[key]
|
|
353
|
+
except KeyError:
|
|
354
|
+
return default
|
|
355
|
+
|
|
57
356
|
|
|
58
357
|
##
|
|
59
358
|
# Helpers
|
|
60
359
|
|
|
61
|
-
def _kind_str_for_sample_type( st: Type[
|
|
62
|
-
"""
|
|
360
|
+
def _kind_str_for_sample_type( st: Type[Packable] ) -> str:
|
|
361
|
+
"""Return fully-qualified 'module.name' string for a sample type."""
|
|
362
|
+
return f'{st.__module__}.{st.__name__}'
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def _create_s3_write_callbacks(
|
|
366
|
+
credentials: dict[str, Any],
|
|
367
|
+
temp_dir: str,
|
|
368
|
+
written_shards: list[str],
|
|
369
|
+
fs: S3FileSystem | None,
|
|
370
|
+
cache_local: bool,
|
|
371
|
+
add_s3_prefix: bool = False,
|
|
372
|
+
) -> tuple:
|
|
373
|
+
"""Create opener and post callbacks for ShardWriter with S3 upload.
|
|
63
374
|
|
|
64
375
|
Args:
|
|
65
|
-
|
|
376
|
+
credentials: S3 credentials dict.
|
|
377
|
+
temp_dir: Temporary directory for local caching.
|
|
378
|
+
written_shards: List to append written shard paths to.
|
|
379
|
+
fs: S3FileSystem for direct writes (used when cache_local=False).
|
|
380
|
+
cache_local: If True, write locally then copy to S3.
|
|
381
|
+
add_s3_prefix: If True, prepend 's3://' to shard paths.
|
|
66
382
|
|
|
67
383
|
Returns:
|
|
68
|
-
|
|
384
|
+
Tuple of (writer_opener, writer_post) callbacks.
|
|
69
385
|
"""
|
|
70
|
-
|
|
386
|
+
if cache_local:
|
|
387
|
+
import boto3
|
|
388
|
+
|
|
389
|
+
s3_client_kwargs = {
|
|
390
|
+
'aws_access_key_id': credentials['AWS_ACCESS_KEY_ID'],
|
|
391
|
+
'aws_secret_access_key': credentials['AWS_SECRET_ACCESS_KEY']
|
|
392
|
+
}
|
|
393
|
+
if 'AWS_ENDPOINT' in credentials:
|
|
394
|
+
s3_client_kwargs['endpoint_url'] = credentials['AWS_ENDPOINT']
|
|
395
|
+
s3_client = boto3.client('s3', **s3_client_kwargs)
|
|
396
|
+
|
|
397
|
+
def _writer_opener(p: str):
|
|
398
|
+
local_path = Path(temp_dir) / p
|
|
399
|
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
|
400
|
+
return open(local_path, 'wb')
|
|
401
|
+
|
|
402
|
+
def _writer_post(p: str):
|
|
403
|
+
local_path = Path(temp_dir) / p
|
|
404
|
+
path_parts = Path(p).parts
|
|
405
|
+
bucket = path_parts[0]
|
|
406
|
+
key = str(Path(*path_parts[1:]))
|
|
407
|
+
|
|
408
|
+
with open(local_path, 'rb') as f_in:
|
|
409
|
+
s3_client.put_object(Bucket=bucket, Key=key, Body=f_in.read())
|
|
410
|
+
|
|
411
|
+
local_path.unlink()
|
|
412
|
+
if add_s3_prefix:
|
|
413
|
+
written_shards.append(f"s3://{p}")
|
|
414
|
+
else:
|
|
415
|
+
written_shards.append(p)
|
|
71
416
|
|
|
72
|
-
|
|
73
|
-
|
|
417
|
+
return _writer_opener, _writer_post
|
|
418
|
+
else:
|
|
419
|
+
assert fs is not None, "S3FileSystem required when cache_local=False"
|
|
74
420
|
|
|
75
|
-
|
|
421
|
+
def _direct_opener(s: str):
|
|
422
|
+
return cast(BinaryIO, fs.open(f's3://{s}', 'wb'))
|
|
423
|
+
|
|
424
|
+
def _direct_post(s: str):
|
|
425
|
+
if add_s3_prefix:
|
|
426
|
+
written_shards.append(f"s3://{s}")
|
|
427
|
+
else:
|
|
428
|
+
written_shards.append(s)
|
|
429
|
+
|
|
430
|
+
return _direct_opener, _direct_post
|
|
431
|
+
|
|
432
|
+
##
|
|
433
|
+
# Schema helpers
|
|
434
|
+
|
|
435
|
+
# URI scheme prefixes
|
|
436
|
+
_ATDATA_URI_PREFIX = "atdata://local/sampleSchema/"
|
|
437
|
+
_LEGACY_URI_PREFIX = "local://schemas/"
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def _schema_ref_from_type(sample_type: Type[Packable], version: str) -> str:
|
|
441
|
+
"""Generate 'atdata://local/sampleSchema/{name}@{version}' reference."""
|
|
442
|
+
return _make_schema_ref(sample_type.__name__, version)
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def _make_schema_ref(name: str, version: str) -> str:
|
|
446
|
+
"""Generate schema reference URI from name and version."""
|
|
447
|
+
return f"{_ATDATA_URI_PREFIX}{name}@{version}"
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def _parse_schema_ref(ref: str) -> tuple[str, str]:
|
|
451
|
+
"""Parse schema reference into (name, version).
|
|
452
|
+
|
|
453
|
+
Supports both new format: 'atdata://local/sampleSchema/{name}@{version}'
|
|
454
|
+
and legacy format: 'local://schemas/{module.Class}@{version}'
|
|
455
|
+
"""
|
|
456
|
+
if ref.startswith(_ATDATA_URI_PREFIX):
|
|
457
|
+
path = ref[len(_ATDATA_URI_PREFIX):]
|
|
458
|
+
elif ref.startswith(_LEGACY_URI_PREFIX):
|
|
459
|
+
path = ref[len(_LEGACY_URI_PREFIX):]
|
|
460
|
+
else:
|
|
461
|
+
raise ValueError(f"Invalid schema reference: {ref}")
|
|
462
|
+
|
|
463
|
+
if "@" not in path:
|
|
464
|
+
raise ValueError(f"Schema reference must include version (@version): {ref}")
|
|
465
|
+
|
|
466
|
+
name, version = path.rsplit("@", 1)
|
|
467
|
+
# For legacy format, extract just the class name from module.Class
|
|
468
|
+
if "." in name:
|
|
469
|
+
name = name.rsplit(".", 1)[1]
|
|
470
|
+
return name, version
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def _parse_semver(version: str) -> tuple[int, int, int]:
|
|
474
|
+
"""Parse semantic version string into (major, minor, patch) tuple."""
|
|
475
|
+
parts = version.split(".")
|
|
476
|
+
if len(parts) != 3:
|
|
477
|
+
raise ValueError(f"Invalid semver format: {version}")
|
|
478
|
+
return int(parts[0]), int(parts[1]), int(parts[2])
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def _increment_patch(version: str) -> str:
|
|
482
|
+
"""Increment patch version: 1.0.0 -> 1.0.1"""
|
|
483
|
+
major, minor, patch = _parse_semver(version)
|
|
484
|
+
return f"{major}.{minor}.{patch + 1}"
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def _python_type_to_field_type(python_type: Any) -> dict:
|
|
488
|
+
"""Convert Python type annotation to schema field type dict."""
|
|
489
|
+
if python_type in PRIMITIVE_TYPE_MAP:
|
|
490
|
+
return {"$type": "local#primitive", "primitive": PRIMITIVE_TYPE_MAP[python_type]}
|
|
491
|
+
|
|
492
|
+
if is_ndarray_type(python_type):
|
|
493
|
+
return {"$type": "local#ndarray", "dtype": extract_ndarray_dtype(python_type)}
|
|
494
|
+
|
|
495
|
+
origin = get_origin(python_type)
|
|
496
|
+
if origin is list:
|
|
497
|
+
args = get_args(python_type)
|
|
498
|
+
items = _python_type_to_field_type(args[0]) if args else {"$type": "local#primitive", "primitive": "str"}
|
|
499
|
+
return {"$type": "local#array", "items": items}
|
|
500
|
+
|
|
501
|
+
if is_dataclass(python_type):
|
|
502
|
+
raise TypeError(
|
|
503
|
+
f"Nested dataclass types not yet supported: {python_type.__name__}. "
|
|
504
|
+
"Publish nested types separately and use references."
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
raise TypeError(f"Unsupported type for schema field: {python_type}")
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def _build_schema_record(
|
|
511
|
+
sample_type: Type[Packable],
|
|
512
|
+
*,
|
|
513
|
+
version: str,
|
|
514
|
+
description: str | None = None,
|
|
515
|
+
) -> dict:
|
|
516
|
+
"""Build a schema record dict from a PackableSample type.
|
|
76
517
|
|
|
77
518
|
Args:
|
|
78
|
-
|
|
519
|
+
sample_type: The PackableSample subclass to introspect.
|
|
520
|
+
version: Semantic version string.
|
|
521
|
+
description: Optional human-readable description. If None, uses the
|
|
522
|
+
class docstring.
|
|
79
523
|
|
|
80
524
|
Returns:
|
|
81
|
-
|
|
525
|
+
Schema record dict suitable for Redis storage.
|
|
526
|
+
|
|
527
|
+
Raises:
|
|
528
|
+
ValueError: If sample_type is not a dataclass.
|
|
529
|
+
TypeError: If a field type is not supported.
|
|
82
530
|
"""
|
|
531
|
+
if not is_dataclass(sample_type):
|
|
532
|
+
raise ValueError(f"{sample_type.__name__} must be a dataclass (use @packable)")
|
|
533
|
+
|
|
534
|
+
# Use docstring as fallback for description
|
|
535
|
+
if description is None:
|
|
536
|
+
description = sample_type.__doc__
|
|
537
|
+
|
|
538
|
+
field_defs = []
|
|
539
|
+
type_hints = get_type_hints(sample_type)
|
|
540
|
+
|
|
541
|
+
for f in fields(sample_type):
|
|
542
|
+
field_type = type_hints.get(f.name, f.type)
|
|
543
|
+
field_type, is_optional = unwrap_optional(field_type)
|
|
544
|
+
field_type_dict = _python_type_to_field_type(field_type)
|
|
545
|
+
|
|
546
|
+
field_defs.append({
|
|
547
|
+
"name": f.name,
|
|
548
|
+
"fieldType": field_type_dict,
|
|
549
|
+
"optional": is_optional,
|
|
550
|
+
})
|
|
551
|
+
|
|
83
552
|
return {
|
|
84
|
-
|
|
85
|
-
|
|
553
|
+
"name": sample_type.__name__,
|
|
554
|
+
"version": version,
|
|
555
|
+
"fields": field_defs,
|
|
556
|
+
"description": description,
|
|
557
|
+
"createdAt": datetime.now(timezone.utc).isoformat(),
|
|
86
558
|
}
|
|
87
559
|
|
|
88
560
|
|
|
@@ -90,80 +562,168 @@ def _decode_bytes_dict( d: dict[bytes, bytes] ) -> dict[str, str]:
|
|
|
90
562
|
# Redis object model
|
|
91
563
|
|
|
92
564
|
@dataclass
|
|
93
|
-
class
|
|
94
|
-
"""Index entry for a dataset stored in the repository.
|
|
565
|
+
class LocalDatasetEntry:
|
|
566
|
+
"""Index entry for a dataset stored in the local repository.
|
|
95
567
|
|
|
96
|
-
|
|
97
|
-
|
|
568
|
+
Implements the IndexEntry protocol for compatibility with AbstractIndex.
|
|
569
|
+
Uses dual identity: a content-addressable CID (ATProto-compatible) and
|
|
570
|
+
a human-readable name.
|
|
571
|
+
|
|
572
|
+
The CID is generated from the entry's content (schema_ref + data_urls),
|
|
573
|
+
ensuring the same data produces the same CID whether stored locally or
|
|
574
|
+
in the atmosphere. This enables seamless promotion from local to ATProto.
|
|
575
|
+
|
|
576
|
+
Attributes:
|
|
577
|
+
name: Human-readable name for this dataset.
|
|
578
|
+
schema_ref: Reference to the schema for this dataset.
|
|
579
|
+
data_urls: WebDataset URLs for the data.
|
|
580
|
+
metadata: Arbitrary metadata dictionary, or None if not set.
|
|
98
581
|
"""
|
|
99
582
|
##
|
|
100
583
|
|
|
101
|
-
|
|
102
|
-
"""
|
|
584
|
+
name: str
|
|
585
|
+
"""Human-readable name for this dataset."""
|
|
586
|
+
|
|
587
|
+
schema_ref: str
|
|
588
|
+
"""Reference to the schema for this dataset."""
|
|
589
|
+
|
|
590
|
+
data_urls: list[str]
|
|
591
|
+
"""WebDataset URLs for the data."""
|
|
592
|
+
|
|
593
|
+
metadata: dict | None = None
|
|
594
|
+
"""Arbitrary metadata dictionary, or None if not set."""
|
|
595
|
+
|
|
596
|
+
_cid: str | None = field(default=None, repr=False)
|
|
597
|
+
"""Content identifier (ATProto-compatible CID). Generated from content if not provided."""
|
|
598
|
+
|
|
599
|
+
# Legacy field for backwards compatibility during migration
|
|
600
|
+
_legacy_uuid: str | None = field(default=None, repr=False)
|
|
601
|
+
"""Legacy UUID for backwards compatibility with existing Redis entries."""
|
|
103
602
|
|
|
104
|
-
|
|
105
|
-
|
|
603
|
+
def __post_init__(self):
|
|
604
|
+
"""Generate CID from content if not provided."""
|
|
605
|
+
if self._cid is None:
|
|
606
|
+
self._cid = self._generate_cid()
|
|
106
607
|
|
|
107
|
-
|
|
108
|
-
|
|
608
|
+
def _generate_cid(self) -> str:
|
|
609
|
+
"""Generate ATProto-compatible CID from entry content."""
|
|
610
|
+
# CID is based on schema_ref and data_urls - the identity of the dataset
|
|
611
|
+
content = {
|
|
612
|
+
"schema_ref": self.schema_ref,
|
|
613
|
+
"data_urls": self.data_urls,
|
|
614
|
+
}
|
|
615
|
+
return generate_cid(content)
|
|
109
616
|
|
|
110
|
-
|
|
111
|
-
|
|
617
|
+
@property
|
|
618
|
+
def cid(self) -> str:
|
|
619
|
+
"""Content identifier (ATProto-compatible CID)."""
|
|
620
|
+
assert self._cid is not None
|
|
621
|
+
return self._cid
|
|
622
|
+
|
|
623
|
+
# Legacy compatibility
|
|
624
|
+
|
|
625
|
+
@property
|
|
626
|
+
def wds_url(self) -> str:
|
|
627
|
+
"""Legacy property: returns first data URL for backwards compatibility."""
|
|
628
|
+
return self.data_urls[0] if self.data_urls else ""
|
|
629
|
+
|
|
630
|
+
@property
|
|
631
|
+
def sample_kind(self) -> str:
|
|
632
|
+
"""Legacy property: returns schema_ref for backwards compatibility."""
|
|
633
|
+
return self.schema_ref
|
|
112
634
|
|
|
113
|
-
def write_to(
|
|
635
|
+
def write_to(self, redis: Redis):
|
|
114
636
|
"""Persist this index entry to Redis.
|
|
115
637
|
|
|
116
|
-
Stores the entry as a Redis hash with key '
|
|
638
|
+
Stores the entry as a Redis hash with key '{REDIS_KEY_DATASET_ENTRY}:{cid}'.
|
|
117
639
|
|
|
118
640
|
Args:
|
|
119
641
|
redis: Redis connection to write to.
|
|
120
642
|
"""
|
|
121
|
-
save_key = f'
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
643
|
+
save_key = f'{REDIS_KEY_DATASET_ENTRY}:{self.cid}'
|
|
644
|
+
data = {
|
|
645
|
+
'name': self.name,
|
|
646
|
+
'schema_ref': self.schema_ref,
|
|
647
|
+
'data_urls': msgpack.packb(self.data_urls), # Serialize list
|
|
648
|
+
'cid': self.cid,
|
|
649
|
+
}
|
|
650
|
+
if self.metadata is not None:
|
|
651
|
+
data['metadata'] = msgpack.packb(self.metadata)
|
|
652
|
+
if self._legacy_uuid is not None:
|
|
653
|
+
data['legacy_uuid'] = self._legacy_uuid
|
|
654
|
+
|
|
655
|
+
redis.hset(save_key, mapping=data) # type: ignore[arg-type]
|
|
656
|
+
|
|
657
|
+
@classmethod
|
|
658
|
+
def from_redis(cls, redis: Redis, cid: str) -> "LocalDatasetEntry":
|
|
659
|
+
"""Load an entry from Redis by CID.
|
|
660
|
+
|
|
661
|
+
Args:
|
|
662
|
+
redis: Redis connection to read from.
|
|
663
|
+
cid: Content identifier of the entry to load.
|
|
664
|
+
|
|
665
|
+
Returns:
|
|
666
|
+
LocalDatasetEntry loaded from Redis.
|
|
667
|
+
|
|
668
|
+
Raises:
|
|
669
|
+
KeyError: If entry not found.
|
|
670
|
+
"""
|
|
671
|
+
save_key = f'{REDIS_KEY_DATASET_ENTRY}:{cid}'
|
|
672
|
+
raw_data = redis.hgetall(save_key)
|
|
673
|
+
if not raw_data:
|
|
674
|
+
raise KeyError(f"{REDIS_KEY_DATASET_ENTRY} not found: {cid}")
|
|
675
|
+
|
|
676
|
+
# Decode string fields, keep binary fields as bytes for msgpack
|
|
677
|
+
raw_data_typed = cast(dict[bytes, bytes], raw_data)
|
|
678
|
+
name = raw_data_typed[b'name'].decode('utf-8')
|
|
679
|
+
schema_ref = raw_data_typed[b'schema_ref'].decode('utf-8')
|
|
680
|
+
cid_value = raw_data_typed.get(b'cid', b'').decode('utf-8') or None
|
|
681
|
+
legacy_uuid = raw_data_typed.get(b'legacy_uuid', b'').decode('utf-8') or None
|
|
682
|
+
|
|
683
|
+
# Deserialize msgpack fields (stored as raw bytes)
|
|
684
|
+
data_urls = msgpack.unpackb(raw_data_typed[b'data_urls'])
|
|
685
|
+
metadata = None
|
|
686
|
+
if b'metadata' in raw_data_typed:
|
|
687
|
+
metadata = msgpack.unpackb(raw_data_typed[b'metadata'])
|
|
688
|
+
|
|
689
|
+
return cls(
|
|
690
|
+
name=name,
|
|
691
|
+
schema_ref=schema_ref,
|
|
692
|
+
data_urls=data_urls,
|
|
693
|
+
metadata=metadata,
|
|
694
|
+
_cid=cid_value,
|
|
695
|
+
_legacy_uuid=legacy_uuid,
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
# Backwards compatibility alias
|
|
700
|
+
BasicIndexEntry = LocalDatasetEntry
|
|
126
701
|
|
|
127
702
|
def _s3_env( credentials_path: str | Path ) -> dict[str, Any]:
|
|
128
|
-
"""Load S3 credentials from
|
|
703
|
+
"""Load S3 credentials from .env file.
|
|
129
704
|
|
|
130
705
|
Args:
|
|
131
|
-
credentials_path: Path to .env file containing
|
|
706
|
+
credentials_path: Path to .env file containing AWS_ENDPOINT,
|
|
707
|
+
AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY.
|
|
132
708
|
|
|
133
709
|
Returns:
|
|
134
|
-
|
|
710
|
+
Dict with the three required credential keys.
|
|
135
711
|
|
|
136
712
|
Raises:
|
|
137
|
-
|
|
713
|
+
ValueError: If any required key is missing from the .env file.
|
|
138
714
|
"""
|
|
139
|
-
##
|
|
140
715
|
credentials_path = Path( credentials_path )
|
|
141
716
|
env_values = dotenv_values( credentials_path )
|
|
142
|
-
assert 'AWS_ENDPOINT' in env_values
|
|
143
|
-
assert 'AWS_ACCESS_KEY_ID' in env_values
|
|
144
|
-
assert 'AWS_SECRET_ACCESS_KEY' in env_values
|
|
145
717
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
'AWS_ACCESS_KEY_ID',
|
|
151
|
-
'AWS_SECRET_ACCESS_KEY',
|
|
152
|
-
)
|
|
153
|
-
}
|
|
154
|
-
|
|
155
|
-
def _s3_from_credentials( creds: str | Path | dict ) -> S3FileSystem:
|
|
156
|
-
"""Create an S3FileSystem from credentials.
|
|
718
|
+
required_keys = ('AWS_ENDPOINT', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY')
|
|
719
|
+
missing = [k for k in required_keys if k not in env_values]
|
|
720
|
+
if missing:
|
|
721
|
+
raise ValueError(f"Missing required keys in {credentials_path}: {', '.join(missing)}")
|
|
157
722
|
|
|
158
|
-
|
|
159
|
-
creds: Either a path to a .env file with credentials, or a dict
|
|
160
|
-
containing AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and optionally
|
|
161
|
-
AWS_ENDPOINT.
|
|
723
|
+
return {k: env_values[k] for k in required_keys}
|
|
162
724
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
"""
|
|
166
|
-
##
|
|
725
|
+
def _s3_from_credentials( creds: str | Path | dict ) -> S3FileSystem:
|
|
726
|
+
"""Create S3FileSystem from credentials dict or .env file path."""
|
|
167
727
|
if not isinstance( creds, dict ):
|
|
168
728
|
creds = _s3_env( creds )
|
|
169
729
|
|
|
@@ -184,6 +744,13 @@ def _s3_from_credentials( creds: str | Path | dict ) -> S3FileSystem:
|
|
|
184
744
|
class Repo:
|
|
185
745
|
"""Repository for storing and managing atdata datasets.
|
|
186
746
|
|
|
747
|
+
.. deprecated::
|
|
748
|
+
Use :class:`Index` with :class:`S3DataStore` instead::
|
|
749
|
+
|
|
750
|
+
store = S3DataStore(credentials, bucket="my-bucket")
|
|
751
|
+
index = Index(redis=redis, data_store=store)
|
|
752
|
+
entry = index.insert_dataset(ds, name="my-dataset")
|
|
753
|
+
|
|
187
754
|
Provides storage of datasets in S3-compatible object storage with Redis-based
|
|
188
755
|
indexing. Datasets are stored as WebDataset tar files with optional metadata.
|
|
189
756
|
|
|
@@ -197,17 +764,17 @@ class Repo:
|
|
|
197
764
|
|
|
198
765
|
##
|
|
199
766
|
|
|
200
|
-
def __init__(
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
#
|
|
207
|
-
**kwargs
|
|
208
|
-
) -> None:
|
|
767
|
+
def __init__(
|
|
768
|
+
self,
|
|
769
|
+
s3_credentials: str | Path | dict[str, Any] | None = None,
|
|
770
|
+
hive_path: str | Path | None = None,
|
|
771
|
+
redis: Redis | None = None,
|
|
772
|
+
) -> None:
|
|
209
773
|
"""Initialize a repository.
|
|
210
774
|
|
|
775
|
+
.. deprecated::
|
|
776
|
+
Use Index with S3DataStore instead.
|
|
777
|
+
|
|
211
778
|
Args:
|
|
212
779
|
s3_credentials: Path to .env file with S3 credentials, or dict with
|
|
213
780
|
AWS_ENDPOINT, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY.
|
|
@@ -215,11 +782,18 @@ class Repo:
|
|
|
215
782
|
hive_path: Path within the S3 bucket to store datasets.
|
|
216
783
|
Required if s3_credentials is provided.
|
|
217
784
|
redis: Redis connection for indexing. If None, creates a new connection.
|
|
218
|
-
**kwargs: Additional arguments (reserved for future use).
|
|
219
785
|
|
|
220
786
|
Raises:
|
|
221
787
|
ValueError: If hive_path is not provided when s3_credentials is set.
|
|
222
788
|
"""
|
|
789
|
+
warnings.warn(
|
|
790
|
+
"Repo is deprecated. Use Index with S3DataStore instead:\n"
|
|
791
|
+
" store = S3DataStore(credentials, bucket='my-bucket')\n"
|
|
792
|
+
" index = Index(redis=redis, data_store=store)\n"
|
|
793
|
+
" entry = index.insert_dataset(ds, name='my-dataset')",
|
|
794
|
+
DeprecationWarning,
|
|
795
|
+
stacklevel=2,
|
|
796
|
+
)
|
|
223
797
|
|
|
224
798
|
if s3_credentials is None:
|
|
225
799
|
self.s3_credentials = None
|
|
@@ -241,19 +815,21 @@ class Repo:
|
|
|
241
815
|
else:
|
|
242
816
|
self.hive_path = None
|
|
243
817
|
self.hive_bucket = None
|
|
244
|
-
|
|
818
|
+
|
|
245
819
|
#
|
|
246
820
|
|
|
247
821
|
self.index = Index( redis = redis )
|
|
248
822
|
|
|
249
823
|
##
|
|
250
824
|
|
|
251
|
-
def insert(
|
|
252
|
-
|
|
825
|
+
def insert(self,
|
|
826
|
+
ds: Dataset[T],
|
|
827
|
+
*,
|
|
828
|
+
name: str,
|
|
253
829
|
cache_local: bool = False,
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
830
|
+
schema_ref: str | None = None,
|
|
831
|
+
**kwargs
|
|
832
|
+
) -> tuple[LocalDatasetEntry, Dataset[T]]:
|
|
257
833
|
"""Insert a dataset into the repository.
|
|
258
834
|
|
|
259
835
|
Writes the dataset to S3 as WebDataset tar files, stores metadata,
|
|
@@ -261,23 +837,25 @@ class Repo:
|
|
|
261
837
|
|
|
262
838
|
Args:
|
|
263
839
|
ds: The dataset to insert.
|
|
840
|
+
name: Human-readable name for the dataset.
|
|
264
841
|
cache_local: If True, write to local temporary storage first, then
|
|
265
842
|
copy to S3. This can be faster for some workloads.
|
|
843
|
+
schema_ref: Optional schema reference. If None, generates from sample type.
|
|
266
844
|
**kwargs: Additional arguments passed to wds.ShardWriter.
|
|
267
845
|
|
|
268
846
|
Returns:
|
|
269
847
|
A tuple of (index_entry, new_dataset) where:
|
|
270
|
-
- index_entry:
|
|
848
|
+
- index_entry: LocalDatasetEntry for the stored dataset
|
|
271
849
|
- new_dataset: Dataset object pointing to the stored copy
|
|
272
850
|
|
|
273
851
|
Raises:
|
|
274
|
-
|
|
852
|
+
ValueError: If S3 credentials or hive_path are not configured.
|
|
275
853
|
RuntimeError: If no shards were written.
|
|
276
854
|
"""
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
855
|
+
if self.s3_credentials is None:
|
|
856
|
+
raise ValueError("S3 credentials required for insert(). Initialize Repo with s3_credentials.")
|
|
857
|
+
if self.hive_bucket is None or self.hive_path is None:
|
|
858
|
+
raise ValueError("hive_path required for insert(). Initialize Repo with hive_path.")
|
|
281
859
|
|
|
282
860
|
new_uuid = str( uuid4() )
|
|
283
861
|
|
|
@@ -305,58 +883,25 @@ class Repo:
|
|
|
305
883
|
/ f'atdata--{new_uuid}--%06d.tar'
|
|
306
884
|
).as_posix()
|
|
307
885
|
|
|
886
|
+
written_shards: list[str] = []
|
|
308
887
|
with TemporaryDirectory() as temp_dir:
|
|
888
|
+
writer_opener, writer_post = _create_s3_write_callbacks(
|
|
889
|
+
credentials=self.s3_credentials,
|
|
890
|
+
temp_dir=temp_dir,
|
|
891
|
+
written_shards=written_shards,
|
|
892
|
+
fs=hive_fs,
|
|
893
|
+
cache_local=cache_local,
|
|
894
|
+
add_s3_prefix=False,
|
|
895
|
+
)
|
|
309
896
|
|
|
310
|
-
if cache_local:
|
|
311
|
-
# For cache_local, we need to use boto3 directly to avoid s3fs async issues with moto
|
|
312
|
-
import boto3
|
|
313
|
-
|
|
314
|
-
# Create boto3 client from credentials
|
|
315
|
-
s3_client_kwargs = {
|
|
316
|
-
'aws_access_key_id': self.s3_credentials['AWS_ACCESS_KEY_ID'],
|
|
317
|
-
'aws_secret_access_key': self.s3_credentials['AWS_SECRET_ACCESS_KEY']
|
|
318
|
-
}
|
|
319
|
-
if 'AWS_ENDPOINT' in self.s3_credentials:
|
|
320
|
-
s3_client_kwargs['endpoint_url'] = self.s3_credentials['AWS_ENDPOINT']
|
|
321
|
-
s3_client = boto3.client('s3', **s3_client_kwargs)
|
|
322
|
-
|
|
323
|
-
def _writer_opener( p: str ):
|
|
324
|
-
local_cache_path = Path( temp_dir ) / p
|
|
325
|
-
local_cache_path.parent.mkdir( parents = True, exist_ok = True )
|
|
326
|
-
return open( local_cache_path, 'wb' )
|
|
327
|
-
writer_opener = _writer_opener
|
|
328
|
-
|
|
329
|
-
def _writer_post( p: str ):
|
|
330
|
-
local_cache_path = Path( temp_dir ) / p
|
|
331
|
-
|
|
332
|
-
# Copy to S3 using boto3 client (avoids s3fs async issues)
|
|
333
|
-
path_parts = Path( p ).parts
|
|
334
|
-
bucket = path_parts[0]
|
|
335
|
-
key = str( Path( *path_parts[1:] ) )
|
|
336
|
-
|
|
337
|
-
with open( local_cache_path, 'rb' ) as f_in:
|
|
338
|
-
s3_client.put_object( Bucket=bucket, Key=key, Body=f_in.read() )
|
|
339
|
-
|
|
340
|
-
# Delete local cache file
|
|
341
|
-
local_cache_path.unlink()
|
|
342
|
-
|
|
343
|
-
written_shards.append( p )
|
|
344
|
-
writer_post = _writer_post
|
|
345
|
-
|
|
346
|
-
else:
|
|
347
|
-
# Use s3:// prefix to ensure s3fs treats paths as S3 paths
|
|
348
|
-
writer_opener = lambda s: cast( BinaryIO, hive_fs.open( f's3://{s}', 'wb' ) )
|
|
349
|
-
writer_post = lambda s: written_shards.append( s )
|
|
350
|
-
|
|
351
|
-
written_shards = []
|
|
352
897
|
with wds.writer.ShardWriter(
|
|
353
898
|
shard_pattern,
|
|
354
|
-
opener
|
|
355
|
-
post
|
|
899
|
+
opener=writer_opener,
|
|
900
|
+
post=writer_post,
|
|
356
901
|
**kwargs,
|
|
357
902
|
) as sink:
|
|
358
|
-
for sample in ds.ordered(
|
|
359
|
-
sink.write(
|
|
903
|
+
for sample in ds.ordered(batch_size=None):
|
|
904
|
+
sink.write(sample.as_wds)
|
|
360
905
|
|
|
361
906
|
# Make a new Dataset object for the written dataset copy
|
|
362
907
|
if len( written_shards ) == 0:
|
|
@@ -379,12 +924,17 @@ class Repo:
|
|
|
379
924
|
new_dataset_url = shard_s3_format.format( shard_id = shard_id_braced )
|
|
380
925
|
|
|
381
926
|
new_dataset = Dataset[ds.sample_type](
|
|
382
|
-
url
|
|
383
|
-
metadata_url
|
|
927
|
+
url=new_dataset_url,
|
|
928
|
+
metadata_url=metadata_path.as_posix(),
|
|
384
929
|
)
|
|
385
930
|
|
|
386
|
-
# Add to index
|
|
387
|
-
new_entry = self.index.add_entry(
|
|
931
|
+
# Add to index (use ds._metadata to avoid network requests)
|
|
932
|
+
new_entry = self.index.add_entry(
|
|
933
|
+
new_dataset,
|
|
934
|
+
name=name,
|
|
935
|
+
schema_ref=schema_ref,
|
|
936
|
+
metadata=ds._metadata,
|
|
937
|
+
)
|
|
388
938
|
|
|
389
939
|
return new_entry, new_dataset
|
|
390
940
|
|
|
@@ -392,24 +942,43 @@ class Repo:
|
|
|
392
942
|
class Index:
|
|
393
943
|
"""Redis-backed index for tracking datasets in a repository.
|
|
394
944
|
|
|
395
|
-
Maintains a registry of
|
|
396
|
-
enumeration and lookup
|
|
945
|
+
Implements the AbstractIndex protocol. Maintains a registry of
|
|
946
|
+
LocalDatasetEntry objects in Redis, allowing enumeration and lookup
|
|
947
|
+
of stored datasets.
|
|
948
|
+
|
|
949
|
+
When initialized with a data_store, insert_dataset() will write dataset
|
|
950
|
+
shards to storage before indexing. Without a data_store, insert_dataset()
|
|
951
|
+
only indexes existing URLs.
|
|
397
952
|
|
|
398
953
|
Attributes:
|
|
399
954
|
_redis: Redis connection for index storage.
|
|
955
|
+
_data_store: Optional AbstractDataStore for writing dataset shards.
|
|
400
956
|
"""
|
|
401
957
|
|
|
402
958
|
##
|
|
403
959
|
|
|
404
|
-
def __init__(
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
960
|
+
def __init__(
|
|
961
|
+
self,
|
|
962
|
+
redis: Redis | None = None,
|
|
963
|
+
data_store: AbstractDataStore | None = None,
|
|
964
|
+
auto_stubs: bool = False,
|
|
965
|
+
stub_dir: Path | str | None = None,
|
|
966
|
+
**kwargs,
|
|
967
|
+
) -> None:
|
|
408
968
|
"""Initialize an index.
|
|
409
969
|
|
|
410
970
|
Args:
|
|
411
971
|
redis: Redis connection to use. If None, creates a new connection
|
|
412
972
|
using the provided kwargs.
|
|
973
|
+
data_store: Optional data store for writing dataset shards.
|
|
974
|
+
If provided, insert_dataset() will write shards to this store.
|
|
975
|
+
If None, insert_dataset() only indexes existing URLs.
|
|
976
|
+
auto_stubs: If True, automatically generate .pyi stub files when
|
|
977
|
+
schemas are accessed via get_schema() or decode_schema().
|
|
978
|
+
This enables IDE autocomplete for dynamically decoded types.
|
|
979
|
+
stub_dir: Directory to write stub files. Only used if auto_stubs
|
|
980
|
+
is True or if this parameter is provided (which implies auto_stubs).
|
|
981
|
+
Defaults to ~/.atdata/stubs/ if not specified.
|
|
413
982
|
**kwargs: Additional arguments passed to Redis() constructor if
|
|
414
983
|
redis is None.
|
|
415
984
|
"""
|
|
@@ -418,75 +987,721 @@ class Index:
|
|
|
418
987
|
if redis is not None:
|
|
419
988
|
self._redis = redis
|
|
420
989
|
else:
|
|
421
|
-
self._redis: Redis = Redis(
|
|
990
|
+
self._redis: Redis = Redis(**kwargs)
|
|
991
|
+
|
|
992
|
+
self._data_store = data_store
|
|
993
|
+
|
|
994
|
+
# Initialize stub manager if auto-stubs enabled
|
|
995
|
+
# Providing stub_dir implies auto_stubs=True
|
|
996
|
+
if auto_stubs or stub_dir is not None:
|
|
997
|
+
from ._stub_manager import StubManager
|
|
998
|
+
self._stub_manager: StubManager | None = StubManager(stub_dir=stub_dir)
|
|
999
|
+
else:
|
|
1000
|
+
self._stub_manager = None
|
|
1001
|
+
|
|
1002
|
+
# Initialize schema namespace for load_schema/schemas API
|
|
1003
|
+
self._schema_namespace = SchemaNamespace()
|
|
422
1004
|
|
|
423
1005
|
@property
|
|
424
|
-
def
|
|
425
|
-
"""
|
|
1006
|
+
def data_store(self) -> AbstractDataStore | None:
|
|
1007
|
+
"""The data store for writing shards, or None if index-only."""
|
|
1008
|
+
return self._data_store
|
|
1009
|
+
|
|
1010
|
+
@property
|
|
1011
|
+
def stub_dir(self) -> Path | None:
|
|
1012
|
+
"""Directory where stub files are written, or None if auto-stubs disabled.
|
|
1013
|
+
|
|
1014
|
+
Use this path to configure your IDE for type checking support:
|
|
1015
|
+
- VS Code/Pylance: Add to python.analysis.extraPaths in settings.json
|
|
1016
|
+
- PyCharm: Mark as Sources Root
|
|
1017
|
+
- mypy: Add to mypy_path in mypy.ini
|
|
1018
|
+
"""
|
|
1019
|
+
if self._stub_manager is not None:
|
|
1020
|
+
return self._stub_manager.stub_dir
|
|
1021
|
+
return None
|
|
1022
|
+
|
|
1023
|
+
@property
|
|
1024
|
+
def types(self) -> SchemaNamespace:
|
|
1025
|
+
"""Namespace for accessing loaded schema types.
|
|
1026
|
+
|
|
1027
|
+
After calling :meth:`load_schema`, schema types become available
|
|
1028
|
+
as attributes on this namespace.
|
|
1029
|
+
|
|
1030
|
+
Example:
|
|
1031
|
+
::
|
|
1032
|
+
|
|
1033
|
+
>>> index.load_schema("atdata://local/sampleSchema/MySample@1.0.0")
|
|
1034
|
+
>>> MyType = index.types.MySample
|
|
1035
|
+
>>> sample = MyType(name="hello", value=42)
|
|
1036
|
+
|
|
1037
|
+
Returns:
|
|
1038
|
+
SchemaNamespace containing all loaded schema types.
|
|
1039
|
+
"""
|
|
1040
|
+
return self._schema_namespace
|
|
1041
|
+
|
|
1042
|
+
def load_schema(self, ref: str) -> Type[Packable]:
|
|
1043
|
+
"""Load a schema and make it available in the types namespace.
|
|
1044
|
+
|
|
1045
|
+
This method decodes the schema, optionally generates a Python module
|
|
1046
|
+
for IDE support (if auto_stubs is enabled), and registers the type
|
|
1047
|
+
in the :attr:`types` namespace for easy access.
|
|
1048
|
+
|
|
1049
|
+
Args:
|
|
1050
|
+
ref: Schema reference string (atdata://local/sampleSchema/... or
|
|
1051
|
+
legacy local://schemas/...).
|
|
1052
|
+
|
|
1053
|
+
Returns:
|
|
1054
|
+
The decoded PackableSample subclass. Also available via
|
|
1055
|
+
``index.types.<ClassName>`` after this call.
|
|
1056
|
+
|
|
1057
|
+
Raises:
|
|
1058
|
+
KeyError: If schema not found.
|
|
1059
|
+
ValueError: If schema cannot be decoded.
|
|
1060
|
+
|
|
1061
|
+
Example:
|
|
1062
|
+
::
|
|
1063
|
+
|
|
1064
|
+
>>> # Load and use immediately
|
|
1065
|
+
>>> MyType = index.load_schema("atdata://local/sampleSchema/MySample@1.0.0")
|
|
1066
|
+
>>> sample = MyType(name="hello", value=42)
|
|
1067
|
+
>>>
|
|
1068
|
+
>>> # Or access later via namespace
|
|
1069
|
+
>>> index.load_schema("atdata://local/sampleSchema/OtherType@1.0.0")
|
|
1070
|
+
>>> other = index.types.OtherType(data="test")
|
|
1071
|
+
"""
|
|
1072
|
+
# Decode the schema (uses generated module if auto_stubs enabled)
|
|
1073
|
+
cls = self.decode_schema(ref)
|
|
1074
|
+
|
|
1075
|
+
# Register in namespace using the class name
|
|
1076
|
+
self._schema_namespace._register(cls.__name__, cls)
|
|
1077
|
+
|
|
1078
|
+
return cls
|
|
1079
|
+
|
|
1080
|
+
def get_import_path(self, ref: str) -> str | None:
|
|
1081
|
+
"""Get the import path for a schema's generated module.
|
|
1082
|
+
|
|
1083
|
+
When auto_stubs is enabled, this returns the import path that can
|
|
1084
|
+
be used to import the schema type with full IDE support.
|
|
1085
|
+
|
|
1086
|
+
Args:
|
|
1087
|
+
ref: Schema reference string.
|
|
1088
|
+
|
|
1089
|
+
Returns:
|
|
1090
|
+
Import path like "local.MySample_1_0_0", or None if auto_stubs
|
|
1091
|
+
is disabled.
|
|
1092
|
+
|
|
1093
|
+
Example:
|
|
1094
|
+
::
|
|
1095
|
+
|
|
1096
|
+
>>> index = LocalIndex(auto_stubs=True)
|
|
1097
|
+
>>> ref = index.publish_schema(MySample, version="1.0.0")
|
|
1098
|
+
>>> index.load_schema(ref)
|
|
1099
|
+
>>> print(index.get_import_path(ref))
|
|
1100
|
+
local.MySample_1_0_0
|
|
1101
|
+
>>> # Then in your code:
|
|
1102
|
+
>>> # from local.MySample_1_0_0 import MySample
|
|
1103
|
+
"""
|
|
1104
|
+
if self._stub_manager is None:
|
|
1105
|
+
return None
|
|
1106
|
+
|
|
1107
|
+
from ._stub_manager import _extract_authority
|
|
1108
|
+
|
|
1109
|
+
name, version = _parse_schema_ref(ref)
|
|
1110
|
+
schema_dict = self.get_schema(ref)
|
|
1111
|
+
authority = _extract_authority(schema_dict.get("$ref"))
|
|
1112
|
+
|
|
1113
|
+
safe_version = version.replace(".", "_")
|
|
1114
|
+
module_name = f"{name}_{safe_version}"
|
|
1115
|
+
|
|
1116
|
+
return f"{authority}.{module_name}"
|
|
1117
|
+
|
|
1118
|
+
def list_entries(self) -> list[LocalDatasetEntry]:
|
|
1119
|
+
"""Get all index entries as a materialized list.
|
|
426
1120
|
|
|
427
1121
|
Returns:
|
|
428
|
-
List of all
|
|
1122
|
+
List of all LocalDatasetEntry objects in the index.
|
|
429
1123
|
"""
|
|
430
|
-
return list(
|
|
1124
|
+
return list(self.entries)
|
|
431
1125
|
|
|
1126
|
+
# Legacy alias for backwards compatibility
|
|
432
1127
|
@property
|
|
433
|
-
def
|
|
1128
|
+
def all_entries(self) -> list[LocalDatasetEntry]:
|
|
1129
|
+
"""Get all index entries as a list (deprecated, use list_entries())."""
|
|
1130
|
+
return self.list_entries()
|
|
1131
|
+
|
|
1132
|
+
@property
|
|
1133
|
+
def entries(self) -> Generator[LocalDatasetEntry, None, None]:
|
|
434
1134
|
"""Iterate over all index entries.
|
|
435
1135
|
|
|
436
|
-
Scans Redis for
|
|
1136
|
+
Scans Redis for LocalDatasetEntry keys and yields them one at a time.
|
|
437
1137
|
|
|
438
1138
|
Yields:
|
|
439
|
-
|
|
1139
|
+
LocalDatasetEntry objects from the index.
|
|
440
1140
|
"""
|
|
441
|
-
|
|
442
|
-
for key in self._redis.scan_iter(
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
return
|
|
455
|
-
|
|
456
|
-
def add_entry( self, ds: Dataset,
|
|
457
|
-
uuid: str | None = None,
|
|
458
|
-
) -> BasicIndexEntry:
|
|
1141
|
+
prefix = f'{REDIS_KEY_DATASET_ENTRY}:'
|
|
1142
|
+
for key in self._redis.scan_iter(match=f'{prefix}*'):
|
|
1143
|
+
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
|
|
1144
|
+
cid = key_str[len(prefix):]
|
|
1145
|
+
yield LocalDatasetEntry.from_redis(self._redis, cid)
|
|
1146
|
+
|
|
1147
|
+
def add_entry(self,
|
|
1148
|
+
ds: Dataset,
|
|
1149
|
+
*,
|
|
1150
|
+
name: str,
|
|
1151
|
+
schema_ref: str | None = None,
|
|
1152
|
+
metadata: dict | None = None,
|
|
1153
|
+
) -> LocalDatasetEntry:
|
|
459
1154
|
"""Add a dataset to the index.
|
|
460
1155
|
|
|
461
|
-
Creates a
|
|
1156
|
+
Creates a LocalDatasetEntry for the dataset and persists it to Redis.
|
|
462
1157
|
|
|
463
1158
|
Args:
|
|
464
1159
|
ds: The dataset to add to the index.
|
|
465
|
-
|
|
1160
|
+
name: Human-readable name for the dataset.
|
|
1161
|
+
schema_ref: Optional schema reference. If None, generates from sample type.
|
|
1162
|
+
metadata: Optional metadata dictionary. If None, uses ds._metadata if available.
|
|
466
1163
|
|
|
467
1164
|
Returns:
|
|
468
|
-
The created
|
|
1165
|
+
The created LocalDatasetEntry object.
|
|
469
1166
|
"""
|
|
470
1167
|
##
|
|
471
|
-
|
|
1168
|
+
if schema_ref is None:
|
|
1169
|
+
schema_ref = f"local://schemas/{_kind_str_for_sample_type(ds.sample_type)}@1.0.0"
|
|
1170
|
+
|
|
1171
|
+
# Normalize URL to list
|
|
1172
|
+
data_urls = [ds.url]
|
|
1173
|
+
|
|
1174
|
+
# Use provided metadata, or fall back to dataset's cached metadata
|
|
1175
|
+
# (avoid triggering network requests via ds.metadata property)
|
|
1176
|
+
entry_metadata = metadata if metadata is not None else ds._metadata
|
|
1177
|
+
|
|
1178
|
+
entry = LocalDatasetEntry(
|
|
1179
|
+
name=name,
|
|
1180
|
+
schema_ref=schema_ref,
|
|
1181
|
+
data_urls=data_urls,
|
|
1182
|
+
metadata=entry_metadata,
|
|
1183
|
+
)
|
|
1184
|
+
|
|
1185
|
+
entry.write_to(self._redis)
|
|
1186
|
+
|
|
1187
|
+
return entry
|
|
1188
|
+
|
|
1189
|
+
def get_entry(self, cid: str) -> LocalDatasetEntry:
|
|
1190
|
+
"""Get an entry by its CID.
|
|
1191
|
+
|
|
1192
|
+
Args:
|
|
1193
|
+
cid: Content identifier of the entry.
|
|
1194
|
+
|
|
1195
|
+
Returns:
|
|
1196
|
+
LocalDatasetEntry for the given CID.
|
|
1197
|
+
|
|
1198
|
+
Raises:
|
|
1199
|
+
KeyError: If entry not found.
|
|
1200
|
+
"""
|
|
1201
|
+
return LocalDatasetEntry.from_redis(self._redis, cid)
|
|
1202
|
+
|
|
1203
|
+
def get_entry_by_name(self, name: str) -> LocalDatasetEntry:
|
|
1204
|
+
"""Get an entry by its human-readable name.
|
|
1205
|
+
|
|
1206
|
+
Args:
|
|
1207
|
+
name: Human-readable name of the entry.
|
|
1208
|
+
|
|
1209
|
+
Returns:
|
|
1210
|
+
LocalDatasetEntry with the given name.
|
|
1211
|
+
|
|
1212
|
+
Raises:
|
|
1213
|
+
KeyError: If no entry with that name exists.
|
|
1214
|
+
"""
|
|
1215
|
+
for entry in self.entries:
|
|
1216
|
+
if entry.name == name:
|
|
1217
|
+
return entry
|
|
1218
|
+
raise KeyError(f"No entry with name: {name}")
|
|
1219
|
+
|
|
1220
|
+
# AbstractIndex protocol methods
|
|
1221
|
+
|
|
1222
|
+
def insert_dataset(
|
|
1223
|
+
self,
|
|
1224
|
+
ds: Dataset,
|
|
1225
|
+
*,
|
|
1226
|
+
name: str,
|
|
1227
|
+
schema_ref: str | None = None,
|
|
1228
|
+
**kwargs,
|
|
1229
|
+
) -> LocalDatasetEntry:
|
|
1230
|
+
"""Insert a dataset into the index (AbstractIndex protocol).
|
|
1231
|
+
|
|
1232
|
+
If a data_store was provided at initialization, writes dataset shards
|
|
1233
|
+
to storage first, then indexes the new URLs. Otherwise, indexes the
|
|
1234
|
+
dataset's existing URL.
|
|
1235
|
+
|
|
1236
|
+
Args:
|
|
1237
|
+
ds: The Dataset to register.
|
|
1238
|
+
name: Human-readable name for the dataset.
|
|
1239
|
+
schema_ref: Optional schema reference.
|
|
1240
|
+
**kwargs: Additional options:
|
|
1241
|
+
- metadata: Optional metadata dict
|
|
1242
|
+
- prefix: Storage prefix (default: dataset name)
|
|
1243
|
+
- cache_local: If True, cache writes locally first
|
|
1244
|
+
|
|
1245
|
+
Returns:
|
|
1246
|
+
IndexEntry for the inserted dataset.
|
|
1247
|
+
"""
|
|
1248
|
+
metadata = kwargs.get('metadata')
|
|
1249
|
+
|
|
1250
|
+
if self._data_store is not None:
|
|
1251
|
+
# Write shards to data store, then index the new URLs
|
|
1252
|
+
prefix = kwargs.get('prefix', name)
|
|
1253
|
+
cache_local = kwargs.get('cache_local', False)
|
|
1254
|
+
|
|
1255
|
+
written_urls = self._data_store.write_shards(
|
|
1256
|
+
ds,
|
|
1257
|
+
prefix=prefix,
|
|
1258
|
+
cache_local=cache_local,
|
|
1259
|
+
)
|
|
472
1260
|
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
1261
|
+
# Generate schema_ref if not provided
|
|
1262
|
+
if schema_ref is None:
|
|
1263
|
+
schema_ref = _schema_ref_from_type(ds.sample_type, version="1.0.0")
|
|
1264
|
+
|
|
1265
|
+
# Create entry with the written URLs
|
|
1266
|
+
entry_metadata = metadata if metadata is not None else ds._metadata
|
|
1267
|
+
entry = LocalDatasetEntry(
|
|
1268
|
+
name=name,
|
|
1269
|
+
schema_ref=schema_ref,
|
|
1270
|
+
data_urls=written_urls,
|
|
1271
|
+
metadata=entry_metadata,
|
|
478
1272
|
)
|
|
1273
|
+
entry.write_to(self._redis)
|
|
1274
|
+
return entry
|
|
1275
|
+
|
|
1276
|
+
# No data store - just index the existing URL
|
|
1277
|
+
return self.add_entry(ds, name=name, schema_ref=schema_ref, metadata=metadata)
|
|
1278
|
+
|
|
1279
|
+
def get_dataset(self, ref: str) -> LocalDatasetEntry:
|
|
1280
|
+
"""Get a dataset entry by name (AbstractIndex protocol).
|
|
1281
|
+
|
|
1282
|
+
Args:
|
|
1283
|
+
ref: Dataset name.
|
|
1284
|
+
|
|
1285
|
+
Returns:
|
|
1286
|
+
IndexEntry for the dataset.
|
|
1287
|
+
|
|
1288
|
+
Raises:
|
|
1289
|
+
KeyError: If dataset not found.
|
|
1290
|
+
"""
|
|
1291
|
+
return self.get_entry_by_name(ref)
|
|
1292
|
+
|
|
1293
|
+
@property
|
|
1294
|
+
def datasets(self) -> Generator[LocalDatasetEntry, None, None]:
|
|
1295
|
+
"""Lazily iterate over all dataset entries (AbstractIndex protocol).
|
|
1296
|
+
|
|
1297
|
+
Yields:
|
|
1298
|
+
IndexEntry for each dataset.
|
|
1299
|
+
"""
|
|
1300
|
+
return self.entries
|
|
1301
|
+
|
|
1302
|
+
def list_datasets(self) -> list[LocalDatasetEntry]:
|
|
1303
|
+
"""Get all dataset entries as a materialized list (AbstractIndex protocol).
|
|
1304
|
+
|
|
1305
|
+
Returns:
|
|
1306
|
+
List of IndexEntry for each dataset.
|
|
1307
|
+
"""
|
|
1308
|
+
return self.list_entries()
|
|
1309
|
+
|
|
1310
|
+
# Schema operations
|
|
1311
|
+
|
|
1312
|
+
def _get_latest_schema_version(self, name: str) -> str | None:
|
|
1313
|
+
"""Get the latest version for a schema by name, or None if not found."""
|
|
1314
|
+
latest_version: tuple[int, int, int] | None = None
|
|
1315
|
+
latest_version_str: str | None = None
|
|
1316
|
+
|
|
1317
|
+
prefix = f'{REDIS_KEY_SCHEMA}:'
|
|
1318
|
+
for key in self._redis.scan_iter(match=f'{prefix}*'):
|
|
1319
|
+
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
|
|
1320
|
+
schema_id = key_str[len(prefix):]
|
|
1321
|
+
|
|
1322
|
+
if "@" not in schema_id:
|
|
1323
|
+
continue
|
|
1324
|
+
|
|
1325
|
+
schema_name, version_str = schema_id.rsplit("@", 1)
|
|
1326
|
+
# Handle legacy format: module.Class -> Class
|
|
1327
|
+
if "." in schema_name:
|
|
1328
|
+
schema_name = schema_name.rsplit(".", 1)[1]
|
|
1329
|
+
|
|
1330
|
+
if schema_name != name:
|
|
1331
|
+
continue
|
|
1332
|
+
|
|
1333
|
+
try:
|
|
1334
|
+
version_tuple = _parse_semver(version_str)
|
|
1335
|
+
if latest_version is None or version_tuple > latest_version:
|
|
1336
|
+
latest_version = version_tuple
|
|
1337
|
+
latest_version_str = version_str
|
|
1338
|
+
except ValueError:
|
|
1339
|
+
continue
|
|
1340
|
+
|
|
1341
|
+
return latest_version_str
|
|
1342
|
+
|
|
1343
|
+
def publish_schema(
|
|
1344
|
+
self,
|
|
1345
|
+
sample_type: type,
|
|
1346
|
+
*,
|
|
1347
|
+
version: str | None = None,
|
|
1348
|
+
description: str | None = None,
|
|
1349
|
+
) -> str:
|
|
1350
|
+
"""Publish a schema for a sample type to Redis.
|
|
1351
|
+
|
|
1352
|
+
Args:
|
|
1353
|
+
sample_type: A Packable type (@packable-decorated or PackableSample subclass).
|
|
1354
|
+
version: Semantic version string (e.g., '1.0.0'). If None,
|
|
1355
|
+
auto-increments from the latest published version (patch bump),
|
|
1356
|
+
or starts at '1.0.0' if no previous version exists.
|
|
1357
|
+
description: Optional human-readable description. If None, uses
|
|
1358
|
+
the class docstring.
|
|
1359
|
+
|
|
1360
|
+
Returns:
|
|
1361
|
+
Schema reference string: 'atdata://local/sampleSchema/{name}@{version}'.
|
|
1362
|
+
|
|
1363
|
+
Raises:
|
|
1364
|
+
ValueError: If sample_type is not a dataclass.
|
|
1365
|
+
TypeError: If sample_type doesn't satisfy the Packable protocol,
|
|
1366
|
+
or if a field type is not supported.
|
|
1367
|
+
"""
|
|
1368
|
+
# Validate that sample_type satisfies Packable protocol at runtime
|
|
1369
|
+
# This catches non-packable types early with a clear error message
|
|
1370
|
+
try:
|
|
1371
|
+
# Check protocol compliance by verifying required methods exist
|
|
1372
|
+
if not (hasattr(sample_type, 'from_data') and
|
|
1373
|
+
hasattr(sample_type, 'from_bytes') and
|
|
1374
|
+
callable(getattr(sample_type, 'from_data', None)) and
|
|
1375
|
+
callable(getattr(sample_type, 'from_bytes', None))):
|
|
1376
|
+
raise TypeError(
|
|
1377
|
+
f"{sample_type.__name__} does not satisfy the Packable protocol. "
|
|
1378
|
+
"Use @packable decorator or inherit from PackableSample."
|
|
1379
|
+
)
|
|
1380
|
+
except AttributeError:
|
|
1381
|
+
raise TypeError(
|
|
1382
|
+
f"sample_type must be a class, got {type(sample_type).__name__}"
|
|
1383
|
+
)
|
|
1384
|
+
|
|
1385
|
+
# Auto-increment version if not specified
|
|
1386
|
+
if version is None:
|
|
1387
|
+
latest = self._get_latest_schema_version(sample_type.__name__)
|
|
1388
|
+
if latest is None:
|
|
1389
|
+
version = "1.0.0"
|
|
1390
|
+
else:
|
|
1391
|
+
version = _increment_patch(latest)
|
|
1392
|
+
|
|
1393
|
+
schema_record = _build_schema_record(
|
|
1394
|
+
sample_type,
|
|
1395
|
+
version=version,
|
|
1396
|
+
description=description,
|
|
1397
|
+
)
|
|
1398
|
+
|
|
1399
|
+
schema_ref = _schema_ref_from_type(sample_type, version)
|
|
1400
|
+
name, _ = _parse_schema_ref(schema_ref)
|
|
1401
|
+
|
|
1402
|
+
# Store in Redis
|
|
1403
|
+
redis_key = f"{REDIS_KEY_SCHEMA}:{name}@{version}"
|
|
1404
|
+
schema_json = json.dumps(schema_record)
|
|
1405
|
+
self._redis.set(redis_key, schema_json)
|
|
1406
|
+
|
|
1407
|
+
return schema_ref
|
|
1408
|
+
|
|
1409
|
+
def get_schema(self, ref: str) -> dict:
|
|
1410
|
+
"""Get a schema record by reference (AbstractIndex protocol).
|
|
1411
|
+
|
|
1412
|
+
Args:
|
|
1413
|
+
ref: Schema reference string. Supports both new format
|
|
1414
|
+
(atdata://local/sampleSchema/{name}@{version}) and legacy
|
|
1415
|
+
format (local://schemas/{module.Class}@{version}).
|
|
1416
|
+
|
|
1417
|
+
Returns:
|
|
1418
|
+
Schema record as a dictionary with keys 'name', 'version',
|
|
1419
|
+
'fields', '$ref', etc.
|
|
1420
|
+
|
|
1421
|
+
Raises:
|
|
1422
|
+
KeyError: If schema not found.
|
|
1423
|
+
ValueError: If reference format is invalid.
|
|
1424
|
+
"""
|
|
1425
|
+
name, version = _parse_schema_ref(ref)
|
|
1426
|
+
redis_key = f"{REDIS_KEY_SCHEMA}:{name}@{version}"
|
|
1427
|
+
|
|
1428
|
+
schema_json = self._redis.get(redis_key)
|
|
1429
|
+
if schema_json is None:
|
|
1430
|
+
raise KeyError(f"Schema not found: {ref}")
|
|
1431
|
+
|
|
1432
|
+
if isinstance(schema_json, bytes):
|
|
1433
|
+
schema_json = schema_json.decode('utf-8')
|
|
1434
|
+
|
|
1435
|
+
schema = json.loads(schema_json)
|
|
1436
|
+
schema['$ref'] = _make_schema_ref(name, version)
|
|
1437
|
+
|
|
1438
|
+
# Auto-generate stub if enabled
|
|
1439
|
+
if self._stub_manager is not None:
|
|
1440
|
+
record = LocalSchemaRecord.from_dict(schema)
|
|
1441
|
+
self._stub_manager.ensure_stub(record)
|
|
1442
|
+
|
|
1443
|
+
return schema
|
|
1444
|
+
|
|
1445
|
+
def get_schema_record(self, ref: str) -> LocalSchemaRecord:
|
|
1446
|
+
"""Get a schema record as LocalSchemaRecord object.
|
|
1447
|
+
|
|
1448
|
+
Use this when you need the full LocalSchemaRecord with typed properties.
|
|
1449
|
+
For Protocol-compliant dict access, use get_schema() instead.
|
|
1450
|
+
|
|
1451
|
+
Args:
|
|
1452
|
+
ref: Schema reference string.
|
|
1453
|
+
|
|
1454
|
+
Returns:
|
|
1455
|
+
LocalSchemaRecord with schema details.
|
|
1456
|
+
|
|
1457
|
+
Raises:
|
|
1458
|
+
KeyError: If schema not found.
|
|
1459
|
+
ValueError: If reference format is invalid.
|
|
1460
|
+
"""
|
|
1461
|
+
schema = self.get_schema(ref)
|
|
1462
|
+
return LocalSchemaRecord.from_dict(schema)
|
|
1463
|
+
|
|
1464
|
+
@property
|
|
1465
|
+
def schemas(self) -> Generator[LocalSchemaRecord, None, None]:
|
|
1466
|
+
"""Iterate over all schema records in this index.
|
|
1467
|
+
|
|
1468
|
+
Yields:
|
|
1469
|
+
LocalSchemaRecord for each schema.
|
|
1470
|
+
"""
|
|
1471
|
+
prefix = f'{REDIS_KEY_SCHEMA}:'
|
|
1472
|
+
for key in self._redis.scan_iter(match=f'{prefix}*'):
|
|
1473
|
+
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
|
|
1474
|
+
# Extract name@version from key
|
|
1475
|
+
schema_id = key_str[len(prefix):]
|
|
1476
|
+
|
|
1477
|
+
schema_json = self._redis.get(key)
|
|
1478
|
+
if schema_json is None:
|
|
1479
|
+
continue
|
|
1480
|
+
|
|
1481
|
+
if isinstance(schema_json, bytes):
|
|
1482
|
+
schema_json = schema_json.decode('utf-8')
|
|
1483
|
+
|
|
1484
|
+
schema = json.loads(schema_json)
|
|
1485
|
+
# Handle legacy keys that have module.Class format
|
|
1486
|
+
if "." in schema_id.split("@")[0]:
|
|
1487
|
+
name = schema_id.split("@")[0].rsplit(".", 1)[1]
|
|
1488
|
+
version = schema_id.split("@")[1]
|
|
1489
|
+
schema['$ref'] = _make_schema_ref(name, version)
|
|
1490
|
+
else:
|
|
1491
|
+
# schema_id is already "name@version"
|
|
1492
|
+
name, version = schema_id.rsplit("@", 1)
|
|
1493
|
+
schema['$ref'] = _make_schema_ref(name, version)
|
|
1494
|
+
yield LocalSchemaRecord.from_dict(schema)
|
|
1495
|
+
|
|
1496
|
+
def list_schemas(self) -> list[dict]:
|
|
1497
|
+
"""Get all schema records as a materialized list (AbstractIndex protocol).
|
|
1498
|
+
|
|
1499
|
+
Returns:
|
|
1500
|
+
List of schema records as dictionaries.
|
|
1501
|
+
"""
|
|
1502
|
+
return [record.to_dict() for record in self.schemas]
|
|
1503
|
+
|
|
1504
|
+
def decode_schema(self, ref: str) -> Type[Packable]:
|
|
1505
|
+
"""Reconstruct a Python PackableSample type from a stored schema.
|
|
1506
|
+
|
|
1507
|
+
This method enables loading datasets without knowing the sample type
|
|
1508
|
+
ahead of time. The index retrieves the schema record and dynamically
|
|
1509
|
+
generates a PackableSample subclass matching the schema definition.
|
|
1510
|
+
|
|
1511
|
+
If auto_stubs is enabled, a Python module will be generated and the
|
|
1512
|
+
class will be imported from it, providing full IDE autocomplete support.
|
|
1513
|
+
The returned class has proper type information that IDEs can understand.
|
|
1514
|
+
|
|
1515
|
+
Args:
|
|
1516
|
+
ref: Schema reference string (atdata://local/sampleSchema/... or
|
|
1517
|
+
legacy local://schemas/...).
|
|
1518
|
+
|
|
1519
|
+
Returns:
|
|
1520
|
+
A PackableSample subclass - either imported from a generated module
|
|
1521
|
+
(if auto_stubs is enabled) or dynamically created.
|
|
1522
|
+
|
|
1523
|
+
Raises:
|
|
1524
|
+
KeyError: If schema not found.
|
|
1525
|
+
ValueError: If schema cannot be decoded.
|
|
1526
|
+
"""
|
|
1527
|
+
schema_dict = self.get_schema(ref)
|
|
1528
|
+
|
|
1529
|
+
# If auto_stubs is enabled, generate module and import class from it
|
|
1530
|
+
if self._stub_manager is not None:
|
|
1531
|
+
cls = self._stub_manager.ensure_module(schema_dict)
|
|
1532
|
+
if cls is not None:
|
|
1533
|
+
return cls
|
|
1534
|
+
|
|
1535
|
+
# Fall back to dynamic type generation
|
|
1536
|
+
from atdata._schema_codec import schema_to_type
|
|
1537
|
+
return schema_to_type(schema_dict)
|
|
1538
|
+
|
|
1539
|
+
def decode_schema_as(self, ref: str, type_hint: type[T]) -> type[T]:
|
|
1540
|
+
"""Decode a schema with explicit type hint for IDE support.
|
|
1541
|
+
|
|
1542
|
+
This is a typed wrapper around decode_schema() that preserves the
|
|
1543
|
+
type information for IDE autocomplete. Use this when you have a
|
|
1544
|
+
stub file for the schema and want full IDE support.
|
|
1545
|
+
|
|
1546
|
+
Args:
|
|
1547
|
+
ref: Schema reference string.
|
|
1548
|
+
type_hint: The stub type to use for type hints. Import this from
|
|
1549
|
+
the generated stub file.
|
|
1550
|
+
|
|
1551
|
+
Returns:
|
|
1552
|
+
The decoded type, cast to match the type_hint for IDE support.
|
|
1553
|
+
|
|
1554
|
+
Example:
|
|
1555
|
+
::
|
|
1556
|
+
|
|
1557
|
+
>>> # After enabling auto_stubs and configuring IDE extraPaths:
|
|
1558
|
+
>>> from local.MySample_1_0_0 import MySample
|
|
1559
|
+
>>>
|
|
1560
|
+
>>> # This gives full IDE autocomplete:
|
|
1561
|
+
>>> DecodedType = index.decode_schema_as(ref, MySample)
|
|
1562
|
+
>>> sample = DecodedType(text="hello", value=42) # IDE knows signature!
|
|
1563
|
+
|
|
1564
|
+
Note:
|
|
1565
|
+
The type_hint is only used for static type checking - at runtime,
|
|
1566
|
+
the actual decoded type from the schema is returned. Ensure the
|
|
1567
|
+
stub matches the schema to avoid runtime surprises.
|
|
1568
|
+
"""
|
|
1569
|
+
from typing import cast
|
|
1570
|
+
return cast(type[T], self.decode_schema(ref))
|
|
1571
|
+
|
|
1572
|
+
def clear_stubs(self) -> int:
|
|
1573
|
+
"""Remove all auto-generated stub files.
|
|
1574
|
+
|
|
1575
|
+
Only works if auto_stubs was enabled when creating the Index.
|
|
1576
|
+
|
|
1577
|
+
Returns:
|
|
1578
|
+
Number of stub files removed, or 0 if auto_stubs is disabled.
|
|
1579
|
+
"""
|
|
1580
|
+
if self._stub_manager is not None:
|
|
1581
|
+
return self._stub_manager.clear_stubs()
|
|
1582
|
+
return 0
|
|
1583
|
+
|
|
1584
|
+
|
|
1585
|
+
# Backwards compatibility alias
|
|
1586
|
+
LocalIndex = Index
|
|
1587
|
+
|
|
1588
|
+
|
|
1589
|
+
class S3DataStore:
|
|
1590
|
+
"""S3-compatible data store implementing AbstractDataStore protocol.
|
|
1591
|
+
|
|
1592
|
+
Handles writing dataset shards to S3-compatible object storage and
|
|
1593
|
+
resolving URLs for reading.
|
|
1594
|
+
|
|
1595
|
+
Attributes:
|
|
1596
|
+
credentials: S3 credentials dictionary.
|
|
1597
|
+
bucket: Target bucket name.
|
|
1598
|
+
_fs: S3FileSystem instance.
|
|
1599
|
+
"""
|
|
1600
|
+
|
|
1601
|
+
def __init__(
|
|
1602
|
+
self,
|
|
1603
|
+
credentials: str | Path | dict[str, Any],
|
|
1604
|
+
*,
|
|
1605
|
+
bucket: str,
|
|
1606
|
+
) -> None:
|
|
1607
|
+
"""Initialize an S3 data store.
|
|
1608
|
+
|
|
1609
|
+
Args:
|
|
1610
|
+
credentials: Path to .env file or dict with AWS_ACCESS_KEY_ID,
|
|
1611
|
+
AWS_SECRET_ACCESS_KEY, and optionally AWS_ENDPOINT.
|
|
1612
|
+
bucket: Name of the S3 bucket for storage.
|
|
1613
|
+
"""
|
|
1614
|
+
if isinstance(credentials, dict):
|
|
1615
|
+
self.credentials = credentials
|
|
479
1616
|
else:
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
1617
|
+
self.credentials = _s3_env(credentials)
|
|
1618
|
+
|
|
1619
|
+
self.bucket = bucket
|
|
1620
|
+
self._fs = _s3_from_credentials(self.credentials)
|
|
1621
|
+
|
|
1622
|
+
def write_shards(
|
|
1623
|
+
self,
|
|
1624
|
+
ds: Dataset,
|
|
1625
|
+
*,
|
|
1626
|
+
prefix: str,
|
|
1627
|
+
cache_local: bool = False,
|
|
1628
|
+
**kwargs,
|
|
1629
|
+
) -> list[str]:
|
|
1630
|
+
"""Write dataset shards to S3.
|
|
1631
|
+
|
|
1632
|
+
Args:
|
|
1633
|
+
ds: The Dataset to write.
|
|
1634
|
+
prefix: Path prefix within bucket (e.g., 'datasets/mnist/v1').
|
|
1635
|
+
cache_local: If True, write locally first then copy to S3.
|
|
1636
|
+
**kwargs: Additional args passed to wds.ShardWriter (e.g., maxcount).
|
|
1637
|
+
|
|
1638
|
+
Returns:
|
|
1639
|
+
List of S3 URLs for the written shards.
|
|
1640
|
+
|
|
1641
|
+
Raises:
|
|
1642
|
+
RuntimeError: If no shards were written.
|
|
1643
|
+
"""
|
|
1644
|
+
new_uuid = str(uuid4())
|
|
1645
|
+
shard_pattern = f"{self.bucket}/{prefix}/data--{new_uuid}--%06d.tar"
|
|
1646
|
+
|
|
1647
|
+
written_shards: list[str] = []
|
|
1648
|
+
|
|
1649
|
+
with TemporaryDirectory() as temp_dir:
|
|
1650
|
+
writer_opener, writer_post = _create_s3_write_callbacks(
|
|
1651
|
+
credentials=self.credentials,
|
|
1652
|
+
temp_dir=temp_dir,
|
|
1653
|
+
written_shards=written_shards,
|
|
1654
|
+
fs=self._fs,
|
|
1655
|
+
cache_local=cache_local,
|
|
1656
|
+
add_s3_prefix=True,
|
|
485
1657
|
)
|
|
486
1658
|
|
|
487
|
-
|
|
1659
|
+
with wds.writer.ShardWriter(
|
|
1660
|
+
shard_pattern,
|
|
1661
|
+
opener=writer_opener,
|
|
1662
|
+
post=writer_post,
|
|
1663
|
+
**kwargs,
|
|
1664
|
+
) as sink:
|
|
1665
|
+
for sample in ds.ordered(batch_size=None):
|
|
1666
|
+
sink.write(sample.as_wds)
|
|
1667
|
+
|
|
1668
|
+
if len(written_shards) == 0:
|
|
1669
|
+
raise RuntimeError("No shards written")
|
|
1670
|
+
|
|
1671
|
+
return written_shards
|
|
1672
|
+
|
|
1673
|
+
def read_url(self, url: str) -> str:
|
|
1674
|
+
"""Resolve an S3 URL for reading/streaming.
|
|
488
1675
|
|
|
489
|
-
|
|
1676
|
+
For S3-compatible stores with custom endpoints (like Cloudflare R2,
|
|
1677
|
+
MinIO, etc.), converts s3:// URLs to HTTPS URLs that WebDataset can
|
|
1678
|
+
stream directly.
|
|
1679
|
+
|
|
1680
|
+
For standard AWS S3 (no custom endpoint), URLs are returned unchanged
|
|
1681
|
+
since WebDataset's built-in s3fs integration handles them.
|
|
1682
|
+
|
|
1683
|
+
Args:
|
|
1684
|
+
url: S3 URL to resolve (e.g., 's3://bucket/path/file.tar').
|
|
1685
|
+
|
|
1686
|
+
Returns:
|
|
1687
|
+
HTTPS URL if custom endpoint is configured, otherwise unchanged.
|
|
1688
|
+
Example: 's3://bucket/path' -> 'https://endpoint.com/bucket/path'
|
|
1689
|
+
"""
|
|
1690
|
+
endpoint = self.credentials.get('AWS_ENDPOINT')
|
|
1691
|
+
if endpoint and url.startswith('s3://'):
|
|
1692
|
+
# s3://bucket/path -> https://endpoint/bucket/path
|
|
1693
|
+
path = url[5:] # Remove 's3://' prefix
|
|
1694
|
+
endpoint = endpoint.rstrip('/')
|
|
1695
|
+
return f"{endpoint}/{path}"
|
|
1696
|
+
return url
|
|
1697
|
+
|
|
1698
|
+
def supports_streaming(self) -> bool:
|
|
1699
|
+
"""S3 supports streaming reads.
|
|
1700
|
+
|
|
1701
|
+
Returns:
|
|
1702
|
+
True.
|
|
1703
|
+
"""
|
|
1704
|
+
return True
|
|
490
1705
|
|
|
491
1706
|
|
|
492
1707
|
#
|