juniper-data 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. juniper_data/__init__.py +88 -0
  2. juniper_data/__main__.py +78 -0
  3. juniper_data/api/__init__.py +10 -0
  4. juniper_data/api/app.py +111 -0
  5. juniper_data/api/middleware.py +95 -0
  6. juniper_data/api/routes/__init__.py +9 -0
  7. juniper_data/api/routes/datasets.py +414 -0
  8. juniper_data/api/routes/generators.py +125 -0
  9. juniper_data/api/routes/health.py +49 -0
  10. juniper_data/api/security.py +238 -0
  11. juniper_data/api/settings.py +109 -0
  12. juniper_data/core/__init__.py +32 -0
  13. juniper_data/core/artifacts.py +63 -0
  14. juniper_data/core/dataset_id.py +38 -0
  15. juniper_data/core/models.py +135 -0
  16. juniper_data/core/split.py +120 -0
  17. juniper_data/generators/__init__.py +15 -0
  18. juniper_data/generators/arc_agi/__init__.py +11 -0
  19. juniper_data/generators/arc_agi/generator.py +229 -0
  20. juniper_data/generators/arc_agi/params.py +56 -0
  21. juniper_data/generators/checkerboard/__init__.py +15 -0
  22. juniper_data/generators/checkerboard/generator.py +114 -0
  23. juniper_data/generators/checkerboard/params.py +32 -0
  24. juniper_data/generators/circles/__init__.py +11 -0
  25. juniper_data/generators/circles/generator.py +112 -0
  26. juniper_data/generators/circles/params.py +31 -0
  27. juniper_data/generators/csv_import/__init__.py +15 -0
  28. juniper_data/generators/csv_import/generator.py +198 -0
  29. juniper_data/generators/csv_import/params.py +48 -0
  30. juniper_data/generators/gaussian/__init__.py +11 -0
  31. juniper_data/generators/gaussian/generator.py +149 -0
  32. juniper_data/generators/gaussian/params.py +53 -0
  33. juniper_data/generators/mnist/__init__.py +11 -0
  34. juniper_data/generators/mnist/generator.py +124 -0
  35. juniper_data/generators/mnist/params.py +39 -0
  36. juniper_data/generators/spiral/__init__.py +57 -0
  37. juniper_data/generators/spiral/defaults.py +39 -0
  38. juniper_data/generators/spiral/generator.py +206 -0
  39. juniper_data/generators/spiral/params.py +148 -0
  40. juniper_data/generators/xor/__init__.py +11 -0
  41. juniper_data/generators/xor/generator.py +162 -0
  42. juniper_data/generators/xor/params.py +30 -0
  43. juniper_data/storage/__init__.py +120 -0
  44. juniper_data/storage/base.py +279 -0
  45. juniper_data/storage/cached.py +211 -0
  46. juniper_data/storage/hf_store.py +257 -0
  47. juniper_data/storage/kaggle_store.py +333 -0
  48. juniper_data/storage/local_fs.py +232 -0
  49. juniper_data/storage/memory.py +136 -0
  50. juniper_data/storage/postgres_store.py +373 -0
  51. juniper_data/storage/redis_store.py +264 -0
  52. juniper_data/tests/__init__.py +1 -0
  53. juniper_data/tests/conftest.py +68 -0
  54. juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
  55. juniper_data/tests/integration/__init__.py +1 -0
  56. juniper_data/tests/integration/test_api.py +283 -0
  57. juniper_data/tests/integration/test_e2e_workflow.py +378 -0
  58. juniper_data/tests/integration/test_lifecycle_api.py +304 -0
  59. juniper_data/tests/integration/test_security_integration.py +189 -0
  60. juniper_data/tests/integration/test_storage_workflow.py +259 -0
  61. juniper_data/tests/performance/__init__.py +1 -0
  62. juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
  63. juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
  64. juniper_data/tests/unit/__init__.py +1 -0
  65. juniper_data/tests/unit/test_api_app.py +206 -0
  66. juniper_data/tests/unit/test_api_routes.py +407 -0
  67. juniper_data/tests/unit/test_api_settings.py +100 -0
  68. juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
  69. juniper_data/tests/unit/test_artifacts.py +145 -0
  70. juniper_data/tests/unit/test_cached_store.py +423 -0
  71. juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
  72. juniper_data/tests/unit/test_circles_generator.py +256 -0
  73. juniper_data/tests/unit/test_csv_import_generator.py +345 -0
  74. juniper_data/tests/unit/test_dataset_id.py +181 -0
  75. juniper_data/tests/unit/test_gaussian_generator.py +333 -0
  76. juniper_data/tests/unit/test_hf_store.py +416 -0
  77. juniper_data/tests/unit/test_init.py +93 -0
  78. juniper_data/tests/unit/test_kaggle_store.py +469 -0
  79. juniper_data/tests/unit/test_lifecycle.py +394 -0
  80. juniper_data/tests/unit/test_main.py +127 -0
  81. juniper_data/tests/unit/test_middleware.py +79 -0
  82. juniper_data/tests/unit/test_mnist_generator.py +370 -0
  83. juniper_data/tests/unit/test_postgres_store.py +490 -0
  84. juniper_data/tests/unit/test_redis_store.py +500 -0
  85. juniper_data/tests/unit/test_security.py +281 -0
  86. juniper_data/tests/unit/test_security_boundaries.py +517 -0
  87. juniper_data/tests/unit/test_spiral_generator.py +566 -0
  88. juniper_data/tests/unit/test_split.py +245 -0
  89. juniper_data/tests/unit/test_storage.py +767 -0
  90. juniper_data/tests/unit/test_xor_generator.py +223 -0
  91. juniper_data-0.4.2.dist-info/METADATA +216 -0
  92. juniper_data-0.4.2.dist-info/RECORD +95 -0
  93. juniper_data-0.4.2.dist-info/WHEEL +5 -0
  94. juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
  95. juniper_data-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,199 @@
