juniper-data 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. juniper_data/__init__.py +88 -0
  2. juniper_data/__main__.py +78 -0
  3. juniper_data/api/__init__.py +10 -0
  4. juniper_data/api/app.py +111 -0
  5. juniper_data/api/middleware.py +95 -0
  6. juniper_data/api/routes/__init__.py +9 -0
  7. juniper_data/api/routes/datasets.py +414 -0
  8. juniper_data/api/routes/generators.py +125 -0
  9. juniper_data/api/routes/health.py +49 -0
  10. juniper_data/api/security.py +238 -0
  11. juniper_data/api/settings.py +109 -0
  12. juniper_data/core/__init__.py +32 -0
  13. juniper_data/core/artifacts.py +63 -0
  14. juniper_data/core/dataset_id.py +38 -0
  15. juniper_data/core/models.py +135 -0
  16. juniper_data/core/split.py +120 -0
  17. juniper_data/generators/__init__.py +15 -0
  18. juniper_data/generators/arc_agi/__init__.py +11 -0
  19. juniper_data/generators/arc_agi/generator.py +229 -0
  20. juniper_data/generators/arc_agi/params.py +56 -0
  21. juniper_data/generators/checkerboard/__init__.py +15 -0
  22. juniper_data/generators/checkerboard/generator.py +114 -0
  23. juniper_data/generators/checkerboard/params.py +32 -0
  24. juniper_data/generators/circles/__init__.py +11 -0
  25. juniper_data/generators/circles/generator.py +112 -0
  26. juniper_data/generators/circles/params.py +31 -0
  27. juniper_data/generators/csv_import/__init__.py +15 -0
  28. juniper_data/generators/csv_import/generator.py +198 -0
  29. juniper_data/generators/csv_import/params.py +48 -0
  30. juniper_data/generators/gaussian/__init__.py +11 -0
  31. juniper_data/generators/gaussian/generator.py +149 -0
  32. juniper_data/generators/gaussian/params.py +53 -0
  33. juniper_data/generators/mnist/__init__.py +11 -0
  34. juniper_data/generators/mnist/generator.py +124 -0
  35. juniper_data/generators/mnist/params.py +39 -0
  36. juniper_data/generators/spiral/__init__.py +57 -0
  37. juniper_data/generators/spiral/defaults.py +39 -0
  38. juniper_data/generators/spiral/generator.py +206 -0
  39. juniper_data/generators/spiral/params.py +148 -0
  40. juniper_data/generators/xor/__init__.py +11 -0
  41. juniper_data/generators/xor/generator.py +162 -0
  42. juniper_data/generators/xor/params.py +30 -0
  43. juniper_data/storage/__init__.py +120 -0
  44. juniper_data/storage/base.py +279 -0
  45. juniper_data/storage/cached.py +211 -0
  46. juniper_data/storage/hf_store.py +257 -0
  47. juniper_data/storage/kaggle_store.py +333 -0
  48. juniper_data/storage/local_fs.py +232 -0
  49. juniper_data/storage/memory.py +136 -0
  50. juniper_data/storage/postgres_store.py +373 -0
  51. juniper_data/storage/redis_store.py +264 -0
  52. juniper_data/tests/__init__.py +1 -0
  53. juniper_data/tests/conftest.py +68 -0
  54. juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
  55. juniper_data/tests/integration/__init__.py +1 -0
  56. juniper_data/tests/integration/test_api.py +283 -0
  57. juniper_data/tests/integration/test_e2e_workflow.py +378 -0
  58. juniper_data/tests/integration/test_lifecycle_api.py +304 -0
  59. juniper_data/tests/integration/test_security_integration.py +189 -0
  60. juniper_data/tests/integration/test_storage_workflow.py +259 -0
  61. juniper_data/tests/performance/__init__.py +1 -0
  62. juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
  63. juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
  64. juniper_data/tests/unit/__init__.py +1 -0
  65. juniper_data/tests/unit/test_api_app.py +206 -0
  66. juniper_data/tests/unit/test_api_routes.py +407 -0
  67. juniper_data/tests/unit/test_api_settings.py +100 -0
  68. juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
  69. juniper_data/tests/unit/test_artifacts.py +145 -0
  70. juniper_data/tests/unit/test_cached_store.py +423 -0
  71. juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
  72. juniper_data/tests/unit/test_circles_generator.py +256 -0
  73. juniper_data/tests/unit/test_csv_import_generator.py +345 -0
  74. juniper_data/tests/unit/test_dataset_id.py +181 -0
  75. juniper_data/tests/unit/test_gaussian_generator.py +333 -0
  76. juniper_data/tests/unit/test_hf_store.py +416 -0
  77. juniper_data/tests/unit/test_init.py +93 -0
  78. juniper_data/tests/unit/test_kaggle_store.py +469 -0
  79. juniper_data/tests/unit/test_lifecycle.py +394 -0
  80. juniper_data/tests/unit/test_main.py +127 -0
  81. juniper_data/tests/unit/test_middleware.py +79 -0
  82. juniper_data/tests/unit/test_mnist_generator.py +370 -0
  83. juniper_data/tests/unit/test_postgres_store.py +490 -0
  84. juniper_data/tests/unit/test_redis_store.py +500 -0
  85. juniper_data/tests/unit/test_security.py +281 -0
  86. juniper_data/tests/unit/test_security_boundaries.py +517 -0
  87. juniper_data/tests/unit/test_spiral_generator.py +566 -0
  88. juniper_data/tests/unit/test_split.py +245 -0
  89. juniper_data/tests/unit/test_storage.py +767 -0
  90. juniper_data/tests/unit/test_xor_generator.py +223 -0
  91. juniper_data-0.4.2.dist-info/METADATA +216 -0
  92. juniper_data-0.4.2.dist-info/RECORD +95 -0
  93. juniper_data-0.4.2.dist-info/WHEEL +5 -0
  94. juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
  95. juniper_data-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,373 @@
