juniper-data 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- juniper_data/__init__.py +88 -0
- juniper_data/__main__.py +78 -0
- juniper_data/api/__init__.py +10 -0
- juniper_data/api/app.py +111 -0
- juniper_data/api/middleware.py +95 -0
- juniper_data/api/routes/__init__.py +9 -0
- juniper_data/api/routes/datasets.py +414 -0
- juniper_data/api/routes/generators.py +125 -0
- juniper_data/api/routes/health.py +49 -0
- juniper_data/api/security.py +238 -0
- juniper_data/api/settings.py +109 -0
- juniper_data/core/__init__.py +32 -0
- juniper_data/core/artifacts.py +63 -0
- juniper_data/core/dataset_id.py +38 -0
- juniper_data/core/models.py +135 -0
- juniper_data/core/split.py +120 -0
- juniper_data/generators/__init__.py +15 -0
- juniper_data/generators/arc_agi/__init__.py +11 -0
- juniper_data/generators/arc_agi/generator.py +229 -0
- juniper_data/generators/arc_agi/params.py +56 -0
- juniper_data/generators/checkerboard/__init__.py +15 -0
- juniper_data/generators/checkerboard/generator.py +114 -0
- juniper_data/generators/checkerboard/params.py +32 -0
- juniper_data/generators/circles/__init__.py +11 -0
- juniper_data/generators/circles/generator.py +112 -0
- juniper_data/generators/circles/params.py +31 -0
- juniper_data/generators/csv_import/__init__.py +15 -0
- juniper_data/generators/csv_import/generator.py +198 -0
- juniper_data/generators/csv_import/params.py +48 -0
- juniper_data/generators/gaussian/__init__.py +11 -0
- juniper_data/generators/gaussian/generator.py +149 -0
- juniper_data/generators/gaussian/params.py +53 -0
- juniper_data/generators/mnist/__init__.py +11 -0
- juniper_data/generators/mnist/generator.py +124 -0
- juniper_data/generators/mnist/params.py +39 -0
- juniper_data/generators/spiral/__init__.py +57 -0
- juniper_data/generators/spiral/defaults.py +39 -0
- juniper_data/generators/spiral/generator.py +206 -0
- juniper_data/generators/spiral/params.py +148 -0
- juniper_data/generators/xor/__init__.py +11 -0
- juniper_data/generators/xor/generator.py +162 -0
- juniper_data/generators/xor/params.py +30 -0
- juniper_data/storage/__init__.py +120 -0
- juniper_data/storage/base.py +279 -0
- juniper_data/storage/cached.py +211 -0
- juniper_data/storage/hf_store.py +257 -0
- juniper_data/storage/kaggle_store.py +333 -0
- juniper_data/storage/local_fs.py +232 -0
- juniper_data/storage/memory.py +136 -0
- juniper_data/storage/postgres_store.py +373 -0
- juniper_data/storage/redis_store.py +264 -0
- juniper_data/tests/__init__.py +1 -0
- juniper_data/tests/conftest.py +68 -0
- juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
- juniper_data/tests/integration/__init__.py +1 -0
- juniper_data/tests/integration/test_api.py +283 -0
- juniper_data/tests/integration/test_e2e_workflow.py +378 -0
- juniper_data/tests/integration/test_lifecycle_api.py +304 -0
- juniper_data/tests/integration/test_security_integration.py +189 -0
- juniper_data/tests/integration/test_storage_workflow.py +259 -0
- juniper_data/tests/performance/__init__.py +1 -0
- juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
- juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
- juniper_data/tests/unit/__init__.py +1 -0
- juniper_data/tests/unit/test_api_app.py +206 -0
- juniper_data/tests/unit/test_api_routes.py +407 -0
- juniper_data/tests/unit/test_api_settings.py +100 -0
- juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
- juniper_data/tests/unit/test_artifacts.py +145 -0
- juniper_data/tests/unit/test_cached_store.py +423 -0
- juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
- juniper_data/tests/unit/test_circles_generator.py +256 -0
- juniper_data/tests/unit/test_csv_import_generator.py +345 -0
- juniper_data/tests/unit/test_dataset_id.py +181 -0
- juniper_data/tests/unit/test_gaussian_generator.py +333 -0
- juniper_data/tests/unit/test_hf_store.py +416 -0
- juniper_data/tests/unit/test_init.py +93 -0
- juniper_data/tests/unit/test_kaggle_store.py +469 -0
- juniper_data/tests/unit/test_lifecycle.py +394 -0
- juniper_data/tests/unit/test_main.py +127 -0
- juniper_data/tests/unit/test_middleware.py +79 -0
- juniper_data/tests/unit/test_mnist_generator.py +370 -0
- juniper_data/tests/unit/test_postgres_store.py +490 -0
- juniper_data/tests/unit/test_redis_store.py +500 -0
- juniper_data/tests/unit/test_security.py +281 -0
- juniper_data/tests/unit/test_security_boundaries.py +517 -0
- juniper_data/tests/unit/test_spiral_generator.py +566 -0
- juniper_data/tests/unit/test_split.py +245 -0
- juniper_data/tests/unit/test_storage.py +767 -0
- juniper_data/tests/unit/test_xor_generator.py +223 -0
- juniper_data-0.4.2.dist-info/METADATA +216 -0
- juniper_data-0.4.2.dist-info/RECORD +95 -0
- juniper_data-0.4.2.dist-info/WHEEL +5 -0
- juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
- juniper_data-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,525 @@
|
|
|
1
|
+
"""Unit tests for the ARC-AGI dataset generator."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from unittest.mock import MagicMock, patch
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pytest
|
|
8
|
+
|
|
9
|
+
from juniper_data.generators.arc_agi.params import ArcAgiParams
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _make_sample_tasks(n_tasks=3, train_pairs=2, test_pairs=1):
|
|
13
|
+
"""Create sample ARC tasks for testing."""
|
|
14
|
+
tasks = []
|
|
15
|
+
for i in range(n_tasks):
|
|
16
|
+
task = {
|
|
17
|
+
"task_id": f"task_{i}",
|
|
18
|
+
"train": [
|
|
19
|
+
{
|
|
20
|
+
"input": [[j, j + 1], [j + 2, j + 3]],
|
|
21
|
+
"output": [[j + 3, j + 2], [j + 1, j]],
|
|
22
|
+
}
|
|
23
|
+
for j in range(train_pairs)
|
|
24
|
+
],
|
|
25
|
+
"test": [
|
|
26
|
+
{
|
|
27
|
+
"input": [[1, 2], [3, 4]],
|
|
28
|
+
"output": [[4, 3], [2, 1]],
|
|
29
|
+
}
|
|
30
|
+
for _ in range(test_pairs)
|
|
31
|
+
],
|
|
32
|
+
}
|
|
33
|
+
tasks.append(task)
|
|
34
|
+
return tasks
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _make_mock_hf_dataset(tasks):
|
|
38
|
+
"""Create a mock HuggingFace dataset from task list."""
|
|
39
|
+
mock_ds = MagicMock()
|
|
40
|
+
mock_ds.__len__ = MagicMock(return_value=len(tasks))
|
|
41
|
+
mock_ds.__iter__ = MagicMock(return_value=iter(tasks))
|
|
42
|
+
|
|
43
|
+
return mock_ds
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@pytest.fixture
|
|
47
|
+
def mock_hf_load():
|
|
48
|
+
"""Patch HF_AVAILABLE and hf_load_dataset for the arc_agi generator module."""
|
|
49
|
+
mock_load = MagicMock()
|
|
50
|
+
|
|
51
|
+
with patch("juniper_data.generators.arc_agi.generator.HF_AVAILABLE", True):
|
|
52
|
+
with patch("juniper_data.generators.arc_agi.generator.hf_load_dataset", mock_load):
|
|
53
|
+
yield mock_load
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@pytest.mark.unit
|
|
57
|
+
@pytest.mark.generators
|
|
58
|
+
class TestArcAgiParams:
|
|
59
|
+
"""Tests for ArcAgiParams validation."""
|
|
60
|
+
|
|
61
|
+
def test_default_params(self) -> None:
|
|
62
|
+
"""Default parameters are valid."""
|
|
63
|
+
params = ArcAgiParams()
|
|
64
|
+
assert params.source == "huggingface"
|
|
65
|
+
assert params.local_path is None
|
|
66
|
+
assert params.subset == "training"
|
|
67
|
+
assert params.n_tasks is None
|
|
68
|
+
assert params.pad_to == 30
|
|
69
|
+
assert params.pad_value == -1
|
|
70
|
+
assert params.include_test is True
|
|
71
|
+
assert params.flatten_pairs is True
|
|
72
|
+
assert params.train_ratio == 0.8
|
|
73
|
+
|
|
74
|
+
def test_custom_params(self) -> None:
|
|
75
|
+
"""Custom parameters are accepted."""
|
|
76
|
+
params = ArcAgiParams(
|
|
77
|
+
source="local",
|
|
78
|
+
local_path="/data/arc",
|
|
79
|
+
subset="evaluation",
|
|
80
|
+
n_tasks=10,
|
|
81
|
+
pad_to=15,
|
|
82
|
+
pad_value=0,
|
|
83
|
+
include_test=False,
|
|
84
|
+
flatten_pairs=False,
|
|
85
|
+
seed=42,
|
|
86
|
+
)
|
|
87
|
+
assert params.source == "local"
|
|
88
|
+
assert params.local_path == "/data/arc"
|
|
89
|
+
assert params.n_tasks == 10
|
|
90
|
+
|
|
91
|
+
def test_invalid_pad_to(self) -> None:
|
|
92
|
+
"""pad_to must be >= 1 and <= 50."""
|
|
93
|
+
with pytest.raises(ValueError):
|
|
94
|
+
ArcAgiParams(pad_to=0)
|
|
95
|
+
with pytest.raises(ValueError):
|
|
96
|
+
ArcAgiParams(pad_to=51)
|
|
97
|
+
|
|
98
|
+
def test_invalid_n_tasks(self) -> None:
|
|
99
|
+
"""n_tasks must be >= 1."""
|
|
100
|
+
with pytest.raises(ValueError):
|
|
101
|
+
ArcAgiParams(n_tasks=0)
|
|
102
|
+
|
|
103
|
+
def test_invalid_train_ratio(self) -> None:
|
|
104
|
+
"""train_ratio must be in (0, 1]."""
|
|
105
|
+
with pytest.raises(ValueError):
|
|
106
|
+
ArcAgiParams(train_ratio=0)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@pytest.mark.unit
|
|
110
|
+
@pytest.mark.generators
|
|
111
|
+
class TestArcAgiGeneratorHuggingFace:
|
|
112
|
+
"""Tests for ARC-AGI generation from HuggingFace source."""
|
|
113
|
+
|
|
114
|
+
def test_generate_from_hf(self, mock_hf_load) -> None:
|
|
115
|
+
"""Generate produces correct output structure from HuggingFace."""
|
|
116
|
+
tasks = _make_sample_tasks(n_tasks=2, train_pairs=2, test_pairs=1)
|
|
117
|
+
mock_ds = _make_mock_hf_dataset(tasks)
|
|
118
|
+
mock_hf_load.return_value = mock_ds
|
|
119
|
+
|
|
120
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
121
|
+
|
|
122
|
+
params = ArcAgiParams(source="huggingface", seed=42)
|
|
123
|
+
result = ArcAgiGenerator.generate(params)
|
|
124
|
+
|
|
125
|
+
assert "X_train" in result
|
|
126
|
+
assert "y_train" in result
|
|
127
|
+
assert "X_test" in result
|
|
128
|
+
assert "y_test" in result
|
|
129
|
+
assert "X_full" in result
|
|
130
|
+
assert "y_full" in result
|
|
131
|
+
assert "task_ids" in result
|
|
132
|
+
|
|
133
|
+
def test_generate_correct_shapes_flattened(self, mock_hf_load) -> None:
|
|
134
|
+
"""Flattened output has correct shape (n_pairs, pad_to*pad_to)."""
|
|
135
|
+
tasks = _make_sample_tasks(n_tasks=2, train_pairs=2, test_pairs=1)
|
|
136
|
+
mock_ds = _make_mock_hf_dataset(tasks)
|
|
137
|
+
mock_hf_load.return_value = mock_ds
|
|
138
|
+
|
|
139
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
140
|
+
|
|
141
|
+
params = ArcAgiParams(pad_to=5, flatten_pairs=True, seed=42)
|
|
142
|
+
result = ArcAgiGenerator.generate(params)
|
|
143
|
+
|
|
144
|
+
n_total = 2 * 2 + 2 * 1 # 2 tasks * 2 train + 2 tasks * 1 test
|
|
145
|
+
assert result["X_full"].shape == (n_total, 25)
|
|
146
|
+
assert result["y_full"].shape == (n_total, 25)
|
|
147
|
+
|
|
148
|
+
def test_generate_correct_shapes_not_flattened(self, mock_hf_load) -> None:
|
|
149
|
+
"""Non-flattened output has correct shape (n_pairs, pad_to, pad_to)."""
|
|
150
|
+
tasks = _make_sample_tasks(n_tasks=2, train_pairs=2, test_pairs=1)
|
|
151
|
+
mock_ds = _make_mock_hf_dataset(tasks)
|
|
152
|
+
mock_hf_load.return_value = mock_ds
|
|
153
|
+
|
|
154
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
155
|
+
|
|
156
|
+
params = ArcAgiParams(pad_to=5, flatten_pairs=False, seed=42)
|
|
157
|
+
result = ArcAgiGenerator.generate(params)
|
|
158
|
+
|
|
159
|
+
n_total = 2 * 2 + 2 * 1
|
|
160
|
+
assert result["X_full"].shape == (n_total, 5, 5)
|
|
161
|
+
|
|
162
|
+
def test_generate_without_test_pairs(self, mock_hf_load) -> None:
|
|
163
|
+
"""Generate without test pairs only uses train pairs."""
|
|
164
|
+
tasks = _make_sample_tasks(n_tasks=2, train_pairs=3, test_pairs=2)
|
|
165
|
+
mock_ds = _make_mock_hf_dataset(tasks)
|
|
166
|
+
mock_hf_load.return_value = mock_ds
|
|
167
|
+
|
|
168
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
169
|
+
|
|
170
|
+
params = ArcAgiParams(include_test=False, pad_to=5, seed=42)
|
|
171
|
+
result = ArcAgiGenerator.generate(params)
|
|
172
|
+
|
|
173
|
+
n_total = 2 * 3 # only train pairs
|
|
174
|
+
assert result["X_full"].shape[0] == n_total
|
|
175
|
+
|
|
176
|
+
def test_generate_with_n_tasks_seed(self, mock_hf_load) -> None:
|
|
177
|
+
"""n_tasks with seed selects random subset."""
|
|
178
|
+
tasks = _make_sample_tasks(n_tasks=10, train_pairs=1, test_pairs=0)
|
|
179
|
+
mock_ds = _make_mock_hf_dataset(tasks)
|
|
180
|
+
mock_hf_load.return_value = mock_ds
|
|
181
|
+
|
|
182
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
183
|
+
|
|
184
|
+
params = ArcAgiParams(n_tasks=3, seed=42, include_test=False, pad_to=5)
|
|
185
|
+
result = ArcAgiGenerator.generate(params)
|
|
186
|
+
|
|
187
|
+
assert result["X_full"].shape[0] == 3
|
|
188
|
+
|
|
189
|
+
def test_generate_with_n_tasks_no_seed(self, mock_hf_load) -> None:
|
|
190
|
+
"""n_tasks without seed takes first N tasks."""
|
|
191
|
+
tasks = _make_sample_tasks(n_tasks=10, train_pairs=1, test_pairs=0)
|
|
192
|
+
mock_ds = _make_mock_hf_dataset(tasks)
|
|
193
|
+
mock_hf_load.return_value = mock_ds
|
|
194
|
+
|
|
195
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
196
|
+
|
|
197
|
+
params = ArcAgiParams(n_tasks=3, seed=None, include_test=False, pad_to=5)
|
|
198
|
+
result = ArcAgiGenerator.generate(params)
|
|
199
|
+
|
|
200
|
+
assert result["X_full"].shape[0] == 3
|
|
201
|
+
|
|
202
|
+
def test_generate_hf_fallback_dataset(self, mock_hf_load) -> None:
|
|
203
|
+
"""HuggingFace loading tries fallback dataset on failure."""
|
|
204
|
+
tasks = _make_sample_tasks(n_tasks=2, train_pairs=1, test_pairs=0)
|
|
205
|
+
mock_ds = _make_mock_hf_dataset(tasks)
|
|
206
|
+
|
|
207
|
+
call_count = [0]
|
|
208
|
+
|
|
209
|
+
def side_effect(name, split=None):
|
|
210
|
+
call_count[0] += 1
|
|
211
|
+
if call_count[0] == 1:
|
|
212
|
+
raise Exception("Dataset not found")
|
|
213
|
+
return mock_ds
|
|
214
|
+
|
|
215
|
+
mock_hf_load.side_effect = side_effect
|
|
216
|
+
|
|
217
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
218
|
+
|
|
219
|
+
params = ArcAgiParams(include_test=False, pad_to=5, seed=42)
|
|
220
|
+
result = ArcAgiGenerator.generate(params)
|
|
221
|
+
|
|
222
|
+
assert mock_hf_load.call_count == 2
|
|
223
|
+
assert result["X_full"].shape[0] == 2
|
|
224
|
+
|
|
225
|
+
def test_generate_raises_without_datasets(self) -> None:
|
|
226
|
+
"""Raises ImportError when datasets not installed."""
|
|
227
|
+
with patch("juniper_data.generators.arc_agi.generator.HF_AVAILABLE", False):
|
|
228
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
229
|
+
|
|
230
|
+
params = ArcAgiParams(source="huggingface")
|
|
231
|
+
with pytest.raises(ImportError, match="Hugging Face datasets package not installed"):
|
|
232
|
+
ArcAgiGenerator.generate(params)
|
|
233
|
+
|
|
234
|
+
def test_generate_hf_missing_task_id(self, mock_hf_load) -> None:
|
|
235
|
+
"""Handle HF items without task_id."""
|
|
236
|
+
tasks = [{"train": [{"input": [[1]], "output": [[2]]}], "test": []}]
|
|
237
|
+
mock_ds = _make_mock_hf_dataset(tasks)
|
|
238
|
+
mock_hf_load.return_value = mock_ds
|
|
239
|
+
|
|
240
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
241
|
+
|
|
242
|
+
params = ArcAgiParams(include_test=False, pad_to=5, seed=42)
|
|
243
|
+
result = ArcAgiGenerator.generate(params)
|
|
244
|
+
|
|
245
|
+
assert "task_0" in str(result["task_ids"][0])
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@pytest.mark.unit
|
|
249
|
+
@pytest.mark.generators
|
|
250
|
+
class TestArcAgiGeneratorLocal:
|
|
251
|
+
"""Tests for ARC-AGI generation from local files."""
|
|
252
|
+
|
|
253
|
+
def test_generate_from_local(self, tmp_path) -> None:
|
|
254
|
+
"""Generate from local JSON files."""
|
|
255
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
256
|
+
|
|
257
|
+
training_dir = tmp_path / "training"
|
|
258
|
+
training_dir.mkdir()
|
|
259
|
+
|
|
260
|
+
task = {
|
|
261
|
+
"train": [
|
|
262
|
+
{"input": [[1, 2], [3, 4]], "output": [[4, 3], [2, 1]]},
|
|
263
|
+
],
|
|
264
|
+
"test": [
|
|
265
|
+
{"input": [[5, 6], [7, 8]], "output": [[8, 7], [6, 5]]},
|
|
266
|
+
],
|
|
267
|
+
}
|
|
268
|
+
(training_dir / "task1.json").write_text(json.dumps(task))
|
|
269
|
+
(training_dir / "task2.json").write_text(json.dumps(task))
|
|
270
|
+
|
|
271
|
+
params = ArcAgiParams(source="local", local_path=str(tmp_path), subset="training", pad_to=5, seed=42)
|
|
272
|
+
result = ArcAgiGenerator.generate(params)
|
|
273
|
+
|
|
274
|
+
assert result["X_full"].shape[0] == 4 # 2 tasks * (1 train + 1 test)
|
|
275
|
+
assert result["task_ids"].shape[0] == 4
|
|
276
|
+
|
|
277
|
+
def test_generate_local_evaluation_subset(self, tmp_path) -> None:
|
|
278
|
+
"""Generate from evaluation subset."""
|
|
279
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
280
|
+
|
|
281
|
+
eval_dir = tmp_path / "evaluation"
|
|
282
|
+
eval_dir.mkdir()
|
|
283
|
+
|
|
284
|
+
task = {"train": [{"input": [[1]], "output": [[2]]}], "test": []}
|
|
285
|
+
(eval_dir / "eval1.json").write_text(json.dumps(task))
|
|
286
|
+
|
|
287
|
+
params = ArcAgiParams(
|
|
288
|
+
source="local", local_path=str(tmp_path), subset="evaluation", include_test=False, pad_to=5, seed=42
|
|
289
|
+
)
|
|
290
|
+
result = ArcAgiGenerator.generate(params)
|
|
291
|
+
|
|
292
|
+
assert result["X_full"].shape[0] == 1
|
|
293
|
+
|
|
294
|
+
def test_generate_local_all_subsets(self, tmp_path) -> None:
|
|
295
|
+
"""Generate from all subsets."""
|
|
296
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
297
|
+
|
|
298
|
+
training_dir = tmp_path / "training"
|
|
299
|
+
training_dir.mkdir()
|
|
300
|
+
eval_dir = tmp_path / "evaluation"
|
|
301
|
+
eval_dir.mkdir()
|
|
302
|
+
|
|
303
|
+
task = {"train": [{"input": [[1]], "output": [[2]]}], "test": []}
|
|
304
|
+
(training_dir / "t1.json").write_text(json.dumps(task))
|
|
305
|
+
(eval_dir / "e1.json").write_text(json.dumps(task))
|
|
306
|
+
|
|
307
|
+
params = ArcAgiParams(
|
|
308
|
+
source="local", local_path=str(tmp_path), subset="all", include_test=False, pad_to=5, seed=42
|
|
309
|
+
)
|
|
310
|
+
result = ArcAgiGenerator.generate(params)
|
|
311
|
+
|
|
312
|
+
assert result["X_full"].shape[0] == 2
|
|
313
|
+
|
|
314
|
+
def test_generate_local_missing_path(self) -> None:
|
|
315
|
+
"""Raises ValueError when local_path is None."""
|
|
316
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
317
|
+
|
|
318
|
+
params = ArcAgiParams(source="local", local_path=None)
|
|
319
|
+
with pytest.raises(ValueError, match="local_path is required"):
|
|
320
|
+
ArcAgiGenerator.generate(params)
|
|
321
|
+
|
|
322
|
+
def test_generate_local_nonexistent_path(self) -> None:
|
|
323
|
+
"""Raises FileNotFoundError when path doesn't exist."""
|
|
324
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
325
|
+
|
|
326
|
+
params = ArcAgiParams(source="local", local_path="/nonexistent/path")
|
|
327
|
+
with pytest.raises(FileNotFoundError, match="Path not found"):
|
|
328
|
+
ArcAgiGenerator.generate(params)
|
|
329
|
+
|
|
330
|
+
def test_generate_local_with_n_tasks_seed(self, tmp_path) -> None:
|
|
331
|
+
"""n_tasks with seed selects random subset from local files."""
|
|
332
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
333
|
+
|
|
334
|
+
training_dir = tmp_path / "training"
|
|
335
|
+
training_dir.mkdir()
|
|
336
|
+
|
|
337
|
+
task = {"train": [{"input": [[1]], "output": [[2]]}], "test": []}
|
|
338
|
+
for i in range(10):
|
|
339
|
+
(training_dir / f"task_{i}.json").write_text(json.dumps(task))
|
|
340
|
+
|
|
341
|
+
params = ArcAgiParams(
|
|
342
|
+
source="local", local_path=str(tmp_path), n_tasks=3, seed=42, include_test=False, pad_to=5
|
|
343
|
+
)
|
|
344
|
+
result = ArcAgiGenerator.generate(params)
|
|
345
|
+
|
|
346
|
+
assert result["X_full"].shape[0] == 3
|
|
347
|
+
|
|
348
|
+
def test_generate_local_with_n_tasks_no_seed(self, tmp_path) -> None:
|
|
349
|
+
"""n_tasks without seed takes first N local tasks."""
|
|
350
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
351
|
+
|
|
352
|
+
training_dir = tmp_path / "training"
|
|
353
|
+
training_dir.mkdir()
|
|
354
|
+
|
|
355
|
+
task = {"train": [{"input": [[1]], "output": [[2]]}], "test": []}
|
|
356
|
+
for i in range(10):
|
|
357
|
+
(training_dir / f"task_{i:02d}.json").write_text(json.dumps(task))
|
|
358
|
+
|
|
359
|
+
params = ArcAgiParams(
|
|
360
|
+
source="local", local_path=str(tmp_path), n_tasks=3, seed=None, include_test=False, pad_to=5
|
|
361
|
+
)
|
|
362
|
+
result = ArcAgiGenerator.generate(params)
|
|
363
|
+
|
|
364
|
+
assert result["X_full"].shape[0] == 3
|
|
365
|
+
|
|
366
|
+
def test_generate_local_missing_subdirs(self, tmp_path) -> None:
|
|
367
|
+
"""Handle missing training/evaluation subdirectories gracefully."""
|
|
368
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
369
|
+
|
|
370
|
+
params = ArcAgiParams(
|
|
371
|
+
source="local", local_path=str(tmp_path), subset="training", include_test=False, pad_to=5, seed=42
|
|
372
|
+
)
|
|
373
|
+
result = ArcAgiGenerator.generate(params)
|
|
374
|
+
|
|
375
|
+
assert result["X_full"].shape[0] == 0
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
@pytest.mark.unit
|
|
379
|
+
@pytest.mark.generators
|
|
380
|
+
class TestArcAgiGeneratorPadGrid:
|
|
381
|
+
"""Tests for _pad_grid helper."""
|
|
382
|
+
|
|
383
|
+
def test_pad_grid_smaller_than_target(self) -> None:
|
|
384
|
+
"""Pad a small grid to target size."""
|
|
385
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
386
|
+
|
|
387
|
+
grid = [[1, 2], [3, 4]]
|
|
388
|
+
result = ArcAgiGenerator._pad_grid(grid, pad_to=5, pad_value=-1)
|
|
389
|
+
|
|
390
|
+
assert result.shape == (5, 5)
|
|
391
|
+
assert result[0, 0] == 1
|
|
392
|
+
assert result[0, 2] == -1
|
|
393
|
+
assert result[2, 0] == -1
|
|
394
|
+
|
|
395
|
+
def test_pad_grid_exact_size(self) -> None:
|
|
396
|
+
"""Grid exactly matching target size is unchanged."""
|
|
397
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
398
|
+
|
|
399
|
+
grid = [[1, 2], [3, 4]]
|
|
400
|
+
result = ArcAgiGenerator._pad_grid(grid, pad_to=2, pad_value=0)
|
|
401
|
+
|
|
402
|
+
assert result.shape == (2, 2)
|
|
403
|
+
np.testing.assert_array_equal(result, [[1, 2], [3, 4]])
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
@pytest.mark.unit
|
|
407
|
+
@pytest.mark.generators
|
|
408
|
+
class TestArcAgiGeneratorConvertTasks:
|
|
409
|
+
"""Tests for _convert_tasks_to_arrays helper."""
|
|
410
|
+
|
|
411
|
+
def test_convert_empty_tasks(self) -> None:
|
|
412
|
+
"""Empty task list returns zero-sized arrays."""
|
|
413
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
414
|
+
|
|
415
|
+
params = ArcAgiParams(pad_to=5, include_test=False)
|
|
416
|
+
X, y, ids = ArcAgiGenerator._convert_tasks_to_arrays([], params)
|
|
417
|
+
|
|
418
|
+
assert X.shape == (0, 25)
|
|
419
|
+
assert y.shape == (0, 25)
|
|
420
|
+
assert len(ids) == 0
|
|
421
|
+
|
|
422
|
+
def test_convert_tasks_with_test_missing_output(self) -> None:
|
|
423
|
+
"""Test pairs with missing output get padded."""
|
|
424
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
425
|
+
|
|
426
|
+
tasks = [
|
|
427
|
+
{
|
|
428
|
+
"task_id": "test-task",
|
|
429
|
+
"train": [{"input": [[1]], "output": [[2]]}],
|
|
430
|
+
"test": [{"input": [[3]]}],
|
|
431
|
+
}
|
|
432
|
+
]
|
|
433
|
+
|
|
434
|
+
params = ArcAgiParams(pad_to=5, include_test=True, flatten_pairs=True)
|
|
435
|
+
X, y, ids = ArcAgiGenerator._convert_tasks_to_arrays(tasks, params)
|
|
436
|
+
|
|
437
|
+
assert X.shape[0] == 2 # 1 train + 1 test
|
|
438
|
+
assert y.shape[0] == 2
|
|
439
|
+
|
|
440
|
+
def test_convert_tasks_unknown_task_id(self) -> None:
|
|
441
|
+
"""Tasks without task_id get 'unknown'."""
|
|
442
|
+
from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
|
|
443
|
+
|
|
444
|
+
tasks = [{"train": [{"input": [[1]], "output": [[2]]}], "test": []}]
|
|
445
|
+
|
|
446
|
+
params = ArcAgiParams(pad_to=5, include_test=False, flatten_pairs=True)
|
|
447
|
+
X, y, ids = ArcAgiGenerator._convert_tasks_to_arrays(tasks, params)
|
|
448
|
+
|
|
449
|
+
assert ids[0] == "unknown"
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
@pytest.mark.unit
|
|
453
|
+
@pytest.mark.generators
|
|
454
|
+
class TestArcAgiGetSchema:
|
|
455
|
+
"""Tests for get_schema function."""
|
|
456
|
+
|
|
457
|
+
def test_get_schema_returns_dict(self) -> None:
|
|
458
|
+
"""get_schema returns a dictionary."""
|
|
459
|
+
from juniper_data.generators.arc_agi.generator import get_schema
|
|
460
|
+
|
|
461
|
+
schema = get_schema()
|
|
462
|
+
assert isinstance(schema, dict)
|
|
463
|
+
|
|
464
|
+
def test_get_schema_has_properties(self) -> None:
|
|
465
|
+
"""Schema has expected properties."""
|
|
466
|
+
from juniper_data.generators.arc_agi.generator import get_schema
|
|
467
|
+
|
|
468
|
+
schema = get_schema()
|
|
469
|
+
assert "properties" in schema
|
|
470
|
+
assert "source" in schema["properties"]
|
|
471
|
+
assert "pad_to" in schema["properties"]
|
|
472
|
+
assert "n_tasks" in schema["properties"]
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
@pytest.mark.unit
|
|
476
|
+
@pytest.mark.generators
|
|
477
|
+
class TestArcAgiVersion:
|
|
478
|
+
"""Tests for version constant."""
|
|
479
|
+
|
|
480
|
+
def test_version_format(self) -> None:
|
|
481
|
+
"""Version follows semver format."""
|
|
482
|
+
from juniper_data.generators.arc_agi.generator import VERSION
|
|
483
|
+
|
|
484
|
+
parts = VERSION.split(".")
|
|
485
|
+
assert len(parts) == 3
|
|
486
|
+
for part in parts:
|
|
487
|
+
assert part.isdigit()
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
@pytest.mark.unit
|
|
491
|
+
@pytest.mark.generators
|
|
492
|
+
class TestArcAgiImports:
|
|
493
|
+
"""Tests for __init__.py imports."""
|
|
494
|
+
|
|
495
|
+
def test_init_exports(self) -> None:
|
|
496
|
+
"""__init__.py exports expected symbols."""
|
|
497
|
+
from juniper_data.generators.arc_agi import VERSION, ArcAgiGenerator, ArcAgiParams, get_schema
|
|
498
|
+
|
|
499
|
+
assert ArcAgiGenerator is not None
|
|
500
|
+
assert ArcAgiParams is not None
|
|
501
|
+
assert VERSION is not None
|
|
502
|
+
assert get_schema is not None
|
|
503
|
+
|
|
504
|
+
def test_module_level_hf_available_true(self) -> None:
|
|
505
|
+
"""Module-level HF_AVAILABLE is True when datasets is importable."""
|
|
506
|
+
import importlib
|
|
507
|
+
import sys
|
|
508
|
+
from types import ModuleType
|
|
509
|
+
from unittest.mock import MagicMock
|
|
510
|
+
|
|
511
|
+
# Inject a fake 'datasets' module so the try-branch succeeds
|
|
512
|
+
fake_datasets = ModuleType("datasets")
|
|
513
|
+
fake_datasets.load_dataset = MagicMock() # type: ignore[attr-defined]
|
|
514
|
+
sys.modules["datasets"] = fake_datasets
|
|
515
|
+
|
|
516
|
+
mod_name = "juniper_data.generators.arc_agi.generator"
|
|
517
|
+
sys.modules.pop(mod_name, None)
|
|
518
|
+
try:
|
|
519
|
+
mod = importlib.import_module(mod_name)
|
|
520
|
+
assert mod.HF_AVAILABLE is True
|
|
521
|
+
finally:
|
|
522
|
+
# Restore original state
|
|
523
|
+
sys.modules.pop("datasets", None)
|
|
524
|
+
sys.modules.pop(mod_name, None)
|
|
525
|
+
importlib.import_module(mod_name)
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Unit tests for core artifacts module."""
|
|
2
|
+
|
|
3
|
+
import tempfile
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pytest
|
|
8
|
+
|
|
9
|
+
from juniper_data.core.artifacts import arrays_to_bytes, compute_checksum, load_npz, save_npz
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@pytest.fixture
|
|
13
|
+
def sample_arrays() -> dict[str, np.ndarray]:
|
|
14
|
+
"""Create sample arrays for testing."""
|
|
15
|
+
return {
|
|
16
|
+
"X": np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32),
|
|
17
|
+
"y": np.array([[1, 0], [0, 1]], dtype=np.float32),
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@pytest.fixture
|
|
22
|
+
def temp_npz_path():
|
|
23
|
+
"""Create a temporary file path for NPZ files."""
|
|
24
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
25
|
+
yield Path(tmpdir) / "test.npz"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TestSaveNpz:
|
|
29
|
+
"""Tests for save_npz function."""
|
|
30
|
+
|
|
31
|
+
@pytest.mark.unit
|
|
32
|
+
def test_save_creates_file(self, temp_npz_path: Path, sample_arrays: dict[str, np.ndarray]):
|
|
33
|
+
"""Test that save_npz creates the file."""
|
|
34
|
+
save_npz(temp_npz_path, sample_arrays)
|
|
35
|
+
assert temp_npz_path.exists()
|
|
36
|
+
|
|
37
|
+
@pytest.mark.unit
|
|
38
|
+
def test_save_correct_content(self, temp_npz_path: Path, sample_arrays: dict[str, np.ndarray]):
|
|
39
|
+
"""Test that saved content is correct."""
|
|
40
|
+
save_npz(temp_npz_path, sample_arrays)
|
|
41
|
+
|
|
42
|
+
with np.load(temp_npz_path) as data:
|
|
43
|
+
assert set(data.files) == set(sample_arrays.keys())
|
|
44
|
+
for key in sample_arrays:
|
|
45
|
+
np.testing.assert_array_equal(data[key], sample_arrays[key])
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class TestLoadNpz:
|
|
49
|
+
"""Tests for load_npz function."""
|
|
50
|
+
|
|
51
|
+
@pytest.mark.unit
|
|
52
|
+
def test_load_returns_arrays(self, temp_npz_path: Path, sample_arrays: dict[str, np.ndarray]):
|
|
53
|
+
"""Test that load_npz returns the correct arrays."""
|
|
54
|
+
save_npz(temp_npz_path, sample_arrays)
|
|
55
|
+
loaded = load_npz(temp_npz_path)
|
|
56
|
+
|
|
57
|
+
assert isinstance(loaded, dict)
|
|
58
|
+
assert set(loaded.keys()) == set(sample_arrays.keys())
|
|
59
|
+
for key in sample_arrays:
|
|
60
|
+
np.testing.assert_array_equal(loaded[key], sample_arrays[key])
|
|
61
|
+
|
|
62
|
+
@pytest.mark.unit
|
|
63
|
+
def test_load_returns_mutable_arrays(self, temp_npz_path: Path, sample_arrays: dict[str, np.ndarray]):
|
|
64
|
+
"""Test that loaded arrays are mutable (not memory-mapped)."""
|
|
65
|
+
save_npz(temp_npz_path, sample_arrays)
|
|
66
|
+
loaded = load_npz(temp_npz_path)
|
|
67
|
+
|
|
68
|
+
loaded["X"][0, 0] = 999.0
|
|
69
|
+
assert loaded["X"][0, 0] == 999.0
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class TestArraysToBytes:
|
|
73
|
+
"""Tests for arrays_to_bytes function."""
|
|
74
|
+
|
|
75
|
+
@pytest.mark.unit
|
|
76
|
+
def test_returns_bytes(self, sample_arrays: dict[str, np.ndarray]):
|
|
77
|
+
"""Test that arrays_to_bytes returns bytes."""
|
|
78
|
+
result = arrays_to_bytes(sample_arrays)
|
|
79
|
+
assert isinstance(result, bytes)
|
|
80
|
+
|
|
81
|
+
@pytest.mark.unit
|
|
82
|
+
def test_bytes_contain_npz_data(self, sample_arrays: dict[str, np.ndarray]):
|
|
83
|
+
"""Test that bytes contain valid NPZ data."""
|
|
84
|
+
import io
|
|
85
|
+
|
|
86
|
+
result = arrays_to_bytes(sample_arrays)
|
|
87
|
+
|
|
88
|
+
loaded = np.load(io.BytesIO(result))
|
|
89
|
+
assert set(loaded.files) == set(sample_arrays.keys())
|
|
90
|
+
for key in sample_arrays:
|
|
91
|
+
np.testing.assert_array_equal(loaded[key], sample_arrays[key])
|
|
92
|
+
|
|
93
|
+
@pytest.mark.unit
|
|
94
|
+
def test_bytes_non_empty(self, sample_arrays: dict[str, np.ndarray]):
|
|
95
|
+
"""Test that result is non-empty."""
|
|
96
|
+
result = arrays_to_bytes(sample_arrays)
|
|
97
|
+
assert len(result) > 0
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class TestComputeChecksum:
|
|
101
|
+
"""Tests for compute_checksum function."""
|
|
102
|
+
|
|
103
|
+
@pytest.mark.unit
|
|
104
|
+
def test_returns_hex_string(self, sample_arrays: dict[str, np.ndarray]):
|
|
105
|
+
"""Test that compute_checksum returns a hex string."""
|
|
106
|
+
result = compute_checksum(sample_arrays)
|
|
107
|
+
assert isinstance(result, str)
|
|
108
|
+
assert len(result) == 64
|
|
109
|
+
assert all(c in "0123456789abcdef" for c in result)
|
|
110
|
+
|
|
111
|
+
@pytest.mark.unit
|
|
112
|
+
def test_deterministic(self, sample_arrays: dict[str, np.ndarray]):
|
|
113
|
+
"""Test that checksum is deterministic for same input."""
|
|
114
|
+
checksum1 = compute_checksum(sample_arrays)
|
|
115
|
+
checksum2 = compute_checksum(sample_arrays)
|
|
116
|
+
assert checksum1 == checksum2
|
|
117
|
+
|
|
118
|
+
@pytest.mark.unit
|
|
119
|
+
def test_different_for_different_data(self, sample_arrays: dict[str, np.ndarray]):
|
|
120
|
+
"""Test that checksum differs for different data."""
|
|
121
|
+
checksum1 = compute_checksum(sample_arrays)
|
|
122
|
+
|
|
123
|
+
different_arrays = {
|
|
124
|
+
"X": np.array([[5.0, 6.0], [7.0, 8.0]], dtype=np.float32),
|
|
125
|
+
"y": sample_arrays["y"],
|
|
126
|
+
}
|
|
127
|
+
checksum2 = compute_checksum(different_arrays)
|
|
128
|
+
|
|
129
|
+
assert checksum1 != checksum2
|
|
130
|
+
|
|
131
|
+
@pytest.mark.unit
|
|
132
|
+
def test_sha256_format(self, sample_arrays: dict[str, np.ndarray]):
|
|
133
|
+
"""Test that checksum is valid SHA-256 format."""
|
|
134
|
+
import hashlib
|
|
135
|
+
|
|
136
|
+
result = compute_checksum(sample_arrays)
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
int(result, 16)
|
|
140
|
+
valid_hex = True
|
|
141
|
+
except ValueError:
|
|
142
|
+
valid_hex = False
|
|
143
|
+
|
|
144
|
+
assert valid_hex
|
|
145
|
+
assert len(result) == hashlib.sha256().digest_size * 2
|