juniper-data 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- juniper_data/__init__.py +88 -0
- juniper_data/__main__.py +78 -0
- juniper_data/api/__init__.py +10 -0
- juniper_data/api/app.py +111 -0
- juniper_data/api/middleware.py +95 -0
- juniper_data/api/routes/__init__.py +9 -0
- juniper_data/api/routes/datasets.py +414 -0
- juniper_data/api/routes/generators.py +125 -0
- juniper_data/api/routes/health.py +49 -0
- juniper_data/api/security.py +238 -0
- juniper_data/api/settings.py +109 -0
- juniper_data/core/__init__.py +32 -0
- juniper_data/core/artifacts.py +63 -0
- juniper_data/core/dataset_id.py +38 -0
- juniper_data/core/models.py +135 -0
- juniper_data/core/split.py +120 -0
- juniper_data/generators/__init__.py +15 -0
- juniper_data/generators/arc_agi/__init__.py +11 -0
- juniper_data/generators/arc_agi/generator.py +229 -0
- juniper_data/generators/arc_agi/params.py +56 -0
- juniper_data/generators/checkerboard/__init__.py +15 -0
- juniper_data/generators/checkerboard/generator.py +114 -0
- juniper_data/generators/checkerboard/params.py +32 -0
- juniper_data/generators/circles/__init__.py +11 -0
- juniper_data/generators/circles/generator.py +112 -0
- juniper_data/generators/circles/params.py +31 -0
- juniper_data/generators/csv_import/__init__.py +15 -0
- juniper_data/generators/csv_import/generator.py +198 -0
- juniper_data/generators/csv_import/params.py +48 -0
- juniper_data/generators/gaussian/__init__.py +11 -0
- juniper_data/generators/gaussian/generator.py +149 -0
- juniper_data/generators/gaussian/params.py +53 -0
- juniper_data/generators/mnist/__init__.py +11 -0
- juniper_data/generators/mnist/generator.py +124 -0
- juniper_data/generators/mnist/params.py +39 -0
- juniper_data/generators/spiral/__init__.py +57 -0
- juniper_data/generators/spiral/defaults.py +39 -0
- juniper_data/generators/spiral/generator.py +206 -0
- juniper_data/generators/spiral/params.py +148 -0
- juniper_data/generators/xor/__init__.py +11 -0
- juniper_data/generators/xor/generator.py +162 -0
- juniper_data/generators/xor/params.py +30 -0
- juniper_data/storage/__init__.py +120 -0
- juniper_data/storage/base.py +279 -0
- juniper_data/storage/cached.py +211 -0
- juniper_data/storage/hf_store.py +257 -0
- juniper_data/storage/kaggle_store.py +333 -0
- juniper_data/storage/local_fs.py +232 -0
- juniper_data/storage/memory.py +136 -0
- juniper_data/storage/postgres_store.py +373 -0
- juniper_data/storage/redis_store.py +264 -0
- juniper_data/tests/__init__.py +1 -0
- juniper_data/tests/conftest.py +68 -0
- juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
- juniper_data/tests/integration/__init__.py +1 -0
- juniper_data/tests/integration/test_api.py +283 -0
- juniper_data/tests/integration/test_e2e_workflow.py +378 -0
- juniper_data/tests/integration/test_lifecycle_api.py +304 -0
- juniper_data/tests/integration/test_security_integration.py +189 -0
- juniper_data/tests/integration/test_storage_workflow.py +259 -0
- juniper_data/tests/performance/__init__.py +1 -0
- juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
- juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
- juniper_data/tests/unit/__init__.py +1 -0
- juniper_data/tests/unit/test_api_app.py +206 -0
- juniper_data/tests/unit/test_api_routes.py +407 -0
- juniper_data/tests/unit/test_api_settings.py +100 -0
- juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
- juniper_data/tests/unit/test_artifacts.py +145 -0
- juniper_data/tests/unit/test_cached_store.py +423 -0
- juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
- juniper_data/tests/unit/test_circles_generator.py +256 -0
- juniper_data/tests/unit/test_csv_import_generator.py +345 -0
- juniper_data/tests/unit/test_dataset_id.py +181 -0
- juniper_data/tests/unit/test_gaussian_generator.py +333 -0
- juniper_data/tests/unit/test_hf_store.py +416 -0
- juniper_data/tests/unit/test_init.py +93 -0
- juniper_data/tests/unit/test_kaggle_store.py +469 -0
- juniper_data/tests/unit/test_lifecycle.py +394 -0
- juniper_data/tests/unit/test_main.py +127 -0
- juniper_data/tests/unit/test_middleware.py +79 -0
- juniper_data/tests/unit/test_mnist_generator.py +370 -0
- juniper_data/tests/unit/test_postgres_store.py +490 -0
- juniper_data/tests/unit/test_redis_store.py +500 -0
- juniper_data/tests/unit/test_security.py +281 -0
- juniper_data/tests/unit/test_security_boundaries.py +517 -0
- juniper_data/tests/unit/test_spiral_generator.py +566 -0
- juniper_data/tests/unit/test_split.py +245 -0
- juniper_data/tests/unit/test_storage.py +767 -0
- juniper_data/tests/unit/test_xor_generator.py +223 -0
- juniper_data-0.4.2.dist-info/METADATA +216 -0
- juniper_data-0.4.2.dist-info/RECORD +95 -0
- juniper_data-0.4.2.dist-info/WHEEL +5 -0
- juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
- juniper_data-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
"""End-to-End integration tests for the complete JuniperData workflow.
|
|
2
|
+
|
|
3
|
+
These tests verify the full flow:
|
|
4
|
+
1. Start JuniperData service (via TestClient)
|
|
5
|
+
2. Create dataset via REST API
|
|
6
|
+
3. Download NPZ artifact
|
|
7
|
+
4. Verify data integrity (shapes, dtypes, determinism)
|
|
8
|
+
|
|
9
|
+
Marked with @pytest.mark.slow for weekly CI runs.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import io
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pytest
|
|
16
|
+
from fastapi.testclient import TestClient
|
|
17
|
+
|
|
18
|
+
from juniper_data.api.app import create_app
|
|
19
|
+
from juniper_data.api.routes import datasets
|
|
20
|
+
from juniper_data.api.settings import Settings
|
|
21
|
+
from juniper_data.storage.memory import InMemoryDatasetStore
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@pytest.fixture
|
|
25
|
+
def e2e_store() -> InMemoryDatasetStore:
|
|
26
|
+
"""Create a fresh in-memory store for E2E tests."""
|
|
27
|
+
return InMemoryDatasetStore()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@pytest.fixture
|
|
31
|
+
def e2e_settings() -> Settings:
|
|
32
|
+
"""Create E2E test settings."""
|
|
33
|
+
return Settings(storage_path="/tmp/juniper_data_e2e_test")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@pytest.fixture
|
|
37
|
+
def e2e_client(e2e_store: InMemoryDatasetStore, e2e_settings: Settings) -> TestClient:
|
|
38
|
+
"""Create an E2E test client with in-memory storage."""
|
|
39
|
+
app = create_app(settings=e2e_settings)
|
|
40
|
+
datasets.set_store(e2e_store)
|
|
41
|
+
return TestClient(app)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@pytest.mark.integration
|
|
45
|
+
@pytest.mark.slow
|
|
46
|
+
class TestE2EModernAlgorithm:
|
|
47
|
+
"""E2E tests for the modern spiral generation algorithm."""
|
|
48
|
+
|
|
49
|
+
@pytest.fixture
|
|
50
|
+
def modern_request(self) -> dict:
|
|
51
|
+
"""Request for modern algorithm spiral dataset."""
|
|
52
|
+
return {
|
|
53
|
+
"generator": "spiral",
|
|
54
|
+
"params": {
|
|
55
|
+
"n_spirals": 2,
|
|
56
|
+
"n_points_per_spiral": 100,
|
|
57
|
+
"seed": 42,
|
|
58
|
+
"algorithm": "modern",
|
|
59
|
+
"noise": 0.1,
|
|
60
|
+
"train_ratio": 0.8,
|
|
61
|
+
"test_ratio": 0.2,
|
|
62
|
+
},
|
|
63
|
+
"persist": True,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
def test_e2e_create_download_verify_modern(self, e2e_client: TestClient, modern_request: dict) -> None:
|
|
67
|
+
"""Complete E2E flow: create dataset, download NPZ, verify integrity."""
|
|
68
|
+
create_response = e2e_client.post("/v1/datasets", json=modern_request)
|
|
69
|
+
assert create_response.status_code == 201
|
|
70
|
+
dataset_id = create_response.json()["dataset_id"]
|
|
71
|
+
|
|
72
|
+
artifact_response = e2e_client.get(f"/v1/datasets/{dataset_id}/artifact")
|
|
73
|
+
assert artifact_response.status_code == 200
|
|
74
|
+
assert artifact_response.headers["content-type"] == "application/octet-stream"
|
|
75
|
+
|
|
76
|
+
with np.load(io.BytesIO(artifact_response.content)) as data:
|
|
77
|
+
assert "X_train" in data.files
|
|
78
|
+
assert "y_train" in data.files
|
|
79
|
+
assert "X_test" in data.files
|
|
80
|
+
assert "y_test" in data.files
|
|
81
|
+
assert "X_full" in data.files
|
|
82
|
+
assert "y_full" in data.files
|
|
83
|
+
|
|
84
|
+
X_train = data["X_train"]
|
|
85
|
+
y_train = data["y_train"]
|
|
86
|
+
X_test = data["X_test"]
|
|
87
|
+
y_test = data["y_test"]
|
|
88
|
+
X_full = data["X_full"]
|
|
89
|
+
y_full = data["y_full"]
|
|
90
|
+
|
|
91
|
+
assert X_train.dtype == np.float32
|
|
92
|
+
assert y_train.dtype == np.float32
|
|
93
|
+
assert X_test.dtype == np.float32
|
|
94
|
+
assert y_test.dtype == np.float32
|
|
95
|
+
assert X_full.dtype == np.float32
|
|
96
|
+
assert y_full.dtype == np.float32
|
|
97
|
+
|
|
98
|
+
n_total = 2 * 100
|
|
99
|
+
n_train = int(n_total * 0.8)
|
|
100
|
+
n_test = n_total - n_train
|
|
101
|
+
n_spirals = 2
|
|
102
|
+
|
|
103
|
+
assert X_train.shape == (n_train, 2)
|
|
104
|
+
assert y_train.shape == (n_train, n_spirals)
|
|
105
|
+
assert X_test.shape == (n_test, 2)
|
|
106
|
+
assert y_test.shape == (n_test, n_spirals)
|
|
107
|
+
assert X_full.shape == (n_total, 2)
|
|
108
|
+
assert y_full.shape == (n_total, n_spirals)
|
|
109
|
+
|
|
110
|
+
def test_e2e_deterministic_with_seed(self, e2e_client: TestClient, modern_request: dict) -> None:
|
|
111
|
+
"""Same seed produces identical data (determinism verification)."""
|
|
112
|
+
create_response1 = e2e_client.post("/v1/datasets", json=modern_request)
|
|
113
|
+
dataset_id1 = create_response1.json()["dataset_id"]
|
|
114
|
+
artifact_response1 = e2e_client.get(f"/v1/datasets/{dataset_id1}/artifact")
|
|
115
|
+
|
|
116
|
+
modern_request["params"]["seed"] = 42
|
|
117
|
+
create_response2 = e2e_client.post("/v1/datasets", json=modern_request)
|
|
118
|
+
dataset_id2 = create_response2.json()["dataset_id"]
|
|
119
|
+
artifact_response2 = e2e_client.get(f"/v1/datasets/{dataset_id2}/artifact")
|
|
120
|
+
|
|
121
|
+
assert dataset_id1 == dataset_id2
|
|
122
|
+
|
|
123
|
+
with np.load(io.BytesIO(artifact_response1.content)) as data1:
|
|
124
|
+
with np.load(io.BytesIO(artifact_response2.content)) as data2:
|
|
125
|
+
np.testing.assert_array_equal(data1["X_full"], data2["X_full"])
|
|
126
|
+
np.testing.assert_array_equal(data1["y_full"], data2["y_full"])
|
|
127
|
+
|
|
128
|
+
def test_e2e_different_seed_different_data(self, e2e_client: TestClient, modern_request: dict) -> None:
|
|
129
|
+
"""Different seeds produce different data."""
|
|
130
|
+
modern_request["params"]["seed"] = 42
|
|
131
|
+
create_response1 = e2e_client.post("/v1/datasets", json=modern_request)
|
|
132
|
+
dataset_id1 = create_response1.json()["dataset_id"]
|
|
133
|
+
artifact_response1 = e2e_client.get(f"/v1/datasets/{dataset_id1}/artifact")
|
|
134
|
+
|
|
135
|
+
modern_request["params"]["seed"] = 123
|
|
136
|
+
create_response2 = e2e_client.post("/v1/datasets", json=modern_request)
|
|
137
|
+
dataset_id2 = create_response2.json()["dataset_id"]
|
|
138
|
+
artifact_response2 = e2e_client.get(f"/v1/datasets/{dataset_id2}/artifact")
|
|
139
|
+
|
|
140
|
+
assert dataset_id1 != dataset_id2
|
|
141
|
+
|
|
142
|
+
with np.load(io.BytesIO(artifact_response1.content)) as data1:
|
|
143
|
+
with np.load(io.BytesIO(artifact_response2.content)) as data2:
|
|
144
|
+
assert not np.array_equal(data1["X_full"], data2["X_full"])
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@pytest.mark.integration
|
|
148
|
+
@pytest.mark.slow
|
|
149
|
+
class TestE2ELegacyCascorAlgorithm:
|
|
150
|
+
"""E2E tests for the legacy_cascor spiral generation algorithm."""
|
|
151
|
+
|
|
152
|
+
@pytest.fixture
|
|
153
|
+
def legacy_request(self) -> dict:
|
|
154
|
+
"""Request for legacy_cascor algorithm spiral dataset."""
|
|
155
|
+
return {
|
|
156
|
+
"generator": "spiral",
|
|
157
|
+
"params": {
|
|
158
|
+
"n_spirals": 2,
|
|
159
|
+
"n_points_per_spiral": 100,
|
|
160
|
+
"seed": 42,
|
|
161
|
+
"algorithm": "legacy_cascor",
|
|
162
|
+
"radius": 10.0,
|
|
163
|
+
"origin": [0.0, 0.0],
|
|
164
|
+
"noise": 0.1,
|
|
165
|
+
"train_ratio": 0.8,
|
|
166
|
+
"test_ratio": 0.2,
|
|
167
|
+
},
|
|
168
|
+
"persist": True,
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
def test_e2e_create_download_verify_legacy(self, e2e_client: TestClient, legacy_request: dict) -> None:
|
|
172
|
+
"""Complete E2E flow for legacy_cascor algorithm."""
|
|
173
|
+
create_response = e2e_client.post("/v1/datasets", json=legacy_request)
|
|
174
|
+
assert create_response.status_code == 201
|
|
175
|
+
dataset_id = create_response.json()["dataset_id"]
|
|
176
|
+
|
|
177
|
+
artifact_response = e2e_client.get(f"/v1/datasets/{dataset_id}/artifact")
|
|
178
|
+
assert artifact_response.status_code == 200
|
|
179
|
+
|
|
180
|
+
with np.load(io.BytesIO(artifact_response.content)) as data:
|
|
181
|
+
expected_keys = ["X_train", "y_train", "X_test", "y_test", "X_full", "y_full"]
|
|
182
|
+
for key in expected_keys:
|
|
183
|
+
assert key in data.files, f"Missing key: {key}"
|
|
184
|
+
|
|
185
|
+
X_full = data["X_full"]
|
|
186
|
+
y_full = data["y_full"]
|
|
187
|
+
|
|
188
|
+
assert X_full.dtype == np.float32
|
|
189
|
+
assert y_full.dtype == np.float32
|
|
190
|
+
|
|
191
|
+
n_total = 2 * 100
|
|
192
|
+
assert X_full.shape == (n_total, 2)
|
|
193
|
+
assert y_full.shape == (n_total, 2)
|
|
194
|
+
|
|
195
|
+
def test_e2e_legacy_vs_modern_different(self, e2e_client: TestClient) -> None:
|
|
196
|
+
"""Legacy and modern algorithms produce different data with same seed."""
|
|
197
|
+
base_params = {
|
|
198
|
+
"n_spirals": 2,
|
|
199
|
+
"n_points_per_spiral": 50,
|
|
200
|
+
"seed": 42,
|
|
201
|
+
"noise": 0.1,
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
modern_request = {
|
|
205
|
+
"generator": "spiral",
|
|
206
|
+
"params": {**base_params, "algorithm": "modern"},
|
|
207
|
+
"persist": True,
|
|
208
|
+
}
|
|
209
|
+
legacy_request = {
|
|
210
|
+
"generator": "spiral",
|
|
211
|
+
"params": {**base_params, "algorithm": "legacy_cascor", "radius": 10.0},
|
|
212
|
+
"persist": True,
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
modern_response = e2e_client.post("/v1/datasets", json=modern_request)
|
|
216
|
+
legacy_response = e2e_client.post("/v1/datasets", json=legacy_request)
|
|
217
|
+
|
|
218
|
+
modern_id = modern_response.json()["dataset_id"]
|
|
219
|
+
legacy_id = legacy_response.json()["dataset_id"]
|
|
220
|
+
|
|
221
|
+
assert modern_id != legacy_id
|
|
222
|
+
|
|
223
|
+
modern_artifact = e2e_client.get(f"/v1/datasets/{modern_id}/artifact")
|
|
224
|
+
legacy_artifact = e2e_client.get(f"/v1/datasets/{legacy_id}/artifact")
|
|
225
|
+
|
|
226
|
+
with np.load(io.BytesIO(modern_artifact.content)) as modern_data:
|
|
227
|
+
with np.load(io.BytesIO(legacy_artifact.content)) as legacy_data:
|
|
228
|
+
assert not np.array_equal(modern_data["X_full"], legacy_data["X_full"])
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
@pytest.mark.integration
|
|
232
|
+
@pytest.mark.slow
|
|
233
|
+
class TestE2EDataContract:
|
|
234
|
+
"""E2E tests verifying the NPZ data contract for consumers."""
|
|
235
|
+
|
|
236
|
+
@pytest.fixture
|
|
237
|
+
def contract_request(self) -> dict:
|
|
238
|
+
"""Standard request for data contract verification."""
|
|
239
|
+
return {
|
|
240
|
+
"generator": "spiral",
|
|
241
|
+
"params": {
|
|
242
|
+
"n_spirals": 2,
|
|
243
|
+
"n_points_per_spiral": 50,
|
|
244
|
+
"seed": 12345,
|
|
245
|
+
"train_ratio": 0.7,
|
|
246
|
+
"test_ratio": 0.3,
|
|
247
|
+
},
|
|
248
|
+
"persist": True,
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
def test_e2e_npz_keys_contract(self, e2e_client: TestClient, contract_request: dict) -> None:
|
|
252
|
+
"""Verify NPZ contains exactly the expected keys per data contract."""
|
|
253
|
+
create_response = e2e_client.post("/v1/datasets", json=contract_request)
|
|
254
|
+
dataset_id = create_response.json()["dataset_id"]
|
|
255
|
+
artifact_response = e2e_client.get(f"/v1/datasets/{dataset_id}/artifact")
|
|
256
|
+
|
|
257
|
+
with np.load(io.BytesIO(artifact_response.content)) as data:
|
|
258
|
+
expected_keys = {"X_train", "y_train", "X_test", "y_test", "X_full", "y_full"}
|
|
259
|
+
actual_keys = set(data.files)
|
|
260
|
+
assert actual_keys == expected_keys, f"Keys mismatch: expected {expected_keys}, got {actual_keys}"
|
|
261
|
+
|
|
262
|
+
def test_e2e_feature_dimensions(self, e2e_client: TestClient, contract_request: dict) -> None:
|
|
263
|
+
"""Verify features have 2 dimensions (x, y coordinates)."""
|
|
264
|
+
create_response = e2e_client.post("/v1/datasets", json=contract_request)
|
|
265
|
+
dataset_id = create_response.json()["dataset_id"]
|
|
266
|
+
artifact_response = e2e_client.get(f"/v1/datasets/{dataset_id}/artifact")
|
|
267
|
+
|
|
268
|
+
with np.load(io.BytesIO(artifact_response.content)) as data:
|
|
269
|
+
assert data["X_train"].shape[1] == 2
|
|
270
|
+
assert data["X_test"].shape[1] == 2
|
|
271
|
+
assert data["X_full"].shape[1] == 2
|
|
272
|
+
|
|
273
|
+
def test_e2e_one_hot_labels(self, e2e_client: TestClient, contract_request: dict) -> None:
|
|
274
|
+
"""Verify labels are one-hot encoded with correct class count."""
|
|
275
|
+
create_response = e2e_client.post("/v1/datasets", json=contract_request)
|
|
276
|
+
dataset_id = create_response.json()["dataset_id"]
|
|
277
|
+
artifact_response = e2e_client.get(f"/v1/datasets/{dataset_id}/artifact")
|
|
278
|
+
|
|
279
|
+
with np.load(io.BytesIO(artifact_response.content)) as data:
|
|
280
|
+
y_full = data["y_full"]
|
|
281
|
+
n_spirals = contract_request["params"]["n_spirals"]
|
|
282
|
+
|
|
283
|
+
assert y_full.shape[1] == n_spirals
|
|
284
|
+
|
|
285
|
+
row_sums = y_full.sum(axis=1)
|
|
286
|
+
np.testing.assert_array_almost_equal(row_sums, np.ones(len(y_full)))
|
|
287
|
+
|
|
288
|
+
assert set(np.unique(y_full)) == {0.0, 1.0}
|
|
289
|
+
|
|
290
|
+
def test_e2e_train_test_split_ratios(self, e2e_client: TestClient, contract_request: dict) -> None:
|
|
291
|
+
"""Verify train/test split matches requested ratios."""
|
|
292
|
+
create_response = e2e_client.post("/v1/datasets", json=contract_request)
|
|
293
|
+
dataset_id = create_response.json()["dataset_id"]
|
|
294
|
+
artifact_response = e2e_client.get(f"/v1/datasets/{dataset_id}/artifact")
|
|
295
|
+
|
|
296
|
+
with np.load(io.BytesIO(artifact_response.content)) as data:
|
|
297
|
+
n_train = len(data["X_train"])
|
|
298
|
+
n_test = len(data["X_test"])
|
|
299
|
+
n_full = len(data["X_full"])
|
|
300
|
+
|
|
301
|
+
assert n_train + n_test == n_full
|
|
302
|
+
|
|
303
|
+
expected_train_ratio = 0.7
|
|
304
|
+
actual_train_ratio = n_train / n_full
|
|
305
|
+
assert abs(actual_train_ratio - expected_train_ratio) < 0.05
|
|
306
|
+
|
|
307
|
+
def test_e2e_metadata_consistency(self, e2e_client: TestClient, contract_request: dict) -> None:
|
|
308
|
+
"""Verify metadata matches actual data dimensions."""
|
|
309
|
+
create_response = e2e_client.post("/v1/datasets", json=contract_request)
|
|
310
|
+
data = create_response.json()
|
|
311
|
+
dataset_id = data["dataset_id"]
|
|
312
|
+
meta = data["meta"]
|
|
313
|
+
|
|
314
|
+
artifact_response = e2e_client.get(f"/v1/datasets/{dataset_id}/artifact")
|
|
315
|
+
|
|
316
|
+
with np.load(io.BytesIO(artifact_response.content)) as npz_data:
|
|
317
|
+
assert meta["n_samples"] == len(npz_data["X_full"])
|
|
318
|
+
assert meta["n_train"] == len(npz_data["X_train"])
|
|
319
|
+
assert meta["n_test"] == len(npz_data["X_test"])
|
|
320
|
+
assert meta["n_features"] == npz_data["X_full"].shape[1]
|
|
321
|
+
assert meta["n_classes"] == npz_data["y_full"].shape[1]
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
@pytest.mark.integration
|
|
325
|
+
@pytest.mark.slow
|
|
326
|
+
class TestE2EErrorHandling:
|
|
327
|
+
"""E2E tests for error handling scenarios."""
|
|
328
|
+
|
|
329
|
+
def test_e2e_invalid_generator_name(self, e2e_client: TestClient) -> None:
|
|
330
|
+
"""Invalid generator name returns error (400 or 404)."""
|
|
331
|
+
request = {
|
|
332
|
+
"generator": "nonexistent_generator",
|
|
333
|
+
"params": {},
|
|
334
|
+
"persist": True,
|
|
335
|
+
}
|
|
336
|
+
response = e2e_client.post("/v1/datasets", json=request)
|
|
337
|
+
assert response.status_code in (400, 404)
|
|
338
|
+
assert "detail" in response.json()
|
|
339
|
+
|
|
340
|
+
def test_e2e_invalid_params(self, e2e_client: TestClient) -> None:
|
|
341
|
+
"""Invalid parameters return 400/422."""
|
|
342
|
+
request = {
|
|
343
|
+
"generator": "spiral",
|
|
344
|
+
"params": {
|
|
345
|
+
"n_spirals": -1,
|
|
346
|
+
"n_points_per_spiral": 100,
|
|
347
|
+
},
|
|
348
|
+
"persist": True,
|
|
349
|
+
}
|
|
350
|
+
response = e2e_client.post("/v1/datasets", json=request)
|
|
351
|
+
assert response.status_code in (400, 422)
|
|
352
|
+
|
|
353
|
+
def test_e2e_nonexistent_dataset_artifact(self, e2e_client: TestClient) -> None:
|
|
354
|
+
"""Requesting artifact for nonexistent dataset returns 404."""
|
|
355
|
+
response = e2e_client.get("/v1/datasets/nonexistent-id-12345/artifact")
|
|
356
|
+
assert response.status_code == 404
|
|
357
|
+
|
|
358
|
+
def test_e2e_delete_and_verify_gone(self, e2e_client: TestClient) -> None:
|
|
359
|
+
"""Deleted dataset cannot be retrieved."""
|
|
360
|
+
request = {
|
|
361
|
+
"generator": "spiral",
|
|
362
|
+
"params": {"n_spirals": 2, "n_points_per_spiral": 10, "seed": 1},
|
|
363
|
+
"persist": True,
|
|
364
|
+
}
|
|
365
|
+
create_response = e2e_client.post("/v1/datasets", json=request)
|
|
366
|
+
dataset_id = create_response.json()["dataset_id"]
|
|
367
|
+
|
|
368
|
+
get_response = e2e_client.get(f"/v1/datasets/{dataset_id}")
|
|
369
|
+
assert get_response.status_code == 200
|
|
370
|
+
|
|
371
|
+
delete_response = e2e_client.delete(f"/v1/datasets/{dataset_id}")
|
|
372
|
+
assert delete_response.status_code == 204
|
|
373
|
+
|
|
374
|
+
get_after_delete = e2e_client.get(f"/v1/datasets/{dataset_id}")
|
|
375
|
+
assert get_after_delete.status_code == 404
|
|
376
|
+
|
|
377
|
+
artifact_after_delete = e2e_client.get(f"/v1/datasets/{dataset_id}/artifact")
|
|
378
|
+
assert artifact_after_delete.status_code == 404
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
"""Integration tests for dataset lifecycle management API endpoints (DATA-016).
|
|
2
|
+
|
|
3
|
+
Tests for:
|
|
4
|
+
- POST /v1/datasets with tags and TTL
|
|
5
|
+
- GET /v1/datasets/filter
|
|
6
|
+
- POST /v1/datasets/batch-delete
|
|
7
|
+
- PATCH /v1/datasets/{id}/tags
|
|
8
|
+
- GET /v1/datasets/stats
|
|
9
|
+
- POST /v1/datasets/cleanup-expired
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from datetime import UTC, datetime, timedelta
|
|
13
|
+
|
|
14
|
+
import pytest
|
|
15
|
+
from fastapi.testclient import TestClient
|
|
16
|
+
|
|
17
|
+
from juniper_data.api.app import create_app
|
|
18
|
+
from juniper_data.api.routes import datasets
|
|
19
|
+
from juniper_data.api.settings import Settings
|
|
20
|
+
from juniper_data.storage.memory import InMemoryDatasetStore
|
|
21
|
+
|
|
22
|
+
# from typing import Dict
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def lifecycle_store() -> InMemoryDatasetStore:
|
|
27
|
+
"""Create a fresh in-memory store for lifecycle tests."""
|
|
28
|
+
return InMemoryDatasetStore()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@pytest.fixture
|
|
32
|
+
def lifecycle_settings() -> Settings:
|
|
33
|
+
"""Create lifecycle test settings."""
|
|
34
|
+
return Settings(storage_path="/tmp/juniper_data_lifecycle_test")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@pytest.fixture
|
|
38
|
+
def lifecycle_client(lifecycle_store: InMemoryDatasetStore, lifecycle_settings: Settings) -> TestClient:
|
|
39
|
+
"""Create a lifecycle test client with in-memory storage."""
|
|
40
|
+
app = create_app(settings=lifecycle_settings)
|
|
41
|
+
datasets.set_store(lifecycle_store)
|
|
42
|
+
return TestClient(app)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _create_spiral_request(
|
|
46
|
+
n_points: int = 50,
|
|
47
|
+
seed: int = 42,
|
|
48
|
+
tags: list[str] | None = None,
|
|
49
|
+
ttl_seconds: int | None = None,
|
|
50
|
+
) -> dict:
|
|
51
|
+
"""Create a spiral dataset request."""
|
|
52
|
+
request = {
|
|
53
|
+
"generator": "spiral",
|
|
54
|
+
"params": {"n_spirals": 2, "n_points_per_spiral": n_points, "seed": seed},
|
|
55
|
+
"persist": True,
|
|
56
|
+
}
|
|
57
|
+
if tags:
|
|
58
|
+
request["tags"] = tags
|
|
59
|
+
if ttl_seconds:
|
|
60
|
+
request["ttl_seconds"] = ttl_seconds
|
|
61
|
+
return request
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@pytest.mark.integration
|
|
65
|
+
class TestCreateDatasetWithLifecycle:
|
|
66
|
+
"""Tests for creating datasets with lifecycle features."""
|
|
67
|
+
|
|
68
|
+
def test_create_dataset_with_tags(self, lifecycle_client: TestClient) -> None:
|
|
69
|
+
"""Create dataset with tags."""
|
|
70
|
+
request = _create_spiral_request(tags=["train", "experiment-1"])
|
|
71
|
+
response = lifecycle_client.post("/v1/datasets", json=request)
|
|
72
|
+
|
|
73
|
+
assert response.status_code == 201
|
|
74
|
+
meta = response.json()["meta"]
|
|
75
|
+
assert "train" in meta["tags"]
|
|
76
|
+
assert "experiment-1" in meta["tags"]
|
|
77
|
+
|
|
78
|
+
def test_create_dataset_with_ttl(self, lifecycle_client: TestClient) -> None:
|
|
79
|
+
"""Create dataset with TTL."""
|
|
80
|
+
request = _create_spiral_request(ttl_seconds=3600)
|
|
81
|
+
response = lifecycle_client.post("/v1/datasets", json=request)
|
|
82
|
+
|
|
83
|
+
assert response.status_code == 201
|
|
84
|
+
meta = response.json()["meta"]
|
|
85
|
+
assert meta["ttl_seconds"] == 3600
|
|
86
|
+
assert meta["expires_at"] is not None
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@pytest.mark.integration
|
|
90
|
+
class TestFilterDatasets:
|
|
91
|
+
"""Tests for the filter datasets endpoint."""
|
|
92
|
+
|
|
93
|
+
@pytest.fixture
|
|
94
|
+
def populated_client(self, lifecycle_client: TestClient) -> TestClient:
|
|
95
|
+
"""Create multiple datasets for filtering tests."""
|
|
96
|
+
requests = [
|
|
97
|
+
_create_spiral_request(n_points=50, seed=1, tags=["train", "v1"]),
|
|
98
|
+
_create_spiral_request(n_points=100, seed=2, tags=["train", "v2"]),
|
|
99
|
+
_create_spiral_request(n_points=150, seed=3, tags=["test", "v1"]),
|
|
100
|
+
_create_spiral_request(n_points=200, seed=4, tags=["test", "v2"]),
|
|
101
|
+
]
|
|
102
|
+
for req in requests:
|
|
103
|
+
lifecycle_client.post("/v1/datasets", json=req)
|
|
104
|
+
return lifecycle_client
|
|
105
|
+
|
|
106
|
+
def test_filter_by_generator(self, populated_client: TestClient) -> None:
|
|
107
|
+
"""Filter by generator name."""
|
|
108
|
+
response = populated_client.get("/v1/datasets/filter?generator=spiral")
|
|
109
|
+
|
|
110
|
+
assert response.status_code == 200
|
|
111
|
+
data = response.json()
|
|
112
|
+
assert data["total"] == 4
|
|
113
|
+
assert all(d["generator"] == "spiral" for d in data["datasets"])
|
|
114
|
+
|
|
115
|
+
def test_filter_by_tags_any(self, populated_client: TestClient) -> None:
|
|
116
|
+
"""Filter by tags with any match."""
|
|
117
|
+
response = populated_client.get("/v1/datasets/filter?tags=train&tags_match=any")
|
|
118
|
+
|
|
119
|
+
assert response.status_code == 200
|
|
120
|
+
data = response.json()
|
|
121
|
+
assert data["total"] == 2
|
|
122
|
+
|
|
123
|
+
def test_filter_by_tags_all(self, populated_client: TestClient) -> None:
|
|
124
|
+
"""Filter by tags with all match."""
|
|
125
|
+
response = populated_client.get("/v1/datasets/filter?tags=train,v1&tags_match=all")
|
|
126
|
+
|
|
127
|
+
assert response.status_code == 200
|
|
128
|
+
data = response.json()
|
|
129
|
+
assert data["total"] == 1
|
|
130
|
+
|
|
131
|
+
def test_filter_by_sample_count(self, populated_client: TestClient) -> None:
|
|
132
|
+
"""Filter by sample count range."""
|
|
133
|
+
response = populated_client.get("/v1/datasets/filter?min_samples=250&max_samples=350")
|
|
134
|
+
|
|
135
|
+
assert response.status_code == 200
|
|
136
|
+
data = response.json()
|
|
137
|
+
assert data["total"] == 1
|
|
138
|
+
|
|
139
|
+
def test_filter_with_pagination(self, populated_client: TestClient) -> None:
|
|
140
|
+
"""Filter with pagination."""
|
|
141
|
+
response = populated_client.get("/v1/datasets/filter?limit=2&offset=0")
|
|
142
|
+
|
|
143
|
+
assert response.status_code == 200
|
|
144
|
+
data = response.json()
|
|
145
|
+
assert data["total"] == 4
|
|
146
|
+
assert len(data["datasets"]) == 2
|
|
147
|
+
assert data["limit"] == 2
|
|
148
|
+
assert data["offset"] == 0
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@pytest.mark.integration
|
|
152
|
+
class TestBatchDelete:
|
|
153
|
+
"""Tests for the batch delete endpoint."""
|
|
154
|
+
|
|
155
|
+
def test_batch_delete_existing(self, lifecycle_client: TestClient) -> None:
|
|
156
|
+
"""Batch delete existing datasets."""
|
|
157
|
+
ids = []
|
|
158
|
+
for seed in range(3):
|
|
159
|
+
response = lifecycle_client.post("/v1/datasets", json=_create_spiral_request(seed=seed))
|
|
160
|
+
ids.append(response.json()["dataset_id"])
|
|
161
|
+
|
|
162
|
+
response = lifecycle_client.post("/v1/datasets/batch-delete", json={"dataset_ids": ids[:2]})
|
|
163
|
+
|
|
164
|
+
assert response.status_code == 200
|
|
165
|
+
data = response.json()
|
|
166
|
+
assert len(data["deleted"]) == 2
|
|
167
|
+
assert data["not_found"] == []
|
|
168
|
+
assert data["total_deleted"] == 2
|
|
169
|
+
|
|
170
|
+
for deleted_id in ids[:2]:
|
|
171
|
+
get_response = lifecycle_client.get(f"/v1/datasets/{deleted_id}")
|
|
172
|
+
assert get_response.status_code == 404
|
|
173
|
+
|
|
174
|
+
get_response = lifecycle_client.get(f"/v1/datasets/{ids[2]}")
|
|
175
|
+
assert get_response.status_code == 200
|
|
176
|
+
|
|
177
|
+
def test_batch_delete_mixed(self, lifecycle_client: TestClient) -> None:
|
|
178
|
+
"""Batch delete with some nonexistent IDs."""
|
|
179
|
+
response = lifecycle_client.post("/v1/datasets", json=_create_spiral_request(seed=42))
|
|
180
|
+
existing_id = response.json()["dataset_id"]
|
|
181
|
+
|
|
182
|
+
response = lifecycle_client.post(
|
|
183
|
+
"/v1/datasets/batch-delete", json={"dataset_ids": [existing_id, "fake-id-1", "fake-id-2"]}
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
assert response.status_code == 200
|
|
187
|
+
data = response.json()
|
|
188
|
+
assert data["deleted"] == [existing_id]
|
|
189
|
+
assert set(data["not_found"]) == {"fake-id-1", "fake-id-2"}
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
@pytest.mark.integration
|
|
193
|
+
class TestUpdateTags:
|
|
194
|
+
"""Tests for the update tags endpoint."""
|
|
195
|
+
|
|
196
|
+
def test_add_tags(self, lifecycle_client: TestClient) -> None:
|
|
197
|
+
"""Add tags to existing dataset."""
|
|
198
|
+
response = lifecycle_client.post("/v1/datasets", json=_create_spiral_request(tags=["original"]))
|
|
199
|
+
dataset_id = response.json()["dataset_id"]
|
|
200
|
+
|
|
201
|
+
response = lifecycle_client.patch(
|
|
202
|
+
f"/v1/datasets/{dataset_id}/tags", json={"add_tags": ["new-tag-1", "new-tag-2"]}
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
assert response.status_code == 200
|
|
206
|
+
tags = response.json()["tags"]
|
|
207
|
+
assert "original" in tags
|
|
208
|
+
assert "new-tag-1" in tags
|
|
209
|
+
assert "new-tag-2" in tags
|
|
210
|
+
|
|
211
|
+
def test_remove_tags(self, lifecycle_client: TestClient) -> None:
|
|
212
|
+
"""Remove tags from existing dataset."""
|
|
213
|
+
response = lifecycle_client.post("/v1/datasets", json=_create_spiral_request(tags=["keep", "remove"]))
|
|
214
|
+
dataset_id = response.json()["dataset_id"]
|
|
215
|
+
|
|
216
|
+
response = lifecycle_client.patch(f"/v1/datasets/{dataset_id}/tags", json={"remove_tags": ["remove"]})
|
|
217
|
+
|
|
218
|
+
assert response.status_code == 200
|
|
219
|
+
tags = response.json()["tags"]
|
|
220
|
+
assert "keep" in tags
|
|
221
|
+
assert "remove" not in tags
|
|
222
|
+
|
|
223
|
+
def test_add_and_remove_tags(self, lifecycle_client: TestClient) -> None:
|
|
224
|
+
"""Add and remove tags in single request."""
|
|
225
|
+
response = lifecycle_client.post("/v1/datasets", json=_create_spiral_request(tags=["a", "b"]))
|
|
226
|
+
dataset_id = response.json()["dataset_id"]
|
|
227
|
+
|
|
228
|
+
response = lifecycle_client.patch(
|
|
229
|
+
f"/v1/datasets/{dataset_id}/tags", json={"add_tags": ["c"], "remove_tags": ["a"]}
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
assert response.status_code == 200
|
|
233
|
+
tags = response.json()["tags"]
|
|
234
|
+
assert set(tags) == {"b", "c"}
|
|
235
|
+
|
|
236
|
+
def test_update_tags_not_found(self, lifecycle_client: TestClient) -> None:
|
|
237
|
+
"""Update tags on nonexistent dataset returns 404."""
|
|
238
|
+
response = lifecycle_client.patch("/v1/datasets/nonexistent-id/tags", json={"add_tags": ["test"]})
|
|
239
|
+
assert response.status_code == 404
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
@pytest.mark.integration
|
|
243
|
+
class TestDatasetStats:
|
|
244
|
+
"""Tests for the stats endpoint."""
|
|
245
|
+
|
|
246
|
+
def test_stats_empty(self, lifecycle_client: TestClient) -> None:
|
|
247
|
+
"""Stats for empty store."""
|
|
248
|
+
response = lifecycle_client.get("/v1/datasets/stats")
|
|
249
|
+
|
|
250
|
+
assert response.status_code == 200
|
|
251
|
+
data = response.json()
|
|
252
|
+
assert data["total_datasets"] == 0
|
|
253
|
+
assert data["total_samples"] == 0
|
|
254
|
+
|
|
255
|
+
def test_stats_populated(self, lifecycle_client: TestClient) -> None:
|
|
256
|
+
"""Stats for populated store."""
|
|
257
|
+
lifecycle_client.post("/v1/datasets", json=_create_spiral_request(n_points=50, seed=1, tags=["train"]))
|
|
258
|
+
lifecycle_client.post("/v1/datasets", json=_create_spiral_request(n_points=100, seed=2, tags=["train", "v2"]))
|
|
259
|
+
lifecycle_client.post("/v1/datasets", json=_create_spiral_request(n_points=150, seed=3, tags=["test"]))
|
|
260
|
+
|
|
261
|
+
response = lifecycle_client.get("/v1/datasets/stats")
|
|
262
|
+
|
|
263
|
+
assert response.status_code == 200
|
|
264
|
+
data = response.json()
|
|
265
|
+
assert data["total_datasets"] == 3
|
|
266
|
+
assert data["total_samples"] == 600
|
|
267
|
+
assert data["by_generator"] == {"spiral": 3}
|
|
268
|
+
assert data["by_tag"]["train"] == 2
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@pytest.mark.integration
|
|
272
|
+
class TestCleanupExpired:
|
|
273
|
+
"""Tests for the cleanup-expired endpoint."""
|
|
274
|
+
|
|
275
|
+
def test_cleanup_expired_none(self, lifecycle_client: TestClient) -> None:
|
|
276
|
+
"""Cleanup with no expired datasets."""
|
|
277
|
+
lifecycle_client.post("/v1/datasets", json=_create_spiral_request(seed=1))
|
|
278
|
+
lifecycle_client.post("/v1/datasets", json=_create_spiral_request(seed=2))
|
|
279
|
+
|
|
280
|
+
response = lifecycle_client.post("/v1/datasets/cleanup-expired")
|
|
281
|
+
|
|
282
|
+
assert response.status_code == 200
|
|
283
|
+
assert response.json() == []
|
|
284
|
+
|
|
285
|
+
def test_cleanup_expired_with_ttl(
|
|
286
|
+
self, lifecycle_client: TestClient, lifecycle_store: InMemoryDatasetStore
|
|
287
|
+
) -> None:
|
|
288
|
+
"""Cleanup datasets with expired TTL requires manipulating store directly."""
|
|
289
|
+
response = lifecycle_client.post("/v1/datasets", json=_create_spiral_request(seed=1, ttl_seconds=3600))
|
|
290
|
+
dataset_id = response.json()["dataset_id"]
|
|
291
|
+
|
|
292
|
+
meta = lifecycle_store.get_meta(dataset_id)
|
|
293
|
+
assert meta is not None
|
|
294
|
+
meta.expires_at = datetime.now(UTC) - timedelta(hours=1)
|
|
295
|
+
lifecycle_store.update_meta(dataset_id, meta)
|
|
296
|
+
|
|
297
|
+
response = lifecycle_client.post("/v1/datasets/cleanup-expired")
|
|
298
|
+
|
|
299
|
+
assert response.status_code == 200
|
|
300
|
+
deleted = response.json()
|
|
301
|
+
assert dataset_id in deleted
|
|
302
|
+
|
|
303
|
+
get_response = lifecycle_client.get(f"/v1/datasets/{dataset_id}")
|
|
304
|
+
assert get_response.status_code == 404
|