1
+ """PostgreSQL-backed dataset storage for metadata with file system artifacts."""
2
+
3
+ import io
4
+ import json
5
+ from datetime import datetime # noqa: F401 - used by DatasetMeta serialization
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import numpy as np
10
+
11
+ from juniper_data.core.models import DatasetMeta
12
+
13
+ from .base import DatasetStore
14
+
15
+ try:
16
+ import psycopg2
17
+ from psycopg2.extras import RealDictCursor
18
+
19
+ POSTGRES_AVAILABLE = True
20
+ except ImportError:
21
+ POSTGRES_AVAILABLE = False
22
+ psycopg2 = None # type: ignore[assignment]
23
+
24
+
25
+ class PostgresDatasetStore(DatasetStore):
26
+ """PostgreSQL-backed dataset storage.
27
+
28
+ Stores metadata in PostgreSQL and artifacts on the local filesystem.
29
+ Suitable for production deployments with database-backed metadata.
30
+
31
+ Requires the `psycopg2` package: pip install psycopg2-binary
32
+ """
33
+
34
+ SCHEMA_SQL = """
35
+ CREATE TABLE IF NOT EXISTS datasets (
36
+ dataset_id VARCHAR(255) PRIMARY KEY,
37
+ generator VARCHAR(100) NOT NULL,
38
+ generator_version VARCHAR(50) NOT NULL,
39
+ params JSONB NOT NULL,
40
+ n_samples INTEGER NOT NULL,
41
+ n_features INTEGER NOT NULL,
42
+ n_classes INTEGER NOT NULL,
43
+ n_train INTEGER NOT NULL,
44
+ n_test INTEGER NOT NULL,
45
+ class_distribution JSONB NOT NULL,
46
+ artifact_formats TEXT[] NOT NULL DEFAULT ARRAY['npz'],
47
+ created_at TIMESTAMP WITH TIME ZONE NOT NULL,
48
+ checksum VARCHAR(64),
49
+ tags TEXT[] NOT NULL DEFAULT ARRAY[]::TEXT[],
50
+ ttl_seconds INTEGER,
51
+ expires_at TIMESTAMP WITH TIME ZONE,
52
+ last_accessed_at TIMESTAMP WITH TIME ZONE,
53
+ access_count INTEGER NOT NULL DEFAULT 0
54
+ );
55
+
56
+ CREATE INDEX IF NOT EXISTS idx_datasets_generator ON datasets(generator);
57
+ CREATE INDEX IF NOT EXISTS idx_datasets_created_at ON datasets(created_at);
58
+ CREATE INDEX IF NOT EXISTS idx_datasets_expires_at ON datasets(expires_at);
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ host: str = "localhost",
64
+ port: int = 5432,
65
+ database: str = "juniper_data",
66
+ user: str = "postgres",
67
+ password: str | None = None,
68
+ artifact_path: Path | None = None,
69
+ connection_string: str | None = None,
70
+ auto_create_schema: bool = True,
71
+ ) -> None:
72
+ """Initialize PostgreSQL connection.
73
+
74
+ Args:
75
+ host: PostgreSQL server hostname.
76
+ port: PostgreSQL server port.
77
+ database: Database name.
78
+ user: Database user.
79
+ password: Database password.
80
+ artifact_path: Path for storing NPZ artifacts.
81
+ connection_string: Optional full connection string (overrides other params).
82
+ auto_create_schema: Automatically create tables if they don't exist.
83
+
84
+ Raises:
85
+ ImportError: If psycopg2 package is not installed.
86
+ """
87
+ if not POSTGRES_AVAILABLE:
88
+ raise ImportError("psycopg2 package not installed. Install with: pip install psycopg2-binary")
89
+
90
+ self._artifact_path = artifact_path or Path("./data/datasets")
91
+ self._artifact_path.mkdir(parents=True, exist_ok=True)
92
+
93
+ if connection_string:
94
+ self._conn_params: dict[str, Any] = {"dsn": connection_string}
95
+ else:
96
+ self._conn_params = {
97
+ "host": host,
98
+ "port": str(port),
99
+ "database": database,
100
+ "user": user,
101
+ "password": password or "",
102
+ }
103
+
104
+ if auto_create_schema:
105
+ self._create_schema()
106
+
107
+ def _get_connection(self) -> Any:
108
+ """Get a new database connection."""
109
+ return psycopg2.connect(**self._conn_params)
110
+
111
+ def _create_schema(self) -> None:
112
+ """Create database schema if it doesn't exist."""
113
+ with self._get_connection() as conn:
114
+ with conn.cursor() as cur:
115
+ cur.execute(self.SCHEMA_SQL)
116
+ conn.commit()
117
+
118
+ def _artifact_file(self, dataset_id: str) -> Path:
119
+ """Get the artifact file path for a dataset."""
120
+ return self._artifact_path / f"{dataset_id}.npz"
121
+
122
+ def _meta_to_row(self, meta: DatasetMeta) -> dict:
123
+ """Convert DatasetMeta to database row dict."""
124
+ return {
125
+ "dataset_id": meta.dataset_id,
126
+ "generator": meta.generator,
127
+ "generator_version": meta.generator_version,
128
+ "params": json.dumps(meta.params),
129
+ "n_samples": meta.n_samples,
130
+ "n_features": meta.n_features,
131
+ "n_classes": meta.n_classes,
132
+ "n_train": meta.n_train,
133
+ "n_test": meta.n_test,
134
+ "class_distribution": json.dumps(meta.class_distribution),
135
+ "artifact_formats": meta.artifact_formats,
136
+ "created_at": meta.created_at,
137
+ "checksum": meta.checksum,
138
+ "tags": meta.tags,
139
+ "ttl_seconds": meta.ttl_seconds,
140
+ "expires_at": meta.expires_at,
141
+ "last_accessed_at": meta.last_accessed_at,
142
+ "access_count": meta.access_count,
143
+ }
144
+
145
+ def _row_to_meta(self, row: dict) -> DatasetMeta:
146
+ """Convert database row to DatasetMeta."""
147
+ return DatasetMeta(
148
+ dataset_id=row["dataset_id"],
149
+ generator=row["generator"],
150
+ generator_version=row["generator_version"],
151
+ params=row["params"] if isinstance(row["params"], dict) else json.loads(row["params"]),
152
+ n_samples=row["n_samples"],
153
+ n_features=row["n_features"],
154
+ n_classes=row["n_classes"],
155
+ n_train=row["n_train"],
156
+ n_test=row["n_test"],
157
+ class_distribution=row["class_distribution"]
158
+ if isinstance(row["class_distribution"], dict)
159
+ else json.loads(row["class_distribution"]),
160
+ artifact_formats=list(row["artifact_formats"]),
161
+ created_at=row["created_at"],
162
+ checksum=row["checksum"],
163
+ tags=list(row["tags"]) if row["tags"] else [],
164
+ ttl_seconds=row["ttl_seconds"],
165
+ expires_at=row["expires_at"],
166
+ last_accessed_at=row["last_accessed_at"],
167
+ access_count=row["access_count"],
168
+ )
169
+
170
+ def save(
171
+ self,
172
+ dataset_id: str,
173
+ meta: DatasetMeta,
174
+ arrays: dict[str, np.ndarray],
175
+ ) -> None:
176
+ """Save dataset to PostgreSQL and filesystem.
177
+
178
+ Args:
179
+ dataset_id: Unique identifier for the dataset.
180
+ meta: Dataset metadata.
181
+ arrays: Dictionary of numpy arrays.
182
+ """
183
+ row = self._meta_to_row(meta)
184
+
185
+ insert_sql = """
186
+ INSERT INTO datasets (
187
+ dataset_id, generator, generator_version, params, n_samples,
188
+ n_features, n_classes, n_train, n_test, class_distribution,
189
+ artifact_formats, created_at, checksum, tags, ttl_seconds,
190
+ expires_at, last_accessed_at, access_count
191
+ ) VALUES (
192
+ %(dataset_id)s, %(generator)s, %(generator_version)s, %(params)s::jsonb,
193
+ %(n_samples)s, %(n_features)s, %(n_classes)s, %(n_train)s, %(n_test)s,
194
+ %(class_distribution)s::jsonb, %(artifact_formats)s, %(created_at)s,
195
+ %(checksum)s, %(tags)s, %(ttl_seconds)s, %(expires_at)s,
196
+ %(last_accessed_at)s, %(access_count)s
197
+ ) ON CONFLICT (dataset_id) DO UPDATE SET
198
+ generator = EXCLUDED.generator,
199
+ generator_version = EXCLUDED.generator_version,
200
+ params = EXCLUDED.params,
201
+ n_samples = EXCLUDED.n_samples,
202
+ n_features = EXCLUDED.n_features,
203
+ n_classes = EXCLUDED.n_classes,
204
+ n_train = EXCLUDED.n_train,
205
+ n_test = EXCLUDED.n_test,
206
+ class_distribution = EXCLUDED.class_distribution,
207
+ artifact_formats = EXCLUDED.artifact_formats,
208
+ checksum = EXCLUDED.checksum,
209
+ tags = EXCLUDED.tags,
210
+ ttl_seconds = EXCLUDED.ttl_seconds,
211
+ expires_at = EXCLUDED.expires_at,
212
+ last_accessed_at = EXCLUDED.last_accessed_at,
213
+ access_count = EXCLUDED.access_count
214
+ """
215
+
216
+ with self._get_connection() as conn:
217
+ with conn.cursor() as cur:
218
+ cur.execute(insert_sql, row)
219
+ conn.commit()
220
+
221
+ artifact_path = self._artifact_file(dataset_id)
222
+ buffer = io.BytesIO()
223
+ np.savez_compressed(buffer, **arrays) # type: ignore[arg-type]
224
+ artifact_path.write_bytes(buffer.getvalue())
225
+
226
+ def get_meta(self, dataset_id: str) -> DatasetMeta | None:
227
+ """Get dataset metadata from PostgreSQL.
228
+
229
+ Args:
230
+ dataset_id: Unique identifier for the dataset.
231
+
232
+ Returns:
233
+ Dataset metadata if found, None otherwise.
234
+ """
235
+ with self._get_connection() as conn:
236
+ with conn.cursor(cursor_factory=RealDictCursor) as cur:
237
+ cur.execute("SELECT * FROM datasets WHERE dataset_id = %s", (dataset_id,))
238
+ row = cur.fetchone()
239
+
240
+ if row is None:
241
+ return None
242
+
243
+ return self._row_to_meta(dict(row))
244
+
245
+ def get_artifact_bytes(self, dataset_id: str) -> bytes | None:
246
+ """Get dataset artifact bytes from filesystem.
247
+
248
+ Args:
249
+ dataset_id: Unique identifier for the dataset.
250
+
251
+ Returns:
252
+ NPZ bytes if found, None otherwise.
253
+ """
254
+ artifact_path = self._artifact_file(dataset_id)
255
+ if not artifact_path.exists():
256
+ return None
257
+ return artifact_path.read_bytes()
258
+
259
+ def exists(self, dataset_id: str) -> bool:
260
+ """Check if dataset exists in PostgreSQL.
261
+
262
+ Args:
263
+ dataset_id: Unique identifier for the dataset.
264
+
265
+ Returns:
266
+ True if the dataset exists, False otherwise.
267
+ """
268
+ with self._get_connection() as conn:
269
+ with conn.cursor() as cur:
270
+ cur.execute("SELECT 1 FROM datasets WHERE dataset_id = %s", (dataset_id,))
271
+ return cur.fetchone() is not None
272
+
273
+ def delete(self, dataset_id: str) -> bool:
274
+ """Delete dataset from PostgreSQL and filesystem.
275
+
276
+ Args:
277
+ dataset_id: Unique identifier for the dataset.
278
+
279
+ Returns:
280
+ True if the dataset was deleted, False if it didn't exist.
281
+ """
282
+ with self._get_connection() as conn:
283
+ with conn.cursor() as cur:
284
+ cur.execute(
285
+ "DELETE FROM datasets WHERE dataset_id = %s RETURNING dataset_id",
286
+ (dataset_id,),
287
+ )
288
+ deleted = cur.fetchone() is not None
289
+ conn.commit()
290
+
291
+ artifact_path = self._artifact_file(dataset_id)
292
+ if artifact_path.exists():
293
+ artifact_path.unlink()
294
+
295
+ return deleted
296
+
297
+ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]:
298
+ """List dataset IDs from PostgreSQL.
299
+
300
+ Args:
301
+ limit: Maximum number of dataset IDs to return.
302
+ offset: Number of dataset IDs to skip.
303
+
304
+ Returns:
305
+ List of dataset IDs.
306
+ """
307
+ with self._get_connection() as conn:
308
+ with conn.cursor() as cur:
309
+ cur.execute(
310
+ "SELECT dataset_id FROM datasets ORDER BY created_at DESC LIMIT %s OFFSET %s",
311
+ (limit, offset),
312
+ )
313
+ rows = cur.fetchall()
314
+
315
+ return [row[0] for row in rows]
316
+
317
+ def update_meta(self, dataset_id: str, meta: DatasetMeta) -> bool:
318
+ """Update dataset metadata in PostgreSQL.
319
+
320
+ Args:
321
+ dataset_id: Unique identifier for the dataset.
322
+ meta: Updated dataset metadata.
323
+
324
+ Returns:
325
+ True if the dataset was updated, False if it didn't exist.
326
+ """
327
+ row = self._meta_to_row(meta)
328
+
329
+ update_sql = """
330
+ UPDATE datasets SET
331
+ generator = %(generator)s,
332
+ generator_version = %(generator_version)s,
333
+ params = %(params)s::jsonb,
334
+ n_samples = %(n_samples)s,
335
+ n_features = %(n_features)s,
336
+ n_classes = %(n_classes)s,
337
+ n_train = %(n_train)s,
338
+ n_test = %(n_test)s,
339
+ class_distribution = %(class_distribution)s::jsonb,
340
+ artifact_formats = %(artifact_formats)s,
341
+ checksum = %(checksum)s,
342
+ tags = %(tags)s,
343
+ ttl_seconds = %(ttl_seconds)s,
344
+ expires_at = %(expires_at)s,
345
+ last_accessed_at = %(last_accessed_at)s,
346
+ access_count = %(access_count)s
347
+ WHERE dataset_id = %(dataset_id)s
348
+ """
349
+
350
+ with self._get_connection() as conn:
351
+ with conn.cursor() as cur:
352
+ cur.execute(update_sql, row)
353
+ updated = cur.rowcount > 0
354
+ conn.commit()
355
+
356
+ return updated
357
+
358
+ def list_all_metadata(self) -> list[DatasetMeta]:
359
+ """List all dataset metadata from PostgreSQL.
360
+
361
+ Returns:
362
+ List of all DatasetMeta objects.
363
+ """
364
+ with self._get_connection() as conn:
365
+ with conn.cursor(cursor_factory=RealDictCursor) as cur:
366
+ cur.execute("SELECT * FROM datasets ORDER BY created_at DESC")
367
+ rows = cur.fetchall()
368
+
369
+ return [self._row_to_meta(dict(row)) for row in rows]
370
+
371
+ def close(self) -> None:
372
+ """Close database connections (no-op for connection-per-request pattern)."""
373
+ pass
@@ -0,0 +1,264 @@
1
+ """Redis-backed dataset storage for caching and distributed deployments."""
2
+
3
+ import io
4
+ import json
5
+ from datetime import datetime # noqa: F401 - used by DatasetMeta serialization
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+
10
+ from juniper_data.core.models import DatasetMeta
11
+
12
+ from .base import DatasetStore
13
+
14
+ try:
15
+ import redis
16
+
17
+ REDIS_AVAILABLE = True
18
+ except ImportError:
19
+ REDIS_AVAILABLE = False
20
+ redis = None # type: ignore[assignment]
21
+
22
+
23
+ class RedisDatasetStore(DatasetStore):
24
+ """Redis-backed dataset storage.
25
+
26
+ Uses Redis for both metadata (as JSON) and artifact storage (as bytes).
27
+ Suitable for caching layers and distributed deployments.
28
+
29
+ Requires the `redis` package: pip install redis
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ host: str = "localhost",
35
+ port: int = 6379,
36
+ db: int = 0,
37
+ password: str | None = None,
38
+ key_prefix: str = "juniper:dataset:",
39
+ default_ttl: int | None = None,
40
+ connection_pool: Any | None = None,
41
+ ) -> None:
42
+ """Initialize Redis connection.
43
+
44
+ Args:
45
+ host: Redis server hostname.
46
+ port: Redis server port.
47
+ db: Redis database number.
48
+ password: Redis password (optional).
49
+ key_prefix: Prefix for all Redis keys.
50
+ default_ttl: Default TTL for stored datasets in seconds (optional).
51
+ connection_pool: Optional existing Redis connection pool.
52
+
53
+ Raises:
54
+ ImportError: If redis package is not installed.
55
+ """
56
+ if not REDIS_AVAILABLE:
57
+ raise ImportError("Redis package not installed. Install with: pip install redis")
58
+
59
+ self._key_prefix = key_prefix
60
+ self._default_ttl = default_ttl
61
+
62
+ if connection_pool:
63
+ self._client: redis.Redis[bytes] = redis.Redis(connection_pool=connection_pool)
64
+ else:
65
+ self._client = redis.Redis(
66
+ host=host,
67
+ port=port,
68
+ db=db,
69
+ password=password,
70
+ decode_responses=False,
71
+ )
72
+
73
+ def _meta_key(self, dataset_id: str) -> str:
74
+ """Get Redis key for metadata."""
75
+ return f"{self._key_prefix}{dataset_id}:meta"
76
+
77
+ def _artifact_key(self, dataset_id: str) -> str:
78
+ """Get Redis key for artifact data."""
79
+ return f"{self._key_prefix}{dataset_id}:artifact"
80
+
81
+ def _encode_meta(self, meta: DatasetMeta) -> bytes:
82
+ """Encode metadata to JSON bytes."""
83
+ data = meta.model_dump(mode="json")
84
+ return json.dumps(data).encode("utf-8")
85
+
86
+ def _decode_meta(self, data: bytes) -> DatasetMeta:
87
+ """Decode metadata from JSON bytes."""
88
+ parsed = json.loads(data.decode("utf-8"))
89
+ return DatasetMeta(**parsed)
90
+
91
+ def _encode_arrays(self, arrays: dict[str, np.ndarray]) -> bytes:
92
+ """Encode arrays to NPZ bytes."""
93
+ buffer = io.BytesIO()
94
+ np.savez_compressed(buffer, **arrays) # type: ignore[arg-type]
95
+ return buffer.getvalue()
96
+
97
+ def save(
98
+ self,
99
+ dataset_id: str,
100
+ meta: DatasetMeta,
101
+ arrays: dict[str, np.ndarray],
102
+ ) -> None:
103
+ """Save dataset to Redis.
104
+
105
+ Args:
106
+ dataset_id: Unique identifier for the dataset.
107
+ meta: Dataset metadata.
108
+ arrays: Dictionary of numpy arrays.
109
+ """
110
+ meta_bytes = self._encode_meta(meta)
111
+ artifact_bytes = self._encode_arrays(arrays)
112
+
113
+ meta_key = self._meta_key(dataset_id)
114
+ artifact_key = self._artifact_key(dataset_id)
115
+
116
+ ttl = meta.ttl_seconds or self._default_ttl
117
+
118
+ pipe = self._client.pipeline()
119
+ if ttl:
120
+ pipe.setex(meta_key, ttl, meta_bytes)
121
+ pipe.setex(artifact_key, ttl, artifact_bytes)
122
+ else:
123
+ pipe.set(meta_key, meta_bytes)
124
+ pipe.set(artifact_key, artifact_bytes)
125
+ pipe.execute()
126
+
127
+ def get_meta(self, dataset_id: str) -> DatasetMeta | None:
128
+ """Get dataset metadata from Redis.
129
+
130
+ Args:
131
+ dataset_id: Unique identifier for the dataset.
132
+
133
+ Returns:
134
+ Dataset metadata if found, None otherwise.
135
+ """
136
+ data = self._client.get(self._meta_key(dataset_id))
137
+ if data is None:
138
+ return None
139
+ return self._decode_meta(data)
140
+
141
+ def get_artifact_bytes(self, dataset_id: str) -> bytes | None:
142
+ """Get dataset artifact bytes from Redis.
143
+
144
+ Args:
145
+ dataset_id: Unique identifier for the dataset.
146
+
147
+ Returns:
148
+ NPZ bytes if found, None otherwise.
149
+ """
150
+ return self._client.get(self._artifact_key(dataset_id))
151
+
152
+ def exists(self, dataset_id: str) -> bool:
153
+ """Check if dataset exists in Redis.
154
+
155
+ Args:
156
+ dataset_id: Unique identifier for the dataset.
157
+
158
+ Returns:
159
+ True if the dataset exists, False otherwise.
160
+ """
161
+ return bool(self._client.exists(self._meta_key(dataset_id)))
162
+
163
+ def delete(self, dataset_id: str) -> bool:
164
+ """Delete dataset from Redis.
165
+
166
+ Args:
167
+ dataset_id: Unique identifier for the dataset.
168
+
169
+ Returns:
170
+ True if the dataset was deleted, False if it didn't exist.
171
+ """
172
+ meta_key = self._meta_key(dataset_id)
173
+ artifact_key = self._artifact_key(dataset_id)
174
+
175
+ deleted = self._client.delete(meta_key, artifact_key)
176
+ return deleted > 0
177
+
178
+ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]:
179
+ """List dataset IDs from Redis.
180
+
181
+ Args:
182
+ limit: Maximum number of dataset IDs to return.
183
+ offset: Number of dataset IDs to skip.
184
+
185
+ Returns:
186
+ List of dataset IDs.
187
+ """
188
+ pattern = f"{self._key_prefix}*:meta"
189
+ keys = list(self._client.scan_iter(match=pattern))
190
+
191
+ dataset_ids = []
192
+ for key in keys:
193
+ key_str = key.decode("utf-8") if isinstance(key, bytes) else key
194
+ dataset_id = key_str[len(self._key_prefix) : -5]
195
+ dataset_ids.append(dataset_id)
196
+
197
+ dataset_ids.sort()
198
+ return dataset_ids[offset : offset + limit]
199
+
200
+ def update_meta(self, dataset_id: str, meta: DatasetMeta) -> bool:
201
+ """Update dataset metadata in Redis.
202
+
203
+ Args:
204
+ dataset_id: Unique identifier for the dataset.
205
+ meta: Updated dataset metadata.
206
+
207
+ Returns:
208
+ True if the dataset was updated, False if it didn't exist.
209
+ """
210
+ meta_key = self._meta_key(dataset_id)
211
+ if not self._client.exists(meta_key):
212
+ return False
213
+
214
+ ttl = self._client.ttl(meta_key)
215
+ meta_bytes = self._encode_meta(meta)
216
+
217
+ if ttl > 0:
218
+ self._client.setex(meta_key, ttl, meta_bytes)
219
+ else:
220
+ self._client.set(meta_key, meta_bytes)
221
+
222
+ return True
223
+
224
+ def list_all_metadata(self) -> list[DatasetMeta]:
225
+ """List all dataset metadata from Redis.
226
+
227
+ Returns:
228
+ List of all DatasetMeta objects.
229
+ """
230
+ pattern = f"{self._key_prefix}*:meta"
231
+ keys = list(self._client.scan_iter(match=pattern))
232
+
233
+ metadata: list[DatasetMeta] = []
234
+ for key in keys:
235
+ data = self._client.get(key)
236
+ if data is not None:
237
+ metadata.append(self._decode_meta(data))
238
+
239
+ return metadata
240
+
241
+ def ping(self) -> bool:
242
+ """Check if Redis connection is alive.
243
+
244
+ Returns:
245
+ True if connected, False otherwise.
246
+ """
247
+ try:
248
+ return self._client.ping()
249
+ except Exception:
250
+ return False
251
+
252
+ def flush_prefix(self) -> int:
253
+ """Delete all keys with this store's prefix.
254
+
255
+ WARNING: This deletes all datasets in this store.
256
+
257
+ Returns:
258
+ Number of keys deleted.
259
+ """
260
+ pattern = f"{self._key_prefix}*"
261
+ keys = list(self._client.scan_iter(match=pattern))
262
+ if keys:
263
+ return self._client.delete(*keys)
264
+ return 0
@@ -0,0 +1 @@
1
+ """Test suite for Juniper Data."""
@@ -0,0 +1,68 @@
1
+ """Common pytest fixtures for juniper_data tests."""
2
+
3
+ import numpy as np
4
+ import pytest
5
+
6
+ from juniper_data.generators.spiral import SpiralGenerator, SpiralParams
7
+
8
+
9
+ @pytest.fixture
10
+ def default_spiral_params() -> SpiralParams:
11
+ """Default spiral parameters for testing."""
12
+ return SpiralParams()
13
+
14
+
15
+ @pytest.fixture
16
+ def two_spiral_params() -> SpiralParams:
17
+ """Parameters for a 2-spiral dataset with 100 points per spiral."""
18
+ return SpiralParams(
19
+ n_spirals=2,
20
+ n_points_per_spiral=100,
21
+ seed=42,
22
+ )
23
+
24
+
25
+ @pytest.fixture
26
+ def three_spiral_params() -> SpiralParams:
27
+ """Parameters for a 3-spiral dataset with 50 points per spiral."""
28
+ return SpiralParams(
29
+ n_spirals=3,
30
+ n_points_per_spiral=50,
31
+ seed=42,
32
+ )
33
+
34
+
35
+ @pytest.fixture
36
+ def minimal_spiral_params() -> SpiralParams:
37
+ """Minimal valid spiral parameters for fast tests."""
38
+ return SpiralParams(
39
+ n_spirals=2,
40
+ n_points_per_spiral=10,
41
+ seed=42,
42
+ )
43
+
44
+
45
+ @pytest.fixture
46
+ def generated_two_spiral_dataset(two_spiral_params: SpiralParams) -> dict[str, np.ndarray]:
47
+ """Generate a 2-spiral dataset for testing."""
48
+ return SpiralGenerator.generate(two_spiral_params)
49
+
50
+
51
+ @pytest.fixture
52
+ def generated_three_spiral_dataset(three_spiral_params: SpiralParams) -> dict[str, np.ndarray]:
53
+ """Generate a 3-spiral dataset for testing."""
54
+ return SpiralGenerator.generate(three_spiral_params)
55
+
56
+
57
+ @pytest.fixture
58
+ def generated_minimal_dataset(minimal_spiral_params: SpiralParams) -> dict[str, np.ndarray]:
59
+ """Generate a minimal dataset for fast tests."""
60
+ return SpiralGenerator.generate(minimal_spiral_params)
61
+
62
+
63
+ @pytest.fixture
64
+ def sample_arrays() -> dict[str, np.ndarray]:
65
+ """Simple sample arrays for split/shuffle testing."""
66
+ X = np.arange(20).reshape(10, 2).astype(np.float32)
67
+ y = np.eye(2, dtype=np.float32)[np.arange(10) % 2]
68
+ return {"X": X, "y": y}