juniper-data 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. juniper_data/__init__.py +88 -0
  2. juniper_data/__main__.py +78 -0
  3. juniper_data/api/__init__.py +10 -0
  4. juniper_data/api/app.py +111 -0
  5. juniper_data/api/middleware.py +95 -0
  6. juniper_data/api/routes/__init__.py +9 -0
  7. juniper_data/api/routes/datasets.py +414 -0
  8. juniper_data/api/routes/generators.py +125 -0
  9. juniper_data/api/routes/health.py +49 -0
  10. juniper_data/api/security.py +238 -0
  11. juniper_data/api/settings.py +109 -0
  12. juniper_data/core/__init__.py +32 -0
  13. juniper_data/core/artifacts.py +63 -0
  14. juniper_data/core/dataset_id.py +38 -0
  15. juniper_data/core/models.py +135 -0
  16. juniper_data/core/split.py +120 -0
  17. juniper_data/generators/__init__.py +15 -0
  18. juniper_data/generators/arc_agi/__init__.py +11 -0
  19. juniper_data/generators/arc_agi/generator.py +229 -0
  20. juniper_data/generators/arc_agi/params.py +56 -0
  21. juniper_data/generators/checkerboard/__init__.py +15 -0
  22. juniper_data/generators/checkerboard/generator.py +114 -0
  23. juniper_data/generators/checkerboard/params.py +32 -0
  24. juniper_data/generators/circles/__init__.py +11 -0
  25. juniper_data/generators/circles/generator.py +112 -0
  26. juniper_data/generators/circles/params.py +31 -0
  27. juniper_data/generators/csv_import/__init__.py +15 -0
  28. juniper_data/generators/csv_import/generator.py +198 -0
  29. juniper_data/generators/csv_import/params.py +48 -0
  30. juniper_data/generators/gaussian/__init__.py +11 -0
  31. juniper_data/generators/gaussian/generator.py +149 -0
  32. juniper_data/generators/gaussian/params.py +53 -0
  33. juniper_data/generators/mnist/__init__.py +11 -0
  34. juniper_data/generators/mnist/generator.py +124 -0
  35. juniper_data/generators/mnist/params.py +39 -0
  36. juniper_data/generators/spiral/__init__.py +57 -0
  37. juniper_data/generators/spiral/defaults.py +39 -0
  38. juniper_data/generators/spiral/generator.py +206 -0
  39. juniper_data/generators/spiral/params.py +148 -0
  40. juniper_data/generators/xor/__init__.py +11 -0
  41. juniper_data/generators/xor/generator.py +162 -0
  42. juniper_data/generators/xor/params.py +30 -0
  43. juniper_data/storage/__init__.py +120 -0
  44. juniper_data/storage/base.py +279 -0
  45. juniper_data/storage/cached.py +211 -0
  46. juniper_data/storage/hf_store.py +257 -0
  47. juniper_data/storage/kaggle_store.py +333 -0
  48. juniper_data/storage/local_fs.py +232 -0
  49. juniper_data/storage/memory.py +136 -0
  50. juniper_data/storage/postgres_store.py +373 -0
  51. juniper_data/storage/redis_store.py +264 -0
  52. juniper_data/tests/__init__.py +1 -0
  53. juniper_data/tests/conftest.py +68 -0
  54. juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
  55. juniper_data/tests/integration/__init__.py +1 -0
  56. juniper_data/tests/integration/test_api.py +283 -0
  57. juniper_data/tests/integration/test_e2e_workflow.py +378 -0
  58. juniper_data/tests/integration/test_lifecycle_api.py +304 -0
  59. juniper_data/tests/integration/test_security_integration.py +189 -0
  60. juniper_data/tests/integration/test_storage_workflow.py +259 -0
  61. juniper_data/tests/performance/__init__.py +1 -0
  62. juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
  63. juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
  64. juniper_data/tests/unit/__init__.py +1 -0
  65. juniper_data/tests/unit/test_api_app.py +206 -0
  66. juniper_data/tests/unit/test_api_routes.py +407 -0
  67. juniper_data/tests/unit/test_api_settings.py +100 -0
  68. juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
  69. juniper_data/tests/unit/test_artifacts.py +145 -0
  70. juniper_data/tests/unit/test_cached_store.py +423 -0
  71. juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
  72. juniper_data/tests/unit/test_circles_generator.py +256 -0
  73. juniper_data/tests/unit/test_csv_import_generator.py +345 -0
  74. juniper_data/tests/unit/test_dataset_id.py +181 -0
  75. juniper_data/tests/unit/test_gaussian_generator.py +333 -0
  76. juniper_data/tests/unit/test_hf_store.py +416 -0
  77. juniper_data/tests/unit/test_init.py +93 -0
  78. juniper_data/tests/unit/test_kaggle_store.py +469 -0
  79. juniper_data/tests/unit/test_lifecycle.py +394 -0
  80. juniper_data/tests/unit/test_main.py +127 -0
  81. juniper_data/tests/unit/test_middleware.py +79 -0
  82. juniper_data/tests/unit/test_mnist_generator.py +370 -0
  83. juniper_data/tests/unit/test_postgres_store.py +490 -0
  84. juniper_data/tests/unit/test_redis_store.py +500 -0
  85. juniper_data/tests/unit/test_security.py +281 -0
  86. juniper_data/tests/unit/test_security_boundaries.py +517 -0
  87. juniper_data/tests/unit/test_spiral_generator.py +566 -0
  88. juniper_data/tests/unit/test_split.py +245 -0
  89. juniper_data/tests/unit/test_storage.py +767 -0
  90. juniper_data/tests/unit/test_xor_generator.py +223 -0
  91. juniper_data-0.4.2.dist-info/METADATA +216 -0
  92. juniper_data-0.4.2.dist-info/RECORD +95 -0
  93. juniper_data-0.4.2.dist-info/WHEEL +5 -0
  94. juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
  95. juniper_data-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,181 @@
