juniper-data 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. juniper_data/__init__.py +88 -0
  2. juniper_data/__main__.py +78 -0
  3. juniper_data/api/__init__.py +10 -0
  4. juniper_data/api/app.py +111 -0
  5. juniper_data/api/middleware.py +95 -0
  6. juniper_data/api/routes/__init__.py +9 -0
  7. juniper_data/api/routes/datasets.py +414 -0
  8. juniper_data/api/routes/generators.py +125 -0
  9. juniper_data/api/routes/health.py +49 -0
  10. juniper_data/api/security.py +238 -0
  11. juniper_data/api/settings.py +109 -0
  12. juniper_data/core/__init__.py +32 -0
  13. juniper_data/core/artifacts.py +63 -0
  14. juniper_data/core/dataset_id.py +38 -0
  15. juniper_data/core/models.py +135 -0
  16. juniper_data/core/split.py +120 -0
  17. juniper_data/generators/__init__.py +15 -0
  18. juniper_data/generators/arc_agi/__init__.py +11 -0
  19. juniper_data/generators/arc_agi/generator.py +229 -0
  20. juniper_data/generators/arc_agi/params.py +56 -0
  21. juniper_data/generators/checkerboard/__init__.py +15 -0
  22. juniper_data/generators/checkerboard/generator.py +114 -0
  23. juniper_data/generators/checkerboard/params.py +32 -0
  24. juniper_data/generators/circles/__init__.py +11 -0
  25. juniper_data/generators/circles/generator.py +112 -0
  26. juniper_data/generators/circles/params.py +31 -0
  27. juniper_data/generators/csv_import/__init__.py +15 -0
  28. juniper_data/generators/csv_import/generator.py +198 -0
  29. juniper_data/generators/csv_import/params.py +48 -0
  30. juniper_data/generators/gaussian/__init__.py +11 -0
  31. juniper_data/generators/gaussian/generator.py +149 -0
  32. juniper_data/generators/gaussian/params.py +53 -0
  33. juniper_data/generators/mnist/__init__.py +11 -0
  34. juniper_data/generators/mnist/generator.py +124 -0
  35. juniper_data/generators/mnist/params.py +39 -0
  36. juniper_data/generators/spiral/__init__.py +57 -0
  37. juniper_data/generators/spiral/defaults.py +39 -0
  38. juniper_data/generators/spiral/generator.py +206 -0
  39. juniper_data/generators/spiral/params.py +148 -0
  40. juniper_data/generators/xor/__init__.py +11 -0
  41. juniper_data/generators/xor/generator.py +162 -0
  42. juniper_data/generators/xor/params.py +30 -0
  43. juniper_data/storage/__init__.py +120 -0
  44. juniper_data/storage/base.py +279 -0
  45. juniper_data/storage/cached.py +211 -0
  46. juniper_data/storage/hf_store.py +257 -0
  47. juniper_data/storage/kaggle_store.py +333 -0
  48. juniper_data/storage/local_fs.py +232 -0
  49. juniper_data/storage/memory.py +136 -0
  50. juniper_data/storage/postgres_store.py +373 -0
  51. juniper_data/storage/redis_store.py +264 -0
  52. juniper_data/tests/__init__.py +1 -0
  53. juniper_data/tests/conftest.py +68 -0
  54. juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
  55. juniper_data/tests/integration/__init__.py +1 -0
  56. juniper_data/tests/integration/test_api.py +283 -0
  57. juniper_data/tests/integration/test_e2e_workflow.py +378 -0
  58. juniper_data/tests/integration/test_lifecycle_api.py +304 -0
  59. juniper_data/tests/integration/test_security_integration.py +189 -0
  60. juniper_data/tests/integration/test_storage_workflow.py +259 -0
  61. juniper_data/tests/performance/__init__.py +1 -0
  62. juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
  63. juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
  64. juniper_data/tests/unit/__init__.py +1 -0
  65. juniper_data/tests/unit/test_api_app.py +206 -0
  66. juniper_data/tests/unit/test_api_routes.py +407 -0
  67. juniper_data/tests/unit/test_api_settings.py +100 -0
  68. juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
  69. juniper_data/tests/unit/test_artifacts.py +145 -0
  70. juniper_data/tests/unit/test_cached_store.py +423 -0
  71. juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
  72. juniper_data/tests/unit/test_circles_generator.py +256 -0
  73. juniper_data/tests/unit/test_csv_import_generator.py +345 -0
  74. juniper_data/tests/unit/test_dataset_id.py +181 -0
  75. juniper_data/tests/unit/test_gaussian_generator.py +333 -0
  76. juniper_data/tests/unit/test_hf_store.py +416 -0
  77. juniper_data/tests/unit/test_init.py +93 -0
  78. juniper_data/tests/unit/test_kaggle_store.py +469 -0
  79. juniper_data/tests/unit/test_lifecycle.py +394 -0
  80. juniper_data/tests/unit/test_main.py +127 -0
  81. juniper_data/tests/unit/test_middleware.py +79 -0
  82. juniper_data/tests/unit/test_mnist_generator.py +370 -0
  83. juniper_data/tests/unit/test_postgres_store.py +490 -0
  84. juniper_data/tests/unit/test_redis_store.py +500 -0
  85. juniper_data/tests/unit/test_security.py +281 -0
  86. juniper_data/tests/unit/test_security_boundaries.py +517 -0
  87. juniper_data/tests/unit/test_spiral_generator.py +566 -0
  88. juniper_data/tests/unit/test_split.py +245 -0
  89. juniper_data/tests/unit/test_storage.py +767 -0
  90. juniper_data/tests/unit/test_xor_generator.py +223 -0
  91. juniper_data-0.4.2.dist-info/METADATA +216 -0
  92. juniper_data-0.4.2.dist-info/RECORD +95 -0
  93. juniper_data-0.4.2.dist-info/WHEEL +5 -0
  94. juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
  95. juniper_data-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,566 @@
