juniper-data 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. juniper_data/__init__.py +88 -0
  2. juniper_data/__main__.py +78 -0
  3. juniper_data/api/__init__.py +10 -0
  4. juniper_data/api/app.py +111 -0
  5. juniper_data/api/middleware.py +95 -0
  6. juniper_data/api/routes/__init__.py +9 -0
  7. juniper_data/api/routes/datasets.py +414 -0
  8. juniper_data/api/routes/generators.py +125 -0
  9. juniper_data/api/routes/health.py +49 -0
  10. juniper_data/api/security.py +238 -0
  11. juniper_data/api/settings.py +109 -0
  12. juniper_data/core/__init__.py +32 -0
  13. juniper_data/core/artifacts.py +63 -0
  14. juniper_data/core/dataset_id.py +38 -0
  15. juniper_data/core/models.py +135 -0
  16. juniper_data/core/split.py +120 -0
  17. juniper_data/generators/__init__.py +15 -0
  18. juniper_data/generators/arc_agi/__init__.py +11 -0
  19. juniper_data/generators/arc_agi/generator.py +229 -0
  20. juniper_data/generators/arc_agi/params.py +56 -0
  21. juniper_data/generators/checkerboard/__init__.py +15 -0
  22. juniper_data/generators/checkerboard/generator.py +114 -0
  23. juniper_data/generators/checkerboard/params.py +32 -0
  24. juniper_data/generators/circles/__init__.py +11 -0
  25. juniper_data/generators/circles/generator.py +112 -0
  26. juniper_data/generators/circles/params.py +31 -0
  27. juniper_data/generators/csv_import/__init__.py +15 -0
  28. juniper_data/generators/csv_import/generator.py +198 -0
  29. juniper_data/generators/csv_import/params.py +48 -0
  30. juniper_data/generators/gaussian/__init__.py +11 -0
  31. juniper_data/generators/gaussian/generator.py +149 -0
  32. juniper_data/generators/gaussian/params.py +53 -0
  33. juniper_data/generators/mnist/__init__.py +11 -0
  34. juniper_data/generators/mnist/generator.py +124 -0
  35. juniper_data/generators/mnist/params.py +39 -0
  36. juniper_data/generators/spiral/__init__.py +57 -0
  37. juniper_data/generators/spiral/defaults.py +39 -0
  38. juniper_data/generators/spiral/generator.py +206 -0
  39. juniper_data/generators/spiral/params.py +148 -0
  40. juniper_data/generators/xor/__init__.py +11 -0
  41. juniper_data/generators/xor/generator.py +162 -0
  42. juniper_data/generators/xor/params.py +30 -0
  43. juniper_data/storage/__init__.py +120 -0
  44. juniper_data/storage/base.py +279 -0
  45. juniper_data/storage/cached.py +211 -0
  46. juniper_data/storage/hf_store.py +257 -0
  47. juniper_data/storage/kaggle_store.py +333 -0
  48. juniper_data/storage/local_fs.py +232 -0
  49. juniper_data/storage/memory.py +136 -0
  50. juniper_data/storage/postgres_store.py +373 -0
  51. juniper_data/storage/redis_store.py +264 -0
  52. juniper_data/tests/__init__.py +1 -0
  53. juniper_data/tests/conftest.py +68 -0
  54. juniper_data/tests/fixtures/generate_golden_datasets.py +199 -0
  55. juniper_data/tests/integration/__init__.py +1 -0
  56. juniper_data/tests/integration/test_api.py +283 -0
  57. juniper_data/tests/integration/test_e2e_workflow.py +378 -0
  58. juniper_data/tests/integration/test_lifecycle_api.py +304 -0
  59. juniper_data/tests/integration/test_security_integration.py +189 -0
  60. juniper_data/tests/integration/test_storage_workflow.py +259 -0
  61. juniper_data/tests/performance/__init__.py +1 -0
  62. juniper_data/tests/performance/test_generator_benchmarks.py +178 -0
  63. juniper_data/tests/performance/test_storage_benchmarks.py +257 -0
  64. juniper_data/tests/unit/__init__.py +1 -0
  65. juniper_data/tests/unit/test_api_app.py +206 -0
  66. juniper_data/tests/unit/test_api_routes.py +407 -0
  67. juniper_data/tests/unit/test_api_settings.py +100 -0
  68. juniper_data/tests/unit/test_arc_agi_generator.py +525 -0
  69. juniper_data/tests/unit/test_artifacts.py +145 -0
  70. juniper_data/tests/unit/test_cached_store.py +423 -0
  71. juniper_data/tests/unit/test_checkerboard_generator.py +232 -0
  72. juniper_data/tests/unit/test_circles_generator.py +256 -0
  73. juniper_data/tests/unit/test_csv_import_generator.py +345 -0
  74. juniper_data/tests/unit/test_dataset_id.py +181 -0
  75. juniper_data/tests/unit/test_gaussian_generator.py +333 -0
  76. juniper_data/tests/unit/test_hf_store.py +416 -0
  77. juniper_data/tests/unit/test_init.py +93 -0
  78. juniper_data/tests/unit/test_kaggle_store.py +469 -0
  79. juniper_data/tests/unit/test_lifecycle.py +394 -0
  80. juniper_data/tests/unit/test_main.py +127 -0
  81. juniper_data/tests/unit/test_middleware.py +79 -0
  82. juniper_data/tests/unit/test_mnist_generator.py +370 -0
  83. juniper_data/tests/unit/test_postgres_store.py +490 -0
  84. juniper_data/tests/unit/test_redis_store.py +500 -0
  85. juniper_data/tests/unit/test_security.py +281 -0
  86. juniper_data/tests/unit/test_security_boundaries.py +517 -0
  87. juniper_data/tests/unit/test_spiral_generator.py +566 -0
  88. juniper_data/tests/unit/test_split.py +245 -0
  89. juniper_data/tests/unit/test_storage.py +767 -0
  90. juniper_data/tests/unit/test_xor_generator.py +223 -0
  91. juniper_data-0.4.2.dist-info/METADATA +216 -0
  92. juniper_data-0.4.2.dist-info/RECORD +95 -0
  93. juniper_data-0.4.2.dist-info/WHEEL +5 -0
  94. juniper_data-0.4.2.dist-info/licenses/LICENSE +9 -0
  95. juniper_data-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,238 @@