1
+ """Unit tests for dataset ID generation.
2
+
3
+ Tests cover:
4
+ - Deterministic ID generation
5
+ - Different params produce different IDs
6
+ - ID format matches expected pattern
7
+ """
8
+
9
+ import pytest
10
+
11
+ from juniper_data.core.dataset_id import generate_dataset_id
12
+
13
+
14
+ @pytest.mark.unit
15
+ class TestDatasetIdGeneration:
16
+ """Tests for deterministic dataset ID generation."""
17
+
18
+ def test_deterministic_id_generation(self) -> None:
19
+ """Verify same inputs produce identical IDs."""
20
+ params = {
21
+ "n_spirals": 2,
22
+ "n_points_per_spiral": 100,
23
+ "seed": 42,
24
+ }
25
+
26
+ id1 = generate_dataset_id("spiral", "v1.0.0", params)
27
+ id2 = generate_dataset_id("spiral", "v1.0.0", params)
28
+
29
+ assert id1 == id2
30
+
31
+ def test_multiple_calls_identical(self) -> None:
32
+ """Verify multiple sequential calls produce identical IDs."""
33
+ params = {"n_spirals": 3, "n_points": 50}
34
+
35
+ ids = [generate_dataset_id("spiral", "v1.0.0", params) for _ in range(5)]
36
+
37
+ assert all(id_ == ids[0] for id_ in ids)
38
+
39
+ def test_different_params_produce_different_ids(self) -> None:
40
+ """Verify different params produce different IDs."""
41
+ params1 = {"n_spirals": 2, "seed": 42}
42
+ params2 = {"n_spirals": 3, "seed": 42}
43
+
44
+ id1 = generate_dataset_id("spiral", "v1.0.0", params1)
45
+ id2 = generate_dataset_id("spiral", "v1.0.0", params2)
46
+
47
+ assert id1 != id2
48
+
49
+ def test_different_generator_produces_different_id(self) -> None:
50
+ """Verify different generator names produce different IDs."""
51
+ params = {"n_spirals": 2}
52
+
53
+ id1 = generate_dataset_id("spiral", "v1.0.0", params)
54
+ id2 = generate_dataset_id("circle", "v1.0.0", params)
55
+
56
+ assert id1 != id2
57
+
58
+ def test_different_version_produces_different_id(self) -> None:
59
+ """Verify different versions produce different IDs."""
60
+ params = {"n_spirals": 2}
61
+
62
+ id1 = generate_dataset_id("spiral", "v1.0.0", params)
63
+ id2 = generate_dataset_id("spiral", "v2.0.0", params)
64
+
65
+ assert id1 != id2
66
+
67
+
68
+ @pytest.mark.unit
69
+ class TestDatasetIdFormat:
70
+ """Tests for dataset ID format validation."""
71
+
72
+ def test_id_format_matches_pattern(self) -> None:
73
+ """Verify ID format is '{generator}-{version}-{hash[:16]}'."""
74
+ params = {"n_spirals": 2}
75
+
76
+ dataset_id = generate_dataset_id("spiral", "v1.0.0", params)
77
+
78
+ assert dataset_id.startswith("spiral-v1.0.0-")
79
+ parts = dataset_id.split("-")
80
+ assert len(parts) == 3
81
+ assert parts[0] == "spiral"
82
+ assert parts[1] == "v1.0.0"
83
+ assert len(parts[2]) == 16
84
+
85
+ def test_hash_is_hex(self) -> None:
86
+ """Verify hash portion is valid hexadecimal."""
87
+ params = {"n_spirals": 2}
88
+
89
+ dataset_id = generate_dataset_id("spiral", "v1.0.0", params)
90
+ hash_part = dataset_id.split("-")[-1]
91
+
92
+ int(hash_part, 16)
93
+
94
+ def test_id_length_consistent(self) -> None:
95
+ """Verify ID has consistent length structure."""
96
+ params1 = {"n_spirals": 2}
97
+ params2 = {"n_spirals": 3, "noise": 0.5, "seed": 12345}
98
+
99
+ id1 = generate_dataset_id("spiral", "v1.0.0", params1)
100
+ id2 = generate_dataset_id("spiral", "v1.0.0", params2)
101
+
102
+ hash1 = id1.split("-")[-1]
103
+ hash2 = id2.split("-")[-1]
104
+
105
+ assert len(hash1) == 16
106
+ assert len(hash2) == 16
107
+
108
+
109
+ @pytest.mark.unit
110
+ class TestDatasetIdEdgeCases:
111
+ """Tests for edge cases in dataset ID generation."""
112
+
113
+ def test_empty_params(self) -> None:
114
+ """Verify empty params dict works."""
115
+ dataset_id = generate_dataset_id("spiral", "v1.0.0", {})
116
+
117
+ assert dataset_id.startswith("spiral-v1.0.0-")
118
+ assert len(dataset_id.split("-")[-1]) == 16
119
+
120
+ def test_nested_params(self) -> None:
121
+ """Verify nested params are handled correctly."""
122
+ params = {
123
+ "n_spirals": 2,
124
+ "advanced": {"noise": 0.5, "scale": 1.0},
125
+ }
126
+
127
+ dataset_id = generate_dataset_id("spiral", "v1.0.0", params)
128
+
129
+ assert dataset_id.startswith("spiral-v1.0.0-")
130
+
131
+ def test_params_order_independent(self) -> None:
132
+ """Verify param order doesn't affect ID (sorted keys)."""
133
+ params1 = {"a": 1, "b": 2, "c": 3}
134
+ params2 = {"c": 3, "a": 1, "b": 2}
135
+
136
+ id1 = generate_dataset_id("spiral", "v1.0.0", params1)
137
+ id2 = generate_dataset_id("spiral", "v1.0.0", params2)
138
+
139
+ assert id1 == id2
140
+
141
+ def test_float_params(self) -> None:
142
+ """Verify float params work correctly."""
143
+ params = {"noise": 0.25, "ratio": 0.8}
144
+
145
+ dataset_id = generate_dataset_id("spiral", "v1.0.0", params)
146
+
147
+ assert dataset_id.startswith("spiral-v1.0.0-")
148
+
149
+ def test_boolean_params(self) -> None:
150
+ """Verify boolean params work correctly."""
151
+ params = {"clockwise": True, "shuffle": False}
152
+
153
+ dataset_id = generate_dataset_id("spiral", "v1.0.0", params)
154
+
155
+ assert dataset_id.startswith("spiral-v1.0.0-")
156
+
157
+ def test_none_param_value(self) -> None:
158
+ """Verify None param values work correctly."""
159
+ params = {"seed": None, "n_spirals": 2}
160
+
161
+ dataset_id = generate_dataset_id("spiral", "v1.0.0", params)
162
+
163
+ assert dataset_id.startswith("spiral-v1.0.0-")
164
+
165
+ def test_special_characters_in_generator_name(self) -> None:
166
+ """Verify special characters in generator name are handled."""
167
+ params = {"n_spirals": 2}
168
+
169
+ dataset_id = generate_dataset_id("spiral_v2", "v1.0.0", params)
170
+
171
+ assert dataset_id.startswith("spiral_v2-v1.0.0-")
172
+
173
+ def test_different_float_precision_different_id(self) -> None:
174
+ """Verify different float values produce different IDs."""
175
+ params1 = {"noise": 0.25}
176
+ params2 = {"noise": 0.250001}
177
+
178
+ id1 = generate_dataset_id("spiral", "v1.0.0", params1)
179
+ id2 = generate_dataset_id("spiral", "v1.0.0", params2)
180
+
181
+ assert id1 != id2
@@ -0,0 +1,333 @@
1
+ """Unit tests for the Gaussian blobs dataset generator."""
2
+
3
+ import numpy as np
4
+ import pytest
5
+
6
+ from juniper_data.generators.gaussian import VERSION, GaussianGenerator, GaussianParams, get_schema
7
+
8
+
9
+ class TestGaussianParams:
10
+ """Tests for GaussianParams validation."""
11
+
12
+ def test_default_params(self) -> None:
13
+ """Default parameters should be valid."""
14
+ params = GaussianParams()
15
+ assert params.n_classes == 2
16
+ assert params.n_samples_per_class == 50
17
+ assert params.n_features == 2
18
+ assert params.class_std == 1.0
19
+ assert params.centers is None
20
+ assert params.center_radius == 3.0
21
+ assert params.noise == 0.0
22
+ assert params.train_ratio == 0.8
23
+ assert params.test_ratio == 0.2
24
+ assert params.shuffle is True
25
+
26
+ def test_custom_params(self) -> None:
27
+ """Custom parameters should be accepted."""
28
+ params = GaussianParams(
29
+ n_classes=3,
30
+ n_samples_per_class=100,
31
+ n_features=4,
32
+ class_std=0.5,
33
+ seed=42,
34
+ )
35
+ assert params.n_classes == 3
36
+ assert params.n_samples_per_class == 100
37
+ assert params.n_features == 4
38
+ assert params.class_std == 0.5
39
+ assert params.seed == 42
40
+
41
+ def test_list_class_std(self) -> None:
42
+ """List of class_std values should be accepted."""
43
+ params = GaussianParams(n_classes=3, class_std=[0.5, 1.0, 1.5])
44
+ assert params.class_std == [0.5, 1.0, 1.5]
45
+
46
+ def test_invalid_class_std_negative(self) -> None:
47
+ """Negative class_std should raise validation error."""
48
+ with pytest.raises(ValueError, match="positive"):
49
+ GaussianParams(class_std=-0.5)
50
+
51
+ def test_invalid_class_std_list_negative(self) -> None:
52
+ """List with negative class_std should raise validation error."""
53
+ with pytest.raises(ValueError, match="positive"):
54
+ GaussianParams(class_std=[0.5, -1.0, 1.5])
55
+
56
+ def test_custom_centers(self) -> None:
57
+ """Custom centers should be accepted."""
58
+ centers = [[0.0, 0.0], [5.0, 5.0]]
59
+ params = GaussianParams(n_classes=2, n_features=2, centers=centers)
60
+ assert params.centers == centers
61
+
62
+ def test_empty_centers_invalid(self) -> None:
63
+ """Empty centers list should raise validation error."""
64
+ with pytest.raises(ValueError, match="cannot be empty"):
65
+ GaussianParams(centers=[])
66
+
67
+ def test_invalid_n_classes_too_low(self) -> None:
68
+ """n_classes less than 2 should raise validation error."""
69
+ with pytest.raises(ValueError):
70
+ GaussianParams(n_classes=1)
71
+
72
+ def test_invalid_n_classes_too_high(self) -> None:
73
+ """n_classes greater than 10 should raise validation error."""
74
+ with pytest.raises(ValueError):
75
+ GaussianParams(n_classes=11)
76
+
77
+
78
+ class TestGaussianGenerator:
79
+ """Tests for GaussianGenerator."""
80
+
81
+ def test_generate_returns_expected_keys(self) -> None:
82
+ """Generated data should contain all expected keys."""
83
+ params = GaussianParams(seed=42)
84
+ result = GaussianGenerator.generate(params)
85
+
86
+ expected_keys = {"X_train", "y_train", "X_test", "y_test", "X_full", "y_full"}
87
+ assert set(result.keys()) == expected_keys
88
+
89
+ def test_generate_shapes(self) -> None:
90
+ """Generated arrays should have correct shapes."""
91
+ params = GaussianParams(
92
+ n_classes=3,
93
+ n_samples_per_class=40,
94
+ n_features=5,
95
+ seed=42,
96
+ )
97
+ result = GaussianGenerator.generate(params)
98
+
99
+ total_samples = 3 * 40
100
+ assert result["X_full"].shape == (total_samples, 5)
101
+ assert result["y_full"].shape == (total_samples, 3)
102
+
103
+ def test_generate_dtypes(self) -> None:
104
+ """Generated arrays should have float32 dtype."""
105
+ params = GaussianParams(seed=42)
106
+ result = GaussianGenerator.generate(params)
107
+
108
+ assert result["X_train"].dtype == np.float32
109
+ assert result["y_train"].dtype == np.float32
110
+ assert result["X_full"].dtype == np.float32
111
+ assert result["y_full"].dtype == np.float32
112
+
113
+ def test_determinism_with_seed(self) -> None:
114
+ """Same seed should produce identical results."""
115
+ params = GaussianParams(seed=123)
116
+
117
+ result1 = GaussianGenerator.generate(params)
118
+ result2 = GaussianGenerator.generate(params)
119
+
120
+ np.testing.assert_array_equal(result1["X_full"], result2["X_full"])
121
+ np.testing.assert_array_equal(result1["y_full"], result2["y_full"])
122
+
123
+ def test_different_seeds_produce_different_data(self) -> None:
124
+ """Different seeds should produce different results."""
125
+ params1 = GaussianParams(seed=42)
126
+ params2 = GaussianParams(seed=43)
127
+
128
+ result1 = GaussianGenerator.generate(params1)
129
+ result2 = GaussianGenerator.generate(params2)
130
+
131
+ assert not np.allclose(result1["X_full"], result2["X_full"])
132
+
133
+ def test_one_hot_labels(self) -> None:
134
+ """Labels should be valid one-hot encoded."""
135
+ params = GaussianParams(n_classes=4, seed=42)
136
+ result = GaussianGenerator.generate(params)
137
+
138
+ row_sums = result["y_full"].sum(axis=1)
139
+ np.testing.assert_array_almost_equal(row_sums, np.ones(len(row_sums)))
140
+
141
+ for row in result["y_full"]:
142
+ assert np.sum(row == 1.0) == 1
143
+ assert np.sum(row == 0.0) == params.n_classes - 1
144
+
145
+ def test_class_distribution(self) -> None:
146
+ """Each class should have n_samples_per_class samples."""
147
+ params = GaussianParams(n_classes=3, n_samples_per_class=50, seed=42)
148
+ result = GaussianGenerator.generate(params)
149
+
150
+ class_counts = result["y_full"].sum(axis=0)
151
+ np.testing.assert_array_equal(class_counts, [50, 50, 50])
152
+
153
+ def test_train_test_split_ratio(self) -> None:
154
+ """Train/test split should respect configured ratios."""
155
+ params = GaussianParams(
156
+ n_samples_per_class=50,
157
+ train_ratio=0.7,
158
+ test_ratio=0.3,
159
+ seed=42,
160
+ )
161
+ result = GaussianGenerator.generate(params)
162
+
163
+ total = 2 * 50
164
+ expected_train = int(total * 0.7)
165
+ expected_test = int(total * 0.3)
166
+
167
+ assert len(result["X_train"]) == expected_train
168
+ assert len(result["X_test"]) == expected_test
169
+
170
+ def test_custom_centers(self) -> None:
171
+ """Custom centers should position class means correctly."""
172
+ centers = [[0.0, 0.0], [10.0, 10.0]]
173
+ params = GaussianParams(
174
+ n_classes=2,
175
+ n_samples_per_class=100,
176
+ centers=centers,
177
+ class_std=0.1,
178
+ noise=0.0,
179
+ seed=42,
180
+ )
181
+ result = GaussianGenerator.generate(params)
182
+
183
+ class_0_samples = result["X_full"][:100]
184
+ class_1_samples = result["X_full"][100:]
185
+
186
+ class_0_mean = class_0_samples.mean(axis=0)
187
+ class_1_mean = class_1_samples.mean(axis=0)
188
+
189
+ np.testing.assert_array_almost_equal(class_0_mean, [0.0, 0.0], decimal=0)
190
+ np.testing.assert_array_almost_equal(class_1_mean, [10.0, 10.0], decimal=0)
191
+
192
+ def test_centers_dimension_mismatch_raises_error(self) -> None:
193
+ """Centers with wrong dimensions should raise error."""
194
+ centers = [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]
195
+ params = GaussianParams(
196
+ n_classes=2,
197
+ n_features=2,
198
+ centers=centers,
199
+ seed=42,
200
+ )
201
+ with pytest.raises(ValueError, match="n_features"):
202
+ GaussianGenerator.generate(params)
203
+
204
+ def test_centers_count_mismatch_raises_error(self) -> None:
205
+ """Wrong number of centers should raise error."""
206
+ centers = [[0.0, 0.0]]
207
+ params = GaussianParams(
208
+ n_classes=2,
209
+ n_features=2,
210
+ centers=centers,
211
+ seed=42,
212
+ )
213
+ with pytest.raises(ValueError, match="n_classes"):
214
+ GaussianGenerator.generate(params)
215
+
216
+ def test_noise_adds_variation(self) -> None:
217
+ """Noise parameter should increase data variance."""
218
+ params_no_noise = GaussianParams(
219
+ n_samples_per_class=100,
220
+ class_std=0.5,
221
+ noise=0.0,
222
+ seed=42,
223
+ )
224
+ params_with_noise = GaussianParams(
225
+ n_samples_per_class=100,
226
+ class_std=0.5,
227
+ noise=1.0,
228
+ seed=42,
229
+ )
230
+
231
+ result_no_noise = GaussianGenerator.generate(params_no_noise)
232
+ result_with_noise = GaussianGenerator.generate(params_with_noise)
233
+
234
+ var_no_noise = np.var(result_no_noise["X_full"])
235
+ var_with_noise = np.var(result_with_noise["X_full"])
236
+
237
+ assert var_with_noise > var_no_noise
238
+
239
+ def test_auto_center_placement(self) -> None:
240
+ """Auto-placed centers should be on a circle."""
241
+ params = GaussianParams(
242
+ n_classes=4,
243
+ n_samples_per_class=100,
244
+ center_radius=5.0,
245
+ class_std=0.1,
246
+ seed=42,
247
+ )
248
+ result = GaussianGenerator.generate(params)
249
+
250
+ for i in range(4):
251
+ start = i * 100
252
+ end = start + 100
253
+ class_mean = result["X_full"][start:end].mean(axis=0)
254
+ distance_from_origin = np.linalg.norm(class_mean)
255
+ np.testing.assert_almost_equal(distance_from_origin, 5.0, decimal=0)
256
+
257
+ def test_generate_with_list_class_std(self) -> None:
258
+ """Per-class std list should apply different stds to each class."""
259
+ params = GaussianParams(
260
+ n_classes=3,
261
+ n_samples_per_class=100,
262
+ class_std=[0.1, 0.5, 2.0],
263
+ seed=42,
264
+ )
265
+ result = GaussianGenerator.generate(params)
266
+
267
+ assert result["X_full"].shape == (300, 2)
268
+
269
+ def test_generate_single_feature(self) -> None:
270
+ """Single feature should skip sin component in center placement."""
271
+ params = GaussianParams(
272
+ n_classes=2,
273
+ n_samples_per_class=50,
274
+ n_features=1,
275
+ seed=42,
276
+ )
277
+ result = GaussianGenerator.generate(params)
278
+
279
+ assert result["X_full"].shape == (100, 1)
280
+
281
+ def test_get_stds_scalar(self) -> None:
282
+ """Scalar class_std should return a list of repeated values."""
283
+ params = GaussianParams(n_classes=3, class_std=0.5)
284
+ stds = GaussianGenerator._get_stds(params)
285
+ assert stds == [0.5, 0.5, 0.5]
286
+
287
+ def test_get_stds_list(self) -> None:
288
+ """List class_std should be returned as-is."""
289
+ params = GaussianParams(n_classes=3, class_std=[0.1, 0.5, 2.0])
290
+ stds = GaussianGenerator._get_stds(params)
291
+ assert stds == [0.1, 0.5, 2.0]
292
+
293
+
294
+ class TestGetSchema:
295
+ """Tests for get_schema function."""
296
+
297
+ def test_returns_dict(self) -> None:
298
+ """get_schema should return a dictionary."""
299
+ schema = get_schema()
300
+ assert isinstance(schema, dict)
301
+
302
+ def test_schema_has_properties(self) -> None:
303
+ """Schema should have properties key."""
304
+ schema = get_schema()
305
+ assert "properties" in schema
306
+
307
+ def test_schema_includes_all_params(self) -> None:
308
+ """Schema should include all parameter names."""
309
+ schema = get_schema()
310
+ expected_params = {
311
+ "n_classes",
312
+ "n_samples_per_class",
313
+ "n_features",
314
+ "class_std",
315
+ "centers",
316
+ "center_radius",
317
+ "noise",
318
+ "seed",
319
+ "train_ratio",
320
+ "test_ratio",
321
+ "shuffle",
322
+ }
323
+ assert expected_params.issubset(set(schema["properties"].keys()))
324
+
325
+
326
+ class TestVersion:
327
+ """Tests for VERSION constant."""
328
+
329
+ def test_version_format(self) -> None:
330
+ """VERSION should be a valid semver string."""
331
+ parts = VERSION.split(".")
332
+ assert len(parts) == 3
333
+ assert all(part.isdigit() for part in parts)