juniper-data 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. juniper_data/__init__.py +88 -0
  2. juniper_data/__main__.py +78 -0
  3. juniper_data/api/__init__.py +10 -0
  4. juniper_data/api/app.py +111 -0
  5. juniper_data/api/middleware.py +95 -0
  6. juniper_data/api/routes/__init__.py +9 -0
  7. juniper_data/api/routes/datasets.py +414 -0
  8. juniper_data/api/routes/generators.py +125 -0
  9. juniper_data/api/routes/health.py +49 -0
  10. juniper_data/api/security.py +238 -0
  11. juniper_data/api/settings.py +109 -0
  12. juniper_data/core/__init__.py +32 -0
  13. juniper_data/core/artifacts.py +63 -0
  14. juniper_data/core/dataset_id.py +38 -0
  15. juniper_data/core/models.py +135 -0
  16. juniper_data/core/split.py +120 -0
  17. juniper_data/generators/__init__.py +15 -0
  18. juniper_data/generators/arc_agi/__init__.py +11 -0
  19. juniper_data/generators/arc_agi/generator.py +229 -0
  20. juniper_data/generators/arc_agi/params.py +56 -0
  21. juniper_data/generators/checkerboard/__init__.py +15 -0
  22. juniper_data/generators/checkerboard/generator.py +114 -0
  23. juniper_data/generators/checkerboard/params.py +32 -0
  24. juniper_data/generators/circles/__init__.py +11 -0
  25. juniper_data/generators/circles/generator.py +112 -0
  26. juniper_data/generators/circles/params.py +31 -0
  27. juniper_data/generators/csv_import/__init__.py +15 -0
  28. juniper_data/generators/csv_import/generator.py +198 -0
  29. juniper_data/generators/csv_import/params.py +48 -0
  30. juniper_data/generators/gaussian/__init__.py +11 -0
  31. juniper_data/generators/gaussian/generator.py +149 -0
  32. juniper_data/generators/gaussian/params.py +53 -0
  33. juniper_data/generators/mnist/__init__.py +11 -0
  34. juniper_data/generators/mnist/generator.py +124 -0
  35. juniper_data/generators/mnist/params.py +39 -0
  36. juniper_data/generators/spiral/__init__.py +57 -0
  37. juniper_data/generators/spiral/defaults.py +39 -0
  38. juniper_data/generators/spiral/generator.py +206 -0
  39. juniper_data/generators/spiral/params.py +148 -0
  40. juniper_data/generators/xor/__init__.py +11 -0
  41. juniper_data/generators/xor/generator.py +162 -0
  42. juniper_data/generators/xor/params.py +30 -0
  43. juniper_data/storage/__init__.py +120 -0
  44. juniper_data/storage/base.py +279 -0
  45. juniper_data/storage/cached.py +211 -0
  46. juniper_data/storage/hf_store.py +257 -0
  47. juniper_data/storage/kaggle_store.py +333 -0
  48. juniper_data/storage/local_fs.py +232 -0
  49. juniper_data/storage/memory.py +136 -0
  50. juniper_data/storage/postgres_store.py +373 -0
  51. juniper_data/storage/redis_store.py +264 -0
  52. juniper_data/tests/__init__.py +1 -0
  53. juniper_data/tests/conftest.py +68 -0
  54. juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
  55. juniper_data/tests/integration/__init__.py +1 -0
  56. juniper_data/tests/integration/test_api.py +283 -0
  57. juniper_data/tests/integration/test_e2e_workflow.py +378 -0
  58. juniper_data/tests/integration/test_lifecycle_api.py +304 -0
  59. juniper_data/tests/integration/test_security_integration.py +189 -0
  60. juniper_data/tests/integration/test_storage_workflow.py +259 -0
  61. juniper_data/tests/performance/__init__.py +1 -0
  62. juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
  63. juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
  64. juniper_data/tests/unit/__init__.py +1 -0
  65. juniper_data/tests/unit/test_api_app.py +206 -0
  66. juniper_data/tests/unit/test_api_routes.py +407 -0
  67. juniper_data/tests/unit/test_api_settings.py +100 -0
  68. juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
  69. juniper_data/tests/unit/test_artifacts.py +145 -0
  70. juniper_data/tests/unit/test_cached_store.py +423 -0
  71. juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
  72. juniper_data/tests/unit/test_circles_generator.py +256 -0
  73. juniper_data/tests/unit/test_csv_import_generator.py +345 -0
  74. juniper_data/tests/unit/test_dataset_id.py +181 -0
  75. juniper_data/tests/unit/test_gaussian_generator.py +333 -0
  76. juniper_data/tests/unit/test_hf_store.py +416 -0
  77. juniper_data/tests/unit/test_init.py +93 -0
  78. juniper_data/tests/unit/test_kaggle_store.py +469 -0
  79. juniper_data/tests/unit/test_lifecycle.py +394 -0
  80. juniper_data/tests/unit/test_main.py +127 -0
  81. juniper_data/tests/unit/test_middleware.py +79 -0
  82. juniper_data/tests/unit/test_mnist_generator.py +370 -0
  83. juniper_data/tests/unit/test_postgres_store.py +490 -0
  84. juniper_data/tests/unit/test_redis_store.py +500 -0
  85. juniper_data/tests/unit/test_security.py +281 -0
  86. juniper_data/tests/unit/test_security_boundaries.py +517 -0
  87. juniper_data/tests/unit/test_spiral_generator.py +566 -0
  88. juniper_data/tests/unit/test_split.py +245 -0
  89. juniper_data/tests/unit/test_storage.py +767 -0
  90. juniper_data/tests/unit/test_xor_generator.py +223 -0
  91. juniper_data-0.4.2.dist-info/METADATA +216 -0
  92. juniper_data-0.4.2.dist-info/RECORD +95 -0
  93. juniper_data-0.4.2.dist-info/WHEEL +5 -0
  94. juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
  95. juniper_data-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,370 @@
