juniper-data 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- juniper_data/__init__.py +88 -0
- juniper_data/__main__.py +78 -0
- juniper_data/api/__init__.py +10 -0
- juniper_data/api/app.py +111 -0
- juniper_data/api/middleware.py +95 -0
- juniper_data/api/routes/__init__.py +9 -0
- juniper_data/api/routes/datasets.py +414 -0
- juniper_data/api/routes/generators.py +125 -0
- juniper_data/api/routes/health.py +49 -0
- juniper_data/api/security.py +238 -0
- juniper_data/api/settings.py +109 -0
- juniper_data/core/__init__.py +32 -0
- juniper_data/core/artifacts.py +63 -0
- juniper_data/core/dataset_id.py +38 -0
- juniper_data/core/models.py +135 -0
- juniper_data/core/split.py +120 -0
- juniper_data/generators/__init__.py +15 -0
- juniper_data/generators/arc_agi/__init__.py +11 -0
- juniper_data/generators/arc_agi/generator.py +229 -0
- juniper_data/generators/arc_agi/params.py +56 -0
- juniper_data/generators/checkerboard/__init__.py +15 -0
- juniper_data/generators/checkerboard/generator.py +114 -0
- juniper_data/generators/checkerboard/params.py +32 -0
- juniper_data/generators/circles/__init__.py +11 -0
- juniper_data/generators/circles/generator.py +112 -0
- juniper_data/generators/circles/params.py +31 -0
- juniper_data/generators/csv_import/__init__.py +15 -0
- juniper_data/generators/csv_import/generator.py +198 -0
- juniper_data/generators/csv_import/params.py +48 -0
- juniper_data/generators/gaussian/__init__.py +11 -0
- juniper_data/generators/gaussian/generator.py +149 -0
- juniper_data/generators/gaussian/params.py +53 -0
- juniper_data/generators/mnist/__init__.py +11 -0
- juniper_data/generators/mnist/generator.py +124 -0
- juniper_data/generators/mnist/params.py +39 -0
- juniper_data/generators/spiral/__init__.py +57 -0
- juniper_data/generators/spiral/defaults.py +39 -0
- juniper_data/generators/spiral/generator.py +206 -0
- juniper_data/generators/spiral/params.py +148 -0
- juniper_data/generators/xor/__init__.py +11 -0
- juniper_data/generators/xor/generator.py +162 -0
- juniper_data/generators/xor/params.py +30 -0
- juniper_data/storage/__init__.py +120 -0
- juniper_data/storage/base.py +279 -0
- juniper_data/storage/cached.py +211 -0
- juniper_data/storage/hf_store.py +257 -0
- juniper_data/storage/kaggle_store.py +333 -0
- juniper_data/storage/local_fs.py +232 -0
- juniper_data/storage/memory.py +136 -0
- juniper_data/storage/postgres_store.py +373 -0
- juniper_data/storage/redis_store.py +264 -0
- juniper_data/tests/__init__.py +1 -0
- juniper_data/tests/conftest.py +68 -0
- juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
- juniper_data/tests/integration/__init__.py +1 -0
- juniper_data/tests/integration/test_api.py +283 -0
- juniper_data/tests/integration/test_e2e_workflow.py +378 -0
- juniper_data/tests/integration/test_lifecycle_api.py +304 -0
- juniper_data/tests/integration/test_security_integration.py +189 -0
- juniper_data/tests/integration/test_storage_workflow.py +259 -0
- juniper_data/tests/performance/__init__.py +1 -0
- juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
- juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
- juniper_data/tests/unit/__init__.py +1 -0
- juniper_data/tests/unit/test_api_app.py +206 -0
- juniper_data/tests/unit/test_api_routes.py +407 -0
- juniper_data/tests/unit/test_api_settings.py +100 -0
- juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
- juniper_data/tests/unit/test_artifacts.py +145 -0
- juniper_data/tests/unit/test_cached_store.py +423 -0
- juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
- juniper_data/tests/unit/test_circles_generator.py +256 -0
- juniper_data/tests/unit/test_csv_import_generator.py +345 -0
- juniper_data/tests/unit/test_dataset_id.py +181 -0
- juniper_data/tests/unit/test_gaussian_generator.py +333 -0
- juniper_data/tests/unit/test_hf_store.py +416 -0
- juniper_data/tests/unit/test_init.py +93 -0
- juniper_data/tests/unit/test_kaggle_store.py +469 -0
- juniper_data/tests/unit/test_lifecycle.py +394 -0
- juniper_data/tests/unit/test_main.py +127 -0
- juniper_data/tests/unit/test_middleware.py +79 -0
- juniper_data/tests/unit/test_mnist_generator.py +370 -0
- juniper_data/tests/unit/test_postgres_store.py +490 -0
- juniper_data/tests/unit/test_redis_store.py +500 -0
- juniper_data/tests/unit/test_security.py +281 -0
- juniper_data/tests/unit/test_security_boundaries.py +517 -0
- juniper_data/tests/unit/test_spiral_generator.py +566 -0
- juniper_data/tests/unit/test_split.py +245 -0
- juniper_data/tests/unit/test_storage.py +767 -0
- juniper_data/tests/unit/test_xor_generator.py +223 -0
- juniper_data-0.4.2.dist-info/METADATA +216 -0
- juniper_data-0.4.2.dist-info/RECORD +95 -0
- juniper_data-0.4.2.dist-info/WHEEL +5 -0
- juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
- juniper_data-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
"""Unit tests for the MNIST dataset generator."""
|
|
2
|
+
|
|
3
|
+
from unittest.mock import MagicMock, patch
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
from juniper_data.generators.mnist.params import MnistParams
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _make_mock_hf_dataset(n_samples=20, n_classes=10):
|
|
12
|
+
"""Create a mock HuggingFace dataset for MNIST-like data."""
|
|
13
|
+
mock_ds = MagicMock()
|
|
14
|
+
mock_ds.__len__ = MagicMock(return_value=n_samples)
|
|
15
|
+
|
|
16
|
+
labels = [i % n_classes for i in range(n_samples)]
|
|
17
|
+
|
|
18
|
+
np_images = np.random.randint(0, 255, (n_samples, 28, 28), dtype=np.uint8)
|
|
19
|
+
np_labels = np.array(labels)
|
|
20
|
+
|
|
21
|
+
images = []
|
|
22
|
+
for i in range(n_samples):
|
|
23
|
+
mock_img = MagicMock()
|
|
24
|
+
mock_img.convert.return_value = np_images[i]
|
|
25
|
+
images.append(mock_img)
|
|
26
|
+
|
|
27
|
+
def ds_iter():
|
|
28
|
+
for i in range(n_samples):
|
|
29
|
+
yield {"image": images[i], "label": labels[i]}
|
|
30
|
+
|
|
31
|
+
mock_ds.__iter__ = MagicMock(side_effect=ds_iter)
|
|
32
|
+
mock_ds.__getitem__ = MagicMock(side_effect=lambda key: labels if key == "label" else images)
|
|
33
|
+
mock_ds.shuffle.return_value = mock_ds
|
|
34
|
+
mock_ds.select.return_value = mock_ds
|
|
35
|
+
|
|
36
|
+
# Mock with_format("numpy") to return a dataset providing numpy arrays
|
|
37
|
+
formatted_ds = MagicMock()
|
|
38
|
+
formatted_ds.__getitem__ = MagicMock(side_effect=lambda key: np_labels if key == "label" else np_images)
|
|
39
|
+
mock_ds.with_format.return_value = formatted_ds
|
|
40
|
+
|
|
41
|
+
return mock_ds, labels, images
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@pytest.fixture
|
|
45
|
+
def mock_hf_load():
|
|
46
|
+
"""Patch HF_AVAILABLE and hf_load_dataset for the mnist generator module."""
|
|
47
|
+
mock_load = MagicMock()
|
|
48
|
+
|
|
49
|
+
with patch("juniper_data.generators.mnist.generator.HF_AVAILABLE", True):
|
|
50
|
+
with patch("juniper_data.generators.mnist.generator.hf_load_dataset", mock_load):
|
|
51
|
+
yield mock_load
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@pytest.mark.unit
|
|
55
|
+
@pytest.mark.generators
|
|
56
|
+
class TestMnistParams:
|
|
57
|
+
"""Tests for MnistParams validation."""
|
|
58
|
+
|
|
59
|
+
def test_default_params(self) -> None:
|
|
60
|
+
"""Default parameters are valid."""
|
|
61
|
+
params = MnistParams()
|
|
62
|
+
assert params.dataset == "mnist"
|
|
63
|
+
assert params.n_samples is None
|
|
64
|
+
assert params.flatten is True
|
|
65
|
+
assert params.normalize is True
|
|
66
|
+
assert params.one_hot_labels is True
|
|
67
|
+
assert params.train_ratio == 0.8
|
|
68
|
+
assert params.test_ratio == 0.2
|
|
69
|
+
|
|
70
|
+
def test_fashion_mnist(self) -> None:
|
|
71
|
+
"""Fashion-MNIST variant is accepted."""
|
|
72
|
+
params = MnistParams(dataset="fashion_mnist")
|
|
73
|
+
assert params.dataset == "fashion_mnist"
|
|
74
|
+
|
|
75
|
+
def test_custom_params(self) -> None:
|
|
76
|
+
"""Custom parameters are accepted."""
|
|
77
|
+
params = MnistParams(
|
|
78
|
+
n_samples=100,
|
|
79
|
+
flatten=False,
|
|
80
|
+
normalize=False,
|
|
81
|
+
one_hot_labels=False,
|
|
82
|
+
seed=42,
|
|
83
|
+
train_ratio=0.7,
|
|
84
|
+
test_ratio=0.3,
|
|
85
|
+
)
|
|
86
|
+
assert params.n_samples == 100
|
|
87
|
+
assert params.flatten is False
|
|
88
|
+
assert params.seed == 42
|
|
89
|
+
|
|
90
|
+
def test_invalid_n_samples(self) -> None:
|
|
91
|
+
"""n_samples must be >= 1."""
|
|
92
|
+
with pytest.raises(ValueError):
|
|
93
|
+
MnistParams(n_samples=0)
|
|
94
|
+
|
|
95
|
+
def test_invalid_train_ratio(self) -> None:
|
|
96
|
+
"""train_ratio must be in (0, 1]."""
|
|
97
|
+
with pytest.raises(ValueError):
|
|
98
|
+
MnistParams(train_ratio=0)
|
|
99
|
+
|
|
100
|
+
def test_invalid_dataset_name(self) -> None:
|
|
101
|
+
"""Invalid dataset name is rejected."""
|
|
102
|
+
with pytest.raises(ValueError):
|
|
103
|
+
MnistParams(dataset="cifar10") # type: ignore[arg-type]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@pytest.mark.unit
|
|
107
|
+
@pytest.mark.generators
|
|
108
|
+
class TestMnistGenerator:
|
|
109
|
+
"""Tests for MnistGenerator functionality."""
|
|
110
|
+
|
|
111
|
+
def test_generate_correct_shapes(self, mock_hf_load) -> None:
|
|
112
|
+
"""Generated arrays have correct shapes."""
|
|
113
|
+
mock_ds, _, _ = _make_mock_hf_dataset(n_samples=20)
|
|
114
|
+
mock_hf_load.return_value = mock_ds
|
|
115
|
+
|
|
116
|
+
from juniper_data.generators.mnist.generator import MnistGenerator
|
|
117
|
+
|
|
118
|
+
params = MnistParams(seed=42)
|
|
119
|
+
result = MnistGenerator.generate(params)
|
|
120
|
+
|
|
121
|
+
n_total = 20
|
|
122
|
+
n_train = int(n_total * 0.8)
|
|
123
|
+
n_test = n_total - n_train
|
|
124
|
+
|
|
125
|
+
assert result["X_train"].shape[0] == n_train
|
|
126
|
+
assert result["X_test"].shape[0] == n_test
|
|
127
|
+
assert result["X_full"].shape[0] == n_total
|
|
128
|
+
assert result["y_full"].shape[0] == n_total
|
|
129
|
+
|
|
130
|
+
def test_generate_flattened(self, mock_hf_load) -> None:
|
|
131
|
+
"""Flatten option produces 1D features (784 for 28x28)."""
|
|
132
|
+
mock_ds, _, _ = _make_mock_hf_dataset(n_samples=10)
|
|
133
|
+
mock_hf_load.return_value = mock_ds
|
|
134
|
+
|
|
135
|
+
from juniper_data.generators.mnist.generator import MnistGenerator
|
|
136
|
+
|
|
137
|
+
params = MnistParams(flatten=True, seed=42)
|
|
138
|
+
result = MnistGenerator.generate(params)
|
|
139
|
+
|
|
140
|
+
assert result["X_full"].shape[1] == 784
|
|
141
|
+
|
|
142
|
+
def test_generate_not_flattened(self, mock_hf_load) -> None:
|
|
143
|
+
"""Non-flatten produces 2D features (28x28)."""
|
|
144
|
+
mock_ds, _, _ = _make_mock_hf_dataset(n_samples=10)
|
|
145
|
+
mock_hf_load.return_value = mock_ds
|
|
146
|
+
|
|
147
|
+
from juniper_data.generators.mnist.generator import MnistGenerator
|
|
148
|
+
|
|
149
|
+
params = MnistParams(flatten=False, seed=42)
|
|
150
|
+
result = MnistGenerator.generate(params)
|
|
151
|
+
|
|
152
|
+
assert result["X_full"].shape[1:] == (28, 28)
|
|
153
|
+
|
|
154
|
+
def test_generate_normalized(self, mock_hf_load) -> None:
|
|
155
|
+
"""Normalized values are in [0, 1]."""
|
|
156
|
+
mock_ds, _, _ = _make_mock_hf_dataset(n_samples=10)
|
|
157
|
+
mock_hf_load.return_value = mock_ds
|
|
158
|
+
|
|
159
|
+
from juniper_data.generators.mnist.generator import MnistGenerator
|
|
160
|
+
|
|
161
|
+
params = MnistParams(normalize=True, seed=42)
|
|
162
|
+
result = MnistGenerator.generate(params)
|
|
163
|
+
|
|
164
|
+
assert result["X_full"].max() <= 1.0
|
|
165
|
+
assert result["X_full"].min() >= 0.0
|
|
166
|
+
|
|
167
|
+
def test_generate_not_normalized(self, mock_hf_load) -> None:
|
|
168
|
+
"""Non-normalized values can exceed 1.0."""
|
|
169
|
+
mock_ds, _, _ = _make_mock_hf_dataset(n_samples=10)
|
|
170
|
+
mock_hf_load.return_value = mock_ds
|
|
171
|
+
|
|
172
|
+
from juniper_data.generators.mnist.generator import MnistGenerator
|
|
173
|
+
|
|
174
|
+
params = MnistParams(normalize=False, seed=42)
|
|
175
|
+
result = MnistGenerator.generate(params)
|
|
176
|
+
|
|
177
|
+
assert result["X_full"].dtype == np.float32
|
|
178
|
+
|
|
179
|
+
def test_generate_one_hot_labels(self, mock_hf_load) -> None:
|
|
180
|
+
"""One-hot encoding produces correct label shape."""
|
|
181
|
+
mock_ds, _, _ = _make_mock_hf_dataset(n_samples=20, n_classes=10)
|
|
182
|
+
mock_hf_load.return_value = mock_ds
|
|
183
|
+
|
|
184
|
+
from juniper_data.generators.mnist.generator import MnistGenerator
|
|
185
|
+
|
|
186
|
+
params = MnistParams(one_hot_labels=True, seed=42)
|
|
187
|
+
result = MnistGenerator.generate(params)
|
|
188
|
+
|
|
189
|
+
assert result["y_full"].shape[1] == 10
|
|
190
|
+
row_sums = result["y_full"].sum(axis=1)
|
|
191
|
+
np.testing.assert_array_almost_equal(row_sums, np.ones(20))
|
|
192
|
+
|
|
193
|
+
def test_generate_integer_labels(self, mock_hf_load) -> None:
|
|
194
|
+
"""Non-one-hot produces integer labels."""
|
|
195
|
+
mock_ds, _, _ = _make_mock_hf_dataset(n_samples=10, n_classes=10)
|
|
196
|
+
mock_hf_load.return_value = mock_ds
|
|
197
|
+
|
|
198
|
+
from juniper_data.generators.mnist.generator import MnistGenerator
|
|
199
|
+
|
|
200
|
+
params = MnistParams(one_hot_labels=False, seed=42)
|
|
201
|
+
result = MnistGenerator.generate(params)
|
|
202
|
+
|
|
203
|
+
assert result["y_full"].shape[1] == 1
|
|
204
|
+
|
|
205
|
+
def test_generate_with_seed_shuffle(self, mock_hf_load) -> None:
|
|
206
|
+
"""Seed triggers dataset shuffle."""
|
|
207
|
+
mock_ds, _, _ = _make_mock_hf_dataset(n_samples=10)
|
|
208
|
+
mock_hf_load.return_value = mock_ds
|
|
209
|
+
|
|
210
|
+
from juniper_data.generators.mnist.generator import MnistGenerator
|
|
211
|
+
|
|
212
|
+
params = MnistParams(seed=42)
|
|
213
|
+
MnistGenerator.generate(params)
|
|
214
|
+
|
|
215
|
+
mock_ds.shuffle.assert_called_once_with(seed=42)
|
|
216
|
+
|
|
217
|
+
def test_generate_with_n_samples(self, mock_hf_load) -> None:
|
|
218
|
+
"""n_samples limits the dataset."""
|
|
219
|
+
mock_ds, _, _ = _make_mock_hf_dataset(n_samples=100)
|
|
220
|
+
mock_hf_load.return_value = mock_ds
|
|
221
|
+
|
|
222
|
+
from juniper_data.generators.mnist.generator import MnistGenerator
|
|
223
|
+
|
|
224
|
+
params = MnistParams(n_samples=50, seed=42)
|
|
225
|
+
MnistGenerator.generate(params)
|
|
226
|
+
|
|
227
|
+
mock_ds.select.assert_called_once()
|
|
228
|
+
|
|
229
|
+
def test_generate_no_seed_no_shuffle(self, mock_hf_load) -> None:
|
|
230
|
+
"""Without seed, no shuffle is called."""
|
|
231
|
+
mock_ds, _, _ = _make_mock_hf_dataset(n_samples=10)
|
|
232
|
+
mock_hf_load.return_value = mock_ds
|
|
233
|
+
|
|
234
|
+
from juniper_data.generators.mnist.generator import MnistGenerator
|
|
235
|
+
|
|
236
|
+
params = MnistParams(seed=None, n_samples=None)
|
|
237
|
+
MnistGenerator.generate(params)
|
|
238
|
+
|
|
239
|
+
mock_ds.shuffle.assert_not_called()
|
|
240
|
+
mock_ds.select.assert_not_called()
|
|
241
|
+
|
|
242
|
+
def test_generate_image_without_convert(self, mock_hf_load) -> None:
|
|
243
|
+
"""Handle images that don't have a convert method."""
|
|
244
|
+
mock_ds = MagicMock()
|
|
245
|
+
mock_ds.__len__ = MagicMock(return_value=5)
|
|
246
|
+
|
|
247
|
+
raw_images = [np.random.randint(0, 255, (28, 28), dtype=np.uint8) for _ in range(5)]
|
|
248
|
+
labels = [0, 1, 2, 3, 4]
|
|
249
|
+
|
|
250
|
+
def ds_iter():
|
|
251
|
+
for i in range(5):
|
|
252
|
+
yield {"image": raw_images[i], "label": labels[i]}
|
|
253
|
+
|
|
254
|
+
mock_ds.__iter__ = MagicMock(side_effect=ds_iter)
|
|
255
|
+
mock_ds.__getitem__ = MagicMock(side_effect=lambda key: labels if key == "label" else raw_images)
|
|
256
|
+
mock_ds.shuffle.return_value = mock_ds
|
|
257
|
+
mock_ds.select.return_value = mock_ds
|
|
258
|
+
|
|
259
|
+
# Mock with_format("numpy") to return numpy arrays
|
|
260
|
+
np_images = np.array(raw_images)
|
|
261
|
+
np_labels = np.array(labels)
|
|
262
|
+
formatted_ds = MagicMock()
|
|
263
|
+
formatted_ds.__getitem__ = MagicMock(side_effect=lambda key: np_labels if key == "label" else np_images)
|
|
264
|
+
mock_ds.with_format.return_value = formatted_ds
|
|
265
|
+
|
|
266
|
+
mock_hf_load.return_value = mock_ds
|
|
267
|
+
|
|
268
|
+
from juniper_data.generators.mnist.generator import MnistGenerator
|
|
269
|
+
|
|
270
|
+
params = MnistParams(seed=42)
|
|
271
|
+
result = MnistGenerator.generate(params)
|
|
272
|
+
assert result["X_full"].shape[0] == 5
|
|
273
|
+
|
|
274
|
+
def test_generate_raises_without_datasets(self) -> None:
|
|
275
|
+
"""Raises ImportError when datasets not installed."""
|
|
276
|
+
with patch("juniper_data.generators.mnist.generator.HF_AVAILABLE", False):
|
|
277
|
+
from juniper_data.generators.mnist.generator import MnistGenerator
|
|
278
|
+
|
|
279
|
+
params = MnistParams()
|
|
280
|
+
with pytest.raises(ImportError, match="Hugging Face datasets package not installed"):
|
|
281
|
+
MnistGenerator.generate(params)
|
|
282
|
+
|
|
283
|
+
def test_generate_correct_dtypes(self, mock_hf_load) -> None:
|
|
284
|
+
"""All arrays are float32."""
|
|
285
|
+
mock_ds, _, _ = _make_mock_hf_dataset(n_samples=10)
|
|
286
|
+
mock_hf_load.return_value = mock_ds
|
|
287
|
+
|
|
288
|
+
from juniper_data.generators.mnist.generator import MnistGenerator
|
|
289
|
+
|
|
290
|
+
params = MnistParams(seed=42)
|
|
291
|
+
result = MnistGenerator.generate(params)
|
|
292
|
+
|
|
293
|
+
for key in ["X_train", "y_train", "X_test", "y_test", "X_full", "y_full"]:
|
|
294
|
+
assert result[key].dtype == np.float32
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
@pytest.mark.unit
|
|
298
|
+
@pytest.mark.generators
|
|
299
|
+
class TestMnistGetSchema:
|
|
300
|
+
"""Tests for get_schema function."""
|
|
301
|
+
|
|
302
|
+
def test_get_schema_returns_dict(self) -> None:
|
|
303
|
+
"""get_schema returns a dictionary."""
|
|
304
|
+
from juniper_data.generators.mnist.generator import get_schema
|
|
305
|
+
|
|
306
|
+
schema = get_schema()
|
|
307
|
+
assert isinstance(schema, dict)
|
|
308
|
+
|
|
309
|
+
def test_get_schema_has_properties(self) -> None:
|
|
310
|
+
"""Schema has expected properties."""
|
|
311
|
+
from juniper_data.generators.mnist.generator import get_schema
|
|
312
|
+
|
|
313
|
+
schema = get_schema()
|
|
314
|
+
assert "properties" in schema
|
|
315
|
+
assert "dataset" in schema["properties"]
|
|
316
|
+
assert "n_samples" in schema["properties"]
|
|
317
|
+
assert "flatten" in schema["properties"]
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
@pytest.mark.unit
|
|
321
|
+
@pytest.mark.generators
|
|
322
|
+
class TestMnistVersion:
|
|
323
|
+
"""Tests for version constant."""
|
|
324
|
+
|
|
325
|
+
def test_version_format(self) -> None:
|
|
326
|
+
"""Version follows semver format."""
|
|
327
|
+
from juniper_data.generators.mnist.generator import VERSION
|
|
328
|
+
|
|
329
|
+
parts = VERSION.split(".")
|
|
330
|
+
assert len(parts) == 3
|
|
331
|
+
for part in parts:
|
|
332
|
+
assert part.isdigit()
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
@pytest.mark.unit
|
|
336
|
+
@pytest.mark.generators
|
|
337
|
+
class TestMnistImports:
|
|
338
|
+
"""Tests for __init__.py imports."""
|
|
339
|
+
|
|
340
|
+
def test_init_exports(self) -> None:
|
|
341
|
+
"""__init__.py exports expected symbols."""
|
|
342
|
+
from juniper_data.generators.mnist import VERSION, MnistGenerator, MnistParams, get_schema
|
|
343
|
+
|
|
344
|
+
assert MnistGenerator is not None
|
|
345
|
+
assert MnistParams is not None
|
|
346
|
+
assert VERSION is not None
|
|
347
|
+
assert get_schema is not None
|
|
348
|
+
|
|
349
|
+
def test_module_level_hf_available_true(self) -> None:
|
|
350
|
+
"""Module-level HF_AVAILABLE is True when datasets is importable."""
|
|
351
|
+
import importlib
|
|
352
|
+
import sys
|
|
353
|
+
from types import ModuleType
|
|
354
|
+
from unittest.mock import MagicMock
|
|
355
|
+
|
|
356
|
+
# Inject a fake 'datasets' module so the try-branch succeeds
|
|
357
|
+
fake_datasets = ModuleType("datasets")
|
|
358
|
+
fake_datasets.load_dataset = MagicMock() # type: ignore[attr-defined]
|
|
359
|
+
sys.modules["datasets"] = fake_datasets
|
|
360
|
+
|
|
361
|
+
mod_name = "juniper_data.generators.mnist.generator"
|
|
362
|
+
sys.modules.pop(mod_name, None)
|
|
363
|
+
try:
|
|
364
|
+
mod = importlib.import_module(mod_name)
|
|
365
|
+
assert mod.HF_AVAILABLE is True
|
|
366
|
+
finally:
|
|
367
|
+
# Restore original state
|
|
368
|
+
sys.modules.pop("datasets", None)
|
|
369
|
+
sys.modules.pop(mod_name, None)
|
|
370
|
+
importlib.import_module(mod_name)
|