juniper-data 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- juniper_data/__init__.py +88 -0
- juniper_data/__main__.py +78 -0
- juniper_data/api/__init__.py +10 -0
- juniper_data/api/app.py +111 -0
- juniper_data/api/middleware.py +95 -0
- juniper_data/api/routes/__init__.py +9 -0
- juniper_data/api/routes/datasets.py +414 -0
- juniper_data/api/routes/generators.py +125 -0
- juniper_data/api/routes/health.py +49 -0
- juniper_data/api/security.py +238 -0
- juniper_data/api/settings.py +109 -0
- juniper_data/core/__init__.py +32 -0
- juniper_data/core/artifacts.py +63 -0
- juniper_data/core/dataset_id.py +38 -0
- juniper_data/core/models.py +135 -0
- juniper_data/core/split.py +120 -0
- juniper_data/generators/__init__.py +15 -0
- juniper_data/generators/arc_agi/__init__.py +11 -0
- juniper_data/generators/arc_agi/generator.py +229 -0
- juniper_data/generators/arc_agi/params.py +56 -0
- juniper_data/generators/checkerboard/__init__.py +15 -0
- juniper_data/generators/checkerboard/generator.py +114 -0
- juniper_data/generators/checkerboard/params.py +32 -0
- juniper_data/generators/circles/__init__.py +11 -0
- juniper_data/generators/circles/generator.py +112 -0
- juniper_data/generators/circles/params.py +31 -0
- juniper_data/generators/csv_import/__init__.py +15 -0
- juniper_data/generators/csv_import/generator.py +198 -0
- juniper_data/generators/csv_import/params.py +48 -0
- juniper_data/generators/gaussian/__init__.py +11 -0
- juniper_data/generators/gaussian/generator.py +149 -0
- juniper_data/generators/gaussian/params.py +53 -0
- juniper_data/generators/mnist/__init__.py +11 -0
- juniper_data/generators/mnist/generator.py +124 -0
- juniper_data/generators/mnist/params.py +39 -0
- juniper_data/generators/spiral/__init__.py +57 -0
- juniper_data/generators/spiral/defaults.py +39 -0
- juniper_data/generators/spiral/generator.py +206 -0
- juniper_data/generators/spiral/params.py +148 -0
- juniper_data/generators/xor/__init__.py +11 -0
- juniper_data/generators/xor/generator.py +162 -0
- juniper_data/generators/xor/params.py +30 -0
- juniper_data/storage/__init__.py +120 -0
- juniper_data/storage/base.py +279 -0
- juniper_data/storage/cached.py +211 -0
- juniper_data/storage/hf_store.py +257 -0
- juniper_data/storage/kaggle_store.py +333 -0
- juniper_data/storage/local_fs.py +232 -0
- juniper_data/storage/memory.py +136 -0
- juniper_data/storage/postgres_store.py +373 -0
- juniper_data/storage/redis_store.py +264 -0
- juniper_data/tests/__init__.py +1 -0
- juniper_data/tests/conftest.py +68 -0
- juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
- juniper_data/tests/integration/__init__.py +1 -0
- juniper_data/tests/integration/test_api.py +283 -0
- juniper_data/tests/integration/test_e2e_workflow.py +378 -0
- juniper_data/tests/integration/test_lifecycle_api.py +304 -0
- juniper_data/tests/integration/test_security_integration.py +189 -0
- juniper_data/tests/integration/test_storage_workflow.py +259 -0
- juniper_data/tests/performance/__init__.py +1 -0
- juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
- juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
- juniper_data/tests/unit/__init__.py +1 -0
- juniper_data/tests/unit/test_api_app.py +206 -0
- juniper_data/tests/unit/test_api_routes.py +407 -0
- juniper_data/tests/unit/test_api_settings.py +100 -0
- juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
- juniper_data/tests/unit/test_artifacts.py +145 -0
- juniper_data/tests/unit/test_cached_store.py +423 -0
- juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
- juniper_data/tests/unit/test_circles_generator.py +256 -0
- juniper_data/tests/unit/test_csv_import_generator.py +345 -0
- juniper_data/tests/unit/test_dataset_id.py +181 -0
- juniper_data/tests/unit/test_gaussian_generator.py +333 -0
- juniper_data/tests/unit/test_hf_store.py +416 -0
- juniper_data/tests/unit/test_init.py +93 -0
- juniper_data/tests/unit/test_kaggle_store.py +469 -0
- juniper_data/tests/unit/test_lifecycle.py +394 -0
- juniper_data/tests/unit/test_main.py +127 -0
- juniper_data/tests/unit/test_middleware.py +79 -0
- juniper_data/tests/unit/test_mnist_generator.py +370 -0
- juniper_data/tests/unit/test_postgres_store.py +490 -0
- juniper_data/tests/unit/test_redis_store.py +500 -0
- juniper_data/tests/unit/test_security.py +281 -0
- juniper_data/tests/unit/test_security_boundaries.py +517 -0
- juniper_data/tests/unit/test_spiral_generator.py +566 -0
- juniper_data/tests/unit/test_split.py +245 -0
- juniper_data/tests/unit/test_storage.py +767 -0
- juniper_data/tests/unit/test_xor_generator.py +223 -0
- juniper_data-0.4.2.dist-info/METADATA +216 -0
- juniper_data-0.4.2.dist-info/RECORD +95 -0
- juniper_data-0.4.2.dist-info/WHEEL +5 -0
- juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
- juniper_data-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""Core NumPy-only XOR dataset generator.
|
|
2
|
+
|
|
3
|
+
This module provides the XorGenerator class for generating XOR
|
|
4
|
+
classification datasets using only NumPy operations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from juniper_data.core.split import shuffle_and_split
|
|
10
|
+
|
|
11
|
+
from .params import XorParams
|
|
12
|
+
|
|
13
|
+
VERSION = "1.0.0"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class XorGenerator:
|
|
17
|
+
"""NumPy-only generator for XOR classification datasets.
|
|
18
|
+
|
|
19
|
+
The XOR dataset consists of 4 quadrants around the origin:
|
|
20
|
+
- Quadrant 1 (++): x > 0, y > 0 -> Class 0
|
|
21
|
+
- Quadrant 2 (-+): x < 0, y > 0 -> Class 1
|
|
22
|
+
- Quadrant 3 (--): x < 0, y < 0 -> Class 0
|
|
23
|
+
- Quadrant 4 (+-): x > 0, y < 0 -> Class 1
|
|
24
|
+
|
|
25
|
+
All methods are static to ensure the generator is stateless and side-effect free.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def generate(params: XorParams) -> dict[str, np.ndarray]:
|
|
30
|
+
"""Generate a complete XOR dataset with train/test splits.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
params: XorParams instance defining generation configuration.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Dictionary containing:
|
|
37
|
+
- X_train: Training features (n_train, 2)
|
|
38
|
+
- y_train: Training labels (n_train, 2)
|
|
39
|
+
- X_test: Test features (n_test, 2)
|
|
40
|
+
- y_test: Test labels (n_test, 2)
|
|
41
|
+
- X_full: Full dataset features (total_points, 2)
|
|
42
|
+
- y_full: Full dataset labels (total_points, 2)
|
|
43
|
+
"""
|
|
44
|
+
rng = np.random.default_rng(params.seed)
|
|
45
|
+
|
|
46
|
+
X, y = XorGenerator._generate_raw(params, rng)
|
|
47
|
+
|
|
48
|
+
split_result = shuffle_and_split(
|
|
49
|
+
X=X,
|
|
50
|
+
y=y,
|
|
51
|
+
train_ratio=params.train_ratio,
|
|
52
|
+
test_ratio=params.test_ratio,
|
|
53
|
+
seed=params.seed,
|
|
54
|
+
shuffle=params.shuffle,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
return {
|
|
58
|
+
"X_train": split_result["X_train"],
|
|
59
|
+
"y_train": split_result["y_train"],
|
|
60
|
+
"X_test": split_result["X_test"],
|
|
61
|
+
"y_test": split_result["y_test"],
|
|
62
|
+
"X_full": X,
|
|
63
|
+
"y_full": y,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def _generate_raw(params: XorParams, rng: np.random.Generator) -> tuple[np.ndarray, np.ndarray]:
|
|
68
|
+
"""Generate raw XOR coordinates and labels.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
params: XorParams instance defining generation configuration.
|
|
72
|
+
rng: NumPy random generator for reproducibility.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Tuple of (X, y) where:
|
|
76
|
+
- X: Feature array of shape (total_points, 2)
|
|
77
|
+
- y: One-hot label array of shape (total_points, 2)
|
|
78
|
+
"""
|
|
79
|
+
n = params.n_points_per_quadrant
|
|
80
|
+
|
|
81
|
+
q1 = XorGenerator._generate_quadrant(
|
|
82
|
+
n_points=n,
|
|
83
|
+
x_min=params.margin,
|
|
84
|
+
x_max=params.x_range,
|
|
85
|
+
y_min=params.margin,
|
|
86
|
+
y_max=params.y_range,
|
|
87
|
+
rng=rng,
|
|
88
|
+
)
|
|
89
|
+
q2 = XorGenerator._generate_quadrant(
|
|
90
|
+
n_points=n,
|
|
91
|
+
x_min=-params.x_range,
|
|
92
|
+
x_max=-params.margin,
|
|
93
|
+
y_min=params.margin,
|
|
94
|
+
y_max=params.y_range,
|
|
95
|
+
rng=rng,
|
|
96
|
+
)
|
|
97
|
+
q3 = XorGenerator._generate_quadrant(
|
|
98
|
+
n_points=n,
|
|
99
|
+
x_min=-params.x_range,
|
|
100
|
+
x_max=-params.margin,
|
|
101
|
+
y_min=-params.y_range,
|
|
102
|
+
y_max=-params.margin,
|
|
103
|
+
rng=rng,
|
|
104
|
+
)
|
|
105
|
+
q4 = XorGenerator._generate_quadrant(
|
|
106
|
+
n_points=n,
|
|
107
|
+
x_min=params.margin,
|
|
108
|
+
x_max=params.x_range,
|
|
109
|
+
y_min=-params.y_range,
|
|
110
|
+
y_max=-params.margin,
|
|
111
|
+
rng=rng,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
X = np.vstack([q1, q2, q3, q4])
|
|
115
|
+
|
|
116
|
+
if params.noise > 0:
|
|
117
|
+
X += rng.standard_normal(X.shape) * params.noise
|
|
118
|
+
|
|
119
|
+
X = X.astype(np.float32)
|
|
120
|
+
|
|
121
|
+
y = np.zeros((4 * n, 2), dtype=np.float32)
|
|
122
|
+
y[0 * n : 1 * n, 0] = 1.0
|
|
123
|
+
y[1 * n : 2 * n, 1] = 1.0
|
|
124
|
+
y[2 * n : 3 * n, 0] = 1.0
|
|
125
|
+
y[3 * n : 4 * n, 1] = 1.0
|
|
126
|
+
|
|
127
|
+
return X, y
|
|
128
|
+
|
|
129
|
+
@staticmethod
|
|
130
|
+
def _generate_quadrant(
|
|
131
|
+
n_points: int,
|
|
132
|
+
x_min: float,
|
|
133
|
+
x_max: float,
|
|
134
|
+
y_min: float,
|
|
135
|
+
y_max: float,
|
|
136
|
+
rng: np.random.Generator,
|
|
137
|
+
) -> np.ndarray:
|
|
138
|
+
"""Generate points uniformly distributed in a rectangular region.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
n_points: Number of points to generate.
|
|
142
|
+
x_min: Minimum x value.
|
|
143
|
+
x_max: Maximum x value.
|
|
144
|
+
y_min: Minimum y value.
|
|
145
|
+
y_max: Maximum y value.
|
|
146
|
+
rng: NumPy random generator.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Array of shape (n_points, 2) containing x, y coordinates.
|
|
150
|
+
"""
|
|
151
|
+
x = rng.uniform(x_min, x_max, n_points)
|
|
152
|
+
y = rng.uniform(y_min, y_max, n_points)
|
|
153
|
+
return np.column_stack([x, y])
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def get_schema() -> dict:
|
|
157
|
+
"""Return JSON schema describing the generator parameters.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
JSON schema dictionary for XorParams.
|
|
161
|
+
"""
|
|
162
|
+
return XorParams.model_json_schema()
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Parameters for the XOR dataset generator."""
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class XorParams(BaseModel):
|
|
7
|
+
"""Configuration parameters for XOR dataset generation.
|
|
8
|
+
|
|
9
|
+
The XOR dataset consists of 4 quadrants around the origin.
|
|
10
|
+
Points in quadrants 1 and 3 (x*y > 0) belong to class 0.
|
|
11
|
+
Points in quadrants 2 and 4 (x*y < 0) belong to class 1.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
n_points_per_quadrant: int = Field(default=50, ge=1, description="Number of points per quadrant")
|
|
15
|
+
x_range: float = Field(
|
|
16
|
+
default=1.0,
|
|
17
|
+
gt=0,
|
|
18
|
+
description="Maximum absolute x value; x is sampled from the interval [-x_range, x_range]",
|
|
19
|
+
)
|
|
20
|
+
y_range: float = Field(
|
|
21
|
+
default=1.0,
|
|
22
|
+
gt=0,
|
|
23
|
+
description="Maximum absolute y value; y is sampled from the interval [-y_range, y_range]",
|
|
24
|
+
)
|
|
25
|
+
margin: float = Field(default=0.1, ge=0, description="Margin around axes (exclusion zone)")
|
|
26
|
+
noise: float = Field(default=0.0, ge=0, description="Gaussian noise level")
|
|
27
|
+
seed: int | None = Field(default=None, ge=0, description="Random seed for reproducibility")
|
|
28
|
+
train_ratio: float = Field(default=0.8, gt=0, le=1, description="Fraction of data for training")
|
|
29
|
+
test_ratio: float = Field(default=0.2, ge=0, le=1, description="Fraction of data for testing")
|
|
30
|
+
shuffle: bool = Field(default=True, description="Shuffle dataset before train/test split")
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Storage module for dataset persistence."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from juniper_data.storage.base import DatasetStore
|
|
6
|
+
from juniper_data.storage.cached import CachedDatasetStore
|
|
7
|
+
from juniper_data.storage.local_fs import LocalFSDatasetStore
|
|
8
|
+
from juniper_data.storage.memory import InMemoryDatasetStore
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from juniper_data.storage.hf_store import HuggingFaceDatasetStore
|
|
12
|
+
from juniper_data.storage.kaggle_store import KaggleDatasetStore
|
|
13
|
+
from juniper_data.storage.postgres_store import PostgresDatasetStore
|
|
14
|
+
from juniper_data.storage.redis_store import RedisDatasetStore
|
|
15
|
+
else:
|
|
16
|
+
try:
|
|
17
|
+
from juniper_data.storage.redis_store import RedisDatasetStore
|
|
18
|
+
except ImportError:
|
|
19
|
+
RedisDatasetStore = None
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
from juniper_data.storage.hf_store import HuggingFaceDatasetStore
|
|
23
|
+
except ImportError:
|
|
24
|
+
HuggingFaceDatasetStore = None
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
from juniper_data.storage.postgres_store import PostgresDatasetStore
|
|
28
|
+
except ImportError:
|
|
29
|
+
PostgresDatasetStore = None
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
from juniper_data.storage.kaggle_store import KaggleDatasetStore
|
|
33
|
+
except ImportError:
|
|
34
|
+
KaggleDatasetStore = None
|
|
35
|
+
__all__ = [
|
|
36
|
+
"DatasetStore",
|
|
37
|
+
"CachedDatasetStore",
|
|
38
|
+
"LocalFSDatasetStore",
|
|
39
|
+
"InMemoryDatasetStore",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
if "RedisDatasetStore" in globals() and RedisDatasetStore is not None:
|
|
43
|
+
__all__.append("RedisDatasetStore")
|
|
44
|
+
|
|
45
|
+
if "HuggingFaceDatasetStore" in globals() and HuggingFaceDatasetStore is not None:
|
|
46
|
+
__all__.append("HuggingFaceDatasetStore")
|
|
47
|
+
|
|
48
|
+
if "PostgresDatasetStore" in globals() and PostgresDatasetStore is not None:
|
|
49
|
+
__all__.append("PostgresDatasetStore")
|
|
50
|
+
|
|
51
|
+
if "KaggleDatasetStore" in globals() and KaggleDatasetStore is not None:
|
|
52
|
+
__all__.append("KaggleDatasetStore")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_redis_store(**kwargs) -> "RedisDatasetStore": # type: ignore[no-untyped-def]
|
|
56
|
+
"""Get a Redis dataset store (requires redis package).
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
**kwargs: Arguments passed to RedisDatasetStore.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
RedisDatasetStore instance.
|
|
63
|
+
|
|
64
|
+
Raises:
|
|
65
|
+
ImportError: If redis package is not installed.
|
|
66
|
+
"""
|
|
67
|
+
from juniper_data.storage.redis_store import RedisDatasetStore
|
|
68
|
+
|
|
69
|
+
return RedisDatasetStore(**kwargs)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_hf_store(**kwargs) -> "HuggingFaceDatasetStore": # type: ignore[no-untyped-def]
|
|
73
|
+
"""Get a Hugging Face dataset store (requires datasets package).
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
**kwargs: Arguments passed to HuggingFaceDatasetStore.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
HuggingFaceDatasetStore instance.
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
ImportError: If datasets package is not installed.
|
|
83
|
+
"""
|
|
84
|
+
from juniper_data.storage.hf_store import HuggingFaceDatasetStore
|
|
85
|
+
|
|
86
|
+
return HuggingFaceDatasetStore(**kwargs)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def get_postgres_store(**kwargs) -> "PostgresDatasetStore": # type: ignore[no-untyped-def]
|
|
90
|
+
"""Get a PostgreSQL dataset store (requires psycopg2 package).
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
**kwargs: Arguments passed to PostgresDatasetStore.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
PostgresDatasetStore instance.
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ImportError: If psycopg2 package is not installed.
|
|
100
|
+
"""
|
|
101
|
+
from juniper_data.storage.postgres_store import PostgresDatasetStore
|
|
102
|
+
|
|
103
|
+
return PostgresDatasetStore(**kwargs)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def get_kaggle_store(**kwargs) -> "KaggleDatasetStore": # type: ignore[no-untyped-def]
|
|
107
|
+
"""Get a Kaggle dataset store (requires kaggle package).
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
**kwargs: Arguments passed to KaggleDatasetStore.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
KaggleDatasetStore instance.
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
ImportError: If kaggle package is not installed.
|
|
117
|
+
"""
|
|
118
|
+
from juniper_data.storage.kaggle_store import KaggleDatasetStore
|
|
119
|
+
|
|
120
|
+
return KaggleDatasetStore(**kwargs)
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
"""Abstract base class for dataset storage."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
# from collections.abc import Callable
|
|
6
|
+
from datetime import UTC, datetime
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from juniper_data.core.models import DatasetMeta
|
|
11
|
+
|
|
12
|
+
# from typing import Dict, List, Optional
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DatasetStore(ABC):
|
|
16
|
+
"""Abstract dataset storage interface.
|
|
17
|
+
|
|
18
|
+
Provides a common interface for storing and retrieving datasets,
|
|
19
|
+
supporting different backends (in-memory, local filesystem, cloud, etc.).
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def save(
|
|
24
|
+
self,
|
|
25
|
+
dataset_id: str,
|
|
26
|
+
meta: DatasetMeta,
|
|
27
|
+
arrays: dict[str, np.ndarray],
|
|
28
|
+
) -> None:
|
|
29
|
+
"""Save dataset metadata and arrays.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
dataset_id: Unique identifier for the dataset.
|
|
33
|
+
meta: Dataset metadata.
|
|
34
|
+
arrays: Dictionary of numpy arrays (e.g., X_train, y_train, etc.).
|
|
35
|
+
|
|
36
|
+
Raises:
|
|
37
|
+
IOError: If the save operation fails.
|
|
38
|
+
"""
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def get_meta(self, dataset_id: str) -> DatasetMeta | None:
|
|
43
|
+
"""Get dataset metadata.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
dataset_id: Unique identifier for the dataset.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Dataset metadata if found, None otherwise.
|
|
50
|
+
"""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def get_artifact_bytes(self, dataset_id: str) -> bytes | None:
|
|
55
|
+
"""Get dataset artifact as bytes (NPZ format).
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
dataset_id: Unique identifier for the dataset.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
NPZ file contents as bytes if found, None otherwise.
|
|
62
|
+
"""
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
@abstractmethod
|
|
66
|
+
def exists(self, dataset_id: str) -> bool:
|
|
67
|
+
"""Check if dataset exists.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
dataset_id: Unique identifier for the dataset.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
True if the dataset exists, False otherwise.
|
|
74
|
+
"""
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def delete(self, dataset_id: str) -> bool:
|
|
79
|
+
"""Delete dataset.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
dataset_id: Unique identifier for the dataset.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
True if the dataset was deleted, False if it didn't exist.
|
|
86
|
+
"""
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
@abstractmethod
|
|
90
|
+
def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]:
|
|
91
|
+
"""List dataset IDs.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
limit: Maximum number of dataset IDs to return.
|
|
95
|
+
offset: Number of dataset IDs to skip.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
List of dataset IDs.
|
|
99
|
+
"""
|
|
100
|
+
pass
|
|
101
|
+
|
|
102
|
+
def update_meta(self, dataset_id: str, meta: DatasetMeta) -> bool:
|
|
103
|
+
"""Update dataset metadata.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
dataset_id: Unique identifier for the dataset.
|
|
107
|
+
meta: Updated dataset metadata.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
True if the dataset was updated, False if it didn't exist.
|
|
111
|
+
"""
|
|
112
|
+
raise NotImplementedError("update_meta not implemented for this storage backend")
|
|
113
|
+
|
|
114
|
+
def list_all_metadata(self) -> list[DatasetMeta]:
|
|
115
|
+
"""List all dataset metadata (for filtering/stats).
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
List of all DatasetMeta objects.
|
|
119
|
+
"""
|
|
120
|
+
raise NotImplementedError("list_all_metadata not implemented for this storage backend")
|
|
121
|
+
|
|
122
|
+
def record_access(self, dataset_id: str) -> None:
|
|
123
|
+
"""Record an access to a dataset (updates last_accessed_at and access_count).
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
dataset_id: Unique identifier for the dataset.
|
|
127
|
+
"""
|
|
128
|
+
meta = self.get_meta(dataset_id)
|
|
129
|
+
if meta is not None:
|
|
130
|
+
meta.last_accessed_at = datetime.now(UTC)
|
|
131
|
+
meta.access_count += 1
|
|
132
|
+
self.update_meta(dataset_id, meta)
|
|
133
|
+
|
|
134
|
+
def is_expired(self, meta: DatasetMeta) -> bool:
|
|
135
|
+
"""Check if a dataset has expired based on its TTL.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
meta: Dataset metadata.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
True if the dataset has expired, False otherwise.
|
|
142
|
+
"""
|
|
143
|
+
if meta.expires_at is None:
|
|
144
|
+
return False
|
|
145
|
+
return datetime.now(UTC) > meta.expires_at
|
|
146
|
+
|
|
147
|
+
def delete_expired(self) -> list[str]:
|
|
148
|
+
"""Delete all expired datasets.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
List of dataset IDs that were deleted.
|
|
152
|
+
"""
|
|
153
|
+
deleted: list[str] = []
|
|
154
|
+
deleted.extend(
|
|
155
|
+
meta.dataset_id
|
|
156
|
+
for meta in self.list_all_metadata()
|
|
157
|
+
if self.is_expired(meta) and self.delete(meta.dataset_id)
|
|
158
|
+
)
|
|
159
|
+
return deleted
|
|
160
|
+
|
|
161
|
+
def filter_datasets(
|
|
162
|
+
self,
|
|
163
|
+
generator: str | None = None,
|
|
164
|
+
tags: list[str] | None = None,
|
|
165
|
+
tags_match: str = "any",
|
|
166
|
+
created_after: datetime | None = None,
|
|
167
|
+
created_before: datetime | None = None,
|
|
168
|
+
min_samples: int | None = None,
|
|
169
|
+
max_samples: int | None = None,
|
|
170
|
+
include_expired: bool = False,
|
|
171
|
+
limit: int = 100,
|
|
172
|
+
offset: int = 0,
|
|
173
|
+
) -> tuple[list[DatasetMeta], int]:
|
|
174
|
+
"""Filter datasets by various criteria.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
generator: Filter by generator name.
|
|
178
|
+
tags: Filter by tags.
|
|
179
|
+
tags_match: "any" (OR) or "all" (AND) for tag matching.
|
|
180
|
+
created_after: Filter by creation date (after).
|
|
181
|
+
created_before: Filter by creation date (before).
|
|
182
|
+
min_samples: Minimum number of samples.
|
|
183
|
+
max_samples: Maximum number of samples.
|
|
184
|
+
include_expired: Include expired datasets.
|
|
185
|
+
limit: Maximum number of results.
|
|
186
|
+
offset: Number of results to skip.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Tuple of (filtered metadata list, total count before pagination).
|
|
190
|
+
"""
|
|
191
|
+
all_meta = self.list_all_metadata()
|
|
192
|
+
filtered = []
|
|
193
|
+
|
|
194
|
+
for meta in all_meta:
|
|
195
|
+
if not include_expired and self.is_expired(meta):
|
|
196
|
+
continue
|
|
197
|
+
if generator is not None and meta.generator != generator:
|
|
198
|
+
continue
|
|
199
|
+
if tags is not None:
|
|
200
|
+
if tags_match == "all":
|
|
201
|
+
if any(t not in meta.tags for t in tags):
|
|
202
|
+
continue
|
|
203
|
+
elif all(t not in meta.tags for t in tags):
|
|
204
|
+
continue
|
|
205
|
+
if created_after is not None and meta.created_at < created_after:
|
|
206
|
+
continue
|
|
207
|
+
if created_before is not None and meta.created_at > created_before:
|
|
208
|
+
continue
|
|
209
|
+
if min_samples is not None and meta.n_samples < min_samples:
|
|
210
|
+
continue
|
|
211
|
+
if max_samples is not None and meta.n_samples > max_samples:
|
|
212
|
+
continue
|
|
213
|
+
filtered.append(meta)
|
|
214
|
+
|
|
215
|
+
filtered.sort(key=lambda m: m.created_at, reverse=True)
|
|
216
|
+
total = len(filtered)
|
|
217
|
+
return filtered[offset : offset + limit], total
|
|
218
|
+
|
|
219
|
+
def batch_delete(self, dataset_ids: list[str]) -> tuple[list[str], list[str]]:
|
|
220
|
+
"""Delete multiple datasets.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
dataset_ids: List of dataset IDs to delete.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Tuple of (deleted IDs, not found IDs).
|
|
227
|
+
"""
|
|
228
|
+
deleted = []
|
|
229
|
+
not_found = []
|
|
230
|
+
for dataset_id in dataset_ids:
|
|
231
|
+
if self.delete(dataset_id):
|
|
232
|
+
deleted.append(dataset_id)
|
|
233
|
+
else:
|
|
234
|
+
not_found.append(dataset_id)
|
|
235
|
+
return deleted, not_found
|
|
236
|
+
|
|
237
|
+
def get_stats(self) -> dict[str, object]:
|
|
238
|
+
"""Get aggregate statistics about stored datasets.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Dictionary with statistics.
|
|
242
|
+
"""
|
|
243
|
+
all_meta = self.list_all_metadata()
|
|
244
|
+
|
|
245
|
+
if not all_meta:
|
|
246
|
+
return {
|
|
247
|
+
"total_datasets": 0,
|
|
248
|
+
"total_samples": 0,
|
|
249
|
+
"by_generator": {},
|
|
250
|
+
"by_tag": {},
|
|
251
|
+
"oldest_created_at": None,
|
|
252
|
+
"newest_created_at": None,
|
|
253
|
+
"expired_count": 0,
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
by_generator: dict[str, int] = {}
|
|
257
|
+
by_tag: dict[str, int] = {}
|
|
258
|
+
total_samples = 0
|
|
259
|
+
expired_count = 0
|
|
260
|
+
created_times = []
|
|
261
|
+
|
|
262
|
+
for meta in all_meta:
|
|
263
|
+
by_generator[meta.generator] = by_generator.get(meta.generator, 0) + 1
|
|
264
|
+
for tag in meta.tags:
|
|
265
|
+
by_tag[tag] = by_tag.get(tag, 0) + 1
|
|
266
|
+
total_samples += meta.n_samples
|
|
267
|
+
created_times.append(meta.created_at)
|
|
268
|
+
if self.is_expired(meta):
|
|
269
|
+
expired_count += 1
|
|
270
|
+
|
|
271
|
+
return {
|
|
272
|
+
"total_datasets": len(all_meta),
|
|
273
|
+
"total_samples": total_samples,
|
|
274
|
+
"by_generator": by_generator,
|
|
275
|
+
"by_tag": by_tag,
|
|
276
|
+
"oldest_created_at": min(created_times),
|
|
277
|
+
"newest_created_at": max(created_times),
|
|
278
|
+
"expired_count": expired_count,
|
|
279
|
+
}
|