juniper-data 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. juniper_data/__init__.py +88 -0
  2. juniper_data/__main__.py +78 -0
  3. juniper_data/api/__init__.py +10 -0
  4. juniper_data/api/app.py +111 -0
  5. juniper_data/api/middleware.py +95 -0
  6. juniper_data/api/routes/__init__.py +9 -0
  7. juniper_data/api/routes/datasets.py +414 -0
  8. juniper_data/api/routes/generators.py +125 -0
  9. juniper_data/api/routes/health.py +49 -0
  10. juniper_data/api/security.py +238 -0
  11. juniper_data/api/settings.py +109 -0
  12. juniper_data/core/__init__.py +32 -0
  13. juniper_data/core/artifacts.py +63 -0
  14. juniper_data/core/dataset_id.py +38 -0
  15. juniper_data/core/models.py +135 -0
  16. juniper_data/core/split.py +120 -0
  17. juniper_data/generators/__init__.py +15 -0
  18. juniper_data/generators/arc_agi/__init__.py +11 -0
  19. juniper_data/generators/arc_agi/generator.py +229 -0
  20. juniper_data/generators/arc_agi/params.py +56 -0
  21. juniper_data/generators/checkerboard/__init__.py +15 -0
  22. juniper_data/generators/checkerboard/generator.py +114 -0
  23. juniper_data/generators/checkerboard/params.py +32 -0
  24. juniper_data/generators/circles/__init__.py +11 -0
  25. juniper_data/generators/circles/generator.py +112 -0
  26. juniper_data/generators/circles/params.py +31 -0
  27. juniper_data/generators/csv_import/__init__.py +15 -0
  28. juniper_data/generators/csv_import/generator.py +198 -0
  29. juniper_data/generators/csv_import/params.py +48 -0
  30. juniper_data/generators/gaussian/__init__.py +11 -0
  31. juniper_data/generators/gaussian/generator.py +149 -0
  32. juniper_data/generators/gaussian/params.py +53 -0
  33. juniper_data/generators/mnist/__init__.py +11 -0
  34. juniper_data/generators/mnist/generator.py +124 -0
  35. juniper_data/generators/mnist/params.py +39 -0
  36. juniper_data/generators/spiral/__init__.py +57 -0
  37. juniper_data/generators/spiral/defaults.py +39 -0
  38. juniper_data/generators/spiral/generator.py +206 -0
  39. juniper_data/generators/spiral/params.py +148 -0
  40. juniper_data/generators/xor/__init__.py +11 -0
  41. juniper_data/generators/xor/generator.py +162 -0
  42. juniper_data/generators/xor/params.py +30 -0
  43. juniper_data/storage/__init__.py +120 -0
  44. juniper_data/storage/base.py +279 -0
  45. juniper_data/storage/cached.py +211 -0
  46. juniper_data/storage/hf_store.py +257 -0
  47. juniper_data/storage/kaggle_store.py +333 -0
  48. juniper_data/storage/local_fs.py +232 -0
  49. juniper_data/storage/memory.py +136 -0
  50. juniper_data/storage/postgres_store.py +373 -0
  51. juniper_data/storage/redis_store.py +264 -0
  52. juniper_data/tests/__init__.py +1 -0
  53. juniper_data/tests/conftest.py +68 -0
  54. juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
  55. juniper_data/tests/integration/__init__.py +1 -0
  56. juniper_data/tests/integration/test_api.py +283 -0
  57. juniper_data/tests/integration/test_e2e_workflow.py +378 -0
  58. juniper_data/tests/integration/test_lifecycle_api.py +304 -0
  59. juniper_data/tests/integration/test_security_integration.py +189 -0
  60. juniper_data/tests/integration/test_storage_workflow.py +259 -0
  61. juniper_data/tests/performance/__init__.py +1 -0
  62. juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
  63. juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
  64. juniper_data/tests/unit/__init__.py +1 -0
  65. juniper_data/tests/unit/test_api_app.py +206 -0
  66. juniper_data/tests/unit/test_api_routes.py +407 -0
  67. juniper_data/tests/unit/test_api_settings.py +100 -0
  68. juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
  69. juniper_data/tests/unit/test_artifacts.py +145 -0
  70. juniper_data/tests/unit/test_cached_store.py +423 -0
  71. juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
  72. juniper_data/tests/unit/test_circles_generator.py +256 -0
  73. juniper_data/tests/unit/test_csv_import_generator.py +345 -0
  74. juniper_data/tests/unit/test_dataset_id.py +181 -0
  75. juniper_data/tests/unit/test_gaussian_generator.py +333 -0
  76. juniper_data/tests/unit/test_hf_store.py +416 -0
  77. juniper_data/tests/unit/test_init.py +93 -0
  78. juniper_data/tests/unit/test_kaggle_store.py +469 -0
  79. juniper_data/tests/unit/test_lifecycle.py +394 -0
  80. juniper_data/tests/unit/test_main.py +127 -0
  81. juniper_data/tests/unit/test_middleware.py +79 -0
  82. juniper_data/tests/unit/test_mnist_generator.py +370 -0
  83. juniper_data/tests/unit/test_postgres_store.py +490 -0
  84. juniper_data/tests/unit/test_redis_store.py +500 -0
  85. juniper_data/tests/unit/test_security.py +281 -0
  86. juniper_data/tests/unit/test_security_boundaries.py +517 -0
  87. juniper_data/tests/unit/test_spiral_generator.py +566 -0
  88. juniper_data/tests/unit/test_split.py +245 -0
  89. juniper_data/tests/unit/test_storage.py +767 -0
  90. juniper_data/tests/unit/test_xor_generator.py +223 -0
  91. juniper_data-0.4.2.dist-info/METADATA +216 -0
  92. juniper_data-0.4.2.dist-info/RECORD +95 -0
  93. juniper_data-0.4.2.dist-info/WHEEL +5 -0
  94. juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
  95. juniper_data-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,256 @@
