atdata 0.2.2b1__py3-none-any.whl → 0.3.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/.gitignore +1 -0
- atdata/__init__.py +31 -1
- atdata/_cid.py +29 -35
- atdata/_exceptions.py +168 -0
- atdata/_helpers.py +33 -17
- atdata/_hf_api.py +109 -59
- atdata/_logging.py +70 -0
- atdata/_protocols.py +74 -132
- atdata/_schema_codec.py +38 -41
- atdata/_sources.py +57 -64
- atdata/_stub_manager.py +31 -26
- atdata/_type_utils.py +47 -7
- atdata/atmosphere/__init__.py +31 -24
- atdata/atmosphere/_types.py +11 -11
- atdata/atmosphere/client.py +11 -8
- atdata/atmosphere/lens.py +27 -30
- atdata/atmosphere/records.py +34 -39
- atdata/atmosphere/schema.py +35 -31
- atdata/atmosphere/store.py +16 -20
- atdata/cli/__init__.py +163 -168
- atdata/cli/diagnose.py +12 -8
- atdata/cli/inspect.py +69 -0
- atdata/cli/local.py +5 -2
- atdata/cli/preview.py +63 -0
- atdata/cli/schema.py +109 -0
- atdata/dataset.py +678 -533
- atdata/lens.py +85 -83
- atdata/local/__init__.py +71 -0
- atdata/local/_entry.py +157 -0
- atdata/local/_index.py +940 -0
- atdata/local/_repo_legacy.py +218 -0
- atdata/local/_s3.py +349 -0
- atdata/local/_schema.py +380 -0
- atdata/manifest/__init__.py +28 -0
- atdata/manifest/_aggregates.py +156 -0
- atdata/manifest/_builder.py +163 -0
- atdata/manifest/_fields.py +154 -0
- atdata/manifest/_manifest.py +146 -0
- atdata/manifest/_query.py +150 -0
- atdata/manifest/_writer.py +74 -0
- atdata/promote.py +20 -24
- atdata/providers/__init__.py +25 -0
- atdata/providers/_base.py +140 -0
- atdata/providers/_factory.py +69 -0
- atdata/providers/_postgres.py +214 -0
- atdata/providers/_redis.py +171 -0
- atdata/providers/_sqlite.py +191 -0
- atdata/repository.py +323 -0
- atdata/testing.py +337 -0
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/METADATA +5 -1
- atdata-0.3.0b1.dist-info/RECORD +54 -0
- atdata/local.py +0 -1707
- atdata-0.2.2b1.dist-info/RECORD +0 -28
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/licenses/LICENSE +0 -0
atdata/promote.py
CHANGED
|
@@ -4,27 +4,25 @@ This module provides functionality to promote locally-indexed datasets to the
|
|
|
4
4
|
ATProto atmosphere network. This enables sharing datasets with the broader
|
|
5
5
|
federation while maintaining schema consistency.
|
|
6
6
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
>>> entry = local_index.get_dataset("my-dataset")
|
|
21
|
-
>>> at_uri = promote_to_atmosphere(entry, local_index, client)
|
|
7
|
+
Examples:
|
|
8
|
+
>>> from atdata.local import Index, Repo
|
|
9
|
+
>>> from atdata.atmosphere import AtmosphereClient, AtmosphereIndex
|
|
10
|
+
>>> from atdata.promote import promote_to_atmosphere
|
|
11
|
+
>>>
|
|
12
|
+
>>> # Setup
|
|
13
|
+
>>> local_index = Index()
|
|
14
|
+
>>> client = AtmosphereClient()
|
|
15
|
+
>>> client.login("handle.bsky.social", "app-password")
|
|
16
|
+
>>>
|
|
17
|
+
>>> # Promote a dataset
|
|
18
|
+
>>> entry = local_index.get_dataset("my-dataset")
|
|
19
|
+
>>> at_uri = promote_to_atmosphere(entry, local_index, client)
|
|
22
20
|
"""
|
|
23
21
|
|
|
24
22
|
from typing import TYPE_CHECKING, Type
|
|
25
23
|
|
|
26
24
|
if TYPE_CHECKING:
|
|
27
|
-
from .local import LocalDatasetEntry, Index
|
|
25
|
+
from .local import LocalDatasetEntry, Index
|
|
28
26
|
from .atmosphere import AtmosphereClient
|
|
29
27
|
from ._protocols import AbstractDataStore, Packable
|
|
30
28
|
|
|
@@ -96,7 +94,7 @@ def _find_or_publish_schema(
|
|
|
96
94
|
|
|
97
95
|
def promote_to_atmosphere(
|
|
98
96
|
local_entry: "LocalDatasetEntry",
|
|
99
|
-
local_index: "
|
|
97
|
+
local_index: "Index",
|
|
100
98
|
atmosphere_client: "AtmosphereClient",
|
|
101
99
|
*,
|
|
102
100
|
data_store: "AbstractDataStore | None" = None,
|
|
@@ -128,13 +126,11 @@ def promote_to_atmosphere(
|
|
|
128
126
|
KeyError: If schema not found in local index.
|
|
129
127
|
ValueError: If local entry has no data URLs.
|
|
130
128
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
>>> print(uri)
|
|
137
|
-
at://did:plc:abc123/ac.foundation.dataset.datasetIndex/...
|
|
129
|
+
Examples:
|
|
130
|
+
>>> entry = local_index.get_dataset("mnist-train")
|
|
131
|
+
>>> uri = promote_to_atmosphere(entry, local_index, client)
|
|
132
|
+
>>> print(uri)
|
|
133
|
+
at://did:plc:abc123/ac.foundation.dataset.datasetIndex/...
|
|
138
134
|
"""
|
|
139
135
|
from .atmosphere import DatasetPublisher
|
|
140
136
|
from ._schema_codec import schema_to_type
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Storage provider backends for the local Index.
|
|
2
|
+
|
|
3
|
+
This package defines the ``IndexProvider`` abstract base class and concrete
|
|
4
|
+
implementations for Redis, SQLite, and PostgreSQL. The ``Index`` class in
|
|
5
|
+
``atdata.local`` delegates all persistence to an ``IndexProvider``.
|
|
6
|
+
|
|
7
|
+
Providers:
|
|
8
|
+
RedisProvider: Redis-backed storage (existing default).
|
|
9
|
+
SqliteProvider: SQLite file-based storage (zero external dependencies).
|
|
10
|
+
PostgresProvider: PostgreSQL storage (requires ``psycopg``).
|
|
11
|
+
|
|
12
|
+
Examples:
|
|
13
|
+
>>> from atdata.providers import IndexProvider, create_provider
|
|
14
|
+
>>> provider = create_provider("sqlite", path="~/.atdata/index.db")
|
|
15
|
+
>>> from atdata.local import Index
|
|
16
|
+
>>> index = Index(provider=provider)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from ._base import IndexProvider
|
|
20
|
+
from ._factory import create_provider
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"IndexProvider",
|
|
24
|
+
"create_provider",
|
|
25
|
+
]
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""Abstract base class for index storage providers.
|
|
2
|
+
|
|
3
|
+
The ``IndexProvider`` ABC defines the persistence contract that the ``Index``
|
|
4
|
+
class delegates to. Each provider handles storage and retrieval of two entity
|
|
5
|
+
types — dataset entries and schema records — using whatever backend it wraps.
|
|
6
|
+
|
|
7
|
+
Concrete implementations live in sibling modules:
|
|
8
|
+
``_redis.py``, ``_sqlite.py``, ``_postgres.py``
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from abc import ABC, abstractmethod
|
|
14
|
+
from typing import TYPE_CHECKING, Iterator
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from ..local import LocalDatasetEntry
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class IndexProvider(ABC):
|
|
21
|
+
"""Storage backend for the ``Index`` class.
|
|
22
|
+
|
|
23
|
+
Implementations persist ``LocalDatasetEntry`` objects and schema JSON
|
|
24
|
+
records. The ``Index`` class owns all business logic (CID generation,
|
|
25
|
+
version bumping, schema building); the provider is a pure persistence
|
|
26
|
+
layer.
|
|
27
|
+
|
|
28
|
+
Examples:
|
|
29
|
+
>>> from atdata.providers import create_provider
|
|
30
|
+
>>> provider = create_provider("sqlite", path="/tmp/index.db")
|
|
31
|
+
>>> provider.store_schema("MySample", "1.0.0", '{"name": "MySample"}')
|
|
32
|
+
>>> provider.get_schema_json("MySample", "1.0.0")
|
|
33
|
+
'{"name": "MySample"}'
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
# ------------------------------------------------------------------
|
|
37
|
+
# Dataset entry operations
|
|
38
|
+
# ------------------------------------------------------------------
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def store_entry(self, entry: LocalDatasetEntry) -> None:
|
|
42
|
+
"""Persist a dataset entry (upsert by CID).
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
entry: The dataset entry to store. The entry's ``cid`` property
|
|
46
|
+
is used as the primary key.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def get_entry_by_cid(self, cid: str) -> LocalDatasetEntry:
|
|
51
|
+
"""Load a dataset entry by its content identifier.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
cid: Content-addressable identifier.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
The matching ``LocalDatasetEntry``.
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
KeyError: If no entry exists for *cid*.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
def get_entry_by_name(self, name: str) -> LocalDatasetEntry:
|
|
65
|
+
"""Load a dataset entry by its human-readable name.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
name: Dataset name.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
The first matching ``LocalDatasetEntry``.
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
KeyError: If no entry exists with *name*.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def iter_entries(self) -> Iterator[LocalDatasetEntry]:
|
|
79
|
+
"""Iterate over all stored dataset entries.
|
|
80
|
+
|
|
81
|
+
Yields:
|
|
82
|
+
``LocalDatasetEntry`` objects in unspecified order.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
# ------------------------------------------------------------------
|
|
86
|
+
# Schema operations
|
|
87
|
+
# ------------------------------------------------------------------
|
|
88
|
+
|
|
89
|
+
@abstractmethod
|
|
90
|
+
def store_schema(self, name: str, version: str, schema_json: str) -> None:
|
|
91
|
+
"""Persist a schema record (upsert by name + version).
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
name: Schema name (e.g. ``"MySample"``).
|
|
95
|
+
version: Semantic version string (e.g. ``"1.0.0"``).
|
|
96
|
+
schema_json: JSON-serialized schema record.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
@abstractmethod
|
|
100
|
+
def get_schema_json(self, name: str, version: str) -> str | None:
|
|
101
|
+
"""Load a schema's JSON by name and version.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
name: Schema name.
|
|
105
|
+
version: Semantic version string.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
The JSON string, or ``None`` if not found.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
@abstractmethod
|
|
112
|
+
def iter_schemas(self) -> Iterator[tuple[str, str, str]]:
|
|
113
|
+
"""Iterate over all stored schemas.
|
|
114
|
+
|
|
115
|
+
Yields:
|
|
116
|
+
Tuples of ``(name, version, schema_json)``.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
@abstractmethod
|
|
120
|
+
def find_latest_version(self, name: str) -> str | None:
|
|
121
|
+
"""Find the latest semantic version for a schema name.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
name: Schema name to search for.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
The latest version string (e.g. ``"1.2.3"``), or ``None``
|
|
128
|
+
if no schema with *name* exists.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
# ------------------------------------------------------------------
|
|
132
|
+
# Lifecycle
|
|
133
|
+
# ------------------------------------------------------------------
|
|
134
|
+
|
|
135
|
+
def close(self) -> None:
|
|
136
|
+
"""Release any resources held by the provider.
|
|
137
|
+
|
|
138
|
+
The default implementation is a no-op. Providers that hold
|
|
139
|
+
connections (SQLite, PostgreSQL) should override this.
|
|
140
|
+
"""
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Factory for creating index providers by name.
|
|
2
|
+
|
|
3
|
+
Examples:
|
|
4
|
+
>>> from atdata.providers._factory import create_provider
|
|
5
|
+
>>> provider = create_provider("sqlite", path="/tmp/index.db")
|
|
6
|
+
>>> provider = create_provider("redis")
|
|
7
|
+
>>> provider = create_provider("postgres", dsn="postgresql://localhost/mydb")
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from ._base import IndexProvider
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def create_provider(
|
|
19
|
+
name: str,
|
|
20
|
+
*,
|
|
21
|
+
path: str | Path | None = None,
|
|
22
|
+
dsn: str | None = None,
|
|
23
|
+
redis: Any = None,
|
|
24
|
+
**kwargs: Any,
|
|
25
|
+
) -> IndexProvider:
|
|
26
|
+
"""Instantiate an ``IndexProvider`` by backend name.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
name: One of ``"redis"``, ``"sqlite"``, or ``"postgres"``.
|
|
30
|
+
path: Database file path (SQLite). Defaults to
|
|
31
|
+
``~/.atdata/index.db`` when *name* is ``"sqlite"``.
|
|
32
|
+
dsn: Connection string (PostgreSQL).
|
|
33
|
+
redis: An existing ``redis.Redis`` connection (Redis). When
|
|
34
|
+
``None`` and *name* is ``"redis"``, a new connection is
|
|
35
|
+
created from *kwargs*.
|
|
36
|
+
**kwargs: Extra arguments forwarded to the provider constructor
|
|
37
|
+
(e.g. Redis host/port).
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A ready-to-use ``IndexProvider``.
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
ValueError: If *name* is not a recognised backend.
|
|
44
|
+
"""
|
|
45
|
+
name = name.lower().strip()
|
|
46
|
+
|
|
47
|
+
if name == "redis":
|
|
48
|
+
from ._redis import RedisProvider
|
|
49
|
+
from redis import Redis as _Redis
|
|
50
|
+
|
|
51
|
+
if redis is not None:
|
|
52
|
+
return RedisProvider(redis)
|
|
53
|
+
return RedisProvider(_Redis(**kwargs))
|
|
54
|
+
|
|
55
|
+
if name == "sqlite":
|
|
56
|
+
from ._sqlite import SqliteProvider
|
|
57
|
+
|
|
58
|
+
return SqliteProvider(path=path)
|
|
59
|
+
|
|
60
|
+
if name in ("postgres", "postgresql"):
|
|
61
|
+
from ._postgres import PostgresProvider
|
|
62
|
+
|
|
63
|
+
if dsn is None:
|
|
64
|
+
raise ValueError("dsn is required for the postgres provider")
|
|
65
|
+
return PostgresProvider(dsn=dsn)
|
|
66
|
+
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"Unknown provider {name!r}. Choose from: 'redis', 'sqlite', 'postgres'."
|
|
69
|
+
)
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
"""PostgreSQL-backed index provider.
|
|
2
|
+
|
|
3
|
+
Stores dataset entries and schema records in PostgreSQL tables.
|
|
4
|
+
Requires the ``psycopg`` (v3) package, which is an optional dependency::
|
|
5
|
+
|
|
6
|
+
pip install "atdata[postgres]"
|
|
7
|
+
|
|
8
|
+
The provider lazily imports ``psycopg`` so that ``import atdata`` never
|
|
9
|
+
fails when the package is absent.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from typing import Iterator
|
|
15
|
+
|
|
16
|
+
import msgpack
|
|
17
|
+
|
|
18
|
+
from ._base import IndexProvider
|
|
19
|
+
from .._type_utils import parse_semver
|
|
20
|
+
|
|
21
|
+
_CREATE_TABLES = """\
|
|
22
|
+
CREATE TABLE IF NOT EXISTS dataset_entries (
|
|
23
|
+
cid TEXT PRIMARY KEY,
|
|
24
|
+
name TEXT NOT NULL,
|
|
25
|
+
schema_ref TEXT NOT NULL,
|
|
26
|
+
data_urls BYTEA NOT NULL,
|
|
27
|
+
metadata BYTEA,
|
|
28
|
+
legacy_uuid TEXT,
|
|
29
|
+
created_at TIMESTAMPTZ DEFAULT now()
|
|
30
|
+
);
|
|
31
|
+
|
|
32
|
+
CREATE INDEX IF NOT EXISTS idx_entries_name
|
|
33
|
+
ON dataset_entries(name);
|
|
34
|
+
|
|
35
|
+
CREATE TABLE IF NOT EXISTS schemas (
|
|
36
|
+
name TEXT NOT NULL,
|
|
37
|
+
version TEXT NOT NULL,
|
|
38
|
+
schema_json TEXT NOT NULL,
|
|
39
|
+
created_at TIMESTAMPTZ DEFAULT now(),
|
|
40
|
+
PRIMARY KEY (name, version)
|
|
41
|
+
);
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class PostgresProvider(IndexProvider):
|
|
46
|
+
"""Index provider backed by PostgreSQL.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
dsn: PostgreSQL connection string, e.g.
|
|
50
|
+
``"postgresql://user:pass@host:5432/dbname"``.
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
ImportError: If ``psycopg`` is not installed.
|
|
54
|
+
|
|
55
|
+
Examples:
|
|
56
|
+
>>> provider = PostgresProvider(dsn="postgresql://localhost/atdata")
|
|
57
|
+
>>> provider.store_schema("MySample", "1.0.0", '{"name":"MySample"}')
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(self, dsn: str) -> None:
|
|
61
|
+
try:
|
|
62
|
+
import psycopg
|
|
63
|
+
except ImportError as exc:
|
|
64
|
+
raise ImportError(
|
|
65
|
+
"The postgres provider requires the 'psycopg' package. "
|
|
66
|
+
"Install it with: pip install 'atdata[postgres]'"
|
|
67
|
+
) from exc
|
|
68
|
+
|
|
69
|
+
self._conn = psycopg.connect(dsn, autocommit=False)
|
|
70
|
+
with self._conn.cursor() as cur:
|
|
71
|
+
cur.execute(_CREATE_TABLES)
|
|
72
|
+
self._conn.commit()
|
|
73
|
+
|
|
74
|
+
# ------------------------------------------------------------------
|
|
75
|
+
# Dataset entry operations
|
|
76
|
+
# ------------------------------------------------------------------
|
|
77
|
+
|
|
78
|
+
def store_entry(self, entry: "LocalDatasetEntry") -> None: # noqa: F821
|
|
79
|
+
with self._conn.cursor() as cur:
|
|
80
|
+
cur.execute(
|
|
81
|
+
"""INSERT INTO dataset_entries
|
|
82
|
+
(cid, name, schema_ref, data_urls, metadata, legacy_uuid)
|
|
83
|
+
VALUES (%s, %s, %s, %s, %s, %s)
|
|
84
|
+
ON CONFLICT (cid) DO UPDATE SET
|
|
85
|
+
name = EXCLUDED.name,
|
|
86
|
+
schema_ref = EXCLUDED.schema_ref,
|
|
87
|
+
data_urls = EXCLUDED.data_urls,
|
|
88
|
+
metadata = EXCLUDED.metadata,
|
|
89
|
+
legacy_uuid = EXCLUDED.legacy_uuid""",
|
|
90
|
+
(
|
|
91
|
+
entry.cid,
|
|
92
|
+
entry.name,
|
|
93
|
+
entry.schema_ref,
|
|
94
|
+
msgpack.packb(entry.data_urls),
|
|
95
|
+
msgpack.packb(entry.metadata)
|
|
96
|
+
if entry.metadata is not None
|
|
97
|
+
else None,
|
|
98
|
+
entry._legacy_uuid,
|
|
99
|
+
),
|
|
100
|
+
)
|
|
101
|
+
self._conn.commit()
|
|
102
|
+
|
|
103
|
+
def get_entry_by_cid(self, cid: str) -> "LocalDatasetEntry": # noqa: F821
|
|
104
|
+
with self._conn.cursor() as cur:
|
|
105
|
+
cur.execute(
|
|
106
|
+
"SELECT cid, name, schema_ref, data_urls, metadata, legacy_uuid "
|
|
107
|
+
"FROM dataset_entries WHERE cid = %s",
|
|
108
|
+
(cid,),
|
|
109
|
+
)
|
|
110
|
+
row = cur.fetchone()
|
|
111
|
+
if row is None:
|
|
112
|
+
raise KeyError(f"LocalDatasetEntry not found: {cid}")
|
|
113
|
+
return _row_to_entry(row)
|
|
114
|
+
|
|
115
|
+
def get_entry_by_name(self, name: str) -> "LocalDatasetEntry": # noqa: F821
|
|
116
|
+
with self._conn.cursor() as cur:
|
|
117
|
+
cur.execute(
|
|
118
|
+
"SELECT cid, name, schema_ref, data_urls, metadata, legacy_uuid "
|
|
119
|
+
"FROM dataset_entries WHERE name = %s LIMIT 1",
|
|
120
|
+
(name,),
|
|
121
|
+
)
|
|
122
|
+
row = cur.fetchone()
|
|
123
|
+
if row is None:
|
|
124
|
+
raise KeyError(f"No entry with name: {name}")
|
|
125
|
+
return _row_to_entry(row)
|
|
126
|
+
|
|
127
|
+
def iter_entries(self) -> Iterator["LocalDatasetEntry"]: # noqa: F821
|
|
128
|
+
with self._conn.cursor() as cur:
|
|
129
|
+
cur.execute(
|
|
130
|
+
"SELECT cid, name, schema_ref, data_urls, metadata, legacy_uuid "
|
|
131
|
+
"FROM dataset_entries"
|
|
132
|
+
)
|
|
133
|
+
for row in cur:
|
|
134
|
+
yield _row_to_entry(row)
|
|
135
|
+
|
|
136
|
+
# ------------------------------------------------------------------
|
|
137
|
+
# Schema operations
|
|
138
|
+
# ------------------------------------------------------------------
|
|
139
|
+
|
|
140
|
+
def store_schema(self, name: str, version: str, schema_json: str) -> None:
|
|
141
|
+
with self._conn.cursor() as cur:
|
|
142
|
+
cur.execute(
|
|
143
|
+
"""INSERT INTO schemas (name, version, schema_json)
|
|
144
|
+
VALUES (%s, %s, %s)
|
|
145
|
+
ON CONFLICT (name, version) DO UPDATE SET
|
|
146
|
+
schema_json = EXCLUDED.schema_json""",
|
|
147
|
+
(name, version, schema_json),
|
|
148
|
+
)
|
|
149
|
+
self._conn.commit()
|
|
150
|
+
|
|
151
|
+
def get_schema_json(self, name: str, version: str) -> str | None:
|
|
152
|
+
with self._conn.cursor() as cur:
|
|
153
|
+
cur.execute(
|
|
154
|
+
"SELECT schema_json FROM schemas WHERE name = %s AND version = %s",
|
|
155
|
+
(name, version),
|
|
156
|
+
)
|
|
157
|
+
row = cur.fetchone()
|
|
158
|
+
if row is None:
|
|
159
|
+
return None
|
|
160
|
+
return row[0]
|
|
161
|
+
|
|
162
|
+
def iter_schemas(self) -> Iterator[tuple[str, str, str]]:
|
|
163
|
+
with self._conn.cursor() as cur:
|
|
164
|
+
cur.execute("SELECT name, version, schema_json FROM schemas")
|
|
165
|
+
for row in cur:
|
|
166
|
+
yield row[0], row[1], row[2]
|
|
167
|
+
|
|
168
|
+
def find_latest_version(self, name: str) -> str | None:
|
|
169
|
+
with self._conn.cursor() as cur:
|
|
170
|
+
cur.execute(
|
|
171
|
+
"SELECT version FROM schemas WHERE name = %s",
|
|
172
|
+
(name,),
|
|
173
|
+
)
|
|
174
|
+
latest: tuple[int, int, int] | None = None
|
|
175
|
+
latest_str: str | None = None
|
|
176
|
+
for (version_str,) in cur:
|
|
177
|
+
try:
|
|
178
|
+
v = parse_semver(version_str)
|
|
179
|
+
if latest is None or v > latest:
|
|
180
|
+
latest = v
|
|
181
|
+
latest_str = version_str
|
|
182
|
+
except ValueError:
|
|
183
|
+
continue
|
|
184
|
+
return latest_str
|
|
185
|
+
|
|
186
|
+
# ------------------------------------------------------------------
|
|
187
|
+
# Lifecycle
|
|
188
|
+
# ------------------------------------------------------------------
|
|
189
|
+
|
|
190
|
+
def close(self) -> None:
|
|
191
|
+
"""Close the PostgreSQL connection."""
|
|
192
|
+
self._conn.close()
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
# ------------------------------------------------------------------
|
|
196
|
+
# Helpers
|
|
197
|
+
# ------------------------------------------------------------------
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _row_to_entry(row: tuple) -> "LocalDatasetEntry": # noqa: F821
|
|
201
|
+
"""Convert a database row to a ``LocalDatasetEntry``."""
|
|
202
|
+
from ..local import LocalDatasetEntry
|
|
203
|
+
|
|
204
|
+
cid, name, schema_ref, data_urls_blob, metadata_blob, legacy_uuid = row
|
|
205
|
+
return LocalDatasetEntry(
|
|
206
|
+
name=name,
|
|
207
|
+
schema_ref=schema_ref,
|
|
208
|
+
data_urls=msgpack.unpackb(bytes(data_urls_blob)),
|
|
209
|
+
metadata=msgpack.unpackb(bytes(metadata_blob))
|
|
210
|
+
if metadata_blob is not None
|
|
211
|
+
else None,
|
|
212
|
+
_cid=cid,
|
|
213
|
+
_legacy_uuid=legacy_uuid,
|
|
214
|
+
)
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""Redis-backed index provider.
|
|
2
|
+
|
|
3
|
+
This module extracts the Redis persistence logic that was previously
|
|
4
|
+
inlined in ``atdata.local.Index`` and ``LocalDatasetEntry`` into a
|
|
5
|
+
standalone ``IndexProvider`` implementation.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import Iterator
|
|
11
|
+
|
|
12
|
+
import msgpack
|
|
13
|
+
from redis import Redis
|
|
14
|
+
|
|
15
|
+
from ._base import IndexProvider
|
|
16
|
+
from .._type_utils import parse_semver
|
|
17
|
+
|
|
18
|
+
# Redis key prefixes — kept in sync with local.py constants
|
|
19
|
+
_KEY_DATASET_ENTRY = "LocalDatasetEntry"
|
|
20
|
+
_KEY_SCHEMA = "LocalSchema"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RedisProvider(IndexProvider):
|
|
24
|
+
"""Index provider backed by a Redis connection.
|
|
25
|
+
|
|
26
|
+
This reproduces the exact storage layout used by the original
|
|
27
|
+
``Index`` class so that existing Redis data is fully compatible.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
redis: An active ``redis.Redis`` connection.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, redis: Redis) -> None:
|
|
34
|
+
self._redis = redis
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def redis(self) -> Redis:
|
|
38
|
+
"""The underlying Redis connection (for advanced use / migration)."""
|
|
39
|
+
return self._redis
|
|
40
|
+
|
|
41
|
+
# ------------------------------------------------------------------
|
|
42
|
+
# Dataset entry operations
|
|
43
|
+
# ------------------------------------------------------------------
|
|
44
|
+
|
|
45
|
+
def store_entry(self, entry: "LocalDatasetEntry") -> None: # noqa: F821
|
|
46
|
+
save_key = f"{_KEY_DATASET_ENTRY}:{entry.cid}"
|
|
47
|
+
data: dict[str, str | bytes] = {
|
|
48
|
+
"name": entry.name,
|
|
49
|
+
"schema_ref": entry.schema_ref,
|
|
50
|
+
"data_urls": msgpack.packb(entry.data_urls),
|
|
51
|
+
"cid": entry.cid,
|
|
52
|
+
}
|
|
53
|
+
if entry.metadata is not None:
|
|
54
|
+
data["metadata"] = msgpack.packb(entry.metadata)
|
|
55
|
+
if entry._legacy_uuid is not None:
|
|
56
|
+
data["legacy_uuid"] = entry._legacy_uuid
|
|
57
|
+
|
|
58
|
+
self._redis.hset(save_key, mapping=data) # type: ignore[arg-type]
|
|
59
|
+
|
|
60
|
+
def get_entry_by_cid(self, cid: str) -> "LocalDatasetEntry": # noqa: F821
|
|
61
|
+
save_key = f"{_KEY_DATASET_ENTRY}:{cid}"
|
|
62
|
+
raw_data = self._redis.hgetall(save_key)
|
|
63
|
+
if not raw_data:
|
|
64
|
+
raise KeyError(f"{_KEY_DATASET_ENTRY} not found: {cid}")
|
|
65
|
+
|
|
66
|
+
return _entry_from_redis_hash(raw_data)
|
|
67
|
+
|
|
68
|
+
def get_entry_by_name(self, name: str) -> "LocalDatasetEntry": # noqa: F821
|
|
69
|
+
for entry in self.iter_entries():
|
|
70
|
+
if entry.name == name:
|
|
71
|
+
return entry
|
|
72
|
+
raise KeyError(f"No entry with name: {name}")
|
|
73
|
+
|
|
74
|
+
def iter_entries(self) -> Iterator["LocalDatasetEntry"]: # noqa: F821
|
|
75
|
+
prefix = f"{_KEY_DATASET_ENTRY}:"
|
|
76
|
+
for key in self._redis.scan_iter(match=f"{prefix}*"):
|
|
77
|
+
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
|
|
78
|
+
cid = key_str[len(prefix) :]
|
|
79
|
+
yield self.get_entry_by_cid(cid)
|
|
80
|
+
|
|
81
|
+
# ------------------------------------------------------------------
|
|
82
|
+
# Schema operations
|
|
83
|
+
# ------------------------------------------------------------------
|
|
84
|
+
|
|
85
|
+
def store_schema(self, name: str, version: str, schema_json: str) -> None:
|
|
86
|
+
redis_key = f"{_KEY_SCHEMA}:{name}@{version}"
|
|
87
|
+
self._redis.set(redis_key, schema_json)
|
|
88
|
+
|
|
89
|
+
def get_schema_json(self, name: str, version: str) -> str | None:
|
|
90
|
+
redis_key = f"{_KEY_SCHEMA}:{name}@{version}"
|
|
91
|
+
value = self._redis.get(redis_key)
|
|
92
|
+
if value is None:
|
|
93
|
+
return None
|
|
94
|
+
if isinstance(value, bytes):
|
|
95
|
+
return value.decode("utf-8")
|
|
96
|
+
return value # type: ignore[return-value]
|
|
97
|
+
|
|
98
|
+
def iter_schemas(self) -> Iterator[tuple[str, str, str]]:
|
|
99
|
+
prefix = f"{_KEY_SCHEMA}:"
|
|
100
|
+
for key in self._redis.scan_iter(match=f"{prefix}*"):
|
|
101
|
+
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
|
|
102
|
+
schema_id = key_str[len(prefix) :]
|
|
103
|
+
|
|
104
|
+
if "@" not in schema_id:
|
|
105
|
+
continue
|
|
106
|
+
|
|
107
|
+
raw_name, version = schema_id.rsplit("@", 1)
|
|
108
|
+
# Handle legacy format: module.Class -> Class
|
|
109
|
+
if "." in raw_name:
|
|
110
|
+
raw_name = raw_name.rsplit(".", 1)[1]
|
|
111
|
+
|
|
112
|
+
value = self._redis.get(key)
|
|
113
|
+
if value is None:
|
|
114
|
+
continue
|
|
115
|
+
schema_json = value.decode("utf-8") if isinstance(value, bytes) else value
|
|
116
|
+
yield raw_name, version, schema_json # type: ignore[misc]
|
|
117
|
+
|
|
118
|
+
def find_latest_version(self, name: str) -> str | None:
|
|
119
|
+
latest: tuple[int, int, int] | None = None
|
|
120
|
+
latest_str: str | None = None
|
|
121
|
+
|
|
122
|
+
for schema_name, version, _ in self.iter_schemas():
|
|
123
|
+
if schema_name != name:
|
|
124
|
+
continue
|
|
125
|
+
try:
|
|
126
|
+
v = parse_semver(version)
|
|
127
|
+
if latest is None or v > latest:
|
|
128
|
+
latest = v
|
|
129
|
+
latest_str = version
|
|
130
|
+
except ValueError:
|
|
131
|
+
continue
|
|
132
|
+
|
|
133
|
+
return latest_str
|
|
134
|
+
|
|
135
|
+
# ------------------------------------------------------------------
|
|
136
|
+
# Lifecycle
|
|
137
|
+
# ------------------------------------------------------------------
|
|
138
|
+
|
|
139
|
+
def close(self) -> None:
|
|
140
|
+
"""Close the Redis connection."""
|
|
141
|
+
self._redis.close()
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# ------------------------------------------------------------------
|
|
145
|
+
# Helpers
|
|
146
|
+
# ------------------------------------------------------------------
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _entry_from_redis_hash(raw_data: dict) -> "LocalDatasetEntry": # noqa: F821
|
|
150
|
+
"""Reconstruct a ``LocalDatasetEntry`` from a Redis hash mapping."""
|
|
151
|
+
from ..local import LocalDatasetEntry
|
|
152
|
+
from typing import cast
|
|
153
|
+
|
|
154
|
+
raw = cast(dict[bytes, bytes], raw_data)
|
|
155
|
+
name = raw[b"name"].decode("utf-8")
|
|
156
|
+
schema_ref = raw[b"schema_ref"].decode("utf-8")
|
|
157
|
+
cid_value = raw.get(b"cid", b"").decode("utf-8") or None
|
|
158
|
+
legacy_uuid = raw.get(b"legacy_uuid", b"").decode("utf-8") or None
|
|
159
|
+
data_urls = msgpack.unpackb(raw[b"data_urls"])
|
|
160
|
+
metadata = None
|
|
161
|
+
if b"metadata" in raw:
|
|
162
|
+
metadata = msgpack.unpackb(raw[b"metadata"])
|
|
163
|
+
|
|
164
|
+
return LocalDatasetEntry(
|
|
165
|
+
name=name,
|
|
166
|
+
schema_ref=schema_ref,
|
|
167
|
+
data_urls=data_urls,
|
|
168
|
+
metadata=metadata,
|
|
169
|
+
_cid=cid_value,
|
|
170
|
+
_legacy_uuid=legacy_uuid,
|
|
171
|
+
)
|