juniper-data 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. juniper_data/__init__.py +88 -0
  2. juniper_data/__main__.py +78 -0
  3. juniper_data/api/__init__.py +10 -0
  4. juniper_data/api/app.py +111 -0
  5. juniper_data/api/middleware.py +95 -0
  6. juniper_data/api/routes/__init__.py +9 -0
  7. juniper_data/api/routes/datasets.py +414 -0
  8. juniper_data/api/routes/generators.py +125 -0
  9. juniper_data/api/routes/health.py +49 -0
  10. juniper_data/api/security.py +238 -0
  11. juniper_data/api/settings.py +109 -0
  12. juniper_data/core/__init__.py +32 -0
  13. juniper_data/core/artifacts.py +63 -0
  14. juniper_data/core/dataset_id.py +38 -0
  15. juniper_data/core/models.py +135 -0
  16. juniper_data/core/split.py +120 -0
  17. juniper_data/generators/__init__.py +15 -0
  18. juniper_data/generators/arc_agi/__init__.py +11 -0
  19. juniper_data/generators/arc_agi/generator.py +229 -0
  20. juniper_data/generators/arc_agi/params.py +56 -0
  21. juniper_data/generators/checkerboard/__init__.py +15 -0
  22. juniper_data/generators/checkerboard/generator.py +114 -0
  23. juniper_data/generators/checkerboard/params.py +32 -0
  24. juniper_data/generators/circles/__init__.py +11 -0
  25. juniper_data/generators/circles/generator.py +112 -0
  26. juniper_data/generators/circles/params.py +31 -0
  27. juniper_data/generators/csv_import/__init__.py +15 -0
  28. juniper_data/generators/csv_import/generator.py +198 -0
  29. juniper_data/generators/csv_import/params.py +48 -0
  30. juniper_data/generators/gaussian/__init__.py +11 -0
  31. juniper_data/generators/gaussian/generator.py +149 -0
  32. juniper_data/generators/gaussian/params.py +53 -0
  33. juniper_data/generators/mnist/__init__.py +11 -0
  34. juniper_data/generators/mnist/generator.py +124 -0
  35. juniper_data/generators/mnist/params.py +39 -0
  36. juniper_data/generators/spiral/__init__.py +57 -0
  37. juniper_data/generators/spiral/defaults.py +39 -0
  38. juniper_data/generators/spiral/generator.py +206 -0
  39. juniper_data/generators/spiral/params.py +148 -0
  40. juniper_data/generators/xor/__init__.py +11 -0
  41. juniper_data/generators/xor/generator.py +162 -0
  42. juniper_data/generators/xor/params.py +30 -0
  43. juniper_data/storage/__init__.py +120 -0
  44. juniper_data/storage/base.py +279 -0
  45. juniper_data/storage/cached.py +211 -0
  46. juniper_data/storage/hf_store.py +257 -0
  47. juniper_data/storage/kaggle_store.py +333 -0
  48. juniper_data/storage/local_fs.py +232 -0
  49. juniper_data/storage/memory.py +136 -0
  50. juniper_data/storage/postgres_store.py +373 -0
  51. juniper_data/storage/redis_store.py +264 -0
  52. juniper_data/tests/__init__.py +1 -0
  53. juniper_data/tests/conftest.py +68 -0
  54. juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
  55. juniper_data/tests/integration/__init__.py +1 -0
  56. juniper_data/tests/integration/test_api.py +283 -0
  57. juniper_data/tests/integration/test_e2e_workflow.py +378 -0
  58. juniper_data/tests/integration/test_lifecycle_api.py +304 -0
  59. juniper_data/tests/integration/test_security_integration.py +189 -0
  60. juniper_data/tests/integration/test_storage_workflow.py +259 -0
  61. juniper_data/tests/performance/__init__.py +1 -0
  62. juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
  63. juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
  64. juniper_data/tests/unit/__init__.py +1 -0
  65. juniper_data/tests/unit/test_api_app.py +206 -0
  66. juniper_data/tests/unit/test_api_routes.py +407 -0
  67. juniper_data/tests/unit/test_api_settings.py +100 -0
  68. juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
  69. juniper_data/tests/unit/test_artifacts.py +145 -0
  70. juniper_data/tests/unit/test_cached_store.py +423 -0
  71. juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
  72. juniper_data/tests/unit/test_circles_generator.py +256 -0
  73. juniper_data/tests/unit/test_csv_import_generator.py +345 -0
  74. juniper_data/tests/unit/test_dataset_id.py +181 -0
  75. juniper_data/tests/unit/test_gaussian_generator.py +333 -0
  76. juniper_data/tests/unit/test_hf_store.py +416 -0
  77. juniper_data/tests/unit/test_init.py +93 -0
  78. juniper_data/tests/unit/test_kaggle_store.py +469 -0
  79. juniper_data/tests/unit/test_lifecycle.py +394 -0
  80. juniper_data/tests/unit/test_main.py +127 -0
  81. juniper_data/tests/unit/test_middleware.py +79 -0
  82. juniper_data/tests/unit/test_mnist_generator.py +370 -0
  83. juniper_data/tests/unit/test_postgres_store.py +490 -0
  84. juniper_data/tests/unit/test_redis_store.py +500 -0
  85. juniper_data/tests/unit/test_security.py +281 -0
  86. juniper_data/tests/unit/test_security_boundaries.py +517 -0
  87. juniper_data/tests/unit/test_spiral_generator.py +566 -0
  88. juniper_data/tests/unit/test_split.py +245 -0
  89. juniper_data/tests/unit/test_storage.py +767 -0
  90. juniper_data/tests/unit/test_xor_generator.py +223 -0
  91. juniper_data-0.4.2.dist-info/METADATA +216 -0
  92. juniper_data-0.4.2.dist-info/RECORD +95 -0
  93. juniper_data-0.4.2.dist-info/WHEEL +5 -0
  94. juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
  95. juniper_data-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,394 @@
