juniper-data 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- juniper_data/__init__.py +88 -0
- juniper_data/__main__.py +78 -0
- juniper_data/api/__init__.py +10 -0
- juniper_data/api/app.py +111 -0
- juniper_data/api/middleware.py +95 -0
- juniper_data/api/routes/__init__.py +9 -0
- juniper_data/api/routes/datasets.py +414 -0
- juniper_data/api/routes/generators.py +125 -0
- juniper_data/api/routes/health.py +49 -0
- juniper_data/api/security.py +238 -0
- juniper_data/api/settings.py +109 -0
- juniper_data/core/__init__.py +32 -0
- juniper_data/core/artifacts.py +63 -0
- juniper_data/core/dataset_id.py +38 -0
- juniper_data/core/models.py +135 -0
- juniper_data/core/split.py +120 -0
- juniper_data/generators/__init__.py +15 -0
- juniper_data/generators/arc_agi/__init__.py +11 -0
- juniper_data/generators/arc_agi/generator.py +229 -0
- juniper_data/generators/arc_agi/params.py +56 -0
- juniper_data/generators/checkerboard/__init__.py +15 -0
- juniper_data/generators/checkerboard/generator.py +114 -0
- juniper_data/generators/checkerboard/params.py +32 -0
- juniper_data/generators/circles/__init__.py +11 -0
- juniper_data/generators/circles/generator.py +112 -0
- juniper_data/generators/circles/params.py +31 -0
- juniper_data/generators/csv_import/__init__.py +15 -0
- juniper_data/generators/csv_import/generator.py +198 -0
- juniper_data/generators/csv_import/params.py +48 -0
- juniper_data/generators/gaussian/__init__.py +11 -0
- juniper_data/generators/gaussian/generator.py +149 -0
- juniper_data/generators/gaussian/params.py +53 -0
- juniper_data/generators/mnist/__init__.py +11 -0
- juniper_data/generators/mnist/generator.py +124 -0
- juniper_data/generators/mnist/params.py +39 -0
- juniper_data/generators/spiral/__init__.py +57 -0
- juniper_data/generators/spiral/defaults.py +39 -0
- juniper_data/generators/spiral/generator.py +206 -0
- juniper_data/generators/spiral/params.py +148 -0
- juniper_data/generators/xor/__init__.py +11 -0
- juniper_data/generators/xor/generator.py +162 -0
- juniper_data/generators/xor/params.py +30 -0
- juniper_data/storage/__init__.py +120 -0
- juniper_data/storage/base.py +279 -0
- juniper_data/storage/cached.py +211 -0
- juniper_data/storage/hf_store.py +257 -0
- juniper_data/storage/kaggle_store.py +333 -0
- juniper_data/storage/local_fs.py +232 -0
- juniper_data/storage/memory.py +136 -0
- juniper_data/storage/postgres_store.py +373 -0
- juniper_data/storage/redis_store.py +264 -0
- juniper_data/tests/__init__.py +1 -0
- juniper_data/tests/conftest.py +68 -0
- juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
- juniper_data/tests/integration/__init__.py +1 -0
- juniper_data/tests/integration/test_api.py +283 -0
- juniper_data/tests/integration/test_e2e_workflow.py +378 -0
- juniper_data/tests/integration/test_lifecycle_api.py +304 -0
- juniper_data/tests/integration/test_security_integration.py +189 -0
- juniper_data/tests/integration/test_storage_workflow.py +259 -0
- juniper_data/tests/performance/__init__.py +1 -0
- juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
- juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
- juniper_data/tests/unit/__init__.py +1 -0
- juniper_data/tests/unit/test_api_app.py +206 -0
- juniper_data/tests/unit/test_api_routes.py +407 -0
- juniper_data/tests/unit/test_api_settings.py +100 -0
- juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
- juniper_data/tests/unit/test_artifacts.py +145 -0
- juniper_data/tests/unit/test_cached_store.py +423 -0
- juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
- juniper_data/tests/unit/test_circles_generator.py +256 -0
- juniper_data/tests/unit/test_csv_import_generator.py +345 -0
- juniper_data/tests/unit/test_dataset_id.py +181 -0
- juniper_data/tests/unit/test_gaussian_generator.py +333 -0
- juniper_data/tests/unit/test_hf_store.py +416 -0
- juniper_data/tests/unit/test_init.py +93 -0
- juniper_data/tests/unit/test_kaggle_store.py +469 -0
- juniper_data/tests/unit/test_lifecycle.py +394 -0
- juniper_data/tests/unit/test_main.py +127 -0
- juniper_data/tests/unit/test_middleware.py +79 -0
- juniper_data/tests/unit/test_mnist_generator.py +370 -0
- juniper_data/tests/unit/test_postgres_store.py +490 -0
- juniper_data/tests/unit/test_redis_store.py +500 -0
- juniper_data/tests/unit/test_security.py +281 -0
- juniper_data/tests/unit/test_security_boundaries.py +517 -0
- juniper_data/tests/unit/test_spiral_generator.py +566 -0
- juniper_data/tests/unit/test_split.py +245 -0
- juniper_data/tests/unit/test_storage.py +767 -0
- juniper_data/tests/unit/test_xor_generator.py +223 -0
- juniper_data-0.4.2.dist-info/METADATA +216 -0
- juniper_data-0.4.2.dist-info/RECORD +95 -0
- juniper_data-0.4.2.dist-info/WHEEL +5 -0
- juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
- juniper_data-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
"""Unit tests for the concentric circles dataset generator."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from juniper_data.generators.circles import VERSION, CirclesGenerator, CirclesParams, get_schema
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestCirclesParams:
|
|
10
|
+
"""Tests for CirclesParams validation."""
|
|
11
|
+
|
|
12
|
+
def test_default_params(self) -> None:
|
|
13
|
+
"""Default parameters should be valid."""
|
|
14
|
+
params = CirclesParams()
|
|
15
|
+
assert params.n_samples == 100
|
|
16
|
+
assert params.outer_radius == 1.0
|
|
17
|
+
assert params.factor == 0.5
|
|
18
|
+
assert params.noise == 0.0
|
|
19
|
+
assert params.inner_ratio == 0.5
|
|
20
|
+
assert params.train_ratio == 0.8
|
|
21
|
+
assert params.test_ratio == 0.2
|
|
22
|
+
assert params.shuffle is True
|
|
23
|
+
|
|
24
|
+
def test_custom_params(self) -> None:
|
|
25
|
+
"""Custom parameters should be accepted."""
|
|
26
|
+
params = CirclesParams(
|
|
27
|
+
n_samples=200,
|
|
28
|
+
outer_radius=2.0,
|
|
29
|
+
factor=0.3,
|
|
30
|
+
noise=0.1,
|
|
31
|
+
seed=42,
|
|
32
|
+
)
|
|
33
|
+
assert params.n_samples == 200
|
|
34
|
+
assert params.outer_radius == 2.0
|
|
35
|
+
assert params.factor == 0.3
|
|
36
|
+
assert params.noise == 0.1
|
|
37
|
+
assert params.seed == 42
|
|
38
|
+
|
|
39
|
+
def test_invalid_factor_too_low(self) -> None:
|
|
40
|
+
"""Factor must be greater than 0."""
|
|
41
|
+
with pytest.raises(ValueError):
|
|
42
|
+
CirclesParams(factor=0.0)
|
|
43
|
+
|
|
44
|
+
def test_invalid_factor_too_high(self) -> None:
|
|
45
|
+
"""Factor must be less than 1."""
|
|
46
|
+
with pytest.raises(ValueError):
|
|
47
|
+
CirclesParams(factor=1.0)
|
|
48
|
+
|
|
49
|
+
def test_invalid_outer_radius_negative(self) -> None:
|
|
50
|
+
"""Outer radius must be positive."""
|
|
51
|
+
with pytest.raises(ValueError):
|
|
52
|
+
CirclesParams(outer_radius=-1.0)
|
|
53
|
+
|
|
54
|
+
def test_invalid_n_samples_too_low(self) -> None:
|
|
55
|
+
"""n_samples must be at least 2."""
|
|
56
|
+
with pytest.raises(ValueError):
|
|
57
|
+
CirclesParams(n_samples=1)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class TestCirclesGenerator:
|
|
61
|
+
"""Tests for CirclesGenerator."""
|
|
62
|
+
|
|
63
|
+
def test_generate_returns_expected_keys(self) -> None:
|
|
64
|
+
"""Generated data should contain all expected keys."""
|
|
65
|
+
params = CirclesParams(seed=42)
|
|
66
|
+
result = CirclesGenerator.generate(params)
|
|
67
|
+
|
|
68
|
+
expected_keys = {"X_train", "y_train", "X_test", "y_test", "X_full", "y_full"}
|
|
69
|
+
assert set(result.keys()) == expected_keys
|
|
70
|
+
|
|
71
|
+
def test_generate_shapes(self) -> None:
|
|
72
|
+
"""Generated arrays should have correct shapes."""
|
|
73
|
+
params = CirclesParams(n_samples=150, seed=42)
|
|
74
|
+
result = CirclesGenerator.generate(params)
|
|
75
|
+
|
|
76
|
+
assert result["X_full"].shape == (150, 2)
|
|
77
|
+
assert result["y_full"].shape == (150, 2)
|
|
78
|
+
|
|
79
|
+
def test_generate_dtypes(self) -> None:
|
|
80
|
+
"""Generated arrays should have float32 dtype."""
|
|
81
|
+
params = CirclesParams(seed=42)
|
|
82
|
+
result = CirclesGenerator.generate(params)
|
|
83
|
+
|
|
84
|
+
assert result["X_train"].dtype == np.float32
|
|
85
|
+
assert result["y_train"].dtype == np.float32
|
|
86
|
+
assert result["X_full"].dtype == np.float32
|
|
87
|
+
assert result["y_full"].dtype == np.float32
|
|
88
|
+
|
|
89
|
+
def test_determinism_with_seed(self) -> None:
|
|
90
|
+
"""Same seed should produce identical results."""
|
|
91
|
+
params = CirclesParams(seed=123)
|
|
92
|
+
|
|
93
|
+
result1 = CirclesGenerator.generate(params)
|
|
94
|
+
result2 = CirclesGenerator.generate(params)
|
|
95
|
+
|
|
96
|
+
np.testing.assert_array_equal(result1["X_full"], result2["X_full"])
|
|
97
|
+
np.testing.assert_array_equal(result1["y_full"], result2["y_full"])
|
|
98
|
+
|
|
99
|
+
def test_different_seeds_produce_different_data(self) -> None:
|
|
100
|
+
"""Different seeds should produce different results."""
|
|
101
|
+
params1 = CirclesParams(seed=42)
|
|
102
|
+
params2 = CirclesParams(seed=43)
|
|
103
|
+
|
|
104
|
+
result1 = CirclesGenerator.generate(params1)
|
|
105
|
+
result2 = CirclesGenerator.generate(params2)
|
|
106
|
+
|
|
107
|
+
assert not np.allclose(result1["X_full"], result2["X_full"])
|
|
108
|
+
|
|
109
|
+
def test_one_hot_labels(self) -> None:
|
|
110
|
+
"""Labels should be valid one-hot encoded."""
|
|
111
|
+
params = CirclesParams(seed=42)
|
|
112
|
+
result = CirclesGenerator.generate(params)
|
|
113
|
+
|
|
114
|
+
row_sums = result["y_full"].sum(axis=1)
|
|
115
|
+
np.testing.assert_array_almost_equal(row_sums, np.ones(len(row_sums)))
|
|
116
|
+
|
|
117
|
+
for row in result["y_full"]:
|
|
118
|
+
assert np.sum(row == 1.0) == 1
|
|
119
|
+
assert np.sum(row == 0.0) == 1
|
|
120
|
+
|
|
121
|
+
def test_class_distribution(self) -> None:
|
|
122
|
+
"""Classes should be distributed according to inner_ratio."""
|
|
123
|
+
params = CirclesParams(n_samples=100, inner_ratio=0.5, seed=42)
|
|
124
|
+
result = CirclesGenerator.generate(params)
|
|
125
|
+
|
|
126
|
+
class_counts = result["y_full"].sum(axis=0)
|
|
127
|
+
assert class_counts[0] == 50
|
|
128
|
+
assert class_counts[1] == 50
|
|
129
|
+
|
|
130
|
+
def test_class_distribution_custom_ratio(self) -> None:
|
|
131
|
+
"""Custom inner_ratio should be respected."""
|
|
132
|
+
params = CirclesParams(n_samples=100, inner_ratio=0.3, seed=42)
|
|
133
|
+
result = CirclesGenerator.generate(params)
|
|
134
|
+
|
|
135
|
+
class_counts = result["y_full"].sum(axis=0)
|
|
136
|
+
assert class_counts[0] == 70
|
|
137
|
+
assert class_counts[1] == 30
|
|
138
|
+
|
|
139
|
+
def test_train_test_split_ratio(self) -> None:
|
|
140
|
+
"""Train/test split should respect configured ratios."""
|
|
141
|
+
params = CirclesParams(
|
|
142
|
+
n_samples=100,
|
|
143
|
+
train_ratio=0.7,
|
|
144
|
+
test_ratio=0.3,
|
|
145
|
+
seed=42,
|
|
146
|
+
)
|
|
147
|
+
result = CirclesGenerator.generate(params)
|
|
148
|
+
|
|
149
|
+
assert len(result["X_train"]) == 70
|
|
150
|
+
assert len(result["X_test"]) == 30
|
|
151
|
+
|
|
152
|
+
def test_points_on_circles_no_noise(self) -> None:
|
|
153
|
+
"""Without noise, points should lie exactly on circles."""
|
|
154
|
+
params = CirclesParams(
|
|
155
|
+
n_samples=100,
|
|
156
|
+
outer_radius=2.0,
|
|
157
|
+
factor=0.5,
|
|
158
|
+
noise=0.0,
|
|
159
|
+
inner_ratio=0.5,
|
|
160
|
+
seed=42,
|
|
161
|
+
shuffle=False,
|
|
162
|
+
)
|
|
163
|
+
result = CirclesGenerator.generate(params)
|
|
164
|
+
|
|
165
|
+
outer_points = result["X_full"][:50]
|
|
166
|
+
inner_points = result["X_full"][50:]
|
|
167
|
+
|
|
168
|
+
outer_distances = np.linalg.norm(outer_points, axis=1)
|
|
169
|
+
inner_distances = np.linalg.norm(inner_points, axis=1)
|
|
170
|
+
|
|
171
|
+
np.testing.assert_array_almost_equal(outer_distances, np.full(50, 2.0))
|
|
172
|
+
np.testing.assert_array_almost_equal(inner_distances, np.full(50, 1.0))
|
|
173
|
+
|
|
174
|
+
def test_noise_adds_variation(self) -> None:
|
|
175
|
+
"""Noise parameter should add variation to circle radii."""
|
|
176
|
+
params_no_noise = CirclesParams(n_samples=100, noise=0.0, seed=42)
|
|
177
|
+
params_with_noise = CirclesParams(n_samples=100, noise=0.5, seed=42)
|
|
178
|
+
|
|
179
|
+
result_no_noise = CirclesGenerator.generate(params_no_noise)
|
|
180
|
+
result_with_noise = CirclesGenerator.generate(params_with_noise)
|
|
181
|
+
|
|
182
|
+
var_no_noise = np.var(np.linalg.norm(result_no_noise["X_full"], axis=1))
|
|
183
|
+
var_with_noise = np.var(np.linalg.norm(result_with_noise["X_full"], axis=1))
|
|
184
|
+
|
|
185
|
+
assert var_with_noise > var_no_noise
|
|
186
|
+
|
|
187
|
+
def test_factor_affects_inner_radius(self) -> None:
|
|
188
|
+
"""Factor should control inner circle radius."""
|
|
189
|
+
params = CirclesParams(
|
|
190
|
+
n_samples=100,
|
|
191
|
+
outer_radius=4.0,
|
|
192
|
+
factor=0.25,
|
|
193
|
+
noise=0.0,
|
|
194
|
+
inner_ratio=0.5,
|
|
195
|
+
seed=42,
|
|
196
|
+
shuffle=False,
|
|
197
|
+
)
|
|
198
|
+
result = CirclesGenerator.generate(params)
|
|
199
|
+
|
|
200
|
+
inner_points = result["X_full"][50:]
|
|
201
|
+
inner_distances = np.linalg.norm(inner_points, axis=1)
|
|
202
|
+
|
|
203
|
+
np.testing.assert_array_almost_equal(inner_distances, np.full(50, 1.0))
|
|
204
|
+
|
|
205
|
+
def test_generate_with_noise_covers_branch(self) -> None:
|
|
206
|
+
"""Noise > 0 should exercise the noise addition branch."""
|
|
207
|
+
params = CirclesParams(
|
|
208
|
+
n_samples=100,
|
|
209
|
+
noise=0.3,
|
|
210
|
+
seed=42,
|
|
211
|
+
shuffle=False,
|
|
212
|
+
)
|
|
213
|
+
result = CirclesGenerator.generate(params)
|
|
214
|
+
|
|
215
|
+
outer_distances = np.linalg.norm(result["X_full"][:50], axis=1)
|
|
216
|
+
assert not np.allclose(outer_distances, np.full(50, 1.0))
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class TestGetSchema:
|
|
220
|
+
"""Tests for get_schema function."""
|
|
221
|
+
|
|
222
|
+
def test_returns_dict(self) -> None:
|
|
223
|
+
"""get_schema should return a dictionary."""
|
|
224
|
+
schema = get_schema()
|
|
225
|
+
assert isinstance(schema, dict)
|
|
226
|
+
|
|
227
|
+
def test_schema_has_properties(self) -> None:
|
|
228
|
+
"""Schema should have properties key."""
|
|
229
|
+
schema = get_schema()
|
|
230
|
+
assert "properties" in schema
|
|
231
|
+
|
|
232
|
+
def test_schema_includes_all_params(self) -> None:
|
|
233
|
+
"""Schema should include all parameter names."""
|
|
234
|
+
schema = get_schema()
|
|
235
|
+
expected_params = {
|
|
236
|
+
"n_samples",
|
|
237
|
+
"outer_radius",
|
|
238
|
+
"factor",
|
|
239
|
+
"noise",
|
|
240
|
+
"inner_ratio",
|
|
241
|
+
"seed",
|
|
242
|
+
"train_ratio",
|
|
243
|
+
"test_ratio",
|
|
244
|
+
"shuffle",
|
|
245
|
+
}
|
|
246
|
+
assert expected_params.issubset(set(schema["properties"].keys()))
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class TestVersion:
|
|
250
|
+
"""Tests for VERSION constant."""
|
|
251
|
+
|
|
252
|
+
def test_version_format(self) -> None:
|
|
253
|
+
"""VERSION should be a valid semver string."""
|
|
254
|
+
parts = VERSION.split(".")
|
|
255
|
+
assert len(parts) == 3
|
|
256
|
+
assert all(part.isdigit() for part in parts)
|
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
"""Unit tests for the CSV/JSON import generator."""
|
|
2
|
+
|
|
3
|
+
import tempfile
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pytest
|
|
8
|
+
|
|
9
|
+
from juniper_data.generators.csv_import import (
|
|
10
|
+
VERSION,
|
|
11
|
+
CsvImportGenerator,
|
|
12
|
+
CsvImportParams,
|
|
13
|
+
get_schema,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@pytest.fixture
|
|
18
|
+
def sample_csv_file() -> Path:
|
|
19
|
+
"""Create a sample CSV file for testing."""
|
|
20
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
|
|
21
|
+
f.write("feature1,feature2,label\n")
|
|
22
|
+
f.write("1.0,2.0,A\n")
|
|
23
|
+
f.write("3.0,4.0,B\n")
|
|
24
|
+
f.write("5.0,6.0,A\n")
|
|
25
|
+
f.write("7.0,8.0,B\n")
|
|
26
|
+
return Path(f.name)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@pytest.fixture
|
|
30
|
+
def sample_json_file() -> Path:
|
|
31
|
+
"""Create a sample JSON file for testing."""
|
|
32
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
|
33
|
+
f.write('[{"feature1": 1.0, "feature2": 2.0, "label": "A"},')
|
|
34
|
+
f.write('{"feature1": 3.0, "feature2": 4.0, "label": "B"},')
|
|
35
|
+
f.write('{"feature1": 5.0, "feature2": 6.0, "label": "A"},')
|
|
36
|
+
f.write('{"feature1": 7.0, "feature2": 8.0, "label": "B"}]')
|
|
37
|
+
return Path(f.name)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@pytest.fixture
|
|
41
|
+
def sample_jsonl_file() -> Path:
|
|
42
|
+
"""Create a sample JSONL file for testing."""
|
|
43
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
|
44
|
+
f.write('{"feature1": 1.0, "feature2": 2.0, "label": 0}\n')
|
|
45
|
+
f.write('{"feature1": 3.0, "feature2": 4.0, "label": 1}\n')
|
|
46
|
+
f.write('{"feature1": 5.0, "feature2": 6.0, "label": 0}\n')
|
|
47
|
+
f.write('{"feature1": 7.0, "feature2": 8.0, "label": 1}\n')
|
|
48
|
+
return Path(f.name)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class TestCsvImportParams:
|
|
52
|
+
"""Tests for CsvImportParams validation."""
|
|
53
|
+
|
|
54
|
+
def test_valid_params(self) -> None:
|
|
55
|
+
"""Valid parameters should be accepted."""
|
|
56
|
+
params = CsvImportParams(
|
|
57
|
+
file_path="/path/to/file.csv",
|
|
58
|
+
feature_columns=["col1", "col2"],
|
|
59
|
+
label_column="target",
|
|
60
|
+
)
|
|
61
|
+
assert params.file_path == "/path/to/file.csv"
|
|
62
|
+
assert params.feature_columns == ["col1", "col2"]
|
|
63
|
+
assert params.label_column == "target"
|
|
64
|
+
|
|
65
|
+
def test_default_values(self) -> None:
|
|
66
|
+
"""Default values should be set correctly."""
|
|
67
|
+
params = CsvImportParams(file_path="/path/to/file.csv")
|
|
68
|
+
assert params.file_format == "auto"
|
|
69
|
+
assert params.feature_columns is None
|
|
70
|
+
assert params.label_column == "label"
|
|
71
|
+
assert params.delimiter == ","
|
|
72
|
+
assert params.header is True
|
|
73
|
+
assert params.one_hot_labels is True
|
|
74
|
+
assert params.normalize_features is False
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class TestCsvImportGenerator:
|
|
78
|
+
"""Tests for CsvImportGenerator."""
|
|
79
|
+
|
|
80
|
+
def test_load_csv_file(self, sample_csv_file: Path) -> None:
|
|
81
|
+
"""Should load data from CSV file."""
|
|
82
|
+
params = CsvImportParams(
|
|
83
|
+
file_path=str(sample_csv_file),
|
|
84
|
+
seed=42,
|
|
85
|
+
)
|
|
86
|
+
result = CsvImportGenerator.generate(params)
|
|
87
|
+
|
|
88
|
+
assert result["X_full"].shape == (4, 2)
|
|
89
|
+
assert result["y_full"].shape == (4, 2)
|
|
90
|
+
|
|
91
|
+
def test_load_json_file(self, sample_json_file: Path) -> None:
|
|
92
|
+
"""Should load data from JSON file."""
|
|
93
|
+
params = CsvImportParams(
|
|
94
|
+
file_path=str(sample_json_file),
|
|
95
|
+
seed=42,
|
|
96
|
+
)
|
|
97
|
+
result = CsvImportGenerator.generate(params)
|
|
98
|
+
|
|
99
|
+
assert result["X_full"].shape == (4, 2)
|
|
100
|
+
assert result["y_full"].shape == (4, 2)
|
|
101
|
+
|
|
102
|
+
def test_load_jsonl_file(self, sample_jsonl_file: Path) -> None:
|
|
103
|
+
"""Should load data from JSONL file."""
|
|
104
|
+
params = CsvImportParams(
|
|
105
|
+
file_path=str(sample_jsonl_file),
|
|
106
|
+
seed=42,
|
|
107
|
+
)
|
|
108
|
+
result = CsvImportGenerator.generate(params)
|
|
109
|
+
|
|
110
|
+
assert result["X_full"].shape == (4, 2)
|
|
111
|
+
assert result["y_full"].shape == (4, 2)
|
|
112
|
+
|
|
113
|
+
def test_feature_values(self, sample_csv_file: Path) -> None:
|
|
114
|
+
"""Feature values should be correctly parsed."""
|
|
115
|
+
params = CsvImportParams(
|
|
116
|
+
file_path=str(sample_csv_file),
|
|
117
|
+
shuffle=False,
|
|
118
|
+
)
|
|
119
|
+
result = CsvImportGenerator.generate(params)
|
|
120
|
+
|
|
121
|
+
expected_X = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], dtype=np.float32)
|
|
122
|
+
np.testing.assert_array_equal(result["X_full"], expected_X)
|
|
123
|
+
|
|
124
|
+
def test_one_hot_labels(self, sample_csv_file: Path) -> None:
|
|
125
|
+
"""Labels should be one-hot encoded."""
|
|
126
|
+
params = CsvImportParams(
|
|
127
|
+
file_path=str(sample_csv_file),
|
|
128
|
+
one_hot_labels=True,
|
|
129
|
+
shuffle=False,
|
|
130
|
+
)
|
|
131
|
+
result = CsvImportGenerator.generate(params)
|
|
132
|
+
|
|
133
|
+
row_sums = result["y_full"].sum(axis=1)
|
|
134
|
+
np.testing.assert_array_almost_equal(row_sums, np.ones(4))
|
|
135
|
+
|
|
136
|
+
def test_non_one_hot_labels(self, sample_csv_file: Path) -> None:
|
|
137
|
+
"""Labels should be indices when one_hot=False."""
|
|
138
|
+
params = CsvImportParams(
|
|
139
|
+
file_path=str(sample_csv_file),
|
|
140
|
+
one_hot_labels=False,
|
|
141
|
+
shuffle=False,
|
|
142
|
+
)
|
|
143
|
+
result = CsvImportGenerator.generate(params)
|
|
144
|
+
|
|
145
|
+
assert result["y_full"].shape == (4, 1)
|
|
146
|
+
|
|
147
|
+
def test_normalize_features(self, sample_csv_file: Path) -> None:
|
|
148
|
+
"""Features should be normalized to [0, 1]."""
|
|
149
|
+
params = CsvImportParams(
|
|
150
|
+
file_path=str(sample_csv_file),
|
|
151
|
+
normalize_features=True,
|
|
152
|
+
shuffle=False,
|
|
153
|
+
)
|
|
154
|
+
result = CsvImportGenerator.generate(params)
|
|
155
|
+
|
|
156
|
+
assert result["X_full"].min() >= 0.0
|
|
157
|
+
assert result["X_full"].max() <= 1.0
|
|
158
|
+
|
|
159
|
+
def test_file_not_found(self) -> None:
|
|
160
|
+
"""Should raise FileNotFoundError for missing file."""
|
|
161
|
+
params = CsvImportParams(file_path="/nonexistent/path/file.csv")
|
|
162
|
+
|
|
163
|
+
with pytest.raises(FileNotFoundError):
|
|
164
|
+
CsvImportGenerator.generate(params)
|
|
165
|
+
|
|
166
|
+
def test_train_test_split(self, sample_csv_file: Path) -> None:
|
|
167
|
+
"""Train/test split should work correctly."""
|
|
168
|
+
params = CsvImportParams(
|
|
169
|
+
file_path=str(sample_csv_file),
|
|
170
|
+
train_ratio=0.5,
|
|
171
|
+
test_ratio=0.5,
|
|
172
|
+
seed=42,
|
|
173
|
+
)
|
|
174
|
+
result = CsvImportGenerator.generate(params)
|
|
175
|
+
|
|
176
|
+
assert len(result["X_train"]) == 2
|
|
177
|
+
assert len(result["X_test"]) == 2
|
|
178
|
+
|
|
179
|
+
def test_auto_detect_unsupported_extension(self) -> None:
|
|
180
|
+
"""Unsupported file extension should raise ValueError."""
|
|
181
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".xml", delete=False) as f:
|
|
182
|
+
f.write("<data></data>")
|
|
183
|
+
xml_path = f.name
|
|
184
|
+
|
|
185
|
+
params = CsvImportParams(file_path=xml_path)
|
|
186
|
+
with pytest.raises(ValueError, match="Cannot auto-detect format"):
|
|
187
|
+
CsvImportGenerator.generate(params)
|
|
188
|
+
|
|
189
|
+
def test_csv_without_header(self) -> None:
|
|
190
|
+
"""Should load headerless CSV with auto-generated column names."""
|
|
191
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
|
|
192
|
+
f.write("1.0,2.0,A\n")
|
|
193
|
+
f.write("3.0,4.0,B\n")
|
|
194
|
+
f.write("5.0,6.0,A\n")
|
|
195
|
+
f.write("7.0,8.0,B\n")
|
|
196
|
+
csv_path = f.name
|
|
197
|
+
|
|
198
|
+
params = CsvImportParams(
|
|
199
|
+
file_path=csv_path,
|
|
200
|
+
header=False,
|
|
201
|
+
label_column="col_2",
|
|
202
|
+
seed=42,
|
|
203
|
+
)
|
|
204
|
+
result = CsvImportGenerator.generate(params)
|
|
205
|
+
|
|
206
|
+
assert result["X_full"].shape == (4, 2)
|
|
207
|
+
assert result["y_full"].shape == (4, 2)
|
|
208
|
+
|
|
209
|
+
def test_json_jsonl_format(self) -> None:
|
|
210
|
+
"""Should load JSONL (non-array) format via the else branch."""
|
|
211
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
|
212
|
+
f.write('{"feature1": 1.0, "feature2": 2.0, "label": "A"}\n')
|
|
213
|
+
f.write('{"feature1": 3.0, "feature2": 4.0, "label": "B"}\n')
|
|
214
|
+
jsonl_path = f.name
|
|
215
|
+
|
|
216
|
+
params = CsvImportParams(file_path=jsonl_path, seed=42)
|
|
217
|
+
result = CsvImportGenerator.generate(params)
|
|
218
|
+
|
|
219
|
+
assert result["X_full"].shape == (2, 2)
|
|
220
|
+
|
|
221
|
+
def test_convert_to_arrays_empty_data(self) -> None:
|
|
222
|
+
"""Empty file should raise ValueError."""
|
|
223
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
|
|
224
|
+
f.write("feature1,feature2,label\n")
|
|
225
|
+
csv_path = f.name
|
|
226
|
+
|
|
227
|
+
params = CsvImportParams(file_path=csv_path, seed=42)
|
|
228
|
+
with pytest.raises(ValueError, match="No data found"):
|
|
229
|
+
CsvImportGenerator.generate(params)
|
|
230
|
+
|
|
231
|
+
def test_feature_columns_explicit(self) -> None:
|
|
232
|
+
"""Explicit feature_columns should select only those columns."""
|
|
233
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
|
|
234
|
+
f.write("a,b,c,label\n")
|
|
235
|
+
f.write("1.0,2.0,3.0,A\n")
|
|
236
|
+
f.write("4.0,5.0,6.0,B\n")
|
|
237
|
+
csv_path = f.name
|
|
238
|
+
|
|
239
|
+
params = CsvImportParams(
|
|
240
|
+
file_path=csv_path,
|
|
241
|
+
feature_columns=["a", "c"],
|
|
242
|
+
seed=42,
|
|
243
|
+
)
|
|
244
|
+
result = CsvImportGenerator.generate(params)
|
|
245
|
+
|
|
246
|
+
assert result["X_full"].shape == (2, 2)
|
|
247
|
+
|
|
248
|
+
def test_non_numeric_feature_values(self) -> None:
|
|
249
|
+
"""Non-numeric feature values should be replaced with 0.0."""
|
|
250
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
|
|
251
|
+
f.write("feature1,feature2,label\n")
|
|
252
|
+
f.write("1.0,hello,A\n")
|
|
253
|
+
f.write("3.0,world,B\n")
|
|
254
|
+
csv_path = f.name
|
|
255
|
+
|
|
256
|
+
params = CsvImportParams(
|
|
257
|
+
file_path=csv_path,
|
|
258
|
+
shuffle=False,
|
|
259
|
+
seed=42,
|
|
260
|
+
)
|
|
261
|
+
result = CsvImportGenerator.generate(params)
|
|
262
|
+
|
|
263
|
+
assert result["X_full"][0, 1] == 0.0
|
|
264
|
+
assert result["X_full"][1, 1] == 0.0
|
|
265
|
+
|
|
266
|
+
def test_empty_csv_without_header_raises(self) -> None:
|
|
267
|
+
"""Empty CSV with header=False should raise ValueError."""
|
|
268
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
|
|
269
|
+
f.flush()
|
|
270
|
+
path = Path(f.name)
|
|
271
|
+
|
|
272
|
+
params = CsvImportParams(
|
|
273
|
+
file_path=str(path),
|
|
274
|
+
header=False,
|
|
275
|
+
seed=42,
|
|
276
|
+
)
|
|
277
|
+
with pytest.raises(ValueError, match="CSV file is empty"):
|
|
278
|
+
CsvImportGenerator.generate(params)
|
|
279
|
+
|
|
280
|
+
def test_normalize_with_constant_feature(self) -> None:
|
|
281
|
+
"""Normalization with a constant feature column should not produce NaN."""
|
|
282
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
|
|
283
|
+
f.write("feature1,feature2,label\n")
|
|
284
|
+
f.write("5.0,1.0,A\n")
|
|
285
|
+
f.write("5.0,2.0,B\n")
|
|
286
|
+
f.write("5.0,3.0,A\n")
|
|
287
|
+
csv_path = f.name
|
|
288
|
+
|
|
289
|
+
params = CsvImportParams(
|
|
290
|
+
file_path=csv_path,
|
|
291
|
+
normalize_features=True,
|
|
292
|
+
shuffle=False,
|
|
293
|
+
seed=42,
|
|
294
|
+
)
|
|
295
|
+
result = CsvImportGenerator.generate(params)
|
|
296
|
+
|
|
297
|
+
assert not np.any(np.isnan(result["X_full"]))
|
|
298
|
+
assert result["X_full"][:, 0].min() == 0.0
|
|
299
|
+
assert result["X_full"][:, 0].max() == 0.0
|
|
300
|
+
|
|
301
|
+
def test_explicit_csv_format(self, sample_csv_file: Path) -> None:
|
|
302
|
+
"""Explicit file_format='csv' should bypass auto-detect."""
|
|
303
|
+
params = CsvImportParams(
|
|
304
|
+
file_path=str(sample_csv_file),
|
|
305
|
+
file_format="csv",
|
|
306
|
+
seed=42,
|
|
307
|
+
)
|
|
308
|
+
result = CsvImportGenerator.generate(params)
|
|
309
|
+
|
|
310
|
+
assert result["X_full"].shape == (4, 2)
|
|
311
|
+
|
|
312
|
+
def test_explicit_json_format(self, sample_json_file: Path) -> None:
|
|
313
|
+
"""Explicit file_format='json' should bypass auto-detect."""
|
|
314
|
+
params = CsvImportParams(
|
|
315
|
+
file_path=str(sample_json_file),
|
|
316
|
+
file_format="json",
|
|
317
|
+
seed=42,
|
|
318
|
+
)
|
|
319
|
+
result = CsvImportGenerator.generate(params)
|
|
320
|
+
|
|
321
|
+
assert result["X_full"].shape == (4, 2)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
class TestGetSchema:
|
|
325
|
+
"""Tests for get_schema function."""
|
|
326
|
+
|
|
327
|
+
def test_returns_dict(self) -> None:
|
|
328
|
+
"""get_schema should return a dictionary."""
|
|
329
|
+
schema = get_schema()
|
|
330
|
+
assert isinstance(schema, dict)
|
|
331
|
+
|
|
332
|
+
def test_schema_has_properties(self) -> None:
|
|
333
|
+
"""Schema should have properties key."""
|
|
334
|
+
schema = get_schema()
|
|
335
|
+
assert "properties" in schema
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
class TestVersion:
|
|
339
|
+
"""Tests for VERSION constant."""
|
|
340
|
+
|
|
341
|
+
def test_version_format(self) -> None:
|
|
342
|
+
"""VERSION should be a valid semver string."""
|
|
343
|
+
parts = VERSION.split(".")
|
|
344
|
+
assert len(parts) == 3
|
|
345
|
+
assert all(part.isdigit() for part in parts)
|