1
+ #!/usr/bin/env python
2
+ """
3
+ Golden Dataset Generator for JuniperData Parity Testing
4
+
5
+ This script generates golden reference datasets from the existing JuniperCascor
6
+ SpiralProblem implementation for use in validating the new JuniperData implementation.
7
+
8
+ Run this script from the JuniperCascor environment to generate the golden datasets.
9
+
10
+ Usage:
11
+ # Optionally set environment variables to configure paths:
12
+ # JUNIPER_CASCOR_SRC - path to the JuniperCascor 'src' directory
13
+ # GOLDEN_DATASETS_DIR - output directory for generated golden datasets
14
+ #
15
+ # Example:
16
+ # export JUNIPER_CASCOR_SRC=/path/to/JuniperCascor/juniper_cascor/src
17
+ # export GOLDEN_DATASETS_DIR=/path/to/JuniperData/tests/fixtures/golden_datasets
18
+ # python -m juniper_data.tests.fixtures.generate_golden_datasets
19
+ """
20
+
21
+ import json
22
+ import os
23
+ import sys
24
+ from pathlib import Path
25
+
26
+ import numpy as np
27
+
28
+ # Append JuniperCascor source directory for local script execution.
29
+ # The path can be configured via the JUNIPER_CASCOR_SRC environment variable.
30
+ # If not set, we fall back to a path derived relative to this file.
31
+ _default_cascor_src = Path(__file__).resolve().parents[3] / "JuniperCascor" / "juniper_cascor" / "src"
32
+ JUNIPER_CASCOR_SRC = Path(os.environ.get("JUNIPER_CASCOR_SRC", str(_default_cascor_src)))
33
+ sys.path.insert(0, str(JUNIPER_CASCOR_SRC))
34
+
35
+ from spiral_problem.spiral_problem import SpiralProblem # noqa: E402
36
+
37
+ # Directory where golden datasets will be written. Can be overridden via the
38
+ # GOLDEN_DATASETS_DIR environment variable; by default, we use a directory
39
+ # named 'golden_datasets' alongside this script.
40
+ GOLDEN_DATASETS_DIR = Path(
41
+ os.environ.get(
42
+ "GOLDEN_DATASETS_DIR",
43
+ str(Path(__file__).resolve().parent / "golden_datasets"),
44
+ )
45
+ )
46
+
47
+ DATASET_CONFIGS = [
48
+ {
49
+ "name": "2_spiral",
50
+ "n_spirals": 2,
51
+ "n_points": 100,
52
+ "noise": 0.1,
53
+ "seed": 42,
54
+ "train_ratio": 0.8,
55
+ "test_ratio": 0.2,
56
+ },
57
+ {
58
+ "name": "3_spiral",
59
+ "n_spirals": 3,
60
+ "n_points": 50,
61
+ "noise": 0.05,
62
+ "seed": 42,
63
+ "train_ratio": 0.8,
64
+ "test_ratio": 0.2,
65
+ },
66
+ ]
67
+
68
+
69
+ def generate_golden_dataset(config: dict) -> dict:
70
+ """Generate a golden dataset with the specified configuration."""
71
+ np.random.seed(config["seed"])
72
+
73
+ import torch # noqa: E402
74
+
75
+ torch.manual_seed(config["seed"])
76
+
77
+ problem = SpiralProblem(
78
+ _SpiralProblem__n_spirals=config["n_spirals"],
79
+ _SpiralProblem__n_points=config["n_points"],
80
+ _SpiralProblem__noise=config["noise"],
81
+ _SpiralProblem__random_seed=config["seed"],
82
+ _SpiralProblem__train_ratio=config["train_ratio"],
83
+ _SpiralProblem__test_ratio=config["test_ratio"],
84
+ )
85
+
86
+ (X_train, y_train), (X_test, y_test), (X_full, y_full) = problem.generate_n_spiral_dataset(
87
+ n_spirals=config["n_spirals"],
88
+ n_points=config["n_points"],
89
+ noise_level=config["noise"],
90
+ train_ratio=config["train_ratio"],
91
+ test_ratio=config["test_ratio"],
92
+ )
93
+
94
+ X_train_np = X_train.numpy()
95
+ y_train_np = y_train.numpy()
96
+ X_test_np = X_test.numpy()
97
+ y_test_np = y_test.numpy()
98
+
99
+ metadata = {
100
+ "config": config,
101
+ "shapes": {
102
+ "X_train": list(X_train_np.shape),
103
+ "y_train": list(y_train_np.shape),
104
+ "X_test": list(X_test_np.shape),
105
+ "y_test": list(y_test_np.shape),
106
+ },
107
+ "dtypes": {
108
+ "X_train": str(X_train_np.dtype),
109
+ "y_train": str(y_train_np.dtype),
110
+ "X_test": str(X_test_np.dtype),
111
+ "y_test": str(y_test_np.dtype),
112
+ },
113
+ "class_distribution": {
114
+ "train": compute_class_distribution(y_train_np),
115
+ "test": compute_class_distribution(y_test_np),
116
+ },
117
+ "value_ranges": {
118
+ "X_train": {"min": float(X_train_np.min()), "max": float(X_train_np.max())},
119
+ "X_test": {"min": float(X_test_np.min()), "max": float(X_test_np.max())},
120
+ },
121
+ }
122
+
123
+ return {
124
+ "X_train": X_train_np,
125
+ "y_train": y_train_np,
126
+ "X_test": X_test_np,
127
+ "y_test": y_test_np,
128
+ "metadata": metadata,
129
+ }
130
+
131
+
132
+ def compute_class_distribution(y: np.ndarray) -> dict:
133
+ """Compute class distribution from one-hot encoded labels."""
134
+ class_indices = np.argmax(y, axis=1)
135
+ unique, counts = np.unique(class_indices, return_counts=True)
136
+ return {f"class_{int(c)}": int(cnt) for c, cnt in zip(unique, counts)}
137
+
138
+
139
+ def save_golden_dataset(data: dict, name: str) -> None:
140
+ """Save golden dataset as NPZ file with metadata JSON."""
141
+ GOLDEN_DATASETS_DIR.mkdir(parents=True, exist_ok=True)
142
+
143
+ npz_path = GOLDEN_DATASETS_DIR / f"{name}.npz"
144
+ np.savez(
145
+ npz_path,
146
+ X_train=data["X_train"],
147
+ y_train=data["y_train"],
148
+ X_test=data["X_test"],
149
+ y_test=data["y_test"],
150
+ )
151
+ print(f"Saved: {npz_path}")
152
+
153
+ metadata_path = GOLDEN_DATASETS_DIR / f"{name}_metadata.json"
154
+ with open(metadata_path, "w") as f:
155
+ json.dump(data["metadata"], f, indent=2)
156
+ print(f"Saved: {metadata_path}")
157
+
158
+
159
+ def print_dataset_info(data: dict, name: str) -> None:
160
+ """Print dataset information for verification."""
161
+ meta = data["metadata"]
162
+ print(f"\n{'=' * 60}")
163
+ print(f"Dataset: {name}")
164
+ print(f"{'=' * 60}")
165
+ print("Configuration:")
166
+ for key, value in meta["config"].items():
167
+ print(f" {key}: {value}")
168
+ print("\nShapes:")
169
+ for key, shape in meta["shapes"].items():
170
+ print(f" {key}: {shape}")
171
+ print("\nDtypes:")
172
+ for key, dtype in meta["dtypes"].items():
173
+ print(f" {key}: {dtype}")
174
+ print("\nClass Distribution:")
175
+ for split, dist in meta["class_distribution"].items():
176
+ print(f" {split}: {dist}")
177
+ print("\nValue Ranges:")
178
+ for key, ranges in meta["value_ranges"].items():
179
+ print(f" {key}: min={ranges['min']:.6f}, max={ranges['max']:.6f}")
180
+
181
+
182
+ def main():
183
+ """Generate all golden datasets."""
184
+ print("Generating Golden Reference Datasets")
185
+ print("=" * 60)
186
+
187
+ for config in DATASET_CONFIGS:
188
+ print(f"\nGenerating {config['name']} dataset...")
189
+ data = generate_golden_dataset(config)
190
+ print_dataset_info(data, config["name"])
191
+ save_golden_dataset(data, config["name"])
192
+
193
+ print("\n" + "=" * 60)
194
+ print("Golden dataset generation complete!")
195
+ print(f"Output directory: {GOLDEN_DATASETS_DIR}")
196
+
197
+
198
+ if __name__ == "__main__":
199
+ main()
@@ -0,0 +1 @@
1
+ """Integration tests for Juniper Data."""
@@ -0,0 +1,283 @@
1
+ """Integration tests for the FastAPI REST API.
2
+
3
+ Tests cover all endpoints:
4
+ - Health check
5
+ - Generators listing and schema
6
+ - Dataset CRUD operations
7
+ - Artifact download
8
+ - Preview functionality
9
+ """
10
+
11
+ import io
12
+
13
+ import numpy as np
14
+ import pytest
15
+ from fastapi.testclient import TestClient
16
+
17
+ from juniper_data import __version__
18
+ from juniper_data.api.app import create_app
19
+ from juniper_data.api.routes import datasets
20
+ from juniper_data.api.settings import Settings
21
+ from juniper_data.storage.memory import InMemoryDatasetStore
22
+
23
+
24
+ @pytest.fixture
25
+ def memory_store() -> InMemoryDatasetStore:
26
+ """Create a fresh in-memory store for each test."""
27
+ return InMemoryDatasetStore()
28
+
29
+
30
+ @pytest.fixture
31
+ def test_settings() -> Settings:
32
+ """Create test settings."""
33
+ return Settings(storage_path="/tmp/juniper_data_test")
34
+
35
+
36
+ @pytest.fixture
37
+ def client(memory_store: InMemoryDatasetStore, test_settings: Settings) -> TestClient:
38
+ """Create a test client with in-memory storage."""
39
+ app = create_app(settings=test_settings)
40
+ datasets.set_store(memory_store)
41
+ return TestClient(app)
42
+
43
+
44
+ @pytest.fixture
45
+ def spiral_request() -> dict:
46
+ """Default spiral dataset creation request."""
47
+ return {
48
+ "generator": "spiral",
49
+ "params": {
50
+ "n_spirals": 2,
51
+ "n_points_per_spiral": 50,
52
+ "seed": 42,
53
+ },
54
+ "persist": True,
55
+ }
56
+
57
+
58
+ @pytest.mark.integration
59
+ class TestHealthEndpoint:
60
+ """Tests for the /v1/health endpoints."""
61
+
62
+ def test_health_returns_ok(self, client: TestClient) -> None:
63
+ """GET /v1/health returns {"status": "ok"}."""
64
+ response = client.get("/v1/health")
65
+
66
+ assert response.status_code == 200
67
+ data = response.json()
68
+ assert data["status"] == "ok"
69
+
70
+ def test_health_includes_version(self, client: TestClient) -> None:
71
+ """Response includes version string."""
72
+ response = client.get("/v1/health")
73
+
74
+ assert response.status_code == 200
75
+ data = response.json()
76
+ assert "version" in data
77
+ assert data["version"] == __version__
78
+
79
+ def test_liveness_probe(self, client: TestClient) -> None:
80
+ """GET /v1/health/live returns liveness status."""
81
+ response = client.get("/v1/health/live")
82
+
83
+ assert response.status_code == 200
84
+ data = response.json()
85
+ assert data["status"] == "alive"
86
+
87
+ def test_readiness_probe(self, client: TestClient) -> None:
88
+ """GET /v1/health/ready returns readiness status with version."""
89
+ response = client.get("/v1/health/ready")
90
+
91
+ assert response.status_code == 200
92
+ data = response.json()
93
+ assert data["status"] == "ready"
94
+ assert data["version"] == __version__
95
+
96
+
97
+ @pytest.mark.integration
98
+ class TestGeneratorsEndpoint:
99
+ """Tests for the /v1/generators endpoints."""
100
+
101
+ def test_list_generators(self, client: TestClient) -> None:
102
+ """GET /v1/generators returns list with "spiral"."""
103
+ response = client.get("/v1/generators")
104
+
105
+ assert response.status_code == 200
106
+ data = response.json()
107
+ assert isinstance(data, list)
108
+ assert len(data) >= 1
109
+
110
+ generator_names = [g["name"] for g in data]
111
+ assert "spiral" in generator_names
112
+
113
+ def test_get_generator_schema(self, client: TestClient) -> None:
114
+ """GET /v1/generators/spiral/schema returns valid schema."""
115
+ response = client.get("/v1/generators/spiral/schema")
116
+
117
+ assert response.status_code == 200
118
+ schema = response.json()
119
+ assert isinstance(schema, dict)
120
+ assert "properties" in schema
121
+ assert "n_spirals" in schema["properties"]
122
+ assert "n_points_per_spiral" in schema["properties"]
123
+
124
+ def test_unknown_generator_404(self, client: TestClient) -> None:
125
+ """GET /v1/generators/unknown/schema returns 404."""
126
+ response = client.get("/v1/generators/unknown/schema")
127
+
128
+ assert response.status_code == 404
129
+ data = response.json()
130
+ assert "detail" in data
131
+ assert "unknown" in data["detail"].lower()
132
+
133
+
134
+ @pytest.mark.integration
135
+ class TestDatasetsEndpoint:
136
+ """Tests for the /v1/datasets endpoints."""
137
+
138
+ def test_create_spiral_dataset(self, client: TestClient, spiral_request: dict) -> None:
139
+ """POST /v1/datasets creates dataset and returns meta."""
140
+ response = client.post("/v1/datasets", json=spiral_request)
141
+
142
+ assert response.status_code == 201
143
+ data = response.json()
144
+ assert "dataset_id" in data
145
+ assert "meta" in data
146
+ assert data["generator"] == "spiral"
147
+ assert data["meta"]["generator"] == "spiral"
148
+ assert data["meta"]["n_samples"] == 100
149
+
150
+ def test_create_returns_artifact_url(self, client: TestClient, spiral_request: dict) -> None:
151
+ """Response includes artifact_url."""
152
+ response = client.post("/v1/datasets", json=spiral_request)
153
+
154
+ assert response.status_code == 201
155
+ data = response.json()
156
+ assert "artifact_url" in data
157
+ assert "/v1/datasets/" in data["artifact_url"]
158
+ assert "/artifact" in data["artifact_url"]
159
+
160
+ def test_list_datasets(self, client: TestClient, spiral_request: dict) -> None:
161
+ """GET /v1/datasets returns list after creation."""
162
+ client.post("/v1/datasets", json=spiral_request)
163
+
164
+ response = client.get("/v1/datasets")
165
+
166
+ assert response.status_code == 200
167
+ data = response.json()
168
+ assert isinstance(data, list)
169
+ assert len(data) >= 1
170
+
171
+ def test_get_dataset_meta(self, client: TestClient, spiral_request: dict) -> None:
172
+ """GET /v1/datasets/{id} returns metadata."""
173
+ create_response = client.post("/v1/datasets", json=spiral_request)
174
+ dataset_id = create_response.json()["dataset_id"]
175
+
176
+ response = client.get(f"/v1/datasets/{dataset_id}")
177
+
178
+ assert response.status_code == 200
179
+ data = response.json()
180
+ assert data["dataset_id"] == dataset_id
181
+ assert data["generator"] == "spiral"
182
+ assert "n_samples" in data
183
+
184
+ def test_get_dataset_404(self, client: TestClient) -> None:
185
+ """GET /v1/datasets/nonexistent returns 404."""
186
+ response = client.get("/v1/datasets/nonexistent")
187
+
188
+ assert response.status_code == 404
189
+ data = response.json()
190
+ assert "detail" in data
191
+
192
+ def test_delete_dataset(self, client: TestClient, spiral_request: dict) -> None:
193
+ """DELETE /v1/datasets/{id} returns 204."""
194
+ create_response = client.post("/v1/datasets", json=spiral_request)
195
+ dataset_id = create_response.json()["dataset_id"]
196
+
197
+ response = client.delete(f"/v1/datasets/{dataset_id}")
198
+
199
+ assert response.status_code == 204
200
+
201
+ get_response = client.get(f"/v1/datasets/{dataset_id}")
202
+ assert get_response.status_code == 404
203
+
204
+ def test_caching_same_params(self, client: TestClient, spiral_request: dict) -> None:
205
+ """Same params twice returns same dataset_id (no regeneration)."""
206
+ response1 = client.post("/v1/datasets", json=spiral_request)
207
+ response2 = client.post("/v1/datasets", json=spiral_request)
208
+
209
+ assert response1.status_code == 201
210
+ assert response2.status_code == 201
211
+
212
+ data1 = response1.json()
213
+ data2 = response2.json()
214
+ assert data1["dataset_id"] == data2["dataset_id"]
215
+
216
+
217
+ @pytest.mark.integration
218
+ class TestArtifactEndpoint:
219
+ """Tests for the /v1/datasets/{id}/artifact endpoint."""
220
+
221
+ def test_download_artifact(self, client: TestClient, spiral_request: dict) -> None:
222
+ """GET /v1/datasets/{id}/artifact returns NPZ bytes."""
223
+ create_response = client.post("/v1/datasets", json=spiral_request)
224
+ dataset_id = create_response.json()["dataset_id"]
225
+
226
+ response = client.get(f"/v1/datasets/{dataset_id}/artifact")
227
+
228
+ assert response.status_code == 200
229
+ assert response.headers["content-type"] == "application/octet-stream"
230
+ assert len(response.content) > 0
231
+
232
+ with np.load(io.BytesIO(response.content)) as data:
233
+ assert len(data.files) > 0
234
+
235
+ def test_artifact_contains_expected_keys(self, client: TestClient, spiral_request: dict) -> None:
236
+ """NPZ has X_train, y_train, X_test, y_test."""
237
+ create_response = client.post("/v1/datasets", json=spiral_request)
238
+ dataset_id = create_response.json()["dataset_id"]
239
+
240
+ response = client.get(f"/v1/datasets/{dataset_id}/artifact")
241
+
242
+ assert response.status_code == 200
243
+
244
+ with np.load(io.BytesIO(response.content)) as data:
245
+ assert "X_train" in data.files
246
+ assert "y_train" in data.files
247
+ assert "X_test" in data.files
248
+ assert "y_test" in data.files
249
+
250
+
251
+ @pytest.mark.integration
252
+ class TestPreviewEndpoint:
253
+ """Tests for the /v1/datasets/{id}/preview endpoint."""
254
+
255
+ def test_preview_returns_samples(self, client: TestClient, spiral_request: dict) -> None:
256
+ """GET /v1/datasets/{id}/preview returns JSON with samples."""
257
+ create_response = client.post("/v1/datasets", json=spiral_request)
258
+ dataset_id = create_response.json()["dataset_id"]
259
+
260
+ response = client.get(f"/v1/datasets/{dataset_id}/preview")
261
+
262
+ assert response.status_code == 200
263
+ data = response.json()
264
+ assert "n_samples" in data
265
+ assert "X_sample" in data
266
+ assert "y_sample" in data
267
+ assert isinstance(data["X_sample"], list)
268
+ assert isinstance(data["y_sample"], list)
269
+ assert len(data["X_sample"]) > 0
270
+ assert len(data["y_sample"]) > 0
271
+
272
+ def test_preview_respects_n_param(self, client: TestClient, spiral_request: dict) -> None:
273
+ """?n=10 returns 10 samples."""
274
+ create_response = client.post("/v1/datasets", json=spiral_request)
275
+ dataset_id = create_response.json()["dataset_id"]
276
+
277
+ response = client.get(f"/v1/datasets/{dataset_id}/preview?n=10")
278
+
279
+ assert response.status_code == 200
280
+ data = response.json()
281
+ assert data["n_samples"] == 10
282
+ assert len(data["X_sample"]) == 10
283
+ assert len(data["y_sample"]) == 10