juniper-data 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. juniper_data/__init__.py +88 -0
  2. juniper_data/__main__.py +78 -0
  3. juniper_data/api/__init__.py +10 -0
  4. juniper_data/api/app.py +111 -0
  5. juniper_data/api/middleware.py +95 -0
  6. juniper_data/api/routes/__init__.py +9 -0
  7. juniper_data/api/routes/datasets.py +414 -0
  8. juniper_data/api/routes/generators.py +125 -0
  9. juniper_data/api/routes/health.py +49 -0
  10. juniper_data/api/security.py +238 -0
  11. juniper_data/api/settings.py +109 -0
  12. juniper_data/core/__init__.py +32 -0
  13. juniper_data/core/artifacts.py +63 -0
  14. juniper_data/core/dataset_id.py +38 -0
  15. juniper_data/core/models.py +135 -0
  16. juniper_data/core/split.py +120 -0
  17. juniper_data/generators/__init__.py +15 -0
  18. juniper_data/generators/arc_agi/__init__.py +11 -0
  19. juniper_data/generators/arc_agi/generator.py +229 -0
  20. juniper_data/generators/arc_agi/params.py +56 -0
  21. juniper_data/generators/checkerboard/__init__.py +15 -0
  22. juniper_data/generators/checkerboard/generator.py +114 -0
  23. juniper_data/generators/checkerboard/params.py +32 -0
  24. juniper_data/generators/circles/__init__.py +11 -0
  25. juniper_data/generators/circles/generator.py +112 -0
  26. juniper_data/generators/circles/params.py +31 -0
  27. juniper_data/generators/csv_import/__init__.py +15 -0
  28. juniper_data/generators/csv_import/generator.py +198 -0
  29. juniper_data/generators/csv_import/params.py +48 -0
  30. juniper_data/generators/gaussian/__init__.py +11 -0
  31. juniper_data/generators/gaussian/generator.py +149 -0
  32. juniper_data/generators/gaussian/params.py +53 -0
  33. juniper_data/generators/mnist/__init__.py +11 -0
  34. juniper_data/generators/mnist/generator.py +124 -0
  35. juniper_data/generators/mnist/params.py +39 -0
  36. juniper_data/generators/spiral/__init__.py +57 -0
  37. juniper_data/generators/spiral/defaults.py +39 -0
  38. juniper_data/generators/spiral/generator.py +206 -0
  39. juniper_data/generators/spiral/params.py +148 -0
  40. juniper_data/generators/xor/__init__.py +11 -0
  41. juniper_data/generators/xor/generator.py +162 -0
  42. juniper_data/generators/xor/params.py +30 -0
  43. juniper_data/storage/__init__.py +120 -0
  44. juniper_data/storage/base.py +279 -0
  45. juniper_data/storage/cached.py +211 -0
  46. juniper_data/storage/hf_store.py +257 -0
  47. juniper_data/storage/kaggle_store.py +333 -0
  48. juniper_data/storage/local_fs.py +232 -0
  49. juniper_data/storage/memory.py +136 -0
  50. juniper_data/storage/postgres_store.py +373 -0
  51. juniper_data/storage/redis_store.py +264 -0
  52. juniper_data/tests/__init__.py +1 -0
  53. juniper_data/tests/conftest.py +68 -0
  54. juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
  55. juniper_data/tests/integration/__init__.py +1 -0
  56. juniper_data/tests/integration/test_api.py +283 -0
  57. juniper_data/tests/integration/test_e2e_workflow.py +378 -0
  58. juniper_data/tests/integration/test_lifecycle_api.py +304 -0
  59. juniper_data/tests/integration/test_security_integration.py +189 -0
  60. juniper_data/tests/integration/test_storage_workflow.py +259 -0
  61. juniper_data/tests/performance/__init__.py +1 -0
  62. juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
  63. juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
  64. juniper_data/tests/unit/__init__.py +1 -0
  65. juniper_data/tests/unit/test_api_app.py +206 -0
  66. juniper_data/tests/unit/test_api_routes.py +407 -0
  67. juniper_data/tests/unit/test_api_settings.py +100 -0
  68. juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
  69. juniper_data/tests/unit/test_artifacts.py +145 -0
  70. juniper_data/tests/unit/test_cached_store.py +423 -0
  71. juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
  72. juniper_data/tests/unit/test_circles_generator.py +256 -0
  73. juniper_data/tests/unit/test_csv_import_generator.py +345 -0
  74. juniper_data/tests/unit/test_dataset_id.py +181 -0
  75. juniper_data/tests/unit/test_gaussian_generator.py +333 -0
  76. juniper_data/tests/unit/test_hf_store.py +416 -0
  77. juniper_data/tests/unit/test_init.py +93 -0
  78. juniper_data/tests/unit/test_kaggle_store.py +469 -0
  79. juniper_data/tests/unit/test_lifecycle.py +394 -0
  80. juniper_data/tests/unit/test_main.py +127 -0
  81. juniper_data/tests/unit/test_middleware.py +79 -0
  82. juniper_data/tests/unit/test_mnist_generator.py +370 -0
  83. juniper_data/tests/unit/test_postgres_store.py +490 -0
  84. juniper_data/tests/unit/test_redis_store.py +500 -0
  85. juniper_data/tests/unit/test_security.py +281 -0
  86. juniper_data/tests/unit/test_security_boundaries.py +517 -0
  87. juniper_data/tests/unit/test_spiral_generator.py +566 -0
  88. juniper_data/tests/unit/test_split.py +245 -0
  89. juniper_data/tests/unit/test_storage.py +767 -0
  90. juniper_data/tests/unit/test_xor_generator.py +223 -0
  91. juniper_data-0.4.2.dist-info/METADATA +216 -0
  92. juniper_data-0.4.2.dist-info/RECORD +95 -0
  93. juniper_data-0.4.2.dist-info/WHEEL +5 -0
  94. juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
  95. juniper_data-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,112 @@