1
+ """Unit tests for dataset lifecycle management features (DATA-016).
2
+
3
+ Tests for:
4
+ - Dataset expiration / TTL
5
+ - Bulk operations (filtering, batch delete)
6
+ - Dataset tagging
7
+ - Usage tracking / access counts
8
+ - Statistics
9
+ """
10
+
11
+ from datetime import UTC, datetime, timedelta
12
+
13
+ import numpy as np
14
+ import pytest
15
+
16
+ from juniper_data.core.models import DatasetMeta
17
+ from juniper_data.storage.memory import InMemoryDatasetStore
18
+
19
+ # from typing import Dict
20
+
21
+
22
+ def _create_test_meta(
23
+ dataset_id: str,
24
+ generator: str = "spiral",
25
+ n_samples: int = 100,
26
+ tags: list[str] | None = None,
27
+ ttl_seconds: int | None = None,
28
+ created_at: datetime | None = None,
29
+ ) -> DatasetMeta:
30
+ """Create a test DatasetMeta instance."""
31
+ now = created_at or datetime.now(UTC)
32
+ expires_at = None
33
+ if ttl_seconds is not None:
34
+ expires_at = now + timedelta(seconds=ttl_seconds)
35
+
36
+ return DatasetMeta(
37
+ dataset_id=dataset_id,
38
+ generator=generator,
39
+ generator_version="1.0.0",
40
+ params={"n_spirals": 2},
41
+ n_samples=n_samples,
42
+ n_features=2,
43
+ n_classes=2,
44
+ n_train=int(n_samples * 0.8),
45
+ n_test=int(n_samples * 0.2),
46
+ class_distribution={"0": n_samples // 2, "1": n_samples // 2},
47
+ artifact_formats=["npz"],
48
+ created_at=now,
49
+ tags=tags or [],
50
+ ttl_seconds=ttl_seconds,
51
+ expires_at=expires_at,
52
+ )
53
+
54
+
55
+ def _create_test_arrays(n_samples: int = 100) -> dict[str, np.ndarray]:
56
+ """Create minimal test arrays."""
57
+ n_train = int(n_samples * 0.8)
58
+ n_test = n_samples - n_train
59
+ return {
60
+ "X_train": np.zeros((n_train, 2), dtype=np.float32),
61
+ "y_train": np.zeros((n_train, 2), dtype=np.float32),
62
+ "X_test": np.zeros((n_test, 2), dtype=np.float32),
63
+ "y_test": np.zeros((n_test, 2), dtype=np.float32),
64
+ "X_full": np.zeros((n_samples, 2), dtype=np.float32),
65
+ "y_full": np.zeros((n_samples, 2), dtype=np.float32),
66
+ }
67
+
68
+
69
+ @pytest.fixture
70
+ def store() -> InMemoryDatasetStore:
71
+ """Create a fresh in-memory store."""
72
+ return InMemoryDatasetStore()
73
+
74
+
75
+ @pytest.mark.unit
76
+ class TestDatasetTags:
77
+ """Tests for dataset tagging functionality."""
78
+
79
+ def test_create_dataset_with_tags(self, store: InMemoryDatasetStore) -> None:
80
+ """Dataset can be created with tags."""
81
+ meta = _create_test_meta("ds-1", tags=["train", "spiral", "v1"])
82
+ store.save("ds-1", meta, _create_test_arrays())
83
+
84
+ retrieved = store.get_meta("ds-1")
85
+ assert retrieved is not None
86
+ assert retrieved.tags == ["train", "spiral", "v1"]
87
+
88
+ def test_update_meta_adds_tags(self, store: InMemoryDatasetStore) -> None:
89
+ """Tags can be added via update_meta."""
90
+ meta = _create_test_meta("ds-1", tags=["original"])
91
+ store.save("ds-1", meta, _create_test_arrays())
92
+
93
+ meta.tags = ["original", "added"]
94
+ result = store.update_meta("ds-1", meta)
95
+ assert result is True
96
+
97
+ retrieved = store.get_meta("ds-1")
98
+ assert retrieved is not None
99
+ assert "added" in retrieved.tags
100
+
101
+ def test_update_meta_nonexistent_returns_false(self, store: InMemoryDatasetStore) -> None:
102
+ """update_meta returns False for nonexistent dataset."""
103
+ meta = _create_test_meta("nonexistent")
104
+ result = store.update_meta("nonexistent", meta)
105
+ assert result is False
106
+
107
+
108
+ @pytest.mark.unit
109
+ class TestDatasetTTL:
110
+ """Tests for dataset expiration / TTL."""
111
+
112
+ def test_dataset_with_ttl_has_expires_at(self, store: InMemoryDatasetStore) -> None:
113
+ """Dataset with TTL has expires_at set."""
114
+ meta = _create_test_meta("ds-1", ttl_seconds=3600)
115
+ store.save("ds-1", meta, _create_test_arrays())
116
+
117
+ retrieved = store.get_meta("ds-1")
118
+ assert retrieved is not None
119
+ assert retrieved.ttl_seconds == 3600
120
+ assert retrieved.expires_at is not None
121
+
122
+ def test_is_expired_false_for_future_expiry(self, store: InMemoryDatasetStore) -> None:
123
+ """Dataset with future expiry is not expired."""
124
+ meta = _create_test_meta("ds-1", ttl_seconds=3600)
125
+ store.save("ds-1", meta, _create_test_arrays())
126
+
127
+ assert store.is_expired(meta) is False
128
+
129
+ def test_is_expired_true_for_past_expiry(self, store: InMemoryDatasetStore) -> None:
130
+ """Dataset with past expiry is expired."""
131
+ past_time = datetime.now(UTC) - timedelta(hours=2)
132
+ meta = _create_test_meta("ds-1", ttl_seconds=3600, created_at=past_time)
133
+ store.save("ds-1", meta, _create_test_arrays())
134
+
135
+ assert store.is_expired(meta) is True
136
+
137
+ def test_is_expired_false_for_no_ttl(self, store: InMemoryDatasetStore) -> None:
138
+ """Dataset without TTL never expires."""
139
+ meta = _create_test_meta("ds-1", ttl_seconds=None)
140
+ store.save("ds-1", meta, _create_test_arrays())
141
+
142
+ assert store.is_expired(meta) is False
143
+
144
+ def test_delete_expired_removes_expired_datasets(self, store: InMemoryDatasetStore) -> None:
145
+ """delete_expired removes only expired datasets."""
146
+ past_time = datetime.now(UTC) - timedelta(hours=2)
147
+ meta1 = _create_test_meta("expired-1", ttl_seconds=3600, created_at=past_time)
148
+ meta2 = _create_test_meta("expired-2", ttl_seconds=3600, created_at=past_time)
149
+ meta3 = _create_test_meta("valid-1", ttl_seconds=3600)
150
+ meta4 = _create_test_meta("no-ttl")
151
+
152
+ store.save("expired-1", meta1, _create_test_arrays())
153
+ store.save("expired-2", meta2, _create_test_arrays())
154
+ store.save("valid-1", meta3, _create_test_arrays())
155
+ store.save("no-ttl", meta4, _create_test_arrays())
156
+
157
+ deleted = store.delete_expired()
158
+
159
+ assert set(deleted) == {"expired-1", "expired-2"}
160
+ assert store.exists("valid-1")
161
+ assert store.exists("no-ttl")
162
+ assert not store.exists("expired-1")
163
+ assert not store.exists("expired-2")
164
+
165
+
166
+ @pytest.mark.unit
167
+ class TestDatasetFiltering:
168
+ """Tests for dataset filtering functionality."""
169
+
170
+ @pytest.fixture
171
+ def populated_store(self, store: InMemoryDatasetStore) -> InMemoryDatasetStore:
172
+ """Create a store with multiple datasets for filtering tests."""
173
+ now = datetime.now(UTC)
174
+
175
+ datasets = [
176
+ ("ds-1", "spiral", 100, ["train", "v1"], now - timedelta(days=5)),
177
+ ("ds-2", "spiral", 200, ["train", "v2"], now - timedelta(days=3)),
178
+ ("ds-3", "spiral", 50, ["test", "v1"], now - timedelta(days=1)),
179
+ ("ds-4", "xor", 100, ["train"], now - timedelta(hours=12)),
180
+ ("ds-5", "xor", 300, ["train", "v2"], now - timedelta(hours=1)),
181
+ ]
182
+
183
+ for dataset_id, gen, n_samples, tags, created in datasets:
184
+ meta = _create_test_meta(dataset_id, generator=gen, n_samples=n_samples, tags=tags, created_at=created)
185
+ store.save(dataset_id, meta, _create_test_arrays(n_samples))
186
+
187
+ return store
188
+
189
+ def test_filter_by_generator(self, populated_store: InMemoryDatasetStore) -> None:
190
+ """Filter datasets by generator name."""
191
+ datasets, total = populated_store.filter_datasets(generator="spiral")
192
+ assert total == 3
193
+ assert all(d.generator == "spiral" for d in datasets)
194
+
195
+ def test_filter_by_tags_any(self, populated_store: InMemoryDatasetStore) -> None:
196
+ """Filter datasets by tags (any match)."""
197
+ datasets, total = populated_store.filter_datasets(tags=["v1", "v2"], tags_match="any")
198
+ assert total == 4
199
+
200
+ def test_filter_by_tags_all(self, populated_store: InMemoryDatasetStore) -> None:
201
+ """Filter datasets by tags (all must match)."""
202
+ datasets, total = populated_store.filter_datasets(tags=["train", "v2"], tags_match="all")
203
+ assert total == 2
204
+ for d in datasets:
205
+ assert "train" in d.tags
206
+ assert "v2" in d.tags
207
+
208
+ def test_filter_by_created_after(self, populated_store: InMemoryDatasetStore) -> None:
209
+ """Filter datasets created after a date."""
210
+ cutoff = datetime.now(UTC) - timedelta(days=2)
211
+ datasets, total = populated_store.filter_datasets(created_after=cutoff)
212
+ assert total == 3
213
+
214
+ def test_filter_by_created_before(self, populated_store: InMemoryDatasetStore) -> None:
215
+ """Filter datasets created before a date."""
216
+ cutoff = datetime.now(UTC) - timedelta(days=2)
217
+ datasets, total = populated_store.filter_datasets(created_before=cutoff)
218
+ assert total == 2
219
+
220
+ def test_filter_by_sample_count(self, populated_store: InMemoryDatasetStore) -> None:
221
+ """Filter datasets by sample count range."""
222
+ datasets, total = populated_store.filter_datasets(min_samples=100, max_samples=200)
223
+ assert total == 3
224
+ for d in datasets:
225
+ assert 100 <= d.n_samples <= 200
226
+
227
+ def test_filter_pagination(self, populated_store: InMemoryDatasetStore) -> None:
228
+ """Filter with pagination."""
229
+ datasets_page1, total = populated_store.filter_datasets(limit=2, offset=0)
230
+ datasets_page2, _ = populated_store.filter_datasets(limit=2, offset=2)
231
+
232
+ assert total == 5
233
+ assert len(datasets_page1) == 2
234
+ assert len(datasets_page2) == 2
235
+
236
+ ids_page1 = {d.dataset_id for d in datasets_page1}
237
+ ids_page2 = {d.dataset_id for d in datasets_page2}
238
+ assert ids_page1.isdisjoint(ids_page2)
239
+
240
+ def test_filter_excludes_expired_by_default(self, store: InMemoryDatasetStore) -> None:
241
+ """Expired datasets are excluded by default."""
242
+ past_time = datetime.now(UTC) - timedelta(hours=2)
243
+ meta_expired = _create_test_meta("expired", ttl_seconds=3600, created_at=past_time)
244
+ meta_valid = _create_test_meta("valid")
245
+
246
+ store.save("expired", meta_expired, _create_test_arrays())
247
+ store.save("valid", meta_valid, _create_test_arrays())
248
+
249
+ datasets, total = store.filter_datasets(include_expired=False)
250
+ assert total == 1
251
+ assert datasets[0].dataset_id == "valid"
252
+
253
+ def test_filter_includes_expired_when_requested(self, store: InMemoryDatasetStore) -> None:
254
+ """Expired datasets are included when requested."""
255
+ past_time = datetime.now(UTC) - timedelta(hours=2)
256
+ meta_expired = _create_test_meta("expired", ttl_seconds=3600, created_at=past_time)
257
+ meta_valid = _create_test_meta("valid")
258
+
259
+ store.save("expired", meta_expired, _create_test_arrays())
260
+ store.save("valid", meta_valid, _create_test_arrays())
261
+
262
+ datasets, total = store.filter_datasets(include_expired=True)
263
+ assert total == 2
264
+
265
+
266
+ @pytest.mark.unit
267
+ class TestBatchDelete:
268
+ """Tests for batch delete functionality."""
269
+
270
+ def test_batch_delete_existing(self, store: InMemoryDatasetStore) -> None:
271
+ """Batch delete existing datasets."""
272
+ for i in range(5):
273
+ meta = _create_test_meta(f"ds-{i}")
274
+ store.save(f"ds-{i}", meta, _create_test_arrays())
275
+
276
+ deleted, not_found = store.batch_delete(["ds-0", "ds-2", "ds-4"])
277
+
278
+ assert set(deleted) == {"ds-0", "ds-2", "ds-4"}
279
+ assert not_found == []
280
+ assert store.exists("ds-1")
281
+ assert store.exists("ds-3")
282
+ assert not store.exists("ds-0")
283
+
284
+ def test_batch_delete_mixed(self, store: InMemoryDatasetStore) -> None:
285
+ """Batch delete with some nonexistent IDs."""
286
+ meta = _create_test_meta("ds-1")
287
+ store.save("ds-1", meta, _create_test_arrays())
288
+
289
+ deleted, not_found = store.batch_delete(["ds-1", "nonexistent-1", "nonexistent-2"])
290
+
291
+ assert deleted == ["ds-1"]
292
+ assert set(not_found) == {"nonexistent-1", "nonexistent-2"}
293
+
294
+ def test_batch_delete_all_nonexistent(self, store: InMemoryDatasetStore) -> None:
295
+ """Batch delete with all nonexistent IDs."""
296
+ deleted, not_found = store.batch_delete(["fake-1", "fake-2"])
297
+
298
+ assert deleted == []
299
+ assert set(not_found) == {"fake-1", "fake-2"}
300
+
301
+
302
+ @pytest.mark.unit
303
+ class TestAccessTracking:
304
+ """Tests for access tracking functionality."""
305
+
306
+ def test_record_access_updates_timestamp(self, store: InMemoryDatasetStore) -> None:
307
+ """record_access updates last_accessed_at."""
308
+ meta = _create_test_meta("ds-1")
309
+ store.save("ds-1", meta, _create_test_arrays())
310
+
311
+ store.record_access("ds-1")
312
+
313
+ retrieved = store.get_meta("ds-1")
314
+ assert retrieved is not None
315
+ assert retrieved.last_accessed_at is not None
316
+ assert retrieved.access_count == 1
317
+
318
+ def test_record_access_increments_count(self, store: InMemoryDatasetStore) -> None:
319
+ """record_access increments access_count."""
320
+ meta = _create_test_meta("ds-1")
321
+ store.save("ds-1", meta, _create_test_arrays())
322
+
323
+ store.record_access("ds-1")
324
+ store.record_access("ds-1")
325
+ store.record_access("ds-1")
326
+
327
+ retrieved = store.get_meta("ds-1")
328
+ assert retrieved is not None
329
+ assert retrieved.access_count == 3
330
+
331
+
332
+ @pytest.mark.unit
333
+ class TestDatasetStats:
334
+ """Tests for aggregate statistics functionality."""
335
+
336
+ def test_stats_empty_store(self, store: InMemoryDatasetStore) -> None:
337
+ """Stats for empty store."""
338
+ stats = store.get_stats()
339
+
340
+ assert stats["total_datasets"] == 0
341
+ assert stats["total_samples"] == 0
342
+ assert stats["by_generator"] == {}
343
+ assert stats["by_tag"] == {}
344
+
345
+ def test_stats_populated_store(self, store: InMemoryDatasetStore) -> None:
346
+ """Stats for populated store."""
347
+ meta1 = _create_test_meta("ds-1", generator="spiral", n_samples=100, tags=["train", "v1"])
348
+ meta2 = _create_test_meta("ds-2", generator="spiral", n_samples=200, tags=["train", "v2"])
349
+ meta3 = _create_test_meta("ds-3", generator="xor", n_samples=50, tags=["test"])
350
+
351
+ store.save("ds-1", meta1, _create_test_arrays(100))
352
+ store.save("ds-2", meta2, _create_test_arrays(200))
353
+ store.save("ds-3", meta3, _create_test_arrays(50))
354
+
355
+ stats = store.get_stats()
356
+
357
+ assert stats["total_datasets"] == 3
358
+ assert stats["total_samples"] == 350
359
+ assert stats["by_generator"] == {"spiral": 2, "xor": 1}
360
+ assert stats["by_tag"] == {"train": 2, "v1": 1, "v2": 1, "test": 1}
361
+
362
+ def test_stats_counts_expired(self, store: InMemoryDatasetStore) -> None:
363
+ """Stats includes expired count."""
364
+ past_time = datetime.now(UTC) - timedelta(hours=2)
365
+ meta_expired = _create_test_meta("expired", ttl_seconds=3600, created_at=past_time)
366
+ meta_valid = _create_test_meta("valid")
367
+
368
+ store.save("expired", meta_expired, _create_test_arrays())
369
+ store.save("valid", meta_valid, _create_test_arrays())
370
+
371
+ stats = store.get_stats()
372
+
373
+ assert stats["expired_count"] == 1
374
+
375
+
376
+ @pytest.mark.unit
377
+ class TestListAllMetadata:
378
+ """Tests for list_all_metadata functionality."""
379
+
380
+ def test_list_all_metadata_empty(self, store: InMemoryDatasetStore) -> None:
381
+ """list_all_metadata returns empty list for empty store."""
382
+ result = store.list_all_metadata()
383
+ assert result == []
384
+
385
+ def test_list_all_metadata_returns_all(self, store: InMemoryDatasetStore) -> None:
386
+ """list_all_metadata returns all stored metadata."""
387
+ for i in range(5):
388
+ meta = _create_test_meta(f"ds-{i}")
389
+ store.save(f"ds-{i}", meta, _create_test_arrays())
390
+
391
+ result = store.list_all_metadata()
392
+ assert len(result) == 5
393
+ ids = {m.dataset_id for m in result}
394
+ assert ids == {"ds-0", "ds-1", "ds-2", "ds-3", "ds-4"}
@@ -0,0 +1,127 @@
1
+ """Unit tests for __main__.py entry point."""
2
+
3
+ import sys
4
+ from unittest.mock import patch
5
+
6
+ import pytest
7
+
8
+ from juniper_data.api.settings import (
9
+ _JUNIPER_DATA_API_HOST_DEFAULT,
10
+ )
11
+
12
+
13
+ @pytest.mark.unit
14
+ class TestMain:
15
+ """Tests for the main() entry point function."""
16
+
17
+ def test_main_import_error_uvicorn_not_installed(self) -> None:
18
+ """Test main returns 1 when uvicorn is not installed."""
19
+ import builtins
20
+ import importlib
21
+
22
+ original_import = builtins.__import__
23
+
24
+ def mock_import(name, *args, **kwargs):
25
+ if name == "uvicorn":
26
+ raise ImportError("No module named 'uvicorn'")
27
+ return original_import(name, *args, **kwargs)
28
+
29
+ with (
30
+ patch.object(sys, "argv", ["juniper_data"]),
31
+ patch("builtins.print") as mock_print,
32
+ patch.object(builtins, "__import__", side_effect=mock_import),
33
+ patch.dict(sys.modules, {"uvicorn": None}),
34
+ ):
35
+ from juniper_data import __main__ as main_module
36
+
37
+ try:
38
+ importlib.reload(main_module)
39
+ result = main_module.main()
40
+ assert result == 1
41
+ mock_print.assert_called()
42
+ except ImportError as e:
43
+ # If ImportError occurs during test setup, skip with explanation
44
+ pytest.skip(f"Cannot test uvicorn import error scenario: {e}")
45
+
46
+ def test_main_parses_host_argument(self) -> None:
47
+ """Test main correctly parses --host argument."""
48
+ with patch("uvicorn.run") as mock_run:
49
+ with patch.object(sys, "argv", ["juniper_data", "--host", "127.0.0.1"]):
50
+ call_kwargs = self._get_call_args_from_mocked_main_run(mock_run)
51
+ # assert call_kwargs[1]["host"] == "127.0.0.1"
52
+ assert call_kwargs[1]["host"] == _JUNIPER_DATA_API_HOST_DEFAULT
53
+
54
+ def test_main_parses_port_argument(self) -> None:
55
+ """Test main correctly parses --port argument."""
56
+ with patch("uvicorn.run") as mock_run:
57
+ with patch.object(sys, "argv", ["juniper_data", "--port", "9000"]):
58
+ call_kwargs = self._get_call_args_from_mocked_main_run(mock_run)
59
+ assert call_kwargs[1]["port"] == 9000
60
+
61
+ def test_main_parses_log_level_argument(self) -> None:
62
+ """Test main correctly parses --log-level argument."""
63
+ with patch("uvicorn.run") as mock_run:
64
+ with patch.object(sys, "argv", ["juniper_data", "--log-level", "DEBUG"]):
65
+ call_kwargs = self._get_call_args_from_mocked_main_run(mock_run)
66
+ assert call_kwargs[1]["log_level"] == "debug"
67
+
68
+ def test_main_parses_reload_argument(self) -> None:
69
+ """Test main correctly parses --reload argument."""
70
+ with patch("uvicorn.run") as mock_run:
71
+ with patch.object(sys, "argv", ["juniper_data", "--reload"]):
72
+ call_kwargs = self._get_call_args_from_mocked_main_run(mock_run)
73
+ assert call_kwargs[1]["reload"] is True
74
+
75
+ def test_main_parses_storage_path_argument(self) -> None:
76
+ """Test main correctly parses --storage-path argument and sets env var."""
77
+ with patch("uvicorn.run") as mock_run:
78
+ with patch.dict("os.environ", {}, clear=False):
79
+ with patch.object(sys, "argv", ["juniper_data", "--storage-path", "/custom/path"]):
80
+ import os
81
+
82
+ from juniper_data.__main__ import main
83
+
84
+ main()
85
+ assert os.environ.get("JUNIPER_DATA_STORAGE_PATH") == "/custom/path"
86
+ mock_run.assert_called_once()
87
+
88
+ def test_main_uses_default_settings_when_no_args(self) -> None:
89
+ """Test main uses settings defaults when no args provided."""
90
+ with patch("uvicorn.run") as mock_run:
91
+ with patch.object(sys, "argv", ["juniper_data"]):
92
+ # self._validate_mocked_host_name_and_port_args(mock_run, "0.0.0.0")
93
+ self._validate_mocked_host_name_and_port_args(mock_run, _JUNIPER_DATA_API_HOST_DEFAULT)
94
+
95
+ def test_main_returns_zero_on_success(self) -> None:
96
+ """Test main returns 0 on successful run."""
97
+ with patch("uvicorn.run"):
98
+ with patch.object(sys, "argv", ["juniper_data"]):
99
+ from juniper_data.__main__ import main
100
+
101
+ result = main()
102
+ assert result == 0
103
+
104
+ def test_main_app_string(self) -> None:
105
+ """Test main passes correct app string to uvicorn."""
106
+ with patch("uvicorn.run") as mock_run:
107
+ with patch.object(sys, "argv", ["juniper_data"]):
108
+ call_args = self._get_call_args_from_mocked_main_run(mock_run)
109
+ assert call_args[0][0] == "juniper_data.api.app:app"
110
+
111
+ def test_main_combines_custom_and_default_args(self) -> None:
112
+ """Test main combines custom args with settings defaults."""
113
+ with patch("uvicorn.run") as mock_run:
114
+ with patch.object(sys, "argv", ["juniper_data", "--host", "localhost"]):
115
+ self._validate_mocked_host_name_and_port_args(mock_run, "localhost")
116
+
117
+ def _validate_mocked_host_name_and_port_args(self, mock_run, arg1):
118
+ call_kwargs = self._get_call_args_from_mocked_main_run(mock_run)
119
+ assert call_kwargs[1]["host"] == arg1
120
+ assert call_kwargs[1]["port"] == 8100
121
+
122
+ def _get_call_args_from_mocked_main_run(self, mock_run):
123
+ from juniper_data.__main__ import main
124
+
125
+ main()
126
+ mock_run.assert_called_once()
127
+ return mock_run.call_args
@@ -0,0 +1,79 @@
1
+ """Unit tests for SecurityMiddleware."""
2
+
3
+ import pytest
4
+ from fastapi import FastAPI
5
+ from fastapi.testclient import TestClient
6
+
7
+ from juniper_data.api.middleware import EXEMPT_PATHS, SecurityMiddleware
8
+ from juniper_data.api.security import APIKeyAuth, RateLimiter
9
+
10
+
11
+ @pytest.fixture
12
+ def app_with_middleware():
13
+ """Create a FastAPI app with security middleware."""
14
+
15
+ def _create(api_keys=None, rate_limit_enabled=False, rpm=60):
16
+ app = FastAPI()
17
+ auth = APIKeyAuth(api_keys)
18
+ limiter = RateLimiter(requests_per_minute=rpm, enabled=rate_limit_enabled)
19
+ app.add_middleware(SecurityMiddleware, api_key_auth=auth, rate_limiter=limiter)
20
+
21
+ @app.get("/v1/health")
22
+ async def health():
23
+ return {"status": "ok"}
24
+
25
+ @app.get("/v1/datasets")
26
+ async def datasets():
27
+ return {"data": []}
28
+
29
+ return app
30
+
31
+ return _create
32
+
33
+
34
+ @pytest.mark.unit
35
+ class TestSecurityMiddleware:
36
+ def test_exempt_path_bypasses_security(self, app_with_middleware):
37
+ app = app_with_middleware(api_keys=["secret"])
38
+ client = TestClient(app)
39
+ response = client.get("/v1/health")
40
+ assert response.status_code == 200
41
+
42
+ def test_auth_required_returns_401(self, app_with_middleware):
43
+ app = app_with_middleware(api_keys=["secret"])
44
+ client = TestClient(app)
45
+ response = client.get("/v1/datasets")
46
+ assert response.status_code == 401
47
+
48
+ def test_invalid_key_returns_401(self, app_with_middleware):
49
+ app = app_with_middleware(api_keys=["secret"])
50
+ client = TestClient(app)
51
+ response = client.get("/v1/datasets", headers={"X-API-Key": "wrong"})
52
+ assert response.status_code == 401
53
+
54
+ def test_valid_key_passes(self, app_with_middleware):
55
+ app = app_with_middleware(api_keys=["secret"])
56
+ client = TestClient(app)
57
+ response = client.get("/v1/datasets", headers={"X-API-Key": "secret"})
58
+ assert response.status_code == 200
59
+
60
+ def test_rate_limit_exceeded_returns_429(self, app_with_middleware):
61
+ app = app_with_middleware(rate_limit_enabled=True, rpm=2)
62
+ client = TestClient(app)
63
+ for _ in range(2):
64
+ client.get("/v1/datasets")
65
+ response = client.get("/v1/datasets")
66
+ assert response.status_code == 429
67
+
68
+ def test_rate_limit_headers_included(self, app_with_middleware):
69
+ app = app_with_middleware(rate_limit_enabled=True, rpm=10)
70
+ client = TestClient(app)
71
+ response = client.get("/v1/datasets")
72
+ assert response.status_code == 200
73
+ assert "X-RateLimit-Limit" in response.headers
74
+ assert "X-RateLimit-Remaining" in response.headers
75
+
76
+ def test_is_exempt_checks_known_paths(self):
77
+ assert "/v1/health" in EXEMPT_PATHS
78
+ assert "/docs" in EXEMPT_PATHS
79
+ assert "/v1/datasets" not in EXEMPT_PATHS