juniper-data 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. juniper_data/__init__.py +88 -0
  2. juniper_data/__main__.py +78 -0
  3. juniper_data/api/__init__.py +10 -0
  4. juniper_data/api/app.py +111 -0
  5. juniper_data/api/middleware.py +95 -0
  6. juniper_data/api/routes/__init__.py +9 -0
  7. juniper_data/api/routes/datasets.py +414 -0
  8. juniper_data/api/routes/generators.py +125 -0
  9. juniper_data/api/routes/health.py +49 -0
  10. juniper_data/api/security.py +238 -0
  11. juniper_data/api/settings.py +109 -0
  12. juniper_data/core/__init__.py +32 -0
  13. juniper_data/core/artifacts.py +63 -0
  14. juniper_data/core/dataset_id.py +38 -0
  15. juniper_data/core/models.py +135 -0
  16. juniper_data/core/split.py +120 -0
  17. juniper_data/generators/__init__.py +15 -0
  18. juniper_data/generators/arc_agi/__init__.py +11 -0
  19. juniper_data/generators/arc_agi/generator.py +229 -0
  20. juniper_data/generators/arc_agi/params.py +56 -0
  21. juniper_data/generators/checkerboard/__init__.py +15 -0
  22. juniper_data/generators/checkerboard/generator.py +114 -0
  23. juniper_data/generators/checkerboard/params.py +32 -0
  24. juniper_data/generators/circles/__init__.py +11 -0
  25. juniper_data/generators/circles/generator.py +112 -0
  26. juniper_data/generators/circles/params.py +31 -0
  27. juniper_data/generators/csv_import/__init__.py +15 -0
  28. juniper_data/generators/csv_import/generator.py +198 -0
  29. juniper_data/generators/csv_import/params.py +48 -0
  30. juniper_data/generators/gaussian/__init__.py +11 -0
  31. juniper_data/generators/gaussian/generator.py +149 -0
  32. juniper_data/generators/gaussian/params.py +53 -0
  33. juniper_data/generators/mnist/__init__.py +11 -0
  34. juniper_data/generators/mnist/generator.py +124 -0
  35. juniper_data/generators/mnist/params.py +39 -0
  36. juniper_data/generators/spiral/__init__.py +57 -0
  37. juniper_data/generators/spiral/defaults.py +39 -0
  38. juniper_data/generators/spiral/generator.py +206 -0
  39. juniper_data/generators/spiral/params.py +148 -0
  40. juniper_data/generators/xor/__init__.py +11 -0
  41. juniper_data/generators/xor/generator.py +162 -0
  42. juniper_data/generators/xor/params.py +30 -0
  43. juniper_data/storage/__init__.py +120 -0
  44. juniper_data/storage/base.py +279 -0
  45. juniper_data/storage/cached.py +211 -0
  46. juniper_data/storage/hf_store.py +257 -0
  47. juniper_data/storage/kaggle_store.py +333 -0
  48. juniper_data/storage/local_fs.py +232 -0
  49. juniper_data/storage/memory.py +136 -0
  50. juniper_data/storage/postgres_store.py +373 -0
  51. juniper_data/storage/redis_store.py +264 -0
  52. juniper_data/tests/__init__.py +1 -0
  53. juniper_data/tests/conftest.py +68 -0
  54. juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
  55. juniper_data/tests/integration/__init__.py +1 -0
  56. juniper_data/tests/integration/test_api.py +283 -0
  57. juniper_data/tests/integration/test_e2e_workflow.py +378 -0
  58. juniper_data/tests/integration/test_lifecycle_api.py +304 -0
  59. juniper_data/tests/integration/test_security_integration.py +189 -0
  60. juniper_data/tests/integration/test_storage_workflow.py +259 -0
  61. juniper_data/tests/performance/__init__.py +1 -0
  62. juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
  63. juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
  64. juniper_data/tests/unit/__init__.py +1 -0
  65. juniper_data/tests/unit/test_api_app.py +206 -0
  66. juniper_data/tests/unit/test_api_routes.py +407 -0
  67. juniper_data/tests/unit/test_api_settings.py +100 -0
  68. juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
  69. juniper_data/tests/unit/test_artifacts.py +145 -0
  70. juniper_data/tests/unit/test_cached_store.py +423 -0
  71. juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
  72. juniper_data/tests/unit/test_circles_generator.py +256 -0
  73. juniper_data/tests/unit/test_csv_import_generator.py +345 -0
  74. juniper_data/tests/unit/test_dataset_id.py +181 -0
  75. juniper_data/tests/unit/test_gaussian_generator.py +333 -0
  76. juniper_data/tests/unit/test_hf_store.py +416 -0
  77. juniper_data/tests/unit/test_init.py +93 -0
  78. juniper_data/tests/unit/test_kaggle_store.py +469 -0
  79. juniper_data/tests/unit/test_lifecycle.py +394 -0
  80. juniper_data/tests/unit/test_main.py +127 -0
  81. juniper_data/tests/unit/test_middleware.py +79 -0
  82. juniper_data/tests/unit/test_mnist_generator.py +370 -0
  83. juniper_data/tests/unit/test_postgres_store.py +490 -0
  84. juniper_data/tests/unit/test_redis_store.py +500 -0
  85. juniper_data/tests/unit/test_security.py +281 -0
  86. juniper_data/tests/unit/test_security_boundaries.py +517 -0
  87. juniper_data/tests/unit/test_spiral_generator.py +566 -0
  88. juniper_data/tests/unit/test_split.py +245 -0
  89. juniper_data/tests/unit/test_storage.py +767 -0
  90. juniper_data/tests/unit/test_xor_generator.py +223 -0
  91. juniper_data-0.4.2.dist-info/METADATA +216 -0
  92. juniper_data-0.4.2.dist-info/RECORD +95 -0
  93. juniper_data-0.4.2.dist-info/WHEEL +5 -0
  94. juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
  95. juniper_data-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,120 @@
