juniper-data 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- juniper_data/__init__.py +88 -0
- juniper_data/__main__.py +78 -0
- juniper_data/api/__init__.py +10 -0
- juniper_data/api/app.py +111 -0
- juniper_data/api/middleware.py +95 -0
- juniper_data/api/routes/__init__.py +9 -0
- juniper_data/api/routes/datasets.py +414 -0
- juniper_data/api/routes/generators.py +125 -0
- juniper_data/api/routes/health.py +49 -0
- juniper_data/api/security.py +238 -0
- juniper_data/api/settings.py +109 -0
- juniper_data/core/__init__.py +32 -0
- juniper_data/core/artifacts.py +63 -0
- juniper_data/core/dataset_id.py +38 -0
- juniper_data/core/models.py +135 -0
- juniper_data/core/split.py +120 -0
- juniper_data/generators/__init__.py +15 -0
- juniper_data/generators/arc_agi/__init__.py +11 -0
- juniper_data/generators/arc_agi/generator.py +229 -0
- juniper_data/generators/arc_agi/params.py +56 -0
- juniper_data/generators/checkerboard/__init__.py +15 -0
- juniper_data/generators/checkerboard/generator.py +114 -0
- juniper_data/generators/checkerboard/params.py +32 -0
- juniper_data/generators/circles/__init__.py +11 -0
- juniper_data/generators/circles/generator.py +112 -0
- juniper_data/generators/circles/params.py +31 -0
- juniper_data/generators/csv_import/__init__.py +15 -0
- juniper_data/generators/csv_import/generator.py +198 -0
- juniper_data/generators/csv_import/params.py +48 -0
- juniper_data/generators/gaussian/__init__.py +11 -0
- juniper_data/generators/gaussian/generator.py +149 -0
- juniper_data/generators/gaussian/params.py +53 -0
- juniper_data/generators/mnist/__init__.py +11 -0
- juniper_data/generators/mnist/generator.py +124 -0
- juniper_data/generators/mnist/params.py +39 -0
- juniper_data/generators/spiral/__init__.py +57 -0
- juniper_data/generators/spiral/defaults.py +39 -0
- juniper_data/generators/spiral/generator.py +206 -0
- juniper_data/generators/spiral/params.py +148 -0
- juniper_data/generators/xor/__init__.py +11 -0
- juniper_data/generators/xor/generator.py +162 -0
- juniper_data/generators/xor/params.py +30 -0
- juniper_data/storage/__init__.py +120 -0
- juniper_data/storage/base.py +279 -0
- juniper_data/storage/cached.py +211 -0
- juniper_data/storage/hf_store.py +257 -0
- juniper_data/storage/kaggle_store.py +333 -0
- juniper_data/storage/local_fs.py +232 -0
- juniper_data/storage/memory.py +136 -0
- juniper_data/storage/postgres_store.py +373 -0
- juniper_data/storage/redis_store.py +264 -0
- juniper_data/tests/__init__.py +1 -0
- juniper_data/tests/conftest.py +68 -0
- juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
- juniper_data/tests/integration/__init__.py +1 -0
- juniper_data/tests/integration/test_api.py +283 -0
- juniper_data/tests/integration/test_e2e_workflow.py +378 -0
- juniper_data/tests/integration/test_lifecycle_api.py +304 -0
- juniper_data/tests/integration/test_security_integration.py +189 -0
- juniper_data/tests/integration/test_storage_workflow.py +259 -0
- juniper_data/tests/performance/__init__.py +1 -0
- juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
- juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
- juniper_data/tests/unit/__init__.py +1 -0
- juniper_data/tests/unit/test_api_app.py +206 -0
- juniper_data/tests/unit/test_api_routes.py +407 -0
- juniper_data/tests/unit/test_api_settings.py +100 -0
- juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
- juniper_data/tests/unit/test_artifacts.py +145 -0
- juniper_data/tests/unit/test_cached_store.py +423 -0
- juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
- juniper_data/tests/unit/test_circles_generator.py +256 -0
- juniper_data/tests/unit/test_csv_import_generator.py +345 -0
- juniper_data/tests/unit/test_dataset_id.py +181 -0
- juniper_data/tests/unit/test_gaussian_generator.py +333 -0
- juniper_data/tests/unit/test_hf_store.py +416 -0
- juniper_data/tests/unit/test_init.py +93 -0
- juniper_data/tests/unit/test_kaggle_store.py +469 -0
- juniper_data/tests/unit/test_lifecycle.py +394 -0
- juniper_data/tests/unit/test_main.py +127 -0
- juniper_data/tests/unit/test_middleware.py +79 -0
- juniper_data/tests/unit/test_mnist_generator.py +370 -0
- juniper_data/tests/unit/test_postgres_store.py +490 -0
- juniper_data/tests/unit/test_redis_store.py +500 -0
- juniper_data/tests/unit/test_security.py +281 -0
- juniper_data/tests/unit/test_security_boundaries.py +517 -0
- juniper_data/tests/unit/test_spiral_generator.py +566 -0
- juniper_data/tests/unit/test_split.py +245 -0
- juniper_data/tests/unit/test_storage.py +767 -0
- juniper_data/tests/unit/test_xor_generator.py +223 -0
- juniper_data-0.4.2.dist-info/METADATA +216 -0
- juniper_data-0.4.2.dist-info/RECORD +95 -0
- juniper_data-0.4.2.dist-info/WHEEL +5 -0
- juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
- juniper_data-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
"""Unit tests for dataset lifecycle management features (DATA-016).
|
|
2
|
+
|
|
3
|
+
Tests for:
|
|
4
|
+
- Dataset expiration / TTL
|
|
5
|
+
- Bulk operations (filtering, batch delete)
|
|
6
|
+
- Dataset tagging
|
|
7
|
+
- Usage tracking / access counts
|
|
8
|
+
- Statistics
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from datetime import UTC, datetime, timedelta
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pytest
|
|
15
|
+
|
|
16
|
+
from juniper_data.core.models import DatasetMeta
|
|
17
|
+
from juniper_data.storage.memory import InMemoryDatasetStore
|
|
18
|
+
|
|
19
|
+
# from typing import Dict
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _create_test_meta(
|
|
23
|
+
dataset_id: str,
|
|
24
|
+
generator: str = "spiral",
|
|
25
|
+
n_samples: int = 100,
|
|
26
|
+
tags: list[str] | None = None,
|
|
27
|
+
ttl_seconds: int | None = None,
|
|
28
|
+
created_at: datetime | None = None,
|
|
29
|
+
) -> DatasetMeta:
|
|
30
|
+
"""Create a test DatasetMeta instance."""
|
|
31
|
+
now = created_at or datetime.now(UTC)
|
|
32
|
+
expires_at = None
|
|
33
|
+
if ttl_seconds is not None:
|
|
34
|
+
expires_at = now + timedelta(seconds=ttl_seconds)
|
|
35
|
+
|
|
36
|
+
return DatasetMeta(
|
|
37
|
+
dataset_id=dataset_id,
|
|
38
|
+
generator=generator,
|
|
39
|
+
generator_version="1.0.0",
|
|
40
|
+
params={"n_spirals": 2},
|
|
41
|
+
n_samples=n_samples,
|
|
42
|
+
n_features=2,
|
|
43
|
+
n_classes=2,
|
|
44
|
+
n_train=int(n_samples * 0.8),
|
|
45
|
+
n_test=int(n_samples * 0.2),
|
|
46
|
+
class_distribution={"0": n_samples // 2, "1": n_samples // 2},
|
|
47
|
+
artifact_formats=["npz"],
|
|
48
|
+
created_at=now,
|
|
49
|
+
tags=tags or [],
|
|
50
|
+
ttl_seconds=ttl_seconds,
|
|
51
|
+
expires_at=expires_at,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _create_test_arrays(n_samples: int = 100) -> dict[str, np.ndarray]:
|
|
56
|
+
"""Create minimal test arrays."""
|
|
57
|
+
n_train = int(n_samples * 0.8)
|
|
58
|
+
n_test = n_samples - n_train
|
|
59
|
+
return {
|
|
60
|
+
"X_train": np.zeros((n_train, 2), dtype=np.float32),
|
|
61
|
+
"y_train": np.zeros((n_train, 2), dtype=np.float32),
|
|
62
|
+
"X_test": np.zeros((n_test, 2), dtype=np.float32),
|
|
63
|
+
"y_test": np.zeros((n_test, 2), dtype=np.float32),
|
|
64
|
+
"X_full": np.zeros((n_samples, 2), dtype=np.float32),
|
|
65
|
+
"y_full": np.zeros((n_samples, 2), dtype=np.float32),
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@pytest.fixture
|
|
70
|
+
def store() -> InMemoryDatasetStore:
|
|
71
|
+
"""Create a fresh in-memory store."""
|
|
72
|
+
return InMemoryDatasetStore()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@pytest.mark.unit
|
|
76
|
+
class TestDatasetTags:
|
|
77
|
+
"""Tests for dataset tagging functionality."""
|
|
78
|
+
|
|
79
|
+
def test_create_dataset_with_tags(self, store: InMemoryDatasetStore) -> None:
|
|
80
|
+
"""Dataset can be created with tags."""
|
|
81
|
+
meta = _create_test_meta("ds-1", tags=["train", "spiral", "v1"])
|
|
82
|
+
store.save("ds-1", meta, _create_test_arrays())
|
|
83
|
+
|
|
84
|
+
retrieved = store.get_meta("ds-1")
|
|
85
|
+
assert retrieved is not None
|
|
86
|
+
assert retrieved.tags == ["train", "spiral", "v1"]
|
|
87
|
+
|
|
88
|
+
def test_update_meta_adds_tags(self, store: InMemoryDatasetStore) -> None:
|
|
89
|
+
"""Tags can be added via update_meta."""
|
|
90
|
+
meta = _create_test_meta("ds-1", tags=["original"])
|
|
91
|
+
store.save("ds-1", meta, _create_test_arrays())
|
|
92
|
+
|
|
93
|
+
meta.tags = ["original", "added"]
|
|
94
|
+
result = store.update_meta("ds-1", meta)
|
|
95
|
+
assert result is True
|
|
96
|
+
|
|
97
|
+
retrieved = store.get_meta("ds-1")
|
|
98
|
+
assert retrieved is not None
|
|
99
|
+
assert "added" in retrieved.tags
|
|
100
|
+
|
|
101
|
+
def test_update_meta_nonexistent_returns_false(self, store: InMemoryDatasetStore) -> None:
|
|
102
|
+
"""update_meta returns False for nonexistent dataset."""
|
|
103
|
+
meta = _create_test_meta("nonexistent")
|
|
104
|
+
result = store.update_meta("nonexistent", meta)
|
|
105
|
+
assert result is False
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@pytest.mark.unit
|
|
109
|
+
class TestDatasetTTL:
|
|
110
|
+
"""Tests for dataset expiration / TTL."""
|
|
111
|
+
|
|
112
|
+
def test_dataset_with_ttl_has_expires_at(self, store: InMemoryDatasetStore) -> None:
|
|
113
|
+
"""Dataset with TTL has expires_at set."""
|
|
114
|
+
meta = _create_test_meta("ds-1", ttl_seconds=3600)
|
|
115
|
+
store.save("ds-1", meta, _create_test_arrays())
|
|
116
|
+
|
|
117
|
+
retrieved = store.get_meta("ds-1")
|
|
118
|
+
assert retrieved is not None
|
|
119
|
+
assert retrieved.ttl_seconds == 3600
|
|
120
|
+
assert retrieved.expires_at is not None
|
|
121
|
+
|
|
122
|
+
def test_is_expired_false_for_future_expiry(self, store: InMemoryDatasetStore) -> None:
|
|
123
|
+
"""Dataset with future expiry is not expired."""
|
|
124
|
+
meta = _create_test_meta("ds-1", ttl_seconds=3600)
|
|
125
|
+
store.save("ds-1", meta, _create_test_arrays())
|
|
126
|
+
|
|
127
|
+
assert store.is_expired(meta) is False
|
|
128
|
+
|
|
129
|
+
def test_is_expired_true_for_past_expiry(self, store: InMemoryDatasetStore) -> None:
|
|
130
|
+
"""Dataset with past expiry is expired."""
|
|
131
|
+
past_time = datetime.now(UTC) - timedelta(hours=2)
|
|
132
|
+
meta = _create_test_meta("ds-1", ttl_seconds=3600, created_at=past_time)
|
|
133
|
+
store.save("ds-1", meta, _create_test_arrays())
|
|
134
|
+
|
|
135
|
+
assert store.is_expired(meta) is True
|
|
136
|
+
|
|
137
|
+
def test_is_expired_false_for_no_ttl(self, store: InMemoryDatasetStore) -> None:
|
|
138
|
+
"""Dataset without TTL never expires."""
|
|
139
|
+
meta = _create_test_meta("ds-1", ttl_seconds=None)
|
|
140
|
+
store.save("ds-1", meta, _create_test_arrays())
|
|
141
|
+
|
|
142
|
+
assert store.is_expired(meta) is False
|
|
143
|
+
|
|
144
|
+
def test_delete_expired_removes_expired_datasets(self, store: InMemoryDatasetStore) -> None:
|
|
145
|
+
"""delete_expired removes only expired datasets."""
|
|
146
|
+
past_time = datetime.now(UTC) - timedelta(hours=2)
|
|
147
|
+
meta1 = _create_test_meta("expired-1", ttl_seconds=3600, created_at=past_time)
|
|
148
|
+
meta2 = _create_test_meta("expired-2", ttl_seconds=3600, created_at=past_time)
|
|
149
|
+
meta3 = _create_test_meta("valid-1", ttl_seconds=3600)
|
|
150
|
+
meta4 = _create_test_meta("no-ttl")
|
|
151
|
+
|
|
152
|
+
store.save("expired-1", meta1, _create_test_arrays())
|
|
153
|
+
store.save("expired-2", meta2, _create_test_arrays())
|
|
154
|
+
store.save("valid-1", meta3, _create_test_arrays())
|
|
155
|
+
store.save("no-ttl", meta4, _create_test_arrays())
|
|
156
|
+
|
|
157
|
+
deleted = store.delete_expired()
|
|
158
|
+
|
|
159
|
+
assert set(deleted) == {"expired-1", "expired-2"}
|
|
160
|
+
assert store.exists("valid-1")
|
|
161
|
+
assert store.exists("no-ttl")
|
|
162
|
+
assert not store.exists("expired-1")
|
|
163
|
+
assert not store.exists("expired-2")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@pytest.mark.unit
|
|
167
|
+
class TestDatasetFiltering:
|
|
168
|
+
"""Tests for dataset filtering functionality."""
|
|
169
|
+
|
|
170
|
+
@pytest.fixture
|
|
171
|
+
def populated_store(self, store: InMemoryDatasetStore) -> InMemoryDatasetStore:
|
|
172
|
+
"""Create a store with multiple datasets for filtering tests."""
|
|
173
|
+
now = datetime.now(UTC)
|
|
174
|
+
|
|
175
|
+
datasets = [
|
|
176
|
+
("ds-1", "spiral", 100, ["train", "v1"], now - timedelta(days=5)),
|
|
177
|
+
("ds-2", "spiral", 200, ["train", "v2"], now - timedelta(days=3)),
|
|
178
|
+
("ds-3", "spiral", 50, ["test", "v1"], now - timedelta(days=1)),
|
|
179
|
+
("ds-4", "xor", 100, ["train"], now - timedelta(hours=12)),
|
|
180
|
+
("ds-5", "xor", 300, ["train", "v2"], now - timedelta(hours=1)),
|
|
181
|
+
]
|
|
182
|
+
|
|
183
|
+
for dataset_id, gen, n_samples, tags, created in datasets:
|
|
184
|
+
meta = _create_test_meta(dataset_id, generator=gen, n_samples=n_samples, tags=tags, created_at=created)
|
|
185
|
+
store.save(dataset_id, meta, _create_test_arrays(n_samples))
|
|
186
|
+
|
|
187
|
+
return store
|
|
188
|
+
|
|
189
|
+
def test_filter_by_generator(self, populated_store: InMemoryDatasetStore) -> None:
|
|
190
|
+
"""Filter datasets by generator name."""
|
|
191
|
+
datasets, total = populated_store.filter_datasets(generator="spiral")
|
|
192
|
+
assert total == 3
|
|
193
|
+
assert all(d.generator == "spiral" for d in datasets)
|
|
194
|
+
|
|
195
|
+
def test_filter_by_tags_any(self, populated_store: InMemoryDatasetStore) -> None:
|
|
196
|
+
"""Filter datasets by tags (any match)."""
|
|
197
|
+
datasets, total = populated_store.filter_datasets(tags=["v1", "v2"], tags_match="any")
|
|
198
|
+
assert total == 4
|
|
199
|
+
|
|
200
|
+
def test_filter_by_tags_all(self, populated_store: InMemoryDatasetStore) -> None:
|
|
201
|
+
"""Filter datasets by tags (all must match)."""
|
|
202
|
+
datasets, total = populated_store.filter_datasets(tags=["train", "v2"], tags_match="all")
|
|
203
|
+
assert total == 2
|
|
204
|
+
for d in datasets:
|
|
205
|
+
assert "train" in d.tags
|
|
206
|
+
assert "v2" in d.tags
|
|
207
|
+
|
|
208
|
+
def test_filter_by_created_after(self, populated_store: InMemoryDatasetStore) -> None:
|
|
209
|
+
"""Filter datasets created after a date."""
|
|
210
|
+
cutoff = datetime.now(UTC) - timedelta(days=2)
|
|
211
|
+
datasets, total = populated_store.filter_datasets(created_after=cutoff)
|
|
212
|
+
assert total == 3
|
|
213
|
+
|
|
214
|
+
def test_filter_by_created_before(self, populated_store: InMemoryDatasetStore) -> None:
|
|
215
|
+
"""Filter datasets created before a date."""
|
|
216
|
+
cutoff = datetime.now(UTC) - timedelta(days=2)
|
|
217
|
+
datasets, total = populated_store.filter_datasets(created_before=cutoff)
|
|
218
|
+
assert total == 2
|
|
219
|
+
|
|
220
|
+
def test_filter_by_sample_count(self, populated_store: InMemoryDatasetStore) -> None:
|
|
221
|
+
"""Filter datasets by sample count range."""
|
|
222
|
+
datasets, total = populated_store.filter_datasets(min_samples=100, max_samples=200)
|
|
223
|
+
assert total == 3
|
|
224
|
+
for d in datasets:
|
|
225
|
+
assert 100 <= d.n_samples <= 200
|
|
226
|
+
|
|
227
|
+
def test_filter_pagination(self, populated_store: InMemoryDatasetStore) -> None:
|
|
228
|
+
"""Filter with pagination."""
|
|
229
|
+
datasets_page1, total = populated_store.filter_datasets(limit=2, offset=0)
|
|
230
|
+
datasets_page2, _ = populated_store.filter_datasets(limit=2, offset=2)
|
|
231
|
+
|
|
232
|
+
assert total == 5
|
|
233
|
+
assert len(datasets_page1) == 2
|
|
234
|
+
assert len(datasets_page2) == 2
|
|
235
|
+
|
|
236
|
+
ids_page1 = {d.dataset_id for d in datasets_page1}
|
|
237
|
+
ids_page2 = {d.dataset_id for d in datasets_page2}
|
|
238
|
+
assert ids_page1.isdisjoint(ids_page2)
|
|
239
|
+
|
|
240
|
+
def test_filter_excludes_expired_by_default(self, store: InMemoryDatasetStore) -> None:
|
|
241
|
+
"""Expired datasets are excluded by default."""
|
|
242
|
+
past_time = datetime.now(UTC) - timedelta(hours=2)
|
|
243
|
+
meta_expired = _create_test_meta("expired", ttl_seconds=3600, created_at=past_time)
|
|
244
|
+
meta_valid = _create_test_meta("valid")
|
|
245
|
+
|
|
246
|
+
store.save("expired", meta_expired, _create_test_arrays())
|
|
247
|
+
store.save("valid", meta_valid, _create_test_arrays())
|
|
248
|
+
|
|
249
|
+
datasets, total = store.filter_datasets(include_expired=False)
|
|
250
|
+
assert total == 1
|
|
251
|
+
assert datasets[0].dataset_id == "valid"
|
|
252
|
+
|
|
253
|
+
def test_filter_includes_expired_when_requested(self, store: InMemoryDatasetStore) -> None:
|
|
254
|
+
"""Expired datasets are included when requested."""
|
|
255
|
+
past_time = datetime.now(UTC) - timedelta(hours=2)
|
|
256
|
+
meta_expired = _create_test_meta("expired", ttl_seconds=3600, created_at=past_time)
|
|
257
|
+
meta_valid = _create_test_meta("valid")
|
|
258
|
+
|
|
259
|
+
store.save("expired", meta_expired, _create_test_arrays())
|
|
260
|
+
store.save("valid", meta_valid, _create_test_arrays())
|
|
261
|
+
|
|
262
|
+
datasets, total = store.filter_datasets(include_expired=True)
|
|
263
|
+
assert total == 2
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
@pytest.mark.unit
|
|
267
|
+
class TestBatchDelete:
|
|
268
|
+
"""Tests for batch delete functionality."""
|
|
269
|
+
|
|
270
|
+
def test_batch_delete_existing(self, store: InMemoryDatasetStore) -> None:
|
|
271
|
+
"""Batch delete existing datasets."""
|
|
272
|
+
for i in range(5):
|
|
273
|
+
meta = _create_test_meta(f"ds-{i}")
|
|
274
|
+
store.save(f"ds-{i}", meta, _create_test_arrays())
|
|
275
|
+
|
|
276
|
+
deleted, not_found = store.batch_delete(["ds-0", "ds-2", "ds-4"])
|
|
277
|
+
|
|
278
|
+
assert set(deleted) == {"ds-0", "ds-2", "ds-4"}
|
|
279
|
+
assert not_found == []
|
|
280
|
+
assert store.exists("ds-1")
|
|
281
|
+
assert store.exists("ds-3")
|
|
282
|
+
assert not store.exists("ds-0")
|
|
283
|
+
|
|
284
|
+
def test_batch_delete_mixed(self, store: InMemoryDatasetStore) -> None:
|
|
285
|
+
"""Batch delete with some nonexistent IDs."""
|
|
286
|
+
meta = _create_test_meta("ds-1")
|
|
287
|
+
store.save("ds-1", meta, _create_test_arrays())
|
|
288
|
+
|
|
289
|
+
deleted, not_found = store.batch_delete(["ds-1", "nonexistent-1", "nonexistent-2"])
|
|
290
|
+
|
|
291
|
+
assert deleted == ["ds-1"]
|
|
292
|
+
assert set(not_found) == {"nonexistent-1", "nonexistent-2"}
|
|
293
|
+
|
|
294
|
+
def test_batch_delete_all_nonexistent(self, store: InMemoryDatasetStore) -> None:
|
|
295
|
+
"""Batch delete with all nonexistent IDs."""
|
|
296
|
+
deleted, not_found = store.batch_delete(["fake-1", "fake-2"])
|
|
297
|
+
|
|
298
|
+
assert deleted == []
|
|
299
|
+
assert set(not_found) == {"fake-1", "fake-2"}
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
@pytest.mark.unit
|
|
303
|
+
class TestAccessTracking:
|
|
304
|
+
"""Tests for access tracking functionality."""
|
|
305
|
+
|
|
306
|
+
def test_record_access_updates_timestamp(self, store: InMemoryDatasetStore) -> None:
|
|
307
|
+
"""record_access updates last_accessed_at."""
|
|
308
|
+
meta = _create_test_meta("ds-1")
|
|
309
|
+
store.save("ds-1", meta, _create_test_arrays())
|
|
310
|
+
|
|
311
|
+
store.record_access("ds-1")
|
|
312
|
+
|
|
313
|
+
retrieved = store.get_meta("ds-1")
|
|
314
|
+
assert retrieved is not None
|
|
315
|
+
assert retrieved.last_accessed_at is not None
|
|
316
|
+
assert retrieved.access_count == 1
|
|
317
|
+
|
|
318
|
+
def test_record_access_increments_count(self, store: InMemoryDatasetStore) -> None:
|
|
319
|
+
"""record_access increments access_count."""
|
|
320
|
+
meta = _create_test_meta("ds-1")
|
|
321
|
+
store.save("ds-1", meta, _create_test_arrays())
|
|
322
|
+
|
|
323
|
+
store.record_access("ds-1")
|
|
324
|
+
store.record_access("ds-1")
|
|
325
|
+
store.record_access("ds-1")
|
|
326
|
+
|
|
327
|
+
retrieved = store.get_meta("ds-1")
|
|
328
|
+
assert retrieved is not None
|
|
329
|
+
assert retrieved.access_count == 3
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
@pytest.mark.unit
|
|
333
|
+
class TestDatasetStats:
|
|
334
|
+
"""Tests for aggregate statistics functionality."""
|
|
335
|
+
|
|
336
|
+
def test_stats_empty_store(self, store: InMemoryDatasetStore) -> None:
|
|
337
|
+
"""Stats for empty store."""
|
|
338
|
+
stats = store.get_stats()
|
|
339
|
+
|
|
340
|
+
assert stats["total_datasets"] == 0
|
|
341
|
+
assert stats["total_samples"] == 0
|
|
342
|
+
assert stats["by_generator"] == {}
|
|
343
|
+
assert stats["by_tag"] == {}
|
|
344
|
+
|
|
345
|
+
def test_stats_populated_store(self, store: InMemoryDatasetStore) -> None:
|
|
346
|
+
"""Stats for populated store."""
|
|
347
|
+
meta1 = _create_test_meta("ds-1", generator="spiral", n_samples=100, tags=["train", "v1"])
|
|
348
|
+
meta2 = _create_test_meta("ds-2", generator="spiral", n_samples=200, tags=["train", "v2"])
|
|
349
|
+
meta3 = _create_test_meta("ds-3", generator="xor", n_samples=50, tags=["test"])
|
|
350
|
+
|
|
351
|
+
store.save("ds-1", meta1, _create_test_arrays(100))
|
|
352
|
+
store.save("ds-2", meta2, _create_test_arrays(200))
|
|
353
|
+
store.save("ds-3", meta3, _create_test_arrays(50))
|
|
354
|
+
|
|
355
|
+
stats = store.get_stats()
|
|
356
|
+
|
|
357
|
+
assert stats["total_datasets"] == 3
|
|
358
|
+
assert stats["total_samples"] == 350
|
|
359
|
+
assert stats["by_generator"] == {"spiral": 2, "xor": 1}
|
|
360
|
+
assert stats["by_tag"] == {"train": 2, "v1": 1, "v2": 1, "test": 1}
|
|
361
|
+
|
|
362
|
+
def test_stats_counts_expired(self, store: InMemoryDatasetStore) -> None:
|
|
363
|
+
"""Stats includes expired count."""
|
|
364
|
+
past_time = datetime.now(UTC) - timedelta(hours=2)
|
|
365
|
+
meta_expired = _create_test_meta("expired", ttl_seconds=3600, created_at=past_time)
|
|
366
|
+
meta_valid = _create_test_meta("valid")
|
|
367
|
+
|
|
368
|
+
store.save("expired", meta_expired, _create_test_arrays())
|
|
369
|
+
store.save("valid", meta_valid, _create_test_arrays())
|
|
370
|
+
|
|
371
|
+
stats = store.get_stats()
|
|
372
|
+
|
|
373
|
+
assert stats["expired_count"] == 1
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
@pytest.mark.unit
|
|
377
|
+
class TestListAllMetadata:
|
|
378
|
+
"""Tests for list_all_metadata functionality."""
|
|
379
|
+
|
|
380
|
+
def test_list_all_metadata_empty(self, store: InMemoryDatasetStore) -> None:
|
|
381
|
+
"""list_all_metadata returns empty list for empty store."""
|
|
382
|
+
result = store.list_all_metadata()
|
|
383
|
+
assert result == []
|
|
384
|
+
|
|
385
|
+
def test_list_all_metadata_returns_all(self, store: InMemoryDatasetStore) -> None:
|
|
386
|
+
"""list_all_metadata returns all stored metadata."""
|
|
387
|
+
for i in range(5):
|
|
388
|
+
meta = _create_test_meta(f"ds-{i}")
|
|
389
|
+
store.save(f"ds-{i}", meta, _create_test_arrays())
|
|
390
|
+
|
|
391
|
+
result = store.list_all_metadata()
|
|
392
|
+
assert len(result) == 5
|
|
393
|
+
ids = {m.dataset_id for m in result}
|
|
394
|
+
assert ids == {"ds-0", "ds-1", "ds-2", "ds-3", "ds-4"}
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Unit tests for __main__.py entry point."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from unittest.mock import patch
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
from juniper_data.api.settings import (
|
|
9
|
+
_JUNIPER_DATA_API_HOST_DEFAULT,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.mark.unit
|
|
14
|
+
class TestMain:
|
|
15
|
+
"""Tests for the main() entry point function."""
|
|
16
|
+
|
|
17
|
+
def test_main_import_error_uvicorn_not_installed(self) -> None:
|
|
18
|
+
"""Test main returns 1 when uvicorn is not installed."""
|
|
19
|
+
import builtins
|
|
20
|
+
import importlib
|
|
21
|
+
|
|
22
|
+
original_import = builtins.__import__
|
|
23
|
+
|
|
24
|
+
def mock_import(name, *args, **kwargs):
|
|
25
|
+
if name == "uvicorn":
|
|
26
|
+
raise ImportError("No module named 'uvicorn'")
|
|
27
|
+
return original_import(name, *args, **kwargs)
|
|
28
|
+
|
|
29
|
+
with (
|
|
30
|
+
patch.object(sys, "argv", ["juniper_data"]),
|
|
31
|
+
patch("builtins.print") as mock_print,
|
|
32
|
+
patch.object(builtins, "__import__", side_effect=mock_import),
|
|
33
|
+
patch.dict(sys.modules, {"uvicorn": None}),
|
|
34
|
+
):
|
|
35
|
+
from juniper_data import __main__ as main_module
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
importlib.reload(main_module)
|
|
39
|
+
result = main_module.main()
|
|
40
|
+
assert result == 1
|
|
41
|
+
mock_print.assert_called()
|
|
42
|
+
except ImportError as e:
|
|
43
|
+
# If ImportError occurs during test setup, skip with explanation
|
|
44
|
+
pytest.skip(f"Cannot test uvicorn import error scenario: {e}")
|
|
45
|
+
|
|
46
|
+
def test_main_parses_host_argument(self) -> None:
|
|
47
|
+
"""Test main correctly parses --host argument."""
|
|
48
|
+
with patch("uvicorn.run") as mock_run:
|
|
49
|
+
with patch.object(sys, "argv", ["juniper_data", "--host", "127.0.0.1"]):
|
|
50
|
+
call_kwargs = self._get_call_args_from_mocked_main_run(mock_run)
|
|
51
|
+
# assert call_kwargs[1]["host"] == "127.0.0.1"
|
|
52
|
+
assert call_kwargs[1]["host"] == _JUNIPER_DATA_API_HOST_DEFAULT
|
|
53
|
+
|
|
54
|
+
def test_main_parses_port_argument(self) -> None:
|
|
55
|
+
"""Test main correctly parses --port argument."""
|
|
56
|
+
with patch("uvicorn.run") as mock_run:
|
|
57
|
+
with patch.object(sys, "argv", ["juniper_data", "--port", "9000"]):
|
|
58
|
+
call_kwargs = self._get_call_args_from_mocked_main_run(mock_run)
|
|
59
|
+
assert call_kwargs[1]["port"] == 9000
|
|
60
|
+
|
|
61
|
+
def test_main_parses_log_level_argument(self) -> None:
|
|
62
|
+
"""Test main correctly parses --log-level argument."""
|
|
63
|
+
with patch("uvicorn.run") as mock_run:
|
|
64
|
+
with patch.object(sys, "argv", ["juniper_data", "--log-level", "DEBUG"]):
|
|
65
|
+
call_kwargs = self._get_call_args_from_mocked_main_run(mock_run)
|
|
66
|
+
assert call_kwargs[1]["log_level"] == "debug"
|
|
67
|
+
|
|
68
|
+
def test_main_parses_reload_argument(self) -> None:
|
|
69
|
+
"""Test main correctly parses --reload argument."""
|
|
70
|
+
with patch("uvicorn.run") as mock_run:
|
|
71
|
+
with patch.object(sys, "argv", ["juniper_data", "--reload"]):
|
|
72
|
+
call_kwargs = self._get_call_args_from_mocked_main_run(mock_run)
|
|
73
|
+
assert call_kwargs[1]["reload"] is True
|
|
74
|
+
|
|
75
|
+
def test_main_parses_storage_path_argument(self) -> None:
|
|
76
|
+
"""Test main correctly parses --storage-path argument and sets env var."""
|
|
77
|
+
with patch("uvicorn.run") as mock_run:
|
|
78
|
+
with patch.dict("os.environ", {}, clear=False):
|
|
79
|
+
with patch.object(sys, "argv", ["juniper_data", "--storage-path", "/custom/path"]):
|
|
80
|
+
import os
|
|
81
|
+
|
|
82
|
+
from juniper_data.__main__ import main
|
|
83
|
+
|
|
84
|
+
main()
|
|
85
|
+
assert os.environ.get("JUNIPER_DATA_STORAGE_PATH") == "/custom/path"
|
|
86
|
+
mock_run.assert_called_once()
|
|
87
|
+
|
|
88
|
+
def test_main_uses_default_settings_when_no_args(self) -> None:
|
|
89
|
+
"""Test main uses settings defaults when no args provided."""
|
|
90
|
+
with patch("uvicorn.run") as mock_run:
|
|
91
|
+
with patch.object(sys, "argv", ["juniper_data"]):
|
|
92
|
+
# self._validate_mocked_host_name_and_port_args(mock_run, "0.0.0.0")
|
|
93
|
+
self._validate_mocked_host_name_and_port_args(mock_run, _JUNIPER_DATA_API_HOST_DEFAULT)
|
|
94
|
+
|
|
95
|
+
def test_main_returns_zero_on_success(self) -> None:
|
|
96
|
+
"""Test main returns 0 on successful run."""
|
|
97
|
+
with patch("uvicorn.run"):
|
|
98
|
+
with patch.object(sys, "argv", ["juniper_data"]):
|
|
99
|
+
from juniper_data.__main__ import main
|
|
100
|
+
|
|
101
|
+
result = main()
|
|
102
|
+
assert result == 0
|
|
103
|
+
|
|
104
|
+
def test_main_app_string(self) -> None:
|
|
105
|
+
"""Test main passes correct app string to uvicorn."""
|
|
106
|
+
with patch("uvicorn.run") as mock_run:
|
|
107
|
+
with patch.object(sys, "argv", ["juniper_data"]):
|
|
108
|
+
call_args = self._get_call_args_from_mocked_main_run(mock_run)
|
|
109
|
+
assert call_args[0][0] == "juniper_data.api.app:app"
|
|
110
|
+
|
|
111
|
+
def test_main_combines_custom_and_default_args(self) -> None:
|
|
112
|
+
"""Test main combines custom args with settings defaults."""
|
|
113
|
+
with patch("uvicorn.run") as mock_run:
|
|
114
|
+
with patch.object(sys, "argv", ["juniper_data", "--host", "localhost"]):
|
|
115
|
+
self._validate_mocked_host_name_and_port_args(mock_run, "localhost")
|
|
116
|
+
|
|
117
|
+
def _validate_mocked_host_name_and_port_args(self, mock_run, arg1):
|
|
118
|
+
call_kwargs = self._get_call_args_from_mocked_main_run(mock_run)
|
|
119
|
+
assert call_kwargs[1]["host"] == arg1
|
|
120
|
+
assert call_kwargs[1]["port"] == 8100
|
|
121
|
+
|
|
122
|
+
def _get_call_args_from_mocked_main_run(self, mock_run):
|
|
123
|
+
from juniper_data.__main__ import main
|
|
124
|
+
|
|
125
|
+
main()
|
|
126
|
+
mock_run.assert_called_once()
|
|
127
|
+
return mock_run.call_args
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Unit tests for SecurityMiddleware."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from fastapi import FastAPI
|
|
5
|
+
from fastapi.testclient import TestClient
|
|
6
|
+
|
|
7
|
+
from juniper_data.api.middleware import EXEMPT_PATHS, SecurityMiddleware
|
|
8
|
+
from juniper_data.api.security import APIKeyAuth, RateLimiter
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.fixture
|
|
12
|
+
def app_with_middleware():
|
|
13
|
+
"""Create a FastAPI app with security middleware."""
|
|
14
|
+
|
|
15
|
+
def _create(api_keys=None, rate_limit_enabled=False, rpm=60):
|
|
16
|
+
app = FastAPI()
|
|
17
|
+
auth = APIKeyAuth(api_keys)
|
|
18
|
+
limiter = RateLimiter(requests_per_minute=rpm, enabled=rate_limit_enabled)
|
|
19
|
+
app.add_middleware(SecurityMiddleware, api_key_auth=auth, rate_limiter=limiter)
|
|
20
|
+
|
|
21
|
+
@app.get("/v1/health")
|
|
22
|
+
async def health():
|
|
23
|
+
return {"status": "ok"}
|
|
24
|
+
|
|
25
|
+
@app.get("/v1/datasets")
|
|
26
|
+
async def datasets():
|
|
27
|
+
return {"data": []}
|
|
28
|
+
|
|
29
|
+
return app
|
|
30
|
+
|
|
31
|
+
return _create
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@pytest.mark.unit
|
|
35
|
+
class TestSecurityMiddleware:
|
|
36
|
+
def test_exempt_path_bypasses_security(self, app_with_middleware):
|
|
37
|
+
app = app_with_middleware(api_keys=["secret"])
|
|
38
|
+
client = TestClient(app)
|
|
39
|
+
response = client.get("/v1/health")
|
|
40
|
+
assert response.status_code == 200
|
|
41
|
+
|
|
42
|
+
def test_auth_required_returns_401(self, app_with_middleware):
|
|
43
|
+
app = app_with_middleware(api_keys=["secret"])
|
|
44
|
+
client = TestClient(app)
|
|
45
|
+
response = client.get("/v1/datasets")
|
|
46
|
+
assert response.status_code == 401
|
|
47
|
+
|
|
48
|
+
def test_invalid_key_returns_401(self, app_with_middleware):
|
|
49
|
+
app = app_with_middleware(api_keys=["secret"])
|
|
50
|
+
client = TestClient(app)
|
|
51
|
+
response = client.get("/v1/datasets", headers={"X-API-Key": "wrong"})
|
|
52
|
+
assert response.status_code == 401
|
|
53
|
+
|
|
54
|
+
def test_valid_key_passes(self, app_with_middleware):
|
|
55
|
+
app = app_with_middleware(api_keys=["secret"])
|
|
56
|
+
client = TestClient(app)
|
|
57
|
+
response = client.get("/v1/datasets", headers={"X-API-Key": "secret"})
|
|
58
|
+
assert response.status_code == 200
|
|
59
|
+
|
|
60
|
+
def test_rate_limit_exceeded_returns_429(self, app_with_middleware):
|
|
61
|
+
app = app_with_middleware(rate_limit_enabled=True, rpm=2)
|
|
62
|
+
client = TestClient(app)
|
|
63
|
+
for _ in range(2):
|
|
64
|
+
client.get("/v1/datasets")
|
|
65
|
+
response = client.get("/v1/datasets")
|
|
66
|
+
assert response.status_code == 429
|
|
67
|
+
|
|
68
|
+
def test_rate_limit_headers_included(self, app_with_middleware):
|
|
69
|
+
app = app_with_middleware(rate_limit_enabled=True, rpm=10)
|
|
70
|
+
client = TestClient(app)
|
|
71
|
+
response = client.get("/v1/datasets")
|
|
72
|
+
assert response.status_code == 200
|
|
73
|
+
assert "X-RateLimit-Limit" in response.headers
|
|
74
|
+
assert "X-RateLimit-Remaining" in response.headers
|
|
75
|
+
|
|
76
|
+
def test_is_exempt_checks_known_paths(self):
|
|
77
|
+
assert "/v1/health" in EXEMPT_PATHS
|
|
78
|
+
assert "/docs" in EXEMPT_PATHS
|
|
79
|
+
assert "/v1/datasets" not in EXEMPT_PATHS
|