juniper-data 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. juniper_data/__init__.py +88 -0
  2. juniper_data/__main__.py +78 -0
  3. juniper_data/api/__init__.py +10 -0
  4. juniper_data/api/app.py +111 -0
  5. juniper_data/api/middleware.py +95 -0
  6. juniper_data/api/routes/__init__.py +9 -0
  7. juniper_data/api/routes/datasets.py +414 -0
  8. juniper_data/api/routes/generators.py +125 -0
  9. juniper_data/api/routes/health.py +49 -0
  10. juniper_data/api/security.py +238 -0
  11. juniper_data/api/settings.py +109 -0
  12. juniper_data/core/__init__.py +32 -0
  13. juniper_data/core/artifacts.py +63 -0
  14. juniper_data/core/dataset_id.py +38 -0
  15. juniper_data/core/models.py +135 -0
  16. juniper_data/core/split.py +120 -0
  17. juniper_data/generators/__init__.py +15 -0
  18. juniper_data/generators/arc_agi/__init__.py +11 -0
  19. juniper_data/generators/arc_agi/generator.py +229 -0
  20. juniper_data/generators/arc_agi/params.py +56 -0
  21. juniper_data/generators/checkerboard/__init__.py +15 -0
  22. juniper_data/generators/checkerboard/generator.py +114 -0
  23. juniper_data/generators/checkerboard/params.py +32 -0
  24. juniper_data/generators/circles/__init__.py +11 -0
  25. juniper_data/generators/circles/generator.py +112 -0
  26. juniper_data/generators/circles/params.py +31 -0
  27. juniper_data/generators/csv_import/__init__.py +15 -0
  28. juniper_data/generators/csv_import/generator.py +198 -0
  29. juniper_data/generators/csv_import/params.py +48 -0
  30. juniper_data/generators/gaussian/__init__.py +11 -0
  31. juniper_data/generators/gaussian/generator.py +149 -0
  32. juniper_data/generators/gaussian/params.py +53 -0
  33. juniper_data/generators/mnist/__init__.py +11 -0
  34. juniper_data/generators/mnist/generator.py +124 -0
  35. juniper_data/generators/mnist/params.py +39 -0
  36. juniper_data/generators/spiral/__init__.py +57 -0
  37. juniper_data/generators/spiral/defaults.py +39 -0
  38. juniper_data/generators/spiral/generator.py +206 -0
  39. juniper_data/generators/spiral/params.py +148 -0
  40. juniper_data/generators/xor/__init__.py +11 -0
  41. juniper_data/generators/xor/generator.py +162 -0
  42. juniper_data/generators/xor/params.py +30 -0
  43. juniper_data/storage/__init__.py +120 -0
  44. juniper_data/storage/base.py +279 -0
  45. juniper_data/storage/cached.py +211 -0
  46. juniper_data/storage/hf_store.py +257 -0
  47. juniper_data/storage/kaggle_store.py +333 -0
  48. juniper_data/storage/local_fs.py +232 -0
  49. juniper_data/storage/memory.py +136 -0
  50. juniper_data/storage/postgres_store.py +373 -0
  51. juniper_data/storage/redis_store.py +264 -0
  52. juniper_data/tests/__init__.py +1 -0
  53. juniper_data/tests/conftest.py +68 -0
  54. juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
  55. juniper_data/tests/integration/__init__.py +1 -0
  56. juniper_data/tests/integration/test_api.py +283 -0
  57. juniper_data/tests/integration/test_e2e_workflow.py +378 -0
  58. juniper_data/tests/integration/test_lifecycle_api.py +304 -0
  59. juniper_data/tests/integration/test_security_integration.py +189 -0
  60. juniper_data/tests/integration/test_storage_workflow.py +259 -0
  61. juniper_data/tests/performance/__init__.py +1 -0
  62. juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
  63. juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
  64. juniper_data/tests/unit/__init__.py +1 -0
  65. juniper_data/tests/unit/test_api_app.py +206 -0
  66. juniper_data/tests/unit/test_api_routes.py +407 -0
  67. juniper_data/tests/unit/test_api_settings.py +100 -0
  68. juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
  69. juniper_data/tests/unit/test_artifacts.py +145 -0
  70. juniper_data/tests/unit/test_cached_store.py +423 -0
  71. juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
  72. juniper_data/tests/unit/test_circles_generator.py +256 -0
  73. juniper_data/tests/unit/test_csv_import_generator.py +345 -0
  74. juniper_data/tests/unit/test_dataset_id.py +181 -0
  75. juniper_data/tests/unit/test_gaussian_generator.py +333 -0
  76. juniper_data/tests/unit/test_hf_store.py +416 -0
  77. juniper_data/tests/unit/test_init.py +93 -0
  78. juniper_data/tests/unit/test_kaggle_store.py +469 -0
  79. juniper_data/tests/unit/test_lifecycle.py +394 -0
  80. juniper_data/tests/unit/test_main.py +127 -0
  81. juniper_data/tests/unit/test_middleware.py +79 -0
  82. juniper_data/tests/unit/test_mnist_generator.py +370 -0
  83. juniper_data/tests/unit/test_postgres_store.py +490 -0
  84. juniper_data/tests/unit/test_redis_store.py +500 -0
  85. juniper_data/tests/unit/test_security.py +281 -0
  86. juniper_data/tests/unit/test_security_boundaries.py +517 -0
  87. juniper_data/tests/unit/test_spiral_generator.py +566 -0
  88. juniper_data/tests/unit/test_split.py +245 -0
  89. juniper_data/tests/unit/test_storage.py +767 -0
  90. juniper_data/tests/unit/test_xor_generator.py +223 -0
  91. juniper_data-0.4.2.dist-info/METADATA +216 -0
  92. juniper_data-0.4.2.dist-info/RECORD +95 -0
  93. juniper_data-0.4.2.dist-info/WHEEL +5 -0
  94. juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
  95. juniper_data-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,525 @@
