atdata 0.2.3b1__py3-none-any.whl → 0.3.1b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/.gitignore +1 -0
- atdata/__init__.py +39 -0
- atdata/_cid.py +0 -21
- atdata/_exceptions.py +168 -0
- atdata/_helpers.py +41 -15
- atdata/_hf_api.py +95 -11
- atdata/_logging.py +70 -0
- atdata/_protocols.py +77 -238
- atdata/_schema_codec.py +7 -6
- atdata/_stub_manager.py +5 -25
- atdata/_type_utils.py +28 -2
- atdata/atmosphere/__init__.py +31 -20
- atdata/atmosphere/_types.py +4 -4
- atdata/atmosphere/client.py +64 -12
- atdata/atmosphere/lens.py +11 -12
- atdata/atmosphere/records.py +12 -12
- atdata/atmosphere/schema.py +16 -18
- atdata/atmosphere/store.py +6 -7
- atdata/cli/__init__.py +161 -175
- atdata/cli/diagnose.py +2 -2
- atdata/cli/{local.py → infra.py} +11 -11
- atdata/cli/inspect.py +69 -0
- atdata/cli/preview.py +63 -0
- atdata/cli/schema.py +109 -0
- atdata/dataset.py +583 -328
- atdata/index/__init__.py +54 -0
- atdata/index/_entry.py +157 -0
- atdata/index/_index.py +1198 -0
- atdata/index/_schema.py +380 -0
- atdata/lens.py +9 -2
- atdata/lexicons/__init__.py +121 -0
- atdata/lexicons/ac.foundation.dataset.arrayFormat.json +16 -0
- atdata/lexicons/ac.foundation.dataset.getLatestSchema.json +78 -0
- atdata/lexicons/ac.foundation.dataset.lens.json +99 -0
- atdata/lexicons/ac.foundation.dataset.record.json +96 -0
- atdata/lexicons/ac.foundation.dataset.schema.json +107 -0
- atdata/lexicons/ac.foundation.dataset.schemaType.json +16 -0
- atdata/lexicons/ac.foundation.dataset.storageBlobs.json +24 -0
- atdata/lexicons/ac.foundation.dataset.storageExternal.json +25 -0
- atdata/lexicons/ndarray_shim.json +16 -0
- atdata/local/__init__.py +70 -0
- atdata/local/_repo_legacy.py +218 -0
- atdata/manifest/__init__.py +28 -0
- atdata/manifest/_aggregates.py +156 -0
- atdata/manifest/_builder.py +163 -0
- atdata/manifest/_fields.py +154 -0
- atdata/manifest/_manifest.py +146 -0
- atdata/manifest/_query.py +150 -0
- atdata/manifest/_writer.py +74 -0
- atdata/promote.py +18 -14
- atdata/providers/__init__.py +25 -0
- atdata/providers/_base.py +140 -0
- atdata/providers/_factory.py +69 -0
- atdata/providers/_postgres.py +214 -0
- atdata/providers/_redis.py +171 -0
- atdata/providers/_sqlite.py +191 -0
- atdata/repository.py +323 -0
- atdata/stores/__init__.py +23 -0
- atdata/stores/_disk.py +123 -0
- atdata/stores/_s3.py +349 -0
- atdata/testing.py +341 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/METADATA +5 -2
- atdata-0.3.1b1.dist-info/RECORD +67 -0
- atdata/local.py +0 -1720
- atdata-0.2.3b1.dist-info/RECORD +0 -28
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Query executor for manifest-based dataset queries.
|
|
2
|
+
|
|
3
|
+
Provides two-phase filtering: shard-level pruning via aggregates,
|
|
4
|
+
then sample-level filtering via the parquet DataFrame.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Callable
|
|
12
|
+
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
from ._manifest import ShardManifest
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class SampleLocation:
|
|
20
|
+
"""Location of a sample within a shard.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
shard: Shard identifier or URL.
|
|
24
|
+
key: WebDataset ``__key__`` for the sample.
|
|
25
|
+
offset: Byte offset within the tar file.
|
|
26
|
+
|
|
27
|
+
Examples:
|
|
28
|
+
>>> loc = SampleLocation(shard="data/shard-000000", key="sample_00042", offset=52480)
|
|
29
|
+
>>> loc.shard
|
|
30
|
+
'data/shard-000000'
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
shard: str
|
|
34
|
+
key: str
|
|
35
|
+
offset: int
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class QueryExecutor:
|
|
39
|
+
"""Executes queries over per-shard manifests.
|
|
40
|
+
|
|
41
|
+
Performs two-phase filtering:
|
|
42
|
+
|
|
43
|
+
1. **Shard-level**: uses aggregates to skip shards that cannot contain
|
|
44
|
+
matching samples (e.g., numeric range exclusion, categorical value absence).
|
|
45
|
+
2. **Sample-level**: applies the predicate to the parquet DataFrame rows.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
manifests: List of ``ShardManifest`` objects to query over.
|
|
49
|
+
|
|
50
|
+
Examples:
|
|
51
|
+
>>> executor = QueryExecutor(manifests)
|
|
52
|
+
>>> results = executor.query(
|
|
53
|
+
... where=lambda df: (df["confidence"] > 0.9) & (df["label"].isin(["dog", "cat"]))
|
|
54
|
+
... )
|
|
55
|
+
>>> len(results)
|
|
56
|
+
42
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, manifests: list[ShardManifest]) -> None:
|
|
60
|
+
self._manifests = manifests
|
|
61
|
+
|
|
62
|
+
def query(
|
|
63
|
+
self,
|
|
64
|
+
where: Callable[[pd.DataFrame], pd.Series],
|
|
65
|
+
) -> list[SampleLocation]:
|
|
66
|
+
"""Execute a query across all manifests.
|
|
67
|
+
|
|
68
|
+
The ``where`` callable receives a pandas DataFrame with the per-sample
|
|
69
|
+
manifest columns and must return a boolean Series selecting matching rows.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
where: Predicate function. Receives a DataFrame, returns a boolean Series.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
List of ``SampleLocation`` for all matching samples.
|
|
76
|
+
"""
|
|
77
|
+
results: list[SampleLocation] = []
|
|
78
|
+
|
|
79
|
+
for manifest in self._manifests:
|
|
80
|
+
if manifest.samples.empty:
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
mask = where(manifest.samples)
|
|
84
|
+
matching = manifest.samples[mask]
|
|
85
|
+
|
|
86
|
+
for _, row in matching.iterrows():
|
|
87
|
+
results.append(
|
|
88
|
+
SampleLocation(
|
|
89
|
+
shard=manifest.shard_id,
|
|
90
|
+
key=row["__key__"],
|
|
91
|
+
offset=int(row["__offset__"]),
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return results
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def from_directory(cls, directory: str | Path) -> QueryExecutor:
|
|
99
|
+
"""Load all manifests from a directory.
|
|
100
|
+
|
|
101
|
+
Discovers ``*.manifest.json`` files and loads each with its
|
|
102
|
+
companion parquet file.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
directory: Path to scan for manifest files.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
A ``QueryExecutor`` loaded with all discovered manifests.
|
|
109
|
+
|
|
110
|
+
Raises:
|
|
111
|
+
FileNotFoundError: If the directory does not exist.
|
|
112
|
+
"""
|
|
113
|
+
directory = Path(directory)
|
|
114
|
+
manifests: list[ShardManifest] = []
|
|
115
|
+
|
|
116
|
+
for json_path in sorted(directory.glob("*.manifest.json")):
|
|
117
|
+
parquet_path = json_path.with_suffix("").with_suffix(".manifest.parquet")
|
|
118
|
+
if parquet_path.exists():
|
|
119
|
+
manifests.append(ShardManifest.from_files(json_path, parquet_path))
|
|
120
|
+
else:
|
|
121
|
+
manifests.append(ShardManifest.from_json_only(json_path))
|
|
122
|
+
|
|
123
|
+
return cls(manifests)
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def from_shard_urls(cls, shard_urls: list[str]) -> QueryExecutor:
|
|
127
|
+
"""Load manifests corresponding to a list of shard URLs.
|
|
128
|
+
|
|
129
|
+
Derives manifest paths by replacing the ``.tar`` extension with
|
|
130
|
+
``.manifest.json`` and ``.manifest.parquet``.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
shard_urls: List of shard file paths or URLs.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
A ``QueryExecutor`` with manifests for shards that have them.
|
|
137
|
+
"""
|
|
138
|
+
manifests: list[ShardManifest] = []
|
|
139
|
+
|
|
140
|
+
for url in shard_urls:
|
|
141
|
+
base = url.removesuffix(".tar")
|
|
142
|
+
json_path = Path(f"{base}.manifest.json")
|
|
143
|
+
parquet_path = Path(f"{base}.manifest.parquet")
|
|
144
|
+
|
|
145
|
+
if json_path.exists() and parquet_path.exists():
|
|
146
|
+
manifests.append(ShardManifest.from_files(json_path, parquet_path))
|
|
147
|
+
elif json_path.exists():
|
|
148
|
+
manifests.append(ShardManifest.from_json_only(json_path))
|
|
149
|
+
|
|
150
|
+
return cls(manifests)
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""ManifestWriter for serializing ShardManifest to JSON + parquet files."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from ._manifest import ShardManifest
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ManifestWriter:
|
|
12
|
+
"""Writes a ``ShardManifest`` to companion JSON and parquet files.
|
|
13
|
+
|
|
14
|
+
Produces two files alongside each shard:
|
|
15
|
+
|
|
16
|
+
- ``{base_path}.manifest.json`` -- header with metadata and aggregates
|
|
17
|
+
- ``{base_path}.manifest.parquet`` -- per-sample metadata (columnar)
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
base_path: The shard path without the ``.tar`` extension.
|
|
21
|
+
|
|
22
|
+
Examples:
|
|
23
|
+
>>> writer = ManifestWriter("/data/shard-000000")
|
|
24
|
+
>>> json_path, parquet_path = writer.write(manifest)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, base_path: str | Path) -> None:
|
|
28
|
+
self._base_path = Path(base_path)
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def json_path(self) -> Path:
|
|
32
|
+
"""Path for the JSON header file."""
|
|
33
|
+
return self._base_path.with_suffix(".manifest.json")
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def parquet_path(self) -> Path:
|
|
37
|
+
"""Path for the parquet per-sample file."""
|
|
38
|
+
return self._base_path.with_suffix(".manifest.parquet")
|
|
39
|
+
|
|
40
|
+
def write(self, manifest: ShardManifest) -> tuple[Path, Path]:
|
|
41
|
+
"""Write the manifest to JSON + parquet files.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
manifest: The ``ShardManifest`` to serialize.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Tuple of ``(json_path, parquet_path)``.
|
|
48
|
+
"""
|
|
49
|
+
json_out = self.json_path
|
|
50
|
+
parquet_out = self.parquet_path
|
|
51
|
+
|
|
52
|
+
# Ensure parent directory exists
|
|
53
|
+
json_out.parent.mkdir(parents=True, exist_ok=True)
|
|
54
|
+
|
|
55
|
+
# Write JSON header + aggregates
|
|
56
|
+
with open(json_out, "w", encoding="utf-8") as f:
|
|
57
|
+
json.dump(manifest.header_dict(), f, indent=2)
|
|
58
|
+
|
|
59
|
+
# Write per-sample parquet
|
|
60
|
+
if not manifest.samples.empty:
|
|
61
|
+
manifest.samples.to_parquet(
|
|
62
|
+
parquet_out,
|
|
63
|
+
engine="fastparquet",
|
|
64
|
+
index=False,
|
|
65
|
+
)
|
|
66
|
+
else:
|
|
67
|
+
# Write an empty parquet with no rows
|
|
68
|
+
manifest.samples.to_parquet(
|
|
69
|
+
parquet_out,
|
|
70
|
+
engine="fastparquet",
|
|
71
|
+
index=False,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return json_out, parquet_out
|
atdata/promote.py
CHANGED
|
@@ -5,30 +5,29 @@ ATProto atmosphere network. This enables sharing datasets with the broader
|
|
|
5
5
|
federation while maintaining schema consistency.
|
|
6
6
|
|
|
7
7
|
Examples:
|
|
8
|
-
>>> from atdata.local import
|
|
9
|
-
>>> from atdata.atmosphere import
|
|
8
|
+
>>> from atdata.local import Index, Repo
|
|
9
|
+
>>> from atdata.atmosphere import Atmosphere
|
|
10
10
|
>>> from atdata.promote import promote_to_atmosphere
|
|
11
11
|
>>>
|
|
12
12
|
>>> # Setup
|
|
13
|
-
>>> local_index =
|
|
14
|
-
>>>
|
|
15
|
-
>>> client.login("handle.bsky.social", "app-password")
|
|
13
|
+
>>> local_index = Index()
|
|
14
|
+
>>> atmo = Atmosphere.login("handle.bsky.social", "app-password")
|
|
16
15
|
>>>
|
|
17
16
|
>>> # Promote a dataset
|
|
18
17
|
>>> entry = local_index.get_dataset("my-dataset")
|
|
19
|
-
>>> at_uri = promote_to_atmosphere(entry, local_index,
|
|
18
|
+
>>> at_uri = promote_to_atmosphere(entry, local_index, atmo)
|
|
20
19
|
"""
|
|
21
20
|
|
|
22
21
|
from typing import TYPE_CHECKING, Type
|
|
23
22
|
|
|
24
23
|
if TYPE_CHECKING:
|
|
25
|
-
from .local import LocalDatasetEntry, Index
|
|
26
|
-
from .atmosphere import
|
|
24
|
+
from .local import LocalDatasetEntry, Index
|
|
25
|
+
from .atmosphere import Atmosphere
|
|
27
26
|
from ._protocols import AbstractDataStore, Packable
|
|
28
27
|
|
|
29
28
|
|
|
30
29
|
def _find_existing_schema(
|
|
31
|
-
client: "
|
|
30
|
+
client: "Atmosphere",
|
|
32
31
|
name: str,
|
|
33
32
|
version: str,
|
|
34
33
|
) -> str | None:
|
|
@@ -55,7 +54,7 @@ def _find_existing_schema(
|
|
|
55
54
|
def _find_or_publish_schema(
|
|
56
55
|
sample_type: "Type[Packable]",
|
|
57
56
|
version: str,
|
|
58
|
-
client: "
|
|
57
|
+
client: "Atmosphere",
|
|
59
58
|
description: str | None = None,
|
|
60
59
|
) -> str:
|
|
61
60
|
"""Find existing schema or publish a new one.
|
|
@@ -94,8 +93,8 @@ def _find_or_publish_schema(
|
|
|
94
93
|
|
|
95
94
|
def promote_to_atmosphere(
|
|
96
95
|
local_entry: "LocalDatasetEntry",
|
|
97
|
-
local_index: "
|
|
98
|
-
atmosphere_client: "
|
|
96
|
+
local_index: "Index",
|
|
97
|
+
atmosphere_client: "Atmosphere",
|
|
99
98
|
*,
|
|
100
99
|
data_store: "AbstractDataStore | None" = None,
|
|
101
100
|
name: str | None = None,
|
|
@@ -108,10 +107,15 @@ def promote_to_atmosphere(
|
|
|
108
107
|
This function takes a locally-indexed dataset and publishes it to ATProto,
|
|
109
108
|
making it discoverable on the federated atmosphere network.
|
|
110
109
|
|
|
110
|
+
.. deprecated::
|
|
111
|
+
Prefer ``Index.promote_entry()`` or ``Index.promote_dataset()``
|
|
112
|
+
which provide the same functionality through the unified Index
|
|
113
|
+
interface without requiring separate client and index arguments.
|
|
114
|
+
|
|
111
115
|
Args:
|
|
112
116
|
local_entry: The LocalDatasetEntry to promote.
|
|
113
117
|
local_index: Local index containing the schema for this entry.
|
|
114
|
-
atmosphere_client: Authenticated
|
|
118
|
+
atmosphere_client: Authenticated Atmosphere.
|
|
115
119
|
data_store: Optional data store for copying data to new location.
|
|
116
120
|
If None, the existing data_urls are used as-is.
|
|
117
121
|
name: Override name for the atmosphere record. Defaults to local name.
|
|
@@ -128,7 +132,7 @@ def promote_to_atmosphere(
|
|
|
128
132
|
|
|
129
133
|
Examples:
|
|
130
134
|
>>> entry = local_index.get_dataset("mnist-train")
|
|
131
|
-
>>> uri = promote_to_atmosphere(entry, local_index,
|
|
135
|
+
>>> uri = promote_to_atmosphere(entry, local_index, atmo)
|
|
132
136
|
>>> print(uri)
|
|
133
137
|
at://did:plc:abc123/ac.foundation.dataset.datasetIndex/...
|
|
134
138
|
"""
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Storage provider backends for the local Index.
|
|
2
|
+
|
|
3
|
+
This package defines the ``IndexProvider`` abstract base class and concrete
|
|
4
|
+
implementations for Redis, SQLite, and PostgreSQL. The ``Index`` class in
|
|
5
|
+
``atdata.local`` delegates all persistence to an ``IndexProvider``.
|
|
6
|
+
|
|
7
|
+
Providers:
|
|
8
|
+
RedisProvider: Redis-backed storage (existing default).
|
|
9
|
+
SqliteProvider: SQLite file-based storage (zero external dependencies).
|
|
10
|
+
PostgresProvider: PostgreSQL storage (requires ``psycopg``).
|
|
11
|
+
|
|
12
|
+
Examples:
|
|
13
|
+
>>> from atdata.providers import IndexProvider, create_provider
|
|
14
|
+
>>> provider = create_provider("sqlite", path="~/.atdata/index.db")
|
|
15
|
+
>>> from atdata.local import Index
|
|
16
|
+
>>> index = Index(provider=provider)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from ._base import IndexProvider
|
|
20
|
+
from ._factory import create_provider
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"IndexProvider",
|
|
24
|
+
"create_provider",
|
|
25
|
+
]
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""Abstract base class for index storage providers.
|
|
2
|
+
|
|
3
|
+
The ``IndexProvider`` ABC defines the persistence contract that the ``Index``
|
|
4
|
+
class delegates to. Each provider handles storage and retrieval of two entity
|
|
5
|
+
types — dataset entries and schema records — using whatever backend it wraps.
|
|
6
|
+
|
|
7
|
+
Concrete implementations live in sibling modules:
|
|
8
|
+
``_redis.py``, ``_sqlite.py``, ``_postgres.py``
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from abc import ABC, abstractmethod
|
|
14
|
+
from typing import TYPE_CHECKING, Iterator
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from ..local import LocalDatasetEntry
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class IndexProvider(ABC):
|
|
21
|
+
"""Storage backend for the ``Index`` class.
|
|
22
|
+
|
|
23
|
+
Implementations persist ``LocalDatasetEntry`` objects and schema JSON
|
|
24
|
+
records. The ``Index`` class owns all business logic (CID generation,
|
|
25
|
+
version bumping, schema building); the provider is a pure persistence
|
|
26
|
+
layer.
|
|
27
|
+
|
|
28
|
+
Examples:
|
|
29
|
+
>>> from atdata.providers import create_provider
|
|
30
|
+
>>> provider = create_provider("sqlite", path="/tmp/index.db")
|
|
31
|
+
>>> provider.store_schema("MySample", "1.0.0", '{"name": "MySample"}')
|
|
32
|
+
>>> provider.get_schema_json("MySample", "1.0.0")
|
|
33
|
+
'{"name": "MySample"}'
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
# ------------------------------------------------------------------
|
|
37
|
+
# Dataset entry operations
|
|
38
|
+
# ------------------------------------------------------------------
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def store_entry(self, entry: LocalDatasetEntry) -> None:
|
|
42
|
+
"""Persist a dataset entry (upsert by CID).
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
entry: The dataset entry to store. The entry's ``cid`` property
|
|
46
|
+
is used as the primary key.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def get_entry_by_cid(self, cid: str) -> LocalDatasetEntry:
|
|
51
|
+
"""Load a dataset entry by its content identifier.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
cid: Content-addressable identifier.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
The matching ``LocalDatasetEntry``.
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
KeyError: If no entry exists for *cid*.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
def get_entry_by_name(self, name: str) -> LocalDatasetEntry:
|
|
65
|
+
"""Load a dataset entry by its human-readable name.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
name: Dataset name.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
The first matching ``LocalDatasetEntry``.
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
KeyError: If no entry exists with *name*.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def iter_entries(self) -> Iterator[LocalDatasetEntry]:
|
|
79
|
+
"""Iterate over all stored dataset entries.
|
|
80
|
+
|
|
81
|
+
Yields:
|
|
82
|
+
``LocalDatasetEntry`` objects in unspecified order.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
# ------------------------------------------------------------------
|
|
86
|
+
# Schema operations
|
|
87
|
+
# ------------------------------------------------------------------
|
|
88
|
+
|
|
89
|
+
@abstractmethod
|
|
90
|
+
def store_schema(self, name: str, version: str, schema_json: str) -> None:
|
|
91
|
+
"""Persist a schema record (upsert by name + version).
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
name: Schema name (e.g. ``"MySample"``).
|
|
95
|
+
version: Semantic version string (e.g. ``"1.0.0"``).
|
|
96
|
+
schema_json: JSON-serialized schema record.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
@abstractmethod
|
|
100
|
+
def get_schema_json(self, name: str, version: str) -> str | None:
|
|
101
|
+
"""Load a schema's JSON by name and version.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
name: Schema name.
|
|
105
|
+
version: Semantic version string.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
The JSON string, or ``None`` if not found.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
@abstractmethod
|
|
112
|
+
def iter_schemas(self) -> Iterator[tuple[str, str, str]]:
|
|
113
|
+
"""Iterate over all stored schemas.
|
|
114
|
+
|
|
115
|
+
Yields:
|
|
116
|
+
Tuples of ``(name, version, schema_json)``.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
@abstractmethod
|
|
120
|
+
def find_latest_version(self, name: str) -> str | None:
|
|
121
|
+
"""Find the latest semantic version for a schema name.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
name: Schema name to search for.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
The latest version string (e.g. ``"1.2.3"``), or ``None``
|
|
128
|
+
if no schema with *name* exists.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
# ------------------------------------------------------------------
|
|
132
|
+
# Lifecycle
|
|
133
|
+
# ------------------------------------------------------------------
|
|
134
|
+
|
|
135
|
+
def close(self) -> None:
|
|
136
|
+
"""Release any resources held by the provider.
|
|
137
|
+
|
|
138
|
+
The default implementation is a no-op. Providers that hold
|
|
139
|
+
connections (SQLite, PostgreSQL) should override this.
|
|
140
|
+
"""
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Factory for creating index providers by name.
|
|
2
|
+
|
|
3
|
+
Examples:
|
|
4
|
+
>>> from atdata.providers._factory import create_provider
|
|
5
|
+
>>> provider = create_provider("sqlite", path="/tmp/index.db")
|
|
6
|
+
>>> provider = create_provider("redis")
|
|
7
|
+
>>> provider = create_provider("postgres", dsn="postgresql://localhost/mydb")
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from ._base import IndexProvider
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def create_provider(
|
|
19
|
+
name: str,
|
|
20
|
+
*,
|
|
21
|
+
path: str | Path | None = None,
|
|
22
|
+
dsn: str | None = None,
|
|
23
|
+
redis: Any = None,
|
|
24
|
+
**kwargs: Any,
|
|
25
|
+
) -> IndexProvider:
|
|
26
|
+
"""Instantiate an ``IndexProvider`` by backend name.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
name: One of ``"redis"``, ``"sqlite"``, or ``"postgres"``.
|
|
30
|
+
path: Database file path (SQLite). Defaults to
|
|
31
|
+
``~/.atdata/index.db`` when *name* is ``"sqlite"``.
|
|
32
|
+
dsn: Connection string (PostgreSQL).
|
|
33
|
+
redis: An existing ``redis.Redis`` connection (Redis). When
|
|
34
|
+
``None`` and *name* is ``"redis"``, a new connection is
|
|
35
|
+
created from *kwargs*.
|
|
36
|
+
**kwargs: Extra arguments forwarded to the provider constructor
|
|
37
|
+
(e.g. Redis host/port).
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A ready-to-use ``IndexProvider``.
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
ValueError: If *name* is not a recognised backend.
|
|
44
|
+
"""
|
|
45
|
+
name = name.lower().strip()
|
|
46
|
+
|
|
47
|
+
if name == "redis":
|
|
48
|
+
from ._redis import RedisProvider
|
|
49
|
+
from redis import Redis as _Redis
|
|
50
|
+
|
|
51
|
+
if redis is not None:
|
|
52
|
+
return RedisProvider(redis)
|
|
53
|
+
return RedisProvider(_Redis(**kwargs))
|
|
54
|
+
|
|
55
|
+
if name == "sqlite":
|
|
56
|
+
from ._sqlite import SqliteProvider
|
|
57
|
+
|
|
58
|
+
return SqliteProvider(path=path)
|
|
59
|
+
|
|
60
|
+
if name in ("postgres", "postgresql"):
|
|
61
|
+
from ._postgres import PostgresProvider
|
|
62
|
+
|
|
63
|
+
if dsn is None:
|
|
64
|
+
raise ValueError("dsn is required for the postgres provider")
|
|
65
|
+
return PostgresProvider(dsn=dsn)
|
|
66
|
+
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"Unknown provider {name!r}. Choose from: 'redis', 'sqlite', 'postgres'."
|
|
69
|
+
)
|