juniper-data 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- juniper_data/__init__.py +88 -0
- juniper_data/__main__.py +78 -0
- juniper_data/api/__init__.py +10 -0
- juniper_data/api/app.py +111 -0
- juniper_data/api/middleware.py +95 -0
- juniper_data/api/routes/__init__.py +9 -0
- juniper_data/api/routes/datasets.py +414 -0
- juniper_data/api/routes/generators.py +125 -0
- juniper_data/api/routes/health.py +49 -0
- juniper_data/api/security.py +238 -0
- juniper_data/api/settings.py +109 -0
- juniper_data/core/__init__.py +32 -0
- juniper_data/core/artifacts.py +63 -0
- juniper_data/core/dataset_id.py +38 -0
- juniper_data/core/models.py +135 -0
- juniper_data/core/split.py +120 -0
- juniper_data/generators/__init__.py +15 -0
- juniper_data/generators/arc_agi/__init__.py +11 -0
- juniper_data/generators/arc_agi/generator.py +229 -0
- juniper_data/generators/arc_agi/params.py +56 -0
- juniper_data/generators/checkerboard/__init__.py +15 -0
- juniper_data/generators/checkerboard/generator.py +114 -0
- juniper_data/generators/checkerboard/params.py +32 -0
- juniper_data/generators/circles/__init__.py +11 -0
- juniper_data/generators/circles/generator.py +112 -0
- juniper_data/generators/circles/params.py +31 -0
- juniper_data/generators/csv_import/__init__.py +15 -0
- juniper_data/generators/csv_import/generator.py +198 -0
- juniper_data/generators/csv_import/params.py +48 -0
- juniper_data/generators/gaussian/__init__.py +11 -0
- juniper_data/generators/gaussian/generator.py +149 -0
- juniper_data/generators/gaussian/params.py +53 -0
- juniper_data/generators/mnist/__init__.py +11 -0
- juniper_data/generators/mnist/generator.py +124 -0
- juniper_data/generators/mnist/params.py +39 -0
- juniper_data/generators/spiral/__init__.py +57 -0
- juniper_data/generators/spiral/defaults.py +39 -0
- juniper_data/generators/spiral/generator.py +206 -0
- juniper_data/generators/spiral/params.py +148 -0
- juniper_data/generators/xor/__init__.py +11 -0
- juniper_data/generators/xor/generator.py +162 -0
- juniper_data/generators/xor/params.py +30 -0
- juniper_data/storage/__init__.py +120 -0
- juniper_data/storage/base.py +279 -0
- juniper_data/storage/cached.py +211 -0
- juniper_data/storage/hf_store.py +257 -0
- juniper_data/storage/kaggle_store.py +333 -0
- juniper_data/storage/local_fs.py +232 -0
- juniper_data/storage/memory.py +136 -0
- juniper_data/storage/postgres_store.py +373 -0
- juniper_data/storage/redis_store.py +264 -0
- juniper_data/tests/__init__.py +1 -0
- juniper_data/tests/conftest.py +68 -0
- juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
- juniper_data/tests/integration/__init__.py +1 -0
- juniper_data/tests/integration/test_api.py +283 -0
- juniper_data/tests/integration/test_e2e_workflow.py +378 -0
- juniper_data/tests/integration/test_lifecycle_api.py +304 -0
- juniper_data/tests/integration/test_security_integration.py +189 -0
- juniper_data/tests/integration/test_storage_workflow.py +259 -0
- juniper_data/tests/performance/__init__.py +1 -0
- juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
- juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
- juniper_data/tests/unit/__init__.py +1 -0
- juniper_data/tests/unit/test_api_app.py +206 -0
- juniper_data/tests/unit/test_api_routes.py +407 -0
- juniper_data/tests/unit/test_api_settings.py +100 -0
- juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
- juniper_data/tests/unit/test_artifacts.py +145 -0
- juniper_data/tests/unit/test_cached_store.py +423 -0
- juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
- juniper_data/tests/unit/test_circles_generator.py +256 -0
- juniper_data/tests/unit/test_csv_import_generator.py +345 -0
- juniper_data/tests/unit/test_dataset_id.py +181 -0
- juniper_data/tests/unit/test_gaussian_generator.py +333 -0
- juniper_data/tests/unit/test_hf_store.py +416 -0
- juniper_data/tests/unit/test_init.py +93 -0
- juniper_data/tests/unit/test_kaggle_store.py +469 -0
- juniper_data/tests/unit/test_lifecycle.py +394 -0
- juniper_data/tests/unit/test_main.py +127 -0
- juniper_data/tests/unit/test_middleware.py +79 -0
- juniper_data/tests/unit/test_mnist_generator.py +370 -0
- juniper_data/tests/unit/test_postgres_store.py +490 -0
- juniper_data/tests/unit/test_redis_store.py +500 -0
- juniper_data/tests/unit/test_security.py +281 -0
- juniper_data/tests/unit/test_security_boundaries.py +517 -0
- juniper_data/tests/unit/test_spiral_generator.py +566 -0
- juniper_data/tests/unit/test_split.py +245 -0
- juniper_data/tests/unit/test_storage.py +767 -0
- juniper_data/tests/unit/test_xor_generator.py +223 -0
- juniper_data-0.4.2.dist-info/METADATA +216 -0
- juniper_data-0.4.2.dist-info/RECORD +95 -0
- juniper_data-0.4.2.dist-info/WHEEL +5 -0
- juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
- juniper_data-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Split and shuffle utilities for dataset partitioning.
|
|
2
|
+
|
|
3
|
+
This module provides pure NumPy utilities for shuffling and splitting datasets
|
|
4
|
+
into train/test sets with reproducible random number generation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def shuffle_data(
|
|
11
|
+
X: np.ndarray,
|
|
12
|
+
y: np.ndarray,
|
|
13
|
+
rng: np.random.Generator,
|
|
14
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
15
|
+
"""Shuffle X and y arrays together using the same permutation.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
X: Feature array of shape (n_samples, ...).
|
|
19
|
+
y: Label array of shape (n_samples, ...).
|
|
20
|
+
rng: NumPy random generator for reproducibility.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Tuple of shuffled (X, y) arrays with the same permutation applied.
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
ValueError: If X and y have different number of samples.
|
|
27
|
+
"""
|
|
28
|
+
if X.shape[0] != y.shape[0]:
|
|
29
|
+
raise ValueError(
|
|
30
|
+
f"X and y must have the same number of samples. Got X.shape[0]={X.shape[0]}, y.shape[0]={y.shape[0]}"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
permutation = rng.permutation(X.shape[0])
|
|
34
|
+
return X[permutation], y[permutation]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def split_data(
|
|
38
|
+
X: np.ndarray,
|
|
39
|
+
y: np.ndarray,
|
|
40
|
+
train_ratio: float,
|
|
41
|
+
test_ratio: float,
|
|
42
|
+
) -> dict[str, np.ndarray]:
|
|
43
|
+
"""Split arrays into train and test sets based on ratios.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
X: Feature array of shape (n_samples, ...).
|
|
47
|
+
y: Label array of shape (n_samples, ...).
|
|
48
|
+
train_ratio: Fraction of data for training (0.0 to 1.0).
|
|
49
|
+
test_ratio: Fraction of data for testing (0.0 to 1.0).
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Dictionary with keys "X_train", "y_train", "X_test", "y_test".
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
ValueError: If ratios are invalid or X and y have different sample counts.
|
|
56
|
+
"""
|
|
57
|
+
if X.shape[0] != y.shape[0]:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"X and y must have the same number of samples. Got X.shape[0]={X.shape[0]}, y.shape[0]={y.shape[0]}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if not (0.0 <= train_ratio <= 1.0):
|
|
63
|
+
raise ValueError(f"train_ratio must be between 0 and 1. Got {train_ratio}")
|
|
64
|
+
|
|
65
|
+
if not (0.0 <= test_ratio <= 1.0):
|
|
66
|
+
raise ValueError(f"test_ratio must be between 0 and 1. Got {test_ratio}")
|
|
67
|
+
|
|
68
|
+
if train_ratio + test_ratio > 1.0:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"train_ratio + test_ratio must not exceed 1.0. "
|
|
71
|
+
f"Got {train_ratio} + {test_ratio} = {train_ratio + test_ratio}"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
n_samples = X.shape[0]
|
|
75
|
+
n_train = int(np.round(n_samples * train_ratio))
|
|
76
|
+
n_test = int(np.round(n_samples * test_ratio))
|
|
77
|
+
|
|
78
|
+
if n_train + n_test > n_samples:
|
|
79
|
+
n_test = n_samples - n_train
|
|
80
|
+
|
|
81
|
+
return {
|
|
82
|
+
"X_train": X[:n_train],
|
|
83
|
+
"y_train": y[:n_train],
|
|
84
|
+
"X_test": X[n_train : n_train + n_test],
|
|
85
|
+
"y_test": y[n_train : n_train + n_test],
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def shuffle_and_split(
|
|
90
|
+
X: np.ndarray,
|
|
91
|
+
y: np.ndarray,
|
|
92
|
+
train_ratio: float,
|
|
93
|
+
test_ratio: float,
|
|
94
|
+
seed: int | None = None,
|
|
95
|
+
shuffle: bool = True,
|
|
96
|
+
) -> dict[str, np.ndarray]:
|
|
97
|
+
"""Optionally shuffle and then split data into train/test sets.
|
|
98
|
+
|
|
99
|
+
High-level function that combines shuffling and splitting operations.
|
|
100
|
+
Uses np.random.Generator for reproducible randomness.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
X: Feature array of shape (n_samples, ...).
|
|
104
|
+
y: Label array of shape (n_samples, ...).
|
|
105
|
+
train_ratio: Fraction of data for training (0.0 to 1.0).
|
|
106
|
+
test_ratio: Fraction of data for testing (0.0 to 1.0).
|
|
107
|
+
seed: Random seed for reproducibility. If None, uses non-deterministic seed.
|
|
108
|
+
shuffle: Whether to shuffle data before splitting. Defaults to True.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Dictionary with keys "X_train", "y_train", "X_test", "y_test".
|
|
112
|
+
|
|
113
|
+
Raises:
|
|
114
|
+
ValueError: If ratios are invalid or X and y have different sample counts.
|
|
115
|
+
"""
|
|
116
|
+
if shuffle:
|
|
117
|
+
rng = np.random.default_rng(seed)
|
|
118
|
+
X, y = shuffle_data(X, y, rng)
|
|
119
|
+
|
|
120
|
+
return split_data(X, y, train_ratio, test_ratio)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Data generators module for Juniper Data."""
|
|
2
|
+
|
|
3
|
+
from .spiral import SpiralGenerator, SpiralParams
|
|
4
|
+
from .spiral import get_schema as get_spiral_schema
|
|
5
|
+
|
|
6
|
+
# Backwards-compatible alias: existing code may still import `get_schema`
|
|
7
|
+
# from this package. Prefer `get_spiral_schema` for new code.
|
|
8
|
+
get_schema = get_spiral_schema
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"SpiralGenerator",
|
|
12
|
+
"SpiralParams",
|
|
13
|
+
"get_spiral_schema",
|
|
14
|
+
"get_schema",
|
|
15
|
+
]
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""ARC-AGI (Abstraction and Reasoning Corpus) dataset generator."""
|
|
2
|
+
|
|
3
|
+
from juniper_data.generators.arc_agi.generator import VERSION, ArcAgiGenerator, get_schema
|
|
4
|
+
from juniper_data.generators.arc_agi.params import ArcAgiParams
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"ArcAgiGenerator",
|
|
8
|
+
"ArcAgiParams",
|
|
9
|
+
"VERSION",
|
|
10
|
+
"get_schema",
|
|
11
|
+
]
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""ARC-AGI dataset generator.
|
|
2
|
+
|
|
3
|
+
This module provides the ArcAgiGenerator class for loading ARC-AGI
|
|
4
|
+
(Abstraction and Reasoning Corpus) tasks from Hugging Face or local files.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from juniper_data.core.split import shuffle_and_split
|
|
13
|
+
|
|
14
|
+
from .params import ArcAgiParams
|
|
15
|
+
|
|
16
|
+
VERSION = "1.0.0"
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from datasets import load_dataset as hf_load_dataset
|
|
20
|
+
|
|
21
|
+
HF_AVAILABLE = True
|
|
22
|
+
except ImportError:
|
|
23
|
+
HF_AVAILABLE = False
|
|
24
|
+
hf_load_dataset = None # type: ignore[assignment]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ArcAgiGenerator:
|
|
28
|
+
"""Generator for ARC-AGI reasoning tasks.
|
|
29
|
+
|
|
30
|
+
Loads ARC tasks and converts them to padded numpy arrays suitable
|
|
31
|
+
for machine learning. Each task contains input/output grid pairs
|
|
32
|
+
demonstrating a transformation pattern.
|
|
33
|
+
|
|
34
|
+
Grid values are integers 0-9 (colors), with -1 used for padding.
|
|
35
|
+
|
|
36
|
+
Requires the `datasets` package for HuggingFace source:
|
|
37
|
+
pip install datasets
|
|
38
|
+
|
|
39
|
+
All methods are static to ensure the generator is stateless and side-effect free.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def generate(params: ArcAgiParams) -> dict[str, np.ndarray]:
|
|
44
|
+
"""Generate an ARC-AGI dataset with train/test splits.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
params: ArcAgiParams instance defining loading configuration.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Dictionary containing:
|
|
51
|
+
- X_train: Training input grids
|
|
52
|
+
- y_train: Training output grids
|
|
53
|
+
- X_test: Test input grids
|
|
54
|
+
- y_test: Test output grids
|
|
55
|
+
- X_full: All input grids
|
|
56
|
+
- y_full: All output grids
|
|
57
|
+
- task_ids: Task identifiers for each sample
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
ImportError: If datasets package is not installed (HF source).
|
|
61
|
+
FileNotFoundError: If local path does not exist.
|
|
62
|
+
"""
|
|
63
|
+
if params.source == "huggingface":
|
|
64
|
+
tasks = ArcAgiGenerator._load_from_huggingface(params)
|
|
65
|
+
else:
|
|
66
|
+
tasks = ArcAgiGenerator._load_from_local(params)
|
|
67
|
+
|
|
68
|
+
X, y, task_ids = ArcAgiGenerator._convert_tasks_to_arrays(tasks, params)
|
|
69
|
+
|
|
70
|
+
split_result = shuffle_and_split(
|
|
71
|
+
X=X,
|
|
72
|
+
y=y,
|
|
73
|
+
train_ratio=params.train_ratio,
|
|
74
|
+
test_ratio=params.test_ratio,
|
|
75
|
+
seed=params.seed,
|
|
76
|
+
shuffle=params.shuffle,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
return {
|
|
80
|
+
"X_train": split_result["X_train"],
|
|
81
|
+
"y_train": split_result["y_train"],
|
|
82
|
+
"X_test": split_result["X_test"],
|
|
83
|
+
"y_test": split_result["y_test"],
|
|
84
|
+
"X_full": X,
|
|
85
|
+
"y_full": y,
|
|
86
|
+
"task_ids": task_ids,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
@staticmethod
|
|
90
|
+
def _load_from_huggingface(params: ArcAgiParams) -> list[dict]:
|
|
91
|
+
"""Load ARC tasks from Hugging Face Hub."""
|
|
92
|
+
if not HF_AVAILABLE:
|
|
93
|
+
raise ImportError("Hugging Face datasets package not installed. Install with: pip install datasets")
|
|
94
|
+
|
|
95
|
+
# assert hf_load_dataset is not None
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
ds = hf_load_dataset("fchollet/arc-agi", split="train") # nosec B615
|
|
99
|
+
except Exception:
|
|
100
|
+
ds = hf_load_dataset("multimodal-reasoning-lab/ARC-AGI", split="train") # nosec B615
|
|
101
|
+
|
|
102
|
+
tasks: list[dict] = []
|
|
103
|
+
for item in ds:
|
|
104
|
+
task = {
|
|
105
|
+
"task_id": item.get("task_id", f"task_{len(tasks)}"),
|
|
106
|
+
"train": item.get("train", []),
|
|
107
|
+
"test": item.get("test", []),
|
|
108
|
+
}
|
|
109
|
+
tasks.append(task)
|
|
110
|
+
|
|
111
|
+
if params.n_tasks is not None:
|
|
112
|
+
if params.seed is None:
|
|
113
|
+
tasks = tasks[: params.n_tasks]
|
|
114
|
+
|
|
115
|
+
else:
|
|
116
|
+
rng = np.random.default_rng(params.seed)
|
|
117
|
+
indices = rng.choice(len(tasks), min(params.n_tasks, len(tasks)), replace=False)
|
|
118
|
+
tasks = [tasks[i] for i in indices]
|
|
119
|
+
return tasks
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def _load_from_local(params: ArcAgiParams) -> list[dict]:
|
|
123
|
+
"""Load ARC tasks from local JSON files."""
|
|
124
|
+
if params.local_path is None:
|
|
125
|
+
raise ValueError("local_path is required when source='local'")
|
|
126
|
+
|
|
127
|
+
base_path = Path(params.local_path)
|
|
128
|
+
if not base_path.exists():
|
|
129
|
+
raise FileNotFoundError(f"Path not found: {params.local_path}")
|
|
130
|
+
|
|
131
|
+
tasks = []
|
|
132
|
+
|
|
133
|
+
if params.subset in ("training", "all"):
|
|
134
|
+
training_path = base_path / "training"
|
|
135
|
+
if training_path.exists():
|
|
136
|
+
tasks.extend(ArcAgiGenerator._load_json_dir(training_path))
|
|
137
|
+
|
|
138
|
+
if params.subset in ("evaluation", "all"):
|
|
139
|
+
eval_path = base_path / "evaluation"
|
|
140
|
+
if eval_path.exists():
|
|
141
|
+
tasks.extend(ArcAgiGenerator._load_json_dir(eval_path))
|
|
142
|
+
|
|
143
|
+
if params.n_tasks is not None:
|
|
144
|
+
if params.seed is None:
|
|
145
|
+
tasks = tasks[: params.n_tasks]
|
|
146
|
+
|
|
147
|
+
else:
|
|
148
|
+
rng = np.random.default_rng(params.seed)
|
|
149
|
+
indices = rng.choice(len(tasks), min(params.n_tasks, len(tasks)), replace=False)
|
|
150
|
+
tasks = [tasks[i] for i in indices]
|
|
151
|
+
return tasks
|
|
152
|
+
|
|
153
|
+
@staticmethod
|
|
154
|
+
def _load_json_dir(dir_path: Path) -> list[dict]:
|
|
155
|
+
"""Load all JSON task files from a directory."""
|
|
156
|
+
tasks = []
|
|
157
|
+
for json_file in sorted(dir_path.glob("*.json")):
|
|
158
|
+
with open(json_file, encoding="utf-8") as f:
|
|
159
|
+
task_data = json.load(f)
|
|
160
|
+
task_data["task_id"] = json_file.stem
|
|
161
|
+
tasks.append(task_data)
|
|
162
|
+
return tasks
|
|
163
|
+
|
|
164
|
+
@staticmethod
|
|
165
|
+
def _convert_tasks_to_arrays(tasks: list[dict], params: ArcAgiParams) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
166
|
+
"""Convert ARC tasks to padded numpy arrays."""
|
|
167
|
+
inputs = []
|
|
168
|
+
outputs = []
|
|
169
|
+
task_ids = []
|
|
170
|
+
|
|
171
|
+
for task in tasks:
|
|
172
|
+
task_id = task.get("task_id", "unknown")
|
|
173
|
+
|
|
174
|
+
for pair in task.get("train", []):
|
|
175
|
+
input_grid = ArcAgiGenerator._pad_grid(pair["input"], params.pad_to, params.pad_value)
|
|
176
|
+
output_grid = ArcAgiGenerator._pad_grid(pair["output"], params.pad_to, params.pad_value)
|
|
177
|
+
inputs.append(input_grid)
|
|
178
|
+
outputs.append(output_grid)
|
|
179
|
+
task_ids.append(task_id)
|
|
180
|
+
|
|
181
|
+
if params.include_test:
|
|
182
|
+
for pair in task.get("test", []):
|
|
183
|
+
input_grid = ArcAgiGenerator._pad_grid(pair["input"], params.pad_to, params.pad_value)
|
|
184
|
+
output_grid = ArcAgiGenerator._pad_grid(
|
|
185
|
+
pair.get("output", [[params.pad_value]]),
|
|
186
|
+
params.pad_to,
|
|
187
|
+
params.pad_value,
|
|
188
|
+
)
|
|
189
|
+
inputs.append(input_grid)
|
|
190
|
+
outputs.append(output_grid)
|
|
191
|
+
task_ids.append(task_id)
|
|
192
|
+
|
|
193
|
+
if not inputs:
|
|
194
|
+
X_arr = np.zeros((0, params.pad_to * params.pad_to), dtype=np.float32)
|
|
195
|
+
y_arr = np.zeros((0, params.pad_to * params.pad_to), dtype=np.float32)
|
|
196
|
+
ids = np.array([], dtype=object)
|
|
197
|
+
return X_arr, y_arr, ids
|
|
198
|
+
|
|
199
|
+
X_stacked = np.stack(inputs)
|
|
200
|
+
y_stacked = np.stack(outputs)
|
|
201
|
+
|
|
202
|
+
if params.flatten_pairs:
|
|
203
|
+
X_arr = X_stacked.reshape(len(X_stacked), -1).astype(np.float32)
|
|
204
|
+
y_arr = y_stacked.reshape(len(y_stacked), -1).astype(np.float32)
|
|
205
|
+
else:
|
|
206
|
+
X_arr = X_stacked.astype(np.float32)
|
|
207
|
+
y_arr = y_stacked.astype(np.float32)
|
|
208
|
+
|
|
209
|
+
return X_arr, y_arr, np.array(task_ids, dtype=object)
|
|
210
|
+
|
|
211
|
+
@staticmethod
|
|
212
|
+
def _pad_grid(grid: list[list[int]], pad_to: int, pad_value: int) -> np.ndarray:
|
|
213
|
+
"""Pad a grid to the specified size."""
|
|
214
|
+
arr = np.array(grid, dtype=np.int16)
|
|
215
|
+
h, w = arr.shape
|
|
216
|
+
|
|
217
|
+
padded = np.full((pad_to, pad_to), pad_value, dtype=np.int16)
|
|
218
|
+
padded[:h, :w] = arr
|
|
219
|
+
|
|
220
|
+
return padded
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def get_schema() -> dict:
|
|
224
|
+
"""Return JSON schema describing the generator parameters.
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
JSON schema dictionary for ArcAgiParams.
|
|
228
|
+
"""
|
|
229
|
+
return ArcAgiParams.model_json_schema()
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Parameters for the ARC-AGI dataset generator."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ArcAgiParams(BaseModel):
|
|
9
|
+
"""Configuration parameters for ARC-AGI dataset loading.
|
|
10
|
+
|
|
11
|
+
Loads ARC-AGI tasks from Hugging Face Hub or local JSON files.
|
|
12
|
+
The ARC (Abstraction and Reasoning Corpus) contains grid-based
|
|
13
|
+
reasoning tasks with input/output pairs.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
source: Literal["huggingface", "local"] = Field(
|
|
17
|
+
default="huggingface",
|
|
18
|
+
description="Data source: 'huggingface' or 'local' JSON files",
|
|
19
|
+
)
|
|
20
|
+
local_path: str | None = Field(
|
|
21
|
+
default=None,
|
|
22
|
+
description="Path to local ARC JSON files (required if source='local')",
|
|
23
|
+
)
|
|
24
|
+
subset: Literal["training", "evaluation", "all"] = Field(
|
|
25
|
+
default="training",
|
|
26
|
+
description="Which subset to load: 'training', 'evaluation', or 'all'",
|
|
27
|
+
)
|
|
28
|
+
n_tasks: int | None = Field(
|
|
29
|
+
default=None,
|
|
30
|
+
ge=1,
|
|
31
|
+
description="Limit number of tasks to load (None for all)",
|
|
32
|
+
)
|
|
33
|
+
pad_to: int = Field(
|
|
34
|
+
default=30,
|
|
35
|
+
ge=1,
|
|
36
|
+
le=50,
|
|
37
|
+
description="Pad all grids to this size (max ARC grid is 30x30)",
|
|
38
|
+
)
|
|
39
|
+
pad_value: int = Field(
|
|
40
|
+
default=-1,
|
|
41
|
+
ge=-1,
|
|
42
|
+
le=9,
|
|
43
|
+
description="Value to use for padding (-1 recommended for masking)",
|
|
44
|
+
)
|
|
45
|
+
include_test: bool = Field(
|
|
46
|
+
default=True,
|
|
47
|
+
description="Include test input/output pairs (in addition to train pairs)",
|
|
48
|
+
)
|
|
49
|
+
flatten_pairs: bool = Field(
|
|
50
|
+
default=True,
|
|
51
|
+
description="Flatten all input/output pairs into single arrays",
|
|
52
|
+
)
|
|
53
|
+
seed: int | None = Field(default=None, ge=0, description="Random seed for reproducibility")
|
|
54
|
+
train_ratio: float = Field(default=0.8, gt=0, le=1, description="Fraction of data for training")
|
|
55
|
+
test_ratio: float = Field(default=0.2, ge=0, le=1, description="Fraction of data for testing")
|
|
56
|
+
shuffle: bool = Field(default=True, description="Shuffle before splitting")
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Checkerboard classification dataset generator."""
|
|
2
|
+
|
|
3
|
+
from juniper_data.generators.checkerboard.generator import (
|
|
4
|
+
VERSION,
|
|
5
|
+
CheckerboardGenerator,
|
|
6
|
+
get_schema,
|
|
7
|
+
)
|
|
8
|
+
from juniper_data.generators.checkerboard.params import CheckerboardParams
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"CheckerboardGenerator",
|
|
12
|
+
"CheckerboardParams",
|
|
13
|
+
"VERSION",
|
|
14
|
+
"get_schema",
|
|
15
|
+
]
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""Core NumPy-only checkerboard dataset generator.
|
|
2
|
+
|
|
3
|
+
This module provides the CheckerboardGenerator class for generating
|
|
4
|
+
checkerboard classification datasets using only NumPy operations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from juniper_data.core.split import shuffle_and_split
|
|
10
|
+
|
|
11
|
+
from .params import CheckerboardParams
|
|
12
|
+
|
|
13
|
+
VERSION = "1.0.0"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CheckerboardGenerator:
|
|
17
|
+
"""NumPy-only generator for checkerboard classification datasets.
|
|
18
|
+
|
|
19
|
+
Generates a 2D checkerboard pattern where alternating squares
|
|
20
|
+
belong to different classes. Points are uniformly distributed
|
|
21
|
+
across the grid.
|
|
22
|
+
|
|
23
|
+
All methods are static to ensure the generator is stateless and side-effect free.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def generate(params: CheckerboardParams) -> dict[str, np.ndarray]:
|
|
28
|
+
"""Generate a complete checkerboard dataset with train/test splits.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
params: CheckerboardParams instance defining generation configuration.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Dictionary containing:
|
|
35
|
+
- X_train: Training features (n_train, 2)
|
|
36
|
+
- y_train: Training labels (n_train, 2)
|
|
37
|
+
- X_test: Test features (n_test, 2)
|
|
38
|
+
- y_test: Test labels (n_test, 2)
|
|
39
|
+
- X_full: Full dataset features (n_samples, 2)
|
|
40
|
+
- y_full: Full dataset labels (n_samples, 2)
|
|
41
|
+
"""
|
|
42
|
+
rng = np.random.default_rng(params.seed)
|
|
43
|
+
|
|
44
|
+
X, y = CheckerboardGenerator._generate_raw(params, rng)
|
|
45
|
+
|
|
46
|
+
split_result = shuffle_and_split(
|
|
47
|
+
X=X,
|
|
48
|
+
y=y,
|
|
49
|
+
train_ratio=params.train_ratio,
|
|
50
|
+
test_ratio=params.test_ratio,
|
|
51
|
+
seed=params.seed,
|
|
52
|
+
shuffle=params.shuffle,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
return {
|
|
56
|
+
"X_train": split_result["X_train"],
|
|
57
|
+
"y_train": split_result["y_train"],
|
|
58
|
+
"X_test": split_result["X_test"],
|
|
59
|
+
"y_test": split_result["y_test"],
|
|
60
|
+
"X_full": X,
|
|
61
|
+
"y_full": y,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def _generate_raw(params: CheckerboardParams, rng: np.random.Generator) -> tuple[np.ndarray, np.ndarray]:
|
|
66
|
+
"""Generate raw checkerboard coordinates and labels.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
params: CheckerboardParams instance defining generation configuration.
|
|
70
|
+
rng: NumPy random generator for reproducibility.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Tuple of (X, y) where:
|
|
74
|
+
- X: Feature array of shape (n_samples, 2)
|
|
75
|
+
- y: One-hot label array of shape (n_samples, 2)
|
|
76
|
+
"""
|
|
77
|
+
x_min, x_max = params.x_range
|
|
78
|
+
y_min, y_max = params.y_range
|
|
79
|
+
|
|
80
|
+
x = rng.uniform(x_min, x_max, params.n_samples)
|
|
81
|
+
y_coord = rng.uniform(y_min, y_max, params.n_samples)
|
|
82
|
+
|
|
83
|
+
X = np.column_stack([x, y_coord])
|
|
84
|
+
|
|
85
|
+
if params.noise > 0:
|
|
86
|
+
X += rng.standard_normal(X.shape) * params.noise
|
|
87
|
+
|
|
88
|
+
X = X.astype(np.float32)
|
|
89
|
+
|
|
90
|
+
x_step = (x_max - x_min) / params.n_squares
|
|
91
|
+
y_step = (y_max - y_min) / params.n_squares
|
|
92
|
+
|
|
93
|
+
x_idx = np.floor((x - x_min) / x_step).astype(int)
|
|
94
|
+
y_idx = np.floor((y_coord - y_min) / y_step).astype(int)
|
|
95
|
+
|
|
96
|
+
x_idx = np.clip(x_idx, 0, params.n_squares - 1)
|
|
97
|
+
y_idx = np.clip(y_idx, 0, params.n_squares - 1)
|
|
98
|
+
|
|
99
|
+
labels = (x_idx + y_idx) % 2
|
|
100
|
+
|
|
101
|
+
y = np.zeros((params.n_samples, 2), dtype=np.float32)
|
|
102
|
+
y[labels == 0, 0] = 1.0
|
|
103
|
+
y[labels == 1, 1] = 1.0
|
|
104
|
+
|
|
105
|
+
return X, y
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def get_schema() -> dict:
|
|
109
|
+
"""Return JSON schema describing the generator parameters.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
JSON schema dictionary for CheckerboardParams.
|
|
113
|
+
"""
|
|
114
|
+
return CheckerboardParams.model_json_schema()
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""Parameters for the checkerboard dataset generator."""
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CheckerboardParams(BaseModel):
|
|
7
|
+
"""Configuration parameters for checkerboard dataset generation.
|
|
8
|
+
|
|
9
|
+
Generates a checkerboard pattern classification dataset where
|
|
10
|
+
alternating squares belong to different classes.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
n_samples: int = Field(default=200, ge=2, description="Total number of samples")
|
|
14
|
+
n_squares: int = Field(
|
|
15
|
+
default=4,
|
|
16
|
+
ge=2,
|
|
17
|
+
le=16,
|
|
18
|
+
description="Number of squares per side (total squares = n_squares^2)",
|
|
19
|
+
)
|
|
20
|
+
x_range: tuple[float, float] = Field(
|
|
21
|
+
default=(0.0, 1.0),
|
|
22
|
+
description="Range of x values (min, max)",
|
|
23
|
+
)
|
|
24
|
+
y_range: tuple[float, float] = Field(
|
|
25
|
+
default=(0.0, 1.0),
|
|
26
|
+
description="Range of y values (min, max)",
|
|
27
|
+
)
|
|
28
|
+
noise: float = Field(default=0.0, ge=0, description="Gaussian noise level")
|
|
29
|
+
seed: int | None = Field(default=None, ge=0, description="Random seed for reproducibility")
|
|
30
|
+
train_ratio: float = Field(default=0.8, gt=0, le=1, description="Fraction of data for training")
|
|
31
|
+
test_ratio: float = Field(default=0.2, ge=0, le=1, description="Fraction of data for testing")
|
|
32
|
+
shuffle: bool = Field(default=True, description="Shuffle before splitting")
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Concentric circles classification dataset generator."""
|
|
2
|
+
|
|
3
|
+
from juniper_data.generators.circles.generator import VERSION, CirclesGenerator, get_schema
|
|
4
|
+
from juniper_data.generators.circles.params import CirclesParams
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"CirclesGenerator",
|
|
8
|
+
"CirclesParams",
|
|
9
|
+
"VERSION",
|
|
10
|
+
"get_schema",
|
|
11
|
+
]
|