juniper-data 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- juniper_data/__init__.py +88 -0
- juniper_data/__main__.py +78 -0
- juniper_data/api/__init__.py +10 -0
- juniper_data/api/app.py +111 -0
- juniper_data/api/middleware.py +95 -0
- juniper_data/api/routes/__init__.py +9 -0
- juniper_data/api/routes/datasets.py +414 -0
- juniper_data/api/routes/generators.py +125 -0
- juniper_data/api/routes/health.py +49 -0
- juniper_data/api/security.py +238 -0
- juniper_data/api/settings.py +109 -0
- juniper_data/core/__init__.py +32 -0
- juniper_data/core/artifacts.py +63 -0
- juniper_data/core/dataset_id.py +38 -0
- juniper_data/core/models.py +135 -0
- juniper_data/core/split.py +120 -0
- juniper_data/generators/__init__.py +15 -0
- juniper_data/generators/arc_agi/__init__.py +11 -0
- juniper_data/generators/arc_agi/generator.py +229 -0
- juniper_data/generators/arc_agi/params.py +56 -0
- juniper_data/generators/checkerboard/__init__.py +15 -0
- juniper_data/generators/checkerboard/generator.py +114 -0
- juniper_data/generators/checkerboard/params.py +32 -0
- juniper_data/generators/circles/__init__.py +11 -0
- juniper_data/generators/circles/generator.py +112 -0
- juniper_data/generators/circles/params.py +31 -0
- juniper_data/generators/csv_import/__init__.py +15 -0
- juniper_data/generators/csv_import/generator.py +198 -0
- juniper_data/generators/csv_import/params.py +48 -0
- juniper_data/generators/gaussian/__init__.py +11 -0
- juniper_data/generators/gaussian/generator.py +149 -0
- juniper_data/generators/gaussian/params.py +53 -0
- juniper_data/generators/mnist/__init__.py +11 -0
- juniper_data/generators/mnist/generator.py +124 -0
- juniper_data/generators/mnist/params.py +39 -0
- juniper_data/generators/spiral/__init__.py +57 -0
- juniper_data/generators/spiral/defaults.py +39 -0
- juniper_data/generators/spiral/generator.py +206 -0
- juniper_data/generators/spiral/params.py +148 -0
- juniper_data/generators/xor/__init__.py +11 -0
- juniper_data/generators/xor/generator.py +162 -0
- juniper_data/generators/xor/params.py +30 -0
- juniper_data/storage/__init__.py +120 -0
- juniper_data/storage/base.py +279 -0
- juniper_data/storage/cached.py +211 -0
- juniper_data/storage/hf_store.py +257 -0
- juniper_data/storage/kaggle_store.py +333 -0
- juniper_data/storage/local_fs.py +232 -0
- juniper_data/storage/memory.py +136 -0
- juniper_data/storage/postgres_store.py +373 -0
- juniper_data/storage/redis_store.py +264 -0
- juniper_data/tests/__init__.py +1 -0
- juniper_data/tests/conftest.py +68 -0
- juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
- juniper_data/tests/integration/__init__.py +1 -0
- juniper_data/tests/integration/test_api.py +283 -0
- juniper_data/tests/integration/test_e2e_workflow.py +378 -0
- juniper_data/tests/integration/test_lifecycle_api.py +304 -0
- juniper_data/tests/integration/test_security_integration.py +189 -0
- juniper_data/tests/integration/test_storage_workflow.py +259 -0
- juniper_data/tests/performance/__init__.py +1 -0
- juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
- juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
- juniper_data/tests/unit/__init__.py +1 -0
- juniper_data/tests/unit/test_api_app.py +206 -0
- juniper_data/tests/unit/test_api_routes.py +407 -0
- juniper_data/tests/unit/test_api_settings.py +100 -0
- juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
- juniper_data/tests/unit/test_artifacts.py +145 -0
- juniper_data/tests/unit/test_cached_store.py +423 -0
- juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
- juniper_data/tests/unit/test_circles_generator.py +256 -0
- juniper_data/tests/unit/test_csv_import_generator.py +345 -0
- juniper_data/tests/unit/test_dataset_id.py +181 -0
- juniper_data/tests/unit/test_gaussian_generator.py +333 -0
- juniper_data/tests/unit/test_hf_store.py +416 -0
- juniper_data/tests/unit/test_init.py +93 -0
- juniper_data/tests/unit/test_kaggle_store.py +469 -0
- juniper_data/tests/unit/test_lifecycle.py +394 -0
- juniper_data/tests/unit/test_main.py +127 -0
- juniper_data/tests/unit/test_middleware.py +79 -0
- juniper_data/tests/unit/test_mnist_generator.py +370 -0
- juniper_data/tests/unit/test_postgres_store.py +490 -0
- juniper_data/tests/unit/test_redis_store.py +500 -0
- juniper_data/tests/unit/test_security.py +281 -0
- juniper_data/tests/unit/test_security_boundaries.py +517 -0
- juniper_data/tests/unit/test_spiral_generator.py +566 -0
- juniper_data/tests/unit/test_split.py +245 -0
- juniper_data/tests/unit/test_storage.py +767 -0
- juniper_data/tests/unit/test_xor_generator.py +223 -0
- juniper_data-0.4.2.dist-info/METADATA +216 -0
- juniper_data-0.4.2.dist-info/RECORD +95 -0
- juniper_data-0.4.2.dist-info/WHEEL +5 -0
- juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
- juniper_data-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""MNIST dataset generator using Hugging Face datasets.
|
|
2
|
+
|
|
3
|
+
This module provides the MnistGenerator class for loading and preprocessing
|
|
4
|
+
MNIST and Fashion-MNIST datasets from the Hugging Face Hub.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from juniper_data.core.split import shuffle_and_split
|
|
10
|
+
|
|
11
|
+
from .params import MnistParams
|
|
12
|
+
|
|
13
|
+
VERSION = "1.0.0"
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from datasets import load_dataset as hf_load_dataset
|
|
17
|
+
|
|
18
|
+
HF_AVAILABLE = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
HF_AVAILABLE = False
|
|
21
|
+
hf_load_dataset = None # type: ignore[assignment]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MnistGenerator:
|
|
25
|
+
"""Generator for MNIST and Fashion-MNIST datasets.
|
|
26
|
+
|
|
27
|
+
Loads datasets from Hugging Face Hub and converts them to the
|
|
28
|
+
JuniperData format with train/test splits.
|
|
29
|
+
|
|
30
|
+
Requires the `datasets` package: pip install datasets
|
|
31
|
+
|
|
32
|
+
All methods are static to ensure the generator is stateless and side-effect free.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def generate(params: MnistParams) -> dict[str, np.ndarray]:
|
|
37
|
+
"""Generate a complete MNIST dataset with train/test splits.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
params: MnistParams instance defining generation configuration.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Dictionary containing:
|
|
44
|
+
- X_train: Training features
|
|
45
|
+
- y_train: Training labels
|
|
46
|
+
- X_test: Test features
|
|
47
|
+
- y_test: Test labels
|
|
48
|
+
- X_full: Full dataset features
|
|
49
|
+
- y_full: Full dataset labels
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
ImportError: If datasets package is not installed.
|
|
53
|
+
"""
|
|
54
|
+
if not HF_AVAILABLE:
|
|
55
|
+
raise ImportError("Hugging Face datasets package not installed. Install with: pip install datasets")
|
|
56
|
+
|
|
57
|
+
X, y = MnistGenerator._load_and_preprocess(params)
|
|
58
|
+
|
|
59
|
+
split_result = shuffle_and_split(
|
|
60
|
+
X=X,
|
|
61
|
+
y=y,
|
|
62
|
+
train_ratio=params.train_ratio,
|
|
63
|
+
test_ratio=params.test_ratio,
|
|
64
|
+
seed=params.seed,
|
|
65
|
+
shuffle=params.shuffle,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return {
|
|
69
|
+
"X_train": split_result["X_train"],
|
|
70
|
+
"y_train": split_result["y_train"],
|
|
71
|
+
"X_test": split_result["X_test"],
|
|
72
|
+
"y_test": split_result["y_test"],
|
|
73
|
+
"X_full": X,
|
|
74
|
+
"y_full": y,
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def _load_and_preprocess(params: MnistParams) -> tuple[np.ndarray, np.ndarray]:
|
|
79
|
+
"""Load dataset from HuggingFace and preprocess.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
params: MnistParams instance.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Tuple of (X, y) arrays.
|
|
86
|
+
"""
|
|
87
|
+
# assert hf_load_dataset is not None
|
|
88
|
+
|
|
89
|
+
# params.dataset is validated by MnistParams (Pydantic) as Literal["mnist", "fashion_mnist"],
|
|
90
|
+
# so this argument to hf_load_dataset is restricted to these known-safe values.
|
|
91
|
+
ds = hf_load_dataset(params.dataset, split="train") # nosec B615
|
|
92
|
+
|
|
93
|
+
if params.seed is not None:
|
|
94
|
+
ds = ds.shuffle(seed=params.seed)
|
|
95
|
+
|
|
96
|
+
if params.n_samples is not None:
|
|
97
|
+
ds = ds.select(range(params.n_samples))
|
|
98
|
+
|
|
99
|
+
# Use bulk column access with numpy formatting for efficient conversion
|
|
100
|
+
ds = ds.with_format("numpy")
|
|
101
|
+
X = np.array(ds["image"])
|
|
102
|
+
X = X.astype(np.float32) / 255.0 if params.normalize else X.astype(np.float32)
|
|
103
|
+
if params.flatten:
|
|
104
|
+
X = X.reshape(len(X), -1)
|
|
105
|
+
|
|
106
|
+
labels = np.array(ds["label"])
|
|
107
|
+
if params.one_hot_labels:
|
|
108
|
+
n_classes = 10
|
|
109
|
+
|
|
110
|
+
y = np.zeros((len(labels), n_classes), dtype=np.float32)
|
|
111
|
+
y[np.arange(len(labels)), labels] = 1.0
|
|
112
|
+
else:
|
|
113
|
+
y = labels.astype(np.float32).reshape(-1, 1)
|
|
114
|
+
|
|
115
|
+
return X, y
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def get_schema() -> dict:
|
|
119
|
+
"""Return JSON schema describing the generator parameters.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
JSON schema dictionary for MnistParams.
|
|
123
|
+
"""
|
|
124
|
+
return MnistParams.model_json_schema()
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Parameters for the MNIST dataset generator."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MnistParams(BaseModel):
|
|
9
|
+
"""Configuration parameters for MNIST dataset generation.
|
|
10
|
+
|
|
11
|
+
Loads and preprocesses MNIST or Fashion-MNIST datasets from
|
|
12
|
+
Hugging Face Hub.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
dataset: Literal["mnist", "fashion_mnist"] = Field(
|
|
16
|
+
default="mnist",
|
|
17
|
+
description="Dataset to load: 'mnist' or 'fashion_mnist'",
|
|
18
|
+
)
|
|
19
|
+
n_samples: int | None = Field(
|
|
20
|
+
default=None,
|
|
21
|
+
ge=1,
|
|
22
|
+
description="Limit number of samples (None for full dataset)",
|
|
23
|
+
)
|
|
24
|
+
flatten: bool = Field(
|
|
25
|
+
default=True,
|
|
26
|
+
description="Flatten images to 1D (784 features) or keep 2D (28x28)",
|
|
27
|
+
)
|
|
28
|
+
normalize: bool = Field(
|
|
29
|
+
default=True,
|
|
30
|
+
description="Normalize pixel values to [0, 1]",
|
|
31
|
+
)
|
|
32
|
+
one_hot_labels: bool = Field(
|
|
33
|
+
default=True,
|
|
34
|
+
description="One-hot encode labels (10 classes)",
|
|
35
|
+
)
|
|
36
|
+
seed: int | None = Field(default=None, ge=0, description="Random seed for reproducibility")
|
|
37
|
+
train_ratio: float = Field(default=0.8, gt=0, le=1, description="Fraction of data for training")
|
|
38
|
+
test_ratio: float = Field(default=0.2, ge=0, le=1, description="Fraction of data for testing")
|
|
39
|
+
shuffle: bool = Field(default=True, description="Shuffle before splitting")
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Spiral dataset generator module."""
|
|
2
|
+
|
|
3
|
+
from .defaults import (
|
|
4
|
+
MAX_NOISE,
|
|
5
|
+
MAX_POINTS,
|
|
6
|
+
MAX_ROTATIONS,
|
|
7
|
+
MAX_SPIRALS,
|
|
8
|
+
MIN_NOISE,
|
|
9
|
+
MIN_POINTS,
|
|
10
|
+
MIN_ROTATIONS,
|
|
11
|
+
MIN_SPIRALS,
|
|
12
|
+
SPIRAL_DEFAULT_CLOCKWISE,
|
|
13
|
+
SPIRAL_DEFAULT_DISTRIBUTION,
|
|
14
|
+
SPIRAL_DEFAULT_N_POINTS,
|
|
15
|
+
SPIRAL_DEFAULT_N_ROTATIONS,
|
|
16
|
+
SPIRAL_DEFAULT_N_SPIRALS,
|
|
17
|
+
SPIRAL_DEFAULT_NOISE,
|
|
18
|
+
SPIRAL_DEFAULT_ORIGIN,
|
|
19
|
+
SPIRAL_DEFAULT_RADIUS,
|
|
20
|
+
SPIRAL_DEFAULT_RANDOM_VALUE_SCALE,
|
|
21
|
+
SPIRAL_DEFAULT_SEED,
|
|
22
|
+
SPIRAL_DEFAULT_TEST_RATIO,
|
|
23
|
+
SPIRAL_DEFAULT_TRAIN_RATIO,
|
|
24
|
+
)
|
|
25
|
+
from .generator import VERSION, SpiralGenerator, get_schema
|
|
26
|
+
from .params import SpiralParams
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
# Generator
|
|
30
|
+
"SpiralGenerator",
|
|
31
|
+
"get_schema",
|
|
32
|
+
"VERSION",
|
|
33
|
+
# Default constants
|
|
34
|
+
"SPIRAL_DEFAULT_N_SPIRALS",
|
|
35
|
+
"SPIRAL_DEFAULT_N_POINTS",
|
|
36
|
+
"SPIRAL_DEFAULT_N_ROTATIONS",
|
|
37
|
+
"SPIRAL_DEFAULT_CLOCKWISE",
|
|
38
|
+
"SPIRAL_DEFAULT_DISTRIBUTION",
|
|
39
|
+
"SPIRAL_DEFAULT_ORIGIN",
|
|
40
|
+
"SPIRAL_DEFAULT_RADIUS",
|
|
41
|
+
"SPIRAL_DEFAULT_NOISE",
|
|
42
|
+
"SPIRAL_DEFAULT_RANDOM_VALUE_SCALE",
|
|
43
|
+
"SPIRAL_DEFAULT_SEED",
|
|
44
|
+
"SPIRAL_DEFAULT_TRAIN_RATIO",
|
|
45
|
+
"SPIRAL_DEFAULT_TEST_RATIO",
|
|
46
|
+
# Validation bounds
|
|
47
|
+
"MIN_SPIRALS",
|
|
48
|
+
"MAX_SPIRALS",
|
|
49
|
+
"MIN_POINTS",
|
|
50
|
+
"MAX_POINTS",
|
|
51
|
+
"MIN_ROTATIONS",
|
|
52
|
+
"MAX_ROTATIONS",
|
|
53
|
+
"MIN_NOISE",
|
|
54
|
+
"MAX_NOISE",
|
|
55
|
+
# Pydantic model
|
|
56
|
+
"SpiralParams",
|
|
57
|
+
]
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Default constants for spiral dataset generation.
|
|
2
|
+
|
|
3
|
+
This module defines all default constants and validation bounds for the spiral
|
|
4
|
+
dataset generator, migrated from JuniperCascor constants_problem.py.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
# Spiral Geometry Defaults
|
|
8
|
+
SPIRAL_DEFAULT_N_SPIRALS: int = 2
|
|
9
|
+
SPIRAL_DEFAULT_N_POINTS: int = 97
|
|
10
|
+
SPIRAL_DEFAULT_N_ROTATIONS: float = 3.0
|
|
11
|
+
SPIRAL_DEFAULT_CLOCKWISE: bool = True
|
|
12
|
+
SPIRAL_DEFAULT_DISTRIBUTION: float = 0.80
|
|
13
|
+
SPIRAL_DEFAULT_ORIGIN: tuple[float, float] = (0.0, 0.0)
|
|
14
|
+
SPIRAL_DEFAULT_RADIUS: float = 10.0
|
|
15
|
+
|
|
16
|
+
# Noise & Randomness Defaults
|
|
17
|
+
SPIRAL_DEFAULT_NOISE: float = 0.25
|
|
18
|
+
SPIRAL_DEFAULT_RANDOM_VALUE_SCALE: float = 0.1
|
|
19
|
+
SPIRAL_DEFAULT_SEED: int = 42
|
|
20
|
+
|
|
21
|
+
# Dataset Splitting Defaults
|
|
22
|
+
SPIRAL_DEFAULT_TRAIN_RATIO: float = 0.8
|
|
23
|
+
SPIRAL_DEFAULT_TEST_RATIO: float = 0.2
|
|
24
|
+
|
|
25
|
+
# Validation Bounds - Spirals
|
|
26
|
+
MIN_SPIRALS: int = 2
|
|
27
|
+
MAX_SPIRALS: int = 10
|
|
28
|
+
|
|
29
|
+
# Validation Bounds - Points
|
|
30
|
+
MIN_POINTS: int = 10
|
|
31
|
+
MAX_POINTS: int = 10000
|
|
32
|
+
|
|
33
|
+
# Validation Bounds - Rotations
|
|
34
|
+
MIN_ROTATIONS: float = 0.5
|
|
35
|
+
MAX_ROTATIONS: float = 10.0
|
|
36
|
+
|
|
37
|
+
# Validation Bounds - Noise
|
|
38
|
+
MIN_NOISE: float = 0.0
|
|
39
|
+
MAX_NOISE: float = 2.0
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""Core NumPy-only spiral dataset generator.
|
|
2
|
+
|
|
3
|
+
This module provides the SpiralGenerator class for generating multi-spiral
|
|
4
|
+
classification datasets using only NumPy operations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Literal
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from juniper_data.core.split import shuffle_and_split
|
|
12
|
+
|
|
13
|
+
from .params import SpiralParams
|
|
14
|
+
|
|
15
|
+
VERSION = "1.0.0"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SpiralGenerator:
|
|
19
|
+
"""NumPy-only generator for multi-spiral classification datasets.
|
|
20
|
+
|
|
21
|
+
All methods are static to ensure the generator is stateless and side-effect free.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def generate(params: SpiralParams) -> dict[str, np.ndarray]:
|
|
26
|
+
"""Generate a complete spiral dataset with train/test splits.
|
|
27
|
+
|
|
28
|
+
Main public API for spiral dataset generation.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
params: SpiralParams instance defining generation configuration.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Dictionary containing:
|
|
35
|
+
- X_train: Training features (n_train, 2)
|
|
36
|
+
- y_train: Training labels (n_train, n_spirals)
|
|
37
|
+
- X_test: Test features (n_test, 2)
|
|
38
|
+
- y_test: Test labels (n_test, n_spirals)
|
|
39
|
+
- X_full: Full dataset features (total_points, 2)
|
|
40
|
+
- y_full: Full dataset labels (total_points, n_spirals)
|
|
41
|
+
"""
|
|
42
|
+
rng = np.random.default_rng(params.seed)
|
|
43
|
+
|
|
44
|
+
X, y = SpiralGenerator._generate_raw(params, rng)
|
|
45
|
+
|
|
46
|
+
split_result = shuffle_and_split(
|
|
47
|
+
X=X,
|
|
48
|
+
y=y,
|
|
49
|
+
train_ratio=params.train_ratio,
|
|
50
|
+
test_ratio=params.test_ratio,
|
|
51
|
+
seed=params.seed,
|
|
52
|
+
shuffle=params.shuffle,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
return {
|
|
56
|
+
"X_train": split_result["X_train"],
|
|
57
|
+
"y_train": split_result["y_train"],
|
|
58
|
+
"X_test": split_result["X_test"],
|
|
59
|
+
"y_test": split_result["y_test"],
|
|
60
|
+
"X_full": X,
|
|
61
|
+
"y_full": y,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def _generate_raw(params: SpiralParams, rng: np.random.Generator) -> tuple[np.ndarray, np.ndarray]:
|
|
66
|
+
"""Generate raw spiral coordinates and labels.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
params: SpiralParams instance defining generation configuration.
|
|
70
|
+
rng: NumPy random generator for reproducibility.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Tuple of (X, y) where:
|
|
74
|
+
- X: Feature array of shape (total_points, 2)
|
|
75
|
+
- y: One-hot label array of shape (total_points, n_spirals)
|
|
76
|
+
"""
|
|
77
|
+
all_coords = []
|
|
78
|
+
|
|
79
|
+
for i in range(params.n_spirals):
|
|
80
|
+
angle_offset = 2 * np.pi * i / params.n_spirals
|
|
81
|
+
coords = SpiralGenerator._generate_spiral_coordinates(
|
|
82
|
+
n_points=params.n_points_per_spiral,
|
|
83
|
+
radius=params.radius,
|
|
84
|
+
n_rotations=params.n_rotations,
|
|
85
|
+
angle_offset=angle_offset,
|
|
86
|
+
clockwise=params.clockwise,
|
|
87
|
+
noise=params.noise,
|
|
88
|
+
rng=rng,
|
|
89
|
+
algorithm=params.algorithm,
|
|
90
|
+
origin=params.origin,
|
|
91
|
+
)
|
|
92
|
+
all_coords.append(coords)
|
|
93
|
+
|
|
94
|
+
X = np.vstack(all_coords).astype(np.float32)
|
|
95
|
+
y = SpiralGenerator._create_one_hot_labels(
|
|
96
|
+
n_spirals=params.n_spirals,
|
|
97
|
+
n_points_per_spiral=params.n_points_per_spiral,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return X, y
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def _generate_spiral_coordinates(
|
|
104
|
+
n_points: int,
|
|
105
|
+
radius: float,
|
|
106
|
+
n_rotations: float,
|
|
107
|
+
angle_offset: float,
|
|
108
|
+
clockwise: bool,
|
|
109
|
+
noise: float,
|
|
110
|
+
rng: np.random.Generator,
|
|
111
|
+
algorithm: Literal["modern", "legacy_cascor"] = "modern",
|
|
112
|
+
origin: tuple[float, float] = (0.0, 0.0),
|
|
113
|
+
) -> np.ndarray:
|
|
114
|
+
"""Generate coordinates for a single spiral arm.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
n_points: Number of points to generate.
|
|
118
|
+
radius: Maximum radius of the spiral.
|
|
119
|
+
n_rotations: Number of full rotations.
|
|
120
|
+
angle_offset: Angular offset for this spiral arm.
|
|
121
|
+
clockwise: Whether spiral rotates clockwise.
|
|
122
|
+
noise: Noise level to apply.
|
|
123
|
+
rng: NumPy random generator.
|
|
124
|
+
algorithm: Generation algorithm ('modern' or 'legacy_cascor').
|
|
125
|
+
origin: Origin point (x, y) for spiral center.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Array of shape (n_points, 2) containing x, y coordinates.
|
|
129
|
+
"""
|
|
130
|
+
direction = 1 if clockwise else -1
|
|
131
|
+
|
|
132
|
+
if algorithm == "legacy_cascor":
|
|
133
|
+
distance = np.sqrt(rng.random(n_points)) * radius
|
|
134
|
+
theta = direction * (distance + angle_offset)
|
|
135
|
+
x = np.cos(theta) * distance + SpiralGenerator._make_noise_uniform(n_points, noise, rng)
|
|
136
|
+
y = np.sin(theta) * distance + SpiralGenerator._make_noise_uniform(n_points, noise, rng)
|
|
137
|
+
else:
|
|
138
|
+
radii = np.linspace(0, radius, n_points)
|
|
139
|
+
theta = np.linspace(0, 2 * np.pi * n_rotations, n_points) + angle_offset
|
|
140
|
+
x = direction * radii * np.cos(theta) + SpiralGenerator._make_noise(n_points, noise, rng)
|
|
141
|
+
y = direction * radii * np.sin(theta) + SpiralGenerator._make_noise(n_points, noise, rng)
|
|
142
|
+
|
|
143
|
+
x += origin[0]
|
|
144
|
+
y += origin[1]
|
|
145
|
+
|
|
146
|
+
return np.column_stack([x, y]).astype(np.float32)
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
149
|
+
def _make_noise(n_points: int, noise: float, rng: np.random.Generator) -> np.ndarray:
|
|
150
|
+
"""Generate random noise array using normal distribution.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
n_points: Number of noise values to generate.
|
|
154
|
+
noise: Noise scale factor.
|
|
155
|
+
rng: NumPy random generator.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Array of shape (n_points,) containing scaled random noise.
|
|
159
|
+
"""
|
|
160
|
+
return rng.standard_normal(n_points) * noise
|
|
161
|
+
|
|
162
|
+
@staticmethod
|
|
163
|
+
def _make_noise_uniform(n_points: int, noise: float, rng: np.random.Generator) -> np.ndarray:
|
|
164
|
+
"""Generate uniform random noise in [0, noise).
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
n_points: Number of noise values to generate.
|
|
168
|
+
noise: Noise scale factor.
|
|
169
|
+
rng: NumPy random generator.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Array of shape (n_points,) containing uniform random noise.
|
|
173
|
+
"""
|
|
174
|
+
return rng.random(n_points) * noise
|
|
175
|
+
|
|
176
|
+
@staticmethod
|
|
177
|
+
def _create_one_hot_labels(n_spirals: int, n_points_per_spiral: int) -> np.ndarray:
|
|
178
|
+
"""Create one-hot encoded labels for spiral classes.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
n_spirals: Number of spiral classes.
|
|
182
|
+
n_points_per_spiral: Number of points per spiral class.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Array of shape (total_points, n_spirals) with one-hot encoding.
|
|
186
|
+
"""
|
|
187
|
+
total_points = n_spirals * n_points_per_spiral
|
|
188
|
+
y = np.zeros((total_points, n_spirals), dtype=np.float32)
|
|
189
|
+
|
|
190
|
+
for i in range(n_spirals):
|
|
191
|
+
start_idx = i * n_points_per_spiral
|
|
192
|
+
end_idx = (i + 1) * n_points_per_spiral
|
|
193
|
+
y[start_idx:end_idx, i] = 1.0
|
|
194
|
+
|
|
195
|
+
return y
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def get_schema() -> dict:
|
|
199
|
+
"""Return JSON schema describing the generator parameters.
|
|
200
|
+
|
|
201
|
+
Useful for API documentation and validation.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
JSON schema dictionary for SpiralParams.
|
|
205
|
+
"""
|
|
206
|
+
return SpiralParams.model_json_schema()
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
"""Spiral dataset generator parameters.
|
|
2
|
+
|
|
3
|
+
This module defines the Pydantic model for spiral dataset generation parameters
|
|
4
|
+
with validation and computation methods.
|
|
5
|
+
|
|
6
|
+
Parameter Aliases:
|
|
7
|
+
Some consumers (JuniperCascor, JuniperCanopy) use different parameter names.
|
|
8
|
+
This module supports the following aliases:
|
|
9
|
+
- `n_points` -> `n_points_per_spiral`
|
|
10
|
+
- `noise_level` -> `noise`
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import Literal
|
|
14
|
+
|
|
15
|
+
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
|
|
16
|
+
|
|
17
|
+
from .defaults import (
|
|
18
|
+
MAX_NOISE,
|
|
19
|
+
MAX_POINTS,
|
|
20
|
+
MAX_ROTATIONS,
|
|
21
|
+
MAX_SPIRALS,
|
|
22
|
+
MIN_NOISE,
|
|
23
|
+
MIN_POINTS,
|
|
24
|
+
MIN_ROTATIONS,
|
|
25
|
+
MIN_SPIRALS,
|
|
26
|
+
SPIRAL_DEFAULT_CLOCKWISE,
|
|
27
|
+
SPIRAL_DEFAULT_N_POINTS,
|
|
28
|
+
SPIRAL_DEFAULT_N_ROTATIONS,
|
|
29
|
+
SPIRAL_DEFAULT_N_SPIRALS,
|
|
30
|
+
SPIRAL_DEFAULT_NOISE,
|
|
31
|
+
SPIRAL_DEFAULT_RADIUS,
|
|
32
|
+
SPIRAL_DEFAULT_SEED,
|
|
33
|
+
SPIRAL_DEFAULT_TEST_RATIO,
|
|
34
|
+
SPIRAL_DEFAULT_TRAIN_RATIO,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
PARAMETER_ALIASES: dict[str, str] = {
|
|
38
|
+
"n_points": "n_points_per_spiral",
|
|
39
|
+
"noise_level": "noise",
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SpiralParams(BaseModel):
|
|
44
|
+
"""Parameters for spiral dataset generation.
|
|
45
|
+
|
|
46
|
+
Defines the configuration for generating multi-spiral classification datasets
|
|
47
|
+
with support for noise, train/test splitting, and deterministic seeding.
|
|
48
|
+
|
|
49
|
+
Attributes:
|
|
50
|
+
n_spirals: Number of spiral arms to generate.
|
|
51
|
+
n_points_per_spiral: Number of points per spiral arm.
|
|
52
|
+
n_rotations: Number of full rotations for each spiral.
|
|
53
|
+
noise: Noise level applied to point positions.
|
|
54
|
+
clockwise: Whether spirals rotate clockwise.
|
|
55
|
+
seed: Random seed for reproducibility.
|
|
56
|
+
train_ratio: Fraction of data for training set.
|
|
57
|
+
test_ratio: Fraction of data for test set.
|
|
58
|
+
shuffle: Whether to shuffle the dataset before splitting.
|
|
59
|
+
|
|
60
|
+
Parameter Aliases:
|
|
61
|
+
For compatibility with JuniperCascor and JuniperCanopy:
|
|
62
|
+
- `n_points` is accepted as an alias for `n_points_per_spiral`
|
|
63
|
+
- `noise_level` is accepted as an alias for `noise`
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
model_config = ConfigDict(populate_by_name=True)
|
|
67
|
+
|
|
68
|
+
n_spirals: int = Field(
|
|
69
|
+
default=SPIRAL_DEFAULT_N_SPIRALS,
|
|
70
|
+
ge=MIN_SPIRALS,
|
|
71
|
+
le=MAX_SPIRALS,
|
|
72
|
+
description="Number of spiral arms to generate",
|
|
73
|
+
)
|
|
74
|
+
n_points_per_spiral: int = Field(
|
|
75
|
+
default=SPIRAL_DEFAULT_N_POINTS,
|
|
76
|
+
ge=MIN_POINTS,
|
|
77
|
+
le=MAX_POINTS,
|
|
78
|
+
description="Number of points per spiral arm",
|
|
79
|
+
validation_alias=AliasChoices("n_points_per_spiral", "n_points"),
|
|
80
|
+
)
|
|
81
|
+
n_rotations: float = Field(
|
|
82
|
+
default=SPIRAL_DEFAULT_N_ROTATIONS,
|
|
83
|
+
ge=MIN_ROTATIONS,
|
|
84
|
+
le=MAX_ROTATIONS,
|
|
85
|
+
description="Number of full rotations for each spiral",
|
|
86
|
+
)
|
|
87
|
+
noise: float = Field(
|
|
88
|
+
default=SPIRAL_DEFAULT_NOISE,
|
|
89
|
+
ge=MIN_NOISE,
|
|
90
|
+
le=MAX_NOISE,
|
|
91
|
+
description="Noise level applied to point positions",
|
|
92
|
+
validation_alias=AliasChoices("noise", "noise_level"),
|
|
93
|
+
)
|
|
94
|
+
clockwise: bool = Field(
|
|
95
|
+
default=SPIRAL_DEFAULT_CLOCKWISE,
|
|
96
|
+
description="Whether spirals rotate clockwise",
|
|
97
|
+
)
|
|
98
|
+
seed: int | None = Field(
|
|
99
|
+
default=SPIRAL_DEFAULT_SEED,
|
|
100
|
+
description="Random seed for reproducibility",
|
|
101
|
+
)
|
|
102
|
+
train_ratio: float = Field(
|
|
103
|
+
default=SPIRAL_DEFAULT_TRAIN_RATIO,
|
|
104
|
+
ge=0.0,
|
|
105
|
+
le=1.0,
|
|
106
|
+
description="Fraction of data for training set",
|
|
107
|
+
)
|
|
108
|
+
test_ratio: float = Field(
|
|
109
|
+
default=SPIRAL_DEFAULT_TEST_RATIO,
|
|
110
|
+
ge=0.0,
|
|
111
|
+
le=1.0,
|
|
112
|
+
description="Fraction of data for test set",
|
|
113
|
+
)
|
|
114
|
+
shuffle: bool = Field(
|
|
115
|
+
default=True,
|
|
116
|
+
description="Whether to shuffle the dataset before splitting",
|
|
117
|
+
)
|
|
118
|
+
algorithm: Literal["modern", "legacy_cascor"] = Field(
|
|
119
|
+
default="modern",
|
|
120
|
+
description="Generation algorithm: 'modern' (linspace+normal noise) or 'legacy_cascor' (sqrt-uniform radii + uniform noise)",
|
|
121
|
+
)
|
|
122
|
+
radius: float = Field(
|
|
123
|
+
default=SPIRAL_DEFAULT_RADIUS,
|
|
124
|
+
gt=0.0,
|
|
125
|
+
le=100.0,
|
|
126
|
+
description="Maximum radius for modern mode, or max distance parameter for legacy mode",
|
|
127
|
+
)
|
|
128
|
+
origin: tuple[float, float] = Field(
|
|
129
|
+
default=(0.0, 0.0),
|
|
130
|
+
description="Origin point (x, y) for spiral center",
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
@model_validator(mode="after")
|
|
134
|
+
def validate_ratios_sum(self) -> "SpiralParams":
|
|
135
|
+
"""Validate that train_ratio + test_ratio <= 1.0."""
|
|
136
|
+
if self.train_ratio + self.test_ratio > 1.0:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
f"train_ratio ({self.train_ratio}) + test_ratio ({self.test_ratio}) must be <= 1.0, got {self.train_ratio + self.test_ratio}"
|
|
139
|
+
)
|
|
140
|
+
return self
|
|
141
|
+
|
|
142
|
+
def total_points(self) -> int:
|
|
143
|
+
"""Compute the total number of points in the dataset.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Total number of points (n_spirals * n_points_per_spiral).
|
|
147
|
+
"""
|
|
148
|
+
return self.n_spirals * self.n_points_per_spiral
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""XOR classification dataset generator."""
|
|
2
|
+
|
|
3
|
+
from juniper_data.generators.xor.generator import VERSION, XorGenerator, get_schema
|
|
4
|
+
from juniper_data.generators.xor.params import XorParams
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"XorGenerator",
|
|
8
|
+
"XorParams",
|
|
9
|
+
"VERSION",
|
|
10
|
+
"get_schema",
|
|
11
|
+
]
|