1
+ """Core NumPy-only concentric circles dataset generator.
2
+
3
+ This module provides the CirclesGenerator class for generating
4
+ concentric circles classification datasets using only NumPy operations.
5
+ """
6
+
7
+ import numpy as np
8
+
9
+ from juniper_data.core.split import shuffle_and_split
10
+
11
+ from .params import CirclesParams
12
+
13
+ VERSION = "1.0.0"
14
+
15
+
16
+ class CirclesGenerator:
17
+ """NumPy-only generator for concentric circles classification datasets.
18
+
19
+ Generates a binary classification dataset with points distributed on
20
+ two concentric circles. The outer circle is class 0, and the inner
21
+ circle is class 1.
22
+
23
+ All methods are static to ensure the generator is stateless and side-effect free.
24
+ """
25
+
26
+ @staticmethod
27
+ def generate(params: CirclesParams) -> dict[str, np.ndarray]:
28
+ """Generate a complete concentric circles dataset with train/test splits.
29
+
30
+ Args:
31
+ params: CirclesParams instance defining generation configuration.
32
+
33
+ Returns:
34
+ Dictionary containing:
35
+ - X_train: Training features (n_train, 2)
36
+ - y_train: Training labels (n_train, 2)
37
+ - X_test: Test features (n_test, 2)
38
+ - y_test: Test labels (n_test, 2)
39
+ - X_full: Full dataset features (n_samples, 2)
40
+ - y_full: Full dataset labels (n_samples, 2)
41
+ """
42
+ rng = np.random.default_rng(params.seed)
43
+
44
+ X, y = CirclesGenerator._generate_raw(params, rng)
45
+
46
+ split_result = shuffle_and_split(
47
+ X=X,
48
+ y=y,
49
+ train_ratio=params.train_ratio,
50
+ test_ratio=params.test_ratio,
51
+ seed=params.seed,
52
+ shuffle=params.shuffle,
53
+ )
54
+
55
+ return {
56
+ "X_train": split_result["X_train"],
57
+ "y_train": split_result["y_train"],
58
+ "X_test": split_result["X_test"],
59
+ "y_test": split_result["y_test"],
60
+ "X_full": X,
61
+ "y_full": y,
62
+ }
63
+
64
+ @staticmethod
65
+ def _generate_raw(params: CirclesParams, rng: np.random.Generator) -> tuple[np.ndarray, np.ndarray]:
66
+ """Generate raw concentric circles coordinates and labels.
67
+
68
+ Args:
69
+ params: CirclesParams instance defining generation configuration.
70
+ rng: NumPy random generator for reproducibility.
71
+
72
+ Returns:
73
+ Tuple of (X, y) where:
74
+ - X: Feature array of shape (n_samples, 2)
75
+ - y: One-hot label array of shape (n_samples, 2)
76
+ """
77
+ n_inner = int(params.n_samples * params.inner_ratio)
78
+ n_outer = params.n_samples - n_inner
79
+
80
+ inner_radius = params.outer_radius * params.factor
81
+
82
+ outer_angles = rng.uniform(0, 2 * np.pi, n_outer)
83
+ outer_x = params.outer_radius * np.cos(outer_angles)
84
+ outer_y = params.outer_radius * np.sin(outer_angles)
85
+ outer_points = np.column_stack([outer_x, outer_y])
86
+
87
+ inner_angles = rng.uniform(0, 2 * np.pi, n_inner)
88
+ inner_x = inner_radius * np.cos(inner_angles)
89
+ inner_y = inner_radius * np.sin(inner_angles)
90
+ inner_points = np.column_stack([inner_x, inner_y])
91
+
92
+ X = np.vstack([outer_points, inner_points])
93
+
94
+ if params.noise > 0:
95
+ X += rng.standard_normal(X.shape) * params.noise
96
+
97
+ X = X.astype(np.float32)
98
+
99
+ y = np.zeros((params.n_samples, 2), dtype=np.float32)
100
+ y[:n_outer, 0] = 1.0
101
+ y[n_outer:, 1] = 1.0
102
+
103
+ return X, y
104
+
105
+
106
+ def get_schema() -> dict:
107
+ """Return JSON schema describing the generator parameters.
108
+
109
+ Returns:
110
+ JSON schema dictionary for CirclesParams.
111
+ """
112
+ return CirclesParams.model_json_schema()
@@ -0,0 +1,31 @@
1
+ """Parameters for the concentric circles dataset generator."""
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class CirclesParams(BaseModel):
7
+ """Configuration parameters for concentric circles dataset generation.
8
+
9
+ Generates a binary classification dataset with points on two concentric
10
+ circles - an inner circle and an outer circle.
11
+ """
12
+
13
+ n_samples: int = Field(default=100, ge=2, description="Total number of samples")
14
+ outer_radius: float = Field(default=1.0, gt=0, description="Radius of the outer circle")
15
+ factor: float = Field(
16
+ default=0.5,
17
+ gt=0,
18
+ lt=1,
19
+ description="Scale factor between inner and outer circles (inner_radius = outer_radius * factor)",
20
+ )
21
+ noise: float = Field(default=0.0, ge=0, description="Gaussian noise level added to coordinates")
22
+ inner_ratio: float = Field(
23
+ default=0.5,
24
+ gt=0,
25
+ le=1,
26
+ description="Fraction of samples on the inner circle",
27
+ )
28
+ seed: int | None = Field(default=None, ge=0, description="Random seed for reproducibility")
29
+ train_ratio: float = Field(default=0.8, gt=0, le=1, description="Fraction of data for training")
30
+ test_ratio: float = Field(default=0.2, ge=0, le=1, description="Fraction of data for testing")
31
+ shuffle: bool = Field(default=True, description="Shuffle before splitting")
@@ -0,0 +1,15 @@
1
+ """CSV/JSON import generator for custom datasets."""
2
+
3
+ from juniper_data.generators.csv_import.generator import (
4
+ VERSION,
5
+ CsvImportGenerator,
6
+ get_schema,
7
+ )
8
+ from juniper_data.generators.csv_import.params import CsvImportParams
9
+
10
+ __all__ = [
11
+ "CsvImportGenerator",
12
+ "CsvImportParams",
13
+ "VERSION",
14
+ "get_schema",
15
+ ]
@@ -0,0 +1,198 @@
1
+ """CSV/JSON import generator for custom datasets.
2
+
3
+ This module provides the CsvImportGenerator class for loading
4
+ datasets from CSV and JSON files.
5
+ """
6
+
7
+ import csv
8
+ import json
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+
13
+ from juniper_data.core.split import shuffle_and_split
14
+
15
+ from .params import CsvImportParams
16
+
17
+ VERSION = "1.0.0"
18
+
19
+
20
+ class CsvImportGenerator:
21
+ """Generator for importing datasets from CSV/JSON files.
22
+
23
+ Loads data from local files and converts them to the
24
+ JuniperData format with train/test splits.
25
+
26
+ All methods are static to ensure the generator is stateless and side-effect free.
27
+ """
28
+
29
+ @staticmethod
30
+ def generate(params: CsvImportParams) -> dict[str, np.ndarray]:
31
+ """Generate a dataset from a CSV/JSON file with train/test splits.
32
+
33
+ Args:
34
+ params: CsvImportParams instance defining import configuration.
35
+
36
+ Returns:
37
+ Dictionary containing:
38
+ - X_train: Training features
39
+ - y_train: Training labels
40
+ - X_test: Test features
41
+ - y_test: Test labels
42
+ - X_full: Full dataset features
43
+ - y_full: Full dataset labels
44
+
45
+ Raises:
46
+ FileNotFoundError: If the file does not exist.
47
+ ValueError: If the file format is unsupported.
48
+ """
49
+ X, y = CsvImportGenerator._load_and_preprocess(params)
50
+
51
+ split_result = shuffle_and_split(
52
+ X=X,
53
+ y=y,
54
+ train_ratio=params.train_ratio,
55
+ test_ratio=params.test_ratio,
56
+ seed=params.seed,
57
+ shuffle=params.shuffle,
58
+ )
59
+
60
+ return {
61
+ "X_train": split_result["X_train"],
62
+ "y_train": split_result["y_train"],
63
+ "X_test": split_result["X_test"],
64
+ "y_test": split_result["y_test"],
65
+ "X_full": X,
66
+ "y_full": y,
67
+ }
68
+
69
+ @staticmethod
70
+ def _load_and_preprocess(params: CsvImportParams) -> tuple[np.ndarray, np.ndarray]:
71
+ """Load data from file and preprocess.
72
+
73
+ Args:
74
+ params: CsvImportParams instance.
75
+
76
+ Returns:
77
+ Tuple of (X, y) arrays.
78
+ """
79
+
80
+ ##########################################################################################################################################################################
81
+ # TODO: restrict reads to an explicitly configured allowlisted base directory (e.g., JUNIPER_DATA_IMPORT_DIR) and reject absolute paths and path traversal (..) that escape that directory.
82
+ ##########################################################################################################################################################################
83
+ path = Path(params.file_path)
84
+
85
+ if not path.exists():
86
+ raise FileNotFoundError(f"File not found: {params.file_path}")
87
+ ##########################################################################################################################################################################
88
+
89
+ file_format = params.file_format
90
+ if file_format == "auto":
91
+ suffix = path.suffix.lower()
92
+ if suffix == ".csv":
93
+ file_format = "csv"
94
+ elif suffix in {".json", ".jsonl"}:
95
+ file_format = "json"
96
+ else:
97
+ raise ValueError(f"Cannot auto-detect format for extension: {suffix}")
98
+
99
+ if file_format == "csv":
100
+ data = CsvImportGenerator._load_csv(path, params)
101
+ else:
102
+ data = CsvImportGenerator._load_json(path, params)
103
+
104
+ return CsvImportGenerator._convert_to_arrays(data, params)
105
+
106
+ @staticmethod
107
+ def _load_csv(path: Path, params: CsvImportParams) -> list[dict]:
108
+ """Load data from CSV file."""
109
+ data: list[dict] = []
110
+ with open(path, newline="", encoding="utf-8") as f:
111
+ if params.header:
112
+ reader = csv.DictReader(f, delimiter=params.delimiter)
113
+ else:
114
+ csv_reader = csv.reader(f, delimiter=params.delimiter)
115
+ try:
116
+ first_row = next(csv_reader)
117
+ except StopIteration as e:
118
+ raise ValueError("CSV file is empty or contains only a header row") from e
119
+ f.seek(0)
120
+ fieldnames = [f"col_{i}" for i in range(len(first_row))]
121
+ reader = csv.DictReader(f, fieldnames=fieldnames, delimiter=params.delimiter)
122
+
123
+ data.extend(iter(reader))
124
+ return data
125
+
126
+ @staticmethod
127
+ def _load_json(path: Path, params: CsvImportParams) -> list[dict]:
128
+ """Load data from JSON file."""
129
+ with open(path, encoding="utf-8") as f:
130
+ content = f.read().strip()
131
+
132
+ if content.startswith("["):
133
+ data = json.loads(content)
134
+ else:
135
+ data = [json.loads(line) for line in content.split("\n") if line.strip()]
136
+
137
+ return data
138
+
139
+ @staticmethod
140
+ def _convert_to_arrays(data: list[dict], params: CsvImportParams) -> tuple[np.ndarray, np.ndarray]:
141
+ """Convert loaded data to numpy arrays."""
142
+ if not data:
143
+ raise ValueError("No data found in file")
144
+
145
+ all_columns = list(data[0].keys())
146
+
147
+ if params.feature_columns is not None:
148
+ feature_cols = params.feature_columns
149
+ else:
150
+ feature_cols = [c for c in all_columns if c != params.label_column]
151
+
152
+ features = []
153
+ labels = []
154
+
155
+ for row in data:
156
+ feature_row = []
157
+ for col in feature_cols:
158
+ val = row.get(col, 0)
159
+ try:
160
+ feature_row.append(float(val))
161
+ except (ValueError, TypeError):
162
+ feature_row.append(0.0)
163
+ features.append(feature_row)
164
+
165
+ label_val = row.get(params.label_column)
166
+ labels.append(label_val)
167
+
168
+ X = np.array(features, dtype=np.float32)
169
+
170
+ if params.normalize_features:
171
+ X_min = X.min(axis=0, keepdims=True)
172
+ X_max = X.max(axis=0, keepdims=True)
173
+ X_range = X_max - X_min
174
+ X_range[X_range == 0] = 1
175
+ X = (X - X_min) / X_range
176
+
177
+ unique_labels = sorted([str(lbl) for lbl in set(labels)])
178
+ label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
179
+ n_classes = len(unique_labels)
180
+
181
+ label_indices = np.array([label_to_idx[str(lbl)] for lbl in labels])
182
+
183
+ if params.one_hot_labels:
184
+ y = np.zeros((len(labels), n_classes), dtype=np.float32)
185
+ y[np.arange(len(labels)), label_indices] = 1.0
186
+ else:
187
+ y = label_indices.astype(np.float32).reshape(-1, 1)
188
+
189
+ return X, y
190
+
191
+
192
+ def get_schema() -> dict:
193
+ """Return JSON schema describing the generator parameters.
194
+
195
+ Returns:
196
+ JSON schema dictionary for CsvImportParams.
197
+ """
198
+ return CsvImportParams.model_json_schema()
@@ -0,0 +1,48 @@
1
+ """Parameters for the CSV/JSON import generator."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class CsvImportParams(BaseModel):
9
+ """Configuration parameters for CSV/JSON data import.
10
+
11
+ Loads and preprocesses data from CSV or JSON files.
12
+ """
13
+
14
+ file_path: str = Field(
15
+ description="Path to the CSV or JSON file to import",
16
+ )
17
+ file_format: Literal["csv", "json", "auto"] = Field(
18
+ default="auto",
19
+ description="File format: 'csv', 'json', or 'auto' (detect from extension)",
20
+ )
21
+ feature_columns: list[str] | None = Field(
22
+ default=None,
23
+ description="Column names for features (None = all except label column)",
24
+ )
25
+ label_column: str = Field(
26
+ default="label",
27
+ description="Column name for labels",
28
+ )
29
+ delimiter: str = Field(
30
+ default=",",
31
+ description="CSV delimiter character",
32
+ )
33
+ header: bool = Field(
34
+ default=True,
35
+ description="Whether the file has a header row",
36
+ )
37
+ one_hot_labels: bool = Field(
38
+ default=True,
39
+ description="One-hot encode labels",
40
+ )
41
+ normalize_features: bool = Field(
42
+ default=False,
43
+ description="Normalize features to [0, 1]",
44
+ )
45
+ seed: int | None = Field(default=None, ge=0, description="Random seed for reproducibility")
46
+ train_ratio: float = Field(default=0.8, gt=0, le=1, description="Fraction of data for training")
47
+ test_ratio: float = Field(default=0.2, ge=0, le=1, description="Fraction of data for testing")
48
+ shuffle: bool = Field(default=True, description="Shuffle before splitting")
@@ -0,0 +1,11 @@
1
+ """Gaussian blobs classification dataset generator."""
2
+
3
+ from juniper_data.generators.gaussian.generator import VERSION, GaussianGenerator, get_schema
4
+ from juniper_data.generators.gaussian.params import GaussianParams
5
+
6
+ __all__ = [
7
+ "GaussianGenerator",
8
+ "GaussianParams",
9
+ "VERSION",
10
+ "get_schema",
11
+ ]
@@ -0,0 +1,149 @@
1
+ """Core NumPy-only Gaussian blobs dataset generator.
2
+
3
+ This module provides the GaussianGenerator class for generating
4
+ mixture-of-Gaussians classification datasets using only NumPy operations.
5
+ """
6
+
7
+ import numpy as np
8
+
9
+ from juniper_data.core.split import shuffle_and_split
10
+
11
+ from .params import GaussianParams
12
+
13
+ VERSION = "1.0.0"
14
+
15
+
16
+ class GaussianGenerator:
17
+ """NumPy-only generator for Gaussian blobs classification datasets.
18
+
19
+ Generates a mixture-of-Gaussians dataset with configurable centers,
20
+ standard deviations, and noise levels. Each class is sampled from
21
+ a multivariate Gaussian distribution.
22
+
23
+ All methods are static to ensure the generator is stateless and side-effect free.
24
+ """
25
+
26
+ @staticmethod
27
+ def generate(params: GaussianParams) -> dict[str, np.ndarray]:
28
+ """Generate a complete Gaussian blobs dataset with train/test splits.
29
+
30
+ Args:
31
+ params: GaussianParams instance defining generation configuration.
32
+
33
+ Returns:
34
+ Dictionary containing:
35
+ - X_train: Training features (n_train, n_features)
36
+ - y_train: Training labels (n_train, n_classes)
37
+ - X_test: Test features (n_test, n_features)
38
+ - y_test: Test labels (n_test, n_classes)
39
+ - X_full: Full dataset features (total_points, n_features)
40
+ - y_full: Full dataset labels (total_points, n_classes)
41
+ """
42
+ rng = np.random.default_rng(params.seed)
43
+
44
+ X, y = GaussianGenerator._generate_raw(params, rng)
45
+
46
+ split_result = shuffle_and_split(
47
+ X=X,
48
+ y=y,
49
+ train_ratio=params.train_ratio,
50
+ test_ratio=params.test_ratio,
51
+ seed=params.seed,
52
+ shuffle=params.shuffle,
53
+ )
54
+
55
+ return {
56
+ "X_train": split_result["X_train"],
57
+ "y_train": split_result["y_train"],
58
+ "X_test": split_result["X_test"],
59
+ "y_test": split_result["y_test"],
60
+ "X_full": X,
61
+ "y_full": y,
62
+ }
63
+
64
+ @staticmethod
65
+ def _generate_raw(params: GaussianParams, rng: np.random.Generator) -> tuple[np.ndarray, np.ndarray]:
66
+ """Generate raw Gaussian blobs coordinates and labels.
67
+
68
+ Args:
69
+ params: GaussianParams instance defining generation configuration.
70
+ rng: NumPy random generator for reproducibility.
71
+
72
+ Returns:
73
+ Tuple of (X, y) where:
74
+ - X: Feature array of shape (total_points, n_features)
75
+ - y: One-hot label array of shape (total_points, n_classes)
76
+ """
77
+ centers = GaussianGenerator._get_centers(params, rng)
78
+ stds = GaussianGenerator._get_stds(params)
79
+
80
+ total_points = params.n_classes * params.n_samples_per_class
81
+ X = np.zeros((total_points, params.n_features), dtype=np.float32)
82
+ y = np.zeros((total_points, params.n_classes), dtype=np.float32)
83
+
84
+ for i in range(params.n_classes):
85
+ start_idx = i * params.n_samples_per_class
86
+ end_idx = start_idx + params.n_samples_per_class
87
+
88
+ class_std = stds[i] if len(stds) > i else stds[0]
89
+ samples = rng.standard_normal((params.n_samples_per_class, params.n_features))
90
+ X[start_idx:end_idx] = samples * class_std + centers[i]
91
+
92
+ y[start_idx:end_idx, i] = 1.0
93
+
94
+ if params.noise > 0:
95
+ X += rng.standard_normal(X.shape).astype(np.float32) * params.noise
96
+
97
+ return X, y
98
+
99
+ @staticmethod
100
+ def _get_centers(params: GaussianParams, rng: np.random.Generator) -> np.ndarray:
101
+ """Get or generate class centers.
102
+
103
+ Args:
104
+ params: GaussianParams instance.
105
+ rng: NumPy random generator (unused if centers provided).
106
+
107
+ Returns:
108
+ Array of shape (n_classes, n_features) containing class centers.
109
+ """
110
+ if params.centers is not None:
111
+ centers = np.array(params.centers, dtype=np.float32)
112
+ if centers.shape[0] != params.n_classes:
113
+ raise ValueError(f"Number of centers ({centers.shape[0]}) must match n_classes ({params.n_classes})")
114
+ if centers.shape[1] != params.n_features:
115
+ raise ValueError(f"Center dimensions ({centers.shape[1]}) must match n_features ({params.n_features})")
116
+ return centers
117
+
118
+ centers = np.zeros((params.n_classes, params.n_features), dtype=np.float32)
119
+ angles = np.linspace(0, 2 * np.pi, params.n_classes, endpoint=False)
120
+
121
+ for i, angle in enumerate(angles):
122
+ centers[i, 0] = params.center_radius * np.cos(angle)
123
+ if params.n_features > 1:
124
+ centers[i, 1] = params.center_radius * np.sin(angle)
125
+
126
+ return centers
127
+
128
+ @staticmethod
129
+ def _get_stds(params: GaussianParams) -> list[float]:
130
+ """Get standard deviations for each class.
131
+
132
+ Args:
133
+ params: GaussianParams instance.
134
+
135
+ Returns:
136
+ List of standard deviations, one per class.
137
+ """
138
+ if isinstance(params.class_std, list):
139
+ return params.class_std
140
+ return [params.class_std] * params.n_classes
141
+
142
+
143
+ def get_schema() -> dict:
144
+ """Return JSON schema describing the generator parameters.
145
+
146
+ Returns:
147
+ JSON schema dictionary for GaussianParams.
148
+ """
149
+ return GaussianParams.model_json_schema()
@@ -0,0 +1,53 @@
1
+ """Parameters for the Gaussian blobs dataset generator."""
2
+
3
+ from pydantic import BaseModel, Field, field_validator
4
+
5
+
6
+ class GaussianParams(BaseModel):
7
+ """Configuration parameters for Gaussian blobs dataset generation.
8
+
9
+ Generates a mixture-of-Gaussians classification dataset with configurable
10
+ class centers, covariance, and noise levels.
11
+ """
12
+
13
+ n_classes: int = Field(default=2, ge=2, le=10, description="Number of classes/blobs")
14
+ n_samples_per_class: int = Field(default=50, ge=1, description="Number of samples per class")
15
+ n_features: int = Field(default=2, ge=1, description="Number of features/dimensions")
16
+ class_std: float | list[float] = Field(
17
+ default=1.0,
18
+ description="Standard deviation for each class. Single value applies to all classes.",
19
+ )
20
+ centers: list[list[float]] | None = Field(
21
+ default=None,
22
+ description="List of class center coordinates. If None, centers are placed on a circle.",
23
+ )
24
+ center_radius: float = Field(
25
+ default=3.0,
26
+ gt=0,
27
+ description="Radius for auto-placed centers when centers is None",
28
+ )
29
+ noise: float = Field(default=0.0, ge=0, description="Additional Gaussian noise level")
30
+ seed: int | None = Field(default=None, ge=0, description="Random seed for reproducibility")
31
+ train_ratio: float = Field(default=0.8, gt=0, le=1, description="Fraction of data for training")
32
+ test_ratio: float = Field(default=0.2, ge=0, le=1, description="Fraction of data for testing")
33
+ shuffle: bool = Field(default=True, description="Shuffle before splitting")
34
+
35
+ @field_validator("class_std")
36
+ @classmethod
37
+ def validate_class_std(cls, v: float | list[float]) -> float | list[float]:
38
+ """Validate that class_std values are positive."""
39
+ if isinstance(v, list):
40
+ if not all(s > 0 for s in v):
41
+ raise ValueError("All class_std values must be positive")
42
+ elif v <= 0:
43
+ raise ValueError("class_std must be positive")
44
+ return v
45
+
46
+ @field_validator("centers")
47
+ @classmethod
48
+ def validate_centers(cls, v: list[list[float]] | None) -> list[list[float]] | None:
49
+ """Validate centers structure if provided."""
50
+ if v is not None:
51
+ if len(v) == 0:
52
+ raise ValueError("centers list cannot be empty")
53
+ return v
@@ -0,0 +1,11 @@
1
+ """MNIST and Fashion-MNIST dataset generator."""
2
+
3
+ from juniper_data.generators.mnist.generator import VERSION, MnistGenerator, get_schema
4
+ from juniper_data.generators.mnist.params import MnistParams
5
+
6
+ __all__ = [
7
+ "MnistGenerator",
8
+ "MnistParams",
9
+ "VERSION",
10
+ "get_schema",
11
+ ]