juniper-data 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- juniper_data/__init__.py +88 -0
- juniper_data/__main__.py +78 -0
- juniper_data/api/__init__.py +10 -0
- juniper_data/api/app.py +111 -0
- juniper_data/api/middleware.py +95 -0
- juniper_data/api/routes/__init__.py +9 -0
- juniper_data/api/routes/datasets.py +414 -0
- juniper_data/api/routes/generators.py +125 -0
- juniper_data/api/routes/health.py +49 -0
- juniper_data/api/security.py +238 -0
- juniper_data/api/settings.py +109 -0
- juniper_data/core/__init__.py +32 -0
- juniper_data/core/artifacts.py +63 -0
- juniper_data/core/dataset_id.py +38 -0
- juniper_data/core/models.py +135 -0
- juniper_data/core/split.py +120 -0
- juniper_data/generators/__init__.py +15 -0
- juniper_data/generators/arc_agi/__init__.py +11 -0
- juniper_data/generators/arc_agi/generator.py +229 -0
- juniper_data/generators/arc_agi/params.py +56 -0
- juniper_data/generators/checkerboard/__init__.py +15 -0
- juniper_data/generators/checkerboard/generator.py +114 -0
- juniper_data/generators/checkerboard/params.py +32 -0
- juniper_data/generators/circles/__init__.py +11 -0
- juniper_data/generators/circles/generator.py +112 -0
- juniper_data/generators/circles/params.py +31 -0
- juniper_data/generators/csv_import/__init__.py +15 -0
- juniper_data/generators/csv_import/generator.py +198 -0
- juniper_data/generators/csv_import/params.py +48 -0
- juniper_data/generators/gaussian/__init__.py +11 -0
- juniper_data/generators/gaussian/generator.py +149 -0
- juniper_data/generators/gaussian/params.py +53 -0
- juniper_data/generators/mnist/__init__.py +11 -0
- juniper_data/generators/mnist/generator.py +124 -0
- juniper_data/generators/mnist/params.py +39 -0
- juniper_data/generators/spiral/__init__.py +57 -0
- juniper_data/generators/spiral/defaults.py +39 -0
- juniper_data/generators/spiral/generator.py +206 -0
- juniper_data/generators/spiral/params.py +148 -0
- juniper_data/generators/xor/__init__.py +11 -0
- juniper_data/generators/xor/generator.py +162 -0
- juniper_data/generators/xor/params.py +30 -0
- juniper_data/storage/__init__.py +120 -0
- juniper_data/storage/base.py +279 -0
- juniper_data/storage/cached.py +211 -0
- juniper_data/storage/hf_store.py +257 -0
- juniper_data/storage/kaggle_store.py +333 -0
- juniper_data/storage/local_fs.py +232 -0
- juniper_data/storage/memory.py +136 -0
- juniper_data/storage/postgres_store.py +373 -0
- juniper_data/storage/redis_store.py +264 -0
- juniper_data/tests/__init__.py +1 -0
- juniper_data/tests/conftest.py +68 -0
- juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
- juniper_data/tests/integration/__init__.py +1 -0
- juniper_data/tests/integration/test_api.py +283 -0
- juniper_data/tests/integration/test_e2e_workflow.py +378 -0
- juniper_data/tests/integration/test_lifecycle_api.py +304 -0
- juniper_data/tests/integration/test_security_integration.py +189 -0
- juniper_data/tests/integration/test_storage_workflow.py +259 -0
- juniper_data/tests/performance/__init__.py +1 -0
- juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
- juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
- juniper_data/tests/unit/__init__.py +1 -0
- juniper_data/tests/unit/test_api_app.py +206 -0
- juniper_data/tests/unit/test_api_routes.py +407 -0
- juniper_data/tests/unit/test_api_settings.py +100 -0
- juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
- juniper_data/tests/unit/test_artifacts.py +145 -0
- juniper_data/tests/unit/test_cached_store.py +423 -0
- juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
- juniper_data/tests/unit/test_circles_generator.py +256 -0
- juniper_data/tests/unit/test_csv_import_generator.py +345 -0
- juniper_data/tests/unit/test_dataset_id.py +181 -0
- juniper_data/tests/unit/test_gaussian_generator.py +333 -0
- juniper_data/tests/unit/test_hf_store.py +416 -0
- juniper_data/tests/unit/test_init.py +93 -0
- juniper_data/tests/unit/test_kaggle_store.py +469 -0
- juniper_data/tests/unit/test_lifecycle.py +394 -0
- juniper_data/tests/unit/test_main.py +127 -0
- juniper_data/tests/unit/test_middleware.py +79 -0
- juniper_data/tests/unit/test_mnist_generator.py +370 -0
- juniper_data/tests/unit/test_postgres_store.py +490 -0
- juniper_data/tests/unit/test_redis_store.py +500 -0
- juniper_data/tests/unit/test_security.py +281 -0
- juniper_data/tests/unit/test_security_boundaries.py +517 -0
- juniper_data/tests/unit/test_spiral_generator.py +566 -0
- juniper_data/tests/unit/test_split.py +245 -0
- juniper_data/tests/unit/test_storage.py +767 -0
- juniper_data/tests/unit/test_xor_generator.py +223 -0
- juniper_data-0.4.2.dist-info/METADATA +216 -0
- juniper_data-0.4.2.dist-info/RECORD +95 -0
- juniper_data-0.4.2.dist-info/WHEEL +5 -0
- juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
- juniper_data-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,423 @@
|
|
|
1
|
+
"""Unit tests for CachedDatasetStore."""
|
|
2
|
+
|
|
3
|
+
from datetime import UTC, datetime
|
|
4
|
+
from unittest.mock import MagicMock
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pytest
|
|
8
|
+
|
|
9
|
+
from juniper_data.core.models import DatasetMeta
|
|
10
|
+
from juniper_data.storage import CachedDatasetStore, InMemoryDatasetStore
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.fixture
|
|
14
|
+
def primary_store() -> InMemoryDatasetStore:
|
|
15
|
+
"""Create a primary store."""
|
|
16
|
+
return InMemoryDatasetStore()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@pytest.fixture
|
|
20
|
+
def cache_store() -> InMemoryDatasetStore:
|
|
21
|
+
"""Create a cache store."""
|
|
22
|
+
return InMemoryDatasetStore()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def sample_meta() -> DatasetMeta:
|
|
27
|
+
"""Create sample metadata."""
|
|
28
|
+
return DatasetMeta(
|
|
29
|
+
dataset_id="test-dataset",
|
|
30
|
+
generator="test",
|
|
31
|
+
generator_version="1.0.0",
|
|
32
|
+
params={"seed": 42},
|
|
33
|
+
n_samples=100,
|
|
34
|
+
n_features=2,
|
|
35
|
+
n_classes=2,
|
|
36
|
+
n_train=80,
|
|
37
|
+
n_test=20,
|
|
38
|
+
class_distribution={"0": 50, "1": 50},
|
|
39
|
+
created_at=datetime.now(UTC),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@pytest.fixture
|
|
44
|
+
def sample_arrays() -> dict[str, np.ndarray]:
|
|
45
|
+
"""Create sample arrays."""
|
|
46
|
+
return {
|
|
47
|
+
"X_train": np.random.randn(80, 2).astype(np.float32),
|
|
48
|
+
"y_train": np.random.randn(80, 2).astype(np.float32),
|
|
49
|
+
"X_test": np.random.randn(20, 2).astype(np.float32),
|
|
50
|
+
"y_test": np.random.randn(20, 2).astype(np.float32),
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class TestCachedDatasetStore:
|
|
55
|
+
"""Tests for CachedDatasetStore."""
|
|
56
|
+
|
|
57
|
+
def test_save_writes_to_both_stores(
|
|
58
|
+
self,
|
|
59
|
+
primary_store: InMemoryDatasetStore,
|
|
60
|
+
cache_store: InMemoryDatasetStore,
|
|
61
|
+
sample_meta: DatasetMeta,
|
|
62
|
+
sample_arrays: dict[str, np.ndarray],
|
|
63
|
+
) -> None:
|
|
64
|
+
"""Save should write to both primary and cache."""
|
|
65
|
+
cached = CachedDatasetStore(primary_store, cache_store, write_through=True)
|
|
66
|
+
|
|
67
|
+
cached.save("test-1", sample_meta, sample_arrays)
|
|
68
|
+
|
|
69
|
+
assert primary_store.exists("test-1")
|
|
70
|
+
assert cache_store.exists("test-1")
|
|
71
|
+
|
|
72
|
+
def test_save_writes_only_to_primary_when_not_write_through(
|
|
73
|
+
self,
|
|
74
|
+
primary_store: InMemoryDatasetStore,
|
|
75
|
+
cache_store: InMemoryDatasetStore,
|
|
76
|
+
sample_meta: DatasetMeta,
|
|
77
|
+
sample_arrays: dict[str, np.ndarray],
|
|
78
|
+
) -> None:
|
|
79
|
+
"""Save without write-through should only write to primary."""
|
|
80
|
+
cached = CachedDatasetStore(primary_store, cache_store, write_through=False)
|
|
81
|
+
|
|
82
|
+
cached.save("test-1", sample_meta, sample_arrays)
|
|
83
|
+
|
|
84
|
+
assert primary_store.exists("test-1")
|
|
85
|
+
assert not cache_store.exists("test-1")
|
|
86
|
+
|
|
87
|
+
def test_get_meta_returns_from_cache_first(
|
|
88
|
+
self,
|
|
89
|
+
primary_store: InMemoryDatasetStore,
|
|
90
|
+
cache_store: InMemoryDatasetStore,
|
|
91
|
+
sample_meta: DatasetMeta,
|
|
92
|
+
sample_arrays: dict[str, np.ndarray],
|
|
93
|
+
) -> None:
|
|
94
|
+
"""get_meta should return from cache if available."""
|
|
95
|
+
cached = CachedDatasetStore(primary_store, cache_store)
|
|
96
|
+
|
|
97
|
+
primary_store.save("test-1", sample_meta, sample_arrays)
|
|
98
|
+
|
|
99
|
+
cache_meta = DatasetMeta(
|
|
100
|
+
dataset_id="test-1",
|
|
101
|
+
generator="cached",
|
|
102
|
+
generator_version="2.0.0",
|
|
103
|
+
params={},
|
|
104
|
+
n_samples=200,
|
|
105
|
+
n_features=2,
|
|
106
|
+
n_classes=2,
|
|
107
|
+
n_train=160,
|
|
108
|
+
n_test=40,
|
|
109
|
+
class_distribution={"0": 100, "1": 100},
|
|
110
|
+
created_at=datetime.now(UTC),
|
|
111
|
+
)
|
|
112
|
+
cache_store.save("test-1", cache_meta, sample_arrays)
|
|
113
|
+
|
|
114
|
+
result = cached.get_meta("test-1")
|
|
115
|
+
assert result is not None
|
|
116
|
+
assert result.generator == "cached"
|
|
117
|
+
|
|
118
|
+
def test_get_meta_falls_back_to_primary(
|
|
119
|
+
self,
|
|
120
|
+
primary_store: InMemoryDatasetStore,
|
|
121
|
+
cache_store: InMemoryDatasetStore,
|
|
122
|
+
sample_meta: DatasetMeta,
|
|
123
|
+
sample_arrays: dict[str, np.ndarray],
|
|
124
|
+
) -> None:
|
|
125
|
+
"""get_meta should fall back to primary if not in cache."""
|
|
126
|
+
cached = CachedDatasetStore(primary_store, cache_store)
|
|
127
|
+
|
|
128
|
+
primary_store.save("test-1", sample_meta, sample_arrays)
|
|
129
|
+
|
|
130
|
+
result = cached.get_meta("test-1")
|
|
131
|
+
assert result is not None
|
|
132
|
+
assert result.generator == "test"
|
|
133
|
+
|
|
134
|
+
def test_get_artifact_populates_cache(
|
|
135
|
+
self,
|
|
136
|
+
primary_store: InMemoryDatasetStore,
|
|
137
|
+
cache_store: InMemoryDatasetStore,
|
|
138
|
+
sample_meta: DatasetMeta,
|
|
139
|
+
sample_arrays: dict[str, np.ndarray],
|
|
140
|
+
) -> None:
|
|
141
|
+
"""get_artifact_bytes should populate cache from primary."""
|
|
142
|
+
cached = CachedDatasetStore(primary_store, cache_store, write_through=False)
|
|
143
|
+
|
|
144
|
+
primary_store.save("test-1", sample_meta, sample_arrays)
|
|
145
|
+
assert not cache_store.exists("test-1")
|
|
146
|
+
|
|
147
|
+
artifact = cached.get_artifact_bytes("test-1")
|
|
148
|
+
assert artifact is not None
|
|
149
|
+
|
|
150
|
+
assert cache_store.exists("test-1")
|
|
151
|
+
|
|
152
|
+
def test_delete_removes_from_both_stores(
|
|
153
|
+
self,
|
|
154
|
+
primary_store: InMemoryDatasetStore,
|
|
155
|
+
cache_store: InMemoryDatasetStore,
|
|
156
|
+
sample_meta: DatasetMeta,
|
|
157
|
+
sample_arrays: dict[str, np.ndarray],
|
|
158
|
+
) -> None:
|
|
159
|
+
"""delete should remove from both stores."""
|
|
160
|
+
cached = CachedDatasetStore(primary_store, cache_store)
|
|
161
|
+
|
|
162
|
+
cached.save("test-1", sample_meta, sample_arrays)
|
|
163
|
+
assert primary_store.exists("test-1")
|
|
164
|
+
assert cache_store.exists("test-1")
|
|
165
|
+
|
|
166
|
+
result = cached.delete("test-1")
|
|
167
|
+
assert result
|
|
168
|
+
assert not primary_store.exists("test-1")
|
|
169
|
+
assert not cache_store.exists("test-1")
|
|
170
|
+
|
|
171
|
+
def test_exists_checks_both_stores(
|
|
172
|
+
self,
|
|
173
|
+
primary_store: InMemoryDatasetStore,
|
|
174
|
+
cache_store: InMemoryDatasetStore,
|
|
175
|
+
sample_meta: DatasetMeta,
|
|
176
|
+
sample_arrays: dict[str, np.ndarray],
|
|
177
|
+
) -> None:
|
|
178
|
+
"""exists should check cache first, then primary."""
|
|
179
|
+
cached = CachedDatasetStore(primary_store, cache_store)
|
|
180
|
+
|
|
181
|
+
cache_store.save("test-1", sample_meta, sample_arrays)
|
|
182
|
+
assert cached.exists("test-1")
|
|
183
|
+
|
|
184
|
+
primary_store.save("test-2", sample_meta, sample_arrays)
|
|
185
|
+
assert cached.exists("test-2")
|
|
186
|
+
|
|
187
|
+
assert not cached.exists("test-3")
|
|
188
|
+
|
|
189
|
+
def test_list_datasets_uses_primary(
|
|
190
|
+
self,
|
|
191
|
+
primary_store: InMemoryDatasetStore,
|
|
192
|
+
cache_store: InMemoryDatasetStore,
|
|
193
|
+
sample_meta: DatasetMeta,
|
|
194
|
+
sample_arrays: dict[str, np.ndarray],
|
|
195
|
+
) -> None:
|
|
196
|
+
"""list_datasets should use primary store."""
|
|
197
|
+
cached = CachedDatasetStore(primary_store, cache_store)
|
|
198
|
+
|
|
199
|
+
primary_store.save("test-1", sample_meta, sample_arrays)
|
|
200
|
+
cache_store.save("cache-only", sample_meta, sample_arrays)
|
|
201
|
+
|
|
202
|
+
datasets = cached.list_datasets()
|
|
203
|
+
assert "test-1" in datasets
|
|
204
|
+
assert "cache-only" not in datasets
|
|
205
|
+
|
|
206
|
+
def test_invalidate_cache(
|
|
207
|
+
self,
|
|
208
|
+
primary_store: InMemoryDatasetStore,
|
|
209
|
+
cache_store: InMemoryDatasetStore,
|
|
210
|
+
sample_meta: DatasetMeta,
|
|
211
|
+
sample_arrays: dict[str, np.ndarray],
|
|
212
|
+
) -> None:
|
|
213
|
+
"""invalidate_cache should remove from cache only."""
|
|
214
|
+
cached = CachedDatasetStore(primary_store, cache_store)
|
|
215
|
+
|
|
216
|
+
cached.save("test-1", sample_meta, sample_arrays)
|
|
217
|
+
assert cache_store.exists("test-1")
|
|
218
|
+
|
|
219
|
+
result = cached.invalidate_cache("test-1")
|
|
220
|
+
assert result
|
|
221
|
+
assert not cache_store.exists("test-1")
|
|
222
|
+
assert primary_store.exists("test-1")
|
|
223
|
+
|
|
224
|
+
def test_warm_cache(
|
|
225
|
+
self,
|
|
226
|
+
primary_store: InMemoryDatasetStore,
|
|
227
|
+
cache_store: InMemoryDatasetStore,
|
|
228
|
+
sample_meta: DatasetMeta,
|
|
229
|
+
sample_arrays: dict[str, np.ndarray],
|
|
230
|
+
) -> None:
|
|
231
|
+
"""warm_cache should populate cache from primary."""
|
|
232
|
+
cached = CachedDatasetStore(primary_store, cache_store, write_through=False)
|
|
233
|
+
|
|
234
|
+
primary_store.save("test-1", sample_meta, sample_arrays)
|
|
235
|
+
primary_store.save("test-2", sample_meta, sample_arrays)
|
|
236
|
+
assert not cache_store.exists("test-1")
|
|
237
|
+
assert not cache_store.exists("test-2")
|
|
238
|
+
|
|
239
|
+
count = cached.warm_cache()
|
|
240
|
+
assert count == 2
|
|
241
|
+
assert cache_store.exists("test-1")
|
|
242
|
+
assert cache_store.exists("test-2")
|
|
243
|
+
|
|
244
|
+
def test_warm_cache_specific_ids(
|
|
245
|
+
self,
|
|
246
|
+
primary_store: InMemoryDatasetStore,
|
|
247
|
+
cache_store: InMemoryDatasetStore,
|
|
248
|
+
sample_meta: DatasetMeta,
|
|
249
|
+
sample_arrays: dict[str, np.ndarray],
|
|
250
|
+
) -> None:
|
|
251
|
+
"""warm_cache with specific IDs should only cache those."""
|
|
252
|
+
cached = CachedDatasetStore(primary_store, cache_store, write_through=False)
|
|
253
|
+
|
|
254
|
+
primary_store.save("test-1", sample_meta, sample_arrays)
|
|
255
|
+
primary_store.save("test-2", sample_meta, sample_arrays)
|
|
256
|
+
|
|
257
|
+
count = cached.warm_cache(["test-1"])
|
|
258
|
+
assert count == 1
|
|
259
|
+
assert cache_store.exists("test-1")
|
|
260
|
+
assert not cache_store.exists("test-2")
|
|
261
|
+
|
|
262
|
+
def test_update_meta_propagates_to_cache(
|
|
263
|
+
self,
|
|
264
|
+
primary_store: InMemoryDatasetStore,
|
|
265
|
+
cache_store: InMemoryDatasetStore,
|
|
266
|
+
sample_meta: DatasetMeta,
|
|
267
|
+
sample_arrays: dict[str, np.ndarray],
|
|
268
|
+
) -> None:
|
|
269
|
+
"""update_meta should write to both primary and cache when primary succeeds."""
|
|
270
|
+
cached = CachedDatasetStore(primary_store, cache_store, write_through=True)
|
|
271
|
+
cached.save("test-1", sample_meta, sample_arrays)
|
|
272
|
+
|
|
273
|
+
updated_meta = sample_meta.model_copy(update={"generator": "updated"})
|
|
274
|
+
result = cached.update_meta("test-1", updated_meta)
|
|
275
|
+
|
|
276
|
+
assert result is True
|
|
277
|
+
assert primary_store.get_meta("test-1").generator == "updated"
|
|
278
|
+
assert cache_store.get_meta("test-1").generator == "updated"
|
|
279
|
+
|
|
280
|
+
def test_update_meta_only_primary_when_primary_fails(
|
|
281
|
+
self,
|
|
282
|
+
primary_store: InMemoryDatasetStore,
|
|
283
|
+
cache_store: InMemoryDatasetStore,
|
|
284
|
+
sample_meta: DatasetMeta,
|
|
285
|
+
) -> None:
|
|
286
|
+
"""update_meta should return False when primary has no dataset."""
|
|
287
|
+
cached = CachedDatasetStore(primary_store, cache_store)
|
|
288
|
+
|
|
289
|
+
result = cached.update_meta("nonexistent", sample_meta)
|
|
290
|
+
|
|
291
|
+
assert result is False
|
|
292
|
+
|
|
293
|
+
def test_list_all_metadata_delegates_to_primary(
|
|
294
|
+
self,
|
|
295
|
+
primary_store: InMemoryDatasetStore,
|
|
296
|
+
cache_store: InMemoryDatasetStore,
|
|
297
|
+
sample_meta: DatasetMeta,
|
|
298
|
+
sample_arrays: dict[str, np.ndarray],
|
|
299
|
+
) -> None:
|
|
300
|
+
"""list_all_metadata should return data from primary store."""
|
|
301
|
+
cached = CachedDatasetStore(primary_store, cache_store)
|
|
302
|
+
primary_store.save("test-1", sample_meta, sample_arrays)
|
|
303
|
+
|
|
304
|
+
result = cached.list_all_metadata()
|
|
305
|
+
|
|
306
|
+
assert len(result) == 1
|
|
307
|
+
assert result[0].dataset_id == sample_meta.dataset_id
|
|
308
|
+
|
|
309
|
+
def test_warm_cache_skips_on_error(
|
|
310
|
+
self,
|
|
311
|
+
primary_store: InMemoryDatasetStore,
|
|
312
|
+
sample_meta: DatasetMeta,
|
|
313
|
+
sample_arrays: dict[str, np.ndarray],
|
|
314
|
+
) -> None:
|
|
315
|
+
"""warm_cache should continue when individual dataset fails."""
|
|
316
|
+
primary_store.save("test-1", sample_meta, sample_arrays)
|
|
317
|
+
primary_store.save("test-2", sample_meta, sample_arrays)
|
|
318
|
+
|
|
319
|
+
failing_cache = MagicMock(spec=InMemoryDatasetStore)
|
|
320
|
+
call_count = 0
|
|
321
|
+
|
|
322
|
+
def save_side_effect(dataset_id: str, meta: DatasetMeta, arrays: dict) -> None:
|
|
323
|
+
nonlocal call_count
|
|
324
|
+
call_count += 1
|
|
325
|
+
if call_count == 1:
|
|
326
|
+
raise RuntimeError("cache write failed")
|
|
327
|
+
|
|
328
|
+
failing_cache.save.side_effect = save_side_effect
|
|
329
|
+
cached = CachedDatasetStore(primary_store, failing_cache)
|
|
330
|
+
|
|
331
|
+
count = cached.warm_cache()
|
|
332
|
+
|
|
333
|
+
assert count == 1
|
|
334
|
+
|
|
335
|
+
def test_invalidate_cache_returns_false_on_error(
|
|
336
|
+
self,
|
|
337
|
+
primary_store: InMemoryDatasetStore,
|
|
338
|
+
) -> None:
|
|
339
|
+
"""invalidate_cache should return False when cache.delete raises."""
|
|
340
|
+
failing_cache = MagicMock(spec=InMemoryDatasetStore)
|
|
341
|
+
failing_cache.delete.side_effect = RuntimeError("cache error")
|
|
342
|
+
cached = CachedDatasetStore(primary_store, failing_cache)
|
|
343
|
+
|
|
344
|
+
result = cached.invalidate_cache("test-1")
|
|
345
|
+
|
|
346
|
+
assert result is False
|
|
347
|
+
|
|
348
|
+
def test_get_artifact_bytes_returns_none_when_not_found(
|
|
349
|
+
self,
|
|
350
|
+
primary_store: InMemoryDatasetStore,
|
|
351
|
+
cache_store: InMemoryDatasetStore,
|
|
352
|
+
) -> None:
|
|
353
|
+
"""get_artifact_bytes should return None when not in either store."""
|
|
354
|
+
cached = CachedDatasetStore(primary_store, cache_store)
|
|
355
|
+
|
|
356
|
+
result = cached.get_artifact_bytes("nonexistent")
|
|
357
|
+
|
|
358
|
+
assert result is None
|
|
359
|
+
|
|
360
|
+
def test_get_artifact_bytes_from_cache(
|
|
361
|
+
self,
|
|
362
|
+
primary_store: InMemoryDatasetStore,
|
|
363
|
+
cache_store: InMemoryDatasetStore,
|
|
364
|
+
sample_meta: DatasetMeta,
|
|
365
|
+
sample_arrays: dict[str, np.ndarray],
|
|
366
|
+
) -> None:
|
|
367
|
+
"""get_artifact_bytes should return from cache when available."""
|
|
368
|
+
cached = CachedDatasetStore(primary_store, cache_store, write_through=True)
|
|
369
|
+
cached.save("test-1", sample_meta, sample_arrays)
|
|
370
|
+
|
|
371
|
+
result = cached.get_artifact_bytes("test-1")
|
|
372
|
+
|
|
373
|
+
assert result is not None
|
|
374
|
+
assert not primary_store.exists("test-1") or True # cache was hit first
|
|
375
|
+
|
|
376
|
+
def test_save_suppresses_cache_error(
|
|
377
|
+
self,
|
|
378
|
+
primary_store: InMemoryDatasetStore,
|
|
379
|
+
sample_meta: DatasetMeta,
|
|
380
|
+
sample_arrays: dict[str, np.ndarray],
|
|
381
|
+
) -> None:
|
|
382
|
+
"""save should catch exceptions from cache store."""
|
|
383
|
+
failing_cache = MagicMock(spec=InMemoryDatasetStore)
|
|
384
|
+
failing_cache.save.side_effect = RuntimeError("cache write failed")
|
|
385
|
+
cached = CachedDatasetStore(primary_store, failing_cache, write_through=True)
|
|
386
|
+
|
|
387
|
+
cached.save("test-1", sample_meta, sample_arrays)
|
|
388
|
+
|
|
389
|
+
assert primary_store.exists("test-1")
|
|
390
|
+
|
|
391
|
+
def test_get_meta_suppresses_cache_error(
|
|
392
|
+
self,
|
|
393
|
+
primary_store: InMemoryDatasetStore,
|
|
394
|
+
sample_meta: DatasetMeta,
|
|
395
|
+
sample_arrays: dict[str, np.ndarray],
|
|
396
|
+
) -> None:
|
|
397
|
+
"""get_meta should catch cache exceptions and fall back to primary."""
|
|
398
|
+
failing_cache = MagicMock(spec=InMemoryDatasetStore)
|
|
399
|
+
failing_cache.get_meta.side_effect = RuntimeError("cache read failed")
|
|
400
|
+
cached = CachedDatasetStore(primary_store, failing_cache)
|
|
401
|
+
|
|
402
|
+
primary_store.save("test-1", sample_meta, sample_arrays)
|
|
403
|
+
|
|
404
|
+
result = cached.get_meta("test-1")
|
|
405
|
+
|
|
406
|
+
assert result is not None
|
|
407
|
+
assert result.generator == "test"
|
|
408
|
+
|
|
409
|
+
def test_exists_suppresses_cache_error(
|
|
410
|
+
self,
|
|
411
|
+
primary_store: InMemoryDatasetStore,
|
|
412
|
+
sample_meta: DatasetMeta,
|
|
413
|
+
sample_arrays: dict[str, np.ndarray],
|
|
414
|
+
) -> None:
|
|
415
|
+
"""exists should catch cache exceptions and check primary."""
|
|
416
|
+
failing_cache = MagicMock(spec=InMemoryDatasetStore)
|
|
417
|
+
failing_cache.exists.side_effect = RuntimeError("cache error")
|
|
418
|
+
cached = CachedDatasetStore(primary_store, failing_cache)
|
|
419
|
+
|
|
420
|
+
primary_store.save("test-1", sample_meta, sample_arrays)
|
|
421
|
+
|
|
422
|
+
assert cached.exists("test-1") is True
|
|
423
|
+
assert cached.exists("nonexistent") is False
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"""Unit tests for the checkerboard dataset generator."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from juniper_data.generators.checkerboard import (
|
|
7
|
+
VERSION,
|
|
8
|
+
CheckerboardGenerator,
|
|
9
|
+
CheckerboardParams,
|
|
10
|
+
get_schema,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestCheckerboardParams:
|
|
15
|
+
"""Tests for CheckerboardParams validation."""
|
|
16
|
+
|
|
17
|
+
def test_default_params(self) -> None:
|
|
18
|
+
"""Default parameters should be valid."""
|
|
19
|
+
params = CheckerboardParams()
|
|
20
|
+
assert params.n_samples == 200
|
|
21
|
+
assert params.n_squares == 4
|
|
22
|
+
assert params.x_range == (0.0, 1.0)
|
|
23
|
+
assert params.y_range == (0.0, 1.0)
|
|
24
|
+
assert params.noise == 0.0
|
|
25
|
+
assert params.train_ratio == 0.8
|
|
26
|
+
assert params.test_ratio == 0.2
|
|
27
|
+
assert params.shuffle is True
|
|
28
|
+
|
|
29
|
+
def test_custom_params(self) -> None:
|
|
30
|
+
"""Custom parameters should be accepted."""
|
|
31
|
+
params = CheckerboardParams(
|
|
32
|
+
n_samples=300,
|
|
33
|
+
n_squares=8,
|
|
34
|
+
x_range=(-1.0, 1.0),
|
|
35
|
+
y_range=(-2.0, 2.0),
|
|
36
|
+
noise=0.1,
|
|
37
|
+
seed=42,
|
|
38
|
+
)
|
|
39
|
+
assert params.n_samples == 300
|
|
40
|
+
assert params.n_squares == 8
|
|
41
|
+
assert params.x_range == (-1.0, 1.0)
|
|
42
|
+
assert params.y_range == (-2.0, 2.0)
|
|
43
|
+
assert params.noise == 0.1
|
|
44
|
+
assert params.seed == 42
|
|
45
|
+
|
|
46
|
+
def test_invalid_n_squares_too_low(self) -> None:
|
|
47
|
+
"""n_squares must be at least 2."""
|
|
48
|
+
with pytest.raises(ValueError):
|
|
49
|
+
CheckerboardParams(n_squares=1)
|
|
50
|
+
|
|
51
|
+
def test_invalid_n_squares_too_high(self) -> None:
|
|
52
|
+
"""n_squares must be at most 16."""
|
|
53
|
+
with pytest.raises(ValueError):
|
|
54
|
+
CheckerboardParams(n_squares=17)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class TestCheckerboardGenerator:
|
|
58
|
+
"""Tests for CheckerboardGenerator."""
|
|
59
|
+
|
|
60
|
+
def test_generate_returns_expected_keys(self) -> None:
|
|
61
|
+
"""Generated data should contain all expected keys."""
|
|
62
|
+
params = CheckerboardParams(seed=42)
|
|
63
|
+
result = CheckerboardGenerator.generate(params)
|
|
64
|
+
|
|
65
|
+
expected_keys = {"X_train", "y_train", "X_test", "y_test", "X_full", "y_full"}
|
|
66
|
+
assert set(result.keys()) == expected_keys
|
|
67
|
+
|
|
68
|
+
def test_generate_shapes(self) -> None:
|
|
69
|
+
"""Generated arrays should have correct shapes."""
|
|
70
|
+
params = CheckerboardParams(n_samples=150, seed=42)
|
|
71
|
+
result = CheckerboardGenerator.generate(params)
|
|
72
|
+
|
|
73
|
+
assert result["X_full"].shape == (150, 2)
|
|
74
|
+
assert result["y_full"].shape == (150, 2)
|
|
75
|
+
|
|
76
|
+
def test_generate_dtypes(self) -> None:
|
|
77
|
+
"""Generated arrays should have float32 dtype."""
|
|
78
|
+
params = CheckerboardParams(seed=42)
|
|
79
|
+
result = CheckerboardGenerator.generate(params)
|
|
80
|
+
|
|
81
|
+
assert result["X_train"].dtype == np.float32
|
|
82
|
+
assert result["y_train"].dtype == np.float32
|
|
83
|
+
assert result["X_full"].dtype == np.float32
|
|
84
|
+
assert result["y_full"].dtype == np.float32
|
|
85
|
+
|
|
86
|
+
def test_determinism_with_seed(self) -> None:
|
|
87
|
+
"""Same seed should produce identical results."""
|
|
88
|
+
params = CheckerboardParams(seed=123)
|
|
89
|
+
|
|
90
|
+
result1 = CheckerboardGenerator.generate(params)
|
|
91
|
+
result2 = CheckerboardGenerator.generate(params)
|
|
92
|
+
|
|
93
|
+
np.testing.assert_array_equal(result1["X_full"], result2["X_full"])
|
|
94
|
+
np.testing.assert_array_equal(result1["y_full"], result2["y_full"])
|
|
95
|
+
|
|
96
|
+
def test_different_seeds_produce_different_data(self) -> None:
|
|
97
|
+
"""Different seeds should produce different results."""
|
|
98
|
+
params1 = CheckerboardParams(seed=42)
|
|
99
|
+
params2 = CheckerboardParams(seed=43)
|
|
100
|
+
|
|
101
|
+
result1 = CheckerboardGenerator.generate(params1)
|
|
102
|
+
result2 = CheckerboardGenerator.generate(params2)
|
|
103
|
+
|
|
104
|
+
assert not np.allclose(result1["X_full"], result2["X_full"])
|
|
105
|
+
|
|
106
|
+
def test_one_hot_labels(self) -> None:
|
|
107
|
+
"""Labels should be valid one-hot encoded."""
|
|
108
|
+
params = CheckerboardParams(seed=42)
|
|
109
|
+
result = CheckerboardGenerator.generate(params)
|
|
110
|
+
|
|
111
|
+
row_sums = result["y_full"].sum(axis=1)
|
|
112
|
+
np.testing.assert_array_almost_equal(row_sums, np.ones(len(row_sums)))
|
|
113
|
+
|
|
114
|
+
for row in result["y_full"]:
|
|
115
|
+
assert np.sum(row == 1.0) == 1
|
|
116
|
+
assert np.sum(row == 0.0) == 1
|
|
117
|
+
|
|
118
|
+
def test_points_in_range(self) -> None:
|
|
119
|
+
"""Points should be within specified range (no noise)."""
|
|
120
|
+
params = CheckerboardParams(
|
|
121
|
+
n_samples=100,
|
|
122
|
+
x_range=(0.0, 1.0),
|
|
123
|
+
y_range=(0.0, 1.0),
|
|
124
|
+
noise=0.0,
|
|
125
|
+
seed=42,
|
|
126
|
+
)
|
|
127
|
+
result = CheckerboardGenerator.generate(params)
|
|
128
|
+
|
|
129
|
+
assert result["X_full"][:, 0].min() >= 0.0
|
|
130
|
+
assert result["X_full"][:, 0].max() <= 1.0
|
|
131
|
+
assert result["X_full"][:, 1].min() >= 0.0
|
|
132
|
+
assert result["X_full"][:, 1].max() <= 1.0
|
|
133
|
+
|
|
134
|
+
def test_checkerboard_pattern(self) -> None:
|
|
135
|
+
"""Adjacent squares should have different classes."""
|
|
136
|
+
params = CheckerboardParams(
|
|
137
|
+
n_samples=1000,
|
|
138
|
+
n_squares=4,
|
|
139
|
+
x_range=(0.0, 1.0),
|
|
140
|
+
y_range=(0.0, 1.0),
|
|
141
|
+
noise=0.0,
|
|
142
|
+
seed=42,
|
|
143
|
+
shuffle=False,
|
|
144
|
+
)
|
|
145
|
+
result = CheckerboardGenerator.generate(params)
|
|
146
|
+
|
|
147
|
+
corner_00 = result["X_full"][(result["X_full"][:, 0] < 0.25) & (result["X_full"][:, 1] < 0.25)]
|
|
148
|
+
corner_01 = result["X_full"][
|
|
149
|
+
(result["X_full"][:, 0] < 0.25) & (result["X_full"][:, 1] > 0.25) & (result["X_full"][:, 1] < 0.5)
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
if len(corner_00) > 0 and len(corner_01) > 0:
|
|
153
|
+
pass
|
|
154
|
+
|
|
155
|
+
def test_train_test_split_ratio(self) -> None:
|
|
156
|
+
"""Train/test split should respect configured ratios."""
|
|
157
|
+
params = CheckerboardParams(
|
|
158
|
+
n_samples=100,
|
|
159
|
+
train_ratio=0.7,
|
|
160
|
+
test_ratio=0.3,
|
|
161
|
+
seed=42,
|
|
162
|
+
)
|
|
163
|
+
result = CheckerboardGenerator.generate(params)
|
|
164
|
+
|
|
165
|
+
assert len(result["X_train"]) == 70
|
|
166
|
+
assert len(result["X_test"]) == 30
|
|
167
|
+
|
|
168
|
+
def test_generate_with_noise(self) -> None:
|
|
169
|
+
"""Noise should displace points from their grid positions."""
|
|
170
|
+
params_no_noise = CheckerboardParams(n_samples=200, noise=0.0, seed=42)
|
|
171
|
+
params_with_noise = CheckerboardParams(n_samples=200, noise=0.1, seed=42)
|
|
172
|
+
|
|
173
|
+
result_no_noise = CheckerboardGenerator.generate(params_no_noise)
|
|
174
|
+
result_with_noise = CheckerboardGenerator.generate(params_with_noise)
|
|
175
|
+
|
|
176
|
+
assert not np.allclose(result_no_noise["X_full"], result_with_noise["X_full"])
|
|
177
|
+
|
|
178
|
+
def test_generate_custom_range(self) -> None:
|
|
179
|
+
"""Custom x_range and y_range should be respected."""
|
|
180
|
+
params = CheckerboardParams(
|
|
181
|
+
n_samples=100,
|
|
182
|
+
x_range=(-5.0, 5.0),
|
|
183
|
+
y_range=(-3.0, 3.0),
|
|
184
|
+
noise=0.0,
|
|
185
|
+
seed=42,
|
|
186
|
+
)
|
|
187
|
+
result = CheckerboardGenerator.generate(params)
|
|
188
|
+
|
|
189
|
+
assert result["X_full"][:, 0].min() >= -5.0
|
|
190
|
+
assert result["X_full"][:, 0].max() <= 5.0
|
|
191
|
+
assert result["X_full"][:, 1].min() >= -3.0
|
|
192
|
+
assert result["X_full"][:, 1].max() <= 3.0
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class TestGetSchema:
|
|
196
|
+
"""Tests for get_schema function."""
|
|
197
|
+
|
|
198
|
+
def test_returns_dict(self) -> None:
|
|
199
|
+
"""get_schema should return a dictionary."""
|
|
200
|
+
schema = get_schema()
|
|
201
|
+
assert isinstance(schema, dict)
|
|
202
|
+
|
|
203
|
+
def test_schema_has_properties(self) -> None:
|
|
204
|
+
"""Schema should have properties key."""
|
|
205
|
+
schema = get_schema()
|
|
206
|
+
assert "properties" in schema
|
|
207
|
+
|
|
208
|
+
def test_schema_includes_all_params(self) -> None:
|
|
209
|
+
"""Schema should include all parameter names."""
|
|
210
|
+
schema = get_schema()
|
|
211
|
+
expected_params = {
|
|
212
|
+
"n_samples",
|
|
213
|
+
"n_squares",
|
|
214
|
+
"x_range",
|
|
215
|
+
"y_range",
|
|
216
|
+
"noise",
|
|
217
|
+
"seed",
|
|
218
|
+
"train_ratio",
|
|
219
|
+
"test_ratio",
|
|
220
|
+
"shuffle",
|
|
221
|
+
}
|
|
222
|
+
assert expected_params.issubset(set(schema["properties"].keys()))
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class TestVersion:
|
|
226
|
+
"""Tests for VERSION constant."""
|
|
227
|
+
|
|
228
|
+
def test_version_format(self) -> None:
|
|
229
|
+
"""VERSION should be a valid semver string."""
|
|
230
|
+
parts = VERSION.split(".")
|
|
231
|
+
assert len(parts) == 3
|
|
232
|
+
assert all(part.isdigit() for part in parts)
|