juniper-data 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. juniper_data/__init__.py +88 -0
  2. juniper_data/__main__.py +78 -0
  3. juniper_data/api/__init__.py +10 -0
  4. juniper_data/api/app.py +111 -0
  5. juniper_data/api/middleware.py +95 -0
  6. juniper_data/api/routes/__init__.py +9 -0
  7. juniper_data/api/routes/datasets.py +414 -0
  8. juniper_data/api/routes/generators.py +125 -0
  9. juniper_data/api/routes/health.py +49 -0
  10. juniper_data/api/security.py +238 -0
  11. juniper_data/api/settings.py +109 -0
  12. juniper_data/core/__init__.py +32 -0
  13. juniper_data/core/artifacts.py +63 -0
  14. juniper_data/core/dataset_id.py +38 -0
  15. juniper_data/core/models.py +135 -0
  16. juniper_data/core/split.py +120 -0
  17. juniper_data/generators/__init__.py +15 -0
  18. juniper_data/generators/arc_agi/__init__.py +11 -0
  19. juniper_data/generators/arc_agi/generator.py +229 -0
  20. juniper_data/generators/arc_agi/params.py +56 -0
  21. juniper_data/generators/checkerboard/__init__.py +15 -0
  22. juniper_data/generators/checkerboard/generator.py +114 -0
  23. juniper_data/generators/checkerboard/params.py +32 -0
  24. juniper_data/generators/circles/__init__.py +11 -0
  25. juniper_data/generators/circles/generator.py +112 -0
  26. juniper_data/generators/circles/params.py +31 -0
  27. juniper_data/generators/csv_import/__init__.py +15 -0
  28. juniper_data/generators/csv_import/generator.py +198 -0
  29. juniper_data/generators/csv_import/params.py +48 -0
  30. juniper_data/generators/gaussian/__init__.py +11 -0
  31. juniper_data/generators/gaussian/generator.py +149 -0
  32. juniper_data/generators/gaussian/params.py +53 -0
  33. juniper_data/generators/mnist/__init__.py +11 -0
  34. juniper_data/generators/mnist/generator.py +124 -0
  35. juniper_data/generators/mnist/params.py +39 -0
  36. juniper_data/generators/spiral/__init__.py +57 -0
  37. juniper_data/generators/spiral/defaults.py +39 -0
  38. juniper_data/generators/spiral/generator.py +206 -0
  39. juniper_data/generators/spiral/params.py +148 -0
  40. juniper_data/generators/xor/__init__.py +11 -0
  41. juniper_data/generators/xor/generator.py +162 -0
  42. juniper_data/generators/xor/params.py +30 -0
  43. juniper_data/storage/__init__.py +120 -0
  44. juniper_data/storage/base.py +279 -0
  45. juniper_data/storage/cached.py +211 -0
  46. juniper_data/storage/hf_store.py +257 -0
  47. juniper_data/storage/kaggle_store.py +333 -0
  48. juniper_data/storage/local_fs.py +232 -0
  49. juniper_data/storage/memory.py +136 -0
  50. juniper_data/storage/postgres_store.py +373 -0
  51. juniper_data/storage/redis_store.py +264 -0
  52. juniper_data/tests/__init__.py +1 -0
  53. juniper_data/tests/conftest.py +68 -0
  54. juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
  55. juniper_data/tests/integration/__init__.py +1 -0
  56. juniper_data/tests/integration/test_api.py +283 -0
  57. juniper_data/tests/integration/test_e2e_workflow.py +378 -0
  58. juniper_data/tests/integration/test_lifecycle_api.py +304 -0
  59. juniper_data/tests/integration/test_security_integration.py +189 -0
  60. juniper_data/tests/integration/test_storage_workflow.py +259 -0
  61. juniper_data/tests/performance/__init__.py +1 -0
  62. juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
  63. juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
  64. juniper_data/tests/unit/__init__.py +1 -0
  65. juniper_data/tests/unit/test_api_app.py +206 -0
  66. juniper_data/tests/unit/test_api_routes.py +407 -0
  67. juniper_data/tests/unit/test_api_settings.py +100 -0
  68. juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
  69. juniper_data/tests/unit/test_artifacts.py +145 -0
  70. juniper_data/tests/unit/test_cached_store.py +423 -0
  71. juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
  72. juniper_data/tests/unit/test_circles_generator.py +256 -0
  73. juniper_data/tests/unit/test_csv_import_generator.py +345 -0
  74. juniper_data/tests/unit/test_dataset_id.py +181 -0
  75. juniper_data/tests/unit/test_gaussian_generator.py +333 -0
  76. juniper_data/tests/unit/test_hf_store.py +416 -0
  77. juniper_data/tests/unit/test_init.py +93 -0
  78. juniper_data/tests/unit/test_kaggle_store.py +469 -0
  79. juniper_data/tests/unit/test_lifecycle.py +394 -0
  80. juniper_data/tests/unit/test_main.py +127 -0
  81. juniper_data/tests/unit/test_middleware.py +79 -0
  82. juniper_data/tests/unit/test_mnist_generator.py +370 -0
  83. juniper_data/tests/unit/test_postgres_store.py +490 -0
  84. juniper_data/tests/unit/test_redis_store.py +500 -0
  85. juniper_data/tests/unit/test_security.py +281 -0
  86. juniper_data/tests/unit/test_security_boundaries.py +517 -0
  87. juniper_data/tests/unit/test_spiral_generator.py +566 -0
  88. juniper_data/tests/unit/test_split.py +245 -0
  89. juniper_data/tests/unit/test_storage.py +767 -0
  90. juniper_data/tests/unit/test_xor_generator.py +223 -0
  91. juniper_data-0.4.2.dist-info/METADATA +216 -0
  92. juniper_data-0.4.2.dist-info/RECORD +95 -0
  93. juniper_data-0.4.2.dist-info/WHEEL +5 -0
  94. juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
  95. juniper_data-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,124 @@