1
+ """Split and shuffle utilities for dataset partitioning.
2
+
3
+ This module provides pure NumPy utilities for shuffling and splitting datasets
4
+ into train/test sets with reproducible random number generation.
5
+ """
6
+
7
+ import numpy as np
8
+
9
+
10
+ def shuffle_data(
11
+ X: np.ndarray,
12
+ y: np.ndarray,
13
+ rng: np.random.Generator,
14
+ ) -> tuple[np.ndarray, np.ndarray]:
15
+ """Shuffle X and y arrays together using the same permutation.
16
+
17
+ Args:
18
+ X: Feature array of shape (n_samples, ...).
19
+ y: Label array of shape (n_samples, ...).
20
+ rng: NumPy random generator for reproducibility.
21
+
22
+ Returns:
23
+ Tuple of shuffled (X, y) arrays with the same permutation applied.
24
+
25
+ Raises:
26
+ ValueError: If X and y have different number of samples.
27
+ """
28
+ if X.shape[0] != y.shape[0]:
29
+ raise ValueError(
30
+ f"X and y must have the same number of samples. Got X.shape[0]={X.shape[0]}, y.shape[0]={y.shape[0]}"
31
+ )
32
+
33
+ permutation = rng.permutation(X.shape[0])
34
+ return X[permutation], y[permutation]
35
+
36
+
37
+ def split_data(
38
+ X: np.ndarray,
39
+ y: np.ndarray,
40
+ train_ratio: float,
41
+ test_ratio: float,
42
+ ) -> dict[str, np.ndarray]:
43
+ """Split arrays into train and test sets based on ratios.
44
+
45
+ Args:
46
+ X: Feature array of shape (n_samples, ...).
47
+ y: Label array of shape (n_samples, ...).
48
+ train_ratio: Fraction of data for training (0.0 to 1.0).
49
+ test_ratio: Fraction of data for testing (0.0 to 1.0).
50
+
51
+ Returns:
52
+ Dictionary with keys "X_train", "y_train", "X_test", "y_test".
53
+
54
+ Raises:
55
+ ValueError: If ratios are invalid or X and y have different sample counts.
56
+ """
57
+ if X.shape[0] != y.shape[0]:
58
+ raise ValueError(
59
+ f"X and y must have the same number of samples. Got X.shape[0]={X.shape[0]}, y.shape[0]={y.shape[0]}"
60
+ )
61
+
62
+ if not (0.0 <= train_ratio <= 1.0):
63
+ raise ValueError(f"train_ratio must be between 0 and 1. Got {train_ratio}")
64
+
65
+ if not (0.0 <= test_ratio <= 1.0):
66
+ raise ValueError(f"test_ratio must be between 0 and 1. Got {test_ratio}")
67
+
68
+ if train_ratio + test_ratio > 1.0:
69
+ raise ValueError(
70
+ f"train_ratio + test_ratio must not exceed 1.0. "
71
+ f"Got {train_ratio} + {test_ratio} = {train_ratio + test_ratio}"
72
+ )
73
+
74
+ n_samples = X.shape[0]
75
+ n_train = int(np.round(n_samples * train_ratio))
76
+ n_test = int(np.round(n_samples * test_ratio))
77
+
78
+ if n_train + n_test > n_samples:
79
+ n_test = n_samples - n_train
80
+
81
+ return {
82
+ "X_train": X[:n_train],
83
+ "y_train": y[:n_train],
84
+ "X_test": X[n_train : n_train + n_test],
85
+ "y_test": y[n_train : n_train + n_test],
86
+ }
87
+
88
+
89
+ def shuffle_and_split(
90
+ X: np.ndarray,
91
+ y: np.ndarray,
92
+ train_ratio: float,
93
+ test_ratio: float,
94
+ seed: int | None = None,
95
+ shuffle: bool = True,
96
+ ) -> dict[str, np.ndarray]:
97
+ """Optionally shuffle and then split data into train/test sets.
98
+
99
+ High-level function that combines shuffling and splitting operations.
100
+ Uses np.random.Generator for reproducible randomness.
101
+
102
+ Args:
103
+ X: Feature array of shape (n_samples, ...).
104
+ y: Label array of shape (n_samples, ...).
105
+ train_ratio: Fraction of data for training (0.0 to 1.0).
106
+ test_ratio: Fraction of data for testing (0.0 to 1.0).
107
+ seed: Random seed for reproducibility. If None, uses non-deterministic seed.
108
+ shuffle: Whether to shuffle data before splitting. Defaults to True.
109
+
110
+ Returns:
111
+ Dictionary with keys "X_train", "y_train", "X_test", "y_test".
112
+
113
+ Raises:
114
+ ValueError: If ratios are invalid or X and y have different sample counts.
115
+ """
116
+ if shuffle:
117
+ rng = np.random.default_rng(seed)
118
+ X, y = shuffle_data(X, y, rng)
119
+
120
+ return split_data(X, y, train_ratio, test_ratio)
@@ -0,0 +1,15 @@
1
+ """Data generators module for Juniper Data."""
2
+
3
+ from .spiral import SpiralGenerator, SpiralParams
4
+ from .spiral import get_schema as get_spiral_schema
5
+
6
+ # Backwards-compatible alias: existing code may still import `get_schema`
7
+ # from this package. Prefer `get_spiral_schema` for new code.
8
+ get_schema = get_spiral_schema
9
+
10
+ __all__ = [
11
+ "SpiralGenerator",
12
+ "SpiralParams",
13
+ "get_spiral_schema",
14
+ "get_schema",
15
+ ]
@@ -0,0 +1,11 @@
1
+ """ARC-AGI (Abstraction and Reasoning Corpus) dataset generator."""
2
+
3
+ from juniper_data.generators.arc_agi.generator import VERSION, ArcAgiGenerator, get_schema
4
+ from juniper_data.generators.arc_agi.params import ArcAgiParams
5
+
6
+ __all__ = [
7
+ "ArcAgiGenerator",
8
+ "ArcAgiParams",
9
+ "VERSION",
10
+ "get_schema",
11
+ ]
@@ -0,0 +1,229 @@
1
+ """ARC-AGI dataset generator.
2
+
3
+ This module provides the ArcAgiGenerator class for loading ARC-AGI
4
+ (Abstraction and Reasoning Corpus) tasks from Hugging Face or local files.
5
+ """
6
+
7
+ import json
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+
12
+ from juniper_data.core.split import shuffle_and_split
13
+
14
+ from .params import ArcAgiParams
15
+
16
+ VERSION = "1.0.0"
17
+
18
+ try:
19
+ from datasets import load_dataset as hf_load_dataset
20
+
21
+ HF_AVAILABLE = True
22
+ except ImportError:
23
+ HF_AVAILABLE = False
24
+ hf_load_dataset = None # type: ignore[assignment]
25
+
26
+
27
+ class ArcAgiGenerator:
28
+ """Generator for ARC-AGI reasoning tasks.
29
+
30
+ Loads ARC tasks and converts them to padded numpy arrays suitable
31
+ for machine learning. Each task contains input/output grid pairs
32
+ demonstrating a transformation pattern.
33
+
34
+ Grid values are integers 0-9 (colors), with -1 used for padding.
35
+
36
+ Requires the `datasets` package for HuggingFace source:
37
+ pip install datasets
38
+
39
+ All methods are static to ensure the generator is stateless and side-effect free.
40
+ """
41
+
42
+ @staticmethod
43
+ def generate(params: ArcAgiParams) -> dict[str, np.ndarray]:
44
+ """Generate an ARC-AGI dataset with train/test splits.
45
+
46
+ Args:
47
+ params: ArcAgiParams instance defining loading configuration.
48
+
49
+ Returns:
50
+ Dictionary containing:
51
+ - X_train: Training input grids
52
+ - y_train: Training output grids
53
+ - X_test: Test input grids
54
+ - y_test: Test output grids
55
+ - X_full: All input grids
56
+ - y_full: All output grids
57
+ - task_ids: Task identifiers for each sample
58
+
59
+ Raises:
60
+ ImportError: If datasets package is not installed (HF source).
61
+ FileNotFoundError: If local path does not exist.
62
+ """
63
+ if params.source == "huggingface":
64
+ tasks = ArcAgiGenerator._load_from_huggingface(params)
65
+ else:
66
+ tasks = ArcAgiGenerator._load_from_local(params)
67
+
68
+ X, y, task_ids = ArcAgiGenerator._convert_tasks_to_arrays(tasks, params)
69
+
70
+ split_result = shuffle_and_split(
71
+ X=X,
72
+ y=y,
73
+ train_ratio=params.train_ratio,
74
+ test_ratio=params.test_ratio,
75
+ seed=params.seed,
76
+ shuffle=params.shuffle,
77
+ )
78
+
79
+ return {
80
+ "X_train": split_result["X_train"],
81
+ "y_train": split_result["y_train"],
82
+ "X_test": split_result["X_test"],
83
+ "y_test": split_result["y_test"],
84
+ "X_full": X,
85
+ "y_full": y,
86
+ "task_ids": task_ids,
87
+ }
88
+
89
+ @staticmethod
90
+ def _load_from_huggingface(params: ArcAgiParams) -> list[dict]:
91
+ """Load ARC tasks from Hugging Face Hub."""
92
+ if not HF_AVAILABLE:
93
+ raise ImportError("Hugging Face datasets package not installed. Install with: pip install datasets")
94
+
95
+ # assert hf_load_dataset is not None
96
+
97
+ try:
98
+ ds = hf_load_dataset("fchollet/arc-agi", split="train") # nosec B615
99
+ except Exception:
100
+ ds = hf_load_dataset("multimodal-reasoning-lab/ARC-AGI", split="train") # nosec B615
101
+
102
+ tasks: list[dict] = []
103
+ for item in ds:
104
+ task = {
105
+ "task_id": item.get("task_id", f"task_{len(tasks)}"),
106
+ "train": item.get("train", []),
107
+ "test": item.get("test", []),
108
+ }
109
+ tasks.append(task)
110
+
111
+ if params.n_tasks is not None:
112
+ if params.seed is None:
113
+ tasks = tasks[: params.n_tasks]
114
+
115
+ else:
116
+ rng = np.random.default_rng(params.seed)
117
+ indices = rng.choice(len(tasks), min(params.n_tasks, len(tasks)), replace=False)
118
+ tasks = [tasks[i] for i in indices]
119
+ return tasks
120
+
121
+ @staticmethod
122
+ def _load_from_local(params: ArcAgiParams) -> list[dict]:
123
+ """Load ARC tasks from local JSON files."""
124
+ if params.local_path is None:
125
+ raise ValueError("local_path is required when source='local'")
126
+
127
+ base_path = Path(params.local_path)
128
+ if not base_path.exists():
129
+ raise FileNotFoundError(f"Path not found: {params.local_path}")
130
+
131
+ tasks = []
132
+
133
+ if params.subset in ("training", "all"):
134
+ training_path = base_path / "training"
135
+ if training_path.exists():
136
+ tasks.extend(ArcAgiGenerator._load_json_dir(training_path))
137
+
138
+ if params.subset in ("evaluation", "all"):
139
+ eval_path = base_path / "evaluation"
140
+ if eval_path.exists():
141
+ tasks.extend(ArcAgiGenerator._load_json_dir(eval_path))
142
+
143
+ if params.n_tasks is not None:
144
+ if params.seed is None:
145
+ tasks = tasks[: params.n_tasks]
146
+
147
+ else:
148
+ rng = np.random.default_rng(params.seed)
149
+ indices = rng.choice(len(tasks), min(params.n_tasks, len(tasks)), replace=False)
150
+ tasks = [tasks[i] for i in indices]
151
+ return tasks
152
+
153
+ @staticmethod
154
+ def _load_json_dir(dir_path: Path) -> list[dict]:
155
+ """Load all JSON task files from a directory."""
156
+ tasks = []
157
+ for json_file in sorted(dir_path.glob("*.json")):
158
+ with open(json_file, encoding="utf-8") as f:
159
+ task_data = json.load(f)
160
+ task_data["task_id"] = json_file.stem
161
+ tasks.append(task_data)
162
+ return tasks
163
+
164
+ @staticmethod
165
+ def _convert_tasks_to_arrays(tasks: list[dict], params: ArcAgiParams) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
166
+ """Convert ARC tasks to padded numpy arrays."""
167
+ inputs = []
168
+ outputs = []
169
+ task_ids = []
170
+
171
+ for task in tasks:
172
+ task_id = task.get("task_id", "unknown")
173
+
174
+ for pair in task.get("train", []):
175
+ input_grid = ArcAgiGenerator._pad_grid(pair["input"], params.pad_to, params.pad_value)
176
+ output_grid = ArcAgiGenerator._pad_grid(pair["output"], params.pad_to, params.pad_value)
177
+ inputs.append(input_grid)
178
+ outputs.append(output_grid)
179
+ task_ids.append(task_id)
180
+
181
+ if params.include_test:
182
+ for pair in task.get("test", []):
183
+ input_grid = ArcAgiGenerator._pad_grid(pair["input"], params.pad_to, params.pad_value)
184
+ output_grid = ArcAgiGenerator._pad_grid(
185
+ pair.get("output", [[params.pad_value]]),
186
+ params.pad_to,
187
+ params.pad_value,
188
+ )
189
+ inputs.append(input_grid)
190
+ outputs.append(output_grid)
191
+ task_ids.append(task_id)
192
+
193
+ if not inputs:
194
+ X_arr = np.zeros((0, params.pad_to * params.pad_to), dtype=np.float32)
195
+ y_arr = np.zeros((0, params.pad_to * params.pad_to), dtype=np.float32)
196
+ ids = np.array([], dtype=object)
197
+ return X_arr, y_arr, ids
198
+
199
+ X_stacked = np.stack(inputs)
200
+ y_stacked = np.stack(outputs)
201
+
202
+ if params.flatten_pairs:
203
+ X_arr = X_stacked.reshape(len(X_stacked), -1).astype(np.float32)
204
+ y_arr = y_stacked.reshape(len(y_stacked), -1).astype(np.float32)
205
+ else:
206
+ X_arr = X_stacked.astype(np.float32)
207
+ y_arr = y_stacked.astype(np.float32)
208
+
209
+ return X_arr, y_arr, np.array(task_ids, dtype=object)
210
+
211
+ @staticmethod
212
+ def _pad_grid(grid: list[list[int]], pad_to: int, pad_value: int) -> np.ndarray:
213
+ """Pad a grid to the specified size."""
214
+ arr = np.array(grid, dtype=np.int16)
215
+ h, w = arr.shape
216
+
217
+ padded = np.full((pad_to, pad_to), pad_value, dtype=np.int16)
218
+ padded[:h, :w] = arr
219
+
220
+ return padded
221
+
222
+
223
+ def get_schema() -> dict:
224
+ """Return JSON schema describing the generator parameters.
225
+
226
+ Returns:
227
+ JSON schema dictionary for ArcAgiParams.
228
+ """
229
+ return ArcAgiParams.model_json_schema()
@@ -0,0 +1,56 @@
1
+ """Parameters for the ARC-AGI dataset generator."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class ArcAgiParams(BaseModel):
9
+ """Configuration parameters for ARC-AGI dataset loading.
10
+
11
+ Loads ARC-AGI tasks from Hugging Face Hub or local JSON files.
12
+ The ARC (Abstraction and Reasoning Corpus) contains grid-based
13
+ reasoning tasks with input/output pairs.
14
+ """
15
+
16
+ source: Literal["huggingface", "local"] = Field(
17
+ default="huggingface",
18
+ description="Data source: 'huggingface' or 'local' JSON files",
19
+ )
20
+ local_path: str | None = Field(
21
+ default=None,
22
+ description="Path to local ARC JSON files (required if source='local')",
23
+ )
24
+ subset: Literal["training", "evaluation", "all"] = Field(
25
+ default="training",
26
+ description="Which subset to load: 'training', 'evaluation', or 'all'",
27
+ )
28
+ n_tasks: int | None = Field(
29
+ default=None,
30
+ ge=1,
31
+ description="Limit number of tasks to load (None for all)",
32
+ )
33
+ pad_to: int = Field(
34
+ default=30,
35
+ ge=1,
36
+ le=50,
37
+ description="Pad all grids to this size (max ARC grid is 30x30)",
38
+ )
39
+ pad_value: int = Field(
40
+ default=-1,
41
+ ge=-1,
42
+ le=9,
43
+ description="Value to use for padding (-1 recommended for masking)",
44
+ )
45
+ include_test: bool = Field(
46
+ default=True,
47
+ description="Include test input/output pairs (in addition to train pairs)",
48
+ )
49
+ flatten_pairs: bool = Field(
50
+ default=True,
51
+ description="Flatten all input/output pairs into single arrays",
52
+ )
53
+ seed: int | None = Field(default=None, ge=0, description="Random seed for reproducibility")
54
+ train_ratio: float = Field(default=0.8, gt=0, le=1, description="Fraction of data for training")
55
+ test_ratio: float = Field(default=0.2, ge=0, le=1, description="Fraction of data for testing")
56
+ shuffle: bool = Field(default=True, description="Shuffle before splitting")
@@ -0,0 +1,15 @@
1
+ """Checkerboard classification dataset generator."""
2
+
3
+ from juniper_data.generators.checkerboard.generator import (
4
+ VERSION,
5
+ CheckerboardGenerator,
6
+ get_schema,
7
+ )
8
+ from juniper_data.generators.checkerboard.params import CheckerboardParams
9
+
10
+ __all__ = [
11
+ "CheckerboardGenerator",
12
+ "CheckerboardParams",
13
+ "VERSION",
14
+ "get_schema",
15
+ ]
@@ -0,0 +1,114 @@
1
+ """Core NumPy-only checkerboard dataset generator.
2
+
3
+ This module provides the CheckerboardGenerator class for generating
4
+ checkerboard classification datasets using only NumPy operations.
5
+ """
6
+
7
+ import numpy as np
8
+
9
+ from juniper_data.core.split import shuffle_and_split
10
+
11
+ from .params import CheckerboardParams
12
+
13
+ VERSION = "1.0.0"
14
+
15
+
16
+ class CheckerboardGenerator:
17
+ """NumPy-only generator for checkerboard classification datasets.
18
+
19
+ Generates a 2D checkerboard pattern where alternating squares
20
+ belong to different classes. Points are uniformly distributed
21
+ across the grid.
22
+
23
+ All methods are static to ensure the generator is stateless and side-effect free.
24
+ """
25
+
26
+ @staticmethod
27
+ def generate(params: CheckerboardParams) -> dict[str, np.ndarray]:
28
+ """Generate a complete checkerboard dataset with train/test splits.
29
+
30
+ Args:
31
+ params: CheckerboardParams instance defining generation configuration.
32
+
33
+ Returns:
34
+ Dictionary containing:
35
+ - X_train: Training features (n_train, 2)
36
+ - y_train: Training labels (n_train, 2)
37
+ - X_test: Test features (n_test, 2)
38
+ - y_test: Test labels (n_test, 2)
39
+ - X_full: Full dataset features (n_samples, 2)
40
+ - y_full: Full dataset labels (n_samples, 2)
41
+ """
42
+ rng = np.random.default_rng(params.seed)
43
+
44
+ X, y = CheckerboardGenerator._generate_raw(params, rng)
45
+
46
+ split_result = shuffle_and_split(
47
+ X=X,
48
+ y=y,
49
+ train_ratio=params.train_ratio,
50
+ test_ratio=params.test_ratio,
51
+ seed=params.seed,
52
+ shuffle=params.shuffle,
53
+ )
54
+
55
+ return {
56
+ "X_train": split_result["X_train"],
57
+ "y_train": split_result["y_train"],
58
+ "X_test": split_result["X_test"],
59
+ "y_test": split_result["y_test"],
60
+ "X_full": X,
61
+ "y_full": y,
62
+ }
63
+
64
+ @staticmethod
65
+ def _generate_raw(params: CheckerboardParams, rng: np.random.Generator) -> tuple[np.ndarray, np.ndarray]:
66
+ """Generate raw checkerboard coordinates and labels.
67
+
68
+ Args:
69
+ params: CheckerboardParams instance defining generation configuration.
70
+ rng: NumPy random generator for reproducibility.
71
+
72
+ Returns:
73
+ Tuple of (X, y) where:
74
+ - X: Feature array of shape (n_samples, 2)
75
+ - y: One-hot label array of shape (n_samples, 2)
76
+ """
77
+ x_min, x_max = params.x_range
78
+ y_min, y_max = params.y_range
79
+
80
+ x = rng.uniform(x_min, x_max, params.n_samples)
81
+ y_coord = rng.uniform(y_min, y_max, params.n_samples)
82
+
83
+ X = np.column_stack([x, y_coord])
84
+
85
+ if params.noise > 0:
86
+ X += rng.standard_normal(X.shape) * params.noise
87
+
88
+ X = X.astype(np.float32)
89
+
90
+ x_step = (x_max - x_min) / params.n_squares
91
+ y_step = (y_max - y_min) / params.n_squares
92
+
93
+ x_idx = np.floor((x - x_min) / x_step).astype(int)
94
+ y_idx = np.floor((y_coord - y_min) / y_step).astype(int)
95
+
96
+ x_idx = np.clip(x_idx, 0, params.n_squares - 1)
97
+ y_idx = np.clip(y_idx, 0, params.n_squares - 1)
98
+
99
+ labels = (x_idx + y_idx) % 2
100
+
101
+ y = np.zeros((params.n_samples, 2), dtype=np.float32)
102
+ y[labels == 0, 0] = 1.0
103
+ y[labels == 1, 1] = 1.0
104
+
105
+ return X, y
106
+
107
+
108
+ def get_schema() -> dict:
109
+ """Return JSON schema describing the generator parameters.
110
+
111
+ Returns:
112
+ JSON schema dictionary for CheckerboardParams.
113
+ """
114
+ return CheckerboardParams.model_json_schema()
@@ -0,0 +1,32 @@
1
+ """Parameters for the checkerboard dataset generator."""
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class CheckerboardParams(BaseModel):
7
+ """Configuration parameters for checkerboard dataset generation.
8
+
9
+ Generates a checkerboard pattern classification dataset where
10
+ alternating squares belong to different classes.
11
+ """
12
+
13
+ n_samples: int = Field(default=200, ge=2, description="Total number of samples")
14
+ n_squares: int = Field(
15
+ default=4,
16
+ ge=2,
17
+ le=16,
18
+ description="Number of squares per side (total squares = n_squares^2)",
19
+ )
20
+ x_range: tuple[float, float] = Field(
21
+ default=(0.0, 1.0),
22
+ description="Range of x values (min, max)",
23
+ )
24
+ y_range: tuple[float, float] = Field(
25
+ default=(0.0, 1.0),
26
+ description="Range of y values (min, max)",
27
+ )
28
+ noise: float = Field(default=0.0, ge=0, description="Gaussian noise level")
29
+ seed: int | None = Field(default=None, ge=0, description="Random seed for reproducibility")
30
+ train_ratio: float = Field(default=0.8, gt=0, le=1, description="Fraction of data for training")
31
+ test_ratio: float = Field(default=0.2, ge=0, le=1, description="Fraction of data for testing")
32
+ shuffle: bool = Field(default=True, description="Shuffle before splitting")
@@ -0,0 +1,11 @@
1
+ """Concentric circles classification dataset generator."""
2
+
3
+ from juniper_data.generators.circles.generator import VERSION, CirclesGenerator, get_schema
4
+ from juniper_data.generators.circles.params import CirclesParams
5
+
6
+ __all__ = [
7
+ "CirclesGenerator",
8
+ "CirclesParams",
9
+ "VERSION",
10
+ "get_schema",
11
+ ]