atdata 0.2.3b1__py3-none-any.whl → 0.3.1b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. atdata/.gitignore +1 -0
  2. atdata/__init__.py +39 -0
  3. atdata/_cid.py +0 -21
  4. atdata/_exceptions.py +168 -0
  5. atdata/_helpers.py +41 -15
  6. atdata/_hf_api.py +95 -11
  7. atdata/_logging.py +70 -0
  8. atdata/_protocols.py +77 -238
  9. atdata/_schema_codec.py +7 -6
  10. atdata/_stub_manager.py +5 -25
  11. atdata/_type_utils.py +28 -2
  12. atdata/atmosphere/__init__.py +31 -20
  13. atdata/atmosphere/_types.py +4 -4
  14. atdata/atmosphere/client.py +64 -12
  15. atdata/atmosphere/lens.py +11 -12
  16. atdata/atmosphere/records.py +12 -12
  17. atdata/atmosphere/schema.py +16 -18
  18. atdata/atmosphere/store.py +6 -7
  19. atdata/cli/__init__.py +161 -175
  20. atdata/cli/diagnose.py +2 -2
  21. atdata/cli/{local.py → infra.py} +11 -11
  22. atdata/cli/inspect.py +69 -0
  23. atdata/cli/preview.py +63 -0
  24. atdata/cli/schema.py +109 -0
  25. atdata/dataset.py +583 -328
  26. atdata/index/__init__.py +54 -0
  27. atdata/index/_entry.py +157 -0
  28. atdata/index/_index.py +1198 -0
  29. atdata/index/_schema.py +380 -0
  30. atdata/lens.py +9 -2
  31. atdata/lexicons/__init__.py +121 -0
  32. atdata/lexicons/ac.foundation.dataset.arrayFormat.json +16 -0
  33. atdata/lexicons/ac.foundation.dataset.getLatestSchema.json +78 -0
  34. atdata/lexicons/ac.foundation.dataset.lens.json +99 -0
  35. atdata/lexicons/ac.foundation.dataset.record.json +96 -0
  36. atdata/lexicons/ac.foundation.dataset.schema.json +107 -0
  37. atdata/lexicons/ac.foundation.dataset.schemaType.json +16 -0
  38. atdata/lexicons/ac.foundation.dataset.storageBlobs.json +24 -0
  39. atdata/lexicons/ac.foundation.dataset.storageExternal.json +25 -0
  40. atdata/lexicons/ndarray_shim.json +16 -0
  41. atdata/local/__init__.py +70 -0
  42. atdata/local/_repo_legacy.py +218 -0
  43. atdata/manifest/__init__.py +28 -0
  44. atdata/manifest/_aggregates.py +156 -0
  45. atdata/manifest/_builder.py +163 -0
  46. atdata/manifest/_fields.py +154 -0
  47. atdata/manifest/_manifest.py +146 -0
  48. atdata/manifest/_query.py +150 -0
  49. atdata/manifest/_writer.py +74 -0
  50. atdata/promote.py +18 -14
  51. atdata/providers/__init__.py +25 -0
  52. atdata/providers/_base.py +140 -0
  53. atdata/providers/_factory.py +69 -0
  54. atdata/providers/_postgres.py +214 -0
  55. atdata/providers/_redis.py +171 -0
  56. atdata/providers/_sqlite.py +191 -0
  57. atdata/repository.py +323 -0
  58. atdata/stores/__init__.py +23 -0
  59. atdata/stores/_disk.py +123 -0
  60. atdata/stores/_s3.py +349 -0
  61. atdata/testing.py +341 -0
  62. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/METADATA +5 -2
  63. atdata-0.3.1b1.dist-info/RECORD +67 -0
  64. atdata/local.py +0 -1720
  65. atdata-0.2.3b1.dist-info/RECORD +0 -28
  66. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/WHEEL +0 -0
  67. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/entry_points.txt +0 -0
  68. {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,150 @@
1
+ """Query executor for manifest-based dataset queries.
2
+
3
+ Provides two-phase filtering: shard-level pruning via aggregates,
4
+ then sample-level filtering via the parquet DataFrame.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Callable
12
+
13
+ import pandas as pd
14
+
15
+ from ._manifest import ShardManifest
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class SampleLocation:
20
+ """Location of a sample within a shard.
21
+
22
+ Attributes:
23
+ shard: Shard identifier or URL.
24
+ key: WebDataset ``__key__`` for the sample.
25
+ offset: Byte offset within the tar file.
26
+
27
+ Examples:
28
+ >>> loc = SampleLocation(shard="data/shard-000000", key="sample_00042", offset=52480)
29
+ >>> loc.shard
30
+ 'data/shard-000000'
31
+ """
32
+
33
+ shard: str
34
+ key: str
35
+ offset: int
36
+
37
+
38
+ class QueryExecutor:
39
+ """Executes queries over per-shard manifests.
40
+
41
+ Performs two-phase filtering:
42
+
43
+ 1. **Shard-level**: uses aggregates to skip shards that cannot contain
44
+ matching samples (e.g., numeric range exclusion, categorical value absence).
45
+ 2. **Sample-level**: applies the predicate to the parquet DataFrame rows.
46
+
47
+ Args:
48
+ manifests: List of ``ShardManifest`` objects to query over.
49
+
50
+ Examples:
51
+ >>> executor = QueryExecutor(manifests)
52
+ >>> results = executor.query(
53
+ ... where=lambda df: (df["confidence"] > 0.9) & (df["label"].isin(["dog", "cat"]))
54
+ ... )
55
+ >>> len(results)
56
+ 42
57
+ """
58
+
59
+ def __init__(self, manifests: list[ShardManifest]) -> None:
60
+ self._manifests = manifests
61
+
62
+ def query(
63
+ self,
64
+ where: Callable[[pd.DataFrame], pd.Series],
65
+ ) -> list[SampleLocation]:
66
+ """Execute a query across all manifests.
67
+
68
+ The ``where`` callable receives a pandas DataFrame with the per-sample
69
+ manifest columns and must return a boolean Series selecting matching rows.
70
+
71
+ Args:
72
+ where: Predicate function. Receives a DataFrame, returns a boolean Series.
73
+
74
+ Returns:
75
+ List of ``SampleLocation`` for all matching samples.
76
+ """
77
+ results: list[SampleLocation] = []
78
+
79
+ for manifest in self._manifests:
80
+ if manifest.samples.empty:
81
+ continue
82
+
83
+ mask = where(manifest.samples)
84
+ matching = manifest.samples[mask]
85
+
86
+ for _, row in matching.iterrows():
87
+ results.append(
88
+ SampleLocation(
89
+ shard=manifest.shard_id,
90
+ key=row["__key__"],
91
+ offset=int(row["__offset__"]),
92
+ )
93
+ )
94
+
95
+ return results
96
+
97
+ @classmethod
98
+ def from_directory(cls, directory: str | Path) -> QueryExecutor:
99
+ """Load all manifests from a directory.
100
+
101
+ Discovers ``*.manifest.json`` files and loads each with its
102
+ companion parquet file.
103
+
104
+ Args:
105
+ directory: Path to scan for manifest files.
106
+
107
+ Returns:
108
+ A ``QueryExecutor`` loaded with all discovered manifests.
109
+
110
+ Raises:
111
+ FileNotFoundError: If the directory does not exist.
112
+ """
113
+ directory = Path(directory)
114
+ manifests: list[ShardManifest] = []
115
+
116
+ for json_path in sorted(directory.glob("*.manifest.json")):
117
+ parquet_path = json_path.with_suffix("").with_suffix(".manifest.parquet")
118
+ if parquet_path.exists():
119
+ manifests.append(ShardManifest.from_files(json_path, parquet_path))
120
+ else:
121
+ manifests.append(ShardManifest.from_json_only(json_path))
122
+
123
+ return cls(manifests)
124
+
125
+ @classmethod
126
+ def from_shard_urls(cls, shard_urls: list[str]) -> QueryExecutor:
127
+ """Load manifests corresponding to a list of shard URLs.
128
+
129
+ Derives manifest paths by replacing the ``.tar`` extension with
130
+ ``.manifest.json`` and ``.manifest.parquet``.
131
+
132
+ Args:
133
+ shard_urls: List of shard file paths or URLs.
134
+
135
+ Returns:
136
+ A ``QueryExecutor`` with manifests for shards that have them.
137
+ """
138
+ manifests: list[ShardManifest] = []
139
+
140
+ for url in shard_urls:
141
+ base = url.removesuffix(".tar")
142
+ json_path = Path(f"{base}.manifest.json")
143
+ parquet_path = Path(f"{base}.manifest.parquet")
144
+
145
+ if json_path.exists() and parquet_path.exists():
146
+ manifests.append(ShardManifest.from_files(json_path, parquet_path))
147
+ elif json_path.exists():
148
+ manifests.append(ShardManifest.from_json_only(json_path))
149
+
150
+ return cls(manifests)
@@ -0,0 +1,74 @@
1
+ """ManifestWriter for serializing ShardManifest to JSON + parquet files."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+
8
+ from ._manifest import ShardManifest
9
+
10
+
11
+ class ManifestWriter:
12
+ """Writes a ``ShardManifest`` to companion JSON and parquet files.
13
+
14
+ Produces two files alongside each shard:
15
+
16
+ - ``{base_path}.manifest.json`` -- header with metadata and aggregates
17
+ - ``{base_path}.manifest.parquet`` -- per-sample metadata (columnar)
18
+
19
+ Args:
20
+ base_path: The shard path without the ``.tar`` extension.
21
+
22
+ Examples:
23
+ >>> writer = ManifestWriter("/data/shard-000000")
24
+ >>> json_path, parquet_path = writer.write(manifest)
25
+ """
26
+
27
+ def __init__(self, base_path: str | Path) -> None:
28
+ self._base_path = Path(base_path)
29
+
30
+ @property
31
+ def json_path(self) -> Path:
32
+ """Path for the JSON header file."""
33
+ return self._base_path.with_suffix(".manifest.json")
34
+
35
+ @property
36
+ def parquet_path(self) -> Path:
37
+ """Path for the parquet per-sample file."""
38
+ return self._base_path.with_suffix(".manifest.parquet")
39
+
40
+ def write(self, manifest: ShardManifest) -> tuple[Path, Path]:
41
+ """Write the manifest to JSON + parquet files.
42
+
43
+ Args:
44
+ manifest: The ``ShardManifest`` to serialize.
45
+
46
+ Returns:
47
+ Tuple of ``(json_path, parquet_path)``.
48
+ """
49
+ json_out = self.json_path
50
+ parquet_out = self.parquet_path
51
+
52
+ # Ensure parent directory exists
53
+ json_out.parent.mkdir(parents=True, exist_ok=True)
54
+
55
+ # Write JSON header + aggregates
56
+ with open(json_out, "w", encoding="utf-8") as f:
57
+ json.dump(manifest.header_dict(), f, indent=2)
58
+
59
+ # Write per-sample parquet
60
+ if not manifest.samples.empty:
61
+ manifest.samples.to_parquet(
62
+ parquet_out,
63
+ engine="fastparquet",
64
+ index=False,
65
+ )
66
+ else:
67
+ # Write an empty parquet with no rows
68
+ manifest.samples.to_parquet(
69
+ parquet_out,
70
+ engine="fastparquet",
71
+ index=False,
72
+ )
73
+
74
+ return json_out, parquet_out
atdata/promote.py CHANGED
@@ -5,30 +5,29 @@ ATProto atmosphere network. This enables sharing datasets with the broader
5
5
  federation while maintaining schema consistency.
6
6
 
7
7
  Examples:
8
- >>> from atdata.local import LocalIndex, Repo
9
- >>> from atdata.atmosphere import AtmosphereClient, AtmosphereIndex
8
+ >>> from atdata.local import Index, Repo
9
+ >>> from atdata.atmosphere import Atmosphere
10
10
  >>> from atdata.promote import promote_to_atmosphere
11
11
  >>>
12
12
  >>> # Setup
13
- >>> local_index = LocalIndex()
14
- >>> client = AtmosphereClient()
15
- >>> client.login("handle.bsky.social", "app-password")
13
+ >>> local_index = Index()
14
+ >>> atmo = Atmosphere.login("handle.bsky.social", "app-password")
16
15
  >>>
17
16
  >>> # Promote a dataset
18
17
  >>> entry = local_index.get_dataset("my-dataset")
19
- >>> at_uri = promote_to_atmosphere(entry, local_index, client)
18
+ >>> at_uri = promote_to_atmosphere(entry, local_index, atmo)
20
19
  """
21
20
 
22
21
  from typing import TYPE_CHECKING, Type
23
22
 
24
23
  if TYPE_CHECKING:
25
- from .local import LocalDatasetEntry, Index as LocalIndex
26
- from .atmosphere import AtmosphereClient
24
+ from .local import LocalDatasetEntry, Index
25
+ from .atmosphere import Atmosphere
27
26
  from ._protocols import AbstractDataStore, Packable
28
27
 
29
28
 
30
29
  def _find_existing_schema(
31
- client: "AtmosphereClient",
30
+ client: "Atmosphere",
32
31
  name: str,
33
32
  version: str,
34
33
  ) -> str | None:
@@ -55,7 +54,7 @@ def _find_existing_schema(
55
54
  def _find_or_publish_schema(
56
55
  sample_type: "Type[Packable]",
57
56
  version: str,
58
- client: "AtmosphereClient",
57
+ client: "Atmosphere",
59
58
  description: str | None = None,
60
59
  ) -> str:
61
60
  """Find existing schema or publish a new one.
@@ -94,8 +93,8 @@ def _find_or_publish_schema(
94
93
 
95
94
  def promote_to_atmosphere(
96
95
  local_entry: "LocalDatasetEntry",
97
- local_index: "LocalIndex",
98
- atmosphere_client: "AtmosphereClient",
96
+ local_index: "Index",
97
+ atmosphere_client: "Atmosphere",
99
98
  *,
100
99
  data_store: "AbstractDataStore | None" = None,
101
100
  name: str | None = None,
@@ -108,10 +107,15 @@ def promote_to_atmosphere(
108
107
  This function takes a locally-indexed dataset and publishes it to ATProto,
109
108
  making it discoverable on the federated atmosphere network.
110
109
 
110
+ .. deprecated::
111
+ Prefer ``Index.promote_entry()`` or ``Index.promote_dataset()``
112
+ which provide the same functionality through the unified Index
113
+ interface without requiring separate client and index arguments.
114
+
111
115
  Args:
112
116
  local_entry: The LocalDatasetEntry to promote.
113
117
  local_index: Local index containing the schema for this entry.
114
- atmosphere_client: Authenticated AtmosphereClient.
118
+ atmosphere_client: Authenticated Atmosphere.
115
119
  data_store: Optional data store for copying data to new location.
116
120
  If None, the existing data_urls are used as-is.
117
121
  name: Override name for the atmosphere record. Defaults to local name.
@@ -128,7 +132,7 @@ def promote_to_atmosphere(
128
132
 
129
133
  Examples:
130
134
  >>> entry = local_index.get_dataset("mnist-train")
131
- >>> uri = promote_to_atmosphere(entry, local_index, client)
135
+ >>> uri = promote_to_atmosphere(entry, local_index, atmo)
132
136
  >>> print(uri)
133
137
  at://did:plc:abc123/ac.foundation.dataset.datasetIndex/...
134
138
  """
@@ -0,0 +1,25 @@
1
+ """Storage provider backends for the local Index.
2
+
3
+ This package defines the ``IndexProvider`` abstract base class and concrete
4
+ implementations for Redis, SQLite, and PostgreSQL. The ``Index`` class in
5
+ ``atdata.local`` delegates all persistence to an ``IndexProvider``.
6
+
7
+ Providers:
8
+ RedisProvider: Redis-backed storage (existing default).
9
+ SqliteProvider: SQLite file-based storage (zero external dependencies).
10
+ PostgresProvider: PostgreSQL storage (requires ``psycopg``).
11
+
12
+ Examples:
13
+ >>> from atdata.providers import IndexProvider, create_provider
14
+ >>> provider = create_provider("sqlite", path="~/.atdata/index.db")
15
+ >>> from atdata.local import Index
16
+ >>> index = Index(provider=provider)
17
+ """
18
+
19
+ from ._base import IndexProvider
20
+ from ._factory import create_provider
21
+
22
+ __all__ = [
23
+ "IndexProvider",
24
+ "create_provider",
25
+ ]
@@ -0,0 +1,140 @@
1
+ """Abstract base class for index storage providers.
2
+
3
+ The ``IndexProvider`` ABC defines the persistence contract that the ``Index``
4
+ class delegates to. Each provider handles storage and retrieval of two entity
5
+ types — dataset entries and schema records — using whatever backend it wraps.
6
+
7
+ Concrete implementations live in sibling modules:
8
+ ``_redis.py``, ``_sqlite.py``, ``_postgres.py``
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from abc import ABC, abstractmethod
14
+ from typing import TYPE_CHECKING, Iterator
15
+
16
+ if TYPE_CHECKING:
17
+ from ..local import LocalDatasetEntry
18
+
19
+
20
+ class IndexProvider(ABC):
21
+ """Storage backend for the ``Index`` class.
22
+
23
+ Implementations persist ``LocalDatasetEntry`` objects and schema JSON
24
+ records. The ``Index`` class owns all business logic (CID generation,
25
+ version bumping, schema building); the provider is a pure persistence
26
+ layer.
27
+
28
+ Examples:
29
+ >>> from atdata.providers import create_provider
30
+ >>> provider = create_provider("sqlite", path="/tmp/index.db")
31
+ >>> provider.store_schema("MySample", "1.0.0", '{"name": "MySample"}')
32
+ >>> provider.get_schema_json("MySample", "1.0.0")
33
+ '{"name": "MySample"}'
34
+ """
35
+
36
+ # ------------------------------------------------------------------
37
+ # Dataset entry operations
38
+ # ------------------------------------------------------------------
39
+
40
+ @abstractmethod
41
+ def store_entry(self, entry: LocalDatasetEntry) -> None:
42
+ """Persist a dataset entry (upsert by CID).
43
+
44
+ Args:
45
+ entry: The dataset entry to store. The entry's ``cid`` property
46
+ is used as the primary key.
47
+ """
48
+
49
+ @abstractmethod
50
+ def get_entry_by_cid(self, cid: str) -> LocalDatasetEntry:
51
+ """Load a dataset entry by its content identifier.
52
+
53
+ Args:
54
+ cid: Content-addressable identifier.
55
+
56
+ Returns:
57
+ The matching ``LocalDatasetEntry``.
58
+
59
+ Raises:
60
+ KeyError: If no entry exists for *cid*.
61
+ """
62
+
63
+ @abstractmethod
64
+ def get_entry_by_name(self, name: str) -> LocalDatasetEntry:
65
+ """Load a dataset entry by its human-readable name.
66
+
67
+ Args:
68
+ name: Dataset name.
69
+
70
+ Returns:
71
+ The first matching ``LocalDatasetEntry``.
72
+
73
+ Raises:
74
+ KeyError: If no entry exists with *name*.
75
+ """
76
+
77
+ @abstractmethod
78
+ def iter_entries(self) -> Iterator[LocalDatasetEntry]:
79
+ """Iterate over all stored dataset entries.
80
+
81
+ Yields:
82
+ ``LocalDatasetEntry`` objects in unspecified order.
83
+ """
84
+
85
+ # ------------------------------------------------------------------
86
+ # Schema operations
87
+ # ------------------------------------------------------------------
88
+
89
+ @abstractmethod
90
+ def store_schema(self, name: str, version: str, schema_json: str) -> None:
91
+ """Persist a schema record (upsert by name + version).
92
+
93
+ Args:
94
+ name: Schema name (e.g. ``"MySample"``).
95
+ version: Semantic version string (e.g. ``"1.0.0"``).
96
+ schema_json: JSON-serialized schema record.
97
+ """
98
+
99
+ @abstractmethod
100
+ def get_schema_json(self, name: str, version: str) -> str | None:
101
+ """Load a schema's JSON by name and version.
102
+
103
+ Args:
104
+ name: Schema name.
105
+ version: Semantic version string.
106
+
107
+ Returns:
108
+ The JSON string, or ``None`` if not found.
109
+ """
110
+
111
+ @abstractmethod
112
+ def iter_schemas(self) -> Iterator[tuple[str, str, str]]:
113
+ """Iterate over all stored schemas.
114
+
115
+ Yields:
116
+ Tuples of ``(name, version, schema_json)``.
117
+ """
118
+
119
+ @abstractmethod
120
+ def find_latest_version(self, name: str) -> str | None:
121
+ """Find the latest semantic version for a schema name.
122
+
123
+ Args:
124
+ name: Schema name to search for.
125
+
126
+ Returns:
127
+ The latest version string (e.g. ``"1.2.3"``), or ``None``
128
+ if no schema with *name* exists.
129
+ """
130
+
131
+ # ------------------------------------------------------------------
132
+ # Lifecycle
133
+ # ------------------------------------------------------------------
134
+
135
+ def close(self) -> None:
136
+ """Release any resources held by the provider.
137
+
138
+ The default implementation is a no-op. Providers that hold
139
+ connections (SQLite, PostgreSQL) should override this.
140
+ """
@@ -0,0 +1,69 @@
1
+ """Factory for creating index providers by name.
2
+
3
+ Examples:
4
+ >>> from atdata.providers._factory import create_provider
5
+ >>> provider = create_provider("sqlite", path="/tmp/index.db")
6
+ >>> provider = create_provider("redis")
7
+ >>> provider = create_provider("postgres", dsn="postgresql://localhost/mydb")
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ from ._base import IndexProvider
16
+
17
+
18
+ def create_provider(
19
+ name: str,
20
+ *,
21
+ path: str | Path | None = None,
22
+ dsn: str | None = None,
23
+ redis: Any = None,
24
+ **kwargs: Any,
25
+ ) -> IndexProvider:
26
+ """Instantiate an ``IndexProvider`` by backend name.
27
+
28
+ Args:
29
+ name: One of ``"redis"``, ``"sqlite"``, or ``"postgres"``.
30
+ path: Database file path (SQLite). Defaults to
31
+ ``~/.atdata/index.db`` when *name* is ``"sqlite"``.
32
+ dsn: Connection string (PostgreSQL).
33
+ redis: An existing ``redis.Redis`` connection (Redis). When
34
+ ``None`` and *name* is ``"redis"``, a new connection is
35
+ created from *kwargs*.
36
+ **kwargs: Extra arguments forwarded to the provider constructor
37
+ (e.g. Redis host/port).
38
+
39
+ Returns:
40
+ A ready-to-use ``IndexProvider``.
41
+
42
+ Raises:
43
+ ValueError: If *name* is not a recognised backend.
44
+ """
45
+ name = name.lower().strip()
46
+
47
+ if name == "redis":
48
+ from ._redis import RedisProvider
49
+ from redis import Redis as _Redis
50
+
51
+ if redis is not None:
52
+ return RedisProvider(redis)
53
+ return RedisProvider(_Redis(**kwargs))
54
+
55
+ if name == "sqlite":
56
+ from ._sqlite import SqliteProvider
57
+
58
+ return SqliteProvider(path=path)
59
+
60
+ if name in ("postgres", "postgresql"):
61
+ from ._postgres import PostgresProvider
62
+
63
+ if dsn is None:
64
+ raise ValueError("dsn is required for the postgres provider")
65
+ return PostgresProvider(dsn=dsn)
66
+
67
+ raise ValueError(
68
+ f"Unknown provider {name!r}. Choose from: 'redis', 'sqlite', 'postgres'."
69
+ )