1
+ """Unit tests for the concentric circles dataset generator."""
2
+
3
+ import numpy as np
4
+ import pytest
5
+
6
+ from juniper_data.generators.circles import VERSION, CirclesGenerator, CirclesParams, get_schema
7
+
8
+
9
+ class TestCirclesParams:
10
+ """Tests for CirclesParams validation."""
11
+
12
+ def test_default_params(self) -> None:
13
+ """Default parameters should be valid."""
14
+ params = CirclesParams()
15
+ assert params.n_samples == 100
16
+ assert params.outer_radius == 1.0
17
+ assert params.factor == 0.5
18
+ assert params.noise == 0.0
19
+ assert params.inner_ratio == 0.5
20
+ assert params.train_ratio == 0.8
21
+ assert params.test_ratio == 0.2
22
+ assert params.shuffle is True
23
+
24
+ def test_custom_params(self) -> None:
25
+ """Custom parameters should be accepted."""
26
+ params = CirclesParams(
27
+ n_samples=200,
28
+ outer_radius=2.0,
29
+ factor=0.3,
30
+ noise=0.1,
31
+ seed=42,
32
+ )
33
+ assert params.n_samples == 200
34
+ assert params.outer_radius == 2.0
35
+ assert params.factor == 0.3
36
+ assert params.noise == 0.1
37
+ assert params.seed == 42
38
+
39
+ def test_invalid_factor_too_low(self) -> None:
40
+ """Factor must be greater than 0."""
41
+ with pytest.raises(ValueError):
42
+ CirclesParams(factor=0.0)
43
+
44
+ def test_invalid_factor_too_high(self) -> None:
45
+ """Factor must be less than 1."""
46
+ with pytest.raises(ValueError):
47
+ CirclesParams(factor=1.0)
48
+
49
+ def test_invalid_outer_radius_negative(self) -> None:
50
+ """Outer radius must be positive."""
51
+ with pytest.raises(ValueError):
52
+ CirclesParams(outer_radius=-1.0)
53
+
54
+ def test_invalid_n_samples_too_low(self) -> None:
55
+ """n_samples must be at least 2."""
56
+ with pytest.raises(ValueError):
57
+ CirclesParams(n_samples=1)
58
+
59
+
60
+ class TestCirclesGenerator:
61
+ """Tests for CirclesGenerator."""
62
+
63
+ def test_generate_returns_expected_keys(self) -> None:
64
+ """Generated data should contain all expected keys."""
65
+ params = CirclesParams(seed=42)
66
+ result = CirclesGenerator.generate(params)
67
+
68
+ expected_keys = {"X_train", "y_train", "X_test", "y_test", "X_full", "y_full"}
69
+ assert set(result.keys()) == expected_keys
70
+
71
+ def test_generate_shapes(self) -> None:
72
+ """Generated arrays should have correct shapes."""
73
+ params = CirclesParams(n_samples=150, seed=42)
74
+ result = CirclesGenerator.generate(params)
75
+
76
+ assert result["X_full"].shape == (150, 2)
77
+ assert result["y_full"].shape == (150, 2)
78
+
79
+ def test_generate_dtypes(self) -> None:
80
+ """Generated arrays should have float32 dtype."""
81
+ params = CirclesParams(seed=42)
82
+ result = CirclesGenerator.generate(params)
83
+
84
+ assert result["X_train"].dtype == np.float32
85
+ assert result["y_train"].dtype == np.float32
86
+ assert result["X_full"].dtype == np.float32
87
+ assert result["y_full"].dtype == np.float32
88
+
89
+ def test_determinism_with_seed(self) -> None:
90
+ """Same seed should produce identical results."""
91
+ params = CirclesParams(seed=123)
92
+
93
+ result1 = CirclesGenerator.generate(params)
94
+ result2 = CirclesGenerator.generate(params)
95
+
96
+ np.testing.assert_array_equal(result1["X_full"], result2["X_full"])
97
+ np.testing.assert_array_equal(result1["y_full"], result2["y_full"])
98
+
99
+ def test_different_seeds_produce_different_data(self) -> None:
100
+ """Different seeds should produce different results."""
101
+ params1 = CirclesParams(seed=42)
102
+ params2 = CirclesParams(seed=43)
103
+
104
+ result1 = CirclesGenerator.generate(params1)
105
+ result2 = CirclesGenerator.generate(params2)
106
+
107
+ assert not np.allclose(result1["X_full"], result2["X_full"])
108
+
109
+ def test_one_hot_labels(self) -> None:
110
+ """Labels should be valid one-hot encoded."""
111
+ params = CirclesParams(seed=42)
112
+ result = CirclesGenerator.generate(params)
113
+
114
+ row_sums = result["y_full"].sum(axis=1)
115
+ np.testing.assert_array_almost_equal(row_sums, np.ones(len(row_sums)))
116
+
117
+ for row in result["y_full"]:
118
+ assert np.sum(row == 1.0) == 1
119
+ assert np.sum(row == 0.0) == 1
120
+
121
+ def test_class_distribution(self) -> None:
122
+ """Classes should be distributed according to inner_ratio."""
123
+ params = CirclesParams(n_samples=100, inner_ratio=0.5, seed=42)
124
+ result = CirclesGenerator.generate(params)
125
+
126
+ class_counts = result["y_full"].sum(axis=0)
127
+ assert class_counts[0] == 50
128
+ assert class_counts[1] == 50
129
+
130
+ def test_class_distribution_custom_ratio(self) -> None:
131
+ """Custom inner_ratio should be respected."""
132
+ params = CirclesParams(n_samples=100, inner_ratio=0.3, seed=42)
133
+ result = CirclesGenerator.generate(params)
134
+
135
+ class_counts = result["y_full"].sum(axis=0)
136
+ assert class_counts[0] == 70
137
+ assert class_counts[1] == 30
138
+
139
+ def test_train_test_split_ratio(self) -> None:
140
+ """Train/test split should respect configured ratios."""
141
+ params = CirclesParams(
142
+ n_samples=100,
143
+ train_ratio=0.7,
144
+ test_ratio=0.3,
145
+ seed=42,
146
+ )
147
+ result = CirclesGenerator.generate(params)
148
+
149
+ assert len(result["X_train"]) == 70
150
+ assert len(result["X_test"]) == 30
151
+
152
+ def test_points_on_circles_no_noise(self) -> None:
153
+ """Without noise, points should lie exactly on circles."""
154
+ params = CirclesParams(
155
+ n_samples=100,
156
+ outer_radius=2.0,
157
+ factor=0.5,
158
+ noise=0.0,
159
+ inner_ratio=0.5,
160
+ seed=42,
161
+ shuffle=False,
162
+ )
163
+ result = CirclesGenerator.generate(params)
164
+
165
+ outer_points = result["X_full"][:50]
166
+ inner_points = result["X_full"][50:]
167
+
168
+ outer_distances = np.linalg.norm(outer_points, axis=1)
169
+ inner_distances = np.linalg.norm(inner_points, axis=1)
170
+
171
+ np.testing.assert_array_almost_equal(outer_distances, np.full(50, 2.0))
172
+ np.testing.assert_array_almost_equal(inner_distances, np.full(50, 1.0))
173
+
174
+ def test_noise_adds_variation(self) -> None:
175
+ """Noise parameter should add variation to circle radii."""
176
+ params_no_noise = CirclesParams(n_samples=100, noise=0.0, seed=42)
177
+ params_with_noise = CirclesParams(n_samples=100, noise=0.5, seed=42)
178
+
179
+ result_no_noise = CirclesGenerator.generate(params_no_noise)
180
+ result_with_noise = CirclesGenerator.generate(params_with_noise)
181
+
182
+ var_no_noise = np.var(np.linalg.norm(result_no_noise["X_full"], axis=1))
183
+ var_with_noise = np.var(np.linalg.norm(result_with_noise["X_full"], axis=1))
184
+
185
+ assert var_with_noise > var_no_noise
186
+
187
+ def test_factor_affects_inner_radius(self) -> None:
188
+ """Factor should control inner circle radius."""
189
+ params = CirclesParams(
190
+ n_samples=100,
191
+ outer_radius=4.0,
192
+ factor=0.25,
193
+ noise=0.0,
194
+ inner_ratio=0.5,
195
+ seed=42,
196
+ shuffle=False,
197
+ )
198
+ result = CirclesGenerator.generate(params)
199
+
200
+ inner_points = result["X_full"][50:]
201
+ inner_distances = np.linalg.norm(inner_points, axis=1)
202
+
203
+ np.testing.assert_array_almost_equal(inner_distances, np.full(50, 1.0))
204
+
205
+ def test_generate_with_noise_covers_branch(self) -> None:
206
+ """Noise > 0 should exercise the noise addition branch."""
207
+ params = CirclesParams(
208
+ n_samples=100,
209
+ noise=0.3,
210
+ seed=42,
211
+ shuffle=False,
212
+ )
213
+ result = CirclesGenerator.generate(params)
214
+
215
+ outer_distances = np.linalg.norm(result["X_full"][:50], axis=1)
216
+ assert not np.allclose(outer_distances, np.full(50, 1.0))
217
+
218
+
219
+ class TestGetSchema:
220
+ """Tests for get_schema function."""
221
+
222
+ def test_returns_dict(self) -> None:
223
+ """get_schema should return a dictionary."""
224
+ schema = get_schema()
225
+ assert isinstance(schema, dict)
226
+
227
+ def test_schema_has_properties(self) -> None:
228
+ """Schema should have properties key."""
229
+ schema = get_schema()
230
+ assert "properties" in schema
231
+
232
+ def test_schema_includes_all_params(self) -> None:
233
+ """Schema should include all parameter names."""
234
+ schema = get_schema()
235
+ expected_params = {
236
+ "n_samples",
237
+ "outer_radius",
238
+ "factor",
239
+ "noise",
240
+ "inner_ratio",
241
+ "seed",
242
+ "train_ratio",
243
+ "test_ratio",
244
+ "shuffle",
245
+ }
246
+ assert expected_params.issubset(set(schema["properties"].keys()))
247
+
248
+
249
+ class TestVersion:
250
+ """Tests for VERSION constant."""
251
+
252
+ def test_version_format(self) -> None:
253
+ """VERSION should be a valid semver string."""
254
+ parts = VERSION.split(".")
255
+ assert len(parts) == 3
256
+ assert all(part.isdigit() for part in parts)
@@ -0,0 +1,345 @@
1
+ """Unit tests for the CSV/JSON import generator."""
2
+
3
+ import tempfile
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pytest
8
+
9
+ from juniper_data.generators.csv_import import (
10
+ VERSION,
11
+ CsvImportGenerator,
12
+ CsvImportParams,
13
+ get_schema,
14
+ )
15
+
16
+
17
+ @pytest.fixture
18
+ def sample_csv_file() -> Path:
19
+ """Create a sample CSV file for testing."""
20
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
21
+ f.write("feature1,feature2,label\n")
22
+ f.write("1.0,2.0,A\n")
23
+ f.write("3.0,4.0,B\n")
24
+ f.write("5.0,6.0,A\n")
25
+ f.write("7.0,8.0,B\n")
26
+ return Path(f.name)
27
+
28
+
29
+ @pytest.fixture
30
+ def sample_json_file() -> Path:
31
+ """Create a sample JSON file for testing."""
32
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
33
+ f.write('[{"feature1": 1.0, "feature2": 2.0, "label": "A"},')
34
+ f.write('{"feature1": 3.0, "feature2": 4.0, "label": "B"},')
35
+ f.write('{"feature1": 5.0, "feature2": 6.0, "label": "A"},')
36
+ f.write('{"feature1": 7.0, "feature2": 8.0, "label": "B"}]')
37
+ return Path(f.name)
38
+
39
+
40
+ @pytest.fixture
41
+ def sample_jsonl_file() -> Path:
42
+ """Create a sample JSONL file for testing."""
43
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
44
+ f.write('{"feature1": 1.0, "feature2": 2.0, "label": 0}\n')
45
+ f.write('{"feature1": 3.0, "feature2": 4.0, "label": 1}\n')
46
+ f.write('{"feature1": 5.0, "feature2": 6.0, "label": 0}\n')
47
+ f.write('{"feature1": 7.0, "feature2": 8.0, "label": 1}\n')
48
+ return Path(f.name)
49
+
50
+
51
+ class TestCsvImportParams:
52
+ """Tests for CsvImportParams validation."""
53
+
54
+ def test_valid_params(self) -> None:
55
+ """Valid parameters should be accepted."""
56
+ params = CsvImportParams(
57
+ file_path="/path/to/file.csv",
58
+ feature_columns=["col1", "col2"],
59
+ label_column="target",
60
+ )
61
+ assert params.file_path == "/path/to/file.csv"
62
+ assert params.feature_columns == ["col1", "col2"]
63
+ assert params.label_column == "target"
64
+
65
+ def test_default_values(self) -> None:
66
+ """Default values should be set correctly."""
67
+ params = CsvImportParams(file_path="/path/to/file.csv")
68
+ assert params.file_format == "auto"
69
+ assert params.feature_columns is None
70
+ assert params.label_column == "label"
71
+ assert params.delimiter == ","
72
+ assert params.header is True
73
+ assert params.one_hot_labels is True
74
+ assert params.normalize_features is False
75
+
76
+
77
+ class TestCsvImportGenerator:
78
+ """Tests for CsvImportGenerator."""
79
+
80
+ def test_load_csv_file(self, sample_csv_file: Path) -> None:
81
+ """Should load data from CSV file."""
82
+ params = CsvImportParams(
83
+ file_path=str(sample_csv_file),
84
+ seed=42,
85
+ )
86
+ result = CsvImportGenerator.generate(params)
87
+
88
+ assert result["X_full"].shape == (4, 2)
89
+ assert result["y_full"].shape == (4, 2)
90
+
91
+ def test_load_json_file(self, sample_json_file: Path) -> None:
92
+ """Should load data from JSON file."""
93
+ params = CsvImportParams(
94
+ file_path=str(sample_json_file),
95
+ seed=42,
96
+ )
97
+ result = CsvImportGenerator.generate(params)
98
+
99
+ assert result["X_full"].shape == (4, 2)
100
+ assert result["y_full"].shape == (4, 2)
101
+
102
+ def test_load_jsonl_file(self, sample_jsonl_file: Path) -> None:
103
+ """Should load data from JSONL file."""
104
+ params = CsvImportParams(
105
+ file_path=str(sample_jsonl_file),
106
+ seed=42,
107
+ )
108
+ result = CsvImportGenerator.generate(params)
109
+
110
+ assert result["X_full"].shape == (4, 2)
111
+ assert result["y_full"].shape == (4, 2)
112
+
113
+ def test_feature_values(self, sample_csv_file: Path) -> None:
114
+ """Feature values should be correctly parsed."""
115
+ params = CsvImportParams(
116
+ file_path=str(sample_csv_file),
117
+ shuffle=False,
118
+ )
119
+ result = CsvImportGenerator.generate(params)
120
+
121
+ expected_X = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], dtype=np.float32)
122
+ np.testing.assert_array_equal(result["X_full"], expected_X)
123
+
124
+ def test_one_hot_labels(self, sample_csv_file: Path) -> None:
125
+ """Labels should be one-hot encoded."""
126
+ params = CsvImportParams(
127
+ file_path=str(sample_csv_file),
128
+ one_hot_labels=True,
129
+ shuffle=False,
130
+ )
131
+ result = CsvImportGenerator.generate(params)
132
+
133
+ row_sums = result["y_full"].sum(axis=1)
134
+ np.testing.assert_array_almost_equal(row_sums, np.ones(4))
135
+
136
+ def test_non_one_hot_labels(self, sample_csv_file: Path) -> None:
137
+ """Labels should be indices when one_hot=False."""
138
+ params = CsvImportParams(
139
+ file_path=str(sample_csv_file),
140
+ one_hot_labels=False,
141
+ shuffle=False,
142
+ )
143
+ result = CsvImportGenerator.generate(params)
144
+
145
+ assert result["y_full"].shape == (4, 1)
146
+
147
+ def test_normalize_features(self, sample_csv_file: Path) -> None:
148
+ """Features should be normalized to [0, 1]."""
149
+ params = CsvImportParams(
150
+ file_path=str(sample_csv_file),
151
+ normalize_features=True,
152
+ shuffle=False,
153
+ )
154
+ result = CsvImportGenerator.generate(params)
155
+
156
+ assert result["X_full"].min() >= 0.0
157
+ assert result["X_full"].max() <= 1.0
158
+
159
+ def test_file_not_found(self) -> None:
160
+ """Should raise FileNotFoundError for missing file."""
161
+ params = CsvImportParams(file_path="/nonexistent/path/file.csv")
162
+
163
+ with pytest.raises(FileNotFoundError):
164
+ CsvImportGenerator.generate(params)
165
+
166
+ def test_train_test_split(self, sample_csv_file: Path) -> None:
167
+ """Train/test split should work correctly."""
168
+ params = CsvImportParams(
169
+ file_path=str(sample_csv_file),
170
+ train_ratio=0.5,
171
+ test_ratio=0.5,
172
+ seed=42,
173
+ )
174
+ result = CsvImportGenerator.generate(params)
175
+
176
+ assert len(result["X_train"]) == 2
177
+ assert len(result["X_test"]) == 2
178
+
179
+ def test_auto_detect_unsupported_extension(self) -> None:
180
+ """Unsupported file extension should raise ValueError."""
181
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".xml", delete=False) as f:
182
+ f.write("<data></data>")
183
+ xml_path = f.name
184
+
185
+ params = CsvImportParams(file_path=xml_path)
186
+ with pytest.raises(ValueError, match="Cannot auto-detect format"):
187
+ CsvImportGenerator.generate(params)
188
+
189
+ def test_csv_without_header(self) -> None:
190
+ """Should load headerless CSV with auto-generated column names."""
191
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
192
+ f.write("1.0,2.0,A\n")
193
+ f.write("3.0,4.0,B\n")
194
+ f.write("5.0,6.0,A\n")
195
+ f.write("7.0,8.0,B\n")
196
+ csv_path = f.name
197
+
198
+ params = CsvImportParams(
199
+ file_path=csv_path,
200
+ header=False,
201
+ label_column="col_2",
202
+ seed=42,
203
+ )
204
+ result = CsvImportGenerator.generate(params)
205
+
206
+ assert result["X_full"].shape == (4, 2)
207
+ assert result["y_full"].shape == (4, 2)
208
+
209
+ def test_json_jsonl_format(self) -> None:
210
+ """Should load JSONL (non-array) format via the else branch."""
211
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
212
+ f.write('{"feature1": 1.0, "feature2": 2.0, "label": "A"}\n')
213
+ f.write('{"feature1": 3.0, "feature2": 4.0, "label": "B"}\n')
214
+ jsonl_path = f.name
215
+
216
+ params = CsvImportParams(file_path=jsonl_path, seed=42)
217
+ result = CsvImportGenerator.generate(params)
218
+
219
+ assert result["X_full"].shape == (2, 2)
220
+
221
+ def test_convert_to_arrays_empty_data(self) -> None:
222
+ """Empty file should raise ValueError."""
223
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
224
+ f.write("feature1,feature2,label\n")
225
+ csv_path = f.name
226
+
227
+ params = CsvImportParams(file_path=csv_path, seed=42)
228
+ with pytest.raises(ValueError, match="No data found"):
229
+ CsvImportGenerator.generate(params)
230
+
231
+ def test_feature_columns_explicit(self) -> None:
232
+ """Explicit feature_columns should select only those columns."""
233
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
234
+ f.write("a,b,c,label\n")
235
+ f.write("1.0,2.0,3.0,A\n")
236
+ f.write("4.0,5.0,6.0,B\n")
237
+ csv_path = f.name
238
+
239
+ params = CsvImportParams(
240
+ file_path=csv_path,
241
+ feature_columns=["a", "c"],
242
+ seed=42,
243
+ )
244
+ result = CsvImportGenerator.generate(params)
245
+
246
+ assert result["X_full"].shape == (2, 2)
247
+
248
+ def test_non_numeric_feature_values(self) -> None:
249
+ """Non-numeric feature values should be replaced with 0.0."""
250
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
251
+ f.write("feature1,feature2,label\n")
252
+ f.write("1.0,hello,A\n")
253
+ f.write("3.0,world,B\n")
254
+ csv_path = f.name
255
+
256
+ params = CsvImportParams(
257
+ file_path=csv_path,
258
+ shuffle=False,
259
+ seed=42,
260
+ )
261
+ result = CsvImportGenerator.generate(params)
262
+
263
+ assert result["X_full"][0, 1] == 0.0
264
+ assert result["X_full"][1, 1] == 0.0
265
+
266
+ def test_empty_csv_without_header_raises(self) -> None:
267
+ """Empty CSV with header=False should raise ValueError."""
268
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
269
+ f.flush()
270
+ path = Path(f.name)
271
+
272
+ params = CsvImportParams(
273
+ file_path=str(path),
274
+ header=False,
275
+ seed=42,
276
+ )
277
+ with pytest.raises(ValueError, match="CSV file is empty"):
278
+ CsvImportGenerator.generate(params)
279
+
280
+ def test_normalize_with_constant_feature(self) -> None:
281
+ """Normalization with a constant feature column should not produce NaN."""
282
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
283
+ f.write("feature1,feature2,label\n")
284
+ f.write("5.0,1.0,A\n")
285
+ f.write("5.0,2.0,B\n")
286
+ f.write("5.0,3.0,A\n")
287
+ csv_path = f.name
288
+
289
+ params = CsvImportParams(
290
+ file_path=csv_path,
291
+ normalize_features=True,
292
+ shuffle=False,
293
+ seed=42,
294
+ )
295
+ result = CsvImportGenerator.generate(params)
296
+
297
+ assert not np.any(np.isnan(result["X_full"]))
298
+ assert result["X_full"][:, 0].min() == 0.0
299
+ assert result["X_full"][:, 0].max() == 0.0
300
+
301
+ def test_explicit_csv_format(self, sample_csv_file: Path) -> None:
302
+ """Explicit file_format='csv' should bypass auto-detect."""
303
+ params = CsvImportParams(
304
+ file_path=str(sample_csv_file),
305
+ file_format="csv",
306
+ seed=42,
307
+ )
308
+ result = CsvImportGenerator.generate(params)
309
+
310
+ assert result["X_full"].shape == (4, 2)
311
+
312
+ def test_explicit_json_format(self, sample_json_file: Path) -> None:
313
+ """Explicit file_format='json' should bypass auto-detect."""
314
+ params = CsvImportParams(
315
+ file_path=str(sample_json_file),
316
+ file_format="json",
317
+ seed=42,
318
+ )
319
+ result = CsvImportGenerator.generate(params)
320
+
321
+ assert result["X_full"].shape == (4, 2)
322
+
323
+
324
+ class TestGetSchema:
325
+ """Tests for get_schema function."""
326
+
327
+ def test_returns_dict(self) -> None:
328
+ """get_schema should return a dictionary."""
329
+ schema = get_schema()
330
+ assert isinstance(schema, dict)
331
+
332
+ def test_schema_has_properties(self) -> None:
333
+ """Schema should have properties key."""
334
+ schema = get_schema()
335
+ assert "properties" in schema
336
+
337
+
338
+ class TestVersion:
339
+ """Tests for VERSION constant."""
340
+
341
+ def test_version_format(self) -> None:
342
+ """VERSION should be a valid semver string."""
343
+ parts = VERSION.split(".")
344
+ assert len(parts) == 3
345
+ assert all(part.isdigit() for part in parts)