juniper-data 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- juniper_data/__init__.py +88 -0
- juniper_data/__main__.py +78 -0
- juniper_data/api/__init__.py +10 -0
- juniper_data/api/app.py +111 -0
- juniper_data/api/middleware.py +95 -0
- juniper_data/api/routes/__init__.py +9 -0
- juniper_data/api/routes/datasets.py +414 -0
- juniper_data/api/routes/generators.py +125 -0
- juniper_data/api/routes/health.py +49 -0
- juniper_data/api/security.py +238 -0
- juniper_data/api/settings.py +109 -0
- juniper_data/core/__init__.py +32 -0
- juniper_data/core/artifacts.py +63 -0
- juniper_data/core/dataset_id.py +38 -0
- juniper_data/core/models.py +135 -0
- juniper_data/core/split.py +120 -0
- juniper_data/generators/__init__.py +15 -0
- juniper_data/generators/arc_agi/__init__.py +11 -0
- juniper_data/generators/arc_agi/generator.py +229 -0
- juniper_data/generators/arc_agi/params.py +56 -0
- juniper_data/generators/checkerboard/__init__.py +15 -0
- juniper_data/generators/checkerboard/generator.py +114 -0
- juniper_data/generators/checkerboard/params.py +32 -0
- juniper_data/generators/circles/__init__.py +11 -0
- juniper_data/generators/circles/generator.py +112 -0
- juniper_data/generators/circles/params.py +31 -0
- juniper_data/generators/csv_import/__init__.py +15 -0
- juniper_data/generators/csv_import/generator.py +198 -0
- juniper_data/generators/csv_import/params.py +48 -0
- juniper_data/generators/gaussian/__init__.py +11 -0
- juniper_data/generators/gaussian/generator.py +149 -0
- juniper_data/generators/gaussian/params.py +53 -0
- juniper_data/generators/mnist/__init__.py +11 -0
- juniper_data/generators/mnist/generator.py +124 -0
- juniper_data/generators/mnist/params.py +39 -0
- juniper_data/generators/spiral/__init__.py +57 -0
- juniper_data/generators/spiral/defaults.py +39 -0
- juniper_data/generators/spiral/generator.py +206 -0
- juniper_data/generators/spiral/params.py +148 -0
- juniper_data/generators/xor/__init__.py +11 -0
- juniper_data/generators/xor/generator.py +162 -0
- juniper_data/generators/xor/params.py +30 -0
- juniper_data/storage/__init__.py +120 -0
- juniper_data/storage/base.py +279 -0
- juniper_data/storage/cached.py +211 -0
- juniper_data/storage/hf_store.py +257 -0
- juniper_data/storage/kaggle_store.py +333 -0
- juniper_data/storage/local_fs.py +232 -0
- juniper_data/storage/memory.py +136 -0
- juniper_data/storage/postgres_store.py +373 -0
- juniper_data/storage/redis_store.py +264 -0
- juniper_data/tests/__init__.py +1 -0
- juniper_data/tests/conftest.py +68 -0
- juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
- juniper_data/tests/integration/__init__.py +1 -0
- juniper_data/tests/integration/test_api.py +283 -0
- juniper_data/tests/integration/test_e2e_workflow.py +378 -0
- juniper_data/tests/integration/test_lifecycle_api.py +304 -0
- juniper_data/tests/integration/test_security_integration.py +189 -0
- juniper_data/tests/integration/test_storage_workflow.py +259 -0
- juniper_data/tests/performance/__init__.py +1 -0
- juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
- juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
- juniper_data/tests/unit/__init__.py +1 -0
- juniper_data/tests/unit/test_api_app.py +206 -0
- juniper_data/tests/unit/test_api_routes.py +407 -0
- juniper_data/tests/unit/test_api_settings.py +100 -0
- juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
- juniper_data/tests/unit/test_artifacts.py +145 -0
- juniper_data/tests/unit/test_cached_store.py +423 -0
- juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
- juniper_data/tests/unit/test_circles_generator.py +256 -0
- juniper_data/tests/unit/test_csv_import_generator.py +345 -0
- juniper_data/tests/unit/test_dataset_id.py +181 -0
- juniper_data/tests/unit/test_gaussian_generator.py +333 -0
- juniper_data/tests/unit/test_hf_store.py +416 -0
- juniper_data/tests/unit/test_init.py +93 -0
- juniper_data/tests/unit/test_kaggle_store.py +469 -0
- juniper_data/tests/unit/test_lifecycle.py +394 -0
- juniper_data/tests/unit/test_main.py +127 -0
- juniper_data/tests/unit/test_middleware.py +79 -0
- juniper_data/tests/unit/test_mnist_generator.py +370 -0
- juniper_data/tests/unit/test_postgres_store.py +490 -0
- juniper_data/tests/unit/test_redis_store.py +500 -0
- juniper_data/tests/unit/test_security.py +281 -0
- juniper_data/tests/unit/test_security_boundaries.py +517 -0
- juniper_data/tests/unit/test_spiral_generator.py +566 -0
- juniper_data/tests/unit/test_split.py +245 -0
- juniper_data/tests/unit/test_storage.py +767 -0
- juniper_data/tests/unit/test_xor_generator.py +223 -0
- juniper_data-0.4.2.dist-info/METADATA +216 -0
- juniper_data-0.4.2.dist-info/RECORD +95 -0
- juniper_data-0.4.2.dist-info/WHEEL +5 -0
- juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
- juniper_data-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
"""
|
|
3
|
+
Golden Dataset Generator for JuniperData Parity Testing
|
|
4
|
+
|
|
5
|
+
This script generates golden reference datasets from the existing JuniperCascor
|
|
6
|
+
SpiralProblem implementation for use in validating the new JuniperData implementation.
|
|
7
|
+
|
|
8
|
+
Run this script from the JuniperCascor environment to generate the golden datasets.
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
# Optionally set environment variables to configure paths:
|
|
12
|
+
# JUNIPER_CASCOR_SRC - path to the JuniperCascor 'src' directory
|
|
13
|
+
# GOLDEN_DATASETS_DIR - output directory for generated golden datasets
|
|
14
|
+
#
|
|
15
|
+
# Example:
|
|
16
|
+
# export JUNIPER_CASCOR_SRC=/path/to/JuniperCascor/juniper_cascor/src
|
|
17
|
+
# export GOLDEN_DATASETS_DIR=/path/to/JuniperData/tests/fixtures/golden_datasets
|
|
18
|
+
# python -m juniper_data.tests.fixtures.generate_golden_datasets
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import json
|
|
22
|
+
import os
|
|
23
|
+
import sys
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
|
|
28
|
+
# Append JuniperCascor source directory for local script execution.
|
|
29
|
+
# The path can be configured via the JUNIPER_CASCOR_SRC environment variable.
|
|
30
|
+
# If not set, we fall back to a path derived relative to this file.
|
|
31
|
+
_default_cascor_src = Path(__file__).resolve().parents[3] / "JuniperCascor" / "juniper_cascor" / "src"
|
|
32
|
+
JUNIPER_CASCOR_SRC = Path(os.environ.get("JUNIPER_CASCOR_SRC", str(_default_cascor_src)))
|
|
33
|
+
sys.path.insert(0, str(JUNIPER_CASCOR_SRC))
|
|
34
|
+
|
|
35
|
+
from spiral_problem.spiral_problem import SpiralProblem # noqa: E402
|
|
36
|
+
|
|
37
|
+
# Directory where golden datasets will be written. Can be overridden via the
|
|
38
|
+
# GOLDEN_DATASETS_DIR environment variable; by default, we use a directory
|
|
39
|
+
# named 'golden_datasets' alongside this script.
|
|
40
|
+
GOLDEN_DATASETS_DIR = Path(
|
|
41
|
+
os.environ.get(
|
|
42
|
+
"GOLDEN_DATASETS_DIR",
|
|
43
|
+
str(Path(__file__).resolve().parent / "golden_datasets"),
|
|
44
|
+
)
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
DATASET_CONFIGS = [
|
|
48
|
+
{
|
|
49
|
+
"name": "2_spiral",
|
|
50
|
+
"n_spirals": 2,
|
|
51
|
+
"n_points": 100,
|
|
52
|
+
"noise": 0.1,
|
|
53
|
+
"seed": 42,
|
|
54
|
+
"train_ratio": 0.8,
|
|
55
|
+
"test_ratio": 0.2,
|
|
56
|
+
},
|
|
57
|
+
{
|
|
58
|
+
"name": "3_spiral",
|
|
59
|
+
"n_spirals": 3,
|
|
60
|
+
"n_points": 50,
|
|
61
|
+
"noise": 0.05,
|
|
62
|
+
"seed": 42,
|
|
63
|
+
"train_ratio": 0.8,
|
|
64
|
+
"test_ratio": 0.2,
|
|
65
|
+
},
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def generate_golden_dataset(config: dict) -> dict:
|
|
70
|
+
"""Generate a golden dataset with the specified configuration."""
|
|
71
|
+
np.random.seed(config["seed"])
|
|
72
|
+
|
|
73
|
+
import torch # noqa: E402
|
|
74
|
+
|
|
75
|
+
torch.manual_seed(config["seed"])
|
|
76
|
+
|
|
77
|
+
problem = SpiralProblem(
|
|
78
|
+
_SpiralProblem__n_spirals=config["n_spirals"],
|
|
79
|
+
_SpiralProblem__n_points=config["n_points"],
|
|
80
|
+
_SpiralProblem__noise=config["noise"],
|
|
81
|
+
_SpiralProblem__random_seed=config["seed"],
|
|
82
|
+
_SpiralProblem__train_ratio=config["train_ratio"],
|
|
83
|
+
_SpiralProblem__test_ratio=config["test_ratio"],
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
(X_train, y_train), (X_test, y_test), (X_full, y_full) = problem.generate_n_spiral_dataset(
|
|
87
|
+
n_spirals=config["n_spirals"],
|
|
88
|
+
n_points=config["n_points"],
|
|
89
|
+
noise_level=config["noise"],
|
|
90
|
+
train_ratio=config["train_ratio"],
|
|
91
|
+
test_ratio=config["test_ratio"],
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
X_train_np = X_train.numpy()
|
|
95
|
+
y_train_np = y_train.numpy()
|
|
96
|
+
X_test_np = X_test.numpy()
|
|
97
|
+
y_test_np = y_test.numpy()
|
|
98
|
+
|
|
99
|
+
metadata = {
|
|
100
|
+
"config": config,
|
|
101
|
+
"shapes": {
|
|
102
|
+
"X_train": list(X_train_np.shape),
|
|
103
|
+
"y_train": list(y_train_np.shape),
|
|
104
|
+
"X_test": list(X_test_np.shape),
|
|
105
|
+
"y_test": list(y_test_np.shape),
|
|
106
|
+
},
|
|
107
|
+
"dtypes": {
|
|
108
|
+
"X_train": str(X_train_np.dtype),
|
|
109
|
+
"y_train": str(y_train_np.dtype),
|
|
110
|
+
"X_test": str(X_test_np.dtype),
|
|
111
|
+
"y_test": str(y_test_np.dtype),
|
|
112
|
+
},
|
|
113
|
+
"class_distribution": {
|
|
114
|
+
"train": compute_class_distribution(y_train_np),
|
|
115
|
+
"test": compute_class_distribution(y_test_np),
|
|
116
|
+
},
|
|
117
|
+
"value_ranges": {
|
|
118
|
+
"X_train": {"min": float(X_train_np.min()), "max": float(X_train_np.max())},
|
|
119
|
+
"X_test": {"min": float(X_test_np.min()), "max": float(X_test_np.max())},
|
|
120
|
+
},
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
return {
|
|
124
|
+
"X_train": X_train_np,
|
|
125
|
+
"y_train": y_train_np,
|
|
126
|
+
"X_test": X_test_np,
|
|
127
|
+
"y_test": y_test_np,
|
|
128
|
+
"metadata": metadata,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def compute_class_distribution(y: np.ndarray) -> dict:
|
|
133
|
+
"""Compute class distribution from one-hot encoded labels."""
|
|
134
|
+
class_indices = np.argmax(y, axis=1)
|
|
135
|
+
unique, counts = np.unique(class_indices, return_counts=True)
|
|
136
|
+
return {f"class_{int(c)}": int(cnt) for c, cnt in zip(unique, counts)}
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def save_golden_dataset(data: dict, name: str) -> None:
|
|
140
|
+
"""Save golden dataset as NPZ file with metadata JSON."""
|
|
141
|
+
GOLDEN_DATASETS_DIR.mkdir(parents=True, exist_ok=True)
|
|
142
|
+
|
|
143
|
+
npz_path = GOLDEN_DATASETS_DIR / f"{name}.npz"
|
|
144
|
+
np.savez(
|
|
145
|
+
npz_path,
|
|
146
|
+
X_train=data["X_train"],
|
|
147
|
+
y_train=data["y_train"],
|
|
148
|
+
X_test=data["X_test"],
|
|
149
|
+
y_test=data["y_test"],
|
|
150
|
+
)
|
|
151
|
+
print(f"Saved: {npz_path}")
|
|
152
|
+
|
|
153
|
+
metadata_path = GOLDEN_DATASETS_DIR / f"{name}_metadata.json"
|
|
154
|
+
with open(metadata_path, "w") as f:
|
|
155
|
+
json.dump(data["metadata"], f, indent=2)
|
|
156
|
+
print(f"Saved: {metadata_path}")
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def print_dataset_info(data: dict, name: str) -> None:
|
|
160
|
+
"""Print dataset information for verification."""
|
|
161
|
+
meta = data["metadata"]
|
|
162
|
+
print(f"\n{'=' * 60}")
|
|
163
|
+
print(f"Dataset: {name}")
|
|
164
|
+
print(f"{'=' * 60}")
|
|
165
|
+
print("Configuration:")
|
|
166
|
+
for key, value in meta["config"].items():
|
|
167
|
+
print(f" {key}: {value}")
|
|
168
|
+
print("\nShapes:")
|
|
169
|
+
for key, shape in meta["shapes"].items():
|
|
170
|
+
print(f" {key}: {shape}")
|
|
171
|
+
print("\nDtypes:")
|
|
172
|
+
for key, dtype in meta["dtypes"].items():
|
|
173
|
+
print(f" {key}: {dtype}")
|
|
174
|
+
print("\nClass Distribution:")
|
|
175
|
+
for split, dist in meta["class_distribution"].items():
|
|
176
|
+
print(f" {split}: {dist}")
|
|
177
|
+
print("\nValue Ranges:")
|
|
178
|
+
for key, ranges in meta["value_ranges"].items():
|
|
179
|
+
print(f" {key}: min={ranges['min']:.6f}, max={ranges['max']:.6f}")
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def main():
|
|
183
|
+
"""Generate all golden datasets."""
|
|
184
|
+
print("Generating Golden Reference Datasets")
|
|
185
|
+
print("=" * 60)
|
|
186
|
+
|
|
187
|
+
for config in DATASET_CONFIGS:
|
|
188
|
+
print(f"\nGenerating {config['name']} dataset...")
|
|
189
|
+
data = generate_golden_dataset(config)
|
|
190
|
+
print_dataset_info(data, config["name"])
|
|
191
|
+
save_golden_dataset(data, config["name"])
|
|
192
|
+
|
|
193
|
+
print("\n" + "=" * 60)
|
|
194
|
+
print("Golden dataset generation complete!")
|
|
195
|
+
print(f"Output directory: {GOLDEN_DATASETS_DIR}")
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
if __name__ == "__main__":
|
|
199
|
+
main()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Integration tests for Juniper Data."""
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
"""Integration tests for the FastAPI REST API.
|
|
2
|
+
|
|
3
|
+
Tests cover all endpoints:
|
|
4
|
+
- Health check
|
|
5
|
+
- Generators listing and schema
|
|
6
|
+
- Dataset CRUD operations
|
|
7
|
+
- Artifact download
|
|
8
|
+
- Preview functionality
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import io
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pytest
|
|
15
|
+
from fastapi.testclient import TestClient
|
|
16
|
+
|
|
17
|
+
from juniper_data import __version__
|
|
18
|
+
from juniper_data.api.app import create_app
|
|
19
|
+
from juniper_data.api.routes import datasets
|
|
20
|
+
from juniper_data.api.settings import Settings
|
|
21
|
+
from juniper_data.storage.memory import InMemoryDatasetStore
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@pytest.fixture
|
|
25
|
+
def memory_store() -> InMemoryDatasetStore:
|
|
26
|
+
"""Create a fresh in-memory store for each test."""
|
|
27
|
+
return InMemoryDatasetStore()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@pytest.fixture
|
|
31
|
+
def test_settings() -> Settings:
|
|
32
|
+
"""Create test settings."""
|
|
33
|
+
return Settings(storage_path="/tmp/juniper_data_test")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@pytest.fixture
|
|
37
|
+
def client(memory_store: InMemoryDatasetStore, test_settings: Settings) -> TestClient:
|
|
38
|
+
"""Create a test client with in-memory storage."""
|
|
39
|
+
app = create_app(settings=test_settings)
|
|
40
|
+
datasets.set_store(memory_store)
|
|
41
|
+
return TestClient(app)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@pytest.fixture
|
|
45
|
+
def spiral_request() -> dict:
|
|
46
|
+
"""Default spiral dataset creation request."""
|
|
47
|
+
return {
|
|
48
|
+
"generator": "spiral",
|
|
49
|
+
"params": {
|
|
50
|
+
"n_spirals": 2,
|
|
51
|
+
"n_points_per_spiral": 50,
|
|
52
|
+
"seed": 42,
|
|
53
|
+
},
|
|
54
|
+
"persist": True,
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@pytest.mark.integration
|
|
59
|
+
class TestHealthEndpoint:
|
|
60
|
+
"""Tests for the /v1/health endpoints."""
|
|
61
|
+
|
|
62
|
+
def test_health_returns_ok(self, client: TestClient) -> None:
|
|
63
|
+
"""GET /v1/health returns {"status": "ok"}."""
|
|
64
|
+
response = client.get("/v1/health")
|
|
65
|
+
|
|
66
|
+
assert response.status_code == 200
|
|
67
|
+
data = response.json()
|
|
68
|
+
assert data["status"] == "ok"
|
|
69
|
+
|
|
70
|
+
def test_health_includes_version(self, client: TestClient) -> None:
|
|
71
|
+
"""Response includes version string."""
|
|
72
|
+
response = client.get("/v1/health")
|
|
73
|
+
|
|
74
|
+
assert response.status_code == 200
|
|
75
|
+
data = response.json()
|
|
76
|
+
assert "version" in data
|
|
77
|
+
assert data["version"] == __version__
|
|
78
|
+
|
|
79
|
+
def test_liveness_probe(self, client: TestClient) -> None:
|
|
80
|
+
"""GET /v1/health/live returns liveness status."""
|
|
81
|
+
response = client.get("/v1/health/live")
|
|
82
|
+
|
|
83
|
+
assert response.status_code == 200
|
|
84
|
+
data = response.json()
|
|
85
|
+
assert data["status"] == "alive"
|
|
86
|
+
|
|
87
|
+
def test_readiness_probe(self, client: TestClient) -> None:
|
|
88
|
+
"""GET /v1/health/ready returns readiness status with version."""
|
|
89
|
+
response = client.get("/v1/health/ready")
|
|
90
|
+
|
|
91
|
+
assert response.status_code == 200
|
|
92
|
+
data = response.json()
|
|
93
|
+
assert data["status"] == "ready"
|
|
94
|
+
assert data["version"] == __version__
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@pytest.mark.integration
|
|
98
|
+
class TestGeneratorsEndpoint:
|
|
99
|
+
"""Tests for the /v1/generators endpoints."""
|
|
100
|
+
|
|
101
|
+
def test_list_generators(self, client: TestClient) -> None:
|
|
102
|
+
"""GET /v1/generators returns list with "spiral"."""
|
|
103
|
+
response = client.get("/v1/generators")
|
|
104
|
+
|
|
105
|
+
assert response.status_code == 200
|
|
106
|
+
data = response.json()
|
|
107
|
+
assert isinstance(data, list)
|
|
108
|
+
assert len(data) >= 1
|
|
109
|
+
|
|
110
|
+
generator_names = [g["name"] for g in data]
|
|
111
|
+
assert "spiral" in generator_names
|
|
112
|
+
|
|
113
|
+
def test_get_generator_schema(self, client: TestClient) -> None:
|
|
114
|
+
"""GET /v1/generators/spiral/schema returns valid schema."""
|
|
115
|
+
response = client.get("/v1/generators/spiral/schema")
|
|
116
|
+
|
|
117
|
+
assert response.status_code == 200
|
|
118
|
+
schema = response.json()
|
|
119
|
+
assert isinstance(schema, dict)
|
|
120
|
+
assert "properties" in schema
|
|
121
|
+
assert "n_spirals" in schema["properties"]
|
|
122
|
+
assert "n_points_per_spiral" in schema["properties"]
|
|
123
|
+
|
|
124
|
+
def test_unknown_generator_404(self, client: TestClient) -> None:
|
|
125
|
+
"""GET /v1/generators/unknown/schema returns 404."""
|
|
126
|
+
response = client.get("/v1/generators/unknown/schema")
|
|
127
|
+
|
|
128
|
+
assert response.status_code == 404
|
|
129
|
+
data = response.json()
|
|
130
|
+
assert "detail" in data
|
|
131
|
+
assert "unknown" in data["detail"].lower()
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@pytest.mark.integration
|
|
135
|
+
class TestDatasetsEndpoint:
|
|
136
|
+
"""Tests for the /v1/datasets endpoints."""
|
|
137
|
+
|
|
138
|
+
def test_create_spiral_dataset(self, client: TestClient, spiral_request: dict) -> None:
|
|
139
|
+
"""POST /v1/datasets creates dataset and returns meta."""
|
|
140
|
+
response = client.post("/v1/datasets", json=spiral_request)
|
|
141
|
+
|
|
142
|
+
assert response.status_code == 201
|
|
143
|
+
data = response.json()
|
|
144
|
+
assert "dataset_id" in data
|
|
145
|
+
assert "meta" in data
|
|
146
|
+
assert data["generator"] == "spiral"
|
|
147
|
+
assert data["meta"]["generator"] == "spiral"
|
|
148
|
+
assert data["meta"]["n_samples"] == 100
|
|
149
|
+
|
|
150
|
+
def test_create_returns_artifact_url(self, client: TestClient, spiral_request: dict) -> None:
|
|
151
|
+
"""Response includes artifact_url."""
|
|
152
|
+
response = client.post("/v1/datasets", json=spiral_request)
|
|
153
|
+
|
|
154
|
+
assert response.status_code == 201
|
|
155
|
+
data = response.json()
|
|
156
|
+
assert "artifact_url" in data
|
|
157
|
+
assert "/v1/datasets/" in data["artifact_url"]
|
|
158
|
+
assert "/artifact" in data["artifact_url"]
|
|
159
|
+
|
|
160
|
+
def test_list_datasets(self, client: TestClient, spiral_request: dict) -> None:
|
|
161
|
+
"""GET /v1/datasets returns list after creation."""
|
|
162
|
+
client.post("/v1/datasets", json=spiral_request)
|
|
163
|
+
|
|
164
|
+
response = client.get("/v1/datasets")
|
|
165
|
+
|
|
166
|
+
assert response.status_code == 200
|
|
167
|
+
data = response.json()
|
|
168
|
+
assert isinstance(data, list)
|
|
169
|
+
assert len(data) >= 1
|
|
170
|
+
|
|
171
|
+
def test_get_dataset_meta(self, client: TestClient, spiral_request: dict) -> None:
|
|
172
|
+
"""GET /v1/datasets/{id} returns metadata."""
|
|
173
|
+
create_response = client.post("/v1/datasets", json=spiral_request)
|
|
174
|
+
dataset_id = create_response.json()["dataset_id"]
|
|
175
|
+
|
|
176
|
+
response = client.get(f"/v1/datasets/{dataset_id}")
|
|
177
|
+
|
|
178
|
+
assert response.status_code == 200
|
|
179
|
+
data = response.json()
|
|
180
|
+
assert data["dataset_id"] == dataset_id
|
|
181
|
+
assert data["generator"] == "spiral"
|
|
182
|
+
assert "n_samples" in data
|
|
183
|
+
|
|
184
|
+
def test_get_dataset_404(self, client: TestClient) -> None:
|
|
185
|
+
"""GET /v1/datasets/nonexistent returns 404."""
|
|
186
|
+
response = client.get("/v1/datasets/nonexistent")
|
|
187
|
+
|
|
188
|
+
assert response.status_code == 404
|
|
189
|
+
data = response.json()
|
|
190
|
+
assert "detail" in data
|
|
191
|
+
|
|
192
|
+
def test_delete_dataset(self, client: TestClient, spiral_request: dict) -> None:
|
|
193
|
+
"""DELETE /v1/datasets/{id} returns 204."""
|
|
194
|
+
create_response = client.post("/v1/datasets", json=spiral_request)
|
|
195
|
+
dataset_id = create_response.json()["dataset_id"]
|
|
196
|
+
|
|
197
|
+
response = client.delete(f"/v1/datasets/{dataset_id}")
|
|
198
|
+
|
|
199
|
+
assert response.status_code == 204
|
|
200
|
+
|
|
201
|
+
get_response = client.get(f"/v1/datasets/{dataset_id}")
|
|
202
|
+
assert get_response.status_code == 404
|
|
203
|
+
|
|
204
|
+
def test_caching_same_params(self, client: TestClient, spiral_request: dict) -> None:
|
|
205
|
+
"""Same params twice returns same dataset_id (no regeneration)."""
|
|
206
|
+
response1 = client.post("/v1/datasets", json=spiral_request)
|
|
207
|
+
response2 = client.post("/v1/datasets", json=spiral_request)
|
|
208
|
+
|
|
209
|
+
assert response1.status_code == 201
|
|
210
|
+
assert response2.status_code == 201
|
|
211
|
+
|
|
212
|
+
data1 = response1.json()
|
|
213
|
+
data2 = response2.json()
|
|
214
|
+
assert data1["dataset_id"] == data2["dataset_id"]
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@pytest.mark.integration
|
|
218
|
+
class TestArtifactEndpoint:
|
|
219
|
+
"""Tests for the /v1/datasets/{id}/artifact endpoint."""
|
|
220
|
+
|
|
221
|
+
def test_download_artifact(self, client: TestClient, spiral_request: dict) -> None:
|
|
222
|
+
"""GET /v1/datasets/{id}/artifact returns NPZ bytes."""
|
|
223
|
+
create_response = client.post("/v1/datasets", json=spiral_request)
|
|
224
|
+
dataset_id = create_response.json()["dataset_id"]
|
|
225
|
+
|
|
226
|
+
response = client.get(f"/v1/datasets/{dataset_id}/artifact")
|
|
227
|
+
|
|
228
|
+
assert response.status_code == 200
|
|
229
|
+
assert response.headers["content-type"] == "application/octet-stream"
|
|
230
|
+
assert len(response.content) > 0
|
|
231
|
+
|
|
232
|
+
with np.load(io.BytesIO(response.content)) as data:
|
|
233
|
+
assert len(data.files) > 0
|
|
234
|
+
|
|
235
|
+
def test_artifact_contains_expected_keys(self, client: TestClient, spiral_request: dict) -> None:
|
|
236
|
+
"""NPZ has X_train, y_train, X_test, y_test."""
|
|
237
|
+
create_response = client.post("/v1/datasets", json=spiral_request)
|
|
238
|
+
dataset_id = create_response.json()["dataset_id"]
|
|
239
|
+
|
|
240
|
+
response = client.get(f"/v1/datasets/{dataset_id}/artifact")
|
|
241
|
+
|
|
242
|
+
assert response.status_code == 200
|
|
243
|
+
|
|
244
|
+
with np.load(io.BytesIO(response.content)) as data:
|
|
245
|
+
assert "X_train" in data.files
|
|
246
|
+
assert "y_train" in data.files
|
|
247
|
+
assert "X_test" in data.files
|
|
248
|
+
assert "y_test" in data.files
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
@pytest.mark.integration
|
|
252
|
+
class TestPreviewEndpoint:
|
|
253
|
+
"""Tests for the /v1/datasets/{id}/preview endpoint."""
|
|
254
|
+
|
|
255
|
+
def test_preview_returns_samples(self, client: TestClient, spiral_request: dict) -> None:
|
|
256
|
+
"""GET /v1/datasets/{id}/preview returns JSON with samples."""
|
|
257
|
+
create_response = client.post("/v1/datasets", json=spiral_request)
|
|
258
|
+
dataset_id = create_response.json()["dataset_id"]
|
|
259
|
+
|
|
260
|
+
response = client.get(f"/v1/datasets/{dataset_id}/preview")
|
|
261
|
+
|
|
262
|
+
assert response.status_code == 200
|
|
263
|
+
data = response.json()
|
|
264
|
+
assert "n_samples" in data
|
|
265
|
+
assert "X_sample" in data
|
|
266
|
+
assert "y_sample" in data
|
|
267
|
+
assert isinstance(data["X_sample"], list)
|
|
268
|
+
assert isinstance(data["y_sample"], list)
|
|
269
|
+
assert len(data["X_sample"]) > 0
|
|
270
|
+
assert len(data["y_sample"]) > 0
|
|
271
|
+
|
|
272
|
+
def test_preview_respects_n_param(self, client: TestClient, spiral_request: dict) -> None:
|
|
273
|
+
"""?n=10 returns 10 samples."""
|
|
274
|
+
create_response = client.post("/v1/datasets", json=spiral_request)
|
|
275
|
+
dataset_id = create_response.json()["dataset_id"]
|
|
276
|
+
|
|
277
|
+
response = client.get(f"/v1/datasets/{dataset_id}/preview?n=10")
|
|
278
|
+
|
|
279
|
+
assert response.status_code == 200
|
|
280
|
+
data = response.json()
|
|
281
|
+
assert data["n_samples"] == 10
|
|
282
|
+
assert len(data["X_sample"]) == 10
|
|
283
|
+
assert len(data["y_sample"]) == 10
|