juniper-data 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- juniper_data/__init__.py +88 -0
- juniper_data/__main__.py +78 -0
- juniper_data/api/__init__.py +10 -0
- juniper_data/api/app.py +111 -0
- juniper_data/api/middleware.py +95 -0
- juniper_data/api/routes/__init__.py +9 -0
- juniper_data/api/routes/datasets.py +414 -0
- juniper_data/api/routes/generators.py +125 -0
- juniper_data/api/routes/health.py +49 -0
- juniper_data/api/security.py +238 -0
- juniper_data/api/settings.py +109 -0
- juniper_data/core/__init__.py +32 -0
- juniper_data/core/artifacts.py +63 -0
- juniper_data/core/dataset_id.py +38 -0
- juniper_data/core/models.py +135 -0
- juniper_data/core/split.py +120 -0
- juniper_data/generators/__init__.py +15 -0
- juniper_data/generators/arc_agi/__init__.py +11 -0
- juniper_data/generators/arc_agi/generator.py +229 -0
- juniper_data/generators/arc_agi/params.py +56 -0
- juniper_data/generators/checkerboard/__init__.py +15 -0
- juniper_data/generators/checkerboard/generator.py +114 -0
- juniper_data/generators/checkerboard/params.py +32 -0
- juniper_data/generators/circles/__init__.py +11 -0
- juniper_data/generators/circles/generator.py +112 -0
- juniper_data/generators/circles/params.py +31 -0
- juniper_data/generators/csv_import/__init__.py +15 -0
- juniper_data/generators/csv_import/generator.py +198 -0
- juniper_data/generators/csv_import/params.py +48 -0
- juniper_data/generators/gaussian/__init__.py +11 -0
- juniper_data/generators/gaussian/generator.py +149 -0
- juniper_data/generators/gaussian/params.py +53 -0
- juniper_data/generators/mnist/__init__.py +11 -0
- juniper_data/generators/mnist/generator.py +124 -0
- juniper_data/generators/mnist/params.py +39 -0
- juniper_data/generators/spiral/__init__.py +57 -0
- juniper_data/generators/spiral/defaults.py +39 -0
- juniper_data/generators/spiral/generator.py +206 -0
- juniper_data/generators/spiral/params.py +148 -0
- juniper_data/generators/xor/__init__.py +11 -0
- juniper_data/generators/xor/generator.py +162 -0
- juniper_data/generators/xor/params.py +30 -0
- juniper_data/storage/__init__.py +120 -0
- juniper_data/storage/base.py +279 -0
- juniper_data/storage/cached.py +211 -0
- juniper_data/storage/hf_store.py +257 -0
- juniper_data/storage/kaggle_store.py +333 -0
- juniper_data/storage/local_fs.py +232 -0
- juniper_data/storage/memory.py +136 -0
- juniper_data/storage/postgres_store.py +373 -0
- juniper_data/storage/redis_store.py +264 -0
- juniper_data/tests/__init__.py +1 -0
- juniper_data/tests/conftest.py +68 -0
- juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
- juniper_data/tests/integration/__init__.py +1 -0
- juniper_data/tests/integration/test_api.py +283 -0
- juniper_data/tests/integration/test_e2e_workflow.py +378 -0
- juniper_data/tests/integration/test_lifecycle_api.py +304 -0
- juniper_data/tests/integration/test_security_integration.py +189 -0
- juniper_data/tests/integration/test_storage_workflow.py +259 -0
- juniper_data/tests/performance/__init__.py +1 -0
- juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
- juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
- juniper_data/tests/unit/__init__.py +1 -0
- juniper_data/tests/unit/test_api_app.py +206 -0
- juniper_data/tests/unit/test_api_routes.py +407 -0
- juniper_data/tests/unit/test_api_settings.py +100 -0
- juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
- juniper_data/tests/unit/test_artifacts.py +145 -0
- juniper_data/tests/unit/test_cached_store.py +423 -0
- juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
- juniper_data/tests/unit/test_circles_generator.py +256 -0
- juniper_data/tests/unit/test_csv_import_generator.py +345 -0
- juniper_data/tests/unit/test_dataset_id.py +181 -0
- juniper_data/tests/unit/test_gaussian_generator.py +333 -0
- juniper_data/tests/unit/test_hf_store.py +416 -0
- juniper_data/tests/unit/test_init.py +93 -0
- juniper_data/tests/unit/test_kaggle_store.py +469 -0
- juniper_data/tests/unit/test_lifecycle.py +394 -0
- juniper_data/tests/unit/test_main.py +127 -0
- juniper_data/tests/unit/test_middleware.py +79 -0
- juniper_data/tests/unit/test_mnist_generator.py +370 -0
- juniper_data/tests/unit/test_postgres_store.py +490 -0
- juniper_data/tests/unit/test_redis_store.py +500 -0
- juniper_data/tests/unit/test_security.py +281 -0
- juniper_data/tests/unit/test_security_boundaries.py +517 -0
- juniper_data/tests/unit/test_spiral_generator.py +566 -0
- juniper_data/tests/unit/test_split.py +245 -0
- juniper_data/tests/unit/test_storage.py +767 -0
- juniper_data/tests/unit/test_xor_generator.py +223 -0
- juniper_data-0.4.2.dist-info/METADATA +216 -0
- juniper_data-0.4.2.dist-info/RECORD +95 -0
- juniper_data-0.4.2.dist-info/WHEEL +5 -0
- juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
- juniper_data-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
"""Unit tests for dataset ID generation.
|
|
2
|
+
|
|
3
|
+
Tests cover:
|
|
4
|
+
- Deterministic ID generation
|
|
5
|
+
- Different params produce different IDs
|
|
6
|
+
- ID format matches expected pattern
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import pytest
|
|
10
|
+
|
|
11
|
+
from juniper_data.core.dataset_id import generate_dataset_id
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@pytest.mark.unit
|
|
15
|
+
class TestDatasetIdGeneration:
|
|
16
|
+
"""Tests for deterministic dataset ID generation."""
|
|
17
|
+
|
|
18
|
+
def test_deterministic_id_generation(self) -> None:
|
|
19
|
+
"""Verify same inputs produce identical IDs."""
|
|
20
|
+
params = {
|
|
21
|
+
"n_spirals": 2,
|
|
22
|
+
"n_points_per_spiral": 100,
|
|
23
|
+
"seed": 42,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
id1 = generate_dataset_id("spiral", "v1.0.0", params)
|
|
27
|
+
id2 = generate_dataset_id("spiral", "v1.0.0", params)
|
|
28
|
+
|
|
29
|
+
assert id1 == id2
|
|
30
|
+
|
|
31
|
+
def test_multiple_calls_identical(self) -> None:
|
|
32
|
+
"""Verify multiple sequential calls produce identical IDs."""
|
|
33
|
+
params = {"n_spirals": 3, "n_points": 50}
|
|
34
|
+
|
|
35
|
+
ids = [generate_dataset_id("spiral", "v1.0.0", params) for _ in range(5)]
|
|
36
|
+
|
|
37
|
+
assert all(id_ == ids[0] for id_ in ids)
|
|
38
|
+
|
|
39
|
+
def test_different_params_produce_different_ids(self) -> None:
|
|
40
|
+
"""Verify different params produce different IDs."""
|
|
41
|
+
params1 = {"n_spirals": 2, "seed": 42}
|
|
42
|
+
params2 = {"n_spirals": 3, "seed": 42}
|
|
43
|
+
|
|
44
|
+
id1 = generate_dataset_id("spiral", "v1.0.0", params1)
|
|
45
|
+
id2 = generate_dataset_id("spiral", "v1.0.0", params2)
|
|
46
|
+
|
|
47
|
+
assert id1 != id2
|
|
48
|
+
|
|
49
|
+
def test_different_generator_produces_different_id(self) -> None:
|
|
50
|
+
"""Verify different generator names produce different IDs."""
|
|
51
|
+
params = {"n_spirals": 2}
|
|
52
|
+
|
|
53
|
+
id1 = generate_dataset_id("spiral", "v1.0.0", params)
|
|
54
|
+
id2 = generate_dataset_id("circle", "v1.0.0", params)
|
|
55
|
+
|
|
56
|
+
assert id1 != id2
|
|
57
|
+
|
|
58
|
+
def test_different_version_produces_different_id(self) -> None:
|
|
59
|
+
"""Verify different versions produce different IDs."""
|
|
60
|
+
params = {"n_spirals": 2}
|
|
61
|
+
|
|
62
|
+
id1 = generate_dataset_id("spiral", "v1.0.0", params)
|
|
63
|
+
id2 = generate_dataset_id("spiral", "v2.0.0", params)
|
|
64
|
+
|
|
65
|
+
assert id1 != id2
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@pytest.mark.unit
|
|
69
|
+
class TestDatasetIdFormat:
|
|
70
|
+
"""Tests for dataset ID format validation."""
|
|
71
|
+
|
|
72
|
+
def test_id_format_matches_pattern(self) -> None:
|
|
73
|
+
"""Verify ID format is '{generator}-{version}-{hash[:16]}'."""
|
|
74
|
+
params = {"n_spirals": 2}
|
|
75
|
+
|
|
76
|
+
dataset_id = generate_dataset_id("spiral", "v1.0.0", params)
|
|
77
|
+
|
|
78
|
+
assert dataset_id.startswith("spiral-v1.0.0-")
|
|
79
|
+
parts = dataset_id.split("-")
|
|
80
|
+
assert len(parts) == 3
|
|
81
|
+
assert parts[0] == "spiral"
|
|
82
|
+
assert parts[1] == "v1.0.0"
|
|
83
|
+
assert len(parts[2]) == 16
|
|
84
|
+
|
|
85
|
+
def test_hash_is_hex(self) -> None:
|
|
86
|
+
"""Verify hash portion is valid hexadecimal."""
|
|
87
|
+
params = {"n_spirals": 2}
|
|
88
|
+
|
|
89
|
+
dataset_id = generate_dataset_id("spiral", "v1.0.0", params)
|
|
90
|
+
hash_part = dataset_id.split("-")[-1]
|
|
91
|
+
|
|
92
|
+
int(hash_part, 16)
|
|
93
|
+
|
|
94
|
+
def test_id_length_consistent(self) -> None:
|
|
95
|
+
"""Verify ID has consistent length structure."""
|
|
96
|
+
params1 = {"n_spirals": 2}
|
|
97
|
+
params2 = {"n_spirals": 3, "noise": 0.5, "seed": 12345}
|
|
98
|
+
|
|
99
|
+
id1 = generate_dataset_id("spiral", "v1.0.0", params1)
|
|
100
|
+
id2 = generate_dataset_id("spiral", "v1.0.0", params2)
|
|
101
|
+
|
|
102
|
+
hash1 = id1.split("-")[-1]
|
|
103
|
+
hash2 = id2.split("-")[-1]
|
|
104
|
+
|
|
105
|
+
assert len(hash1) == 16
|
|
106
|
+
assert len(hash2) == 16
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@pytest.mark.unit
|
|
110
|
+
class TestDatasetIdEdgeCases:
|
|
111
|
+
"""Tests for edge cases in dataset ID generation."""
|
|
112
|
+
|
|
113
|
+
def test_empty_params(self) -> None:
|
|
114
|
+
"""Verify empty params dict works."""
|
|
115
|
+
dataset_id = generate_dataset_id("spiral", "v1.0.0", {})
|
|
116
|
+
|
|
117
|
+
assert dataset_id.startswith("spiral-v1.0.0-")
|
|
118
|
+
assert len(dataset_id.split("-")[-1]) == 16
|
|
119
|
+
|
|
120
|
+
def test_nested_params(self) -> None:
|
|
121
|
+
"""Verify nested params are handled correctly."""
|
|
122
|
+
params = {
|
|
123
|
+
"n_spirals": 2,
|
|
124
|
+
"advanced": {"noise": 0.5, "scale": 1.0},
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
dataset_id = generate_dataset_id("spiral", "v1.0.0", params)
|
|
128
|
+
|
|
129
|
+
assert dataset_id.startswith("spiral-v1.0.0-")
|
|
130
|
+
|
|
131
|
+
def test_params_order_independent(self) -> None:
|
|
132
|
+
"""Verify param order doesn't affect ID (sorted keys)."""
|
|
133
|
+
params1 = {"a": 1, "b": 2, "c": 3}
|
|
134
|
+
params2 = {"c": 3, "a": 1, "b": 2}
|
|
135
|
+
|
|
136
|
+
id1 = generate_dataset_id("spiral", "v1.0.0", params1)
|
|
137
|
+
id2 = generate_dataset_id("spiral", "v1.0.0", params2)
|
|
138
|
+
|
|
139
|
+
assert id1 == id2
|
|
140
|
+
|
|
141
|
+
def test_float_params(self) -> None:
|
|
142
|
+
"""Verify float params work correctly."""
|
|
143
|
+
params = {"noise": 0.25, "ratio": 0.8}
|
|
144
|
+
|
|
145
|
+
dataset_id = generate_dataset_id("spiral", "v1.0.0", params)
|
|
146
|
+
|
|
147
|
+
assert dataset_id.startswith("spiral-v1.0.0-")
|
|
148
|
+
|
|
149
|
+
def test_boolean_params(self) -> None:
|
|
150
|
+
"""Verify boolean params work correctly."""
|
|
151
|
+
params = {"clockwise": True, "shuffle": False}
|
|
152
|
+
|
|
153
|
+
dataset_id = generate_dataset_id("spiral", "v1.0.0", params)
|
|
154
|
+
|
|
155
|
+
assert dataset_id.startswith("spiral-v1.0.0-")
|
|
156
|
+
|
|
157
|
+
def test_none_param_value(self) -> None:
|
|
158
|
+
"""Verify None param values work correctly."""
|
|
159
|
+
params = {"seed": None, "n_spirals": 2}
|
|
160
|
+
|
|
161
|
+
dataset_id = generate_dataset_id("spiral", "v1.0.0", params)
|
|
162
|
+
|
|
163
|
+
assert dataset_id.startswith("spiral-v1.0.0-")
|
|
164
|
+
|
|
165
|
+
def test_special_characters_in_generator_name(self) -> None:
|
|
166
|
+
"""Verify special characters in generator name are handled."""
|
|
167
|
+
params = {"n_spirals": 2}
|
|
168
|
+
|
|
169
|
+
dataset_id = generate_dataset_id("spiral_v2", "v1.0.0", params)
|
|
170
|
+
|
|
171
|
+
assert dataset_id.startswith("spiral_v2-v1.0.0-")
|
|
172
|
+
|
|
173
|
+
def test_different_float_precision_different_id(self) -> None:
|
|
174
|
+
"""Verify different float values produce different IDs."""
|
|
175
|
+
params1 = {"noise": 0.25}
|
|
176
|
+
params2 = {"noise": 0.250001}
|
|
177
|
+
|
|
178
|
+
id1 = generate_dataset_id("spiral", "v1.0.0", params1)
|
|
179
|
+
id2 = generate_dataset_id("spiral", "v1.0.0", params2)
|
|
180
|
+
|
|
181
|
+
assert id1 != id2
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
"""Unit tests for the Gaussian blobs dataset generator."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from juniper_data.generators.gaussian import VERSION, GaussianGenerator, GaussianParams, get_schema
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestGaussianParams:
|
|
10
|
+
"""Tests for GaussianParams validation."""
|
|
11
|
+
|
|
12
|
+
def test_default_params(self) -> None:
|
|
13
|
+
"""Default parameters should be valid."""
|
|
14
|
+
params = GaussianParams()
|
|
15
|
+
assert params.n_classes == 2
|
|
16
|
+
assert params.n_samples_per_class == 50
|
|
17
|
+
assert params.n_features == 2
|
|
18
|
+
assert params.class_std == 1.0
|
|
19
|
+
assert params.centers is None
|
|
20
|
+
assert params.center_radius == 3.0
|
|
21
|
+
assert params.noise == 0.0
|
|
22
|
+
assert params.train_ratio == 0.8
|
|
23
|
+
assert params.test_ratio == 0.2
|
|
24
|
+
assert params.shuffle is True
|
|
25
|
+
|
|
26
|
+
def test_custom_params(self) -> None:
|
|
27
|
+
"""Custom parameters should be accepted."""
|
|
28
|
+
params = GaussianParams(
|
|
29
|
+
n_classes=3,
|
|
30
|
+
n_samples_per_class=100,
|
|
31
|
+
n_features=4,
|
|
32
|
+
class_std=0.5,
|
|
33
|
+
seed=42,
|
|
34
|
+
)
|
|
35
|
+
assert params.n_classes == 3
|
|
36
|
+
assert params.n_samples_per_class == 100
|
|
37
|
+
assert params.n_features == 4
|
|
38
|
+
assert params.class_std == 0.5
|
|
39
|
+
assert params.seed == 42
|
|
40
|
+
|
|
41
|
+
def test_list_class_std(self) -> None:
|
|
42
|
+
"""List of class_std values should be accepted."""
|
|
43
|
+
params = GaussianParams(n_classes=3, class_std=[0.5, 1.0, 1.5])
|
|
44
|
+
assert params.class_std == [0.5, 1.0, 1.5]
|
|
45
|
+
|
|
46
|
+
def test_invalid_class_std_negative(self) -> None:
|
|
47
|
+
"""Negative class_std should raise validation error."""
|
|
48
|
+
with pytest.raises(ValueError, match="positive"):
|
|
49
|
+
GaussianParams(class_std=-0.5)
|
|
50
|
+
|
|
51
|
+
def test_invalid_class_std_list_negative(self) -> None:
|
|
52
|
+
"""List with negative class_std should raise validation error."""
|
|
53
|
+
with pytest.raises(ValueError, match="positive"):
|
|
54
|
+
GaussianParams(class_std=[0.5, -1.0, 1.5])
|
|
55
|
+
|
|
56
|
+
def test_custom_centers(self) -> None:
|
|
57
|
+
"""Custom centers should be accepted."""
|
|
58
|
+
centers = [[0.0, 0.0], [5.0, 5.0]]
|
|
59
|
+
params = GaussianParams(n_classes=2, n_features=2, centers=centers)
|
|
60
|
+
assert params.centers == centers
|
|
61
|
+
|
|
62
|
+
def test_empty_centers_invalid(self) -> None:
|
|
63
|
+
"""Empty centers list should raise validation error."""
|
|
64
|
+
with pytest.raises(ValueError, match="cannot be empty"):
|
|
65
|
+
GaussianParams(centers=[])
|
|
66
|
+
|
|
67
|
+
def test_invalid_n_classes_too_low(self) -> None:
|
|
68
|
+
"""n_classes less than 2 should raise validation error."""
|
|
69
|
+
with pytest.raises(ValueError):
|
|
70
|
+
GaussianParams(n_classes=1)
|
|
71
|
+
|
|
72
|
+
def test_invalid_n_classes_too_high(self) -> None:
|
|
73
|
+
"""n_classes greater than 10 should raise validation error."""
|
|
74
|
+
with pytest.raises(ValueError):
|
|
75
|
+
GaussianParams(n_classes=11)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class TestGaussianGenerator:
|
|
79
|
+
"""Tests for GaussianGenerator."""
|
|
80
|
+
|
|
81
|
+
def test_generate_returns_expected_keys(self) -> None:
|
|
82
|
+
"""Generated data should contain all expected keys."""
|
|
83
|
+
params = GaussianParams(seed=42)
|
|
84
|
+
result = GaussianGenerator.generate(params)
|
|
85
|
+
|
|
86
|
+
expected_keys = {"X_train", "y_train", "X_test", "y_test", "X_full", "y_full"}
|
|
87
|
+
assert set(result.keys()) == expected_keys
|
|
88
|
+
|
|
89
|
+
def test_generate_shapes(self) -> None:
|
|
90
|
+
"""Generated arrays should have correct shapes."""
|
|
91
|
+
params = GaussianParams(
|
|
92
|
+
n_classes=3,
|
|
93
|
+
n_samples_per_class=40,
|
|
94
|
+
n_features=5,
|
|
95
|
+
seed=42,
|
|
96
|
+
)
|
|
97
|
+
result = GaussianGenerator.generate(params)
|
|
98
|
+
|
|
99
|
+
total_samples = 3 * 40
|
|
100
|
+
assert result["X_full"].shape == (total_samples, 5)
|
|
101
|
+
assert result["y_full"].shape == (total_samples, 3)
|
|
102
|
+
|
|
103
|
+
def test_generate_dtypes(self) -> None:
|
|
104
|
+
"""Generated arrays should have float32 dtype."""
|
|
105
|
+
params = GaussianParams(seed=42)
|
|
106
|
+
result = GaussianGenerator.generate(params)
|
|
107
|
+
|
|
108
|
+
assert result["X_train"].dtype == np.float32
|
|
109
|
+
assert result["y_train"].dtype == np.float32
|
|
110
|
+
assert result["X_full"].dtype == np.float32
|
|
111
|
+
assert result["y_full"].dtype == np.float32
|
|
112
|
+
|
|
113
|
+
def test_determinism_with_seed(self) -> None:
|
|
114
|
+
"""Same seed should produce identical results."""
|
|
115
|
+
params = GaussianParams(seed=123)
|
|
116
|
+
|
|
117
|
+
result1 = GaussianGenerator.generate(params)
|
|
118
|
+
result2 = GaussianGenerator.generate(params)
|
|
119
|
+
|
|
120
|
+
np.testing.assert_array_equal(result1["X_full"], result2["X_full"])
|
|
121
|
+
np.testing.assert_array_equal(result1["y_full"], result2["y_full"])
|
|
122
|
+
|
|
123
|
+
def test_different_seeds_produce_different_data(self) -> None:
|
|
124
|
+
"""Different seeds should produce different results."""
|
|
125
|
+
params1 = GaussianParams(seed=42)
|
|
126
|
+
params2 = GaussianParams(seed=43)
|
|
127
|
+
|
|
128
|
+
result1 = GaussianGenerator.generate(params1)
|
|
129
|
+
result2 = GaussianGenerator.generate(params2)
|
|
130
|
+
|
|
131
|
+
assert not np.allclose(result1["X_full"], result2["X_full"])
|
|
132
|
+
|
|
133
|
+
def test_one_hot_labels(self) -> None:
|
|
134
|
+
"""Labels should be valid one-hot encoded."""
|
|
135
|
+
params = GaussianParams(n_classes=4, seed=42)
|
|
136
|
+
result = GaussianGenerator.generate(params)
|
|
137
|
+
|
|
138
|
+
row_sums = result["y_full"].sum(axis=1)
|
|
139
|
+
np.testing.assert_array_almost_equal(row_sums, np.ones(len(row_sums)))
|
|
140
|
+
|
|
141
|
+
for row in result["y_full"]:
|
|
142
|
+
assert np.sum(row == 1.0) == 1
|
|
143
|
+
assert np.sum(row == 0.0) == params.n_classes - 1
|
|
144
|
+
|
|
145
|
+
def test_class_distribution(self) -> None:
|
|
146
|
+
"""Each class should have n_samples_per_class samples."""
|
|
147
|
+
params = GaussianParams(n_classes=3, n_samples_per_class=50, seed=42)
|
|
148
|
+
result = GaussianGenerator.generate(params)
|
|
149
|
+
|
|
150
|
+
class_counts = result["y_full"].sum(axis=0)
|
|
151
|
+
np.testing.assert_array_equal(class_counts, [50, 50, 50])
|
|
152
|
+
|
|
153
|
+
def test_train_test_split_ratio(self) -> None:
|
|
154
|
+
"""Train/test split should respect configured ratios."""
|
|
155
|
+
params = GaussianParams(
|
|
156
|
+
n_samples_per_class=50,
|
|
157
|
+
train_ratio=0.7,
|
|
158
|
+
test_ratio=0.3,
|
|
159
|
+
seed=42,
|
|
160
|
+
)
|
|
161
|
+
result = GaussianGenerator.generate(params)
|
|
162
|
+
|
|
163
|
+
total = 2 * 50
|
|
164
|
+
expected_train = int(total * 0.7)
|
|
165
|
+
expected_test = int(total * 0.3)
|
|
166
|
+
|
|
167
|
+
assert len(result["X_train"]) == expected_train
|
|
168
|
+
assert len(result["X_test"]) == expected_test
|
|
169
|
+
|
|
170
|
+
def test_custom_centers(self) -> None:
|
|
171
|
+
"""Custom centers should position class means correctly."""
|
|
172
|
+
centers = [[0.0, 0.0], [10.0, 10.0]]
|
|
173
|
+
params = GaussianParams(
|
|
174
|
+
n_classes=2,
|
|
175
|
+
n_samples_per_class=100,
|
|
176
|
+
centers=centers,
|
|
177
|
+
class_std=0.1,
|
|
178
|
+
noise=0.0,
|
|
179
|
+
seed=42,
|
|
180
|
+
)
|
|
181
|
+
result = GaussianGenerator.generate(params)
|
|
182
|
+
|
|
183
|
+
class_0_samples = result["X_full"][:100]
|
|
184
|
+
class_1_samples = result["X_full"][100:]
|
|
185
|
+
|
|
186
|
+
class_0_mean = class_0_samples.mean(axis=0)
|
|
187
|
+
class_1_mean = class_1_samples.mean(axis=0)
|
|
188
|
+
|
|
189
|
+
np.testing.assert_array_almost_equal(class_0_mean, [0.0, 0.0], decimal=0)
|
|
190
|
+
np.testing.assert_array_almost_equal(class_1_mean, [10.0, 10.0], decimal=0)
|
|
191
|
+
|
|
192
|
+
def test_centers_dimension_mismatch_raises_error(self) -> None:
|
|
193
|
+
"""Centers with wrong dimensions should raise error."""
|
|
194
|
+
centers = [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]
|
|
195
|
+
params = GaussianParams(
|
|
196
|
+
n_classes=2,
|
|
197
|
+
n_features=2,
|
|
198
|
+
centers=centers,
|
|
199
|
+
seed=42,
|
|
200
|
+
)
|
|
201
|
+
with pytest.raises(ValueError, match="n_features"):
|
|
202
|
+
GaussianGenerator.generate(params)
|
|
203
|
+
|
|
204
|
+
def test_centers_count_mismatch_raises_error(self) -> None:
|
|
205
|
+
"""Wrong number of centers should raise error."""
|
|
206
|
+
centers = [[0.0, 0.0]]
|
|
207
|
+
params = GaussianParams(
|
|
208
|
+
n_classes=2,
|
|
209
|
+
n_features=2,
|
|
210
|
+
centers=centers,
|
|
211
|
+
seed=42,
|
|
212
|
+
)
|
|
213
|
+
with pytest.raises(ValueError, match="n_classes"):
|
|
214
|
+
GaussianGenerator.generate(params)
|
|
215
|
+
|
|
216
|
+
def test_noise_adds_variation(self) -> None:
|
|
217
|
+
"""Noise parameter should increase data variance."""
|
|
218
|
+
params_no_noise = GaussianParams(
|
|
219
|
+
n_samples_per_class=100,
|
|
220
|
+
class_std=0.5,
|
|
221
|
+
noise=0.0,
|
|
222
|
+
seed=42,
|
|
223
|
+
)
|
|
224
|
+
params_with_noise = GaussianParams(
|
|
225
|
+
n_samples_per_class=100,
|
|
226
|
+
class_std=0.5,
|
|
227
|
+
noise=1.0,
|
|
228
|
+
seed=42,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
result_no_noise = GaussianGenerator.generate(params_no_noise)
|
|
232
|
+
result_with_noise = GaussianGenerator.generate(params_with_noise)
|
|
233
|
+
|
|
234
|
+
var_no_noise = np.var(result_no_noise["X_full"])
|
|
235
|
+
var_with_noise = np.var(result_with_noise["X_full"])
|
|
236
|
+
|
|
237
|
+
assert var_with_noise > var_no_noise
|
|
238
|
+
|
|
239
|
+
def test_auto_center_placement(self) -> None:
|
|
240
|
+
"""Auto-placed centers should be on a circle."""
|
|
241
|
+
params = GaussianParams(
|
|
242
|
+
n_classes=4,
|
|
243
|
+
n_samples_per_class=100,
|
|
244
|
+
center_radius=5.0,
|
|
245
|
+
class_std=0.1,
|
|
246
|
+
seed=42,
|
|
247
|
+
)
|
|
248
|
+
result = GaussianGenerator.generate(params)
|
|
249
|
+
|
|
250
|
+
for i in range(4):
|
|
251
|
+
start = i * 100
|
|
252
|
+
end = start + 100
|
|
253
|
+
class_mean = result["X_full"][start:end].mean(axis=0)
|
|
254
|
+
distance_from_origin = np.linalg.norm(class_mean)
|
|
255
|
+
np.testing.assert_almost_equal(distance_from_origin, 5.0, decimal=0)
|
|
256
|
+
|
|
257
|
+
def test_generate_with_list_class_std(self) -> None:
|
|
258
|
+
"""Per-class std list should apply different stds to each class."""
|
|
259
|
+
params = GaussianParams(
|
|
260
|
+
n_classes=3,
|
|
261
|
+
n_samples_per_class=100,
|
|
262
|
+
class_std=[0.1, 0.5, 2.0],
|
|
263
|
+
seed=42,
|
|
264
|
+
)
|
|
265
|
+
result = GaussianGenerator.generate(params)
|
|
266
|
+
|
|
267
|
+
assert result["X_full"].shape == (300, 2)
|
|
268
|
+
|
|
269
|
+
def test_generate_single_feature(self) -> None:
|
|
270
|
+
"""Single feature should skip sin component in center placement."""
|
|
271
|
+
params = GaussianParams(
|
|
272
|
+
n_classes=2,
|
|
273
|
+
n_samples_per_class=50,
|
|
274
|
+
n_features=1,
|
|
275
|
+
seed=42,
|
|
276
|
+
)
|
|
277
|
+
result = GaussianGenerator.generate(params)
|
|
278
|
+
|
|
279
|
+
assert result["X_full"].shape == (100, 1)
|
|
280
|
+
|
|
281
|
+
def test_get_stds_scalar(self) -> None:
|
|
282
|
+
"""Scalar class_std should return a list of repeated values."""
|
|
283
|
+
params = GaussianParams(n_classes=3, class_std=0.5)
|
|
284
|
+
stds = GaussianGenerator._get_stds(params)
|
|
285
|
+
assert stds == [0.5, 0.5, 0.5]
|
|
286
|
+
|
|
287
|
+
def test_get_stds_list(self) -> None:
|
|
288
|
+
"""List class_std should be returned as-is."""
|
|
289
|
+
params = GaussianParams(n_classes=3, class_std=[0.1, 0.5, 2.0])
|
|
290
|
+
stds = GaussianGenerator._get_stds(params)
|
|
291
|
+
assert stds == [0.1, 0.5, 2.0]
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class TestGetSchema:
|
|
295
|
+
"""Tests for get_schema function."""
|
|
296
|
+
|
|
297
|
+
def test_returns_dict(self) -> None:
|
|
298
|
+
"""get_schema should return a dictionary."""
|
|
299
|
+
schema = get_schema()
|
|
300
|
+
assert isinstance(schema, dict)
|
|
301
|
+
|
|
302
|
+
def test_schema_has_properties(self) -> None:
|
|
303
|
+
"""Schema should have properties key."""
|
|
304
|
+
schema = get_schema()
|
|
305
|
+
assert "properties" in schema
|
|
306
|
+
|
|
307
|
+
def test_schema_includes_all_params(self) -> None:
|
|
308
|
+
"""Schema should include all parameter names."""
|
|
309
|
+
schema = get_schema()
|
|
310
|
+
expected_params = {
|
|
311
|
+
"n_classes",
|
|
312
|
+
"n_samples_per_class",
|
|
313
|
+
"n_features",
|
|
314
|
+
"class_std",
|
|
315
|
+
"centers",
|
|
316
|
+
"center_radius",
|
|
317
|
+
"noise",
|
|
318
|
+
"seed",
|
|
319
|
+
"train_ratio",
|
|
320
|
+
"test_ratio",
|
|
321
|
+
"shuffle",
|
|
322
|
+
}
|
|
323
|
+
assert expected_params.issubset(set(schema["properties"].keys()))
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
class TestVersion:
|
|
327
|
+
"""Tests for VERSION constant."""
|
|
328
|
+
|
|
329
|
+
def test_version_format(self) -> None:
|
|
330
|
+
"""VERSION should be a valid semver string."""
|
|
331
|
+
parts = VERSION.split(".")
|
|
332
|
+
assert len(parts) == 3
|
|
333
|
+
assert all(part.isdigit() for part in parts)
|