1
+ """API security: authentication and rate limiting middleware."""
2
+
3
+ import time
4
+ from collections import defaultdict
5
+ from threading import Lock
6
+
7
+ from fastapi import HTTPException, Request, status
8
+ from fastapi.security import APIKeyHeader
9
+
10
+ from .settings import get_settings
11
+
12
+ api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
13
+
14
+
15
+ class APIKeyAuth:
16
+ """API key authentication handler.
17
+
18
+ Validates requests against configured API keys. When no API keys are
19
+ configured, authentication is disabled (open access mode for development).
20
+ """
21
+
22
+ def __init__(self, api_keys: list[str] | None = None) -> None:
23
+ """Initialize with optional list of valid API keys.
24
+
25
+ Args:
26
+ api_keys: List of valid API keys. If None or empty, auth is disabled.
27
+ """
28
+ self._api_keys: set[str] = set(api_keys) if api_keys else set()
29
+ self._enabled = len(self._api_keys) > 0
30
+
31
+ @property
32
+ def enabled(self) -> bool:
33
+ """Check if authentication is enabled."""
34
+ return self._enabled
35
+
36
+ def validate(self, api_key: str | None) -> bool:
37
+ """Validate an API key.
38
+
39
+ Args:
40
+ api_key: The API key to validate.
41
+
42
+ Returns:
43
+ True if auth is disabled or key is valid, False otherwise.
44
+ """
45
+ if not self._enabled:
46
+ return True
47
+ if api_key is None:
48
+ return False
49
+ return api_key in self._api_keys
50
+
51
+ async def __call__(self, request: Request) -> str | None:
52
+ """FastAPI dependency for API key validation.
53
+
54
+ Args:
55
+ request: The incoming request.
56
+
57
+ Returns:
58
+ The validated API key, or None if auth is disabled.
59
+
60
+ Raises:
61
+ HTTPException: 401 if auth is enabled and key is invalid/missing.
62
+ """
63
+ api_key = request.headers.get("X-API-Key")
64
+
65
+ if not self._enabled:
66
+ return None
67
+
68
+ if api_key is None:
69
+ raise HTTPException(
70
+ status_code=status.HTTP_401_UNAUTHORIZED,
71
+ detail="Missing API key. Provide X-API-Key header.",
72
+ )
73
+
74
+ if not self.validate(api_key):
75
+ raise HTTPException(
76
+ status_code=status.HTTP_401_UNAUTHORIZED,
77
+ detail="Invalid API key.",
78
+ )
79
+
80
+ return api_key
81
+
82
+
83
+ class RateLimiter:
84
+ """In-memory fixed-window rate limiter.
85
+
86
+ Tracks request counts per key within fixed time windows. Thread-safe
87
+ implementation suitable for single-process deployments.
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ requests_per_minute: int = 60,
93
+ window_seconds: int = 60,
94
+ enabled: bool = True,
95
+ ) -> None:
96
+ """Initialize the rate limiter.
97
+
98
+ Args:
99
+ requests_per_minute: Maximum requests allowed per window.
100
+ window_seconds: Window duration in seconds.
101
+ enabled: Whether rate limiting is enabled.
102
+ """
103
+ self._limit = requests_per_minute
104
+ self._window = window_seconds
105
+ self._enabled = enabled
106
+ self._counters: dict[str, tuple[int, float]] = defaultdict(lambda: (0, 0.0))
107
+ self._lock = Lock()
108
+
109
+ @property
110
+ def enabled(self) -> bool:
111
+ """Check if rate limiting is enabled."""
112
+ return self._enabled
113
+
114
+ @property
115
+ def limit(self) -> int:
116
+ """Get the rate limit."""
117
+ return self._limit
118
+
119
+ @property
120
+ def window(self) -> int:
121
+ """Get the window duration in seconds."""
122
+ return self._window
123
+
124
+ def _get_key(self, request: Request, api_key: str | None) -> str:
125
+ """Generate a rate limit key for the request.
126
+
127
+ Uses API key if available, otherwise falls back to client IP.
128
+
129
+ Args:
130
+ request: The incoming request.
131
+ api_key: The authenticated API key, if any.
132
+
133
+ Returns:
134
+ A string key for rate limiting.
135
+ """
136
+ if api_key:
137
+ return f"key:{api_key}"
138
+ client_ip = request.client.host if request.client else "unknown"
139
+ return f"ip:{client_ip}"
140
+
141
+ def check(self, key: str) -> tuple[bool, int, int]:
142
+ """Check if a request is allowed under rate limit.
143
+
144
+ Args:
145
+ key: The rate limit key.
146
+
147
+ Returns:
148
+ Tuple of (allowed, remaining, reset_seconds).
149
+ """
150
+ if not self._enabled:
151
+ return (True, self._limit, self._window)
152
+
153
+ now = time.time()
154
+
155
+ with self._lock:
156
+ count, window_start = self._counters[key]
157
+
158
+ if now - window_start >= self._window:
159
+ self._counters[key] = (1, now)
160
+ return (True, self._limit - 1, self._window)
161
+
162
+ if count >= self._limit:
163
+ reset_in = int(self._window - (now - window_start))
164
+ return (False, 0, reset_in)
165
+
166
+ self._counters[key] = (count + 1, window_start)
167
+ return (True, self._limit - count - 1, int(self._window - (now - window_start)))
168
+
169
+ async def __call__(self, request: Request, api_key: str | None = None) -> None:
170
+ """FastAPI dependency for rate limit checking.
171
+
172
+ Args:
173
+ request: The incoming request.
174
+ api_key: The authenticated API key, if any.
175
+
176
+ Raises:
177
+ HTTPException: 429 if rate limit exceeded.
178
+ """
179
+ if not self._enabled:
180
+ return
181
+
182
+ key = self._get_key(request, api_key)
183
+ allowed, remaining, reset_in = self.check(key)
184
+
185
+ request.state.rate_limit_remaining = remaining
186
+ request.state.rate_limit_reset = reset_in
187
+
188
+ if not allowed:
189
+ raise HTTPException(
190
+ status_code=status.HTTP_429_TOO_MANY_REQUESTS,
191
+ detail=f"Rate limit exceeded. Try again in {reset_in} seconds.",
192
+ headers={
193
+ "X-RateLimit-Limit": str(self._limit),
194
+ "X-RateLimit-Remaining": "0",
195
+ "X-RateLimit-Reset": str(reset_in),
196
+ "Retry-After": str(reset_in),
197
+ },
198
+ )
199
+
200
+ def reset(self) -> None:
201
+ """Reset all rate limit counters. Useful for testing."""
202
+ with self._lock:
203
+ self._counters.clear()
204
+
205
+
206
+ _api_key_auth: APIKeyAuth | None = None
207
+ _rate_limiter: RateLimiter | None = None
208
+
209
+
210
+ def get_api_key_auth() -> APIKeyAuth:
211
+ """Get the global API key auth handler, creating if needed."""
212
+ global _api_key_auth
213
+ if _api_key_auth is None:
214
+ settings = get_settings()
215
+ api_keys = getattr(settings, "api_keys", None)
216
+ _api_key_auth = APIKeyAuth(api_keys)
217
+ return _api_key_auth
218
+
219
+
220
+ def get_rate_limiter() -> RateLimiter:
221
+ """Get the global rate limiter, creating if needed."""
222
+ global _rate_limiter
223
+ if _rate_limiter is None:
224
+ settings = get_settings()
225
+ enabled = getattr(settings, "rate_limit_enabled", False)
226
+ requests_per_minute = getattr(settings, "rate_limit_requests_per_minute", 60)
227
+ _rate_limiter = RateLimiter(
228
+ requests_per_minute=requests_per_minute,
229
+ enabled=enabled,
230
+ )
231
+ return _rate_limiter
232
+
233
+
234
+ def reset_security_state() -> None:
235
+ """Reset global security state. Useful for testing."""
236
+ global _api_key_auth, _rate_limiter
237
+ _api_key_auth = None
238
+ _rate_limiter = None
@@ -0,0 +1,109 @@
1
+ """API configuration settings using pydantic-settings."""
2
+
3
+ # import json
4
+ from functools import lru_cache
5
+
6
+ from pydantic_settings import BaseSettings, SettingsConfigDict
7
+
8
+ # Define Safe and Reasonable Defaults for API Model Config
9
+ _JUNIPER_DATA_ENV_PREFIX: str = "JUNIPER_DATA_"
10
+ _JUNIPER_DATA_ENV_PREFIX_DEFAULT: str = _JUNIPER_DATA_ENV_PREFIX
11
+
12
+ _JUNIPER_DATA_ENV_FILE: str = ".env"
13
+ _JUNIPER_DATA_ENV_FILE_DEFAULT: str = _JUNIPER_DATA_ENV_FILE
14
+
15
+ _JUNIPER_DATA_ENV_FILE_ENCODING: str = "utf-8"
16
+ _JUNIPER_DATA_ENV_FILE_ENCODING_DEFAULT: str = _JUNIPER_DATA_ENV_FILE_ENCODING
17
+
18
+ _JUNIPER_DATA_ENV_CASE_SENSITIVE_ENABLED: bool = True
19
+ _JUNIPER_DATA_ENV_CASE_SENSITIVE_DISABLED: bool = False
20
+ _JUNIPER_DATA_ENV_CASE_SENSITIVE_DEFAULT: bool = _JUNIPER_DATA_ENV_CASE_SENSITIVE_DISABLED
21
+
22
+ _JUNIPER_DATA_ENV_EXTRA_DISABLED: str = "ignore"
23
+ _JUNIPER_DATA_ENV_EXTRA_DEFAULT: str = _JUNIPER_DATA_ENV_EXTRA_DISABLED
24
+
25
+ # Define Safe and Reasonable Defaults for API Settings
26
+ _JUNIPER_DATA_API_DATASET_PATH: str = "./data/datasets"
27
+ _JUNIPER_DATA_API_STORAGE_PATH_DEFAULT: str = _JUNIPER_DATA_API_DATASET_PATH
28
+
29
+ _JUNIPER_DATA_API_HOST_GLOBAL: str = "0.0.0.0" # nosec B104
30
+ _JUNIPER_DATA_API_HOST_LOCAL: str = "127.0.0.1"
31
+ _JUNIPER_DATA_API_HOST_DEFAULT: str = _JUNIPER_DATA_API_HOST_LOCAL
32
+
33
+ _JUNIPER_DATA_API_PORT: int = 8100
34
+ _JUNIPER_DATA_API_PORT_DEFAULT: int = _JUNIPER_DATA_API_PORT
35
+
36
+ _JUNIPER_DATA_API_LOGLEVEL_TRACE: str = "TRACE"
37
+ _JUNIPER_DATA_API_LOGLEVEL_VERBOSE: str = "VERBOSE"
38
+ _JUNIPER_DATA_API_LOGLEVEL_DEBUG: str = "DEBUG"
39
+ _JUNIPER_DATA_API_LOGLEVEL_INFO: str = "INFO"
40
+ _JUNIPER_DATA_API_LOGLEVEL_WARNING: str = "WARNING"
41
+ _JUNIPER_DATA_API_LOGLEVEL_ERROR: str = "ERROR"
42
+ _JUNIPER_DATA_API_LOGLEVEL_CRITICAL: str = "CRITICAL"
43
+ _JUNIPER_DATA_API_LOGLEVEL_FATAL: str = "FATAL"
44
+ _JUNIPER_DATA_API_LOGLEVEL_DEFAULT: str = _JUNIPER_DATA_API_LOGLEVEL_INFO
45
+
46
+ _JUNIPER_DATA_API_RATELIMIT_DISABLED: bool = False
47
+ _JUNIPER_DATA_API_RATELIMIT_ENABLED: bool = True
48
+ _JUNIPER_DATA_API_RATELIMIT_ACTIVE_DEFAULT: bool = _JUNIPER_DATA_API_RATELIMIT_DISABLED
49
+
50
+ _JUNIPER_DATA_API_RATELIMIT_VALUE_SLOW: int = 30 # Requests per Minute
51
+ _JUNIPER_DATA_API_RATELIMIT_VALUE_MID: int = 60 # Requests per Minute
52
+ _JUNIPER_DATA_API_RATELIMIT_VALUE_FAST: int = 120 # Requests per Minute
53
+ _JUNIPER_DATA_API_RATELIMIT_DEFAULT: int = _JUNIPER_DATA_API_RATELIMIT_VALUE_MID
54
+
55
+ _JUNIPER_DATA_API_CORS_ORIGINS_ALL: list[str] = ["*"]
56
+ _JUNIPER_DATA_API_CORS_ORIGINS_NONE: list[str] = []
57
+ _JUNIPER_DATA_API_CORS_ORIGINS_DEFAULT: list[str] = _JUNIPER_DATA_API_CORS_ORIGINS_ALL
58
+
59
+
60
+ _JUNIPER_DATA_API_KEYS_LIST_EMPTY: list[str] | None = None
61
+ _JUNIPER_DATA_API_KEYS_LIST_VALUES: list[str] | None = []
62
+ _JUNIPER_DATA_API_KEYS_LIST_DEFAULT: list[str] | None = _JUNIPER_DATA_API_KEYS_LIST_EMPTY
63
+
64
+
65
+ class Settings(BaseSettings):
66
+ """Application settings loaded from environment variables.
67
+
68
+ All settings can be overridden via environment variables with the
69
+ JUNIPER_DATA_ prefix (e.g., JUNIPER_DATA_STORAGE_PATH).
70
+
71
+ Security Settings:
72
+ - api_keys: JSON list of comma-separated, valid API keys (e.g., ["key1,key2"] ).
73
+ - If empty, authentication is disabled (open access).
74
+ - rate_limit_enabled: Enable/disable rate limiting.
75
+ - rate_limit_requests_per_minute: Max requests per minute per client.
76
+ """
77
+
78
+ model_config = SettingsConfigDict(
79
+ env_prefix=_JUNIPER_DATA_ENV_PREFIX_DEFAULT,
80
+ env_file=_JUNIPER_DATA_ENV_FILE_DEFAULT,
81
+ env_file_encoding=_JUNIPER_DATA_ENV_FILE_ENCODING_DEFAULT,
82
+ case_sensitive=_JUNIPER_DATA_ENV_CASE_SENSITIVE_DEFAULT,
83
+ extra=_JUNIPER_DATA_ENV_EXTRA_DEFAULT,
84
+ )
85
+
86
+ # storage_path: str = "./data/datasets"
87
+ storage_path: str = _JUNIPER_DATA_API_STORAGE_PATH_DEFAULT
88
+
89
+ # Default to a more restrictive binding (e.g., 127.0.0.1) for general, non-containerized environments.
90
+ # To provide external access and allow listening on all interfaces,
91
+ # for compatibility with containerized deployments (e.g., Docker, Kubernetes),
92
+ # override JUNIPER_DATA_HOST (e.g., to 0.0.0.0).
93
+ # Note: When setting JUNIPER_DATA_HOST to 0.0.0.0, use firewall/security groups or reverse proxies to control access.
94
+ host: str = _JUNIPER_DATA_API_HOST_DEFAULT
95
+ port: int = _JUNIPER_DATA_API_PORT_DEFAULT
96
+ log_level: str = _JUNIPER_DATA_API_LOGLEVEL_DEFAULT
97
+ cors_origins: list[str] = _JUNIPER_DATA_API_CORS_ORIGINS_DEFAULT
98
+
99
+ # api_keys: list[str] | None = _JUNIPER_DATA_API_KEYS_LIST_DEFAULT
100
+ # api_keys: JSON[list[str]] | None = _JUNIPER_DATA_API_KEYS_LIST_DEFAULT
101
+ api_keys: list[str] | None = _JUNIPER_DATA_API_KEYS_LIST_DEFAULT
102
+ rate_limit_enabled: bool = _JUNIPER_DATA_API_RATELIMIT_ACTIVE_DEFAULT
103
+ rate_limit_requests_per_minute: int = _JUNIPER_DATA_API_RATELIMIT_DEFAULT
104
+
105
+
106
+ @lru_cache
107
+ def get_settings() -> Settings:
108
+ """Get cached application settings."""
109
+ return Settings()
@@ -0,0 +1,32 @@
1
+ """Core module for Juniper Data."""
2
+
3
+ from juniper_data.core.artifacts import arrays_to_bytes, compute_checksum, load_npz, save_npz
4
+ from juniper_data.core.dataset_id import generate_dataset_id
5
+ from juniper_data.core.models import (
6
+ CreateDatasetRequest,
7
+ CreateDatasetResponse,
8
+ DatasetMeta,
9
+ GeneratorInfo,
10
+ PreviewData,
11
+ )
12
+ from juniper_data.core.split import shuffle_and_split, shuffle_data, split_data
13
+
14
+ __all__ = [
15
+ # Dataset ID
16
+ "generate_dataset_id",
17
+ # Split utilities
18
+ "shuffle_and_split",
19
+ "shuffle_data",
20
+ "split_data",
21
+ # Models
22
+ "CreateDatasetRequest",
23
+ "CreateDatasetResponse",
24
+ "DatasetMeta",
25
+ "GeneratorInfo",
26
+ "PreviewData",
27
+ # Artifacts
28
+ "arrays_to_bytes",
29
+ "compute_checksum",
30
+ "load_npz",
31
+ "save_npz",
32
+ ]
@@ -0,0 +1,63 @@
1
+ """Artifact utilities for NPZ file handling and checksum computation."""
2
+
3
+ import hashlib
4
+ import io
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+
9
+
10
+ def save_npz(path: Path, arrays: dict[str, np.ndarray]) -> None:
11
+ """Save arrays to NPZ file.
12
+
13
+ Args:
14
+ path: Path to save the NPZ file.
15
+ arrays: Dictionary mapping array names to numpy arrays.
16
+ """
17
+ np.savez(path, **arrays) # type: ignore[arg-type] # numpy stubs incomplete for **kwargs
18
+
19
+
20
+ def load_npz(path: Path) -> dict[str, np.ndarray]:
21
+ """Load arrays from NPZ file.
22
+
23
+ Args:
24
+ path: Path to the NPZ file.
25
+
26
+ Returns:
27
+ Dictionary mapping array names to numpy arrays.
28
+ """
29
+ with np.load(path) as data:
30
+ return {key: data[key] for key in data.files}
31
+
32
+
33
+ def arrays_to_bytes(arrays: dict[str, np.ndarray]) -> bytes:
34
+ """Convert arrays to NPZ bytes for streaming response.
35
+
36
+ Args:
37
+ arrays: Dictionary mapping array names to numpy arrays.
38
+
39
+ Returns:
40
+ Bytes representation of the NPZ file.
41
+ """
42
+ buffer = io.BytesIO()
43
+ # Ensure a stable serialization order by sorting keys before saving.
44
+ ordered_arrays = {key: arrays[key] for key in sorted(arrays.keys())}
45
+ np.savez(buffer, **ordered_arrays) # type: ignore[arg-type] # numpy stubs incomplete for **kwargs
46
+ buffer.seek(0)
47
+ return buffer.read()
48
+
49
+
50
+ def compute_checksum(arrays: dict[str, np.ndarray]) -> str:
51
+ """Compute SHA-256 checksum of arrays for integrity verification.
52
+
53
+ The checksum is computed over the NPZ byte representation of the arrays,
54
+ ensuring consistent results across different systems.
55
+
56
+ Args:
57
+ arrays: Dictionary mapping array names to numpy arrays.
58
+
59
+ Returns:
60
+ SHA-256 hex digest of the arrays.
61
+ """
62
+ data = arrays_to_bytes(arrays)
63
+ return hashlib.sha256(data).hexdigest()
@@ -0,0 +1,38 @@
1
+ """Dataset ID generation utilities.
2
+
3
+ This module provides deterministic ID generation for datasets based on
4
+ generator name, version, and parameters.
5
+ """
6
+
7
+ import hashlib
8
+ import json
9
+ from typing import Any
10
+
11
+
12
+ def generate_dataset_id(generator: str, version: str, params: dict[str, Any]) -> str:
13
+ """Generate a deterministic hash-based ID from generator metadata and params.
14
+
15
+ Creates a unique, reproducible identifier for a dataset configuration by
16
+ hashing the canonical JSON representation of the generator name, version,
17
+ and parameters.
18
+
19
+ Args:
20
+ generator: Name of the generator (e.g., "spiral").
21
+ version: Version string (e.g., "v1.0.0").
22
+ params: Dictionary of generator parameters.
23
+
24
+ Returns:
25
+ Dataset ID in format "{generator}-{version}-{hash[:16]}".
26
+ Example: "spiral-v1.0.0-a3f8e12b4c567890"
27
+ """
28
+ canonical_data = {
29
+ "generator": generator,
30
+ "version": version,
31
+ "params": params,
32
+ }
33
+
34
+ canonical_json = json.dumps(canonical_data, sort_keys=True, separators=(",", ":"))
35
+
36
+ hash_digest = hashlib.sha256(canonical_json.encode("utf-8")).hexdigest()
37
+
38
+ return f"{generator}-{version}-{hash_digest[:16]}"
@@ -0,0 +1,135 @@
1
+ """Core Pydantic models for dataset metadata and API responses."""
2
+
3
+ from datetime import datetime
4
+ from typing import Any
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ class DatasetMeta(BaseModel):
10
+ """Dataset metadata (always small, JSON-safe)."""
11
+
12
+ # Identity
13
+ dataset_id: str
14
+ generator: str
15
+ generator_version: str
16
+
17
+ # Generation Parameters
18
+ params: dict[str, Any]
19
+
20
+ # Shape Information
21
+ n_samples: int
22
+ n_features: int
23
+ n_classes: int
24
+ n_train: int
25
+ n_test: int
26
+
27
+ # Class Distribution (str keys for JSON compatibility)
28
+ class_distribution: dict[str, int]
29
+
30
+ # Artifacts
31
+ artifact_formats: list[str] = Field(default_factory=lambda: ["npz"])
32
+
33
+ # Timestamps
34
+ created_at: datetime
35
+
36
+ # Optional fields
37
+ checksum: str | None = None
38
+
39
+ # Lifecycle management (DATA-016)
40
+ tags: list[str] = Field(default_factory=list)
41
+ ttl_seconds: int | None = None
42
+ expires_at: datetime | None = None
43
+ last_accessed_at: datetime | None = None
44
+ access_count: int = 0
45
+
46
+
47
+ class CreateDatasetRequest(BaseModel):
48
+ """Request model for creating a new dataset."""
49
+
50
+ generator: str
51
+ params: dict[str, Any] = Field(default_factory=dict)
52
+ persist: bool = True
53
+ tags: list[str] = Field(default_factory=list)
54
+ ttl_seconds: int | None = Field(default=None, ge=1, description="Time-to-live in seconds")
55
+
56
+
57
+ class CreateDatasetResponse(BaseModel):
58
+ """Response model for dataset creation."""
59
+
60
+ dataset_id: str
61
+ generator: str
62
+ meta: DatasetMeta
63
+ artifact_url: str
64
+
65
+
66
+ class GeneratorInfo(BaseModel):
67
+ """Information about an available generator."""
68
+
69
+ name: str
70
+ version: str
71
+ description: str
72
+ params_schema: dict[str, Any] = Field(alias="schema") # JSON schema for params
73
+
74
+
75
+ class PreviewData(BaseModel):
76
+ """Preview subset of a dataset for visualization."""
77
+
78
+ n_samples: int
79
+ X_sample: list[list[float]]
80
+ y_sample: list[list[float]]
81
+
82
+
83
+ class DatasetListFilter(BaseModel):
84
+ """Filter criteria for listing datasets."""
85
+
86
+ generator: str | None = None
87
+ tags: list[str] | None = None
88
+ tags_match: str = Field(default="any", pattern="^(any|all)$")
89
+ created_after: datetime | None = None
90
+ created_before: datetime | None = None
91
+ min_samples: int | None = Field(default=None, ge=1)
92
+ max_samples: int | None = Field(default=None, ge=1)
93
+ include_expired: bool = False
94
+
95
+
96
+ class DatasetListResponse(BaseModel):
97
+ """Response model for filtered dataset listing."""
98
+
99
+ datasets: list[DatasetMeta]
100
+ total: int
101
+ limit: int
102
+ offset: int
103
+
104
+
105
+ class BatchDeleteRequest(BaseModel):
106
+ """Request model for batch delete operation."""
107
+
108
+ dataset_ids: list[str] = Field(min_length=1, max_length=100)
109
+
110
+
111
+ class BatchDeleteResponse(BaseModel):
112
+ """Response model for batch delete operation."""
113
+
114
+ deleted: list[str]
115
+ not_found: list[str]
116
+ total_deleted: int
117
+
118
+
119
+ class UpdateTagsRequest(BaseModel):
120
+ """Request model for updating dataset tags."""
121
+
122
+ add_tags: list[str] = Field(default_factory=list)
123
+ remove_tags: list[str] = Field(default_factory=list)
124
+
125
+
126
+ class DatasetStats(BaseModel):
127
+ """Aggregate statistics about stored datasets."""
128
+
129
+ total_datasets: int
130
+ total_samples: int
131
+ by_generator: dict[str, int]
132
+ by_tag: dict[str, int]
133
+ oldest_created_at: datetime | None = None
134
+ newest_created_at: datetime | None = None
135
+ expired_count: int = 0