1
+ """MNIST dataset generator using Hugging Face datasets.
2
+
3
+ This module provides the MnistGenerator class for loading and preprocessing
4
+ MNIST and Fashion-MNIST datasets from the Hugging Face Hub.
5
+ """
6
+
7
+ import numpy as np
8
+
9
+ from juniper_data.core.split import shuffle_and_split
10
+
11
+ from .params import MnistParams
12
+
13
+ VERSION = "1.0.0"
14
+
15
+ try:
16
+ from datasets import load_dataset as hf_load_dataset
17
+
18
+ HF_AVAILABLE = True
19
+ except ImportError:
20
+ HF_AVAILABLE = False
21
+ hf_load_dataset = None # type: ignore[assignment]
22
+
23
+
24
+ class MnistGenerator:
25
+ """Generator for MNIST and Fashion-MNIST datasets.
26
+
27
+ Loads datasets from Hugging Face Hub and converts them to the
28
+ JuniperData format with train/test splits.
29
+
30
+ Requires the `datasets` package: pip install datasets
31
+
32
+ All methods are static to ensure the generator is stateless and side-effect free.
33
+ """
34
+
35
+ @staticmethod
36
+ def generate(params: MnistParams) -> dict[str, np.ndarray]:
37
+ """Generate a complete MNIST dataset with train/test splits.
38
+
39
+ Args:
40
+ params: MnistParams instance defining generation configuration.
41
+
42
+ Returns:
43
+ Dictionary containing:
44
+ - X_train: Training features
45
+ - y_train: Training labels
46
+ - X_test: Test features
47
+ - y_test: Test labels
48
+ - X_full: Full dataset features
49
+ - y_full: Full dataset labels
50
+
51
+ Raises:
52
+ ImportError: If datasets package is not installed.
53
+ """
54
+ if not HF_AVAILABLE:
55
+ raise ImportError("Hugging Face datasets package not installed. Install with: pip install datasets")
56
+
57
+ X, y = MnistGenerator._load_and_preprocess(params)
58
+
59
+ split_result = shuffle_and_split(
60
+ X=X,
61
+ y=y,
62
+ train_ratio=params.train_ratio,
63
+ test_ratio=params.test_ratio,
64
+ seed=params.seed,
65
+ shuffle=params.shuffle,
66
+ )
67
+
68
+ return {
69
+ "X_train": split_result["X_train"],
70
+ "y_train": split_result["y_train"],
71
+ "X_test": split_result["X_test"],
72
+ "y_test": split_result["y_test"],
73
+ "X_full": X,
74
+ "y_full": y,
75
+ }
76
+
77
+ @staticmethod
78
+ def _load_and_preprocess(params: MnistParams) -> tuple[np.ndarray, np.ndarray]:
79
+ """Load dataset from HuggingFace and preprocess.
80
+
81
+ Args:
82
+ params: MnistParams instance.
83
+
84
+ Returns:
85
+ Tuple of (X, y) arrays.
86
+ """
87
+ # assert hf_load_dataset is not None
88
+
89
+ # params.dataset is validated by MnistParams (Pydantic) as Literal["mnist", "fashion_mnist"],
90
+ # so this argument to hf_load_dataset is restricted to these known-safe values.
91
+ ds = hf_load_dataset(params.dataset, split="train") # nosec B615
92
+
93
+ if params.seed is not None:
94
+ ds = ds.shuffle(seed=params.seed)
95
+
96
+ if params.n_samples is not None:
97
+ ds = ds.select(range(params.n_samples))
98
+
99
+ # Use bulk column access with numpy formatting for efficient conversion
100
+ ds = ds.with_format("numpy")
101
+ X = np.array(ds["image"])
102
+ X = X.astype(np.float32) / 255.0 if params.normalize else X.astype(np.float32)
103
+ if params.flatten:
104
+ X = X.reshape(len(X), -1)
105
+
106
+ labels = np.array(ds["label"])
107
+ if params.one_hot_labels:
108
+ n_classes = 10
109
+
110
+ y = np.zeros((len(labels), n_classes), dtype=np.float32)
111
+ y[np.arange(len(labels)), labels] = 1.0
112
+ else:
113
+ y = labels.astype(np.float32).reshape(-1, 1)
114
+
115
+ return X, y
116
+
117
+
118
+ def get_schema() -> dict:
119
+ """Return JSON schema describing the generator parameters.
120
+
121
+ Returns:
122
+ JSON schema dictionary for MnistParams.
123
+ """
124
+ return MnistParams.model_json_schema()
@@ -0,0 +1,39 @@
1
+ """Parameters for the MNIST dataset generator."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class MnistParams(BaseModel):
9
+ """Configuration parameters for MNIST dataset generation.
10
+
11
+ Loads and preprocesses MNIST or Fashion-MNIST datasets from
12
+ Hugging Face Hub.
13
+ """
14
+
15
+ dataset: Literal["mnist", "fashion_mnist"] = Field(
16
+ default="mnist",
17
+ description="Dataset to load: 'mnist' or 'fashion_mnist'",
18
+ )
19
+ n_samples: int | None = Field(
20
+ default=None,
21
+ ge=1,
22
+ description="Limit number of samples (None for full dataset)",
23
+ )
24
+ flatten: bool = Field(
25
+ default=True,
26
+ description="Flatten images to 1D (784 features) or keep 2D (28x28)",
27
+ )
28
+ normalize: bool = Field(
29
+ default=True,
30
+ description="Normalize pixel values to [0, 1]",
31
+ )
32
+ one_hot_labels: bool = Field(
33
+ default=True,
34
+ description="One-hot encode labels (10 classes)",
35
+ )
36
+ seed: int | None = Field(default=None, ge=0, description="Random seed for reproducibility")
37
+ train_ratio: float = Field(default=0.8, gt=0, le=1, description="Fraction of data for training")
38
+ test_ratio: float = Field(default=0.2, ge=0, le=1, description="Fraction of data for testing")
39
+ shuffle: bool = Field(default=True, description="Shuffle before splitting")
@@ -0,0 +1,57 @@
1
+ """Spiral dataset generator module."""
2
+
3
+ from .defaults import (
4
+ MAX_NOISE,
5
+ MAX_POINTS,
6
+ MAX_ROTATIONS,
7
+ MAX_SPIRALS,
8
+ MIN_NOISE,
9
+ MIN_POINTS,
10
+ MIN_ROTATIONS,
11
+ MIN_SPIRALS,
12
+ SPIRAL_DEFAULT_CLOCKWISE,
13
+ SPIRAL_DEFAULT_DISTRIBUTION,
14
+ SPIRAL_DEFAULT_N_POINTS,
15
+ SPIRAL_DEFAULT_N_ROTATIONS,
16
+ SPIRAL_DEFAULT_N_SPIRALS,
17
+ SPIRAL_DEFAULT_NOISE,
18
+ SPIRAL_DEFAULT_ORIGIN,
19
+ SPIRAL_DEFAULT_RADIUS,
20
+ SPIRAL_DEFAULT_RANDOM_VALUE_SCALE,
21
+ SPIRAL_DEFAULT_SEED,
22
+ SPIRAL_DEFAULT_TEST_RATIO,
23
+ SPIRAL_DEFAULT_TRAIN_RATIO,
24
+ )
25
+ from .generator import VERSION, SpiralGenerator, get_schema
26
+ from .params import SpiralParams
27
+
28
+ __all__ = [
29
+ # Generator
30
+ "SpiralGenerator",
31
+ "get_schema",
32
+ "VERSION",
33
+ # Default constants
34
+ "SPIRAL_DEFAULT_N_SPIRALS",
35
+ "SPIRAL_DEFAULT_N_POINTS",
36
+ "SPIRAL_DEFAULT_N_ROTATIONS",
37
+ "SPIRAL_DEFAULT_CLOCKWISE",
38
+ "SPIRAL_DEFAULT_DISTRIBUTION",
39
+ "SPIRAL_DEFAULT_ORIGIN",
40
+ "SPIRAL_DEFAULT_RADIUS",
41
+ "SPIRAL_DEFAULT_NOISE",
42
+ "SPIRAL_DEFAULT_RANDOM_VALUE_SCALE",
43
+ "SPIRAL_DEFAULT_SEED",
44
+ "SPIRAL_DEFAULT_TRAIN_RATIO",
45
+ "SPIRAL_DEFAULT_TEST_RATIO",
46
+ # Validation bounds
47
+ "MIN_SPIRALS",
48
+ "MAX_SPIRALS",
49
+ "MIN_POINTS",
50
+ "MAX_POINTS",
51
+ "MIN_ROTATIONS",
52
+ "MAX_ROTATIONS",
53
+ "MIN_NOISE",
54
+ "MAX_NOISE",
55
+ # Pydantic model
56
+ "SpiralParams",
57
+ ]
@@ -0,0 +1,39 @@
1
+ """Default constants for spiral dataset generation.
2
+
3
+ This module defines all default constants and validation bounds for the spiral
4
+ dataset generator, migrated from JuniperCascor constants_problem.py.
5
+ """
6
+
7
+ # Spiral Geometry Defaults
8
+ SPIRAL_DEFAULT_N_SPIRALS: int = 2
9
+ SPIRAL_DEFAULT_N_POINTS: int = 97
10
+ SPIRAL_DEFAULT_N_ROTATIONS: float = 3.0
11
+ SPIRAL_DEFAULT_CLOCKWISE: bool = True
12
+ SPIRAL_DEFAULT_DISTRIBUTION: float = 0.80
13
+ SPIRAL_DEFAULT_ORIGIN: tuple[float, float] = (0.0, 0.0)
14
+ SPIRAL_DEFAULT_RADIUS: float = 10.0
15
+
16
+ # Noise & Randomness Defaults
17
+ SPIRAL_DEFAULT_NOISE: float = 0.25
18
+ SPIRAL_DEFAULT_RANDOM_VALUE_SCALE: float = 0.1
19
+ SPIRAL_DEFAULT_SEED: int = 42
20
+
21
+ # Dataset Splitting Defaults
22
+ SPIRAL_DEFAULT_TRAIN_RATIO: float = 0.8
23
+ SPIRAL_DEFAULT_TEST_RATIO: float = 0.2
24
+
25
+ # Validation Bounds - Spirals
26
+ MIN_SPIRALS: int = 2
27
+ MAX_SPIRALS: int = 10
28
+
29
+ # Validation Bounds - Points
30
+ MIN_POINTS: int = 10
31
+ MAX_POINTS: int = 10000
32
+
33
+ # Validation Bounds - Rotations
34
+ MIN_ROTATIONS: float = 0.5
35
+ MAX_ROTATIONS: float = 10.0
36
+
37
+ # Validation Bounds - Noise
38
+ MIN_NOISE: float = 0.0
39
+ MAX_NOISE: float = 2.0
@@ -0,0 +1,206 @@
1
+ """Core NumPy-only spiral dataset generator.
2
+
3
+ This module provides the SpiralGenerator class for generating multi-spiral
4
+ classification datasets using only NumPy operations.
5
+ """
6
+
7
+ from typing import Literal
8
+
9
+ import numpy as np
10
+
11
+ from juniper_data.core.split import shuffle_and_split
12
+
13
+ from .params import SpiralParams
14
+
15
+ VERSION = "1.0.0"
16
+
17
+
18
+ class SpiralGenerator:
19
+ """NumPy-only generator for multi-spiral classification datasets.
20
+
21
+ All methods are static to ensure the generator is stateless and side-effect free.
22
+ """
23
+
24
+ @staticmethod
25
+ def generate(params: SpiralParams) -> dict[str, np.ndarray]:
26
+ """Generate a complete spiral dataset with train/test splits.
27
+
28
+ Main public API for spiral dataset generation.
29
+
30
+ Args:
31
+ params: SpiralParams instance defining generation configuration.
32
+
33
+ Returns:
34
+ Dictionary containing:
35
+ - X_train: Training features (n_train, 2)
36
+ - y_train: Training labels (n_train, n_spirals)
37
+ - X_test: Test features (n_test, 2)
38
+ - y_test: Test labels (n_test, n_spirals)
39
+ - X_full: Full dataset features (total_points, 2)
40
+ - y_full: Full dataset labels (total_points, n_spirals)
41
+ """
42
+ rng = np.random.default_rng(params.seed)
43
+
44
+ X, y = SpiralGenerator._generate_raw(params, rng)
45
+
46
+ split_result = shuffle_and_split(
47
+ X=X,
48
+ y=y,
49
+ train_ratio=params.train_ratio,
50
+ test_ratio=params.test_ratio,
51
+ seed=params.seed,
52
+ shuffle=params.shuffle,
53
+ )
54
+
55
+ return {
56
+ "X_train": split_result["X_train"],
57
+ "y_train": split_result["y_train"],
58
+ "X_test": split_result["X_test"],
59
+ "y_test": split_result["y_test"],
60
+ "X_full": X,
61
+ "y_full": y,
62
+ }
63
+
64
+ @staticmethod
65
+ def _generate_raw(params: SpiralParams, rng: np.random.Generator) -> tuple[np.ndarray, np.ndarray]:
66
+ """Generate raw spiral coordinates and labels.
67
+
68
+ Args:
69
+ params: SpiralParams instance defining generation configuration.
70
+ rng: NumPy random generator for reproducibility.
71
+
72
+ Returns:
73
+ Tuple of (X, y) where:
74
+ - X: Feature array of shape (total_points, 2)
75
+ - y: One-hot label array of shape (total_points, n_spirals)
76
+ """
77
+ all_coords = []
78
+
79
+ for i in range(params.n_spirals):
80
+ angle_offset = 2 * np.pi * i / params.n_spirals
81
+ coords = SpiralGenerator._generate_spiral_coordinates(
82
+ n_points=params.n_points_per_spiral,
83
+ radius=params.radius,
84
+ n_rotations=params.n_rotations,
85
+ angle_offset=angle_offset,
86
+ clockwise=params.clockwise,
87
+ noise=params.noise,
88
+ rng=rng,
89
+ algorithm=params.algorithm,
90
+ origin=params.origin,
91
+ )
92
+ all_coords.append(coords)
93
+
94
+ X = np.vstack(all_coords).astype(np.float32)
95
+ y = SpiralGenerator._create_one_hot_labels(
96
+ n_spirals=params.n_spirals,
97
+ n_points_per_spiral=params.n_points_per_spiral,
98
+ )
99
+
100
+ return X, y
101
+
102
+ @staticmethod
103
+ def _generate_spiral_coordinates(
104
+ n_points: int,
105
+ radius: float,
106
+ n_rotations: float,
107
+ angle_offset: float,
108
+ clockwise: bool,
109
+ noise: float,
110
+ rng: np.random.Generator,
111
+ algorithm: Literal["modern", "legacy_cascor"] = "modern",
112
+ origin: tuple[float, float] = (0.0, 0.0),
113
+ ) -> np.ndarray:
114
+ """Generate coordinates for a single spiral arm.
115
+
116
+ Args:
117
+ n_points: Number of points to generate.
118
+ radius: Maximum radius of the spiral.
119
+ n_rotations: Number of full rotations.
120
+ angle_offset: Angular offset for this spiral arm.
121
+ clockwise: Whether spiral rotates clockwise.
122
+ noise: Noise level to apply.
123
+ rng: NumPy random generator.
124
+ algorithm: Generation algorithm ('modern' or 'legacy_cascor').
125
+ origin: Origin point (x, y) for spiral center.
126
+
127
+ Returns:
128
+ Array of shape (n_points, 2) containing x, y coordinates.
129
+ """
130
+ direction = 1 if clockwise else -1
131
+
132
+ if algorithm == "legacy_cascor":
133
+ distance = np.sqrt(rng.random(n_points)) * radius
134
+ theta = direction * (distance + angle_offset)
135
+ x = np.cos(theta) * distance + SpiralGenerator._make_noise_uniform(n_points, noise, rng)
136
+ y = np.sin(theta) * distance + SpiralGenerator._make_noise_uniform(n_points, noise, rng)
137
+ else:
138
+ radii = np.linspace(0, radius, n_points)
139
+ theta = np.linspace(0, 2 * np.pi * n_rotations, n_points) + angle_offset
140
+ x = direction * radii * np.cos(theta) + SpiralGenerator._make_noise(n_points, noise, rng)
141
+ y = direction * radii * np.sin(theta) + SpiralGenerator._make_noise(n_points, noise, rng)
142
+
143
+ x += origin[0]
144
+ y += origin[1]
145
+
146
+ return np.column_stack([x, y]).astype(np.float32)
147
+
148
+ @staticmethod
149
+ def _make_noise(n_points: int, noise: float, rng: np.random.Generator) -> np.ndarray:
150
+ """Generate random noise array using normal distribution.
151
+
152
+ Args:
153
+ n_points: Number of noise values to generate.
154
+ noise: Noise scale factor.
155
+ rng: NumPy random generator.
156
+
157
+ Returns:
158
+ Array of shape (n_points,) containing scaled random noise.
159
+ """
160
+ return rng.standard_normal(n_points) * noise
161
+
162
+ @staticmethod
163
+ def _make_noise_uniform(n_points: int, noise: float, rng: np.random.Generator) -> np.ndarray:
164
+ """Generate uniform random noise in [0, noise).
165
+
166
+ Args:
167
+ n_points: Number of noise values to generate.
168
+ noise: Noise scale factor.
169
+ rng: NumPy random generator.
170
+
171
+ Returns:
172
+ Array of shape (n_points,) containing uniform random noise.
173
+ """
174
+ return rng.random(n_points) * noise
175
+
176
+ @staticmethod
177
+ def _create_one_hot_labels(n_spirals: int, n_points_per_spiral: int) -> np.ndarray:
178
+ """Create one-hot encoded labels for spiral classes.
179
+
180
+ Args:
181
+ n_spirals: Number of spiral classes.
182
+ n_points_per_spiral: Number of points per spiral class.
183
+
184
+ Returns:
185
+ Array of shape (total_points, n_spirals) with one-hot encoding.
186
+ """
187
+ total_points = n_spirals * n_points_per_spiral
188
+ y = np.zeros((total_points, n_spirals), dtype=np.float32)
189
+
190
+ for i in range(n_spirals):
191
+ start_idx = i * n_points_per_spiral
192
+ end_idx = (i + 1) * n_points_per_spiral
193
+ y[start_idx:end_idx, i] = 1.0
194
+
195
+ return y
196
+
197
+
198
+ def get_schema() -> dict:
199
+ """Return JSON schema describing the generator parameters.
200
+
201
+ Useful for API documentation and validation.
202
+
203
+ Returns:
204
+ JSON schema dictionary for SpiralParams.
205
+ """
206
+ return SpiralParams.model_json_schema()
@@ -0,0 +1,148 @@
1
+ """Spiral dataset generator parameters.
2
+
3
+ This module defines the Pydantic model for spiral dataset generation parameters
4
+ with validation and computation methods.
5
+
6
+ Parameter Aliases:
7
+ Some consumers (JuniperCascor, JuniperCanopy) use different parameter names.
8
+ This module supports the following aliases:
9
+ - `n_points` -> `n_points_per_spiral`
10
+ - `noise_level` -> `noise`
11
+ """
12
+
13
+ from typing import Literal
14
+
15
+ from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
16
+
17
+ from .defaults import (
18
+ MAX_NOISE,
19
+ MAX_POINTS,
20
+ MAX_ROTATIONS,
21
+ MAX_SPIRALS,
22
+ MIN_NOISE,
23
+ MIN_POINTS,
24
+ MIN_ROTATIONS,
25
+ MIN_SPIRALS,
26
+ SPIRAL_DEFAULT_CLOCKWISE,
27
+ SPIRAL_DEFAULT_N_POINTS,
28
+ SPIRAL_DEFAULT_N_ROTATIONS,
29
+ SPIRAL_DEFAULT_N_SPIRALS,
30
+ SPIRAL_DEFAULT_NOISE,
31
+ SPIRAL_DEFAULT_RADIUS,
32
+ SPIRAL_DEFAULT_SEED,
33
+ SPIRAL_DEFAULT_TEST_RATIO,
34
+ SPIRAL_DEFAULT_TRAIN_RATIO,
35
+ )
36
+
37
+ PARAMETER_ALIASES: dict[str, str] = {
38
+ "n_points": "n_points_per_spiral",
39
+ "noise_level": "noise",
40
+ }
41
+
42
+
43
+ class SpiralParams(BaseModel):
44
+ """Parameters for spiral dataset generation.
45
+
46
+ Defines the configuration for generating multi-spiral classification datasets
47
+ with support for noise, train/test splitting, and deterministic seeding.
48
+
49
+ Attributes:
50
+ n_spirals: Number of spiral arms to generate.
51
+ n_points_per_spiral: Number of points per spiral arm.
52
+ n_rotations: Number of full rotations for each spiral.
53
+ noise: Noise level applied to point positions.
54
+ clockwise: Whether spirals rotate clockwise.
55
+ seed: Random seed for reproducibility.
56
+ train_ratio: Fraction of data for training set.
57
+ test_ratio: Fraction of data for test set.
58
+ shuffle: Whether to shuffle the dataset before splitting.
59
+
60
+ Parameter Aliases:
61
+ For compatibility with JuniperCascor and JuniperCanopy:
62
+ - `n_points` is accepted as an alias for `n_points_per_spiral`
63
+ - `noise_level` is accepted as an alias for `noise`
64
+ """
65
+
66
+ model_config = ConfigDict(populate_by_name=True)
67
+
68
+ n_spirals: int = Field(
69
+ default=SPIRAL_DEFAULT_N_SPIRALS,
70
+ ge=MIN_SPIRALS,
71
+ le=MAX_SPIRALS,
72
+ description="Number of spiral arms to generate",
73
+ )
74
+ n_points_per_spiral: int = Field(
75
+ default=SPIRAL_DEFAULT_N_POINTS,
76
+ ge=MIN_POINTS,
77
+ le=MAX_POINTS,
78
+ description="Number of points per spiral arm",
79
+ validation_alias=AliasChoices("n_points_per_spiral", "n_points"),
80
+ )
81
+ n_rotations: float = Field(
82
+ default=SPIRAL_DEFAULT_N_ROTATIONS,
83
+ ge=MIN_ROTATIONS,
84
+ le=MAX_ROTATIONS,
85
+ description="Number of full rotations for each spiral",
86
+ )
87
+ noise: float = Field(
88
+ default=SPIRAL_DEFAULT_NOISE,
89
+ ge=MIN_NOISE,
90
+ le=MAX_NOISE,
91
+ description="Noise level applied to point positions",
92
+ validation_alias=AliasChoices("noise", "noise_level"),
93
+ )
94
+ clockwise: bool = Field(
95
+ default=SPIRAL_DEFAULT_CLOCKWISE,
96
+ description="Whether spirals rotate clockwise",
97
+ )
98
+ seed: int | None = Field(
99
+ default=SPIRAL_DEFAULT_SEED,
100
+ description="Random seed for reproducibility",
101
+ )
102
+ train_ratio: float = Field(
103
+ default=SPIRAL_DEFAULT_TRAIN_RATIO,
104
+ ge=0.0,
105
+ le=1.0,
106
+ description="Fraction of data for training set",
107
+ )
108
+ test_ratio: float = Field(
109
+ default=SPIRAL_DEFAULT_TEST_RATIO,
110
+ ge=0.0,
111
+ le=1.0,
112
+ description="Fraction of data for test set",
113
+ )
114
+ shuffle: bool = Field(
115
+ default=True,
116
+ description="Whether to shuffle the dataset before splitting",
117
+ )
118
+ algorithm: Literal["modern", "legacy_cascor"] = Field(
119
+ default="modern",
120
+ description="Generation algorithm: 'modern' (linspace+normal noise) or 'legacy_cascor' (sqrt-uniform radii + uniform noise)",
121
+ )
122
+ radius: float = Field(
123
+ default=SPIRAL_DEFAULT_RADIUS,
124
+ gt=0.0,
125
+ le=100.0,
126
+ description="Maximum radius for modern mode, or max distance parameter for legacy mode",
127
+ )
128
+ origin: tuple[float, float] = Field(
129
+ default=(0.0, 0.0),
130
+ description="Origin point (x, y) for spiral center",
131
+ )
132
+
133
+ @model_validator(mode="after")
134
+ def validate_ratios_sum(self) -> "SpiralParams":
135
+ """Validate that train_ratio + test_ratio <= 1.0."""
136
+ if self.train_ratio + self.test_ratio > 1.0:
137
+ raise ValueError(
138
+ f"train_ratio ({self.train_ratio}) + test_ratio ({self.test_ratio}) must be <= 1.0, got {self.train_ratio + self.test_ratio}"
139
+ )
140
+ return self
141
+
142
+ def total_points(self) -> int:
143
+ """Compute the total number of points in the dataset.
144
+
145
+ Returns:
146
+ Total number of points (n_spirals * n_points_per_spiral).
147
+ """
148
+ return self.n_spirals * self.n_points_per_spiral
@@ -0,0 +1,11 @@
1
+ """XOR classification dataset generator."""
2
+
3
+ from juniper_data.generators.xor.generator import VERSION, XorGenerator, get_schema
4
+ from juniper_data.generators.xor.params import XorParams
5
+
6
+ __all__ = [
7
+ "XorGenerator",
8
+ "XorParams",
9
+ "VERSION",
10
+ "get_schema",
11
+ ]