juniper-data 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. juniper_data/__init__.py +88 -0
  2. juniper_data/__main__.py +78 -0
  3. juniper_data/api/__init__.py +10 -0
  4. juniper_data/api/app.py +111 -0
  5. juniper_data/api/middleware.py +95 -0
  6. juniper_data/api/routes/__init__.py +9 -0
  7. juniper_data/api/routes/datasets.py +414 -0
  8. juniper_data/api/routes/generators.py +125 -0
  9. juniper_data/api/routes/health.py +49 -0
  10. juniper_data/api/security.py +238 -0
  11. juniper_data/api/settings.py +109 -0
  12. juniper_data/core/__init__.py +32 -0
  13. juniper_data/core/artifacts.py +63 -0
  14. juniper_data/core/dataset_id.py +38 -0
  15. juniper_data/core/models.py +135 -0
  16. juniper_data/core/split.py +120 -0
  17. juniper_data/generators/__init__.py +15 -0
  18. juniper_data/generators/arc_agi/__init__.py +11 -0
  19. juniper_data/generators/arc_agi/generator.py +229 -0
  20. juniper_data/generators/arc_agi/params.py +56 -0
  21. juniper_data/generators/checkerboard/__init__.py +15 -0
  22. juniper_data/generators/checkerboard/generator.py +114 -0
  23. juniper_data/generators/checkerboard/params.py +32 -0
  24. juniper_data/generators/circles/__init__.py +11 -0
  25. juniper_data/generators/circles/generator.py +112 -0
  26. juniper_data/generators/circles/params.py +31 -0
  27. juniper_data/generators/csv_import/__init__.py +15 -0
  28. juniper_data/generators/csv_import/generator.py +198 -0
  29. juniper_data/generators/csv_import/params.py +48 -0
  30. juniper_data/generators/gaussian/__init__.py +11 -0
  31. juniper_data/generators/gaussian/generator.py +149 -0
  32. juniper_data/generators/gaussian/params.py +53 -0
  33. juniper_data/generators/mnist/__init__.py +11 -0
  34. juniper_data/generators/mnist/generator.py +124 -0
  35. juniper_data/generators/mnist/params.py +39 -0
  36. juniper_data/generators/spiral/__init__.py +57 -0
  37. juniper_data/generators/spiral/defaults.py +39 -0
  38. juniper_data/generators/spiral/generator.py +206 -0
  39. juniper_data/generators/spiral/params.py +148 -0
  40. juniper_data/generators/xor/__init__.py +11 -0
  41. juniper_data/generators/xor/generator.py +162 -0
  42. juniper_data/generators/xor/params.py +30 -0
  43. juniper_data/storage/__init__.py +120 -0
  44. juniper_data/storage/base.py +279 -0
  45. juniper_data/storage/cached.py +211 -0
  46. juniper_data/storage/hf_store.py +257 -0
  47. juniper_data/storage/kaggle_store.py +333 -0
  48. juniper_data/storage/local_fs.py +232 -0
  49. juniper_data/storage/memory.py +136 -0
  50. juniper_data/storage/postgres_store.py +373 -0
  51. juniper_data/storage/redis_store.py +264 -0
  52. juniper_data/tests/__init__.py +1 -0
  53. juniper_data/tests/conftest.py +68 -0
  54. juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
  55. juniper_data/tests/integration/__init__.py +1 -0
  56. juniper_data/tests/integration/test_api.py +283 -0
  57. juniper_data/tests/integration/test_e2e_workflow.py +378 -0
  58. juniper_data/tests/integration/test_lifecycle_api.py +304 -0
  59. juniper_data/tests/integration/test_security_integration.py +189 -0
  60. juniper_data/tests/integration/test_storage_workflow.py +259 -0
  61. juniper_data/tests/performance/__init__.py +1 -0
  62. juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
  63. juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
  64. juniper_data/tests/unit/__init__.py +1 -0
  65. juniper_data/tests/unit/test_api_app.py +206 -0
  66. juniper_data/tests/unit/test_api_routes.py +407 -0
  67. juniper_data/tests/unit/test_api_settings.py +100 -0
  68. juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
  69. juniper_data/tests/unit/test_artifacts.py +145 -0
  70. juniper_data/tests/unit/test_cached_store.py +423 -0
  71. juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
  72. juniper_data/tests/unit/test_circles_generator.py +256 -0
  73. juniper_data/tests/unit/test_csv_import_generator.py +345 -0
  74. juniper_data/tests/unit/test_dataset_id.py +181 -0
  75. juniper_data/tests/unit/test_gaussian_generator.py +333 -0
  76. juniper_data/tests/unit/test_hf_store.py +416 -0
  77. juniper_data/tests/unit/test_init.py +93 -0
  78. juniper_data/tests/unit/test_kaggle_store.py +469 -0
  79. juniper_data/tests/unit/test_lifecycle.py +394 -0
  80. juniper_data/tests/unit/test_main.py +127 -0
  81. juniper_data/tests/unit/test_middleware.py +79 -0
  82. juniper_data/tests/unit/test_mnist_generator.py +370 -0
  83. juniper_data/tests/unit/test_postgres_store.py +490 -0
  84. juniper_data/tests/unit/test_redis_store.py +500 -0
  85. juniper_data/tests/unit/test_security.py +281 -0
  86. juniper_data/tests/unit/test_security_boundaries.py +517 -0
  87. juniper_data/tests/unit/test_spiral_generator.py +566 -0
  88. juniper_data/tests/unit/test_split.py +245 -0
  89. juniper_data/tests/unit/test_storage.py +767 -0
  90. juniper_data/tests/unit/test_xor_generator.py +223 -0
  91. juniper_data-0.4.2.dist-info/METADATA +216 -0
  92. juniper_data-0.4.2.dist-info/RECORD +95 -0
  93. juniper_data-0.4.2.dist-info/WHEEL +5 -0
  94. juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
  95. juniper_data-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,378 @@