1
+ """Unit tests for the ARC-AGI dataset generator."""
2
+
3
+ import json
4
+ from unittest.mock import MagicMock, patch
5
+
6
+ import numpy as np
7
+ import pytest
8
+
9
+ from juniper_data.generators.arc_agi.params import ArcAgiParams
10
+
11
+
12
+ def _make_sample_tasks(n_tasks=3, train_pairs=2, test_pairs=1):
13
+ """Create sample ARC tasks for testing."""
14
+ tasks = []
15
+ for i in range(n_tasks):
16
+ task = {
17
+ "task_id": f"task_{i}",
18
+ "train": [
19
+ {
20
+ "input": [[j, j + 1], [j + 2, j + 3]],
21
+ "output": [[j + 3, j + 2], [j + 1, j]],
22
+ }
23
+ for j in range(train_pairs)
24
+ ],
25
+ "test": [
26
+ {
27
+ "input": [[1, 2], [3, 4]],
28
+ "output": [[4, 3], [2, 1]],
29
+ }
30
+ for _ in range(test_pairs)
31
+ ],
32
+ }
33
+ tasks.append(task)
34
+ return tasks
35
+
36
+
37
+ def _make_mock_hf_dataset(tasks):
38
+ """Create a mock HuggingFace dataset from task list."""
39
+ mock_ds = MagicMock()
40
+ mock_ds.__len__ = MagicMock(return_value=len(tasks))
41
+ mock_ds.__iter__ = MagicMock(return_value=iter(tasks))
42
+
43
+ return mock_ds
44
+
45
+
46
+ @pytest.fixture
47
+ def mock_hf_load():
48
+ """Patch HF_AVAILABLE and hf_load_dataset for the arc_agi generator module."""
49
+ mock_load = MagicMock()
50
+
51
+ with patch("juniper_data.generators.arc_agi.generator.HF_AVAILABLE", True):
52
+ with patch("juniper_data.generators.arc_agi.generator.hf_load_dataset", mock_load):
53
+ yield mock_load
54
+
55
+
56
+ @pytest.mark.unit
57
+ @pytest.mark.generators
58
+ class TestArcAgiParams:
59
+ """Tests for ArcAgiParams validation."""
60
+
61
+ def test_default_params(self) -> None:
62
+ """Default parameters are valid."""
63
+ params = ArcAgiParams()
64
+ assert params.source == "huggingface"
65
+ assert params.local_path is None
66
+ assert params.subset == "training"
67
+ assert params.n_tasks is None
68
+ assert params.pad_to == 30
69
+ assert params.pad_value == -1
70
+ assert params.include_test is True
71
+ assert params.flatten_pairs is True
72
+ assert params.train_ratio == 0.8
73
+
74
+ def test_custom_params(self) -> None:
75
+ """Custom parameters are accepted."""
76
+ params = ArcAgiParams(
77
+ source="local",
78
+ local_path="/data/arc",
79
+ subset="evaluation",
80
+ n_tasks=10,
81
+ pad_to=15,
82
+ pad_value=0,
83
+ include_test=False,
84
+ flatten_pairs=False,
85
+ seed=42,
86
+ )
87
+ assert params.source == "local"
88
+ assert params.local_path == "/data/arc"
89
+ assert params.n_tasks == 10
90
+
91
+ def test_invalid_pad_to(self) -> None:
92
+ """pad_to must be >= 1 and <= 50."""
93
+ with pytest.raises(ValueError):
94
+ ArcAgiParams(pad_to=0)
95
+ with pytest.raises(ValueError):
96
+ ArcAgiParams(pad_to=51)
97
+
98
+ def test_invalid_n_tasks(self) -> None:
99
+ """n_tasks must be >= 1."""
100
+ with pytest.raises(ValueError):
101
+ ArcAgiParams(n_tasks=0)
102
+
103
+ def test_invalid_train_ratio(self) -> None:
104
+ """train_ratio must be in (0, 1]."""
105
+ with pytest.raises(ValueError):
106
+ ArcAgiParams(train_ratio=0)
107
+
108
+
109
+ @pytest.mark.unit
110
+ @pytest.mark.generators
111
+ class TestArcAgiGeneratorHuggingFace:
112
+ """Tests for ARC-AGI generation from HuggingFace source."""
113
+
114
+ def test_generate_from_hf(self, mock_hf_load) -> None:
115
+ """Generate produces correct output structure from HuggingFace."""
116
+ tasks = _make_sample_tasks(n_tasks=2, train_pairs=2, test_pairs=1)
117
+ mock_ds = _make_mock_hf_dataset(tasks)
118
+ mock_hf_load.return_value = mock_ds
119
+
120
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
121
+
122
+ params = ArcAgiParams(source="huggingface", seed=42)
123
+ result = ArcAgiGenerator.generate(params)
124
+
125
+ assert "X_train" in result
126
+ assert "y_train" in result
127
+ assert "X_test" in result
128
+ assert "y_test" in result
129
+ assert "X_full" in result
130
+ assert "y_full" in result
131
+ assert "task_ids" in result
132
+
133
+ def test_generate_correct_shapes_flattened(self, mock_hf_load) -> None:
134
+ """Flattened output has correct shape (n_pairs, pad_to*pad_to)."""
135
+ tasks = _make_sample_tasks(n_tasks=2, train_pairs=2, test_pairs=1)
136
+ mock_ds = _make_mock_hf_dataset(tasks)
137
+ mock_hf_load.return_value = mock_ds
138
+
139
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
140
+
141
+ params = ArcAgiParams(pad_to=5, flatten_pairs=True, seed=42)
142
+ result = ArcAgiGenerator.generate(params)
143
+
144
+ n_total = 2 * 2 + 2 * 1 # 2 tasks * 2 train + 2 tasks * 1 test
145
+ assert result["X_full"].shape == (n_total, 25)
146
+ assert result["y_full"].shape == (n_total, 25)
147
+
148
+ def test_generate_correct_shapes_not_flattened(self, mock_hf_load) -> None:
149
+ """Non-flattened output has correct shape (n_pairs, pad_to, pad_to)."""
150
+ tasks = _make_sample_tasks(n_tasks=2, train_pairs=2, test_pairs=1)
151
+ mock_ds = _make_mock_hf_dataset(tasks)
152
+ mock_hf_load.return_value = mock_ds
153
+
154
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
155
+
156
+ params = ArcAgiParams(pad_to=5, flatten_pairs=False, seed=42)
157
+ result = ArcAgiGenerator.generate(params)
158
+
159
+ n_total = 2 * 2 + 2 * 1
160
+ assert result["X_full"].shape == (n_total, 5, 5)
161
+
162
+ def test_generate_without_test_pairs(self, mock_hf_load) -> None:
163
+ """Generate without test pairs only uses train pairs."""
164
+ tasks = _make_sample_tasks(n_tasks=2, train_pairs=3, test_pairs=2)
165
+ mock_ds = _make_mock_hf_dataset(tasks)
166
+ mock_hf_load.return_value = mock_ds
167
+
168
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
169
+
170
+ params = ArcAgiParams(include_test=False, pad_to=5, seed=42)
171
+ result = ArcAgiGenerator.generate(params)
172
+
173
+ n_total = 2 * 3 # only train pairs
174
+ assert result["X_full"].shape[0] == n_total
175
+
176
+ def test_generate_with_n_tasks_seed(self, mock_hf_load) -> None:
177
+ """n_tasks with seed selects random subset."""
178
+ tasks = _make_sample_tasks(n_tasks=10, train_pairs=1, test_pairs=0)
179
+ mock_ds = _make_mock_hf_dataset(tasks)
180
+ mock_hf_load.return_value = mock_ds
181
+
182
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
183
+
184
+ params = ArcAgiParams(n_tasks=3, seed=42, include_test=False, pad_to=5)
185
+ result = ArcAgiGenerator.generate(params)
186
+
187
+ assert result["X_full"].shape[0] == 3
188
+
189
+ def test_generate_with_n_tasks_no_seed(self, mock_hf_load) -> None:
190
+ """n_tasks without seed takes first N tasks."""
191
+ tasks = _make_sample_tasks(n_tasks=10, train_pairs=1, test_pairs=0)
192
+ mock_ds = _make_mock_hf_dataset(tasks)
193
+ mock_hf_load.return_value = mock_ds
194
+
195
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
196
+
197
+ params = ArcAgiParams(n_tasks=3, seed=None, include_test=False, pad_to=5)
198
+ result = ArcAgiGenerator.generate(params)
199
+
200
+ assert result["X_full"].shape[0] == 3
201
+
202
+ def test_generate_hf_fallback_dataset(self, mock_hf_load) -> None:
203
+ """HuggingFace loading tries fallback dataset on failure."""
204
+ tasks = _make_sample_tasks(n_tasks=2, train_pairs=1, test_pairs=0)
205
+ mock_ds = _make_mock_hf_dataset(tasks)
206
+
207
+ call_count = [0]
208
+
209
+ def side_effect(name, split=None):
210
+ call_count[0] += 1
211
+ if call_count[0] == 1:
212
+ raise Exception("Dataset not found")
213
+ return mock_ds
214
+
215
+ mock_hf_load.side_effect = side_effect
216
+
217
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
218
+
219
+ params = ArcAgiParams(include_test=False, pad_to=5, seed=42)
220
+ result = ArcAgiGenerator.generate(params)
221
+
222
+ assert mock_hf_load.call_count == 2
223
+ assert result["X_full"].shape[0] == 2
224
+
225
+ def test_generate_raises_without_datasets(self) -> None:
226
+ """Raises ImportError when datasets not installed."""
227
+ with patch("juniper_data.generators.arc_agi.generator.HF_AVAILABLE", False):
228
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
229
+
230
+ params = ArcAgiParams(source="huggingface")
231
+ with pytest.raises(ImportError, match="Hugging Face datasets package not installed"):
232
+ ArcAgiGenerator.generate(params)
233
+
234
+ def test_generate_hf_missing_task_id(self, mock_hf_load) -> None:
235
+ """Handle HF items without task_id."""
236
+ tasks = [{"train": [{"input": [[1]], "output": [[2]]}], "test": []}]
237
+ mock_ds = _make_mock_hf_dataset(tasks)
238
+ mock_hf_load.return_value = mock_ds
239
+
240
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
241
+
242
+ params = ArcAgiParams(include_test=False, pad_to=5, seed=42)
243
+ result = ArcAgiGenerator.generate(params)
244
+
245
+ assert "task_0" in str(result["task_ids"][0])
246
+
247
+
248
+ @pytest.mark.unit
249
+ @pytest.mark.generators
250
+ class TestArcAgiGeneratorLocal:
251
+ """Tests for ARC-AGI generation from local files."""
252
+
253
+ def test_generate_from_local(self, tmp_path) -> None:
254
+ """Generate from local JSON files."""
255
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
256
+
257
+ training_dir = tmp_path / "training"
258
+ training_dir.mkdir()
259
+
260
+ task = {
261
+ "train": [
262
+ {"input": [[1, 2], [3, 4]], "output": [[4, 3], [2, 1]]},
263
+ ],
264
+ "test": [
265
+ {"input": [[5, 6], [7, 8]], "output": [[8, 7], [6, 5]]},
266
+ ],
267
+ }
268
+ (training_dir / "task1.json").write_text(json.dumps(task))
269
+ (training_dir / "task2.json").write_text(json.dumps(task))
270
+
271
+ params = ArcAgiParams(source="local", local_path=str(tmp_path), subset="training", pad_to=5, seed=42)
272
+ result = ArcAgiGenerator.generate(params)
273
+
274
+ assert result["X_full"].shape[0] == 4 # 2 tasks * (1 train + 1 test)
275
+ assert result["task_ids"].shape[0] == 4
276
+
277
+ def test_generate_local_evaluation_subset(self, tmp_path) -> None:
278
+ """Generate from evaluation subset."""
279
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
280
+
281
+ eval_dir = tmp_path / "evaluation"
282
+ eval_dir.mkdir()
283
+
284
+ task = {"train": [{"input": [[1]], "output": [[2]]}], "test": []}
285
+ (eval_dir / "eval1.json").write_text(json.dumps(task))
286
+
287
+ params = ArcAgiParams(
288
+ source="local", local_path=str(tmp_path), subset="evaluation", include_test=False, pad_to=5, seed=42
289
+ )
290
+ result = ArcAgiGenerator.generate(params)
291
+
292
+ assert result["X_full"].shape[0] == 1
293
+
294
+ def test_generate_local_all_subsets(self, tmp_path) -> None:
295
+ """Generate from all subsets."""
296
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
297
+
298
+ training_dir = tmp_path / "training"
299
+ training_dir.mkdir()
300
+ eval_dir = tmp_path / "evaluation"
301
+ eval_dir.mkdir()
302
+
303
+ task = {"train": [{"input": [[1]], "output": [[2]]}], "test": []}
304
+ (training_dir / "t1.json").write_text(json.dumps(task))
305
+ (eval_dir / "e1.json").write_text(json.dumps(task))
306
+
307
+ params = ArcAgiParams(
308
+ source="local", local_path=str(tmp_path), subset="all", include_test=False, pad_to=5, seed=42
309
+ )
310
+ result = ArcAgiGenerator.generate(params)
311
+
312
+ assert result["X_full"].shape[0] == 2
313
+
314
+ def test_generate_local_missing_path(self) -> None:
315
+ """Raises ValueError when local_path is None."""
316
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
317
+
318
+ params = ArcAgiParams(source="local", local_path=None)
319
+ with pytest.raises(ValueError, match="local_path is required"):
320
+ ArcAgiGenerator.generate(params)
321
+
322
+ def test_generate_local_nonexistent_path(self) -> None:
323
+ """Raises FileNotFoundError when path doesn't exist."""
324
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
325
+
326
+ params = ArcAgiParams(source="local", local_path="/nonexistent/path")
327
+ with pytest.raises(FileNotFoundError, match="Path not found"):
328
+ ArcAgiGenerator.generate(params)
329
+
330
+ def test_generate_local_with_n_tasks_seed(self, tmp_path) -> None:
331
+ """n_tasks with seed selects random subset from local files."""
332
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
333
+
334
+ training_dir = tmp_path / "training"
335
+ training_dir.mkdir()
336
+
337
+ task = {"train": [{"input": [[1]], "output": [[2]]}], "test": []}
338
+ for i in range(10):
339
+ (training_dir / f"task_{i}.json").write_text(json.dumps(task))
340
+
341
+ params = ArcAgiParams(
342
+ source="local", local_path=str(tmp_path), n_tasks=3, seed=42, include_test=False, pad_to=5
343
+ )
344
+ result = ArcAgiGenerator.generate(params)
345
+
346
+ assert result["X_full"].shape[0] == 3
347
+
348
+ def test_generate_local_with_n_tasks_no_seed(self, tmp_path) -> None:
349
+ """n_tasks without seed takes first N local tasks."""
350
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
351
+
352
+ training_dir = tmp_path / "training"
353
+ training_dir.mkdir()
354
+
355
+ task = {"train": [{"input": [[1]], "output": [[2]]}], "test": []}
356
+ for i in range(10):
357
+ (training_dir / f"task_{i:02d}.json").write_text(json.dumps(task))
358
+
359
+ params = ArcAgiParams(
360
+ source="local", local_path=str(tmp_path), n_tasks=3, seed=None, include_test=False, pad_to=5
361
+ )
362
+ result = ArcAgiGenerator.generate(params)
363
+
364
+ assert result["X_full"].shape[0] == 3
365
+
366
+ def test_generate_local_missing_subdirs(self, tmp_path) -> None:
367
+ """Handle missing training/evaluation subdirectories gracefully."""
368
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
369
+
370
+ params = ArcAgiParams(
371
+ source="local", local_path=str(tmp_path), subset="training", include_test=False, pad_to=5, seed=42
372
+ )
373
+ result = ArcAgiGenerator.generate(params)
374
+
375
+ assert result["X_full"].shape[0] == 0
376
+
377
+
378
+ @pytest.mark.unit
379
+ @pytest.mark.generators
380
+ class TestArcAgiGeneratorPadGrid:
381
+ """Tests for _pad_grid helper."""
382
+
383
+ def test_pad_grid_smaller_than_target(self) -> None:
384
+ """Pad a small grid to target size."""
385
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
386
+
387
+ grid = [[1, 2], [3, 4]]
388
+ result = ArcAgiGenerator._pad_grid(grid, pad_to=5, pad_value=-1)
389
+
390
+ assert result.shape == (5, 5)
391
+ assert result[0, 0] == 1
392
+ assert result[0, 2] == -1
393
+ assert result[2, 0] == -1
394
+
395
+ def test_pad_grid_exact_size(self) -> None:
396
+ """Grid exactly matching target size is unchanged."""
397
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
398
+
399
+ grid = [[1, 2], [3, 4]]
400
+ result = ArcAgiGenerator._pad_grid(grid, pad_to=2, pad_value=0)
401
+
402
+ assert result.shape == (2, 2)
403
+ np.testing.assert_array_equal(result, [[1, 2], [3, 4]])
404
+
405
+
406
+ @pytest.mark.unit
407
+ @pytest.mark.generators
408
+ class TestArcAgiGeneratorConvertTasks:
409
+ """Tests for _convert_tasks_to_arrays helper."""
410
+
411
+ def test_convert_empty_tasks(self) -> None:
412
+ """Empty task list returns zero-sized arrays."""
413
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
414
+
415
+ params = ArcAgiParams(pad_to=5, include_test=False)
416
+ X, y, ids = ArcAgiGenerator._convert_tasks_to_arrays([], params)
417
+
418
+ assert X.shape == (0, 25)
419
+ assert y.shape == (0, 25)
420
+ assert len(ids) == 0
421
+
422
+ def test_convert_tasks_with_test_missing_output(self) -> None:
423
+ """Test pairs with missing output get padded."""
424
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
425
+
426
+ tasks = [
427
+ {
428
+ "task_id": "test-task",
429
+ "train": [{"input": [[1]], "output": [[2]]}],
430
+ "test": [{"input": [[3]]}],
431
+ }
432
+ ]
433
+
434
+ params = ArcAgiParams(pad_to=5, include_test=True, flatten_pairs=True)
435
+ X, y, ids = ArcAgiGenerator._convert_tasks_to_arrays(tasks, params)
436
+
437
+ assert X.shape[0] == 2 # 1 train + 1 test
438
+ assert y.shape[0] == 2
439
+
440
+ def test_convert_tasks_unknown_task_id(self) -> None:
441
+ """Tasks without task_id get 'unknown'."""
442
+ from juniper_data.generators.arc_agi.generator import ArcAgiGenerator
443
+
444
+ tasks = [{"train": [{"input": [[1]], "output": [[2]]}], "test": []}]
445
+
446
+ params = ArcAgiParams(pad_to=5, include_test=False, flatten_pairs=True)
447
+ X, y, ids = ArcAgiGenerator._convert_tasks_to_arrays(tasks, params)
448
+
449
+ assert ids[0] == "unknown"
450
+
451
+
452
+ @pytest.mark.unit
453
+ @pytest.mark.generators
454
+ class TestArcAgiGetSchema:
455
+ """Tests for get_schema function."""
456
+
457
+ def test_get_schema_returns_dict(self) -> None:
458
+ """get_schema returns a dictionary."""
459
+ from juniper_data.generators.arc_agi.generator import get_schema
460
+
461
+ schema = get_schema()
462
+ assert isinstance(schema, dict)
463
+
464
+ def test_get_schema_has_properties(self) -> None:
465
+ """Schema has expected properties."""
466
+ from juniper_data.generators.arc_agi.generator import get_schema
467
+
468
+ schema = get_schema()
469
+ assert "properties" in schema
470
+ assert "source" in schema["properties"]
471
+ assert "pad_to" in schema["properties"]
472
+ assert "n_tasks" in schema["properties"]
473
+
474
+
475
+ @pytest.mark.unit
476
+ @pytest.mark.generators
477
+ class TestArcAgiVersion:
478
+ """Tests for version constant."""
479
+
480
+ def test_version_format(self) -> None:
481
+ """Version follows semver format."""
482
+ from juniper_data.generators.arc_agi.generator import VERSION
483
+
484
+ parts = VERSION.split(".")
485
+ assert len(parts) == 3
486
+ for part in parts:
487
+ assert part.isdigit()
488
+
489
+
490
+ @pytest.mark.unit
491
+ @pytest.mark.generators
492
+ class TestArcAgiImports:
493
+ """Tests for __init__.py imports."""
494
+
495
+ def test_init_exports(self) -> None:
496
+ """__init__.py exports expected symbols."""
497
+ from juniper_data.generators.arc_agi import VERSION, ArcAgiGenerator, ArcAgiParams, get_schema
498
+
499
+ assert ArcAgiGenerator is not None
500
+ assert ArcAgiParams is not None
501
+ assert VERSION is not None
502
+ assert get_schema is not None
503
+
504
+ def test_module_level_hf_available_true(self) -> None:
505
+ """Module-level HF_AVAILABLE is True when datasets is importable."""
506
+ import importlib
507
+ import sys
508
+ from types import ModuleType
509
+ from unittest.mock import MagicMock
510
+
511
+ # Inject a fake 'datasets' module so the try-branch succeeds
512
+ fake_datasets = ModuleType("datasets")
513
+ fake_datasets.load_dataset = MagicMock() # type: ignore[attr-defined]
514
+ sys.modules["datasets"] = fake_datasets
515
+
516
+ mod_name = "juniper_data.generators.arc_agi.generator"
517
+ sys.modules.pop(mod_name, None)
518
+ try:
519
+ mod = importlib.import_module(mod_name)
520
+ assert mod.HF_AVAILABLE is True
521
+ finally:
522
+ # Restore original state
523
+ sys.modules.pop("datasets", None)
524
+ sys.modules.pop(mod_name, None)
525
+ importlib.import_module(mod_name)
@@ -0,0 +1,145 @@
1
+ """Unit tests for core artifacts module."""
2
+
3
+ import tempfile
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pytest
8
+
9
+ from juniper_data.core.artifacts import arrays_to_bytes, compute_checksum, load_npz, save_npz
10
+
11
+
12
+ @pytest.fixture
13
+ def sample_arrays() -> dict[str, np.ndarray]:
14
+ """Create sample arrays for testing."""
15
+ return {
16
+ "X": np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32),
17
+ "y": np.array([[1, 0], [0, 1]], dtype=np.float32),
18
+ }
19
+
20
+
21
+ @pytest.fixture
22
+ def temp_npz_path():
23
+ """Create a temporary file path for NPZ files."""
24
+ with tempfile.TemporaryDirectory() as tmpdir:
25
+ yield Path(tmpdir) / "test.npz"
26
+
27
+
28
+ class TestSaveNpz:
29
+ """Tests for save_npz function."""
30
+
31
+ @pytest.mark.unit
32
+ def test_save_creates_file(self, temp_npz_path: Path, sample_arrays: dict[str, np.ndarray]):
33
+ """Test that save_npz creates the file."""
34
+ save_npz(temp_npz_path, sample_arrays)
35
+ assert temp_npz_path.exists()
36
+
37
+ @pytest.mark.unit
38
+ def test_save_correct_content(self, temp_npz_path: Path, sample_arrays: dict[str, np.ndarray]):
39
+ """Test that saved content is correct."""
40
+ save_npz(temp_npz_path, sample_arrays)
41
+
42
+ with np.load(temp_npz_path) as data:
43
+ assert set(data.files) == set(sample_arrays.keys())
44
+ for key in sample_arrays:
45
+ np.testing.assert_array_equal(data[key], sample_arrays[key])
46
+
47
+
48
+ class TestLoadNpz:
49
+ """Tests for load_npz function."""
50
+
51
+ @pytest.mark.unit
52
+ def test_load_returns_arrays(self, temp_npz_path: Path, sample_arrays: dict[str, np.ndarray]):
53
+ """Test that load_npz returns the correct arrays."""
54
+ save_npz(temp_npz_path, sample_arrays)
55
+ loaded = load_npz(temp_npz_path)
56
+
57
+ assert isinstance(loaded, dict)
58
+ assert set(loaded.keys()) == set(sample_arrays.keys())
59
+ for key in sample_arrays:
60
+ np.testing.assert_array_equal(loaded[key], sample_arrays[key])
61
+
62
+ @pytest.mark.unit
63
+ def test_load_returns_mutable_arrays(self, temp_npz_path: Path, sample_arrays: dict[str, np.ndarray]):
64
+ """Test that loaded arrays are mutable (not memory-mapped)."""
65
+ save_npz(temp_npz_path, sample_arrays)
66
+ loaded = load_npz(temp_npz_path)
67
+
68
+ loaded["X"][0, 0] = 999.0
69
+ assert loaded["X"][0, 0] == 999.0
70
+
71
+
72
+ class TestArraysToBytes:
73
+ """Tests for arrays_to_bytes function."""
74
+
75
+ @pytest.mark.unit
76
+ def test_returns_bytes(self, sample_arrays: dict[str, np.ndarray]):
77
+ """Test that arrays_to_bytes returns bytes."""
78
+ result = arrays_to_bytes(sample_arrays)
79
+ assert isinstance(result, bytes)
80
+
81
+ @pytest.mark.unit
82
+ def test_bytes_contain_npz_data(self, sample_arrays: dict[str, np.ndarray]):
83
+ """Test that bytes contain valid NPZ data."""
84
+ import io
85
+
86
+ result = arrays_to_bytes(sample_arrays)
87
+
88
+ loaded = np.load(io.BytesIO(result))
89
+ assert set(loaded.files) == set(sample_arrays.keys())
90
+ for key in sample_arrays:
91
+ np.testing.assert_array_equal(loaded[key], sample_arrays[key])
92
+
93
+ @pytest.mark.unit
94
+ def test_bytes_non_empty(self, sample_arrays: dict[str, np.ndarray]):
95
+ """Test that result is non-empty."""
96
+ result = arrays_to_bytes(sample_arrays)
97
+ assert len(result) > 0
98
+
99
+
100
+ class TestComputeChecksum:
101
+ """Tests for compute_checksum function."""
102
+
103
+ @pytest.mark.unit
104
+ def test_returns_hex_string(self, sample_arrays: dict[str, np.ndarray]):
105
+ """Test that compute_checksum returns a hex string."""
106
+ result = compute_checksum(sample_arrays)
107
+ assert isinstance(result, str)
108
+ assert len(result) == 64
109
+ assert all(c in "0123456789abcdef" for c in result)
110
+
111
+ @pytest.mark.unit
112
+ def test_deterministic(self, sample_arrays: dict[str, np.ndarray]):
113
+ """Test that checksum is deterministic for same input."""
114
+ checksum1 = compute_checksum(sample_arrays)
115
+ checksum2 = compute_checksum(sample_arrays)
116
+ assert checksum1 == checksum2
117
+
118
+ @pytest.mark.unit
119
+ def test_different_for_different_data(self, sample_arrays: dict[str, np.ndarray]):
120
+ """Test that checksum differs for different data."""
121
+ checksum1 = compute_checksum(sample_arrays)
122
+
123
+ different_arrays = {
124
+ "X": np.array([[5.0, 6.0], [7.0, 8.0]], dtype=np.float32),
125
+ "y": sample_arrays["y"],
126
+ }
127
+ checksum2 = compute_checksum(different_arrays)
128
+
129
+ assert checksum1 != checksum2
130
+
131
+ @pytest.mark.unit
132
+ def test_sha256_format(self, sample_arrays: dict[str, np.ndarray]):
133
+ """Test that checksum is valid SHA-256 format."""
134
+ import hashlib
135
+
136
+ result = compute_checksum(sample_arrays)
137
+
138
+ try:
139
+ int(result, 16)
140
+ valid_hex = True
141
+ except ValueError:
142
+ valid_hex = False
143
+
144
+ assert valid_hex
145
+ assert len(result) == hashlib.sha256().digest_size * 2