juniper-data 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. juniper_data/__init__.py +88 -0
  2. juniper_data/__main__.py +78 -0
  3. juniper_data/api/__init__.py +10 -0
  4. juniper_data/api/app.py +111 -0
  5. juniper_data/api/middleware.py +95 -0
  6. juniper_data/api/routes/__init__.py +9 -0
  7. juniper_data/api/routes/datasets.py +414 -0
  8. juniper_data/api/routes/generators.py +125 -0
  9. juniper_data/api/routes/health.py +49 -0
  10. juniper_data/api/security.py +238 -0
  11. juniper_data/api/settings.py +109 -0
  12. juniper_data/core/__init__.py +32 -0
  13. juniper_data/core/artifacts.py +63 -0
  14. juniper_data/core/dataset_id.py +38 -0
  15. juniper_data/core/models.py +135 -0
  16. juniper_data/core/split.py +120 -0
  17. juniper_data/generators/__init__.py +15 -0
  18. juniper_data/generators/arc_agi/__init__.py +11 -0
  19. juniper_data/generators/arc_agi/generator.py +229 -0
  20. juniper_data/generators/arc_agi/params.py +56 -0
  21. juniper_data/generators/checkerboard/__init__.py +15 -0
  22. juniper_data/generators/checkerboard/generator.py +114 -0
  23. juniper_data/generators/checkerboard/params.py +32 -0
  24. juniper_data/generators/circles/__init__.py +11 -0
  25. juniper_data/generators/circles/generator.py +112 -0
  26. juniper_data/generators/circles/params.py +31 -0
  27. juniper_data/generators/csv_import/__init__.py +15 -0
  28. juniper_data/generators/csv_import/generator.py +198 -0
  29. juniper_data/generators/csv_import/params.py +48 -0
  30. juniper_data/generators/gaussian/__init__.py +11 -0
  31. juniper_data/generators/gaussian/generator.py +149 -0
  32. juniper_data/generators/gaussian/params.py +53 -0
  33. juniper_data/generators/mnist/__init__.py +11 -0
  34. juniper_data/generators/mnist/generator.py +124 -0
  35. juniper_data/generators/mnist/params.py +39 -0
  36. juniper_data/generators/spiral/__init__.py +57 -0
  37. juniper_data/generators/spiral/defaults.py +39 -0
  38. juniper_data/generators/spiral/generator.py +206 -0
  39. juniper_data/generators/spiral/params.py +148 -0
  40. juniper_data/generators/xor/__init__.py +11 -0
  41. juniper_data/generators/xor/generator.py +162 -0
  42. juniper_data/generators/xor/params.py +30 -0
  43. juniper_data/storage/__init__.py +120 -0
  44. juniper_data/storage/base.py +279 -0
  45. juniper_data/storage/cached.py +211 -0
  46. juniper_data/storage/hf_store.py +257 -0
  47. juniper_data/storage/kaggle_store.py +333 -0
  48. juniper_data/storage/local_fs.py +232 -0
  49. juniper_data/storage/memory.py +136 -0
  50. juniper_data/storage/postgres_store.py +373 -0
  51. juniper_data/storage/redis_store.py +264 -0
  52. juniper_data/tests/__init__.py +1 -0
  53. juniper_data/tests/conftest.py +68 -0
  54. juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
  55. juniper_data/tests/integration/__init__.py +1 -0
  56. juniper_data/tests/integration/test_api.py +283 -0
  57. juniper_data/tests/integration/test_e2e_workflow.py +378 -0
  58. juniper_data/tests/integration/test_lifecycle_api.py +304 -0
  59. juniper_data/tests/integration/test_security_integration.py +189 -0
  60. juniper_data/tests/integration/test_storage_workflow.py +259 -0
  61. juniper_data/tests/performance/__init__.py +1 -0
  62. juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
  63. juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
  64. juniper_data/tests/unit/__init__.py +1 -0
  65. juniper_data/tests/unit/test_api_app.py +206 -0
  66. juniper_data/tests/unit/test_api_routes.py +407 -0
  67. juniper_data/tests/unit/test_api_settings.py +100 -0
  68. juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
  69. juniper_data/tests/unit/test_artifacts.py +145 -0
  70. juniper_data/tests/unit/test_cached_store.py +423 -0
  71. juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
  72. juniper_data/tests/unit/test_circles_generator.py +256 -0
  73. juniper_data/tests/unit/test_csv_import_generator.py +345 -0
  74. juniper_data/tests/unit/test_dataset_id.py +181 -0
  75. juniper_data/tests/unit/test_gaussian_generator.py +333 -0
  76. juniper_data/tests/unit/test_hf_store.py +416 -0
  77. juniper_data/tests/unit/test_init.py +93 -0
  78. juniper_data/tests/unit/test_kaggle_store.py +469 -0
  79. juniper_data/tests/unit/test_lifecycle.py +394 -0
  80. juniper_data/tests/unit/test_main.py +127 -0
  81. juniper_data/tests/unit/test_middleware.py +79 -0
  82. juniper_data/tests/unit/test_mnist_generator.py +370 -0
  83. juniper_data/tests/unit/test_postgres_store.py +490 -0
  84. juniper_data/tests/unit/test_redis_store.py +500 -0
  85. juniper_data/tests/unit/test_security.py +281 -0
  86. juniper_data/tests/unit/test_security_boundaries.py +517 -0
  87. juniper_data/tests/unit/test_spiral_generator.py +566 -0
  88. juniper_data/tests/unit/test_split.py +245 -0
  89. juniper_data/tests/unit/test_storage.py +767 -0
  90. juniper_data/tests/unit/test_xor_generator.py +223 -0
  91. juniper_data-0.4.2.dist-info/METADATA +216 -0
  92. juniper_data-0.4.2.dist-info/RECORD +95 -0
  93. juniper_data-0.4.2.dist-info/WHEEL +5 -0
  94. juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
  95. juniper_data-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,189 @@