1
+ """End-to-End integration tests for the complete JuniperData workflow.
2
+
3
+ These tests verify the full flow:
4
+ 1. Start JuniperData service (via TestClient)
5
+ 2. Create dataset via REST API
6
+ 3. Download NPZ artifact
7
+ 4. Verify data integrity (shapes, dtypes, determinism)
8
+
9
+ Marked with @pytest.mark.slow for weekly CI runs.
10
+ """
11
+
12
+ import io
13
+
14
+ import numpy as np
15
+ import pytest
16
+ from fastapi.testclient import TestClient
17
+
18
+ from juniper_data.api.app import create_app
19
+ from juniper_data.api.routes import datasets
20
+ from juniper_data.api.settings import Settings
21
+ from juniper_data.storage.memory import InMemoryDatasetStore
22
+
23
+
24
+ @pytest.fixture
25
+ def e2e_store() -> InMemoryDatasetStore:
26
+ """Create a fresh in-memory store for E2E tests."""
27
+ return InMemoryDatasetStore()
28
+
29
+
30
+ @pytest.fixture
31
+ def e2e_settings() -> Settings:
32
+ """Create E2E test settings."""
33
+ return Settings(storage_path="/tmp/juniper_data_e2e_test")
34
+
35
+
36
+ @pytest.fixture
37
+ def e2e_client(e2e_store: InMemoryDatasetStore, e2e_settings: Settings) -> TestClient:
38
+ """Create an E2E test client with in-memory storage."""
39
+ app = create_app(settings=e2e_settings)
40
+ datasets.set_store(e2e_store)
41
+ return TestClient(app)
42
+
43
+
44
+ @pytest.mark.integration
45
+ @pytest.mark.slow
46
+ class TestE2EModernAlgorithm:
47
+ """E2E tests for the modern spiral generation algorithm."""
48
+
49
+ @pytest.fixture
50
+ def modern_request(self) -> dict:
51
+ """Request for modern algorithm spiral dataset."""
52
+ return {
53
+ "generator": "spiral",
54
+ "params": {
55
+ "n_spirals": 2,
56
+ "n_points_per_spiral": 100,
57
+ "seed": 42,
58
+ "algorithm": "modern",
59
+ "noise": 0.1,
60
+ "train_ratio": 0.8,
61
+ "test_ratio": 0.2,
62
+ },
63
+ "persist": True,
64
+ }
65
+
66
+ def test_e2e_create_download_verify_modern(self, e2e_client: TestClient, modern_request: dict) -> None:
67
+ """Complete E2E flow: create dataset, download NPZ, verify integrity."""
68
+ create_response = e2e_client.post("/v1/datasets", json=modern_request)
69
+ assert create_response.status_code == 201
70
+ dataset_id = create_response.json()["dataset_id"]
71
+
72
+ artifact_response = e2e_client.get(f"/v1/datasets/{dataset_id}/artifact")
73
+ assert artifact_response.status_code == 200
74
+ assert artifact_response.headers["content-type"] == "application/octet-stream"
75
+
76
+ with np.load(io.BytesIO(artifact_response.content)) as data:
77
+ assert "X_train" in data.files
78
+ assert "y_train" in data.files
79
+ assert "X_test" in data.files
80
+ assert "y_test" in data.files
81
+ assert "X_full" in data.files
82
+ assert "y_full" in data.files
83
+
84
+ X_train = data["X_train"]
85
+ y_train = data["y_train"]
86
+ X_test = data["X_test"]
87
+ y_test = data["y_test"]
88
+ X_full = data["X_full"]
89
+ y_full = data["y_full"]
90
+
91
+ assert X_train.dtype == np.float32
92
+ assert y_train.dtype == np.float32
93
+ assert X_test.dtype == np.float32
94
+ assert y_test.dtype == np.float32
95
+ assert X_full.dtype == np.float32
96
+ assert y_full.dtype == np.float32
97
+
98
+ n_total = 2 * 100
99
+ n_train = int(n_total * 0.8)
100
+ n_test = n_total - n_train
101
+ n_spirals = 2
102
+
103
+ assert X_train.shape == (n_train, 2)
104
+ assert y_train.shape == (n_train, n_spirals)
105
+ assert X_test.shape == (n_test, 2)
106
+ assert y_test.shape == (n_test, n_spirals)
107
+ assert X_full.shape == (n_total, 2)
108
+ assert y_full.shape == (n_total, n_spirals)
109
+
110
+ def test_e2e_deterministic_with_seed(self, e2e_client: TestClient, modern_request: dict) -> None:
111
+ """Same seed produces identical data (determinism verification)."""
112
+ create_response1 = e2e_client.post("/v1/datasets", json=modern_request)
113
+ dataset_id1 = create_response1.json()["dataset_id"]
114
+ artifact_response1 = e2e_client.get(f"/v1/datasets/{dataset_id1}/artifact")
115
+
116
+ modern_request["params"]["seed"] = 42
117
+ create_response2 = e2e_client.post("/v1/datasets", json=modern_request)
118
+ dataset_id2 = create_response2.json()["dataset_id"]
119
+ artifact_response2 = e2e_client.get(f"/v1/datasets/{dataset_id2}/artifact")
120
+
121
+ assert dataset_id1 == dataset_id2
122
+
123
+ with np.load(io.BytesIO(artifact_response1.content)) as data1:
124
+ with np.load(io.BytesIO(artifact_response2.content)) as data2:
125
+ np.testing.assert_array_equal(data1["X_full"], data2["X_full"])
126
+ np.testing.assert_array_equal(data1["y_full"], data2["y_full"])
127
+
128
+ def test_e2e_different_seed_different_data(self, e2e_client: TestClient, modern_request: dict) -> None:
129
+ """Different seeds produce different data."""
130
+ modern_request["params"]["seed"] = 42
131
+ create_response1 = e2e_client.post("/v1/datasets", json=modern_request)
132
+ dataset_id1 = create_response1.json()["dataset_id"]
133
+ artifact_response1 = e2e_client.get(f"/v1/datasets/{dataset_id1}/artifact")
134
+
135
+ modern_request["params"]["seed"] = 123
136
+ create_response2 = e2e_client.post("/v1/datasets", json=modern_request)
137
+ dataset_id2 = create_response2.json()["dataset_id"]
138
+ artifact_response2 = e2e_client.get(f"/v1/datasets/{dataset_id2}/artifact")
139
+
140
+ assert dataset_id1 != dataset_id2
141
+
142
+ with np.load(io.BytesIO(artifact_response1.content)) as data1:
143
+ with np.load(io.BytesIO(artifact_response2.content)) as data2:
144
+ assert not np.array_equal(data1["X_full"], data2["X_full"])
145
+
146
+
147
+ @pytest.mark.integration
148
+ @pytest.mark.slow
149
+ class TestE2ELegacyCascorAlgorithm:
150
+ """E2E tests for the legacy_cascor spiral generation algorithm."""
151
+
152
+ @pytest.fixture
153
+ def legacy_request(self) -> dict:
154
+ """Request for legacy_cascor algorithm spiral dataset."""
155
+ return {
156
+ "generator": "spiral",
157
+ "params": {
158
+ "n_spirals": 2,
159
+ "n_points_per_spiral": 100,
160
+ "seed": 42,
161
+ "algorithm": "legacy_cascor",
162
+ "radius": 10.0,
163
+ "origin": [0.0, 0.0],
164
+ "noise": 0.1,
165
+ "train_ratio": 0.8,
166
+ "test_ratio": 0.2,
167
+ },
168
+ "persist": True,
169
+ }
170
+
171
+ def test_e2e_create_download_verify_legacy(self, e2e_client: TestClient, legacy_request: dict) -> None:
172
+ """Complete E2E flow for legacy_cascor algorithm."""
173
+ create_response = e2e_client.post("/v1/datasets", json=legacy_request)
174
+ assert create_response.status_code == 201
175
+ dataset_id = create_response.json()["dataset_id"]
176
+
177
+ artifact_response = e2e_client.get(f"/v1/datasets/{dataset_id}/artifact")
178
+ assert artifact_response.status_code == 200
179
+
180
+ with np.load(io.BytesIO(artifact_response.content)) as data:
181
+ expected_keys = ["X_train", "y_train", "X_test", "y_test", "X_full", "y_full"]
182
+ for key in expected_keys:
183
+ assert key in data.files, f"Missing key: {key}"
184
+
185
+ X_full = data["X_full"]
186
+ y_full = data["y_full"]
187
+
188
+ assert X_full.dtype == np.float32
189
+ assert y_full.dtype == np.float32
190
+
191
+ n_total = 2 * 100
192
+ assert X_full.shape == (n_total, 2)
193
+ assert y_full.shape == (n_total, 2)
194
+
195
+ def test_e2e_legacy_vs_modern_different(self, e2e_client: TestClient) -> None:
196
+ """Legacy and modern algorithms produce different data with same seed."""
197
+ base_params = {
198
+ "n_spirals": 2,
199
+ "n_points_per_spiral": 50,
200
+ "seed": 42,
201
+ "noise": 0.1,
202
+ }
203
+
204
+ modern_request = {
205
+ "generator": "spiral",
206
+ "params": {**base_params, "algorithm": "modern"},
207
+ "persist": True,
208
+ }
209
+ legacy_request = {
210
+ "generator": "spiral",
211
+ "params": {**base_params, "algorithm": "legacy_cascor", "radius": 10.0},
212
+ "persist": True,
213
+ }
214
+
215
+ modern_response = e2e_client.post("/v1/datasets", json=modern_request)
216
+ legacy_response = e2e_client.post("/v1/datasets", json=legacy_request)
217
+
218
+ modern_id = modern_response.json()["dataset_id"]
219
+ legacy_id = legacy_response.json()["dataset_id"]
220
+
221
+ assert modern_id != legacy_id
222
+
223
+ modern_artifact = e2e_client.get(f"/v1/datasets/{modern_id}/artifact")
224
+ legacy_artifact = e2e_client.get(f"/v1/datasets/{legacy_id}/artifact")
225
+
226
+ with np.load(io.BytesIO(modern_artifact.content)) as modern_data:
227
+ with np.load(io.BytesIO(legacy_artifact.content)) as legacy_data:
228
+ assert not np.array_equal(modern_data["X_full"], legacy_data["X_full"])
229
+
230
+
231
+ @pytest.mark.integration
232
+ @pytest.mark.slow
233
+ class TestE2EDataContract:
234
+ """E2E tests verifying the NPZ data contract for consumers."""
235
+
236
+ @pytest.fixture
237
+ def contract_request(self) -> dict:
238
+ """Standard request for data contract verification."""
239
+ return {
240
+ "generator": "spiral",
241
+ "params": {
242
+ "n_spirals": 2,
243
+ "n_points_per_spiral": 50,
244
+ "seed": 12345,
245
+ "train_ratio": 0.7,
246
+ "test_ratio": 0.3,
247
+ },
248
+ "persist": True,
249
+ }
250
+
251
+ def test_e2e_npz_keys_contract(self, e2e_client: TestClient, contract_request: dict) -> None:
252
+ """Verify NPZ contains exactly the expected keys per data contract."""
253
+ create_response = e2e_client.post("/v1/datasets", json=contract_request)
254
+ dataset_id = create_response.json()["dataset_id"]
255
+ artifact_response = e2e_client.get(f"/v1/datasets/{dataset_id}/artifact")
256
+
257
+ with np.load(io.BytesIO(artifact_response.content)) as data:
258
+ expected_keys = {"X_train", "y_train", "X_test", "y_test", "X_full", "y_full"}
259
+ actual_keys = set(data.files)
260
+ assert actual_keys == expected_keys, f"Keys mismatch: expected {expected_keys}, got {actual_keys}"
261
+
262
+ def test_e2e_feature_dimensions(self, e2e_client: TestClient, contract_request: dict) -> None:
263
+ """Verify features have 2 dimensions (x, y coordinates)."""
264
+ create_response = e2e_client.post("/v1/datasets", json=contract_request)
265
+ dataset_id = create_response.json()["dataset_id"]
266
+ artifact_response = e2e_client.get(f"/v1/datasets/{dataset_id}/artifact")
267
+
268
+ with np.load(io.BytesIO(artifact_response.content)) as data:
269
+ assert data["X_train"].shape[1] == 2
270
+ assert data["X_test"].shape[1] == 2
271
+ assert data["X_full"].shape[1] == 2
272
+
273
+ def test_e2e_one_hot_labels(self, e2e_client: TestClient, contract_request: dict) -> None:
274
+ """Verify labels are one-hot encoded with correct class count."""
275
+ create_response = e2e_client.post("/v1/datasets", json=contract_request)
276
+ dataset_id = create_response.json()["dataset_id"]
277
+ artifact_response = e2e_client.get(f"/v1/datasets/{dataset_id}/artifact")
278
+
279
+ with np.load(io.BytesIO(artifact_response.content)) as data:
280
+ y_full = data["y_full"]
281
+ n_spirals = contract_request["params"]["n_spirals"]
282
+
283
+ assert y_full.shape[1] == n_spirals
284
+
285
+ row_sums = y_full.sum(axis=1)
286
+ np.testing.assert_array_almost_equal(row_sums, np.ones(len(y_full)))
287
+
288
+ assert set(np.unique(y_full)) == {0.0, 1.0}
289
+
290
+ def test_e2e_train_test_split_ratios(self, e2e_client: TestClient, contract_request: dict) -> None:
291
+ """Verify train/test split matches requested ratios."""
292
+ create_response = e2e_client.post("/v1/datasets", json=contract_request)
293
+ dataset_id = create_response.json()["dataset_id"]
294
+ artifact_response = e2e_client.get(f"/v1/datasets/{dataset_id}/artifact")
295
+
296
+ with np.load(io.BytesIO(artifact_response.content)) as data:
297
+ n_train = len(data["X_train"])
298
+ n_test = len(data["X_test"])
299
+ n_full = len(data["X_full"])
300
+
301
+ assert n_train + n_test == n_full
302
+
303
+ expected_train_ratio = 0.7
304
+ actual_train_ratio = n_train / n_full
305
+ assert abs(actual_train_ratio - expected_train_ratio) < 0.05
306
+
307
+ def test_e2e_metadata_consistency(self, e2e_client: TestClient, contract_request: dict) -> None:
308
+ """Verify metadata matches actual data dimensions."""
309
+ create_response = e2e_client.post("/v1/datasets", json=contract_request)
310
+ data = create_response.json()
311
+ dataset_id = data["dataset_id"]
312
+ meta = data["meta"]
313
+
314
+ artifact_response = e2e_client.get(f"/v1/datasets/{dataset_id}/artifact")
315
+
316
+ with np.load(io.BytesIO(artifact_response.content)) as npz_data:
317
+ assert meta["n_samples"] == len(npz_data["X_full"])
318
+ assert meta["n_train"] == len(npz_data["X_train"])
319
+ assert meta["n_test"] == len(npz_data["X_test"])
320
+ assert meta["n_features"] == npz_data["X_full"].shape[1]
321
+ assert meta["n_classes"] == npz_data["y_full"].shape[1]
322
+
323
+
324
+ @pytest.mark.integration
325
+ @pytest.mark.slow
326
+ class TestE2EErrorHandling:
327
+ """E2E tests for error handling scenarios."""
328
+
329
+ def test_e2e_invalid_generator_name(self, e2e_client: TestClient) -> None:
330
+ """Invalid generator name returns error (400 or 404)."""
331
+ request = {
332
+ "generator": "nonexistent_generator",
333
+ "params": {},
334
+ "persist": True,
335
+ }
336
+ response = e2e_client.post("/v1/datasets", json=request)
337
+ assert response.status_code in (400, 404)
338
+ assert "detail" in response.json()
339
+
340
+ def test_e2e_invalid_params(self, e2e_client: TestClient) -> None:
341
+ """Invalid parameters return 400/422."""
342
+ request = {
343
+ "generator": "spiral",
344
+ "params": {
345
+ "n_spirals": -1,
346
+ "n_points_per_spiral": 100,
347
+ },
348
+ "persist": True,
349
+ }
350
+ response = e2e_client.post("/v1/datasets", json=request)
351
+ assert response.status_code in (400, 422)
352
+
353
+ def test_e2e_nonexistent_dataset_artifact(self, e2e_client: TestClient) -> None:
354
+ """Requesting artifact for nonexistent dataset returns 404."""
355
+ response = e2e_client.get("/v1/datasets/nonexistent-id-12345/artifact")
356
+ assert response.status_code == 404
357
+
358
+ def test_e2e_delete_and_verify_gone(self, e2e_client: TestClient) -> None:
359
+ """Deleted dataset cannot be retrieved."""
360
+ request = {
361
+ "generator": "spiral",
362
+ "params": {"n_spirals": 2, "n_points_per_spiral": 10, "seed": 1},
363
+ "persist": True,
364
+ }
365
+ create_response = e2e_client.post("/v1/datasets", json=request)
366
+ dataset_id = create_response.json()["dataset_id"]
367
+
368
+ get_response = e2e_client.get(f"/v1/datasets/{dataset_id}")
369
+ assert get_response.status_code == 200
370
+
371
+ delete_response = e2e_client.delete(f"/v1/datasets/{dataset_id}")
372
+ assert delete_response.status_code == 204
373
+
374
+ get_after_delete = e2e_client.get(f"/v1/datasets/{dataset_id}")
375
+ assert get_after_delete.status_code == 404
376
+
377
+ artifact_after_delete = e2e_client.get(f"/v1/datasets/{dataset_id}/artifact")
378
+ assert artifact_after_delete.status_code == 404
@@ -0,0 +1,304 @@
1
+ """Integration tests for dataset lifecycle management API endpoints (DATA-016).
2
+
3
+ Tests for:
4
+ - POST /v1/datasets with tags and TTL
5
+ - GET /v1/datasets/filter
6
+ - POST /v1/datasets/batch-delete
7
+ - PATCH /v1/datasets/{id}/tags
8
+ - GET /v1/datasets/stats
9
+ - POST /v1/datasets/cleanup-expired
10
+ """
11
+
12
+ from datetime import UTC, datetime, timedelta
13
+
14
+ import pytest
15
+ from fastapi.testclient import TestClient
16
+
17
+ from juniper_data.api.app import create_app
18
+ from juniper_data.api.routes import datasets
19
+ from juniper_data.api.settings import Settings
20
+ from juniper_data.storage.memory import InMemoryDatasetStore
21
+
22
+ # from typing import Dict
23
+
24
+
25
+ @pytest.fixture
26
+ def lifecycle_store() -> InMemoryDatasetStore:
27
+ """Create a fresh in-memory store for lifecycle tests."""
28
+ return InMemoryDatasetStore()
29
+
30
+
31
+ @pytest.fixture
32
+ def lifecycle_settings() -> Settings:
33
+ """Create lifecycle test settings."""
34
+ return Settings(storage_path="/tmp/juniper_data_lifecycle_test")
35
+
36
+
37
+ @pytest.fixture
38
+ def lifecycle_client(lifecycle_store: InMemoryDatasetStore, lifecycle_settings: Settings) -> TestClient:
39
+ """Create a lifecycle test client with in-memory storage."""
40
+ app = create_app(settings=lifecycle_settings)
41
+ datasets.set_store(lifecycle_store)
42
+ return TestClient(app)
43
+
44
+
45
+ def _create_spiral_request(
46
+ n_points: int = 50,
47
+ seed: int = 42,
48
+ tags: list[str] | None = None,
49
+ ttl_seconds: int | None = None,
50
+ ) -> dict:
51
+ """Create a spiral dataset request."""
52
+ request = {
53
+ "generator": "spiral",
54
+ "params": {"n_spirals": 2, "n_points_per_spiral": n_points, "seed": seed},
55
+ "persist": True,
56
+ }
57
+ if tags:
58
+ request["tags"] = tags
59
+ if ttl_seconds:
60
+ request["ttl_seconds"] = ttl_seconds
61
+ return request
62
+
63
+
64
+ @pytest.mark.integration
65
+ class TestCreateDatasetWithLifecycle:
66
+ """Tests for creating datasets with lifecycle features."""
67
+
68
+ def test_create_dataset_with_tags(self, lifecycle_client: TestClient) -> None:
69
+ """Create dataset with tags."""
70
+ request = _create_spiral_request(tags=["train", "experiment-1"])
71
+ response = lifecycle_client.post("/v1/datasets", json=request)
72
+
73
+ assert response.status_code == 201
74
+ meta = response.json()["meta"]
75
+ assert "train" in meta["tags"]
76
+ assert "experiment-1" in meta["tags"]
77
+
78
+ def test_create_dataset_with_ttl(self, lifecycle_client: TestClient) -> None:
79
+ """Create dataset with TTL."""
80
+ request = _create_spiral_request(ttl_seconds=3600)
81
+ response = lifecycle_client.post("/v1/datasets", json=request)
82
+
83
+ assert response.status_code == 201
84
+ meta = response.json()["meta"]
85
+ assert meta["ttl_seconds"] == 3600
86
+ assert meta["expires_at"] is not None
87
+
88
+
89
+ @pytest.mark.integration
90
+ class TestFilterDatasets:
91
+ """Tests for the filter datasets endpoint."""
92
+
93
+ @pytest.fixture
94
+ def populated_client(self, lifecycle_client: TestClient) -> TestClient:
95
+ """Create multiple datasets for filtering tests."""
96
+ requests = [
97
+ _create_spiral_request(n_points=50, seed=1, tags=["train", "v1"]),
98
+ _create_spiral_request(n_points=100, seed=2, tags=["train", "v2"]),
99
+ _create_spiral_request(n_points=150, seed=3, tags=["test", "v1"]),
100
+ _create_spiral_request(n_points=200, seed=4, tags=["test", "v2"]),
101
+ ]
102
+ for req in requests:
103
+ lifecycle_client.post("/v1/datasets", json=req)
104
+ return lifecycle_client
105
+
106
+ def test_filter_by_generator(self, populated_client: TestClient) -> None:
107
+ """Filter by generator name."""
108
+ response = populated_client.get("/v1/datasets/filter?generator=spiral")
109
+
110
+ assert response.status_code == 200
111
+ data = response.json()
112
+ assert data["total"] == 4
113
+ assert all(d["generator"] == "spiral" for d in data["datasets"])
114
+
115
+ def test_filter_by_tags_any(self, populated_client: TestClient) -> None:
116
+ """Filter by tags with any match."""
117
+ response = populated_client.get("/v1/datasets/filter?tags=train&tags_match=any")
118
+
119
+ assert response.status_code == 200
120
+ data = response.json()
121
+ assert data["total"] == 2
122
+
123
+ def test_filter_by_tags_all(self, populated_client: TestClient) -> None:
124
+ """Filter by tags with all match."""
125
+ response = populated_client.get("/v1/datasets/filter?tags=train,v1&tags_match=all")
126
+
127
+ assert response.status_code == 200
128
+ data = response.json()
129
+ assert data["total"] == 1
130
+
131
+ def test_filter_by_sample_count(self, populated_client: TestClient) -> None:
132
+ """Filter by sample count range."""
133
+ response = populated_client.get("/v1/datasets/filter?min_samples=250&max_samples=350")
134
+
135
+ assert response.status_code == 200
136
+ data = response.json()
137
+ assert data["total"] == 1
138
+
139
+ def test_filter_with_pagination(self, populated_client: TestClient) -> None:
140
+ """Filter with pagination."""
141
+ response = populated_client.get("/v1/datasets/filter?limit=2&offset=0")
142
+
143
+ assert response.status_code == 200
144
+ data = response.json()
145
+ assert data["total"] == 4
146
+ assert len(data["datasets"]) == 2
147
+ assert data["limit"] == 2
148
+ assert data["offset"] == 0
149
+
150
+
151
+ @pytest.mark.integration
152
+ class TestBatchDelete:
153
+ """Tests for the batch delete endpoint."""
154
+
155
+ def test_batch_delete_existing(self, lifecycle_client: TestClient) -> None:
156
+ """Batch delete existing datasets."""
157
+ ids = []
158
+ for seed in range(3):
159
+ response = lifecycle_client.post("/v1/datasets", json=_create_spiral_request(seed=seed))
160
+ ids.append(response.json()["dataset_id"])
161
+
162
+ response = lifecycle_client.post("/v1/datasets/batch-delete", json={"dataset_ids": ids[:2]})
163
+
164
+ assert response.status_code == 200
165
+ data = response.json()
166
+ assert len(data["deleted"]) == 2
167
+ assert data["not_found"] == []
168
+ assert data["total_deleted"] == 2
169
+
170
+ for deleted_id in ids[:2]:
171
+ get_response = lifecycle_client.get(f"/v1/datasets/{deleted_id}")
172
+ assert get_response.status_code == 404
173
+
174
+ get_response = lifecycle_client.get(f"/v1/datasets/{ids[2]}")
175
+ assert get_response.status_code == 200
176
+
177
+ def test_batch_delete_mixed(self, lifecycle_client: TestClient) -> None:
178
+ """Batch delete with some nonexistent IDs."""
179
+ response = lifecycle_client.post("/v1/datasets", json=_create_spiral_request(seed=42))
180
+ existing_id = response.json()["dataset_id"]
181
+
182
+ response = lifecycle_client.post(
183
+ "/v1/datasets/batch-delete", json={"dataset_ids": [existing_id, "fake-id-1", "fake-id-2"]}
184
+ )
185
+
186
+ assert response.status_code == 200
187
+ data = response.json()
188
+ assert data["deleted"] == [existing_id]
189
+ assert set(data["not_found"]) == {"fake-id-1", "fake-id-2"}
190
+
191
+
192
+ @pytest.mark.integration
193
+ class TestUpdateTags:
194
+ """Tests for the update tags endpoint."""
195
+
196
+ def test_add_tags(self, lifecycle_client: TestClient) -> None:
197
+ """Add tags to existing dataset."""
198
+ response = lifecycle_client.post("/v1/datasets", json=_create_spiral_request(tags=["original"]))
199
+ dataset_id = response.json()["dataset_id"]
200
+
201
+ response = lifecycle_client.patch(
202
+ f"/v1/datasets/{dataset_id}/tags", json={"add_tags": ["new-tag-1", "new-tag-2"]}
203
+ )
204
+
205
+ assert response.status_code == 200
206
+ tags = response.json()["tags"]
207
+ assert "original" in tags
208
+ assert "new-tag-1" in tags
209
+ assert "new-tag-2" in tags
210
+
211
+ def test_remove_tags(self, lifecycle_client: TestClient) -> None:
212
+ """Remove tags from existing dataset."""
213
+ response = lifecycle_client.post("/v1/datasets", json=_create_spiral_request(tags=["keep", "remove"]))
214
+ dataset_id = response.json()["dataset_id"]
215
+
216
+ response = lifecycle_client.patch(f"/v1/datasets/{dataset_id}/tags", json={"remove_tags": ["remove"]})
217
+
218
+ assert response.status_code == 200
219
+ tags = response.json()["tags"]
220
+ assert "keep" in tags
221
+ assert "remove" not in tags
222
+
223
+ def test_add_and_remove_tags(self, lifecycle_client: TestClient) -> None:
224
+ """Add and remove tags in single request."""
225
+ response = lifecycle_client.post("/v1/datasets", json=_create_spiral_request(tags=["a", "b"]))
226
+ dataset_id = response.json()["dataset_id"]
227
+
228
+ response = lifecycle_client.patch(
229
+ f"/v1/datasets/{dataset_id}/tags", json={"add_tags": ["c"], "remove_tags": ["a"]}
230
+ )
231
+
232
+ assert response.status_code == 200
233
+ tags = response.json()["tags"]
234
+ assert set(tags) == {"b", "c"}
235
+
236
+ def test_update_tags_not_found(self, lifecycle_client: TestClient) -> None:
237
+ """Update tags on nonexistent dataset returns 404."""
238
+ response = lifecycle_client.patch("/v1/datasets/nonexistent-id/tags", json={"add_tags": ["test"]})
239
+ assert response.status_code == 404
240
+
241
+
242
+ @pytest.mark.integration
243
+ class TestDatasetStats:
244
+ """Tests for the stats endpoint."""
245
+
246
+ def test_stats_empty(self, lifecycle_client: TestClient) -> None:
247
+ """Stats for empty store."""
248
+ response = lifecycle_client.get("/v1/datasets/stats")
249
+
250
+ assert response.status_code == 200
251
+ data = response.json()
252
+ assert data["total_datasets"] == 0
253
+ assert data["total_samples"] == 0
254
+
255
+ def test_stats_populated(self, lifecycle_client: TestClient) -> None:
256
+ """Stats for populated store."""
257
+ lifecycle_client.post("/v1/datasets", json=_create_spiral_request(n_points=50, seed=1, tags=["train"]))
258
+ lifecycle_client.post("/v1/datasets", json=_create_spiral_request(n_points=100, seed=2, tags=["train", "v2"]))
259
+ lifecycle_client.post("/v1/datasets", json=_create_spiral_request(n_points=150, seed=3, tags=["test"]))
260
+
261
+ response = lifecycle_client.get("/v1/datasets/stats")
262
+
263
+ assert response.status_code == 200
264
+ data = response.json()
265
+ assert data["total_datasets"] == 3
266
+ assert data["total_samples"] == 600
267
+ assert data["by_generator"] == {"spiral": 3}
268
+ assert data["by_tag"]["train"] == 2
269
+
270
+
271
+ @pytest.mark.integration
272
+ class TestCleanupExpired:
273
+ """Tests for the cleanup-expired endpoint."""
274
+
275
+ def test_cleanup_expired_none(self, lifecycle_client: TestClient) -> None:
276
+ """Cleanup with no expired datasets."""
277
+ lifecycle_client.post("/v1/datasets", json=_create_spiral_request(seed=1))
278
+ lifecycle_client.post("/v1/datasets", json=_create_spiral_request(seed=2))
279
+
280
+ response = lifecycle_client.post("/v1/datasets/cleanup-expired")
281
+
282
+ assert response.status_code == 200
283
+ assert response.json() == []
284
+
285
+ def test_cleanup_expired_with_ttl(
286
+ self, lifecycle_client: TestClient, lifecycle_store: InMemoryDatasetStore
287
+ ) -> None:
288
+ """Cleanup datasets with expired TTL requires manipulating store directly."""
289
+ response = lifecycle_client.post("/v1/datasets", json=_create_spiral_request(seed=1, ttl_seconds=3600))
290
+ dataset_id = response.json()["dataset_id"]
291
+
292
+ meta = lifecycle_store.get_meta(dataset_id)
293
+ assert meta is not None
294
+ meta.expires_at = datetime.now(UTC) - timedelta(hours=1)
295
+ lifecycle_store.update_meta(dataset_id, meta)
296
+
297
+ response = lifecycle_client.post("/v1/datasets/cleanup-expired")
298
+
299
+ assert response.status_code == 200
300
+ deleted = response.json()
301
+ assert dataset_id in deleted
302
+
303
+ get_response = lifecycle_client.get(f"/v1/datasets/{dataset_id}")
304
+ assert get_response.status_code == 404