atdata 0.2.2b1__py3-none-any.whl → 0.3.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/.gitignore +1 -0
- atdata/__init__.py +31 -1
- atdata/_cid.py +29 -35
- atdata/_exceptions.py +168 -0
- atdata/_helpers.py +33 -17
- atdata/_hf_api.py +109 -59
- atdata/_logging.py +70 -0
- atdata/_protocols.py +74 -132
- atdata/_schema_codec.py +38 -41
- atdata/_sources.py +57 -64
- atdata/_stub_manager.py +31 -26
- atdata/_type_utils.py +47 -7
- atdata/atmosphere/__init__.py +31 -24
- atdata/atmosphere/_types.py +11 -11
- atdata/atmosphere/client.py +11 -8
- atdata/atmosphere/lens.py +27 -30
- atdata/atmosphere/records.py +34 -39
- atdata/atmosphere/schema.py +35 -31
- atdata/atmosphere/store.py +16 -20
- atdata/cli/__init__.py +163 -168
- atdata/cli/diagnose.py +12 -8
- atdata/cli/inspect.py +69 -0
- atdata/cli/local.py +5 -2
- atdata/cli/preview.py +63 -0
- atdata/cli/schema.py +109 -0
- atdata/dataset.py +678 -533
- atdata/lens.py +85 -83
- atdata/local/__init__.py +71 -0
- atdata/local/_entry.py +157 -0
- atdata/local/_index.py +940 -0
- atdata/local/_repo_legacy.py +218 -0
- atdata/local/_s3.py +349 -0
- atdata/local/_schema.py +380 -0
- atdata/manifest/__init__.py +28 -0
- atdata/manifest/_aggregates.py +156 -0
- atdata/manifest/_builder.py +163 -0
- atdata/manifest/_fields.py +154 -0
- atdata/manifest/_manifest.py +146 -0
- atdata/manifest/_query.py +150 -0
- atdata/manifest/_writer.py +74 -0
- atdata/promote.py +20 -24
- atdata/providers/__init__.py +25 -0
- atdata/providers/_base.py +140 -0
- atdata/providers/_factory.py +69 -0
- atdata/providers/_postgres.py +214 -0
- atdata/providers/_redis.py +171 -0
- atdata/providers/_sqlite.py +191 -0
- atdata/repository.py +323 -0
- atdata/testing.py +337 -0
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/METADATA +5 -1
- atdata-0.3.0b1.dist-info/RECORD +54 -0
- atdata/local.py +0 -1707
- atdata-0.2.2b1.dist-info/RECORD +0 -28
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/licenses/LICENSE +0 -0
atdata/_logging.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Pluggable logging for atdata.
|
|
2
|
+
|
|
3
|
+
Provides a thin abstraction over Python's stdlib ``logging`` module that can
|
|
4
|
+
be replaced with ``structlog`` or any other logger implementing the standard
|
|
5
|
+
``debug``/``info``/``warning``/``error`` interface.
|
|
6
|
+
|
|
7
|
+
Usage::
|
|
8
|
+
|
|
9
|
+
# Default: stdlib logging (no config needed)
|
|
10
|
+
from atdata._logging import get_logger
|
|
11
|
+
log = get_logger()
|
|
12
|
+
log.info("processing shard", extra={"shard": "data-000.tar"})
|
|
13
|
+
|
|
14
|
+
# Plug in structlog (or any compatible logger):
|
|
15
|
+
import structlog
|
|
16
|
+
import atdata
|
|
17
|
+
atdata.configure_logging(structlog.get_logger())
|
|
18
|
+
|
|
19
|
+
The module also exports a lightweight ``LoggerProtocol`` for type checking
|
|
20
|
+
custom logger implementations.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import logging
|
|
26
|
+
from typing import Any, Protocol, runtime_checkable
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@runtime_checkable
|
|
30
|
+
class LoggerProtocol(Protocol):
|
|
31
|
+
"""Minimal interface that a pluggable logger must satisfy."""
|
|
32
|
+
|
|
33
|
+
def debug(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
|
|
34
|
+
def info(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
|
|
35
|
+
def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
|
|
36
|
+
def error(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# ---------------------------------------------------------------------------
|
|
40
|
+
# Module-level state
|
|
41
|
+
# ---------------------------------------------------------------------------
|
|
42
|
+
|
|
43
|
+
_logger: LoggerProtocol = logging.getLogger("atdata")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def configure_logging(logger: LoggerProtocol) -> None:
|
|
47
|
+
"""Replace the default logger with a custom implementation.
|
|
48
|
+
|
|
49
|
+
The provided logger must implement ``debug``, ``info``, ``warning``, and
|
|
50
|
+
``error`` methods. Both ``structlog`` bound loggers and stdlib
|
|
51
|
+
``logging.Logger`` instances satisfy this interface.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
logger: A logger instance implementing :class:`LoggerProtocol`.
|
|
55
|
+
|
|
56
|
+
Examples:
|
|
57
|
+
>>> import structlog
|
|
58
|
+
>>> atdata.configure_logging(structlog.get_logger())
|
|
59
|
+
"""
|
|
60
|
+
global _logger
|
|
61
|
+
_logger = logger
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_logger() -> LoggerProtocol:
|
|
65
|
+
"""Return the currently configured logger.
|
|
66
|
+
|
|
67
|
+
Returns the stdlib ``logging.getLogger("atdata")`` by default, or
|
|
68
|
+
whatever was last set via :func:`configure_logging`.
|
|
69
|
+
"""
|
|
70
|
+
return _logger
|
atdata/_protocols.py
CHANGED
|
@@ -10,7 +10,7 @@ formalize that common interface.
|
|
|
10
10
|
Note:
|
|
11
11
|
Protocol methods use ``...`` (Ellipsis) as the body per PEP 544. This is
|
|
12
12
|
the standard Python syntax for Protocol definitions - these are interface
|
|
13
|
-
specifications, not stub implementations. Concrete classes (
|
|
13
|
+
specifications, not stub implementations. Concrete classes (Index,
|
|
14
14
|
AtmosphereIndex, etc.) provide the actual implementations.
|
|
15
15
|
|
|
16
16
|
Protocols:
|
|
@@ -19,22 +19,19 @@ Protocols:
|
|
|
19
19
|
AbstractIndex: Protocol for index operations (schemas, datasets, lenses)
|
|
20
20
|
AbstractDataStore: Protocol for data storage operations
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
>>> process_datasets(local_index)
|
|
31
|
-
>>> process_datasets(atmosphere_index)
|
|
22
|
+
Examples:
|
|
23
|
+
>>> def process_datasets(index: AbstractIndex) -> None:
|
|
24
|
+
... for entry in index.list_datasets():
|
|
25
|
+
... print(f"{entry.name}: {entry.data_urls}")
|
|
26
|
+
...
|
|
27
|
+
>>> # Works with either Index or AtmosphereIndex
|
|
28
|
+
>>> process_datasets(local_index)
|
|
29
|
+
>>> process_datasets(atmosphere_index)
|
|
32
30
|
"""
|
|
33
31
|
|
|
34
32
|
from typing import (
|
|
35
33
|
IO,
|
|
36
34
|
Any,
|
|
37
|
-
ClassVar,
|
|
38
35
|
Iterator,
|
|
39
36
|
Optional,
|
|
40
37
|
Protocol,
|
|
@@ -67,39 +64,29 @@ class Packable(Protocol):
|
|
|
67
64
|
- Schema publishing (class introspection via dataclass fields)
|
|
68
65
|
- Serialization/deserialization (packed, from_bytes)
|
|
69
66
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
... instance = sample_type.from_bytes(data)
|
|
81
|
-
... print(instance.packed)
|
|
67
|
+
Examples:
|
|
68
|
+
>>> @packable
|
|
69
|
+
... class MySample:
|
|
70
|
+
... name: str
|
|
71
|
+
... value: int
|
|
72
|
+
...
|
|
73
|
+
>>> def process(sample_type: Type[Packable]) -> None:
|
|
74
|
+
... # Type checker knows sample_type has from_bytes, packed, etc.
|
|
75
|
+
... instance = sample_type.from_bytes(data)
|
|
76
|
+
... print(instance.packed)
|
|
82
77
|
"""
|
|
83
78
|
|
|
84
79
|
@classmethod
|
|
85
|
-
def from_data(cls, data: dict[str, Any]) -> "Packable":
|
|
86
|
-
"""Create instance from unpacked msgpack data dictionary."""
|
|
87
|
-
...
|
|
80
|
+
def from_data(cls, data: dict[str, Any]) -> "Packable": ...
|
|
88
81
|
|
|
89
82
|
@classmethod
|
|
90
|
-
def from_bytes(cls, bs: bytes) -> "Packable":
|
|
91
|
-
"""Create instance from raw msgpack bytes."""
|
|
92
|
-
...
|
|
83
|
+
def from_bytes(cls, bs: bytes) -> "Packable": ...
|
|
93
84
|
|
|
94
85
|
@property
|
|
95
|
-
def packed(self) -> bytes:
|
|
96
|
-
"""Pack this sample's data into msgpack bytes."""
|
|
97
|
-
...
|
|
86
|
+
def packed(self) -> bytes: ...
|
|
98
87
|
|
|
99
88
|
@property
|
|
100
|
-
def as_wds(self) -> dict[str, Any]:
|
|
101
|
-
"""WebDataset-compatible representation with __key__ and msgpack."""
|
|
102
|
-
...
|
|
89
|
+
def as_wds(self) -> dict[str, Any]: ...
|
|
103
90
|
|
|
104
91
|
|
|
105
92
|
##
|
|
@@ -121,16 +108,14 @@ class IndexEntry(Protocol):
|
|
|
121
108
|
"""
|
|
122
109
|
|
|
123
110
|
@property
|
|
124
|
-
def name(self) -> str:
|
|
125
|
-
"""Human-readable dataset name."""
|
|
126
|
-
...
|
|
111
|
+
def name(self) -> str: ...
|
|
127
112
|
|
|
128
113
|
@property
|
|
129
114
|
def schema_ref(self) -> str:
|
|
130
|
-
"""
|
|
115
|
+
"""Schema reference string.
|
|
131
116
|
|
|
132
|
-
|
|
133
|
-
|
|
117
|
+
Local: ``local://schemas/{module.Class}@{version}``
|
|
118
|
+
Atmosphere: ``at://did:plc:.../ac.foundation.dataset.sampleSchema/...``
|
|
134
119
|
"""
|
|
135
120
|
...
|
|
136
121
|
|
|
@@ -144,9 +129,7 @@ class IndexEntry(Protocol):
|
|
|
144
129
|
...
|
|
145
130
|
|
|
146
131
|
@property
|
|
147
|
-
def metadata(self) -> Optional[dict]:
|
|
148
|
-
"""Arbitrary metadata dictionary, or None if not set."""
|
|
149
|
-
...
|
|
132
|
+
def metadata(self) -> Optional[dict]: ...
|
|
150
133
|
|
|
151
134
|
|
|
152
135
|
##
|
|
@@ -154,7 +137,7 @@ class IndexEntry(Protocol):
|
|
|
154
137
|
|
|
155
138
|
|
|
156
139
|
class AbstractIndex(Protocol):
|
|
157
|
-
"""Protocol for index operations - implemented by
|
|
140
|
+
"""Protocol for index operations - implemented by Index and AtmosphereIndex.
|
|
158
141
|
|
|
159
142
|
This protocol defines the common interface for managing dataset metadata:
|
|
160
143
|
- Publishing and retrieving schemas
|
|
@@ -169,21 +152,19 @@ class AbstractIndex(Protocol):
|
|
|
169
152
|
- ``data_store``: An AbstractDataStore for reading/writing dataset shards.
|
|
170
153
|
If present, ``load_dataset`` will use it for S3 credential resolution.
|
|
171
154
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
... for entry in index.list_datasets():
|
|
186
|
-
... print(f"{entry.name} -> {entry.schema_ref}")
|
|
155
|
+
Examples:
|
|
156
|
+
>>> def publish_and_list(index: AbstractIndex) -> None:
|
|
157
|
+
... # Publish schemas for different types
|
|
158
|
+
... schema1 = index.publish_schema(ImageSample, version="1.0.0")
|
|
159
|
+
... schema2 = index.publish_schema(TextSample, version="1.0.0")
|
|
160
|
+
...
|
|
161
|
+
... # Insert datasets of different types
|
|
162
|
+
... index.insert_dataset(image_ds, name="images")
|
|
163
|
+
... index.insert_dataset(text_ds, name="texts")
|
|
164
|
+
...
|
|
165
|
+
... # List all datasets (mixed types)
|
|
166
|
+
... for entry in index.list_datasets():
|
|
167
|
+
... print(f"{entry.name} -> {entry.schema_ref}")
|
|
187
168
|
"""
|
|
188
169
|
|
|
189
170
|
@property
|
|
@@ -246,21 +227,9 @@ class AbstractIndex(Protocol):
|
|
|
246
227
|
...
|
|
247
228
|
|
|
248
229
|
@property
|
|
249
|
-
def datasets(self) -> Iterator[IndexEntry]:
|
|
250
|
-
"""Lazily iterate over all dataset entries in this index.
|
|
251
|
-
|
|
252
|
-
Yields:
|
|
253
|
-
IndexEntry for each dataset (may be of different sample types).
|
|
254
|
-
"""
|
|
255
|
-
...
|
|
230
|
+
def datasets(self) -> Iterator[IndexEntry]: ...
|
|
256
231
|
|
|
257
|
-
def list_datasets(self) -> list[IndexEntry]:
|
|
258
|
-
"""Get all dataset entries as a materialized list.
|
|
259
|
-
|
|
260
|
-
Returns:
|
|
261
|
-
List of IndexEntry for each dataset.
|
|
262
|
-
"""
|
|
263
|
-
...
|
|
232
|
+
def list_datasets(self) -> list[IndexEntry]: ...
|
|
264
233
|
|
|
265
234
|
# Schema operations
|
|
266
235
|
|
|
@@ -306,21 +275,9 @@ class AbstractIndex(Protocol):
|
|
|
306
275
|
...
|
|
307
276
|
|
|
308
277
|
@property
|
|
309
|
-
def schemas(self) -> Iterator[dict]:
|
|
310
|
-
"""Lazily iterate over all schema records in this index.
|
|
311
|
-
|
|
312
|
-
Yields:
|
|
313
|
-
Schema records as dictionaries.
|
|
314
|
-
"""
|
|
315
|
-
...
|
|
278
|
+
def schemas(self) -> Iterator[dict]: ...
|
|
316
279
|
|
|
317
|
-
def list_schemas(self) -> list[dict]:
|
|
318
|
-
"""Get all schema records as a materialized list.
|
|
319
|
-
|
|
320
|
-
Returns:
|
|
321
|
-
List of schema records as dictionaries.
|
|
322
|
-
"""
|
|
323
|
-
...
|
|
280
|
+
def list_schemas(self) -> list[dict]: ...
|
|
324
281
|
|
|
325
282
|
def decode_schema(self, ref: str) -> Type[Packable]:
|
|
326
283
|
"""Reconstruct a Python Packable type from a stored schema.
|
|
@@ -341,14 +298,12 @@ class AbstractIndex(Protocol):
|
|
|
341
298
|
KeyError: If schema not found.
|
|
342
299
|
ValueError: If schema cannot be decoded (unsupported field types).
|
|
343
300
|
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
>>> for sample in ds.ordered():
|
|
351
|
-
... print(sample) # sample is instance of SampleType
|
|
301
|
+
Examples:
|
|
302
|
+
>>> entry = index.get_dataset("my-dataset")
|
|
303
|
+
>>> SampleType = index.decode_schema(entry.schema_ref)
|
|
304
|
+
>>> ds = Dataset[SampleType](entry.data_urls[0])
|
|
305
|
+
>>> for sample in ds.ordered():
|
|
306
|
+
... print(sample) # sample is instance of SampleType
|
|
352
307
|
"""
|
|
353
308
|
...
|
|
354
309
|
|
|
@@ -368,13 +323,11 @@ class AbstractDataStore(Protocol):
|
|
|
368
323
|
flexible deployment: local index with S3 storage, atmosphere index with
|
|
369
324
|
S3 storage, or atmosphere index with PDS blobs.
|
|
370
325
|
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
>>> print(urls)
|
|
377
|
-
['s3://my-bucket/training/v1/shard-000000.tar', ...]
|
|
326
|
+
Examples:
|
|
327
|
+
>>> store = S3DataStore(credentials, bucket="my-bucket")
|
|
328
|
+
>>> urls = store.write_shards(dataset, prefix="training/v1")
|
|
329
|
+
>>> print(urls)
|
|
330
|
+
['s3://my-bucket/training/v1/shard-000000.tar', ...]
|
|
378
331
|
"""
|
|
379
332
|
|
|
380
333
|
def write_shards(
|
|
@@ -412,14 +365,7 @@ class AbstractDataStore(Protocol):
|
|
|
412
365
|
"""
|
|
413
366
|
...
|
|
414
367
|
|
|
415
|
-
def supports_streaming(self) -> bool:
|
|
416
|
-
"""Whether this store supports streaming reads.
|
|
417
|
-
|
|
418
|
-
Returns:
|
|
419
|
-
True if the store supports efficient streaming (like S3),
|
|
420
|
-
False if data must be fully downloaded first.
|
|
421
|
-
"""
|
|
422
|
-
...
|
|
368
|
+
def supports_streaming(self) -> bool: ...
|
|
423
369
|
|
|
424
370
|
|
|
425
371
|
##
|
|
@@ -443,18 +389,16 @@ class DataSource(Protocol):
|
|
|
443
389
|
- ATProto blob streaming
|
|
444
390
|
- Any other source that can provide file-like objects
|
|
445
391
|
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
>>> for sample in ds.ordered():
|
|
457
|
-
... print(sample)
|
|
392
|
+
Examples:
|
|
393
|
+
>>> source = S3Source(
|
|
394
|
+
... bucket="my-bucket",
|
|
395
|
+
... keys=["data-000.tar", "data-001.tar"],
|
|
396
|
+
... endpoint="https://r2.example.com",
|
|
397
|
+
... credentials=creds,
|
|
398
|
+
... )
|
|
399
|
+
>>> ds = Dataset[MySample](source)
|
|
400
|
+
>>> for sample in ds.ordered():
|
|
401
|
+
... print(sample)
|
|
458
402
|
"""
|
|
459
403
|
|
|
460
404
|
@property
|
|
@@ -467,12 +411,10 @@ class DataSource(Protocol):
|
|
|
467
411
|
Yields:
|
|
468
412
|
Tuple of (shard_identifier, file_like_stream).
|
|
469
413
|
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
... print(f"Processing {shard_id}")
|
|
475
|
-
... data = stream.read()
|
|
414
|
+
Examples:
|
|
415
|
+
>>> for shard_id, stream in source.shards:
|
|
416
|
+
... print(f"Processing {shard_id}")
|
|
417
|
+
... data = stream.read()
|
|
476
418
|
"""
|
|
477
419
|
...
|
|
478
420
|
|
|
@@ -496,13 +438,13 @@ class DataSource(Protocol):
|
|
|
496
438
|
only its assigned shards rather than iterating all shards.
|
|
497
439
|
|
|
498
440
|
Args:
|
|
499
|
-
shard_id: Shard identifier from
|
|
441
|
+
shard_id: Shard identifier from list_shards().
|
|
500
442
|
|
|
501
443
|
Returns:
|
|
502
444
|
File-like stream for reading the shard.
|
|
503
445
|
|
|
504
446
|
Raises:
|
|
505
|
-
KeyError: If shard_id is not in
|
|
447
|
+
KeyError: If shard_id is not in list_shards().
|
|
506
448
|
"""
|
|
507
449
|
...
|
|
508
450
|
|
atdata/_schema_codec.py
CHANGED
|
@@ -9,19 +9,17 @@ The schema format follows the ATProto record structure defined in
|
|
|
9
9
|
``atmosphere/_types.py``, with field types supporting primitives, ndarrays,
|
|
10
10
|
arrays, and schema references.
|
|
11
11
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
>>> ImageSample = schema_to_type(schema)
|
|
24
|
-
>>> sample = ImageSample(image=np.zeros((64, 64)), label="cat")
|
|
12
|
+
Examples:
|
|
13
|
+
>>> schema = {
|
|
14
|
+
... "name": "ImageSample",
|
|
15
|
+
... "version": "1.0.0",
|
|
16
|
+
... "fields": [
|
|
17
|
+
... {"name": "image", "fieldType": {"$type": "...#ndarray", "dtype": "float32"}, "optional": False},
|
|
18
|
+
... {"name": "label", "fieldType": {"$type": "...#primitive", "primitive": "str"}, "optional": False},
|
|
19
|
+
... ]
|
|
20
|
+
... }
|
|
21
|
+
>>> ImageSample = schema_to_type(schema)
|
|
22
|
+
>>> sample = ImageSample(image=np.zeros((64, 64)), label="cat")
|
|
25
23
|
"""
|
|
26
24
|
|
|
27
25
|
from dataclasses import field, make_dataclass
|
|
@@ -30,13 +28,14 @@ import hashlib
|
|
|
30
28
|
|
|
31
29
|
from numpy.typing import NDArray
|
|
32
30
|
|
|
33
|
-
# Import PackableSample for inheritance
|
|
31
|
+
# Import PackableSample for inheritance in dynamic class generation
|
|
34
32
|
from .dataset import PackableSample
|
|
33
|
+
from ._protocols import Packable
|
|
35
34
|
|
|
36
35
|
|
|
37
36
|
# Type cache to avoid regenerating identical types
|
|
38
37
|
# Uses insertion order (Python 3.7+) for simple FIFO eviction
|
|
39
|
-
_type_cache: dict[str, Type[
|
|
38
|
+
_type_cache: dict[str, Type[Packable]] = {}
|
|
40
39
|
_TYPE_CACHE_MAX_SIZE = 256
|
|
41
40
|
|
|
42
41
|
|
|
@@ -132,7 +131,7 @@ def schema_to_type(
|
|
|
132
131
|
schema: dict,
|
|
133
132
|
*,
|
|
134
133
|
use_cache: bool = True,
|
|
135
|
-
) -> Type[
|
|
134
|
+
) -> Type[Packable]:
|
|
136
135
|
"""Generate a PackableSample subclass from a schema record.
|
|
137
136
|
|
|
138
137
|
This function dynamically creates a dataclass that inherits from PackableSample,
|
|
@@ -151,14 +150,12 @@ def schema_to_type(
|
|
|
151
150
|
Raises:
|
|
152
151
|
ValueError: If schema is malformed or contains unsupported types.
|
|
153
152
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
>>> for sample in ds.ordered():
|
|
161
|
-
... print(sample)
|
|
153
|
+
Examples:
|
|
154
|
+
>>> schema = index.get_schema("local://schemas/MySample@1.0.0")
|
|
155
|
+
>>> MySample = schema_to_type(schema)
|
|
156
|
+
>>> ds = Dataset[MySample]("data.tar")
|
|
157
|
+
>>> for sample in ds.ordered():
|
|
158
|
+
... print(sample)
|
|
162
159
|
"""
|
|
163
160
|
# Check cache first
|
|
164
161
|
if use_cache:
|
|
@@ -207,7 +204,9 @@ def schema_to_type(
|
|
|
207
204
|
namespace={
|
|
208
205
|
"__post_init__": lambda self: PackableSample.__post_init__(self),
|
|
209
206
|
"__schema_version__": version,
|
|
210
|
-
"__schema_ref__": schema.get(
|
|
207
|
+
"__schema_ref__": schema.get(
|
|
208
|
+
"$ref", None
|
|
209
|
+
), # Store original ref if available
|
|
211
210
|
},
|
|
212
211
|
)
|
|
213
212
|
|
|
@@ -243,7 +242,9 @@ def _field_type_to_stub_str(field_type: dict, optional: bool = False) -> str:
|
|
|
243
242
|
|
|
244
243
|
if kind == "primitive":
|
|
245
244
|
primitive = field_type.get("primitive", "str")
|
|
246
|
-
py_type =
|
|
245
|
+
py_type = (
|
|
246
|
+
primitive # str, int, float, bool, bytes are all valid Python type names
|
|
247
|
+
)
|
|
247
248
|
elif kind == "ndarray":
|
|
248
249
|
py_type = "NDArray[Any]"
|
|
249
250
|
elif kind == "array":
|
|
@@ -282,14 +283,12 @@ def generate_stub(schema: dict) -> str:
|
|
|
282
283
|
Returns:
|
|
283
284
|
String content for a .pyi stub file.
|
|
284
285
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
>>> with open("stubs/my_sample.pyi", "w") as f:
|
|
292
|
-
... f.write(stub_content)
|
|
286
|
+
Examples:
|
|
287
|
+
>>> schema = index.get_schema("atdata://local/sampleSchema/MySample@1.0.0")
|
|
288
|
+
>>> stub_content = generate_stub(schema.to_dict())
|
|
289
|
+
>>> # Save to a stubs directory configured in your IDE
|
|
290
|
+
>>> with open("stubs/my_sample.pyi", "w") as f:
|
|
291
|
+
... f.write(stub_content)
|
|
293
292
|
"""
|
|
294
293
|
name = schema.get("name", "UnknownSample")
|
|
295
294
|
version = schema.get("version", "1.0.0")
|
|
@@ -360,12 +359,10 @@ def generate_module(schema: dict) -> str:
|
|
|
360
359
|
Returns:
|
|
361
360
|
String content for a .py module file.
|
|
362
361
|
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
>>> module_content = generate_module(schema.to_dict())
|
|
368
|
-
>>> # The module can be imported after being saved
|
|
362
|
+
Examples:
|
|
363
|
+
>>> schema = index.get_schema("atdata://local/sampleSchema/MySample@1.0.0")
|
|
364
|
+
>>> module_content = generate_module(schema.to_dict())
|
|
365
|
+
>>> # The module can be imported after being saved
|
|
369
366
|
"""
|
|
370
367
|
name = schema.get("name", "UnknownSample")
|
|
371
368
|
version = schema.get("version", "1.0.0")
|
|
@@ -424,7 +421,7 @@ def clear_type_cache() -> None:
|
|
|
424
421
|
_type_cache.clear()
|
|
425
422
|
|
|
426
423
|
|
|
427
|
-
def get_cached_types() -> dict[str, Type[
|
|
424
|
+
def get_cached_types() -> dict[str, Type[Packable]]:
|
|
428
425
|
"""Get a copy of the current type cache.
|
|
429
426
|
|
|
430
427
|
Returns:
|