1
+ """Integration tests for API security middleware."""
2
+
3
+ import pytest
4
+ from fastapi.testclient import TestClient
5
+
6
+ from juniper_data.api.app import create_app
7
+ from juniper_data.api.routes import datasets
8
+ from juniper_data.api.settings import Settings
9
+ from juniper_data.storage import InMemoryDatasetStore
10
+
11
+
12
+ @pytest.fixture
13
+ def auth_enabled_client() -> TestClient:
14
+ """Create a test client with API key authentication enabled."""
15
+ settings = Settings(
16
+ storage_path="./test_data",
17
+ api_keys=["valid-key-1", "valid-key-2"],
18
+ rate_limit_enabled=False,
19
+ )
20
+ app = create_app(settings)
21
+ datasets.set_store(InMemoryDatasetStore())
22
+ return TestClient(app)
23
+
24
+
25
+ @pytest.fixture
26
+ def rate_limited_client() -> TestClient:
27
+ """Create a test client with rate limiting enabled."""
28
+ settings = Settings(
29
+ storage_path="./test_data",
30
+ api_keys=None,
31
+ rate_limit_enabled=True,
32
+ rate_limit_requests_per_minute=5,
33
+ )
34
+ app = create_app(settings)
35
+ datasets.set_store(InMemoryDatasetStore())
36
+ return TestClient(app)
37
+
38
+
39
+ @pytest.fixture
40
+ def fully_secured_client() -> TestClient:
41
+ """Create a test client with both auth and rate limiting enabled."""
42
+ settings = Settings(
43
+ storage_path="./test_data",
44
+ api_keys=["secure-key"],
45
+ rate_limit_enabled=True,
46
+ rate_limit_requests_per_minute=10,
47
+ )
48
+ app = create_app(settings)
49
+ datasets.set_store(InMemoryDatasetStore())
50
+ return TestClient(app)
51
+
52
+
53
+ class TestAPIKeyAuthentication:
54
+ """Integration tests for API key authentication."""
55
+
56
+ def test_health_endpoint_exempt(self, auth_enabled_client: TestClient) -> None:
57
+ """Health endpoints should be accessible without API key."""
58
+ response = auth_enabled_client.get("/v1/health")
59
+ assert response.status_code == 200
60
+
61
+ response = auth_enabled_client.get("/v1/health/live")
62
+ assert response.status_code == 200
63
+
64
+ response = auth_enabled_client.get("/v1/health/ready")
65
+ assert response.status_code == 200
66
+
67
+ def test_protected_endpoint_requires_key(self, auth_enabled_client: TestClient) -> None:
68
+ """Protected endpoints should require API key."""
69
+ response = auth_enabled_client.get("/v1/generators")
70
+ assert response.status_code == 401
71
+ assert "Missing API key" in response.json()["detail"]
72
+
73
+ def test_invalid_key_rejected(self, auth_enabled_client: TestClient) -> None:
74
+ """Invalid API keys should be rejected."""
75
+ response = auth_enabled_client.get(
76
+ "/v1/generators",
77
+ headers={"X-API-Key": "invalid-key"},
78
+ )
79
+ assert response.status_code == 401
80
+ assert "Invalid API key" in response.json()["detail"]
81
+
82
+ def test_valid_key_accepted(self, auth_enabled_client: TestClient) -> None:
83
+ """Valid API keys should be accepted."""
84
+ response = auth_enabled_client.get(
85
+ "/v1/generators",
86
+ headers={"X-API-Key": "valid-key-1"},
87
+ )
88
+ assert response.status_code == 200
89
+
90
+ response = auth_enabled_client.get(
91
+ "/v1/generators",
92
+ headers={"X-API-Key": "valid-key-2"},
93
+ )
94
+ assert response.status_code == 200
95
+
96
+ def test_create_dataset_with_auth(self, auth_enabled_client: TestClient) -> None:
97
+ """Dataset creation should work with valid API key."""
98
+ response = auth_enabled_client.post(
99
+ "/v1/datasets",
100
+ json={"generator": "spiral", "params": {"seed": 42}},
101
+ headers={"X-API-Key": "valid-key-1"},
102
+ )
103
+ assert response.status_code == 201
104
+ assert "dataset_id" in response.json()
105
+
106
+
107
+ class TestRateLimiting:
108
+ """Integration tests for rate limiting."""
109
+
110
+ def test_health_endpoint_exempt(self, rate_limited_client: TestClient) -> None:
111
+ """Health endpoints should be exempt from rate limiting."""
112
+ for _ in range(20):
113
+ response = rate_limited_client.get("/v1/health")
114
+ assert response.status_code == 200
115
+
116
+ def test_allows_requests_within_limit(self, rate_limited_client: TestClient) -> None:
117
+ """Requests within limit should be allowed."""
118
+ for i in range(5):
119
+ response = rate_limited_client.get("/v1/generators")
120
+ assert response.status_code == 200
121
+ assert "X-RateLimit-Remaining" in response.headers
122
+ assert int(response.headers["X-RateLimit-Remaining"]) == 4 - i
123
+
124
+ def test_blocks_requests_over_limit(self, rate_limited_client: TestClient) -> None:
125
+ """Requests over limit should be blocked with 429."""
126
+ for _ in range(5):
127
+ rate_limited_client.get("/v1/generators")
128
+
129
+ response = rate_limited_client.get("/v1/generators")
130
+ assert response.status_code == 429
131
+ assert "Rate limit exceeded" in response.json()["detail"]
132
+ assert "Retry-After" in response.headers
133
+
134
+ def test_rate_limit_headers_present(self, rate_limited_client: TestClient) -> None:
135
+ """Rate limit headers should be present in responses."""
136
+ response = rate_limited_client.get("/v1/generators")
137
+ assert response.status_code == 200
138
+
139
+ assert "X-RateLimit-Limit" in response.headers
140
+ assert response.headers["X-RateLimit-Limit"] == "5"
141
+ assert "X-RateLimit-Remaining" in response.headers
142
+ assert "X-RateLimit-Reset" in response.headers
143
+
144
+
145
+ class TestCombinedSecurity:
146
+ """Integration tests for combined auth and rate limiting."""
147
+
148
+ def test_auth_checked_before_rate_limit(self, fully_secured_client: TestClient) -> None:
149
+ """Authentication should be checked before rate limiting."""
150
+ response = fully_secured_client.get("/v1/generators")
151
+ assert response.status_code == 401
152
+
153
+ response = fully_secured_client.get(
154
+ "/v1/generators",
155
+ headers={"X-API-Key": "invalid"},
156
+ )
157
+ assert response.status_code == 401
158
+
159
+ def test_rate_limit_applied_after_auth(self, fully_secured_client: TestClient) -> None:
160
+ """Rate limiting should be applied after successful auth."""
161
+ for _ in range(10):
162
+ response = fully_secured_client.get(
163
+ "/v1/generators",
164
+ headers={"X-API-Key": "secure-key"},
165
+ )
166
+ assert response.status_code == 200
167
+
168
+ response = fully_secured_client.get(
169
+ "/v1/generators",
170
+ headers={"X-API-Key": "secure-key"},
171
+ )
172
+ assert response.status_code == 429
173
+
174
+ def test_full_workflow_with_security(self, fully_secured_client: TestClient) -> None:
175
+ """Full dataset workflow should work with security enabled."""
176
+ response = fully_secured_client.post(
177
+ "/v1/datasets",
178
+ json={"generator": "xor", "params": {"seed": 42}},
179
+ headers={"X-API-Key": "secure-key"},
180
+ )
181
+ assert response.status_code == 201
182
+ dataset_id = response.json()["dataset_id"]
183
+
184
+ response = fully_secured_client.get(
185
+ f"/v1/datasets/{dataset_id}",
186
+ headers={"X-API-Key": "secure-key"},
187
+ )
188
+ assert response.status_code == 200
189
+ assert response.json()["dataset_id"] == dataset_id
@@ -0,0 +1,259 @@
1
+ """Integration tests for storage workflows.
2
+
3
+ Tests cover end-to-end scenarios:
4
+ - Generate → Store → Retrieve → Verify
5
+ - Cross-store compatibility
6
+ - Full dataset lifecycle
7
+ """
8
+
9
+ import io
10
+ import tempfile
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+ import pytest
16
+
17
+ from juniper_data.core.artifacts import compute_checksum
18
+ from juniper_data.core.models import DatasetMeta
19
+ from juniper_data.generators.spiral import SpiralGenerator, SpiralParams
20
+ from juniper_data.storage import InMemoryDatasetStore, LocalFSDatasetStore
21
+
22
+
23
+ @pytest.fixture
24
+ def temp_storage_dir():
25
+ """Create a temporary directory for storage tests."""
26
+ with tempfile.TemporaryDirectory() as tmpdir:
27
+ yield Path(tmpdir)
28
+
29
+
30
+ @pytest.fixture
31
+ def spiral_params() -> SpiralParams:
32
+ """Standard spiral parameters for testing."""
33
+ return SpiralParams(
34
+ n_spirals=2,
35
+ n_points_per_spiral=100,
36
+ noise=0.1,
37
+ seed=42,
38
+ )
39
+
40
+
41
+ def create_dataset_meta(
42
+ dataset_id: str, params: SpiralParams, X: np.ndarray, y: np.ndarray, n_train: int, n_test: int
43
+ ) -> DatasetMeta:
44
+ """Helper to create DatasetMeta from generated data."""
45
+ n_classes = y.shape[1] if len(y.shape) > 1 else len(np.unique(y))
46
+ class_counts = np.sum(y, axis=0).astype(int) if len(y.shape) > 1 else np.bincount(y.astype(int))
47
+
48
+ return DatasetMeta(
49
+ dataset_id=dataset_id,
50
+ generator="spiral",
51
+ generator_version="1.0.0",
52
+ params=params.model_dump(),
53
+ n_samples=len(X),
54
+ n_features=X.shape[1],
55
+ n_classes=n_classes,
56
+ n_train=n_train,
57
+ n_test=n_test,
58
+ class_distribution={str(i): int(c) for i, c in enumerate(class_counts)},
59
+ created_at=datetime.now(),
60
+ )
61
+
62
+
63
+ @pytest.mark.integration
64
+ class TestGenerateStoreRetrieveWorkflow:
65
+ """Tests for the complete generate → store → retrieve workflow."""
66
+
67
+ def test_memory_store_full_workflow(self, spiral_params: SpiralParams):
68
+ """Generate spiral → store in memory → retrieve and verify."""
69
+ store = InMemoryDatasetStore()
70
+
71
+ data = SpiralGenerator.generate(spiral_params)
72
+ X, y = data["X_full"], data["y_full"]
73
+ X_train, y_train = data["X_train"], data["y_train"]
74
+ X_test, y_test = data["X_test"], data["y_test"]
75
+
76
+ arrays = {"X_train": X_train, "y_train": y_train, "X_test": X_test, "y_test": y_test}
77
+
78
+ meta = create_dataset_meta("spiral-mem-001", spiral_params, X, y, len(X_train), len(X_test))
79
+
80
+ store.save("spiral-mem-001", meta, arrays)
81
+
82
+ assert store.exists("spiral-mem-001")
83
+ retrieved_meta = store.get_meta("spiral-mem-001")
84
+ assert retrieved_meta is not None
85
+ assert retrieved_meta.dataset_id == "spiral-mem-001"
86
+ assert retrieved_meta.n_samples == 200
87
+
88
+ artifact_bytes = store.get_artifact_bytes("spiral-mem-001")
89
+ assert artifact_bytes is not None
90
+ loaded = np.load(io.BytesIO(artifact_bytes))
91
+ np.testing.assert_array_equal(loaded["X_train"], X_train)
92
+ np.testing.assert_array_equal(loaded["y_train"], y_train)
93
+
94
+ def test_fs_store_full_workflow(self, temp_storage_dir: Path, spiral_params: SpiralParams):
95
+ """Generate spiral → store to filesystem → retrieve and verify."""
96
+ store = LocalFSDatasetStore(temp_storage_dir)
97
+
98
+ data = SpiralGenerator.generate(spiral_params)
99
+ X, y = data["X_full"], data["y_full"]
100
+ X_train, y_train = data["X_train"], data["y_train"]
101
+ X_test, y_test = data["X_test"], data["y_test"]
102
+
103
+ arrays = {"X_train": X_train, "y_train": y_train, "X_test": X_test, "y_test": y_test}
104
+
105
+ meta = create_dataset_meta("spiral-fs-001", spiral_params, X, y, len(X_train), len(X_test))
106
+
107
+ store.save("spiral-fs-001", meta, arrays)
108
+
109
+ assert store.exists("spiral-fs-001")
110
+ assert (temp_storage_dir / "spiral-fs-001.meta.json").exists()
111
+ assert (temp_storage_dir / "spiral-fs-001.npz").exists()
112
+
113
+ retrieved_meta = store.get_meta("spiral-fs-001")
114
+ assert retrieved_meta is not None
115
+ assert retrieved_meta.dataset_id == "spiral-fs-001"
116
+ assert retrieved_meta.n_samples == 200
117
+
118
+ artifact_bytes = store.get_artifact_bytes("spiral-fs-001")
119
+ assert artifact_bytes is not None
120
+ loaded = np.load(io.BytesIO(artifact_bytes))
121
+ np.testing.assert_array_equal(loaded["X_train"], X_train)
122
+
123
+ def test_persistence_across_store_instances(self, temp_storage_dir: Path, spiral_params: SpiralParams):
124
+ """Data persists when creating new store instance."""
125
+ store1 = LocalFSDatasetStore(temp_storage_dir)
126
+
127
+ data = SpiralGenerator.generate(spiral_params)
128
+ X, y = data["X_full"], data["y_full"]
129
+ X_train, y_train = data["X_train"], data["y_train"]
130
+ X_test, y_test = data["X_test"], data["y_test"]
131
+
132
+ arrays = {"X_train": X_train, "y_train": y_train, "X_test": X_test, "y_test": y_test}
133
+ meta = create_dataset_meta("persist-test", spiral_params, X, y, len(X_train), len(X_test))
134
+
135
+ store1.save("persist-test", meta, arrays)
136
+
137
+ store2 = LocalFSDatasetStore(temp_storage_dir)
138
+
139
+ assert store2.exists("persist-test")
140
+ retrieved = store2.get_meta("persist-test")
141
+ assert retrieved is not None
142
+ assert retrieved.dataset_id == "persist-test"
143
+
144
+ datasets = store2.list_datasets()
145
+ assert "persist-test" in datasets
146
+
147
+
148
+ @pytest.mark.integration
149
+ class TestDatasetLifecycle:
150
+ """Tests for complete dataset lifecycle operations."""
151
+
152
+ def test_create_update_delete_lifecycle(self, temp_storage_dir: Path, spiral_params: SpiralParams):
153
+ """Test full create → verify → delete lifecycle."""
154
+ store = LocalFSDatasetStore(temp_storage_dir)
155
+
156
+ data = SpiralGenerator.generate(spiral_params)
157
+ X, y = data["X_full"], data["y_full"]
158
+ X_train, y_train = data["X_train"], data["y_train"]
159
+ X_test, y_test = data["X_test"], data["y_test"]
160
+
161
+ arrays = {"X_train": X_train, "y_train": y_train, "X_test": X_test, "y_test": y_test}
162
+ meta = create_dataset_meta("lifecycle-test", spiral_params, X, y, len(X_train), len(X_test))
163
+
164
+ store.save("lifecycle-test", meta, arrays)
165
+ assert store.exists("lifecycle-test")
166
+ assert "lifecycle-test" in store.list_datasets()
167
+
168
+ deleted = store.delete("lifecycle-test")
169
+ assert deleted
170
+ assert not store.exists("lifecycle-test")
171
+ assert "lifecycle-test" not in store.list_datasets()
172
+ assert store.get_meta("lifecycle-test") is None
173
+ assert store.get_artifact_bytes("lifecycle-test") is None
174
+
175
+ def test_multiple_datasets(self, temp_storage_dir: Path):
176
+ """Store and manage multiple datasets."""
177
+ store = LocalFSDatasetStore(temp_storage_dir)
178
+
179
+ for i, n_points in enumerate([50, 100, 200]):
180
+ params = SpiralParams(n_spirals=2, n_points_per_spiral=n_points, seed=i)
181
+ data = SpiralGenerator.generate(params)
182
+ X, y = data["X_full"], data["y_full"]
183
+ X_train, y_train = data["X_train"], data["y_train"]
184
+ X_test, y_test = data["X_test"], data["y_test"]
185
+
186
+ arrays = {"X_train": X_train, "y_train": y_train, "X_test": X_test, "y_test": y_test}
187
+ meta = create_dataset_meta(f"multi-{i}", params, X, y, len(X_train), len(X_test))
188
+
189
+ store.save(f"multi-{i}", meta, arrays)
190
+
191
+ datasets = store.list_datasets()
192
+ assert len(datasets) == 3
193
+ assert all(f"multi-{i}" in datasets for i in range(3))
194
+
195
+ store.delete("multi-1")
196
+ datasets = store.list_datasets()
197
+ assert len(datasets) == 2
198
+ assert "multi-1" not in datasets
199
+
200
+
201
+ @pytest.mark.integration
202
+ class TestChecksumVerification:
203
+ """Tests for checksum verification across storage operations."""
204
+
205
+ def test_checksum_consistency(self, temp_storage_dir: Path, spiral_params: SpiralParams):
206
+ """Checksum remains consistent through store/retrieve cycle."""
207
+ store = LocalFSDatasetStore(temp_storage_dir)
208
+
209
+ data = SpiralGenerator.generate(spiral_params)
210
+ X, y = data["X_full"], data["y_full"]
211
+ X_train, y_train = data["X_train"], data["y_train"]
212
+ X_test, y_test = data["X_test"], data["y_test"]
213
+
214
+ arrays = {"X_train": X_train, "y_train": y_train, "X_test": X_test, "y_test": y_test}
215
+
216
+ original_checksum = compute_checksum(arrays)
217
+
218
+ meta = create_dataset_meta("checksum-test", spiral_params, X, y, len(X_train), len(X_test))
219
+ meta.checksum = original_checksum
220
+
221
+ store.save("checksum-test", meta, arrays)
222
+
223
+ retrieved_meta = store.get_meta("checksum-test")
224
+ assert retrieved_meta is not None
225
+ assert retrieved_meta.checksum == original_checksum
226
+
227
+ artifact_bytes = store.get_artifact_bytes("checksum-test")
228
+ assert artifact_bytes is not None
229
+ loaded = np.load(io.BytesIO(artifact_bytes))
230
+ loaded_arrays = {k: loaded[k] for k in loaded.files}
231
+
232
+ loaded_checksum = compute_checksum(loaded_arrays)
233
+ assert loaded_checksum == original_checksum
234
+
235
+
236
+ @pytest.mark.integration
237
+ class TestReproducibility:
238
+ """Tests for dataset generation reproducibility."""
239
+
240
+ def test_seed_reproducibility(self):
241
+ """Same seed produces identical datasets."""
242
+ params1 = SpiralParams(n_spirals=2, n_points_per_spiral=100, seed=42)
243
+ params2 = SpiralParams(n_spirals=2, n_points_per_spiral=100, seed=42)
244
+
245
+ data1 = SpiralGenerator.generate(params1)
246
+ data2 = SpiralGenerator.generate(params2)
247
+
248
+ np.testing.assert_array_equal(data1["X_full"], data2["X_full"])
249
+ np.testing.assert_array_equal(data1["y_full"], data2["y_full"])
250
+
251
+ def test_different_seeds_produce_different_data(self):
252
+ """Different seeds produce different datasets."""
253
+ params1 = SpiralParams(n_spirals=2, n_points_per_spiral=100, seed=42)
254
+ params2 = SpiralParams(n_spirals=2, n_points_per_spiral=100, seed=43)
255
+
256
+ data1 = SpiralGenerator.generate(params1)
257
+ data2 = SpiralGenerator.generate(params2)
258
+
259
+ assert not np.allclose(data1["X_full"], data2["X_full"])
@@ -0,0 +1 @@
1
+ """Performance benchmark tests for juniper_data."""
@@ -0,0 +1,178 @@
1
+ #####################################################################################################################################################################################################
2
+ # Project: Juniper
3
+ # Sub-Project: JuniperData
4
+ # Application: juniper_data
5
+ # File Name: test_generator_benchmarks.py
6
+ # Author: Paul Calnon
7
+ # Version: 0.4.2
8
+ #
9
+ # Date Created: 2026-02-25
10
+ # Last Modified: 2026-02-25
11
+ #
12
+ # License: MIT License
13
+ # Copyright: Copyright (c) 2024-2026 Paul Calnon
14
+ #
15
+ # Description:
16
+ # Performance benchmarks for dataset generators.
17
+ # Measures generation throughput (points/second) for each synthetic generator
18
+ # and scaling behavior across dataset sizes.
19
+ #
20
+ # Usage:
21
+ # # Run benchmarks with timing (disabled by default in addopts):
22
+ # pytest juniper_data/tests/performance/test_generator_benchmarks.py --benchmark-enable -v
23
+ #
24
+ # # Run with autosave for regression tracking:
25
+ # pytest juniper_data/tests/performance/test_generator_benchmarks.py --benchmark-enable --benchmark-autosave
26
+ #
27
+ # # Compare against saved baseline:
28
+ # pytest juniper_data/tests/performance/test_generator_benchmarks.py --benchmark-enable --benchmark-compare
29
+ #
30
+ # References:
31
+ # - RD-009: Performance Test Infrastructure
32
+ # - pytest-benchmark: https://pytest-benchmark.readthedocs.io/
33
+ #####################################################################################################################################################################################################
34
+
35
+ """Performance benchmarks for dataset generators.
36
+
37
+ Benchmarks measure generation throughput for each synthetic generator
38
+ and scaling behavior with increasing dataset sizes. External-dependency
39
+ generators (MNIST, ARC-AGI, CSV import) are excluded as they measure
40
+ I/O rather than generation performance.
41
+ """
42
+
43
+ import numpy as np
44
+ import pytest
45
+
46
+ from juniper_data.generators.checkerboard.generator import CheckerboardGenerator
47
+ from juniper_data.generators.checkerboard.params import CheckerboardParams
48
+ from juniper_data.generators.circles.generator import CirclesGenerator
49
+ from juniper_data.generators.circles.params import CirclesParams
50
+ from juniper_data.generators.gaussian.generator import GaussianGenerator
51
+ from juniper_data.generators.gaussian.params import GaussianParams
52
+ from juniper_data.generators.spiral.generator import SpiralGenerator
53
+ from juniper_data.generators.spiral.params import SpiralParams
54
+ from juniper_data.generators.xor.generator import XorGenerator
55
+ from juniper_data.generators.xor.params import XorParams
56
+
57
+ # ═══════════════════════════════════════════════════════════════════════════════
58
+ # Generator Throughput Benchmarks
59
+ # ═══════════════════════════════════════════════════════════════════════════════
60
+
61
+
62
+ @pytest.mark.performance
63
+ class TestGeneratorThroughput:
64
+ """Benchmark each synthetic generator with standard parameters.
65
+
66
+ Each test generates a dataset of ~1000 total points and validates
67
+ the output structure. The benchmark fixture handles timing and
68
+ iteration count automatically.
69
+ """
70
+
71
+ def test_spiral_generator(self, benchmark):
72
+ """Benchmark spiral dataset generation (1000 points)."""
73
+ params = SpiralParams(n_spirals=2, n_points_per_spiral=500, seed=42)
74
+ result = benchmark(SpiralGenerator.generate, params)
75
+ assert "X_train" in result
76
+ assert result["X_full"].shape[0] == 1000
77
+ assert result["X_train"].dtype == np.float32
78
+
79
+ def test_xor_generator(self, benchmark):
80
+ """Benchmark XOR dataset generation (1000 points)."""
81
+ params = XorParams(n_points_per_quadrant=250, seed=42)
82
+ result = benchmark(XorGenerator.generate, params)
83
+ assert "X_train" in result
84
+ assert result["X_full"].shape[0] == 1000
85
+ assert result["X_train"].dtype == np.float32
86
+
87
+ def test_gaussian_generator(self, benchmark):
88
+ """Benchmark Gaussian blobs dataset generation (1000 points)."""
89
+ params = GaussianParams(n_classes=2, n_samples_per_class=500, seed=42)
90
+ result = benchmark(GaussianGenerator.generate, params)
91
+ assert "X_train" in result
92
+ assert result["X_full"].shape[0] == 1000
93
+ assert result["X_train"].dtype == np.float32
94
+
95
+ def test_circles_generator(self, benchmark):
96
+ """Benchmark concentric circles dataset generation (1000 points)."""
97
+ params = CirclesParams(n_samples=1000, seed=42)
98
+ result = benchmark(CirclesGenerator.generate, params)
99
+ assert "X_train" in result
100
+ assert result["X_full"].shape[0] == 1000
101
+ assert result["X_train"].dtype == np.float32
102
+
103
+ def test_checkerboard_generator(self, benchmark):
104
+ """Benchmark checkerboard dataset generation (1000 points)."""
105
+ params = CheckerboardParams(n_samples=1000, seed=42)
106
+ result = benchmark(CheckerboardGenerator.generate, params)
107
+ assert "X_train" in result
108
+ assert result["X_full"].shape[0] == 1000
109
+ assert result["X_train"].dtype == np.float32
110
+
111
+
112
+ # ═══════════════════════════════════════════════════════════════════════════════
113
+ # Scaling Benchmarks
114
+ # ═══════════════════════════════════════════════════════════════════════════════
115
+
116
+
117
+ @pytest.mark.performance
118
+ class TestGeneratorScaling:
119
+ """Benchmark generation throughput across dataset sizes.
120
+
121
+ Tests the spiral generator (representative of numpy-based generators)
122
+ with increasing point counts to characterize scaling behavior.
123
+ """
124
+
125
+ @pytest.mark.parametrize(
126
+ "n_points_per_spiral",
127
+ [100, 500, 1000, 5000],
128
+ ids=["200pts", "1000pts", "2000pts", "10000pts"],
129
+ )
130
+ def test_spiral_scaling(self, benchmark, n_points_per_spiral):
131
+ """Benchmark spiral generation at various dataset sizes."""
132
+ params = SpiralParams(n_spirals=2, n_points_per_spiral=n_points_per_spiral, seed=42)
133
+ result = benchmark(SpiralGenerator.generate, params)
134
+ assert result["X_full"].shape[0] == n_points_per_spiral * 2
135
+
136
+ @pytest.mark.parametrize(
137
+ "n_samples",
138
+ [100, 500, 1000, 5000],
139
+ ids=["100pts", "500pts", "1000pts", "5000pts"],
140
+ )
141
+ def test_gaussian_scaling(self, benchmark, n_samples):
142
+ """Benchmark Gaussian generation at various dataset sizes."""
143
+ params = GaussianParams(n_classes=2, n_samples_per_class=n_samples // 2, seed=42)
144
+ result = benchmark(GaussianGenerator.generate, params)
145
+ assert result["X_full"].shape[0] == n_samples
146
+
147
+
148
+ # ═══════════════════════════════════════════════════════════════════════════════
149
+ # Multi-Class Benchmarks
150
+ # ═══════════════════════════════════════════════════════════════════════════════
151
+
152
+
153
+ @pytest.mark.performance
154
+ class TestMultiClassScaling:
155
+ """Benchmark generation throughput as class count increases."""
156
+
157
+ @pytest.mark.parametrize(
158
+ "n_spirals",
159
+ [2, 3, 5, 8],
160
+ ids=["2class", "3class", "5class", "8class"],
161
+ )
162
+ def test_spiral_class_scaling(self, benchmark, n_spirals):
163
+ """Benchmark spiral generation with varying class counts."""
164
+ params = SpiralParams(n_spirals=n_spirals, n_points_per_spiral=200, seed=42)
165
+ result = benchmark(SpiralGenerator.generate, params)
166
+ assert result["X_full"].shape[0] == n_spirals * 200
167
+ assert result["y_full"].shape[1] == n_spirals
168
+
169
+ @pytest.mark.parametrize(
170
+ "n_classes",
171
+ [2, 3, 5, 8],
172
+ ids=["2class", "3class", "5class", "8class"],
173
+ )
174
+ def test_gaussian_class_scaling(self, benchmark, n_classes):
175
+ """Benchmark Gaussian generation with varying class counts."""
176
+ params = GaussianParams(n_classes=n_classes, n_samples_per_class=100, seed=42)
177
+ result = benchmark(GaussianGenerator.generate, params)
178
+ assert result["X_full"].shape[0] == n_classes * 100