juniper-data 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- juniper_data/__init__.py +88 -0
- juniper_data/__main__.py +78 -0
- juniper_data/api/__init__.py +10 -0
- juniper_data/api/app.py +111 -0
- juniper_data/api/middleware.py +95 -0
- juniper_data/api/routes/__init__.py +9 -0
- juniper_data/api/routes/datasets.py +414 -0
- juniper_data/api/routes/generators.py +125 -0
- juniper_data/api/routes/health.py +49 -0
- juniper_data/api/security.py +238 -0
- juniper_data/api/settings.py +109 -0
- juniper_data/core/__init__.py +32 -0
- juniper_data/core/artifacts.py +63 -0
- juniper_data/core/dataset_id.py +38 -0
- juniper_data/core/models.py +135 -0
- juniper_data/core/split.py +120 -0
- juniper_data/generators/__init__.py +15 -0
- juniper_data/generators/arc_agi/__init__.py +11 -0
- juniper_data/generators/arc_agi/generator.py +229 -0
- juniper_data/generators/arc_agi/params.py +56 -0
- juniper_data/generators/checkerboard/__init__.py +15 -0
- juniper_data/generators/checkerboard/generator.py +114 -0
- juniper_data/generators/checkerboard/params.py +32 -0
- juniper_data/generators/circles/__init__.py +11 -0
- juniper_data/generators/circles/generator.py +112 -0
- juniper_data/generators/circles/params.py +31 -0
- juniper_data/generators/csv_import/__init__.py +15 -0
- juniper_data/generators/csv_import/generator.py +198 -0
- juniper_data/generators/csv_import/params.py +48 -0
- juniper_data/generators/gaussian/__init__.py +11 -0
- juniper_data/generators/gaussian/generator.py +149 -0
- juniper_data/generators/gaussian/params.py +53 -0
- juniper_data/generators/mnist/__init__.py +11 -0
- juniper_data/generators/mnist/generator.py +124 -0
- juniper_data/generators/mnist/params.py +39 -0
- juniper_data/generators/spiral/__init__.py +57 -0
- juniper_data/generators/spiral/defaults.py +39 -0
- juniper_data/generators/spiral/generator.py +206 -0
- juniper_data/generators/spiral/params.py +148 -0
- juniper_data/generators/xor/__init__.py +11 -0
- juniper_data/generators/xor/generator.py +162 -0
- juniper_data/generators/xor/params.py +30 -0
- juniper_data/storage/__init__.py +120 -0
- juniper_data/storage/base.py +279 -0
- juniper_data/storage/cached.py +211 -0
- juniper_data/storage/hf_store.py +257 -0
- juniper_data/storage/kaggle_store.py +333 -0
- juniper_data/storage/local_fs.py +232 -0
- juniper_data/storage/memory.py +136 -0
- juniper_data/storage/postgres_store.py +373 -0
- juniper_data/storage/redis_store.py +264 -0
- juniper_data/tests/__init__.py +1 -0
- juniper_data/tests/conftest.py +68 -0
- juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
- juniper_data/tests/integration/__init__.py +1 -0
- juniper_data/tests/integration/test_api.py +283 -0
- juniper_data/tests/integration/test_e2e_workflow.py +378 -0
- juniper_data/tests/integration/test_lifecycle_api.py +304 -0
- juniper_data/tests/integration/test_security_integration.py +189 -0
- juniper_data/tests/integration/test_storage_workflow.py +259 -0
- juniper_data/tests/performance/__init__.py +1 -0
- juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
- juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
- juniper_data/tests/unit/__init__.py +1 -0
- juniper_data/tests/unit/test_api_app.py +206 -0
- juniper_data/tests/unit/test_api_routes.py +407 -0
- juniper_data/tests/unit/test_api_settings.py +100 -0
- juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
- juniper_data/tests/unit/test_artifacts.py +145 -0
- juniper_data/tests/unit/test_cached_store.py +423 -0
- juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
- juniper_data/tests/unit/test_circles_generator.py +256 -0
- juniper_data/tests/unit/test_csv_import_generator.py +345 -0
- juniper_data/tests/unit/test_dataset_id.py +181 -0
- juniper_data/tests/unit/test_gaussian_generator.py +333 -0
- juniper_data/tests/unit/test_hf_store.py +416 -0
- juniper_data/tests/unit/test_init.py +93 -0
- juniper_data/tests/unit/test_kaggle_store.py +469 -0
- juniper_data/tests/unit/test_lifecycle.py +394 -0
- juniper_data/tests/unit/test_main.py +127 -0
- juniper_data/tests/unit/test_middleware.py +79 -0
- juniper_data/tests/unit/test_mnist_generator.py +370 -0
- juniper_data/tests/unit/test_postgres_store.py +490 -0
- juniper_data/tests/unit/test_redis_store.py +500 -0
- juniper_data/tests/unit/test_security.py +281 -0
- juniper_data/tests/unit/test_security_boundaries.py +517 -0
- juniper_data/tests/unit/test_spiral_generator.py +566 -0
- juniper_data/tests/unit/test_split.py +245 -0
- juniper_data/tests/unit/test_storage.py +767 -0
- juniper_data/tests/unit/test_xor_generator.py +223 -0
- juniper_data-0.4.2.dist-info/METADATA +216 -0
- juniper_data-0.4.2.dist-info/RECORD +95 -0
- juniper_data-0.4.2.dist-info/WHEEL +5 -0
- juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
- juniper_data-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
"""PostgreSQL-backed dataset storage for metadata with file system artifacts."""
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
import json
|
|
5
|
+
from datetime import datetime # noqa: F401 - used by DatasetMeta serialization
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from juniper_data.core.models import DatasetMeta
|
|
12
|
+
|
|
13
|
+
from .base import DatasetStore
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import psycopg2
|
|
17
|
+
from psycopg2.extras import RealDictCursor
|
|
18
|
+
|
|
19
|
+
POSTGRES_AVAILABLE = True
|
|
20
|
+
except ImportError:
|
|
21
|
+
POSTGRES_AVAILABLE = False
|
|
22
|
+
psycopg2 = None # type: ignore[assignment]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class PostgresDatasetStore(DatasetStore):
|
|
26
|
+
"""PostgreSQL-backed dataset storage.
|
|
27
|
+
|
|
28
|
+
Stores metadata in PostgreSQL and artifacts on the local filesystem.
|
|
29
|
+
Suitable for production deployments with database-backed metadata.
|
|
30
|
+
|
|
31
|
+
Requires the `psycopg2` package: pip install psycopg2-binary
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
SCHEMA_SQL = """
|
|
35
|
+
CREATE TABLE IF NOT EXISTS datasets (
|
|
36
|
+
dataset_id VARCHAR(255) PRIMARY KEY,
|
|
37
|
+
generator VARCHAR(100) NOT NULL,
|
|
38
|
+
generator_version VARCHAR(50) NOT NULL,
|
|
39
|
+
params JSONB NOT NULL,
|
|
40
|
+
n_samples INTEGER NOT NULL,
|
|
41
|
+
n_features INTEGER NOT NULL,
|
|
42
|
+
n_classes INTEGER NOT NULL,
|
|
43
|
+
n_train INTEGER NOT NULL,
|
|
44
|
+
n_test INTEGER NOT NULL,
|
|
45
|
+
class_distribution JSONB NOT NULL,
|
|
46
|
+
artifact_formats TEXT[] NOT NULL DEFAULT ARRAY['npz'],
|
|
47
|
+
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
|
48
|
+
checksum VARCHAR(64),
|
|
49
|
+
tags TEXT[] NOT NULL DEFAULT ARRAY[]::TEXT[],
|
|
50
|
+
ttl_seconds INTEGER,
|
|
51
|
+
expires_at TIMESTAMP WITH TIME ZONE,
|
|
52
|
+
last_accessed_at TIMESTAMP WITH TIME ZONE,
|
|
53
|
+
access_count INTEGER NOT NULL DEFAULT 0
|
|
54
|
+
);
|
|
55
|
+
|
|
56
|
+
CREATE INDEX IF NOT EXISTS idx_datasets_generator ON datasets(generator);
|
|
57
|
+
CREATE INDEX IF NOT EXISTS idx_datasets_created_at ON datasets(created_at);
|
|
58
|
+
CREATE INDEX IF NOT EXISTS idx_datasets_expires_at ON datasets(expires_at);
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
host: str = "localhost",
|
|
64
|
+
port: int = 5432,
|
|
65
|
+
database: str = "juniper_data",
|
|
66
|
+
user: str = "postgres",
|
|
67
|
+
password: str | None = None,
|
|
68
|
+
artifact_path: Path | None = None,
|
|
69
|
+
connection_string: str | None = None,
|
|
70
|
+
auto_create_schema: bool = True,
|
|
71
|
+
) -> None:
|
|
72
|
+
"""Initialize PostgreSQL connection.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
host: PostgreSQL server hostname.
|
|
76
|
+
port: PostgreSQL server port.
|
|
77
|
+
database: Database name.
|
|
78
|
+
user: Database user.
|
|
79
|
+
password: Database password.
|
|
80
|
+
artifact_path: Path for storing NPZ artifacts.
|
|
81
|
+
connection_string: Optional full connection string (overrides other params).
|
|
82
|
+
auto_create_schema: Automatically create tables if they don't exist.
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
ImportError: If psycopg2 package is not installed.
|
|
86
|
+
"""
|
|
87
|
+
if not POSTGRES_AVAILABLE:
|
|
88
|
+
raise ImportError("psycopg2 package not installed. Install with: pip install psycopg2-binary")
|
|
89
|
+
|
|
90
|
+
self._artifact_path = artifact_path or Path("./data/datasets")
|
|
91
|
+
self._artifact_path.mkdir(parents=True, exist_ok=True)
|
|
92
|
+
|
|
93
|
+
if connection_string:
|
|
94
|
+
self._conn_params: dict[str, Any] = {"dsn": connection_string}
|
|
95
|
+
else:
|
|
96
|
+
self._conn_params = {
|
|
97
|
+
"host": host,
|
|
98
|
+
"port": str(port),
|
|
99
|
+
"database": database,
|
|
100
|
+
"user": user,
|
|
101
|
+
"password": password or "",
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
if auto_create_schema:
|
|
105
|
+
self._create_schema()
|
|
106
|
+
|
|
107
|
+
def _get_connection(self) -> Any:
|
|
108
|
+
"""Get a new database connection."""
|
|
109
|
+
return psycopg2.connect(**self._conn_params)
|
|
110
|
+
|
|
111
|
+
def _create_schema(self) -> None:
|
|
112
|
+
"""Create database schema if it doesn't exist."""
|
|
113
|
+
with self._get_connection() as conn:
|
|
114
|
+
with conn.cursor() as cur:
|
|
115
|
+
cur.execute(self.SCHEMA_SQL)
|
|
116
|
+
conn.commit()
|
|
117
|
+
|
|
118
|
+
def _artifact_file(self, dataset_id: str) -> Path:
|
|
119
|
+
"""Get the artifact file path for a dataset."""
|
|
120
|
+
return self._artifact_path / f"{dataset_id}.npz"
|
|
121
|
+
|
|
122
|
+
def _meta_to_row(self, meta: DatasetMeta) -> dict:
|
|
123
|
+
"""Convert DatasetMeta to database row dict."""
|
|
124
|
+
return {
|
|
125
|
+
"dataset_id": meta.dataset_id,
|
|
126
|
+
"generator": meta.generator,
|
|
127
|
+
"generator_version": meta.generator_version,
|
|
128
|
+
"params": json.dumps(meta.params),
|
|
129
|
+
"n_samples": meta.n_samples,
|
|
130
|
+
"n_features": meta.n_features,
|
|
131
|
+
"n_classes": meta.n_classes,
|
|
132
|
+
"n_train": meta.n_train,
|
|
133
|
+
"n_test": meta.n_test,
|
|
134
|
+
"class_distribution": json.dumps(meta.class_distribution),
|
|
135
|
+
"artifact_formats": meta.artifact_formats,
|
|
136
|
+
"created_at": meta.created_at,
|
|
137
|
+
"checksum": meta.checksum,
|
|
138
|
+
"tags": meta.tags,
|
|
139
|
+
"ttl_seconds": meta.ttl_seconds,
|
|
140
|
+
"expires_at": meta.expires_at,
|
|
141
|
+
"last_accessed_at": meta.last_accessed_at,
|
|
142
|
+
"access_count": meta.access_count,
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
def _row_to_meta(self, row: dict) -> DatasetMeta:
|
|
146
|
+
"""Convert database row to DatasetMeta."""
|
|
147
|
+
return DatasetMeta(
|
|
148
|
+
dataset_id=row["dataset_id"],
|
|
149
|
+
generator=row["generator"],
|
|
150
|
+
generator_version=row["generator_version"],
|
|
151
|
+
params=row["params"] if isinstance(row["params"], dict) else json.loads(row["params"]),
|
|
152
|
+
n_samples=row["n_samples"],
|
|
153
|
+
n_features=row["n_features"],
|
|
154
|
+
n_classes=row["n_classes"],
|
|
155
|
+
n_train=row["n_train"],
|
|
156
|
+
n_test=row["n_test"],
|
|
157
|
+
class_distribution=row["class_distribution"]
|
|
158
|
+
if isinstance(row["class_distribution"], dict)
|
|
159
|
+
else json.loads(row["class_distribution"]),
|
|
160
|
+
artifact_formats=list(row["artifact_formats"]),
|
|
161
|
+
created_at=row["created_at"],
|
|
162
|
+
checksum=row["checksum"],
|
|
163
|
+
tags=list(row["tags"]) if row["tags"] else [],
|
|
164
|
+
ttl_seconds=row["ttl_seconds"],
|
|
165
|
+
expires_at=row["expires_at"],
|
|
166
|
+
last_accessed_at=row["last_accessed_at"],
|
|
167
|
+
access_count=row["access_count"],
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def save(
|
|
171
|
+
self,
|
|
172
|
+
dataset_id: str,
|
|
173
|
+
meta: DatasetMeta,
|
|
174
|
+
arrays: dict[str, np.ndarray],
|
|
175
|
+
) -> None:
|
|
176
|
+
"""Save dataset to PostgreSQL and filesystem.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
dataset_id: Unique identifier for the dataset.
|
|
180
|
+
meta: Dataset metadata.
|
|
181
|
+
arrays: Dictionary of numpy arrays.
|
|
182
|
+
"""
|
|
183
|
+
row = self._meta_to_row(meta)
|
|
184
|
+
|
|
185
|
+
insert_sql = """
|
|
186
|
+
INSERT INTO datasets (
|
|
187
|
+
dataset_id, generator, generator_version, params, n_samples,
|
|
188
|
+
n_features, n_classes, n_train, n_test, class_distribution,
|
|
189
|
+
artifact_formats, created_at, checksum, tags, ttl_seconds,
|
|
190
|
+
expires_at, last_accessed_at, access_count
|
|
191
|
+
) VALUES (
|
|
192
|
+
%(dataset_id)s, %(generator)s, %(generator_version)s, %(params)s::jsonb,
|
|
193
|
+
%(n_samples)s, %(n_features)s, %(n_classes)s, %(n_train)s, %(n_test)s,
|
|
194
|
+
%(class_distribution)s::jsonb, %(artifact_formats)s, %(created_at)s,
|
|
195
|
+
%(checksum)s, %(tags)s, %(ttl_seconds)s, %(expires_at)s,
|
|
196
|
+
%(last_accessed_at)s, %(access_count)s
|
|
197
|
+
) ON CONFLICT (dataset_id) DO UPDATE SET
|
|
198
|
+
generator = EXCLUDED.generator,
|
|
199
|
+
generator_version = EXCLUDED.generator_version,
|
|
200
|
+
params = EXCLUDED.params,
|
|
201
|
+
n_samples = EXCLUDED.n_samples,
|
|
202
|
+
n_features = EXCLUDED.n_features,
|
|
203
|
+
n_classes = EXCLUDED.n_classes,
|
|
204
|
+
n_train = EXCLUDED.n_train,
|
|
205
|
+
n_test = EXCLUDED.n_test,
|
|
206
|
+
class_distribution = EXCLUDED.class_distribution,
|
|
207
|
+
artifact_formats = EXCLUDED.artifact_formats,
|
|
208
|
+
checksum = EXCLUDED.checksum,
|
|
209
|
+
tags = EXCLUDED.tags,
|
|
210
|
+
ttl_seconds = EXCLUDED.ttl_seconds,
|
|
211
|
+
expires_at = EXCLUDED.expires_at,
|
|
212
|
+
last_accessed_at = EXCLUDED.last_accessed_at,
|
|
213
|
+
access_count = EXCLUDED.access_count
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
with self._get_connection() as conn:
|
|
217
|
+
with conn.cursor() as cur:
|
|
218
|
+
cur.execute(insert_sql, row)
|
|
219
|
+
conn.commit()
|
|
220
|
+
|
|
221
|
+
artifact_path = self._artifact_file(dataset_id)
|
|
222
|
+
buffer = io.BytesIO()
|
|
223
|
+
np.savez_compressed(buffer, **arrays) # type: ignore[arg-type]
|
|
224
|
+
artifact_path.write_bytes(buffer.getvalue())
|
|
225
|
+
|
|
226
|
+
def get_meta(self, dataset_id: str) -> DatasetMeta | None:
|
|
227
|
+
"""Get dataset metadata from PostgreSQL.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
dataset_id: Unique identifier for the dataset.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
Dataset metadata if found, None otherwise.
|
|
234
|
+
"""
|
|
235
|
+
with self._get_connection() as conn:
|
|
236
|
+
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
|
237
|
+
cur.execute("SELECT * FROM datasets WHERE dataset_id = %s", (dataset_id,))
|
|
238
|
+
row = cur.fetchone()
|
|
239
|
+
|
|
240
|
+
if row is None:
|
|
241
|
+
return None
|
|
242
|
+
|
|
243
|
+
return self._row_to_meta(dict(row))
|
|
244
|
+
|
|
245
|
+
def get_artifact_bytes(self, dataset_id: str) -> bytes | None:
|
|
246
|
+
"""Get dataset artifact bytes from filesystem.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
dataset_id: Unique identifier for the dataset.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
NPZ bytes if found, None otherwise.
|
|
253
|
+
"""
|
|
254
|
+
artifact_path = self._artifact_file(dataset_id)
|
|
255
|
+
if not artifact_path.exists():
|
|
256
|
+
return None
|
|
257
|
+
return artifact_path.read_bytes()
|
|
258
|
+
|
|
259
|
+
def exists(self, dataset_id: str) -> bool:
|
|
260
|
+
"""Check if dataset exists in PostgreSQL.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
dataset_id: Unique identifier for the dataset.
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
True if the dataset exists, False otherwise.
|
|
267
|
+
"""
|
|
268
|
+
with self._get_connection() as conn:
|
|
269
|
+
with conn.cursor() as cur:
|
|
270
|
+
cur.execute("SELECT 1 FROM datasets WHERE dataset_id = %s", (dataset_id,))
|
|
271
|
+
return cur.fetchone() is not None
|
|
272
|
+
|
|
273
|
+
def delete(self, dataset_id: str) -> bool:
|
|
274
|
+
"""Delete dataset from PostgreSQL and filesystem.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
dataset_id: Unique identifier for the dataset.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
True if the dataset was deleted, False if it didn't exist.
|
|
281
|
+
"""
|
|
282
|
+
with self._get_connection() as conn:
|
|
283
|
+
with conn.cursor() as cur:
|
|
284
|
+
cur.execute(
|
|
285
|
+
"DELETE FROM datasets WHERE dataset_id = %s RETURNING dataset_id",
|
|
286
|
+
(dataset_id,),
|
|
287
|
+
)
|
|
288
|
+
deleted = cur.fetchone() is not None
|
|
289
|
+
conn.commit()
|
|
290
|
+
|
|
291
|
+
artifact_path = self._artifact_file(dataset_id)
|
|
292
|
+
if artifact_path.exists():
|
|
293
|
+
artifact_path.unlink()
|
|
294
|
+
|
|
295
|
+
return deleted
|
|
296
|
+
|
|
297
|
+
def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]:
|
|
298
|
+
"""List dataset IDs from PostgreSQL.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
limit: Maximum number of dataset IDs to return.
|
|
302
|
+
offset: Number of dataset IDs to skip.
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
List of dataset IDs.
|
|
306
|
+
"""
|
|
307
|
+
with self._get_connection() as conn:
|
|
308
|
+
with conn.cursor() as cur:
|
|
309
|
+
cur.execute(
|
|
310
|
+
"SELECT dataset_id FROM datasets ORDER BY created_at DESC LIMIT %s OFFSET %s",
|
|
311
|
+
(limit, offset),
|
|
312
|
+
)
|
|
313
|
+
rows = cur.fetchall()
|
|
314
|
+
|
|
315
|
+
return [row[0] for row in rows]
|
|
316
|
+
|
|
317
|
+
def update_meta(self, dataset_id: str, meta: DatasetMeta) -> bool:
|
|
318
|
+
"""Update dataset metadata in PostgreSQL.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
dataset_id: Unique identifier for the dataset.
|
|
322
|
+
meta: Updated dataset metadata.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
True if the dataset was updated, False if it didn't exist.
|
|
326
|
+
"""
|
|
327
|
+
row = self._meta_to_row(meta)
|
|
328
|
+
|
|
329
|
+
update_sql = """
|
|
330
|
+
UPDATE datasets SET
|
|
331
|
+
generator = %(generator)s,
|
|
332
|
+
generator_version = %(generator_version)s,
|
|
333
|
+
params = %(params)s::jsonb,
|
|
334
|
+
n_samples = %(n_samples)s,
|
|
335
|
+
n_features = %(n_features)s,
|
|
336
|
+
n_classes = %(n_classes)s,
|
|
337
|
+
n_train = %(n_train)s,
|
|
338
|
+
n_test = %(n_test)s,
|
|
339
|
+
class_distribution = %(class_distribution)s::jsonb,
|
|
340
|
+
artifact_formats = %(artifact_formats)s,
|
|
341
|
+
checksum = %(checksum)s,
|
|
342
|
+
tags = %(tags)s,
|
|
343
|
+
ttl_seconds = %(ttl_seconds)s,
|
|
344
|
+
expires_at = %(expires_at)s,
|
|
345
|
+
last_accessed_at = %(last_accessed_at)s,
|
|
346
|
+
access_count = %(access_count)s
|
|
347
|
+
WHERE dataset_id = %(dataset_id)s
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
with self._get_connection() as conn:
|
|
351
|
+
with conn.cursor() as cur:
|
|
352
|
+
cur.execute(update_sql, row)
|
|
353
|
+
updated = cur.rowcount > 0
|
|
354
|
+
conn.commit()
|
|
355
|
+
|
|
356
|
+
return updated
|
|
357
|
+
|
|
358
|
+
def list_all_metadata(self) -> list[DatasetMeta]:
|
|
359
|
+
"""List all dataset metadata from PostgreSQL.
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
List of all DatasetMeta objects.
|
|
363
|
+
"""
|
|
364
|
+
with self._get_connection() as conn:
|
|
365
|
+
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
|
366
|
+
cur.execute("SELECT * FROM datasets ORDER BY created_at DESC")
|
|
367
|
+
rows = cur.fetchall()
|
|
368
|
+
|
|
369
|
+
return [self._row_to_meta(dict(row)) for row in rows]
|
|
370
|
+
|
|
371
|
+
def close(self) -> None:
|
|
372
|
+
"""Close database connections (no-op for connection-per-request pattern)."""
|
|
373
|
+
pass
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""Redis-backed dataset storage for caching and distributed deployments."""
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
import json
|
|
5
|
+
from datetime import datetime # noqa: F401 - used by DatasetMeta serialization
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from juniper_data.core.models import DatasetMeta
|
|
11
|
+
|
|
12
|
+
from .base import DatasetStore
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import redis
|
|
16
|
+
|
|
17
|
+
REDIS_AVAILABLE = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
REDIS_AVAILABLE = False
|
|
20
|
+
redis = None # type: ignore[assignment]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RedisDatasetStore(DatasetStore):
|
|
24
|
+
"""Redis-backed dataset storage.
|
|
25
|
+
|
|
26
|
+
Uses Redis for both metadata (as JSON) and artifact storage (as bytes).
|
|
27
|
+
Suitable for caching layers and distributed deployments.
|
|
28
|
+
|
|
29
|
+
Requires the `redis` package: pip install redis
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
host: str = "localhost",
|
|
35
|
+
port: int = 6379,
|
|
36
|
+
db: int = 0,
|
|
37
|
+
password: str | None = None,
|
|
38
|
+
key_prefix: str = "juniper:dataset:",
|
|
39
|
+
default_ttl: int | None = None,
|
|
40
|
+
connection_pool: Any | None = None,
|
|
41
|
+
) -> None:
|
|
42
|
+
"""Initialize Redis connection.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
host: Redis server hostname.
|
|
46
|
+
port: Redis server port.
|
|
47
|
+
db: Redis database number.
|
|
48
|
+
password: Redis password (optional).
|
|
49
|
+
key_prefix: Prefix for all Redis keys.
|
|
50
|
+
default_ttl: Default TTL for stored datasets in seconds (optional).
|
|
51
|
+
connection_pool: Optional existing Redis connection pool.
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
ImportError: If redis package is not installed.
|
|
55
|
+
"""
|
|
56
|
+
if not REDIS_AVAILABLE:
|
|
57
|
+
raise ImportError("Redis package not installed. Install with: pip install redis")
|
|
58
|
+
|
|
59
|
+
self._key_prefix = key_prefix
|
|
60
|
+
self._default_ttl = default_ttl
|
|
61
|
+
|
|
62
|
+
if connection_pool:
|
|
63
|
+
self._client: redis.Redis[bytes] = redis.Redis(connection_pool=connection_pool)
|
|
64
|
+
else:
|
|
65
|
+
self._client = redis.Redis(
|
|
66
|
+
host=host,
|
|
67
|
+
port=port,
|
|
68
|
+
db=db,
|
|
69
|
+
password=password,
|
|
70
|
+
decode_responses=False,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def _meta_key(self, dataset_id: str) -> str:
|
|
74
|
+
"""Get Redis key for metadata."""
|
|
75
|
+
return f"{self._key_prefix}{dataset_id}:meta"
|
|
76
|
+
|
|
77
|
+
def _artifact_key(self, dataset_id: str) -> str:
|
|
78
|
+
"""Get Redis key for artifact data."""
|
|
79
|
+
return f"{self._key_prefix}{dataset_id}:artifact"
|
|
80
|
+
|
|
81
|
+
def _encode_meta(self, meta: DatasetMeta) -> bytes:
|
|
82
|
+
"""Encode metadata to JSON bytes."""
|
|
83
|
+
data = meta.model_dump(mode="json")
|
|
84
|
+
return json.dumps(data).encode("utf-8")
|
|
85
|
+
|
|
86
|
+
def _decode_meta(self, data: bytes) -> DatasetMeta:
|
|
87
|
+
"""Decode metadata from JSON bytes."""
|
|
88
|
+
parsed = json.loads(data.decode("utf-8"))
|
|
89
|
+
return DatasetMeta(**parsed)
|
|
90
|
+
|
|
91
|
+
def _encode_arrays(self, arrays: dict[str, np.ndarray]) -> bytes:
|
|
92
|
+
"""Encode arrays to NPZ bytes."""
|
|
93
|
+
buffer = io.BytesIO()
|
|
94
|
+
np.savez_compressed(buffer, **arrays) # type: ignore[arg-type]
|
|
95
|
+
return buffer.getvalue()
|
|
96
|
+
|
|
97
|
+
def save(
|
|
98
|
+
self,
|
|
99
|
+
dataset_id: str,
|
|
100
|
+
meta: DatasetMeta,
|
|
101
|
+
arrays: dict[str, np.ndarray],
|
|
102
|
+
) -> None:
|
|
103
|
+
"""Save dataset to Redis.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
dataset_id: Unique identifier for the dataset.
|
|
107
|
+
meta: Dataset metadata.
|
|
108
|
+
arrays: Dictionary of numpy arrays.
|
|
109
|
+
"""
|
|
110
|
+
meta_bytes = self._encode_meta(meta)
|
|
111
|
+
artifact_bytes = self._encode_arrays(arrays)
|
|
112
|
+
|
|
113
|
+
meta_key = self._meta_key(dataset_id)
|
|
114
|
+
artifact_key = self._artifact_key(dataset_id)
|
|
115
|
+
|
|
116
|
+
ttl = meta.ttl_seconds or self._default_ttl
|
|
117
|
+
|
|
118
|
+
pipe = self._client.pipeline()
|
|
119
|
+
if ttl:
|
|
120
|
+
pipe.setex(meta_key, ttl, meta_bytes)
|
|
121
|
+
pipe.setex(artifact_key, ttl, artifact_bytes)
|
|
122
|
+
else:
|
|
123
|
+
pipe.set(meta_key, meta_bytes)
|
|
124
|
+
pipe.set(artifact_key, artifact_bytes)
|
|
125
|
+
pipe.execute()
|
|
126
|
+
|
|
127
|
+
def get_meta(self, dataset_id: str) -> DatasetMeta | None:
|
|
128
|
+
"""Get dataset metadata from Redis.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
dataset_id: Unique identifier for the dataset.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Dataset metadata if found, None otherwise.
|
|
135
|
+
"""
|
|
136
|
+
data = self._client.get(self._meta_key(dataset_id))
|
|
137
|
+
if data is None:
|
|
138
|
+
return None
|
|
139
|
+
return self._decode_meta(data)
|
|
140
|
+
|
|
141
|
+
def get_artifact_bytes(self, dataset_id: str) -> bytes | None:
|
|
142
|
+
"""Get dataset artifact bytes from Redis.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
dataset_id: Unique identifier for the dataset.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
NPZ bytes if found, None otherwise.
|
|
149
|
+
"""
|
|
150
|
+
return self._client.get(self._artifact_key(dataset_id))
|
|
151
|
+
|
|
152
|
+
def exists(self, dataset_id: str) -> bool:
|
|
153
|
+
"""Check if dataset exists in Redis.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
dataset_id: Unique identifier for the dataset.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
True if the dataset exists, False otherwise.
|
|
160
|
+
"""
|
|
161
|
+
return bool(self._client.exists(self._meta_key(dataset_id)))
|
|
162
|
+
|
|
163
|
+
def delete(self, dataset_id: str) -> bool:
|
|
164
|
+
"""Delete dataset from Redis.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
dataset_id: Unique identifier for the dataset.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
True if the dataset was deleted, False if it didn't exist.
|
|
171
|
+
"""
|
|
172
|
+
meta_key = self._meta_key(dataset_id)
|
|
173
|
+
artifact_key = self._artifact_key(dataset_id)
|
|
174
|
+
|
|
175
|
+
deleted = self._client.delete(meta_key, artifact_key)
|
|
176
|
+
return deleted > 0
|
|
177
|
+
|
|
178
|
+
def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]:
|
|
179
|
+
"""List dataset IDs from Redis.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
limit: Maximum number of dataset IDs to return.
|
|
183
|
+
offset: Number of dataset IDs to skip.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
List of dataset IDs.
|
|
187
|
+
"""
|
|
188
|
+
pattern = f"{self._key_prefix}*:meta"
|
|
189
|
+
keys = list(self._client.scan_iter(match=pattern))
|
|
190
|
+
|
|
191
|
+
dataset_ids = []
|
|
192
|
+
for key in keys:
|
|
193
|
+
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
|
|
194
|
+
dataset_id = key_str[len(self._key_prefix) : -5]
|
|
195
|
+
dataset_ids.append(dataset_id)
|
|
196
|
+
|
|
197
|
+
dataset_ids.sort()
|
|
198
|
+
return dataset_ids[offset : offset + limit]
|
|
199
|
+
|
|
200
|
+
def update_meta(self, dataset_id: str, meta: DatasetMeta) -> bool:
|
|
201
|
+
"""Update dataset metadata in Redis.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
dataset_id: Unique identifier for the dataset.
|
|
205
|
+
meta: Updated dataset metadata.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
True if the dataset was updated, False if it didn't exist.
|
|
209
|
+
"""
|
|
210
|
+
meta_key = self._meta_key(dataset_id)
|
|
211
|
+
if not self._client.exists(meta_key):
|
|
212
|
+
return False
|
|
213
|
+
|
|
214
|
+
ttl = self._client.ttl(meta_key)
|
|
215
|
+
meta_bytes = self._encode_meta(meta)
|
|
216
|
+
|
|
217
|
+
if ttl > 0:
|
|
218
|
+
self._client.setex(meta_key, ttl, meta_bytes)
|
|
219
|
+
else:
|
|
220
|
+
self._client.set(meta_key, meta_bytes)
|
|
221
|
+
|
|
222
|
+
return True
|
|
223
|
+
|
|
224
|
+
def list_all_metadata(self) -> list[DatasetMeta]:
|
|
225
|
+
"""List all dataset metadata from Redis.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
List of all DatasetMeta objects.
|
|
229
|
+
"""
|
|
230
|
+
pattern = f"{self._key_prefix}*:meta"
|
|
231
|
+
keys = list(self._client.scan_iter(match=pattern))
|
|
232
|
+
|
|
233
|
+
metadata: list[DatasetMeta] = []
|
|
234
|
+
for key in keys:
|
|
235
|
+
data = self._client.get(key)
|
|
236
|
+
if data is not None:
|
|
237
|
+
metadata.append(self._decode_meta(data))
|
|
238
|
+
|
|
239
|
+
return metadata
|
|
240
|
+
|
|
241
|
+
def ping(self) -> bool:
|
|
242
|
+
"""Check if Redis connection is alive.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
True if connected, False otherwise.
|
|
246
|
+
"""
|
|
247
|
+
try:
|
|
248
|
+
return self._client.ping()
|
|
249
|
+
except Exception:
|
|
250
|
+
return False
|
|
251
|
+
|
|
252
|
+
def flush_prefix(self) -> int:
|
|
253
|
+
"""Delete all keys with this store's prefix.
|
|
254
|
+
|
|
255
|
+
WARNING: This deletes all datasets in this store.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
Number of keys deleted.
|
|
259
|
+
"""
|
|
260
|
+
pattern = f"{self._key_prefix}*"
|
|
261
|
+
keys = list(self._client.scan_iter(match=pattern))
|
|
262
|
+
if keys:
|
|
263
|
+
return self._client.delete(*keys)
|
|
264
|
+
return 0
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Test suite for Juniper Data."""
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Common pytest fixtures for juniper_data tests."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from juniper_data.generators.spiral import SpiralGenerator, SpiralParams
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@pytest.fixture
|
|
10
|
+
def default_spiral_params() -> SpiralParams:
|
|
11
|
+
"""Default spiral parameters for testing."""
|
|
12
|
+
return SpiralParams()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.fixture
|
|
16
|
+
def two_spiral_params() -> SpiralParams:
|
|
17
|
+
"""Parameters for a 2-spiral dataset with 100 points per spiral."""
|
|
18
|
+
return SpiralParams(
|
|
19
|
+
n_spirals=2,
|
|
20
|
+
n_points_per_spiral=100,
|
|
21
|
+
seed=42,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def three_spiral_params() -> SpiralParams:
|
|
27
|
+
"""Parameters for a 3-spiral dataset with 50 points per spiral."""
|
|
28
|
+
return SpiralParams(
|
|
29
|
+
n_spirals=3,
|
|
30
|
+
n_points_per_spiral=50,
|
|
31
|
+
seed=42,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.fixture
|
|
36
|
+
def minimal_spiral_params() -> SpiralParams:
|
|
37
|
+
"""Minimal valid spiral parameters for fast tests."""
|
|
38
|
+
return SpiralParams(
|
|
39
|
+
n_spirals=2,
|
|
40
|
+
n_points_per_spiral=10,
|
|
41
|
+
seed=42,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@pytest.fixture
|
|
46
|
+
def generated_two_spiral_dataset(two_spiral_params: SpiralParams) -> dict[str, np.ndarray]:
|
|
47
|
+
"""Generate a 2-spiral dataset for testing."""
|
|
48
|
+
return SpiralGenerator.generate(two_spiral_params)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@pytest.fixture
|
|
52
|
+
def generated_three_spiral_dataset(three_spiral_params: SpiralParams) -> dict[str, np.ndarray]:
|
|
53
|
+
"""Generate a 3-spiral dataset for testing."""
|
|
54
|
+
return SpiralGenerator.generate(three_spiral_params)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@pytest.fixture
|
|
58
|
+
def generated_minimal_dataset(minimal_spiral_params: SpiralParams) -> dict[str, np.ndarray]:
|
|
59
|
+
"""Generate a minimal dataset for fast tests."""
|
|
60
|
+
return SpiralGenerator.generate(minimal_spiral_params)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@pytest.fixture
|
|
64
|
+
def sample_arrays() -> dict[str, np.ndarray]:
|
|
65
|
+
"""Simple sample arrays for split/shuffle testing."""
|
|
66
|
+
X = np.arange(20).reshape(10, 2).astype(np.float32)
|
|
67
|
+
y = np.eye(2, dtype=np.float32)[np.arange(10) % 2]
|
|
68
|
+
return {"X": X, "y": y}
|