1
+ """Unit tests for the MNIST dataset generator."""
2
+
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ import numpy as np
6
+ import pytest
7
+
8
+ from juniper_data.generators.mnist.params import MnistParams
9
+
10
+
11
+ def _make_mock_hf_dataset(n_samples=20, n_classes=10):
12
+ """Create a mock HuggingFace dataset for MNIST-like data."""
13
+ mock_ds = MagicMock()
14
+ mock_ds.__len__ = MagicMock(return_value=n_samples)
15
+
16
+ labels = [i % n_classes for i in range(n_samples)]
17
+
18
+ np_images = np.random.randint(0, 255, (n_samples, 28, 28), dtype=np.uint8)
19
+ np_labels = np.array(labels)
20
+
21
+ images = []
22
+ for i in range(n_samples):
23
+ mock_img = MagicMock()
24
+ mock_img.convert.return_value = np_images[i]
25
+ images.append(mock_img)
26
+
27
+ def ds_iter():
28
+ for i in range(n_samples):
29
+ yield {"image": images[i], "label": labels[i]}
30
+
31
+ mock_ds.__iter__ = MagicMock(side_effect=ds_iter)
32
+ mock_ds.__getitem__ = MagicMock(side_effect=lambda key: labels if key == "label" else images)
33
+ mock_ds.shuffle.return_value = mock_ds
34
+ mock_ds.select.return_value = mock_ds
35
+
36
+ # Mock with_format("numpy") to return a dataset providing numpy arrays
37
+ formatted_ds = MagicMock()
38
+ formatted_ds.__getitem__ = MagicMock(side_effect=lambda key: np_labels if key == "label" else np_images)
39
+ mock_ds.with_format.return_value = formatted_ds
40
+
41
+ return mock_ds, labels, images
42
+
43
+
44
+ @pytest.fixture
45
+ def mock_hf_load():
46
+ """Patch HF_AVAILABLE and hf_load_dataset for the mnist generator module."""
47
+ mock_load = MagicMock()
48
+
49
+ with patch("juniper_data.generators.mnist.generator.HF_AVAILABLE", True):
50
+ with patch("juniper_data.generators.mnist.generator.hf_load_dataset", mock_load):
51
+ yield mock_load
52
+
53
+
54
+ @pytest.mark.unit
55
+ @pytest.mark.generators
56
+ class TestMnistParams:
57
+ """Tests for MnistParams validation."""
58
+
59
+ def test_default_params(self) -> None:
60
+ """Default parameters are valid."""
61
+ params = MnistParams()
62
+ assert params.dataset == "mnist"
63
+ assert params.n_samples is None
64
+ assert params.flatten is True
65
+ assert params.normalize is True
66
+ assert params.one_hot_labels is True
67
+ assert params.train_ratio == 0.8
68
+ assert params.test_ratio == 0.2
69
+
70
+ def test_fashion_mnist(self) -> None:
71
+ """Fashion-MNIST variant is accepted."""
72
+ params = MnistParams(dataset="fashion_mnist")
73
+ assert params.dataset == "fashion_mnist"
74
+
75
+ def test_custom_params(self) -> None:
76
+ """Custom parameters are accepted."""
77
+ params = MnistParams(
78
+ n_samples=100,
79
+ flatten=False,
80
+ normalize=False,
81
+ one_hot_labels=False,
82
+ seed=42,
83
+ train_ratio=0.7,
84
+ test_ratio=0.3,
85
+ )
86
+ assert params.n_samples == 100
87
+ assert params.flatten is False
88
+ assert params.seed == 42
89
+
90
+ def test_invalid_n_samples(self) -> None:
91
+ """n_samples must be >= 1."""
92
+ with pytest.raises(ValueError):
93
+ MnistParams(n_samples=0)
94
+
95
+ def test_invalid_train_ratio(self) -> None:
96
+ """train_ratio must be in (0, 1]."""
97
+ with pytest.raises(ValueError):
98
+ MnistParams(train_ratio=0)
99
+
100
+ def test_invalid_dataset_name(self) -> None:
101
+ """Invalid dataset name is rejected."""
102
+ with pytest.raises(ValueError):
103
+ MnistParams(dataset="cifar10") # type: ignore[arg-type]
104
+
105
+
106
+ @pytest.mark.unit
107
+ @pytest.mark.generators
108
+ class TestMnistGenerator:
109
+ """Tests for MnistGenerator functionality."""
110
+
111
+ def test_generate_correct_shapes(self, mock_hf_load) -> None:
112
+ """Generated arrays have correct shapes."""
113
+ mock_ds, _, _ = _make_mock_hf_dataset(n_samples=20)
114
+ mock_hf_load.return_value = mock_ds
115
+
116
+ from juniper_data.generators.mnist.generator import MnistGenerator
117
+
118
+ params = MnistParams(seed=42)
119
+ result = MnistGenerator.generate(params)
120
+
121
+ n_total = 20
122
+ n_train = int(n_total * 0.8)
123
+ n_test = n_total - n_train
124
+
125
+ assert result["X_train"].shape[0] == n_train
126
+ assert result["X_test"].shape[0] == n_test
127
+ assert result["X_full"].shape[0] == n_total
128
+ assert result["y_full"].shape[0] == n_total
129
+
130
+ def test_generate_flattened(self, mock_hf_load) -> None:
131
+ """Flatten option produces 1D features (784 for 28x28)."""
132
+ mock_ds, _, _ = _make_mock_hf_dataset(n_samples=10)
133
+ mock_hf_load.return_value = mock_ds
134
+
135
+ from juniper_data.generators.mnist.generator import MnistGenerator
136
+
137
+ params = MnistParams(flatten=True, seed=42)
138
+ result = MnistGenerator.generate(params)
139
+
140
+ assert result["X_full"].shape[1] == 784
141
+
142
+ def test_generate_not_flattened(self, mock_hf_load) -> None:
143
+ """Non-flatten produces 2D features (28x28)."""
144
+ mock_ds, _, _ = _make_mock_hf_dataset(n_samples=10)
145
+ mock_hf_load.return_value = mock_ds
146
+
147
+ from juniper_data.generators.mnist.generator import MnistGenerator
148
+
149
+ params = MnistParams(flatten=False, seed=42)
150
+ result = MnistGenerator.generate(params)
151
+
152
+ assert result["X_full"].shape[1:] == (28, 28)
153
+
154
+ def test_generate_normalized(self, mock_hf_load) -> None:
155
+ """Normalized values are in [0, 1]."""
156
+ mock_ds, _, _ = _make_mock_hf_dataset(n_samples=10)
157
+ mock_hf_load.return_value = mock_ds
158
+
159
+ from juniper_data.generators.mnist.generator import MnistGenerator
160
+
161
+ params = MnistParams(normalize=True, seed=42)
162
+ result = MnistGenerator.generate(params)
163
+
164
+ assert result["X_full"].max() <= 1.0
165
+ assert result["X_full"].min() >= 0.0
166
+
167
+ def test_generate_not_normalized(self, mock_hf_load) -> None:
168
+ """Non-normalized values can exceed 1.0."""
169
+ mock_ds, _, _ = _make_mock_hf_dataset(n_samples=10)
170
+ mock_hf_load.return_value = mock_ds
171
+
172
+ from juniper_data.generators.mnist.generator import MnistGenerator
173
+
174
+ params = MnistParams(normalize=False, seed=42)
175
+ result = MnistGenerator.generate(params)
176
+
177
+ assert result["X_full"].dtype == np.float32
178
+
179
+ def test_generate_one_hot_labels(self, mock_hf_load) -> None:
180
+ """One-hot encoding produces correct label shape."""
181
+ mock_ds, _, _ = _make_mock_hf_dataset(n_samples=20, n_classes=10)
182
+ mock_hf_load.return_value = mock_ds
183
+
184
+ from juniper_data.generators.mnist.generator import MnistGenerator
185
+
186
+ params = MnistParams(one_hot_labels=True, seed=42)
187
+ result = MnistGenerator.generate(params)
188
+
189
+ assert result["y_full"].shape[1] == 10
190
+ row_sums = result["y_full"].sum(axis=1)
191
+ np.testing.assert_array_almost_equal(row_sums, np.ones(20))
192
+
193
+ def test_generate_integer_labels(self, mock_hf_load) -> None:
194
+ """Non-one-hot produces integer labels."""
195
+ mock_ds, _, _ = _make_mock_hf_dataset(n_samples=10, n_classes=10)
196
+ mock_hf_load.return_value = mock_ds
197
+
198
+ from juniper_data.generators.mnist.generator import MnistGenerator
199
+
200
+ params = MnistParams(one_hot_labels=False, seed=42)
201
+ result = MnistGenerator.generate(params)
202
+
203
+ assert result["y_full"].shape[1] == 1
204
+
205
+ def test_generate_with_seed_shuffle(self, mock_hf_load) -> None:
206
+ """Seed triggers dataset shuffle."""
207
+ mock_ds, _, _ = _make_mock_hf_dataset(n_samples=10)
208
+ mock_hf_load.return_value = mock_ds
209
+
210
+ from juniper_data.generators.mnist.generator import MnistGenerator
211
+
212
+ params = MnistParams(seed=42)
213
+ MnistGenerator.generate(params)
214
+
215
+ mock_ds.shuffle.assert_called_once_with(seed=42)
216
+
217
+ def test_generate_with_n_samples(self, mock_hf_load) -> None:
218
+ """n_samples limits the dataset."""
219
+ mock_ds, _, _ = _make_mock_hf_dataset(n_samples=100)
220
+ mock_hf_load.return_value = mock_ds
221
+
222
+ from juniper_data.generators.mnist.generator import MnistGenerator
223
+
224
+ params = MnistParams(n_samples=50, seed=42)
225
+ MnistGenerator.generate(params)
226
+
227
+ mock_ds.select.assert_called_once()
228
+
229
+ def test_generate_no_seed_no_shuffle(self, mock_hf_load) -> None:
230
+ """Without seed, no shuffle is called."""
231
+ mock_ds, _, _ = _make_mock_hf_dataset(n_samples=10)
232
+ mock_hf_load.return_value = mock_ds
233
+
234
+ from juniper_data.generators.mnist.generator import MnistGenerator
235
+
236
+ params = MnistParams(seed=None, n_samples=None)
237
+ MnistGenerator.generate(params)
238
+
239
+ mock_ds.shuffle.assert_not_called()
240
+ mock_ds.select.assert_not_called()
241
+
242
+ def test_generate_image_without_convert(self, mock_hf_load) -> None:
243
+ """Handle images that don't have a convert method."""
244
+ mock_ds = MagicMock()
245
+ mock_ds.__len__ = MagicMock(return_value=5)
246
+
247
+ raw_images = [np.random.randint(0, 255, (28, 28), dtype=np.uint8) for _ in range(5)]
248
+ labels = [0, 1, 2, 3, 4]
249
+
250
+ def ds_iter():
251
+ for i in range(5):
252
+ yield {"image": raw_images[i], "label": labels[i]}
253
+
254
+ mock_ds.__iter__ = MagicMock(side_effect=ds_iter)
255
+ mock_ds.__getitem__ = MagicMock(side_effect=lambda key: labels if key == "label" else raw_images)
256
+ mock_ds.shuffle.return_value = mock_ds
257
+ mock_ds.select.return_value = mock_ds
258
+
259
+ # Mock with_format("numpy") to return numpy arrays
260
+ np_images = np.array(raw_images)
261
+ np_labels = np.array(labels)
262
+ formatted_ds = MagicMock()
263
+ formatted_ds.__getitem__ = MagicMock(side_effect=lambda key: np_labels if key == "label" else np_images)
264
+ mock_ds.with_format.return_value = formatted_ds
265
+
266
+ mock_hf_load.return_value = mock_ds
267
+
268
+ from juniper_data.generators.mnist.generator import MnistGenerator
269
+
270
+ params = MnistParams(seed=42)
271
+ result = MnistGenerator.generate(params)
272
+ assert result["X_full"].shape[0] == 5
273
+
274
+ def test_generate_raises_without_datasets(self) -> None:
275
+ """Raises ImportError when datasets not installed."""
276
+ with patch("juniper_data.generators.mnist.generator.HF_AVAILABLE", False):
277
+ from juniper_data.generators.mnist.generator import MnistGenerator
278
+
279
+ params = MnistParams()
280
+ with pytest.raises(ImportError, match="Hugging Face datasets package not installed"):
281
+ MnistGenerator.generate(params)
282
+
283
+ def test_generate_correct_dtypes(self, mock_hf_load) -> None:
284
+ """All arrays are float32."""
285
+ mock_ds, _, _ = _make_mock_hf_dataset(n_samples=10)
286
+ mock_hf_load.return_value = mock_ds
287
+
288
+ from juniper_data.generators.mnist.generator import MnistGenerator
289
+
290
+ params = MnistParams(seed=42)
291
+ result = MnistGenerator.generate(params)
292
+
293
+ for key in ["X_train", "y_train", "X_test", "y_test", "X_full", "y_full"]:
294
+ assert result[key].dtype == np.float32
295
+
296
+
297
+ @pytest.mark.unit
298
+ @pytest.mark.generators
299
+ class TestMnistGetSchema:
300
+ """Tests for get_schema function."""
301
+
302
+ def test_get_schema_returns_dict(self) -> None:
303
+ """get_schema returns a dictionary."""
304
+ from juniper_data.generators.mnist.generator import get_schema
305
+
306
+ schema = get_schema()
307
+ assert isinstance(schema, dict)
308
+
309
+ def test_get_schema_has_properties(self) -> None:
310
+ """Schema has expected properties."""
311
+ from juniper_data.generators.mnist.generator import get_schema
312
+
313
+ schema = get_schema()
314
+ assert "properties" in schema
315
+ assert "dataset" in schema["properties"]
316
+ assert "n_samples" in schema["properties"]
317
+ assert "flatten" in schema["properties"]
318
+
319
+
320
+ @pytest.mark.unit
321
+ @pytest.mark.generators
322
+ class TestMnistVersion:
323
+ """Tests for version constant."""
324
+
325
+ def test_version_format(self) -> None:
326
+ """Version follows semver format."""
327
+ from juniper_data.generators.mnist.generator import VERSION
328
+
329
+ parts = VERSION.split(".")
330
+ assert len(parts) == 3
331
+ for part in parts:
332
+ assert part.isdigit()
333
+
334
+
335
+ @pytest.mark.unit
336
+ @pytest.mark.generators
337
+ class TestMnistImports:
338
+ """Tests for __init__.py imports."""
339
+
340
+ def test_init_exports(self) -> None:
341
+ """__init__.py exports expected symbols."""
342
+ from juniper_data.generators.mnist import VERSION, MnistGenerator, MnistParams, get_schema
343
+
344
+ assert MnistGenerator is not None
345
+ assert MnistParams is not None
346
+ assert VERSION is not None
347
+ assert get_schema is not None
348
+
349
+ def test_module_level_hf_available_true(self) -> None:
350
+ """Module-level HF_AVAILABLE is True when datasets is importable."""
351
+ import importlib
352
+ import sys
353
+ from types import ModuleType
354
+ from unittest.mock import MagicMock
355
+
356
+ # Inject a fake 'datasets' module so the try-branch succeeds
357
+ fake_datasets = ModuleType("datasets")
358
+ fake_datasets.load_dataset = MagicMock() # type: ignore[attr-defined]
359
+ sys.modules["datasets"] = fake_datasets
360
+
361
+ mod_name = "juniper_data.generators.mnist.generator"
362
+ sys.modules.pop(mod_name, None)
363
+ try:
364
+ mod = importlib.import_module(mod_name)
365
+ assert mod.HF_AVAILABLE is True
366
+ finally:
367
+ # Restore original state
368
+ sys.modules.pop("datasets", None)
369
+ sys.modules.pop(mod_name, None)
370
+ importlib.import_module(mod_name)