1
+ """Unit tests for the SpiralGenerator and related modules.
2
+
3
+ Tests cover:
4
+ - Output shapes and dimensions
5
+ - One-hot encoding correctness
6
+ - Deterministic reproducibility
7
+ - Parameter validation
8
+ """
9
+
10
+ import numpy as np
11
+ import pytest
12
+ from pydantic import ValidationError
13
+
14
+ from juniper_data.generators.spiral import SpiralGenerator, SpiralParams
15
+
16
+
17
+ @pytest.mark.unit
18
+ @pytest.mark.spiral
19
+ @pytest.mark.generators
20
+ class TestSpiralShapes:
21
+ """Tests for spiral dataset output shapes and dimensions."""
22
+
23
+ def test_2_spiral_shapes(self, two_spiral_params: SpiralParams) -> None:
24
+ """Verify X is (200, 2) and y is (200, 2) for n_spirals=2, n_points=100."""
25
+ result = SpiralGenerator.generate(two_spiral_params)
26
+
27
+ assert result["X_full"].shape == (200, 2)
28
+ assert result["y_full"].shape == (200, 2)
29
+
30
+ def test_3_spiral_shapes(self, three_spiral_params: SpiralParams) -> None:
31
+ """Verify X is (150, 2) and y is (150, 3) for n_spirals=3, n_points=50."""
32
+ result = SpiralGenerator.generate(three_spiral_params)
33
+
34
+ assert result["X_full"].shape == (150, 2)
35
+ assert result["y_full"].shape == (150, 3)
36
+
37
+ def test_train_test_split_sizes(self, two_spiral_params: SpiralParams) -> None:
38
+ """Verify train/test sizes match ratios within ±1."""
39
+ result = SpiralGenerator.generate(two_spiral_params)
40
+
41
+ total_points = two_spiral_params.total_points()
42
+ expected_train = int(np.round(total_points * two_spiral_params.train_ratio))
43
+ expected_test = int(np.round(total_points * two_spiral_params.test_ratio))
44
+
45
+ assert abs(result["X_train"].shape[0] - expected_train) <= 1
46
+ assert abs(result["y_train"].shape[0] - expected_train) <= 1
47
+ assert abs(result["X_test"].shape[0] - expected_test) <= 1
48
+ assert abs(result["y_test"].shape[0] - expected_test) <= 1
49
+
50
+ def test_custom_split_ratios(self) -> None:
51
+ """Verify custom train/test ratios are honored."""
52
+ params = SpiralParams(
53
+ n_spirals=2,
54
+ n_points_per_spiral=100,
55
+ train_ratio=0.6,
56
+ test_ratio=0.3,
57
+ seed=42,
58
+ )
59
+ result = SpiralGenerator.generate(params)
60
+
61
+ total_points = params.total_points()
62
+ expected_train = int(np.round(total_points * params.train_ratio))
63
+ expected_test = int(np.round(total_points * params.test_ratio))
64
+
65
+ assert abs(result["X_train"].shape[0] - expected_train) <= 1
66
+ assert abs(result["X_test"].shape[0] - expected_test) <= 1
67
+
68
+ def test_output_keys_present(self, generated_minimal_dataset: dict[str, np.ndarray]) -> None:
69
+ """Verify all expected keys are present in output."""
70
+ expected_keys = {"X_train", "y_train", "X_test", "y_test", "X_full", "y_full"}
71
+ assert set(generated_minimal_dataset.keys()) == expected_keys
72
+
73
+
74
+ @pytest.mark.unit
75
+ @pytest.mark.spiral
76
+ @pytest.mark.generators
77
+ class TestOneHotEncoding:
78
+ """Tests for one-hot label encoding correctness."""
79
+
80
+ def test_row_sums_to_one(self, generated_two_spiral_dataset: dict[str, np.ndarray]) -> None:
81
+ """Verify each row of y sums to 1.0."""
82
+ y_full = generated_two_spiral_dataset["y_full"]
83
+
84
+ row_sums = y_full.sum(axis=1)
85
+ expected = np.ones(y_full.shape[0], dtype=np.float32)
86
+
87
+ np.testing.assert_allclose(row_sums, expected, rtol=1e-6)
88
+
89
+ def test_class_distribution(self) -> None:
90
+ """Verify each class has n_points_per_spiral samples in full dataset."""
91
+ params = SpiralParams(
92
+ n_spirals=3,
93
+ n_points_per_spiral=50,
94
+ seed=42,
95
+ )
96
+ result = SpiralGenerator.generate(params)
97
+
98
+ y_full = result["y_full"]
99
+ class_counts = y_full.sum(axis=0).astype(int)
100
+
101
+ expected_counts = np.array([50, 50, 50])
102
+ np.testing.assert_array_equal(class_counts, expected_counts)
103
+
104
+ def test_dtype_is_float32(self, generated_two_spiral_dataset: dict[str, np.ndarray]) -> None:
105
+ """Verify arrays are float32 dtype."""
106
+ assert generated_two_spiral_dataset["X_full"].dtype == np.float32
107
+ assert generated_two_spiral_dataset["y_full"].dtype == np.float32
108
+ assert generated_two_spiral_dataset["X_train"].dtype == np.float32
109
+ assert generated_two_spiral_dataset["y_train"].dtype == np.float32
110
+ assert generated_two_spiral_dataset["X_test"].dtype == np.float32
111
+ assert generated_two_spiral_dataset["y_test"].dtype == np.float32
112
+
113
+ def test_one_hot_values_binary(self, generated_minimal_dataset: dict[str, np.ndarray]) -> None:
114
+ """Verify one-hot encoding contains only 0.0 and 1.0 values."""
115
+ y_full = generated_minimal_dataset["y_full"]
116
+ unique_values = np.unique(y_full)
117
+
118
+ assert len(unique_values) == 2
119
+ assert 0.0 in unique_values
120
+ assert 1.0 in unique_values
121
+
122
+
123
+ @pytest.mark.unit
124
+ @pytest.mark.spiral
125
+ @pytest.mark.generators
126
+ class TestDeterminism:
127
+ """Tests for deterministic reproducibility."""
128
+
129
+ def test_same_seed_identical_output(self) -> None:
130
+ """Verify same params+seed produces bitwise identical arrays."""
131
+ params1 = SpiralParams(
132
+ n_spirals=2,
133
+ n_points_per_spiral=50,
134
+ seed=12345,
135
+ )
136
+ params2 = SpiralParams(
137
+ n_spirals=2,
138
+ n_points_per_spiral=50,
139
+ seed=12345,
140
+ )
141
+
142
+ result1 = SpiralGenerator.generate(params1)
143
+ result2 = SpiralGenerator.generate(params2)
144
+
145
+ np.testing.assert_array_equal(result1["X_full"], result2["X_full"])
146
+ np.testing.assert_array_equal(result1["y_full"], result2["y_full"])
147
+ np.testing.assert_array_equal(result1["X_train"], result2["X_train"])
148
+ np.testing.assert_array_equal(result1["y_train"], result2["y_train"])
149
+ np.testing.assert_array_equal(result1["X_test"], result2["X_test"])
150
+ np.testing.assert_array_equal(result1["y_test"], result2["y_test"])
151
+
152
+ def test_different_seed_different_output(self) -> None:
153
+ """Verify different seeds produce different arrays."""
154
+ params1 = SpiralParams(
155
+ n_spirals=2,
156
+ n_points_per_spiral=50,
157
+ seed=12345,
158
+ )
159
+ params2 = SpiralParams(
160
+ n_spirals=2,
161
+ n_points_per_spiral=50,
162
+ seed=54321,
163
+ )
164
+
165
+ result1 = SpiralGenerator.generate(params1)
166
+ result2 = SpiralGenerator.generate(params2)
167
+
168
+ assert not np.allclose(result1["X_full"], result2["X_full"])
169
+
170
+ def test_multiple_calls_same_seed_identical(self) -> None:
171
+ """Verify multiple sequential calls with same seed are identical."""
172
+ params = SpiralParams(n_spirals=2, n_points_per_spiral=30, seed=999)
173
+
174
+ results = [SpiralGenerator.generate(params) for _ in range(3)]
175
+
176
+ for i in range(1, len(results)):
177
+ np.testing.assert_array_equal(results[0]["X_full"], results[i]["X_full"])
178
+ np.testing.assert_array_equal(results[0]["y_full"], results[i]["y_full"])
179
+
180
+
181
+ @pytest.mark.unit
182
+ @pytest.mark.spiral
183
+ @pytest.mark.generators
184
+ class TestParamValidation:
185
+ """Tests for parameter validation errors."""
186
+
187
+ def test_invalid_n_spirals_too_low(self) -> None:
188
+ """Verify n_spirals < 2 raises ValidationError."""
189
+ with pytest.raises(ValidationError) as exc_info:
190
+ SpiralParams(n_spirals=1)
191
+
192
+ error_str = str(exc_info.value)
193
+ assert "n_spirals" in error_str or "greater than or equal to" in error_str
194
+
195
+ def test_invalid_n_spirals_zero(self) -> None:
196
+ """Verify n_spirals=0 raises ValidationError."""
197
+ with pytest.raises(ValidationError):
198
+ SpiralParams(n_spirals=0)
199
+
200
+ def test_invalid_n_spirals_negative(self) -> None:
201
+ """Verify negative n_spirals raises ValidationError."""
202
+ with pytest.raises(ValidationError):
203
+ SpiralParams(n_spirals=-1)
204
+
205
+ def test_invalid_n_points_too_low(self) -> None:
206
+ """Verify n_points < 10 raises ValidationError."""
207
+ with pytest.raises(ValidationError) as exc_info:
208
+ SpiralParams(n_points_per_spiral=5)
209
+
210
+ error_str = str(exc_info.value)
211
+ assert "n_points_per_spiral" in error_str or "greater than or equal to" in error_str
212
+
213
+ def test_invalid_n_points_zero(self) -> None:
214
+ """Verify n_points=0 raises ValidationError."""
215
+ with pytest.raises(ValidationError):
216
+ SpiralParams(n_points_per_spiral=0)
217
+
218
+ def test_invalid_ratios_exceed_one(self) -> None:
219
+ """Verify train_ratio + test_ratio > 1.0 raises ValidationError."""
220
+ with pytest.raises(ValidationError) as exc_info:
221
+ SpiralParams(train_ratio=0.7, test_ratio=0.5)
222
+
223
+ error_str = str(exc_info.value)
224
+ assert "train_ratio" in error_str or "test_ratio" in error_str or "<= 1.0" in error_str
225
+
226
+ def test_invalid_train_ratio_negative(self) -> None:
227
+ """Verify negative train_ratio raises ValidationError."""
228
+ with pytest.raises(ValidationError):
229
+ SpiralParams(train_ratio=-0.1)
230
+
231
+ def test_invalid_test_ratio_negative(self) -> None:
232
+ """Verify negative test_ratio raises ValidationError."""
233
+ with pytest.raises(ValidationError):
234
+ SpiralParams(test_ratio=-0.1)
235
+
236
+ def test_invalid_noise_negative(self) -> None:
237
+ """Verify negative noise raises ValidationError."""
238
+ with pytest.raises(ValidationError):
239
+ SpiralParams(noise=-0.1)
240
+
241
+ def test_invalid_noise_too_high(self) -> None:
242
+ """Verify noise > MAX_NOISE raises ValidationError."""
243
+ with pytest.raises(ValidationError):
244
+ SpiralParams(noise=10.0)
245
+
246
+ def test_invalid_n_rotations_too_low(self) -> None:
247
+ """Verify n_rotations < MIN_ROTATIONS raises ValidationError."""
248
+ with pytest.raises(ValidationError):
249
+ SpiralParams(n_rotations=0.1)
250
+
251
+ def test_valid_edge_case_min_values(self) -> None:
252
+ """Verify minimum valid values are accepted."""
253
+ params = SpiralParams(
254
+ n_spirals=2,
255
+ n_points_per_spiral=10,
256
+ n_rotations=0.5,
257
+ noise=0.0,
258
+ )
259
+ assert params.n_spirals == 2
260
+ assert params.n_points_per_spiral == 10
261
+ assert params.n_rotations == 0.5
262
+ assert params.noise == 0.0
263
+
264
+
265
+ @pytest.mark.unit
266
+ @pytest.mark.spiral
267
+ @pytest.mark.generators
268
+ class TestSpiralGeometry:
269
+ """Tests for spiral geometric properties."""
270
+
271
+ def test_coordinates_centered_near_origin(self, generated_two_spiral_dataset: dict[str, np.ndarray]) -> None:
272
+ """Verify spiral coordinates are centered roughly around origin."""
273
+ X_full = generated_two_spiral_dataset["X_full"]
274
+ mean_x = X_full[:, 0].mean()
275
+ mean_y = X_full[:, 1].mean()
276
+
277
+ assert abs(mean_x) < 2.0
278
+ assert abs(mean_y) < 2.0
279
+
280
+ def test_coordinates_within_expected_radius(self) -> None:
281
+ """Verify coordinates fall within expected radius bounds."""
282
+ params = SpiralParams(
283
+ n_spirals=2,
284
+ n_points_per_spiral=100,
285
+ noise=0.0,
286
+ seed=42,
287
+ )
288
+ result = SpiralGenerator.generate(params)
289
+
290
+ X_full = result["X_full"]
291
+ distances = np.sqrt(X_full[:, 0] ** 2 + X_full[:, 1] ** 2)
292
+ max_distance = distances.max()
293
+
294
+ assert max_distance <= 12.0
295
+
296
+ def test_noise_increases_variance(self) -> None:
297
+ """Verify adding noise increases coordinate variance."""
298
+ params_no_noise = SpiralParams(
299
+ n_spirals=2,
300
+ n_points_per_spiral=100,
301
+ noise=0.0,
302
+ seed=42,
303
+ )
304
+ params_with_noise = SpiralParams(
305
+ n_spirals=2,
306
+ n_points_per_spiral=100,
307
+ noise=1.0,
308
+ seed=42,
309
+ )
310
+
311
+ result_no_noise = SpiralGenerator.generate(params_no_noise)
312
+ result_with_noise = SpiralGenerator.generate(params_with_noise)
313
+
314
+ var_no_noise = result_no_noise["X_full"].var()
315
+ var_with_noise = result_with_noise["X_full"].var()
316
+
317
+ assert var_with_noise > var_no_noise
318
+
319
+
320
+ @pytest.mark.unit
321
+ @pytest.mark.spiral
322
+ @pytest.mark.generators
323
+ class TestSpiralGeneratorLegacyMode:
324
+ """Tests for legacy_cascor algorithm mode."""
325
+
326
+ def test_legacy_mode_generates_correct_shapes(self) -> None:
327
+ """Verify legacy_cascor mode generates arrays with correct shapes."""
328
+ params = SpiralParams(
329
+ n_spirals=2,
330
+ n_points_per_spiral=50,
331
+ algorithm="legacy_cascor",
332
+ seed=42,
333
+ )
334
+ result = SpiralGenerator.generate(params)
335
+
336
+ assert result["X_full"].shape == (100, 2)
337
+ assert result["y_full"].shape == (100, 2)
338
+ assert result["X_train"].shape[1] == 2
339
+ assert result["y_train"].shape[1] == 2
340
+
341
+ def test_legacy_mode_deterministic_with_seed(self) -> None:
342
+ """Verify same seed produces identical arrays in legacy mode."""
343
+ params1 = SpiralParams(
344
+ n_spirals=2,
345
+ n_points_per_spiral=50,
346
+ algorithm="legacy_cascor",
347
+ seed=12345,
348
+ )
349
+ params2 = SpiralParams(
350
+ n_spirals=2,
351
+ n_points_per_spiral=50,
352
+ algorithm="legacy_cascor",
353
+ seed=12345,
354
+ )
355
+
356
+ result1 = SpiralGenerator.generate(params1)
357
+ result2 = SpiralGenerator.generate(params2)
358
+
359
+ np.testing.assert_array_equal(result1["X_full"], result2["X_full"])
360
+ np.testing.assert_array_equal(result1["y_full"], result2["y_full"])
361
+
362
+ def test_legacy_mode_different_from_modern(self) -> None:
363
+ """Verify legacy_cascor produces different results than modern algorithm."""
364
+ params_modern = SpiralParams(
365
+ n_spirals=2,
366
+ n_points_per_spiral=50,
367
+ algorithm="modern",
368
+ seed=42,
369
+ )
370
+ params_legacy = SpiralParams(
371
+ n_spirals=2,
372
+ n_points_per_spiral=50,
373
+ algorithm="legacy_cascor",
374
+ seed=42,
375
+ )
376
+
377
+ result_modern = SpiralGenerator.generate(params_modern)
378
+ result_legacy = SpiralGenerator.generate(params_legacy)
379
+
380
+ assert not np.allclose(result_modern["X_full"], result_legacy["X_full"])
381
+
382
+ def test_legacy_mode_uniform_noise_range(self) -> None:
383
+ """Verify legacy mode uses uniform noise in [0, noise) range."""
384
+ params = SpiralParams(
385
+ n_spirals=2,
386
+ n_points_per_spiral=100,
387
+ algorithm="legacy_cascor",
388
+ noise=1.0,
389
+ seed=42,
390
+ )
391
+ result = SpiralGenerator.generate(params)
392
+
393
+ params_no_noise = SpiralParams(
394
+ n_spirals=2,
395
+ n_points_per_spiral=100,
396
+ algorithm="legacy_cascor",
397
+ noise=0.0,
398
+ seed=42,
399
+ )
400
+ result_no_noise = SpiralGenerator.generate(params_no_noise)
401
+
402
+ noise_x = result["X_full"][:, 0] - result_no_noise["X_full"][:, 0]
403
+ noise_y = result["X_full"][:, 1] - result_no_noise["X_full"][:, 1]
404
+
405
+ assert noise_x.min() >= 0.0
406
+ assert noise_x.max() < 1.0
407
+ assert noise_y.min() >= 0.0
408
+ assert noise_y.max() < 1.0
409
+
410
+ def test_legacy_mode_radii_distribution(self) -> None:
411
+ """Verify legacy mode uses sqrt-uniform radii distribution."""
412
+ params = SpiralParams(
413
+ n_spirals=2,
414
+ n_points_per_spiral=1000,
415
+ algorithm="legacy_cascor",
416
+ noise=0.0,
417
+ seed=42,
418
+ )
419
+ result = SpiralGenerator.generate(params)
420
+
421
+ X = result["X_full"]
422
+ radii = np.sqrt(X[:, 0] ** 2 + X[:, 1] ** 2)
423
+
424
+ radii_squared = radii**2
425
+ radii_squared_normalized = radii_squared / radii_squared.max()
426
+
427
+ assert radii_squared_normalized.mean() < 0.6
428
+
429
+ def test_origin_offset_works(self) -> None:
430
+ """Verify origin parameter shifts the dataset center."""
431
+ params_centered = SpiralParams(
432
+ n_spirals=2,
433
+ n_points_per_spiral=100,
434
+ algorithm="legacy_cascor",
435
+ origin=(0.0, 0.0),
436
+ seed=42,
437
+ )
438
+ params_offset = SpiralParams(
439
+ n_spirals=2,
440
+ n_points_per_spiral=100,
441
+ algorithm="legacy_cascor",
442
+ origin=(5.0, 10.0),
443
+ seed=42,
444
+ )
445
+
446
+ result_centered = SpiralGenerator.generate(params_centered)
447
+ result_offset = SpiralGenerator.generate(params_offset)
448
+
449
+ mean_centered = result_centered["X_full"].mean(axis=0)
450
+ mean_offset = result_offset["X_full"].mean(axis=0)
451
+
452
+ np.testing.assert_allclose(mean_offset[0] - mean_centered[0], 5.0, atol=0.1)
453
+ np.testing.assert_allclose(mean_offset[1] - mean_centered[1], 10.0, atol=0.1)
454
+
455
+ def test_radius_parameter_controls_spread(self) -> None:
456
+ """Verify radius parameter controls the spread of points."""
457
+ params_small = SpiralParams(
458
+ n_spirals=2,
459
+ n_points_per_spiral=100,
460
+ algorithm="legacy_cascor",
461
+ radius=5.0,
462
+ noise=0.0,
463
+ seed=42,
464
+ )
465
+ params_large = SpiralParams(
466
+ n_spirals=2,
467
+ n_points_per_spiral=100,
468
+ algorithm="legacy_cascor",
469
+ radius=20.0,
470
+ noise=0.0,
471
+ seed=42,
472
+ )
473
+
474
+ result_small = SpiralGenerator.generate(params_small)
475
+ result_large = SpiralGenerator.generate(params_large)
476
+
477
+ radii_small = np.sqrt(result_small["X_full"][:, 0] ** 2 + result_small["X_full"][:, 1] ** 2)
478
+ radii_large = np.sqrt(result_large["X_full"][:, 0] ** 2 + result_large["X_full"][:, 1] ** 2)
479
+
480
+ max_small = radii_small.max()
481
+ max_large = radii_large.max()
482
+
483
+ ratio = max_large / max_small
484
+ assert 3.0 < ratio < 5.0
485
+
486
+ def test_algorithm_param_validation(self) -> None:
487
+ """Verify invalid algorithm values raise ValidationError."""
488
+ with pytest.raises(ValidationError) as exc_info:
489
+ SpiralParams(algorithm="invalid_algorithm") # type: ignore[arg-type] # negative test: ensure runtime validation rejects invalid algorithm
490
+
491
+ error_str = str(exc_info.value)
492
+ assert "algorithm" in error_str or "Input should be" in error_str
493
+
494
+
495
+ @pytest.mark.unit
496
+ @pytest.mark.spiral
497
+ class TestGetSchema:
498
+ """Tests for the get_schema function."""
499
+
500
+ def test_get_schema_returns_dict(self) -> None:
501
+ """Verify get_schema returns a dictionary."""
502
+ from juniper_data.generators.spiral.generator import get_schema
503
+
504
+ schema = get_schema()
505
+ assert isinstance(schema, dict)
506
+
507
+ def test_get_schema_contains_properties(self) -> None:
508
+ """Verify schema contains expected properties."""
509
+ from juniper_data.generators.spiral.generator import get_schema
510
+
511
+ schema = get_schema()
512
+ assert "properties" in schema
513
+ assert "n_spirals" in schema["properties"]
514
+ assert "n_points_per_spiral" in schema["properties"]
515
+ assert "noise" in schema["properties"]
516
+ assert "seed" in schema["properties"]
517
+
518
+ def test_get_schema_contains_title(self) -> None:
519
+ """Verify schema contains title."""
520
+ from juniper_data.generators.spiral.generator import get_schema
521
+
522
+ schema = get_schema()
523
+ assert "title" in schema
524
+ assert schema["title"] == "SpiralParams"
525
+
526
+
527
+ @pytest.mark.unit
528
+ @pytest.mark.spiral
529
+ @pytest.mark.generators
530
+ class TestParameterAliases:
531
+ """Tests for parameter aliases for consumer compatibility."""
532
+
533
+ def test_n_points_alias(self) -> None:
534
+ """Verify n_points is accepted as alias for n_points_per_spiral."""
535
+ params = SpiralParams.model_validate({"n_points": 50, "n_spirals": 2, "seed": 42})
536
+ assert params.n_points_per_spiral == 50
537
+
538
+ def test_noise_level_alias(self) -> None:
539
+ """Verify noise_level is accepted as alias for noise."""
540
+ params = SpiralParams.model_validate({"noise_level": 0.5, "n_spirals": 2, "seed": 42})
541
+ assert params.noise == 0.5
542
+
543
+ def test_canonical_name_takes_precedence(self) -> None:
544
+ """Verify canonical name is used when both are provided."""
545
+ params = SpiralParams(n_points_per_spiral=100, noise=0.2, n_spirals=2, seed=42)
546
+ assert params.n_points_per_spiral == 100
547
+ assert params.noise == 0.2
548
+
549
+ def test_alias_generates_correct_dataset(self) -> None:
550
+ """Verify dataset generation works with alias parameters."""
551
+ params = SpiralParams.model_validate({"n_points": 25, "noise_level": 0.15, "n_spirals": 2, "seed": 42})
552
+ result = SpiralGenerator.generate(params)
553
+
554
+ assert result["X_full"].shape == (50, 2)
555
+ assert result["y_full"].shape == (50, 2)
556
+
557
+ def test_alias_determinism(self) -> None:
558
+ """Verify same seed produces same results regardless of alias usage."""
559
+ params1 = SpiralParams(n_points_per_spiral=30, noise=0.1, n_spirals=2, seed=123)
560
+ params2 = SpiralParams.model_validate({"n_points": 30, "noise_level": 0.1, "n_spirals": 2, "seed": 123})
561
+
562
+ result1 = SpiralGenerator.generate(params1)
563
+ result2 = SpiralGenerator.generate(params2)
564
+
565
+ np.testing.assert_array_equal(result1["X_full"], result2["X_full"])
566
+ np.testing.assert_array_equal(result1["y_full"], result2["y_full"])